Skip to content

Commit 57a6a70

Browse files
authored
RUST-911: Add srvServiceName URI option (#1235)
* RUST-911 Add srvServiceName URI option * fix format * Add prose test and unskip srvServiceName tests * use run_test_srv * fix lint * fix style nit * update test to align more closely with spec description
1 parent f32c18d commit 57a6a70

File tree

9 files changed

+260
-190
lines changed

9 files changed

+260
-190
lines changed

src/client/options.rs

+26-2
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ const URI_OPTIONS: &[&str] = &[
8989
"waitqueuetimeoutms",
9090
"wtimeoutms",
9191
"zlibcompressionlevel",
92+
"srvservicename",
9293
];
9394

9495
/// Reserved characters as defined by [Section 2.2 of RFC-3986](https://tools.ietf.org/html/rfc3986#section-2.2).
@@ -521,6 +522,9 @@ pub struct ClientOptions {
521522
/// By default, no default database is specified.
522523
pub default_database: Option<String>,
523524

525+
/// Overrides the default "mongodb" service name for SRV lookup in both discovery and polling
526+
pub srv_service_name: Option<String>,
527+
524528
#[builder(setter(skip))]
525529
#[derivative(Debug = "ignore")]
526530
pub(crate) socket_timeout: Option<Duration>,
@@ -676,6 +680,8 @@ impl Serialize for ClientOptions {
676680
loadbalanced: &'a Option<bool>,
677681

678682
srvmaxhosts: Option<i32>,
683+
684+
srvservicename: &'a Option<String>,
679685
}
680686

681687
let client_options = ClientOptionsHelper {
@@ -709,6 +715,7 @@ impl Serialize for ClientOptions {
709715
.map(|v| v.try_into())
710716
.transpose()
711717
.map_err(serde::ser::Error::custom)?,
718+
srvservicename: &self.srv_service_name,
712719
};
713720

714721
client_options.serialize(serializer)
@@ -865,6 +872,9 @@ pub struct ConnectionString {
865872
/// Limit on the number of mongos connections that may be created for sharded topologies.
866873
pub srv_max_hosts: Option<u32>,
867874

875+
/// Overrides the default "mongodb" service name for SRV lookup in both discovery and polling
876+
pub srv_service_name: Option<String>,
877+
868878
wait_queue_timeout: Option<Duration>,
869879
tls_insecure: Option<bool>,
870880

@@ -900,11 +910,16 @@ impl Default for HostInfo {
900910
}
901911

902912
impl HostInfo {
903-
async fn resolve(self, resolver_config: Option<ResolverConfig>) -> Result<ResolvedHostInfo> {
913+
async fn resolve(
914+
self,
915+
resolver_config: Option<ResolverConfig>,
916+
srv_service_name: Option<String>,
917+
) -> Result<ResolvedHostInfo> {
904918
Ok(match self {
905919
Self::HostIdentifiers(hosts) => ResolvedHostInfo::HostIdentifiers(hosts),
906920
Self::DnsRecord(hostname) => {
907-
let mut resolver = SrvResolver::new(resolver_config.clone()).await?;
921+
let mut resolver =
922+
SrvResolver::new(resolver_config.clone(), srv_service_name).await?;
908923
let config = resolver.resolve_client_options(&hostname).await?;
909924
ResolvedHostInfo::DnsRecord { hostname, config }
910925
}
@@ -1486,6 +1501,12 @@ impl ConnectionString {
14861501
ConnectionStringParts::default()
14871502
};
14881503

1504+
if conn_str.srv_service_name.is_some() && !srv {
1505+
return Err(Error::invalid_argument(
1506+
"srvServiceName cannot be specified with a non-SRV URI",
1507+
));
1508+
}
1509+
14891510
if let Some(srv_max_hosts) = conn_str.srv_max_hosts {
14901511
if !srv {
14911512
return Err(Error::invalid_argument(
@@ -1976,6 +1997,9 @@ impl ConnectionString {
19761997
k @ "srvmaxhosts" => {
19771998
self.srv_max_hosts = Some(get_u32!(value, k));
19781999
}
2000+
"srvservicename" => {
2001+
self.srv_service_name = Some(value.to_string());
2002+
}
19792003
k @ "tls" | k @ "ssl" => {
19802004
let tls = get_bool!(value, k);
19812005

src/client/options/parse.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ impl Action for ParseConnectionString {
2424
options.resolver_config.clone_from(&self.resolver_config);
2525
}
2626

27-
let resolved = host_info.resolve(self.resolver_config).await?;
27+
let resolved = host_info
28+
.resolve(self.resolver_config, options.srv_service_name.clone())
29+
.await?;
2830
options.hosts = match resolved {
2931
ResolvedHostInfo::HostIdentifiers(hosts) => hosts,
3032
ResolvedHostInfo::DnsRecord {
@@ -159,6 +161,7 @@ impl ClientOptions {
159161
#[cfg(feature = "tracing-unstable")]
160162
tracing_max_document_length_bytes: None,
161163
srv_max_hosts: conn_str.srv_max_hosts,
164+
srv_service_name: conn_str.srv_service_name,
162165
}
163166
}
164167
}

src/client/options/test.rs

-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ static SKIPPED_TESTS: Lazy<Vec<&'static str>> = Lazy::new(|| {
2222
"maxPoolSize=0 does not error",
2323
// TODO RUST-226: unskip this test
2424
"Valid tlsCertificateKeyFilePassword is parsed correctly",
25-
// TODO RUST-911: unskip this test
26-
"SRV URI with custom srvServiceName",
2725
// TODO RUST-229: unskip the following tests
2826
"Single IP literal host without port",
2927
"Single IP literal host with port",

src/sdam/srv_polling.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ impl SrvPollingMonitor {
6262
}
6363

6464
fn rescan_interval(&self) -> Duration {
65-
std::cmp::max(self.rescan_interval, MIN_RESCAN_SRV_INTERVAL)
65+
if cfg!(test) {
66+
self.rescan_interval
67+
} else {
68+
std::cmp::max(self.rescan_interval, MIN_RESCAN_SRV_INTERVAL)
69+
}
6670
}
6771

6872
async fn execute(mut self) {
@@ -130,7 +134,11 @@ impl SrvPollingMonitor {
130134
return Ok(resolver);
131135
}
132136

133-
let resolver = SrvResolver::new(self.client_options.resolver_config().cloned()).await?;
137+
let resolver = SrvResolver::new(
138+
self.client_options.resolver_config().cloned(),
139+
self.client_options.srv_service_name.clone(),
140+
)
141+
.await?;
134142

135143
// Since the connection was not `Some` above, this will always insert the new connection and
136144
// return a reference to it.

src/sdam/srv_polling/test.rs

+24
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,27 @@ async fn srv_max_hosts_random() {
186186
assert_eq!(2, actual.len());
187187
assert!(actual.contains(&localhost_test_build_10gen(27017)));
188188
}
189+
190+
#[tokio::test]
191+
async fn srv_service_name() {
192+
let rescan_interval = Duration::from_secs(1);
193+
let new_hosts = vec![
194+
ServerAddress::Tcp {
195+
host: "localhost.test.build.10gen.cc".to_string(),
196+
port: Some(27019),
197+
},
198+
ServerAddress::Tcp {
199+
host: "localhost.test.build.10gen.cc".to_string(),
200+
port: Some(27020),
201+
},
202+
];
203+
let uri = "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname";
204+
let mut options = ClientOptions::parse(uri).await.unwrap();
205+
// override the min_ttl to speed up lookup interval
206+
options.original_srv_info.as_mut().unwrap().min_ttl = rescan_interval;
207+
options.test_options_mut().mock_lookup_hosts = Some(make_lookup_hosts(new_hosts.clone()));
208+
let mut topology = Topology::new(options).unwrap();
209+
topology.watch().wait_until_initialized().await;
210+
tokio::time::sleep(rescan_interval * 2).await;
211+
assert_eq!(topology.server_addresses(), new_hosts.into_iter().collect());
212+
}

src/srv.rs

+14-3
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,21 @@ pub(crate) enum DomainMismatch {
9090
#[cfg(feature = "dns-resolver")]
9191
pub(crate) struct SrvResolver {
9292
resolver: crate::runtime::AsyncResolver,
93+
srv_service_name: Option<String>,
9394
}
9495

9596
#[cfg(feature = "dns-resolver")]
9697
impl SrvResolver {
97-
pub(crate) async fn new(config: Option<ResolverConfig>) -> Result<Self> {
98+
pub(crate) async fn new(
99+
config: Option<ResolverConfig>,
100+
srv_service_name: Option<String>,
101+
) -> Result<Self> {
98102
let resolver = crate::runtime::AsyncResolver::new(config.map(|c| c.inner)).await?;
99103

100-
Ok(Self { resolver })
104+
Ok(Self {
105+
resolver,
106+
srv_service_name,
107+
})
101108
}
102109

103110
pub(crate) async fn resolve_client_options(
@@ -149,7 +156,11 @@ impl SrvResolver {
149156
original_hostname: &str,
150157
dm: DomainMismatch,
151158
) -> Result<LookupHosts> {
152-
let lookup_hostname = format!("_mongodb._tcp.{}", original_hostname);
159+
let lookup_hostname = format!(
160+
"_{}._tcp.{}",
161+
self.srv_service_name.as_deref().unwrap_or("mongodb"),
162+
original_hostname
163+
);
153164
self.get_srv_hosts_unvalidated(&lookup_hostname)
154165
.await?
155166
.validate(original_hostname, dm)

src/test/spec/initial_dns_seedlist_discovery.rs

-18
Original file line numberDiff line numberDiff line change
@@ -62,24 +62,6 @@ struct ParsedOptions {
6262
}
6363

6464
async fn run_test(mut test_file: TestFile) {
65-
if let Some(ref options) = test_file.options {
66-
// TODO RUST-933: Remove this skip.
67-
let skip = if options.srv_service_name.is_some() {
68-
Some("srvServiceName")
69-
} else {
70-
None
71-
};
72-
73-
if let Some(skip) = skip {
74-
log_uncaptured(format!(
75-
"skipping initial_dns_seedlist_discovery test case due to unsupported connection \
76-
string option: {}",
77-
skip,
78-
));
79-
return;
80-
}
81-
}
82-
8365
// "encoded-userinfo-and-db.json" specifies a database name with a question mark which is
8466
// disallowed on Windows. See
8567
// <https://www.mongodb.com/docs/manual/reference/limits/#restrictions-on-db-names>

0 commit comments

Comments
 (0)