Skip to content

Commit dd0ea41

Browse files
authored
feat: Add special snowflake path for internal dns usage (#52)
1 parent 338241b commit dd0ea41

File tree

3 files changed

+46
-37
lines changed

3 files changed

+46
-37
lines changed

examples/snowflake_native_app_example.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,45 @@
22

33
from contextual import ContextualAI
44

5-
SF_BASE_URL = 'xxxxx-xxxxx-xxxxx.snowflakecomputing.app'
6-
BASE_URL = f'https://{SF_BASE_URL}/v1'
5+
SF_BASE_URL = "xxxxx-xxxxx-xxxxx.snowflakecomputing.app"
6+
BASE_URL = f"https://{SF_BASE_URL}/v1"
77

8-
SAMPLE_MESSAGE = 'Can you tell me about XYZ'
8+
SAMPLE_MESSAGE = "Can you tell me about XYZ"
99

10-
ctx = snowflake.connector.connect( # type: ignore
11-
user="",# snowflake account user
12-
password='', # snowflake account password
13-
account="organization-account", # snowflake organization and account <Organization>-<Account>
14-
session_parameters={
15-
'PYTHON_CONNECTOR_QUERY_RESULT_FORMAT': 'json'
16-
})
10+
ctx = snowflake.connector.connect( # type: ignore
11+
user="", # snowflake account user
12+
password="", # snowflake account password
13+
account="organization-account", # snowflake organization and account <Organization>-<Account>
14+
session_parameters={"PYTHON_CONNECTOR_QUERY_RESULT_FORMAT": "json"},
15+
)
1716

1817
# Obtain a session token.
19-
token_data = ctx._rest._token_request('ISSUE') # type: ignore
20-
token_extract = token_data['data']['sessionToken'] # type: ignore
18+
token_data = ctx._rest._token_request("ISSUE") # type: ignore
19+
token_extract = token_data["data"]["sessionToken"] # type: ignore
2120

2221
# Create a request to the ingress endpoint with authz.
23-
api_key = f'\"{token_extract}\"'
22+
api_key = f'"{token_extract}"'
2423

2524
client = ContextualAI(api_key=api_key, base_url=BASE_URL)
2625

27-
agents = [a for a in client.agents.list() ]
26+
agents = [a for a in client.agents.list()]
2827

2928
agent = agents[0] if agents else None
3029

3130
if agent is None:
32-
print('No agents found')
31+
print("No agents found")
3332
exit()
3433
print(f"Found agent {agent.name} with id {agent.id}")
3534

3635
messages = [
3736
{
38-
'content': SAMPLE_MESSAGE,
39-
'role': 'user',
37+
"content": SAMPLE_MESSAGE,
38+
"role": "user",
4039
}
4140
]
4241

43-
res = client.agents.query.create(agent.id, messages=messages) # type: ignore
42+
res = client.agents.query.create(agent.id, messages=messages) # type: ignore
4443

45-
output = res.message.content # type: ignore
44+
output = res.message.content # type: ignore
4645

47-
print(output)
46+
print(output)

src/contextual/_client.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ class ContextualAI(SyncAPIClient):
5858
with_streaming_response: ContextualAIWithStreamedResponse
5959

6060
# client options
61-
api_key: str
62-
is_snowflake: bool
61+
api_key: str | None = None
62+
is_snowflake: bool = False
63+
is_snowflake_internal: bool = False
6364

6465
def __init__(
6566
self,
@@ -91,9 +92,12 @@ def __init__(
9192
if api_key is None:
9293
api_key = os.environ.get("CONTEXTUAL_API_KEY")
9394
if api_key is None:
94-
raise ContextualAIError(
95-
"The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable"
96-
)
95+
if os.getenv('SNOWFLAKE_INTERNAL_API_SERVICE', False):
96+
self.is_snowflake_internal = True
97+
else:
98+
raise ContextualAIError(
99+
"The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable"
100+
)
97101
self.api_key = api_key
98102

99103
if base_url is None:
@@ -103,8 +107,6 @@ def __init__(
103107

104108
if 'snowflakecomputing.app' in str(base_url):
105109
self.is_snowflake = True
106-
else:
107-
self.is_snowflake = False
108110

109111
super().__init__(
110112
version=__version__,
@@ -137,6 +139,8 @@ def auth_headers(self) -> dict[str, str]:
137139
api_key = self.api_key
138140
if self.is_snowflake:
139141
return {"Authorization": f"Snowflake Token={api_key}"}
142+
elif self.is_snowflake_internal:
143+
return {}
140144
else:
141145
return {"Authorization": f"Bearer {api_key}"}
142146

@@ -245,8 +249,9 @@ class AsyncContextualAI(AsyncAPIClient):
245249
with_streaming_response: AsyncContextualAIWithStreamedResponse
246250

247251
# client options
248-
api_key: str
249-
is_snowflake: bool
252+
api_key: str | None = None
253+
is_snowflake: bool = False
254+
is_snowflake_internal: bool = False
250255

251256
def __init__(
252257
self,
@@ -278,9 +283,12 @@ def __init__(
278283
if api_key is None:
279284
api_key = os.environ.get("CONTEXTUAL_API_KEY")
280285
if api_key is None:
281-
raise ContextualAIError(
282-
"The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable"
283-
)
286+
if os.getenv('SNOWFLAKE_INTERNAL_API_SERVICE', False):
287+
self.is_snowflake_internal = True
288+
else:
289+
raise ContextualAIError(
290+
"The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable"
291+
)
284292
self.api_key = api_key
285293

286294
if base_url is None:
@@ -290,8 +298,6 @@ def __init__(
290298

291299
if 'snowflakecomputing.app' in str(base_url):
292300
self.is_snowflake = True
293-
else:
294-
self.is_snowflake = False
295301

296302
super().__init__(
297303
version=__version__,
@@ -319,11 +325,13 @@ def qs(self) -> Querystring:
319325
return Querystring(array_format="repeat")
320326

321327
@property
322-
@override
328+
@override
323329
def auth_headers(self) -> dict[str, str]:
324330
api_key = self.api_key
325331
if self.is_snowflake:
326332
return {"Authorization": f"Snowflake Token={api_key}"}
333+
elif self.is_snowflake_internal:
334+
return {}
327335
else:
328336
return {"Authorization": f"Bearer {api_key}"}
329337

tests/test_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1617,7 +1617,8 @@ def test_get_platform(self) -> None:
16171617
#
16181618
# Since nest_asyncio.apply() is global and cannot be un-applied, this
16191619
# test is run in a separate process to avoid affecting other tests.
1620-
test_code = dedent("""
1620+
test_code = dedent(
1621+
"""
16211622
import asyncio
16221623
import nest_asyncio
16231624
import threading
@@ -1633,7 +1634,8 @@ async def test_main() -> None:
16331634
16341635
nest_asyncio.apply()
16351636
asyncio.run(test_main())
1636-
""")
1637+
"""
1638+
)
16371639
with subprocess.Popen(
16381640
[sys.executable, "-c", test_code],
16391641
text=True,

0 commit comments

Comments
 (0)