Skip to content

Commit eff96b6

Browse files
aldbrchrisburr
authored andcommitted
feat: introduce jwks
1 parent 587205b commit eff96b6

File tree

21 files changed

+466
-148
lines changed

21 files changed

+466
-148
lines changed

diracx-core/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ classifiers = [
1414
]
1515
dependencies = [
1616
"aiobotocore>=2.15",
17-
"authlib",
1817
"botocore>=1.35",
1918
"cachetools",
2019
"email_validator",
2120
"gitpython",
21+
"joserfc",
2222
"pydantic >=2.10",
2323
"pydantic-settings",
2424
"pyyaml",

diracx-core/src/diracx/core/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ class OpenIDConfiguration(TypedDict):
202202
authorization_endpoint: str
203203
device_authorization_endpoint: str
204204
revocation_endpoint: str
205+
jwks_uri: str
205206
grant_types_supported: list[str]
206207
scopes_supported: list[str]
207208
response_types_supported: list[str]

diracx-core/src/diracx/core/settings.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import json
6+
57
from diracx.core.properties import SecurityProperty
68
from diracx.core.s3 import s3_bucket_exists
79

@@ -17,10 +19,10 @@
1719
from typing import TYPE_CHECKING, Annotated, Any, Self, TypeVar
1820

1921
from aiobotocore.session import get_session
20-
from authlib.jose import JsonWebKey
2122
from botocore.config import Config
2223
from botocore.errorfactory import ClientError
2324
from cryptography.fernet import Fernet
25+
from joserfc.jws import KeySet
2426
from pydantic import (
2527
AnyUrl,
2628
BeforeValidator,
@@ -51,28 +53,40 @@ class SqlalchemyDsn(AnyUrl):
5153
)
5254

5355

54-
class _TokenSigningKey(SecretStr):
55-
jwk: JsonWebKey
56+
class _TokenSigningKeyStore(SecretStr):
57+
jwks: KeySet
5658

5759
def __init__(self, data: str):
5860
super().__init__(data)
59-
self.jwk = JsonWebKey.import_key(self.get_secret_value())
61+
62+
# Load the keys from the JSON string
63+
try:
64+
keys = json.loads(self.get_secret_value())
65+
except json.JSONDecodeError as e:
66+
raise ValueError("Invalid JSON string") from e
67+
if not isinstance(keys, dict):
68+
raise ValueError("Invalid JSON string")
69+
self.jwks = KeySet.import_key_set(keys) # type: ignore
6070

6171

62-
def _maybe_load_key_from_file(value: Any) -> Any:
72+
def _maybe_load_keys_from_file(value: Any) -> Any:
6373
"""Load private keys from files if needed."""
64-
if isinstance(value, str) and not value.strip().startswith("-----BEGIN"):
65-
url = TypeAdapter(LocalFileUrl).validate_python(value)
66-
if not url.scheme == "file":
67-
raise ValueError("Only file:// URLs are supported")
68-
if url.path is None:
69-
raise ValueError("No path specified")
70-
value = Path(url.path).read_text()
74+
if isinstance(value, str):
75+
# If the value is a string, we need to check if it is a JSON string or a file URL
76+
if not (value.strip().startswith("{") or value.startswith("[")):
77+
# If it is not a JSON string, we assume it is a file URL
78+
url = TypeAdapter(LocalFileUrl).validate_python(value)
79+
if not url.scheme == "file":
80+
raise ValueError("Only file:// URLs are supported")
81+
if url.path is None:
82+
raise ValueError("No path specified")
83+
return Path(url.path).read_text()
84+
7185
return value
7286

7387

74-
TokenSigningKey = Annotated[
75-
_TokenSigningKey, BeforeValidator(_maybe_load_key_from_file)
88+
TokenSigningKeyStore = Annotated[
89+
_TokenSigningKeyStore, BeforeValidator(_maybe_load_keys_from_file)
7690
]
7791

7892

@@ -137,8 +151,8 @@ class AuthSettings(ServiceSettingsBase):
137151
state_key: FernetKey
138152

139153
token_issuer: str
140-
token_key: TokenSigningKey
141-
token_algorithm: str = "RS256" # noqa: S105
154+
token_keystore: TokenSigningKeyStore
155+
token_allowed_algorithms: list[str] = ["RS256", "EdDSA"] # noqa: S105
142156
access_token_expire_minutes: int = 20
143157
refresh_token_expire_minutes: int = 60
144158

diracx-core/tests/test_secrets.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,45 @@
11
from __future__ import annotations
22

3-
from cryptography.hazmat.primitives import serialization
4-
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
3+
import json
4+
5+
from joserfc.jwk import KeySet, OKPKey
56
from pydantic import TypeAdapter
7+
from uuid_utils import uuid7
68

7-
from diracx.core.settings import TokenSigningKey
9+
from diracx.core.settings import TokenSigningKeyStore
810

911

10-
def compare_keys(key1, key2):
11-
"""Compare two keys by checking their public keys."""
12-
key1_public = key1.public_key().public_bytes(
13-
encoding=serialization.Encoding.PEM,
14-
format=serialization.PublicFormat.SubjectPublicKeyInfo,
15-
)
16-
key2_public = key2.public_key().public_bytes(
17-
encoding=serialization.Encoding.PEM,
18-
format=serialization.PublicFormat.SubjectPublicKeyInfo,
12+
def test_token_signing_key(tmp_path):
13+
keyset = KeySet(
14+
keys=[
15+
OKPKey.generate_key(
16+
parameters={
17+
"key_ops": ["sign", "verify"],
18+
"alg": "EdDSA",
19+
"kid": uuid7().hex,
20+
}
21+
)
22+
]
1923
)
20-
assert key1_public == key2_public
2124

25+
jwks_file = tmp_path / "jwks.json"
26+
jwks_file.write_text(json.dumps(keyset.as_dict(private=True)))
2227

23-
def test_token_signing_key(tmp_path):
24-
private_key = Ed25519PrivateKey.generate()
25-
private_key_pem = private_key.private_bytes(
26-
encoding=serialization.Encoding.PEM,
27-
format=serialization.PrivateFormat.PKCS8,
28-
encryption_algorithm=serialization.NoEncryption(),
29-
).decode("ascii")
30-
key_file = tmp_path / "private_key.pem"
31-
key_file.write_text(private_key_pem)
32-
33-
adapter = TypeAdapter(TokenSigningKey)
34-
35-
# Test that we can load a key from a file
36-
compare_keys(
37-
adapter.validate_python(f"{key_file}").jwk.get_private_key(), private_key
28+
adapter = TypeAdapter(TokenSigningKeyStore)
29+
30+
# Test that we can load a keystore from a file
31+
assert (
32+
adapter.validate_python(f"{jwks_file}").jwks.keys[0].kid == keyset.keys[0].kid
3833
)
39-
compare_keys(
40-
adapter.validate_python(f"file://{key_file}").jwk.get_private_key(),
41-
private_key,
34+
assert (
35+
adapter.validate_python(f"file://{jwks_file}").jwks.keys[0].kid
36+
== keyset.keys[0].kid
4237
)
4338

44-
# Test with can load the PEM data directly
45-
compare_keys(
46-
adapter.validate_python(private_key_pem).jwk.get_private_key(), private_key
39+
# Test with can load the keystore data directly from a JSON string
40+
assert (
41+
adapter.validate_python(json.dumps(keyset.as_dict(private=True)))
42+
.jwks.keys[0]
43+
.kid
44+
== keyset.keys[0].kid
4745
)

diracx-logic/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies = [
1717
"dirac",
1818
"diracx-core",
1919
"diracx-db",
20+
"joserfc",
2021
"pydantic >=2.10",
2122
"uuid-utils",
2223
]
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""JWKS key management scripts.
2+
3+
See https://datatracker.ietf.org/doc/html/rfc7517 for further details.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
import argparse
9+
import asyncio
10+
import json
11+
import logging
12+
from pathlib import Path
13+
14+
from joserfc.jwk import JWKRegistry, Key, KeySet
15+
from uuid_utils import uuid7
16+
17+
logger = logging.getLogger(__name__)
18+
19+
# ---------- Helpers ----------------------------------------------------------
20+
21+
22+
def load_jwks(path: Path) -> KeySet:
23+
"""Return a (possibly empty) JWKSet."""
24+
if path.exists():
25+
return KeySet.import_key_set(json.loads(path.read_text()))
26+
logger.warning("JWKS file %s not found – creating a new one", path)
27+
return KeySet(keys=[])
28+
29+
30+
def save_jwks(path: Path, jwks: KeySet) -> None:
31+
"""Write JWKSet to disk *including* private parts."""
32+
path.write_text(json.dumps(jwks.as_dict(private=True), indent=2))
33+
logger.info("JWKS written to %s", path)
34+
35+
36+
def new_key(
37+
kty: str = "OKP",
38+
crv_or_size: str | int = "Ed25519",
39+
) -> Key:
40+
"""Create a fresh private signing key."""
41+
parameters = {
42+
"key_ops": ["sign", "verify"],
43+
"alg": "EdDSA",
44+
"kid": uuid7().hex,
45+
}
46+
return JWKRegistry.generate_key(
47+
key_type=kty, crv_or_size=crv_or_size, private=True, parameters=parameters # type: ignore[arg-type]
48+
)
49+
50+
51+
# ---------- CLI --------------------------------------------------------------
52+
53+
54+
async def rotate_jwk(args):
55+
"""Rotate keys in a JWKS file by inserting a new key at index 0 (active)."""
56+
logger.info("Rotating JWKs...")
57+
58+
crv_or_size = args.crv_or_size
59+
if isinstance(crv_or_size, str) and crv_or_size.isdigit():
60+
crv_or_size = int(crv_or_size)
61+
62+
jwks_path = Path(args.jwks_path)
63+
jwks = load_jwks(jwks_path)
64+
65+
# Current key (at index 0) is set to "verify" only
66+
if len(jwks.keys) > 0:
67+
active_key = jwks.keys[0]
68+
active_key_dict = active_key.as_dict(private=True)
69+
active_key_dict["key_ops"] = sorted(
70+
set(active_key_dict.get("key_ops", [])) - {"sign"}
71+
)
72+
jwks.keys[0] = JWKRegistry.import_key(active_key_dict)
73+
74+
jwk = new_key(args.kty, crv_or_size)
75+
jwks.keys.insert(0, jwk)
76+
77+
save_jwks(jwks_path, jwks)
78+
79+
80+
async def delete_jwk(args):
81+
"""Delete a JWK from a JWKS file."""
82+
logger.info("Deleting JWK...")
83+
84+
path = Path(args.jwks_path)
85+
jwks = load_jwks(path)
86+
jwks.keys = [k for k in jwks.keys if k.get("kid") != args.kid]
87+
save_jwks(path, jwks)
88+
89+
90+
def parse_args():
91+
parser = argparse.ArgumentParser()
92+
subparsers = parser.add_subparsers(dest="command", required=True)
93+
94+
rotate_jwk_parser = subparsers.add_parser(
95+
"rotate-jwk", help="Rotate JWK keys in a JWKS file"
96+
)
97+
rotate_jwk_parser.add_argument(
98+
"--jwks-path", required=True, help="Path to the existing (old) JWKS JSON file."
99+
)
100+
101+
rotate_jwk_parser.add_argument(
102+
"--kty", default="OKP", help="Key type for the new key."
103+
)
104+
rotate_jwk_parser.add_argument(
105+
"--crv-or-size", default="Ed25519", help="Curve or size for the new key."
106+
)
107+
rotate_jwk_parser.set_defaults(func=rotate_jwk)
108+
109+
delete_jwk_parser = subparsers.add_parser(
110+
"delete-jwk", help="Delete a JWK key from a JWKS file"
111+
)
112+
delete_jwk_parser.add_argument(
113+
"--jwks-path", required=True, help="Path to the JWKS JSON file."
114+
)
115+
delete_jwk_parser.add_argument(
116+
"--kid", required=True, help="Key ID (kid) of the key to delete."
117+
)
118+
delete_jwk_parser.set_defaults(func=delete_jwk)
119+
120+
args = parser.parse_args()
121+
logger.setLevel(logging.INFO)
122+
asyncio.run(args.func(args))
123+
124+
125+
if __name__ == "__main__":
126+
parse_args()

diracx-logic/src/diracx/logic/auth/token.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
import hashlib
77
import re
88
from datetime import datetime, timedelta, timezone
9+
from typing import cast
910

10-
from authlib.jose import JsonWebToken
11+
from joserfc import jwt
12+
from joserfc.jwt import Claims
1113
from uuid_utils import UUID, uuid7
1214

1315
from diracx.core.config import Config
@@ -356,11 +358,24 @@ async def exchange_token(
356358

357359

358360
def create_token(payload: TokenPayload, settings: AuthSettings) -> str:
359-
jwt = JsonWebToken(settings.token_algorithm)
360-
encoded_jwt = jwt.encode(
361-
{"alg": settings.token_algorithm}, payload, settings.token_key.jwk
361+
"""Create a JWT token with the given payload and settings."""
362+
signing_key = None
363+
for key in settings.token_keystore.jwks.keys:
364+
# TODO: https://github.com/authlib/joserfc/issues/52
365+
key_ops = cast(list[str] | None, key.get("key_ops"))
366+
if key_ops and "sign" in key_ops:
367+
signing_key = key
368+
break
369+
370+
if not signing_key:
371+
raise ValueError("No signing key found in JWKS")
372+
373+
return jwt.encode(
374+
header={"alg": signing_key.get("alg"), "kid": signing_key.get("kid")},
375+
claims=cast(Claims, payload),
376+
key=settings.token_keystore.jwks,
377+
algorithms=settings.token_allowed_algorithms,
362378
)
363-
return encoded_jwt.decode("ascii")
364379

365380

366381
async def insert_refresh_token(

0 commit comments

Comments
 (0)