Skip to content

Commit d8a7990

Browse files
benburkertbradfitz
authored andcommitted
net: support all PacketConn and Conn returned by Resolver.Dial
Allow the Resolver.Dial func to return instances of Conn other than *TCPConn and *UDPConn. If the Conn is also a PacketConn, assume DNS messages transmitted over the Conn adhere to section 4.2.1. "UDP usage". Otherwise, follow section 4.2.2. "TCP usage". Provides a hook mechanism so that DNS queries generated by the net package may be answered or modified before being sent to over the network. Updates #19910 Change-Id: Ib089a28ad4a1848bbeaf624ae889f1e82d56655b Reviewed-on: https://go-review.googlesource.com/45153 Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
1 parent d55d7b9 commit d8a7990

File tree

4 files changed

+86
-34
lines changed

4 files changed

+86
-34
lines changed

src/net/dnsclient_unix.go

+12-12
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ type dnsConn interface {
3636
dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
3737
}
3838

39-
func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
40-
return dnsRoundTripUDP(c, query)
39+
// dnsPacketConn implements the dnsConn interface for RFC 1035's
40+
// "UDP usage" transport mechanism. Conn is a packet-oriented connection,
41+
// such as a *UDPConn.
42+
type dnsPacketConn struct {
43+
Conn
4144
}
4245

43-
// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's
44-
// "UDP usage" transport mechanism. c should be a packet-oriented connection,
45-
// such as a *UDPConn.
46-
func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
46+
func (c *dnsPacketConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
4747
b, ok := query.Pack()
4848
if !ok {
4949
return nil, errors.New("cannot marshal DNS message")
@@ -69,14 +69,14 @@ func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
6969
}
7070
}
7171

72-
func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) {
73-
return dnsRoundTripTCP(c, out)
72+
// dnsStreamConn implements the dnsConn interface for RFC 1035's
73+
// "TCP usage" transport mechanism. Conn is a stream-oriented connection,
74+
// such as a *TCPConn.
75+
type dnsStreamConn struct {
76+
Conn
7477
}
7578

76-
// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's
77-
// "TCP usage" transport mechanism. c should be a stream-oriented connection,
78-
// such as a *TCPConn.
79-
func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
79+
func (c *dnsStreamConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
8080
b, ok := query.Pack()
8181
if !ok {
8282
return nil, errors.New("cannot marshal DNS message")

src/net/dnsclient_unix_test.go

+66-12
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package net
88

99
import (
1010
"context"
11+
"errors"
1112
"fmt"
1213
"internal/poll"
1314
"io/ioutil"
@@ -43,11 +44,14 @@ var dnsTransportFallbackTests = []struct {
4344

4445
func TestDNSTransportFallback(t *testing.T) {
4546
fake := fakeDNSServer{
46-
rh: func(n, _ string, _ *dnsMsg, _ time.Time) (*dnsMsg, error) {
47+
rh: func(n, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
4748
r := &dnsMsg{
4849
dnsMsgHdr: dnsMsgHdr{
49-
rcode: dnsRcodeSuccess,
50+
id: q.id,
51+
response: true,
52+
rcode: dnsRcodeSuccess,
5053
},
54+
question: q.question,
5155
}
5256
if n == "udp" {
5357
r.truncated = true
@@ -98,8 +102,10 @@ func TestSpecialDomainName(t *testing.T) {
98102
fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
99103
r := &dnsMsg{
100104
dnsMsgHdr: dnsMsgHdr{
101-
id: q.id,
105+
id: q.id,
106+
response: true,
102107
},
108+
question: q.question,
103109
}
104110

105111
switch q.question[0].Name {
@@ -612,8 +618,10 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
612618
fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
613619
r := &dnsMsg{
614620
dnsMsgHdr: dnsMsgHdr{
615-
id: q.id,
621+
id: q.id,
622+
response: true,
616623
},
624+
question: q.question,
617625
}
618626

619627
switch q.question[0].Name {
@@ -751,30 +759,61 @@ type fakeDNSServer struct {
751759
}
752760

753761
func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
754-
return &fakeDNSConn{nil, server, n, s, time.Time{}}, nil
762+
return &fakeDNSConn{nil, server, n, s, nil, time.Time{}}, nil
755763
}
756764

757765
type fakeDNSConn struct {
758766
Conn
759767
server *fakeDNSServer
760768
n string
761769
s string
770+
q *dnsMsg
762771
t time.Time
763772
}
764773

765774
func (f *fakeDNSConn) Close() error {
766775
return nil
767776
}
768777

778+
func (f *fakeDNSConn) Read(b []byte) (int, error) {
779+
resp, err := f.server.rh(f.n, f.s, f.q, f.t)
780+
if err != nil {
781+
return 0, err
782+
}
783+
784+
bb, ok := resp.Pack()
785+
if !ok {
786+
return 0, errors.New("cannot marshal DNS message")
787+
}
788+
if len(b) < len(bb) {
789+
return 0, errors.New("read would fragment DNS message")
790+
}
791+
792+
copy(b, bb)
793+
return len(bb), nil
794+
}
795+
796+
func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) {
797+
return 0, nil, nil
798+
}
799+
800+
func (f *fakeDNSConn) Write(b []byte) (int, error) {
801+
f.q = new(dnsMsg)
802+
if !f.q.Unpack(b) {
803+
return 0, errors.New("cannot unmarshal DNS message")
804+
}
805+
return len(b), nil
806+
}
807+
808+
func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) {
809+
return 0, nil
810+
}
811+
769812
func (f *fakeDNSConn) SetDeadline(t time.Time) error {
770813
f.t = t
771814
return nil
772815
}
773816

774-
func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) {
775-
return f.server.rh(f.n, f.s, q, f.t)
776-
}
777-
778817
// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
779818
func TestIgnoreDNSForgeries(t *testing.T) {
780819
c, s := Pipe()
@@ -837,7 +876,8 @@ func TestIgnoreDNSForgeries(t *testing.T) {
837876
},
838877
}
839878

840-
resp, err := dnsRoundTripUDP(c, msg)
879+
dc := &dnsPacketConn{c}
880+
resp, err := dc.dnsRoundTrip(msg)
841881
if err != nil {
842882
t.Fatalf("dnsRoundTripUDP failed: %v", err)
843883
}
@@ -1113,7 +1153,14 @@ func TestStrictErrorsLookupIP(t *testing.T) {
11131153
case resolveOpError:
11141154
return nil, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
11151155
case resolveServfail:
1116-
return &dnsMsg{dnsMsgHdr: dnsMsgHdr{id: q.id, rcode: dnsRcodeServerFailure}}, nil
1156+
return &dnsMsg{
1157+
dnsMsgHdr: dnsMsgHdr{
1158+
id: q.id,
1159+
response: true,
1160+
rcode: dnsRcodeServerFailure,
1161+
},
1162+
question: q.question,
1163+
}, nil
11171164
case resolveTimeout:
11181165
return nil, poll.ErrTimeout
11191166
default:
@@ -1123,7 +1170,14 @@ func TestStrictErrorsLookupIP(t *testing.T) {
11231170
switch q.question[0].Name {
11241171
case searchX, name + ".":
11251172
// Return NXDOMAIN to utilize the search list.
1126-
return &dnsMsg{dnsMsgHdr: dnsMsgHdr{id: q.id, rcode: dnsRcodeNameError}}, nil
1173+
return &dnsMsg{
1174+
dnsMsgHdr: dnsMsgHdr{
1175+
id: q.id,
1176+
response: true,
1177+
rcode: dnsRcodeNameError,
1178+
},
1179+
question: q.question,
1180+
}, nil
11271181
case searchY:
11281182
// Return records below.
11291183
default:

src/net/lookup.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,11 @@ type Resolver struct {
111111
// Go's built-in DNS resolver to make TCP and UDP connections
112112
// to DNS services. The provided addr will always be an IP
113113
// address and not a hostname.
114-
// The Conn returned must be a *TCPConn or *UDPConn as
115-
// requested by the network parameter. If nil, the default
116-
// dialer is used.
114+
// If the Conn returned is also a PacketConn, sent and received DNS
115+
// messages must adhere to section 4.2.1. "UDP usage" of RFC 1035.
116+
// Otherwise, DNS messages transmitted over Conn must adhere to section
117+
// 4.2.2. "TCP usage".
118+
// If nil, the default dialer is used.
117119
Dial func(ctx context.Context, network, addr string) (Conn, error)
118120

119121
// TODO(bradfitz): optional interface impl override hook

src/net/lookup_unix.go

+3-7
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ package net
88

99
import (
1010
"context"
11-
"errors"
12-
"reflect"
1311
"sync"
1412
)
1513

@@ -70,12 +68,10 @@ func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, e
7068
if err != nil {
7169
return nil, mapErr(err)
7270
}
73-
dc, ok := c.(dnsConn)
74-
if !ok {
75-
c.Close()
76-
return nil, errors.New("net: Resolver.Dial returned unsupported connection type " + reflect.TypeOf(c).String())
71+
if _, ok := c.(PacketConn); ok {
72+
return &dnsPacketConn{c}, nil
7773
}
78-
return dc, nil
74+
return &dnsStreamConn{c}, nil
7975
}
8076

8177
func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {

0 commit comments

Comments
 (0)