Skip to content

Commit 3411d63

Browse files
committed
net: keep waiting for valid DNS response until timeout
Prevents denial of service attacks from bogus UDP packets. Fixes #13281. Change-Id: Ifb51b17a1b0807bfd27b144d6037431701184e7b Reviewed-on: https://go-review.googlesource.com/22126 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> Run-TryBot: Matthew Dempsky <mdempsky@google.com> TryBot-Result: Gobot Gobot <gobot@golang.org>
1 parent 9f1ccd6 commit 3411d63

File tree

4 files changed

+260
-60
lines changed

4 files changed

+260
-60
lines changed

src/net/dnsclient_unix.go

+52-47
Original file line numberDiff line numberDiff line change
@@ -38,71 +38,82 @@ type dnsConn interface {
3838

3939
SetDeadline(time.Time) error
4040

41-
// readDNSResponse reads a DNS response message from the DNS
42-
// transport endpoint and returns the received DNS response
43-
// message.
44-
readDNSResponse() (*dnsMsg, error)
45-
46-
// writeDNSQuery writes a DNS query message to the DNS
47-
// connection endpoint.
48-
writeDNSQuery(*dnsMsg) error
41+
// dnsRoundTrip executes a single DNS transaction, returning a
42+
// DNS response message for the provided DNS query message.
43+
dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
4944
}
5045

51-
func (c *UDPConn) readDNSResponse() (*dnsMsg, error) {
52-
b := make([]byte, 512) // see RFC 1035
53-
n, err := c.Read(b)
54-
if err != nil {
46+
func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
47+
return dnsRoundTripUDP(c, query)
48+
}
49+
50+
// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's
51+
// "UDP usage" transport mechanism. c should be a packet-oriented connection,
52+
// such as a *UDPConn.
53+
func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
54+
b, ok := query.Pack()
55+
if !ok {
56+
return nil, errors.New("cannot marshal DNS message")
57+
}
58+
if _, err := c.Write(b); err != nil {
5559
return nil, err
5660
}
57-
msg := &dnsMsg{}
58-
if !msg.Unpack(b[:n]) {
59-
return nil, errors.New("cannot unmarshal DNS message")
61+
62+
b = make([]byte, 512) // see RFC 1035
63+
for {
64+
n, err := c.Read(b)
65+
if err != nil {
66+
return nil, err
67+
}
68+
resp := &dnsMsg{}
69+
if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) {
70+
// Ignore invalid responses as they may be malicious
71+
// forgery attempts. Instead continue waiting until
72+
// timeout. See golang.org/issue/13281.
73+
continue
74+
}
75+
return resp, nil
6076
}
61-
return msg, nil
6277
}
6378

64-
func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error {
65-
b, ok := msg.Pack()
79+
func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) {
80+
return dnsRoundTripTCP(c, out)
81+
}
82+
83+
// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's
84+
// "TCP usage" transport mechanism. c should be a stream-oriented connection,
85+
// such as a *TCPConn.
86+
func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
87+
b, ok := query.Pack()
6688
if !ok {
67-
return errors.New("cannot marshal DNS message")
89+
return nil, errors.New("cannot marshal DNS message")
6890
}
91+
l := len(b)
92+
b = append([]byte{byte(l >> 8), byte(l)}, b...)
6993
if _, err := c.Write(b); err != nil {
70-
return err
94+
return nil, err
7195
}
72-
return nil
73-
}
7496

75-
func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
76-
b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
97+
b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
7798
if _, err := io.ReadFull(c, b[:2]); err != nil {
7899
return nil, err
79100
}
80-
l := int(b[0])<<8 | int(b[1])
101+
l = int(b[0])<<8 | int(b[1])
81102
if l > len(b) {
82103
b = make([]byte, l)
83104
}
84105
n, err := io.ReadFull(c, b[:l])
85106
if err != nil {
86107
return nil, err
87108
}
88-
msg := &dnsMsg{}
89-
if !msg.Unpack(b[:n]) {
109+
resp := &dnsMsg{}
110+
if !resp.Unpack(b[:n]) {
90111
return nil, errors.New("cannot unmarshal DNS message")
91112
}
92-
return msg, nil
93-
}
94-
95-
func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
96-
b, ok := msg.Pack()
97-
if !ok {
98-
return errors.New("cannot marshal DNS message")
113+
if !resp.IsResponseTo(query) {
114+
return nil, errors.New("invalid DNS response")
99115
}
100-
l := uint16(len(b))
101-
b = append([]byte{byte(l >> 8), byte(l)}, b...)
102-
if _, err := c.Write(b); err != nil {
103-
return err
104-
}
105-
return nil
116+
return resp, nil
106117
}
107118

108119
func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) {
@@ -150,16 +161,10 @@ func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg,
150161
c.SetDeadline(d)
151162
}
152163
out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
153-
if err := c.writeDNSQuery(&out); err != nil {
154-
return nil, mapErr(err)
155-
}
156-
in, err := c.readDNSResponse()
164+
in, err := c.dnsRoundTrip(&out)
157165
if err != nil {
158166
return nil, mapErr(err)
159167
}
160-
if in.id != out.id {
161-
return nil, errors.New("DNS message ID mismatch")
162-
}
163168
if in.truncated { // see RFC 5966
164169
continue
165170
}

src/net/dnsclient_unix_test.go

+70-13
Original file line numberDiff line numberDiff line change
@@ -567,9 +567,6 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
567567
}
568568

569569
type fakeDNSConn struct {
570-
// last query
571-
qmu sync.Mutex // guards q
572-
q *dnsMsg
573570
// reply handler
574571
rh func(*dnsMsg) (*dnsMsg, error)
575572
}
@@ -586,16 +583,76 @@ func (f *fakeDNSConn) SetDeadline(time.Time) error {
586583
return nil
587584
}
588585

589-
func (f *fakeDNSConn) writeDNSQuery(q *dnsMsg) error {
590-
f.qmu.Lock()
591-
defer f.qmu.Unlock()
592-
f.q = q
593-
return nil
586+
func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) {
587+
return f.rh(q)
594588
}
595589

596-
func (f *fakeDNSConn) readDNSResponse() (*dnsMsg, error) {
597-
f.qmu.Lock()
598-
q := f.q
599-
f.qmu.Unlock()
600-
return f.rh(q)
590+
// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
591+
func TestIgnoreDNSForgeries(t *testing.T) {
592+
const TestAddr uint32 = 0x80420001
593+
594+
c, s := Pipe()
595+
go func() {
596+
b := make([]byte, 512)
597+
n, err := s.Read(b)
598+
if err != nil {
599+
t.Fatal(err)
600+
}
601+
602+
msg := &dnsMsg{}
603+
if !msg.Unpack(b[:n]) {
604+
t.Fatal("invalid DNS query")
605+
}
606+
607+
s.Write([]byte("garbage DNS response packet"))
608+
609+
msg.response = true
610+
msg.id++ // make invalid ID
611+
b, ok := msg.Pack()
612+
if !ok {
613+
t.Fatal("failed to pack DNS response")
614+
}
615+
s.Write(b)
616+
617+
msg.id-- // restore original ID
618+
msg.answer = []dnsRR{
619+
&dnsRR_A{
620+
Hdr: dnsRR_Header{
621+
Name: "www.example.com.",
622+
Rrtype: dnsTypeA,
623+
Class: dnsClassINET,
624+
Rdlength: 4,
625+
},
626+
A: TestAddr,
627+
},
628+
}
629+
630+
b, ok = msg.Pack()
631+
if !ok {
632+
t.Fatal("failed to pack DNS response")
633+
}
634+
s.Write(b)
635+
}()
636+
637+
msg := &dnsMsg{
638+
dnsMsgHdr: dnsMsgHdr{
639+
id: 42,
640+
},
641+
question: []dnsQuestion{
642+
{
643+
Name: "www.example.com.",
644+
Qtype: dnsTypeA,
645+
Qclass: dnsClassINET,
646+
},
647+
},
648+
}
649+
650+
resp, err := dnsRoundTripUDP(c, msg)
651+
if err != nil {
652+
t.Fatalf("dnsRoundTripUDP failed: %v", err)
653+
}
654+
655+
if got := resp.answer[0].(*dnsRR_A).A; got != TestAddr {
656+
t.Error("got address %v, want %v", got, TestAddr)
657+
}
601658
}

src/net/dnsmsg.go

+20
Original file line numberDiff line numberDiff line change
@@ -934,3 +934,23 @@ func (dns *dnsMsg) String() string {
934934
}
935935
return s
936936
}
937+
938+
// IsResponseTo reports whether m is an acceptable response to query.
939+
func (m *dnsMsg) IsResponseTo(query *dnsMsg) bool {
940+
if !m.response {
941+
return false
942+
}
943+
if m.id != query.id {
944+
return false
945+
}
946+
if len(m.question) != len(query.question) {
947+
return false
948+
}
949+
for i, q := range m.question {
950+
q2 := query.question[i]
951+
if !equalASCIILabel(q.Name, q2.Name) || q.Qtype != q2.Qtype || q.Qclass != q2.Qclass {
952+
return false
953+
}
954+
}
955+
return true
956+
}

src/net/dnsmsg_test.go

+118
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,124 @@ func TestDNSParseTXTCorruptTXTLengthReply(t *testing.T) {
280280
}
281281
}
282282

283+
func TestIsResponseTo(t *testing.T) {
284+
// Sample DNS query.
285+
query := dnsMsg{
286+
dnsMsgHdr: dnsMsgHdr{
287+
id: 42,
288+
},
289+
question: []dnsQuestion{
290+
{
291+
Name: "www.example.com.",
292+
Qtype: dnsTypeA,
293+
Qclass: dnsClassINET,
294+
},
295+
},
296+
}
297+
298+
resp := query
299+
resp.response = true
300+
if !resp.IsResponseTo(&query) {
301+
t.Error("got false, want true")
302+
}
303+
304+
badResponses := []dnsMsg{
305+
// Different ID.
306+
{
307+
dnsMsgHdr: dnsMsgHdr{
308+
id: 43,
309+
response: true,
310+
},
311+
question: []dnsQuestion{
312+
{
313+
Name: "www.example.com.",
314+
Qtype: dnsTypeA,
315+
Qclass: dnsClassINET,
316+
},
317+
},
318+
},
319+
320+
// Different query name.
321+
{
322+
dnsMsgHdr: dnsMsgHdr{
323+
id: 42,
324+
response: true,
325+
},
326+
question: []dnsQuestion{
327+
{
328+
Name: "www.google.com.",
329+
Qtype: dnsTypeA,
330+
Qclass: dnsClassINET,
331+
},
332+
},
333+
},
334+
335+
// Different query type.
336+
{
337+
dnsMsgHdr: dnsMsgHdr{
338+
id: 42,
339+
response: true,
340+
},
341+
question: []dnsQuestion{
342+
{
343+
Name: "www.example.com.",
344+
Qtype: dnsTypeAAAA,
345+
Qclass: dnsClassINET,
346+
},
347+
},
348+
},
349+
350+
// Different query class.
351+
{
352+
dnsMsgHdr: dnsMsgHdr{
353+
id: 42,
354+
response: true,
355+
},
356+
question: []dnsQuestion{
357+
{
358+
Name: "www.example.com.",
359+
Qtype: dnsTypeA,
360+
Qclass: dnsClassCSNET,
361+
},
362+
},
363+
},
364+
365+
// No questions.
366+
{
367+
dnsMsgHdr: dnsMsgHdr{
368+
id: 42,
369+
response: true,
370+
},
371+
},
372+
373+
// Extra questions.
374+
{
375+
dnsMsgHdr: dnsMsgHdr{
376+
id: 42,
377+
response: true,
378+
},
379+
question: []dnsQuestion{
380+
{
381+
Name: "www.example.com.",
382+
Qtype: dnsTypeA,
383+
Qclass: dnsClassINET,
384+
},
385+
{
386+
Name: "www.golang.org.",
387+
Qtype: dnsTypeAAAA,
388+
Qclass: dnsClassINET,
389+
},
390+
},
391+
},
392+
}
393+
394+
for i := range badResponses {
395+
if badResponses[i].IsResponseTo(&query) {
396+
t.Error("%v: got true, want false", i)
397+
}
398+
}
399+
}
400+
283401
// Valid DNS SRV reply
284402
const dnsSRVReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" +
285403
"6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" +

0 commit comments

Comments
 (0)