|
1 | 1 | import base64 |
2 | 2 | import contextlib |
3 | 3 | import json |
| 4 | +import logging |
4 | 5 | import re |
| 6 | +import sys |
| 7 | +from typing import Any, Literal |
5 | 8 | import uuid |
| 9 | + |
| 10 | +if sys.version_info >= (3, 11): |
| 11 | + from typing import Never |
| 12 | +else: |
| 13 | + from typing_extensions import Never |
6 | 14 | import warnings |
7 | 15 | from urllib.parse import quote |
8 | 16 |
|
9 | 17 | import google.oauth2.credentials |
10 | 18 | import numpy as np |
| 19 | +from pydantic import BaseModel, ValidationError |
11 | 20 | import requests |
12 | 21 | from cryptography.hazmat.backends import default_backend |
13 | 22 | from cryptography.hazmat.primitives import serialization |
|
33 | 42 | from deepnote_toolkit.sql.sql_utils import is_single_select_query |
34 | 43 | from deepnote_toolkit.sql.url_utils import replace_user_pass_in_pg_url |
35 | 44 |
|
| 45 | +logger = logging.getLogger(__name__) |
| 46 | + |
| 47 | + |
| 48 | +class IntegrationFederatedAuthParams(BaseModel): |
| 49 | + integrationType: Literal["trino", "big-query"] |
| 50 | + integrationId: str |
| 51 | + userId: str |
| 52 | + |
| 53 | + |
| 54 | +class FederatedAuthResponseData(BaseModel): |
| 55 | + accessToken: str |
| 56 | + |
36 | 57 |
|
37 | 58 | def compile_sql_query( |
38 | 59 | skip_jinja_template_render, |
@@ -247,6 +268,68 @@ def _generate_temporary_credentials(integration_id): |
247 | 268 | return quote(data["username"]), quote(data["password"]) |
248 | 269 |
|
249 | 270 |
|
| 271 | +def _get_federated_auth_credentials(integration_id: str, user_id: str) -> str: |
| 272 | + url = get_absolute_userpod_api_url( |
| 273 | + f"integrations/federated-auth-token/{integration_id}" |
| 274 | + ) |
| 275 | + |
| 276 | + # Add project credentials in detached mode |
| 277 | + headers = get_project_auth_headers() |
| 278 | + |
| 279 | + response = requests.post(url, json={"userId": user_id}, timeout=10, headers=headers) |
| 280 | + |
| 281 | + data = FederatedAuthResponseData.model_validate_json(response.json()) |
| 282 | + |
| 283 | + return data.accessToken |
| 284 | + |
| 285 | + |
| 286 | +def _handle_iam_params(sql_alchemy_dict: dict[str, Any]) -> None: |
| 287 | + if "iamParams" not in sql_alchemy_dict: |
| 288 | + return |
| 289 | + |
| 290 | + integration_id = sql_alchemy_dict["iamParams"]["integrationId"] |
| 291 | + |
| 292 | + temporaryUsername, temporaryPassword = _generate_temporary_credentials( |
| 293 | + integration_id |
| 294 | + ) |
| 295 | + |
| 296 | + sql_alchemy_dict["url"] = replace_user_pass_in_pg_url( |
| 297 | + sql_alchemy_dict["url"], temporaryUsername, temporaryPassword |
| 298 | + ) |
| 299 | + |
| 300 | + |
| 301 | +def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: |
| 302 | + if "federatedAuthParams" not in sql_alchemy_dict: |
| 303 | + return |
| 304 | + |
| 305 | + try: |
| 306 | + federated_auth_params = IntegrationFederatedAuthParams.model_validate( |
| 307 | + sql_alchemy_dict["federatedAuthParams"] |
| 308 | + ) |
| 309 | + except ValidationError as e: |
| 310 | + logger.error( |
| 311 | + f"Invalid federated auth params, try updating toolkit version: {e}" |
| 312 | + ) |
| 313 | + return |
| 314 | + |
| 315 | + access_token = _get_federated_auth_credentials( |
| 316 | + federated_auth_params.integrationId, federated_auth_params.userId |
| 317 | + ) |
| 318 | + |
| 319 | + match federated_auth_params.integrationType: |
| 320 | + case "trino": |
| 321 | + sql_alchemy_dict["params"]["connect_args"]["http_headers"][ |
| 322 | + "Authorization" |
| 323 | + ] = f"Bearer {access_token}" |
| 324 | + case "big-query": |
| 325 | + sql_alchemy_dict["params"]["access_token"] = access_token |
| 326 | + case _: |
| 327 | + _check_never: Never = federated_auth_params.integrationType |
| 328 | + raise ValueError( |
| 329 | + f"Unsupported integration type: {federated_auth_params.integrationType}" |
| 330 | + ) |
| 331 | + |
| 332 | + |
250 | 333 | @contextlib.contextmanager |
251 | 334 | def _create_sql_ssh_uri(ssh_enabled, sql_alchemy_dict): |
252 | 335 | server = None |
@@ -346,16 +429,9 @@ def _query_data_source( |
346 | 429 | ): |
347 | 430 | sshEnabled = sql_alchemy_dict.get("ssh_options", {}).get("enabled", False) |
348 | 431 |
|
349 | | - if "iamParams" in sql_alchemy_dict: |
350 | | - integration_id = sql_alchemy_dict["iamParams"]["integrationId"] |
351 | | - |
352 | | - temporaryUsername, temporaryPassword = _generate_temporary_credentials( |
353 | | - integration_id |
354 | | - ) |
| 432 | + _handle_iam_params(sql_alchemy_dict) |
355 | 433 |
|
356 | | - sql_alchemy_dict["url"] = replace_user_pass_in_pg_url( |
357 | | - sql_alchemy_dict["url"], temporaryUsername, temporaryPassword |
358 | | - ) |
| 434 | + _handle_federated_auth_params(sql_alchemy_dict) |
359 | 435 |
|
360 | 436 | with _create_sql_ssh_uri(sshEnabled, sql_alchemy_dict) as url: |
361 | 437 | if url is None: |
|
0 commit comments