@@ -8,6 +8,7 @@ package net
8
8
9
9
import (
10
10
"context"
11
+ "errors"
11
12
"fmt"
12
13
"internal/poll"
13
14
"io/ioutil"
@@ -43,11 +44,14 @@ var dnsTransportFallbackTests = []struct {
43
44
44
45
func TestDNSTransportFallback (t * testing.T ) {
45
46
fake := fakeDNSServer {
46
- rh : func (n , _ string , _ * dnsMsg , _ time.Time ) (* dnsMsg , error ) {
47
+ rh : func (n , _ string , q * dnsMsg , _ time.Time ) (* dnsMsg , error ) {
47
48
r := & dnsMsg {
48
49
dnsMsgHdr : dnsMsgHdr {
49
- rcode : dnsRcodeSuccess ,
50
+ id : q .id ,
51
+ response : true ,
52
+ rcode : dnsRcodeSuccess ,
50
53
},
54
+ question : q .question ,
51
55
}
52
56
if n == "udp" {
53
57
r .truncated = true
@@ -98,8 +102,10 @@ func TestSpecialDomainName(t *testing.T) {
98
102
fake := fakeDNSServer {func (_ , _ string , q * dnsMsg , _ time.Time ) (* dnsMsg , error ) {
99
103
r := & dnsMsg {
100
104
dnsMsgHdr : dnsMsgHdr {
101
- id : q .id ,
105
+ id : q .id ,
106
+ response : true ,
102
107
},
108
+ question : q .question ,
103
109
}
104
110
105
111
switch q .question [0 ].Name {
@@ -612,8 +618,10 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
612
618
fake := fakeDNSServer {func (_ , _ string , q * dnsMsg , _ time.Time ) (* dnsMsg , error ) {
613
619
r := & dnsMsg {
614
620
dnsMsgHdr : dnsMsgHdr {
615
- id : q .id ,
621
+ id : q .id ,
622
+ response : true ,
616
623
},
624
+ question : q .question ,
617
625
}
618
626
619
627
switch q .question [0 ].Name {
@@ -751,30 +759,61 @@ type fakeDNSServer struct {
751
759
}
752
760
753
761
func (server * fakeDNSServer ) DialContext (_ context.Context , n , s string ) (Conn , error ) {
754
- return & fakeDNSConn {nil , server , n , s , time.Time {}}, nil
762
+ return & fakeDNSConn {nil , server , n , s , nil , time.Time {}}, nil
755
763
}
756
764
757
765
type fakeDNSConn struct {
758
766
Conn
759
767
server * fakeDNSServer
760
768
n string
761
769
s string
770
+ q * dnsMsg
762
771
t time.Time
763
772
}
764
773
765
774
func (f * fakeDNSConn ) Close () error {
766
775
return nil
767
776
}
768
777
778
+ func (f * fakeDNSConn ) Read (b []byte ) (int , error ) {
779
+ resp , err := f .server .rh (f .n , f .s , f .q , f .t )
780
+ if err != nil {
781
+ return 0 , err
782
+ }
783
+
784
+ bb , ok := resp .Pack ()
785
+ if ! ok {
786
+ return 0 , errors .New ("cannot marshal DNS message" )
787
+ }
788
+ if len (b ) < len (bb ) {
789
+ return 0 , errors .New ("read would fragment DNS message" )
790
+ }
791
+
792
+ copy (b , bb )
793
+ return len (bb ), nil
794
+ }
795
+
796
+ func (f * fakeDNSConn ) ReadFrom (b []byte ) (int , Addr , error ) {
797
+ return 0 , nil , nil
798
+ }
799
+
800
+ func (f * fakeDNSConn ) Write (b []byte ) (int , error ) {
801
+ f .q = new (dnsMsg )
802
+ if ! f .q .Unpack (b ) {
803
+ return 0 , errors .New ("cannot unmarshal DNS message" )
804
+ }
805
+ return len (b ), nil
806
+ }
807
+
808
+ func (f * fakeDNSConn ) WriteTo (b []byte , addr Addr ) (int , error ) {
809
+ return 0 , nil
810
+ }
811
+
769
812
func (f * fakeDNSConn ) SetDeadline (t time.Time ) error {
770
813
f .t = t
771
814
return nil
772
815
}
773
816
774
- func (f * fakeDNSConn ) dnsRoundTrip (q * dnsMsg ) (* dnsMsg , error ) {
775
- return f .server .rh (f .n , f .s , q , f .t )
776
- }
777
-
778
817
// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
779
818
func TestIgnoreDNSForgeries (t * testing.T ) {
780
819
c , s := Pipe ()
@@ -837,7 +876,8 @@ func TestIgnoreDNSForgeries(t *testing.T) {
837
876
},
838
877
}
839
878
840
- resp , err := dnsRoundTripUDP (c , msg )
879
+ dc := & dnsPacketConn {c }
880
+ resp , err := dc .dnsRoundTrip (msg )
841
881
if err != nil {
842
882
t .Fatalf ("dnsRoundTripUDP failed: %v" , err )
843
883
}
@@ -1113,7 +1153,14 @@ func TestStrictErrorsLookupIP(t *testing.T) {
1113
1153
case resolveOpError :
1114
1154
return nil , & OpError {Op : "write" , Err : fmt .Errorf ("socket on fire" )}
1115
1155
case resolveServfail :
1116
- return & dnsMsg {dnsMsgHdr : dnsMsgHdr {id : q .id , rcode : dnsRcodeServerFailure }}, nil
1156
+ return & dnsMsg {
1157
+ dnsMsgHdr : dnsMsgHdr {
1158
+ id : q .id ,
1159
+ response : true ,
1160
+ rcode : dnsRcodeServerFailure ,
1161
+ },
1162
+ question : q .question ,
1163
+ }, nil
1117
1164
case resolveTimeout :
1118
1165
return nil , poll .ErrTimeout
1119
1166
default :
@@ -1123,7 +1170,14 @@ func TestStrictErrorsLookupIP(t *testing.T) {
1123
1170
switch q .question [0 ].Name {
1124
1171
case searchX , name + "." :
1125
1172
// Return NXDOMAIN to utilize the search list.
1126
- return & dnsMsg {dnsMsgHdr : dnsMsgHdr {id : q .id , rcode : dnsRcodeNameError }}, nil
1173
+ return & dnsMsg {
1174
+ dnsMsgHdr : dnsMsgHdr {
1175
+ id : q .id ,
1176
+ response : true ,
1177
+ rcode : dnsRcodeNameError ,
1178
+ },
1179
+ question : q .question ,
1180
+ }, nil
1127
1181
case searchY :
1128
1182
// Return records below.
1129
1183
default :
0 commit comments