Skip to content

Commit 0c94af0

Browse files
authored
Merge c50220d into a8da274
2 parents a8da274 + c50220d commit 0c94af0

File tree

1 file changed

+85
-9
lines changed

1 file changed

+85
-9
lines changed

deepnote_toolkit/sql/sql_execution.py

Lines changed: 85 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
import base64
22
import contextlib
33
import json
4+
import logging
45
import re
6+
import sys
7+
from typing import Any, Literal
58
import uuid
9+
10+
if sys.version_info >= (3, 11):
11+
from typing import Never
12+
else:
13+
from typing_extensions import Never
614
import warnings
715
from urllib.parse import quote
816

917
import google.oauth2.credentials
1018
import numpy as np
19+
from pydantic import BaseModel, ValidationError
1120
import requests
1221
from cryptography.hazmat.backends import default_backend
1322
from cryptography.hazmat.primitives import serialization
@@ -33,6 +42,18 @@
3342
from deepnote_toolkit.sql.sql_utils import is_single_select_query
3443
from deepnote_toolkit.sql.url_utils import replace_user_pass_in_pg_url
3544

45+
logger = logging.getLogger(__name__)
46+
47+
48+
class IntegrationFederatedAuthParams(BaseModel):
49+
integrationType: Literal["trino", "big-query"]
50+
integrationId: str
51+
userId: str
52+
53+
54+
class FederatedAuthResponseData(BaseModel):
55+
accessToken: str
56+
3657

3758
def compile_sql_query(
3859
skip_jinja_template_render,
@@ -247,6 +268,68 @@ def _generate_temporary_credentials(integration_id):
247268
return quote(data["username"]), quote(data["password"])
248269

249270

271+
def _get_federated_auth_credentials(integration_id: str, user_id: str) -> str:
272+
url = get_absolute_userpod_api_url(
273+
f"integrations/federated-auth-token/{integration_id}"
274+
)
275+
276+
# Add project credentials in detached mode
277+
headers = get_project_auth_headers()
278+
279+
response = requests.post(url, json={"userId": user_id}, timeout=10, headers=headers)
280+
281+
data = FederatedAuthResponseData.model_validate_json(response.json())
282+
283+
return data.accessToken
284+
285+
286+
def _handle_iam_params(sql_alchemy_dict: dict[str, Any]) -> None:
287+
if "iamParams" not in sql_alchemy_dict:
288+
return
289+
290+
integration_id = sql_alchemy_dict["iamParams"]["integrationId"]
291+
292+
temporaryUsername, temporaryPassword = _generate_temporary_credentials(
293+
integration_id
294+
)
295+
296+
sql_alchemy_dict["url"] = replace_user_pass_in_pg_url(
297+
sql_alchemy_dict["url"], temporaryUsername, temporaryPassword
298+
)
299+
300+
301+
def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None:
302+
if "federatedAuthParams" not in sql_alchemy_dict:
303+
return
304+
305+
try:
306+
federated_auth_params = IntegrationFederatedAuthParams.model_validate(
307+
sql_alchemy_dict["federatedAuthParams"]
308+
)
309+
except ValidationError as e:
310+
logger.error(
311+
f"Invalid federated auth params, try updating toolkit version: {e}"
312+
)
313+
return
314+
315+
access_token = _get_federated_auth_credentials(
316+
federated_auth_params.integrationId, federated_auth_params.userId
317+
)
318+
319+
match federated_auth_params.integrationType:
320+
case "trino":
321+
sql_alchemy_dict["params"]["connect_args"]["http_headers"][
322+
"Authorization"
323+
] = f"Bearer {access_token}"
324+
case "big-query":
325+
sql_alchemy_dict["params"]["access_token"] = access_token
326+
case _:
327+
_check_never: Never = federated_auth_params.integrationType
328+
raise ValueError(
329+
f"Unsupported integration type: {federated_auth_params.integrationType}"
330+
)
331+
332+
250333
@contextlib.contextmanager
251334
def _create_sql_ssh_uri(ssh_enabled, sql_alchemy_dict):
252335
server = None
@@ -346,16 +429,9 @@ def _query_data_source(
346429
):
347430
sshEnabled = sql_alchemy_dict.get("ssh_options", {}).get("enabled", False)
348431

349-
if "iamParams" in sql_alchemy_dict:
350-
integration_id = sql_alchemy_dict["iamParams"]["integrationId"]
351-
352-
temporaryUsername, temporaryPassword = _generate_temporary_credentials(
353-
integration_id
354-
)
432+
_handle_iam_params(sql_alchemy_dict)
355433

356-
sql_alchemy_dict["url"] = replace_user_pass_in_pg_url(
357-
sql_alchemy_dict["url"], temporaryUsername, temporaryPassword
358-
)
434+
_handle_federated_auth_params(sql_alchemy_dict)
359435

360436
with _create_sql_ssh_uri(sshEnabled, sql_alchemy_dict) as url:
361437
if url is None:

0 commit comments

Comments
 (0)