Skip to content
This repository was archived by the owner on Jun 12, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions src/oidcendpoint/oidc/add_on/pkce.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@ def post_authn_parse(request, client_id, endpoint_context, **kwargs):
request["code_challenge_method"] = "plain"

if (
request["code_challenge_method"]
not in endpoint_context.args["pkce"]["code_challenge_methods"]
"code_challenge" in request
and (
request["code_challenge_method"]
not in endpoint_context.args["pkce"]["code_challenge_methods"]
)
):
return AuthorizationErrorResponse(
error="invalid_request",
Expand Down Expand Up @@ -119,7 +122,21 @@ def post_token_parse(request, client_id, endpoint_context, **kwargs):


def add_pkce_support(endpoint, **kwargs):
endpoint["authorization"].post_parse_request.append(post_authn_parse)
authn_endpoint = endpoint.get("authorization")
if authn_endpoint is None:
LOGGER.warning(
"No authorization endpoint found, skipping PKCE configuration"
)
return

token_endpoint = endpoint.get("token")
if token_endpoint is None:
LOGGER.warning(
"No token endpoint found, skipping PKCE configuration"
)
return

authn_endpoint.post_parse_request.append(post_authn_parse)

if "essential" not in kwargs:
kwargs["essential"] = False
Expand All @@ -134,6 +151,6 @@ def add_pkce_support(endpoint, **kwargs):
raise ValueError("Unsupported method: {}".format(method))
kwargs["code_challenge_methods"][method] = CC_METHOD[method]

endpoint["authorization"].endpoint_context.args["pkce"] = kwargs
authn_endpoint.endpoint_context.args["pkce"] = kwargs

endpoint["token"].post_parse_request.append(post_token_parse)
token_endpoint.post_parse_request.append(post_token_parse)
51 changes: 51 additions & 0 deletions tests/test_33_pkce.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,54 @@ def test_no_code_verifier(self):
assert isinstance(resp, TokenErrorResponse)
assert resp["error"] == "invalid_grant"
assert resp["error_description"] == "Missing code_verifier"

def test_no_authorization_endpoint(self, conf, caplog):
"""
Test that PKCE configuration does not crash when there is no authorization
endpoint and a warning is logged.
"""
del conf["endpoint"]["authorization"]
create_endpoint(conf)
assert "WARNING" in caplog.text
assert (
"No authorization endpoint found, skipping PKCE configuration"
in caplog.text
)

def test_no_token_endpoint(self, conf, caplog):
"""
Test that PKCE configuration does not crash when there is no token endpoint
and a warning is logged.
"""
del conf["endpoint"]["token"]
create_endpoint(conf)
assert "WARNING" in caplog.text
assert "No token endpoint found, skipping PKCE configuration" in caplog.text

def test_plain_challenge_method_not_supported_and_PKCE_not_essential(self, conf):
"""
Test that an authentication request without PKCE parameters does not fail when
"plain" code_challenge_method is not supported and PKCE is not essential.
"""
conf["add_on"]["pkce"]["kwargs"]["code_challenge_methods"] = ["S256"]
conf["add_on"]["pkce"]["kwargs"]["essential"] = False
endpoint_context = create_endpoint(conf)
authn_endpoint = endpoint_context.endpoint["authorization"]
token_endpoint = endpoint_context.endpoint["token"]

authentication_request = AUTH_REQ.copy()

parsed_request = authn_endpoint.parse_request(authentication_request.to_dict())

assert not isinstance(parsed_request, AuthorizationErrorResponse)
assert isinstance(parsed_request, AuthorizationRequest)

response = authn_endpoint.process_request(parsed_request)

assert isinstance(response["response_args"], AuthorizationResponse)

token_request = TOKEN_REQ.copy()
token_request["code"] = response["response_args"]["code"]
parsed_token_request = token_endpoint.parse_request(token_request)

assert isinstance(parsed_token_request, AccessTokenRequest)