Skip to content

Commit 1e76096

Browse files
feat: add mTLS to ads template (#384)
Co-authored-by: Dov Shlachter <dovs@google.com>
1 parent 164d6bc commit 1e76096

File tree

3 files changed

+343
-76
lines changed

3 files changed

+343
-76
lines changed

gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
{% block content %}
44
from collections import OrderedDict
5-
from typing import Dict, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union
5+
import re
6+
from typing import Callable, Dict, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union
67
import pkg_resources
78

89
import google.api_core.client_options as ClientOptions # type: ignore
@@ -57,7 +58,39 @@ class {{ service.client_name }}Meta(type):
5758
class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
5859
"""{{ service.meta.doc|rst(width=72, indent=4) }}"""
5960

60-
DEFAULT_OPTIONS = ClientOptions.ClientOptions({% if service.host %}api_endpoint='{{ service.host }}'{% endif %})
61+
@staticmethod
62+
def _get_default_mtls_endpoint(api_endpoint):
63+
"""Convert api endpoint to mTLS endpoint.
64+
Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to
65+
"*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively.
66+
Args:
67+
api_endpoint (Optional[str]): the api endpoint to convert.
68+
Returns:
69+
str: converted mTLS api endpoint.
70+
"""
71+
if not api_endpoint:
72+
return api_endpoint
73+
74+
mtls_endpoint_re = re.compile(
75+
r"(?P<name>[^.]+)(?P<mtls>\.mtls)?(?P<sandbox>\.sandbox)?(?P<googledomain>\.googleapis\.com)?"
76+
)
77+
78+
m = mtls_endpoint_re.match(api_endpoint)
79+
name, mtls, sandbox, googledomain = m.groups()
80+
if mtls or not googledomain:
81+
return api_endpoint
82+
83+
if sandbox:
84+
return api_endpoint.replace(
85+
"sandbox.googleapis.com", "mtls.sandbox.googleapis.com"
86+
)
87+
88+
return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com")
89+
90+
DEFAULT_ENDPOINT = {% if service.host %}'{{ service.host }}'{% else %}None{% endif %}
91+
DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
92+
DEFAULT_ENDPOINT
93+
)
6194

6295
@classmethod
6396
def from_service_account_file(cls, filename: str, *args, **kwargs):
@@ -92,7 +125,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
92125
def __init__(self, *,
93126
credentials: credentials.Credentials = None,
94127
transport: Union[str, {{ service.name }}Transport] = None,
95-
client_options: ClientOptions = DEFAULT_OPTIONS,
128+
client_options: ClientOptions = None,
96129
) -> None:
97130
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.
98131

@@ -106,6 +139,17 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
106139
transport to use. If set to None, a transport is chosen
107140
automatically.
108141
client_options (ClientOptions): Custom options for the client.
142+
(1) The ``api_endpoint`` property can be used to override the
143+
default endpoint provided by the client.
144+
(2) If ``transport`` argument is None, ``client_options`` can be
145+
used to create a mutual TLS transport. If ``client_cert_source``
146+
is provided, mutual TLS transport will be created with the given
147+
``api_endpoint`` or the default mTLS endpoint, and the client
148+
SSL credentials obtained from ``client_cert_source``.
149+
150+
Raises:
151+
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
152+
creation failed for any reason.
109153
"""
110154
if isinstance(client_options, dict):
111155
client_options = ClientOptions.from_dict(client_options)
@@ -114,16 +158,45 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
114158
# Ordinarily, we provide the transport, but allowing a custom transport
115159
# instance provides an extensibility point for unusual situations.
116160
if isinstance(transport, {{ service.name }}Transport):
161+
# transport is a {{ service.name }}Transport instance.
117162
if credentials:
118163
raise ValueError('When providing a transport instance, '
119164
'provide its credentials directly.')
120165
self._transport = transport
121-
else:
166+
elif client_options is None or (
167+
client_options.api_endpoint == None
168+
and client_options.client_cert_source is None
169+
):
170+
# Don't trigger mTLS if we get an empty ClientOptions.
122171
Transport = type(self).get_transport_class(transport)
123172
self._transport = Transport(
124-
credentials=credentials,
125-
host=client_options.api_endpoint{% if service.host %} or '{{ service.host }}'{% endif %},
173+
credentials=credentials, host=self.DEFAULT_ENDPOINT
126174
)
175+
else:
176+
# We have a non-empty ClientOptions. If client_cert_source is
177+
# provided, trigger mTLS with user provided endpoint or the default
178+
# mTLS endpoint.
179+
if client_options.client_cert_source:
180+
api_mtls_endpoint = (
181+
client_options.api_endpoint
182+
if client_options.api_endpoint
183+
else self.DEFAULT_MTLS_ENDPOINT
184+
)
185+
else:
186+
api_mtls_endpoint = None
187+
188+
api_endpoint = (
189+
client_options.api_endpoint
190+
if client_options.api_endpoint
191+
else self.DEFAULT_ENDPOINT
192+
)
193+
194+
self._transport = {{ service.name }}GrpcTransport(
195+
credentials=credentials,
196+
host=api_endpoint,
197+
api_mtls_endpoint=api_mtls_endpoint,
198+
client_cert_source=client_options.client_cert_source,
199+
)
127200

128201
{% for method in service.methods.values() -%}
129202
def {{ method.name|snake_case }}(self,

gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/grpc.py.j2

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
{% extends '_base.py.j2' %}
22

33
{% block content %}
4-
from typing import Callable, Dict
4+
from typing import Callable, Dict, Tuple
55

66
from google.api_core import grpc_helpers # type: ignore
77
{%- if service.has_lro %}
88
from google.api_core import operations_v1 # type: ignore
99
{%- endif %}
1010
from google.auth import credentials # type: ignore
11+
from google.auth.transport.grpc import SslCredentials # type: ignore
12+
1113

1214
import grpc # type: ignore
1315

@@ -35,7 +37,9 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
3537
def __init__(self, *,
3638
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
3739
credentials: credentials.Credentials = None,
38-
channel: grpc.Channel = None) -> None:
40+
channel: grpc.Channel = None,
41+
api_mtls_endpoint: str = None,
42+
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None) -> None:
3943
"""Instantiate the transport.
4044

4145
Args:
@@ -49,19 +53,51 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
4953
This argument is ignored if ``channel`` is provided.
5054
channel (Optional[grpc.Channel]): A ``Channel`` instance through
5155
which to make calls.
56+
api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If
57+
provided, it overrides the ``host`` argument and tries to create
58+
a mutual TLS channel with client SSL credentials from
59+
``client_cert_source`` or applicatin default SSL credentials.
60+
client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A
61+
callback to provide client SSL certificate bytes and private key
62+
bytes, both in PEM format. It is ignored if ``api_mtls_endpoint``
63+
is None.
64+
65+
Raises:
66+
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
67+
creation failed for any reason.
5268
"""
53-
# Sanity check: Ensure that channel and credentials are not both
54-
# provided.
5569
if channel:
70+
# Sanity check: Ensure that channel and credentials are not both
71+
# provided.
5672
credentials = False
5773

74+
# If a channel was explicitly provided, set it.
75+
self._grpc_channel = channel
76+
elif api_mtls_endpoint:
77+
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"
78+
79+
# Create SSL credentials with client_cert_source or application
80+
# default SSL credentials.
81+
if client_cert_source:
82+
cert, key = client_cert_source()
83+
ssl_credentials = grpc.ssl_channel_credentials(
84+
certificate_chain=cert, private_key=key
85+
)
86+
else:
87+
ssl_credentials = SslCredentials().ssl_credentials
88+
89+
# create a new channel. The provided one is ignored.
90+
self._grpc_channel = grpc_helpers.create_channel(
91+
host,
92+
credentials=credentials,
93+
ssl_credentials=ssl_credentials,
94+
scopes=self.AUTH_SCOPES,
95+
)
96+
5897
# Run the base constructor.
5998
super().__init__(host=host, credentials=credentials)
60-
self._stubs = {} # type: Dict[str, Callable]
99+
self._stubs = {} # type: Dict[str, Callable]
61100

62-
# If a channel was explicitly provided, set it.
63-
if channel:
64-
self._grpc_channel = channel
65101

66102
@classmethod
67103
def create_channel(cls,

0 commit comments

Comments
 (0)