Skip to content

Commit 305ed34

Browse files
authored
feat: provide AsyncIO support for generated code (#365)
1 parent 6a1263c commit 305ed34

24 files changed

+1397
-73
lines changed

gapic/schema/wrappers.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,13 @@ def __getattr__(self, name):
566566

567567
@utils.cached_property
568568
def client_output(self):
569+
return self._client_output(enable_asyncio=False)
570+
571+
@utils.cached_property
572+
def client_output_async(self):
573+
return self._client_output(enable_asyncio=True)
574+
575+
def _client_output(self, enable_asyncio: bool):
569576
"""Return the output from the client layer.
570577
571578
This takes into account transformations made by the outer GAPIC
@@ -584,8 +591,8 @@ def client_output(self):
584591
if self.lro:
585592
return PythonType(meta=metadata.Metadata(
586593
address=metadata.Address(
587-
name='Operation',
588-
module='operation',
594+
name='AsyncOperation' if enable_asyncio else 'Operation',
595+
module='operation_async' if enable_asyncio else 'operation',
589596
package=('google', 'api_core'),
590597
collisions=self.lro.response_type.ident.collisions,
591598
),
@@ -603,7 +610,7 @@ def client_output(self):
603610
if self.paged_result_field:
604611
return PythonType(meta=metadata.Metadata(
605612
address=metadata.Address(
606-
name=f'{self.name}Pager',
613+
name=f'{self.name}AsyncPager' if enable_asyncio else f'{self.name}Pager',
607614
package=self.ident.api_naming.module_namespace + (self.ident.api_naming.versioned_module_name,) + self.ident.subpackage + (
608615
'services',
609616
utils.to_snake_case(self.ident.parent[-1]),
@@ -744,6 +751,8 @@ def _ref_types(self, recursive: bool) -> Sequence[Union[MessageType, EnumType]]:
744751
if not self.void:
745752
answer.append(self.client_output)
746753
answer.extend(self.client_output.field_types)
754+
answer.append(self.client_output_async)
755+
answer.extend(self.client_output_async.field_types)
747756

748757
# If this method has LRO, it is possible (albeit unlikely) that
749758
# the LRO messages reside in a different module.
@@ -801,6 +810,11 @@ def client_name(self) -> str:
801810
"""Returns the name of the generated client class"""
802811
return self.name + "Client"
803812

813+
@property
814+
def async_client_name(self) -> str:
815+
"""Returns the name of the generated AsyncIO client class"""
816+
return self.name + "AsyncClient"
817+
804818
@property
805819
def transport_name(self):
806820
return self.name + "Transport"
@@ -809,6 +823,10 @@ def transport_name(self):
809823
def grpc_transport_name(self):
810824
return self.name + "GrpcTransport"
811825

826+
@property
827+
def grpc_asyncio_transport_name(self):
828+
return self.name + "GrpcAsyncIOTransport"
829+
812830
@property
813831
def has_lro(self) -> bool:
814832
"""Return whether the service has a long-running method."""
@@ -856,7 +874,7 @@ def names(self) -> FrozenSet[str]:
856874
used for imports.
857875
"""
858876
# Put together a set of the service and method names.
859-
answer = {self.name, self.client_name}
877+
answer = {self.name, self.client_name, self.async_client_name}
860878
answer.update(
861879
utils.to_snake_case(i.name) for i in self.methods.values()
862880
)

gapic/templates/%namespace/%name/__init__.py.j2

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.'
1212
if service.meta.address.subpackage == api.subpackage_view -%}
1313
from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.') }}.{% endif -%}
1414
{{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.client import {{ service.client_name }}
15+
from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.') }}.{% endif -%}
16+
{{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.async_client import {{ service.async_client_name }}
1517
{% endfor -%}
1618

1719
{# Import messages and enums from each proto.
@@ -48,6 +50,7 @@ __all__ = (
4850
{% for service in api.services.values()|sort(attribute='name')
4951
if service.meta.address.subpackage == api.subpackage_view -%}
5052
'{{ service.client_name }}',
53+
'{{ service.async_client_name }}',
5154
{% endfor -%}
5255
{% for proto in api.protos.values()|sort(attribute='module_name')
5356
if proto.meta.address.subpackage == api.subpackage_view -%}

gapic/templates/%namespace/%name_%version/%sub/services/%service/__init__.py.j2

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
{% block content %}
44
from .client import {{ service.client_name }}
5+
from .async_client import {{ service.async_client_name }}
56

67
__all__ = (
78
'{{ service.client_name }}',
9+
'{{ service.async_client_name }}',
810
)
911
{% endblock %}
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
{% extends '_base.py.j2' %}
2+
3+
{% block content %}
4+
from collections import OrderedDict
5+
import functools
6+
import re
7+
from typing import Dict, {% if service.any_server_streaming %}AsyncIterable, {% endif %}{% if service.any_client_streaming %}AsyncIterator, {% endif %}Sequence, Tuple, Type, Union
8+
import pkg_resources
9+
10+
import google.api_core.client_options as ClientOptions # type: ignore
11+
from google.api_core import exceptions # type: ignore
12+
from google.api_core import gapic_v1 # type: ignore
13+
from google.api_core import retry as retries # type: ignore
14+
from google.auth import credentials # type: ignore
15+
from google.oauth2 import service_account # type: ignore
16+
17+
{% filter sort_lines -%}
18+
{% for method in service.methods.values() -%}
19+
{% for ref_type in method.flat_ref_types -%}
20+
{{ ref_type.ident.python_import }}
21+
{% endfor -%}
22+
{% endfor -%}
23+
{% endfilter %}
24+
from .transports.base import {{ service.name }}Transport
25+
from .transports.grpc_asyncio import {{ service.grpc_asyncio_transport_name }}
26+
from .client import {{ service.client_name }}
27+
28+
29+
class {{ service.async_client_name }}:
30+
"""{{ service.meta.doc|rst(width=72, indent=4) }}"""
31+
32+
_client: {{ service.client_name }}
33+
34+
DEFAULT_ENDPOINT = {{ service.client_name }}.DEFAULT_ENDPOINT
35+
DEFAULT_MTLS_ENDPOINT = {{ service.client_name }}.DEFAULT_MTLS_ENDPOINT
36+
37+
{% for message in service.resource_messages -%}
38+
{{ message.resource_type|snake_case }}_path = staticmethod({{ service.client_name }}.{{ message.resource_type|snake_case }}_path)
39+
40+
{% endfor %}
41+
42+
from_service_account_file = {{ service.client_name }}.from_service_account_file
43+
from_service_account_json = from_service_account_file
44+
45+
get_transport_class = functools.partial(type({{ service.client_name }}).get_transport_class, type({{ service.client_name }}))
46+
47+
def __init__(self, *,
48+
credentials: credentials.Credentials = None,
49+
transport: Union[str, {{ service.name }}Transport] = 'grpc_asyncio',
50+
client_options: ClientOptions = None,
51+
) -> None:
52+
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.
53+
54+
Args:
55+
credentials (Optional[google.auth.credentials.Credentials]): The
56+
authorization credentials to attach to requests. These
57+
credentials identify the application to the service; if none
58+
are specified, the client will attempt to ascertain the
59+
credentials from the environment.
60+
transport (Union[str, ~.{{ service.name }}Transport]): The
61+
transport to use. If set to None, a transport is chosen
62+
automatically.
63+
client_options (ClientOptions): Custom options for the client. It
64+
won't take effect if a ``transport`` instance is provided.
65+
(1) The ``api_endpoint`` property can be used to override the
66+
default endpoint provided by the client. GOOGLE_API_USE_MTLS
67+
environment variable can also be used to override the endpoint:
68+
"always" (always use the default mTLS endpoint), "never" (always
69+
use the default regular endpoint, this is the default value for
70+
the environment variable) and "auto" (auto switch to the default
71+
mTLS endpoint if client SSL credentials is present). However,
72+
the ``api_endpoint`` property takes precedence if provided.
73+
(2) The ``client_cert_source`` property is used to provide client
74+
SSL credentials for mutual TLS transport. If not provided, the
75+
default SSL credentials will be used if present.
76+
77+
Raises:
78+
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
79+
creation failed for any reason.
80+
"""
81+
{# NOTE(lidiz) Not using kwargs since we want the docstring and types. #}
82+
self._client = {{ service.client_name }}(
83+
credentials=credentials,
84+
transport=transport,
85+
client_options=client_options,
86+
)
87+
88+
{% for method in service.methods.values() -%}
89+
{% if not method.server_streaming %}async {% endif -%}def {{ method.name|snake_case }}(self,
90+
{%- if not method.client_streaming %}
91+
request: {{ method.input.ident }} = None,
92+
*,
93+
{% for field in method.flattened_fields.values() -%}
94+
{{ field.name }}: {{ field.ident }} = None,
95+
{% endfor -%}
96+
{%- else %}
97+
requests: AsyncIterator[{{ method.input.ident }}] = None,
98+
*,
99+
{% endif -%}
100+
retry: retries.Retry = gapic_v1.method.DEFAULT,
101+
timeout: float = None,
102+
metadata: Sequence[Tuple[str, str]] = (),
103+
{%- if not method.server_streaming %}
104+
) -> {{ method.client_output_async.ident }}:
105+
{%- else %}
106+
) -> AsyncIterable[{{ method.client_output_async.ident }}]:
107+
{%- endif %}
108+
r"""{{ method.meta.doc|rst(width=72, indent=8) }}
109+
110+
Args:
111+
{%- if not method.client_streaming %}
112+
request (:class:`{{ method.input.ident.sphinx }}`):
113+
The request object.{{ ' ' -}}
114+
{{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }}
115+
{% for key, field in method.flattened_fields.items() -%}
116+
{{ field.name }} (:class:`{{ field.ident.sphinx }}`):
117+
{{ field.meta.doc|rst(width=72, indent=16, nl=False) }}
118+
This corresponds to the ``{{ key }}`` field
119+
on the ``request`` instance; if ``request`` is provided, this
120+
should not be set.
121+
{% endfor -%}
122+
{%- else %}
123+
requests (AsyncIterator[`{{ method.input.ident.sphinx }}`]):
124+
The request object AsyncIterator.{{ ' ' -}}
125+
{{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }}
126+
{%- endif %}
127+
retry (google.api_core.retry.Retry): Designation of what errors, if any,
128+
should be retried.
129+
timeout (float): The timeout for this request.
130+
metadata (Sequence[Tuple[str, str]]): Strings which should be
131+
sent along with the request as metadata.
132+
{%- if not method.void %}
133+
134+
Returns:
135+
{%- if not method.server_streaming %}
136+
{{ method.client_output_async.ident.sphinx }}:
137+
{%- else %}
138+
AsyncIterable[{{ method.client_output_async.ident.sphinx }}]:
139+
{%- endif %}
140+
{{ method.client_output_async.meta.doc|rst(width=72, indent=16) }}
141+
{%- endif %}
142+
"""
143+
{%- if not method.client_streaming %}
144+
# Create or coerce a protobuf request object.
145+
{% if method.flattened_fields -%}
146+
# Sanity check: If we got a request object, we should *not* have
147+
# gotten any keyword arguments that map to the request.
148+
if request is not None and any([{{ method.flattened_fields.values()|join(', ', attribute='name') }}]):
149+
raise ValueError('If the `request` argument is set, then none of '
150+
'the individual field arguments should be set.')
151+
152+
{% endif -%}
153+
{% if method.input.ident.package != method.ident.package -%} {# request lives in a different package, so there is no proto wrapper #}
154+
# The request isn't a proto-plus wrapped type,
155+
# so it must be constructed via keyword expansion.
156+
if isinstance(request, dict):
157+
request = {{ method.input.ident }}(**request)
158+
{% if method.flattened_fields -%}{# Cross-package req and flattened fields #}
159+
elif not request:
160+
request = {{ method.input.ident }}()
161+
{% endif -%}{# Cross-package req and flattened fields #}
162+
{%- else %}
163+
request = {{ method.input.ident }}(request)
164+
{% endif %} {# different request package #}
165+
166+
{#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #}
167+
{% if method.flattened_fields -%}
168+
# If we have keyword arguments corresponding to fields on the
169+
# request, apply these.
170+
{% endif -%}
171+
{%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %}
172+
if {{ field.name }} is not None:
173+
request.{{ key }} = {{ field.name }}
174+
{%- endfor %}
175+
{# They can be _extended_, however -#}
176+
{%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %}
177+
if {{ field.name }}:
178+
request.{{ key }}.extend({{ field.name }})
179+
{%- endfor %}
180+
{%- endif %}
181+
182+
# Wrap the RPC method; this adds retry and timeout information,
183+
# and friendly error handling.
184+
rpc = gapic_v1.method_async.wrap_method(
185+
self._client._transport.{{ method.name|snake_case }},
186+
{%- if method.retry %}
187+
default_retry=retries.Retry(
188+
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
189+
{% if method.retry.max_backoff %}maximum={{ method.retry.max_backoff }},{% endif %}
190+
{% if method.retry.backoff_multiplier %}multiplier={{ method.retry.backoff_multiplier }},{% endif %}
191+
predicate=retries.if_exception_type(
192+
{%- filter sort_lines %}
193+
{%- for ex in method.retry.retryable_exceptions %}
194+
exceptions.{{ ex.__name__ }},
195+
{%- endfor %}
196+
{%- endfilter %}
197+
),
198+
),
199+
{%- endif %}
200+
default_timeout={{ method.timeout }},
201+
client_info=_client_info,
202+
)
203+
{%- if method.field_headers %}
204+
205+
# Certain fields should be provided within the metadata header;
206+
# add these here.
207+
metadata = tuple(metadata) + (
208+
gapic_v1.routing_header.to_grpc_metadata((
209+
{%- for field_header in method.field_headers %}
210+
{%- if not method.client_streaming %}
211+
('{{ field_header }}', request.{{ field_header }}),
212+
{%- endif %}
213+
{%- endfor %}
214+
)),
215+
)
216+
{%- endif %}
217+
218+
# Send the request.
219+
{% if not method.void %}response = {% endif %}
220+
{%- if not method.server_streaming %}await {% endif %}rpc(
221+
{%- if not method.client_streaming %}
222+
request,
223+
{%- else %}
224+
requests,
225+
{%- endif %}
226+
retry=retry,
227+
timeout=timeout,
228+
metadata=metadata,
229+
)
230+
{%- if method.lro %}
231+
232+
# Wrap the response in an operation future.
233+
response = operation_async.from_gapic(
234+
response,
235+
self._client._transport.operations_client,
236+
{{ method.lro.response_type.ident }},
237+
metadata_type={{ method.lro.metadata_type.ident }},
238+
)
239+
{%- elif method.paged_result_field %}
240+
241+
# This method is paged; wrap the response in a pager, which provides
242+
# an `__aiter__` convenience method.
243+
response = {{ method.client_output_async.ident }}(
244+
method=rpc,
245+
request=request,
246+
response=response,
247+
)
248+
{%- endif %}
249+
{%- if not method.void %}
250+
251+
# Done; return the response.
252+
return response
253+
{%- endif %}
254+
{{ '\n' }}
255+
{% endfor %}
256+
257+
258+
try:
259+
_client_info = gapic_v1.client_info.ClientInfo(
260+
gapic_version=pkg_resources.get_distribution(
261+
'{{ api.naming.warehouse_package_name }}',
262+
).version,
263+
)
264+
except pkg_resources.DistributionNotFound:
265+
_client_info = gapic_v1.client_info.ClientInfo()
266+
267+
268+
__all__ = (
269+
'{{ service.async_client_name }}',
270+
)
271+
{% endblock %}

0 commit comments

Comments
 (0)