diff --git a/core/clients/key_flow_continuous_refresh_test.go b/core/clients/key_flow_continuous_refresh_test.go index 76b29f55a..7c7ee9565 100644 --- a/core/clients/key_flow_continuous_refresh_test.go +++ b/core/clients/key_flow_continuous_refresh_test.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "net/url" "testing" "time" @@ -21,27 +22,28 @@ func TestContinuousRefreshToken(t *testing.T) { jwt.TimePrecision = time.Millisecond // Refresher settings - timeStartBeforeTokenExpiration := 100 * time.Millisecond - timeBetweenContextCheck := 5 * time.Millisecond - timeBetweenTries := 40 * time.Millisecond + timeStartBeforeTokenExpiration := 500 * time.Millisecond + timeBetweenContextCheck := 10 * time.Millisecond + timeBetweenTries := 100 * time.Millisecond // All generated acess tokens will have this time to live - accessTokensTimeToLive := 200 * time.Millisecond + accessTokensTimeToLive := 1 * time.Second tests := []struct { desc string contextClosesIn time.Duration doError error expectedNumberDoCalls int + expectedCallRange []int // Optional: for tests that can have variable call counts }{ { desc: "update access token once", - contextClosesIn: 150 * time.Millisecond, + contextClosesIn: 700 * time.Millisecond, // Should allow one refresh expectedNumberDoCalls: 1, }, { desc: "update access token twice", - contextClosesIn: 250 * time.Millisecond, + contextClosesIn: 1300 * time.Millisecond, // Should allow two refreshes expectedNumberDoCalls: 2, }, { @@ -61,13 +63,13 @@ func TestContinuousRefreshToken(t *testing.T) { }, { desc: "refresh token fails - non-API error", - contextClosesIn: 250 * time.Millisecond, + contextClosesIn: 700 * time.Millisecond, doError: fmt.Errorf("something went wrong"), expectedNumberDoCalls: 1, }, { desc: "refresh token fails - API non-5xx error", - contextClosesIn: 250 * time.Millisecond, + contextClosesIn: 700 * time.Millisecond, doError: &oapierror.GenericOpenAPIError{ StatusCode: http.StatusBadRequest, }, @@ -75,11 +77,12 @@ func TestContinuousRefreshToken(t *testing.T) { }, { desc: "refresh token fails - API 5xx error", - contextClosesIn: 200 * time.Millisecond, + contextClosesIn: 800 * time.Millisecond, doError: &oapierror.GenericOpenAPIError{ StatusCode: http.StatusInternalServerError, }, expectedNumberDoCalls: 3, + expectedCallRange: []int{3, 4}, // Allow 3 or 4 calls due to timing race condition }, } @@ -101,19 +104,16 @@ func TestContinuousRefreshToken(t *testing.T) { numberDoCalls := 0 mockDo := func(_ *http.Request) (resp *http.Response, err error) { - numberDoCalls++ - + numberDoCalls++ // count refresh attempts if tt.doError != nil { return nil, tt.doError } - newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)), }).SignedString([]byte("test")) if err != nil { t.Fatalf("Do call: failed to create access token: %v", err) } - responseBodyStruct := TokenResponseBody{ AccessToken: newAccessToken, RefreshToken: refreshToken, @@ -133,19 +133,34 @@ func TestContinuousRefreshToken(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, tt.contextClosesIn) defer cancel() - keyFlow := &KeyFlow{ - config: &KeyFlowConfig{ - BackgroundTokenRefreshContext: ctx, - }, - authClient: &http.Client{ + keyFlow := &KeyFlow{} + privateKeyBytes, err := generatePrivateKey() + if err != nil { + t.Fatalf("Error generating private key: %s", err) + } + keyFlowConfig := &KeyFlowConfig{ + ServiceAccountKey: fixtureServiceAccountKey(), + PrivateKey: string(privateKeyBytes), + AuthHTTPClient: &http.Client{ Transport: mockTransportFn{mockDo}, }, - token: &TokenResponseBody{ - AccessToken: accessToken, - RefreshToken: refreshToken, - }, + HTTPTransport: mockTransportFn{mockDo}, + BackgroundTokenRefreshContext: nil, + } + err = keyFlow.Init(keyFlowConfig) + if err != nil { + t.Fatalf("failed to initialize key flow: %v", err) } + // Set the token after initialization + err = keyFlow.SetToken(accessToken, refreshToken) + if err != nil { + t.Fatalf("failed to set token: %v", err) + } + + // Set the context for continuous refresh + keyFlow.config.BackgroundTokenRefreshContext = ctx + refresher := &continuousTokenRefresher{ keyFlow: keyFlow, timeStartBeforeTokenExpiration: timeStartBeforeTokenExpiration, @@ -157,7 +172,13 @@ func TestContinuousRefreshToken(t *testing.T) { if err == nil { t.Fatalf("routine finished with non-nil error") } - if numberDoCalls != tt.expectedNumberDoCalls { + + // Check if we have a range of expected calls (for timing-sensitive tests) + if tt.expectedCallRange != nil { + if !contains(tt.expectedCallRange, numberDoCalls) { + t.Fatalf("expected %v calls to API to refresh token, got %d", tt.expectedCallRange, numberDoCalls) + } + } else if numberDoCalls != tt.expectedNumberDoCalls { t.Fatalf("expected %d calls to API to refresh token, got %d", tt.expectedNumberDoCalls, numberDoCalls) } }) @@ -194,7 +215,7 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { // The access token at the start accessTokenFirst, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(100 * time.Millisecond)), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), }).SignedString([]byte("token-first")) if err != nil { t.Fatalf("failed to create first access token: %v", err) @@ -225,60 +246,98 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { ctx, cancel := context.WithCancel(ctx) defer cancel() // This cancels the refresher goroutine + // Extract host from tokenAPI constant for consistency + tokenURL, _ := url.Parse(tokenAPI) + tokenHost := tokenURL.Host + // The Do() routine, that both the keyFlow and continuousRefreshToken() use to make their requests // The bools are used to make sure only one request goes through on each test phase doTestPhase1RequestDone := false doTestPhase2RequestDone := false doTestPhase4RequestDone := false mockDo := func(req *http.Request) (resp *http.Response, err error) { - switch currentTestPhase { - default: - t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase) - return nil, nil - case 1: // Call by continuousRefreshToken() - if doTestPhase1RequestDone { - t.Fatalf("Do call: multiple requests during test phase 1") - } - doTestPhase1RequestDone = true + // Handle auth requests (token refresh) + if req.URL.Host == tokenHost { + switch currentTestPhase { + default: + // After phase 1, allow additional auth requests but don't fail the test + // This handles the continuous nature of the refresh routine + if currentTestPhase > 1 { + // Return a valid response for any additional auth requests + newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }).SignedString([]byte("additional-token")) + if err != nil { + t.Fatalf("Do call: failed to create additional access token: %v", err) + } + responseBodyStruct := TokenResponseBody{ + AccessToken: newAccessToken, + RefreshToken: refreshToken, + } + responseBody, err := json.Marshal(responseBodyStruct) + if err != nil { + t.Fatalf("Do call: failed to marshal additional response: %v", err) + } + response := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(responseBody)), + } + return response, nil + } + t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase) + return nil, nil + case 1: // Call by continuousRefreshToken() + if doTestPhase1RequestDone { + t.Fatalf("Do call: multiple requests during test phase 1") + } + doTestPhase1RequestDone = true - currentTestPhase = 2 - chanBlockContinuousRefreshToken <- true + currentTestPhase = 2 + chanBlockContinuousRefreshToken <- true - // Wait until continuousRefreshToken() is to be unblocked - <-chanUnblockContinuousRefreshToken + // Wait until continuousRefreshToken() is to be unblocked + <-chanUnblockContinuousRefreshToken - if currentTestPhase != 3 { - t.Fatalf("Do call: after unlocking refreshToken(), expected test phase to be 3, got %d", currentTestPhase) - } + if currentTestPhase != 3 { + t.Fatalf("Do call: after unlocking refreshToken(), expected test phase to be 3, got %d", currentTestPhase) + } - // Check required fields are passed - err = req.ParseForm() - if err != nil { - t.Fatalf("Do call: failed to parse body form: %v", err) - } - reqGrantType := req.Form.Get("grant_type") - if reqGrantType != "refresh_token" { - t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "refresh_token", reqGrantType) - } - reqRefreshToken := req.Form.Get("refresh_token") - if reqRefreshToken != refreshToken { - t.Fatalf("Do call: failed request to refresh token: call to refresh access token did not have the expected refresh token set") - } + // Check required fields are passed + err = req.ParseForm() + if err != nil { + t.Fatalf("Do call: failed to parse body form: %v", err) + } + reqGrantType := req.Form.Get("grant_type") + if reqGrantType != "refresh_token" { + t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "refresh_token", reqGrantType) + } + reqRefreshToken := req.Form.Get("refresh_token") + if reqRefreshToken != refreshToken { + t.Fatalf("Do call: failed request to refresh token: call to refresh access token did not have the expected refresh token set") + } - // Return response with accessTokenSecond - responseBodyStruct := TokenResponseBody{ - AccessToken: accessTokenSecond, - RefreshToken: refreshToken, - } - responseBody, err := json.Marshal(responseBodyStruct) - if err != nil { - t.Fatalf("Do call: failed request to refresh token: marshal access token response: %v", err) - } - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), + // Return response with accessTokenSecond + responseBodyStruct := TokenResponseBody{ + AccessToken: accessTokenSecond, + RefreshToken: refreshToken, + } + responseBody, err := json.Marshal(responseBodyStruct) + if err != nil { + t.Fatalf("Do call: failed request to refresh token: marshal access token response: %v", err) + } + response := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(responseBody)), + } + return response, nil } - return response, nil + } + + // Handle regular HTTP requests + switch currentTestPhase { + default: + t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase) + return nil, nil case 2: // Call by tokenFlow, first request if doTestPhase2RequestDone { t.Fatalf("Do call: multiple requests during test phase 2") @@ -292,8 +351,9 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { t.Fatalf("Do call: first request expected to have host %q, found %q", expectedHost, host) } authHeader := req.Header.Get("Authorization") - if authHeader != fmt.Sprintf("Bearer %s", accessTokenFirst) { - t.Fatalf("Do call: first request didn't carry first access token") + expectedAuthHeader := fmt.Sprintf("Bearer %s", accessTokenFirst) + if authHeader != expectedAuthHeader { + t.Fatalf("Do call: first request didn't carry first access token. Expected: %s, Got: %s", expectedAuthHeader, authHeader) } // Return empty response @@ -328,23 +388,49 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { } } - keyFlow := &KeyFlow{ - config: &KeyFlowConfig{ - BackgroundTokenRefreshContext: ctx, - }, - authClient: &http.Client{ + keyFlow := &KeyFlow{} + privateKeyBytes, err := generatePrivateKey() + if err != nil { + t.Fatalf("Error generating private key: %s", err) + } + keyFlowConfig := &KeyFlowConfig{ + ServiceAccountKey: fixtureServiceAccountKey(), + PrivateKey: string(privateKeyBytes), + AuthHTTPClient: &http.Client{ Transport: mockTransportFn{mockDo}, }, - rt: mockTransportFn{mockDo}, - token: &TokenResponseBody{ - AccessToken: accessTokenFirst, - RefreshToken: refreshToken, - }, + HTTPTransport: mockTransportFn{mockDo}, // Use same mock for regular requests + // Don't start continuous refresh automatically + BackgroundTokenRefreshContext: nil, + } + err = keyFlow.Init(keyFlowConfig) + if err != nil { + t.Fatalf("failed to initialize key flow: %v", err) + } + + // Set the token after initialization + err = keyFlow.SetToken(accessTokenFirst, refreshToken) + if err != nil { + t.Fatalf("failed to set token: %v", err) + } + + // Set the context for continuous refresh + keyFlow.config.BackgroundTokenRefreshContext = ctx + + // Create a custom refresher with shorter timing for the test + refresher := &continuousTokenRefresher{ + keyFlow: keyFlow, + timeStartBeforeTokenExpiration: 9 * time.Second, // Start 9 seconds before expiration + timeBetweenContextCheck: 5 * time.Millisecond, + timeBetweenTries: 40 * time.Millisecond, } // TEST START currentTestPhase = 1 - go continuousRefreshToken(keyFlow) + // Ignore returned error as expected in test + go func() { + _ = refresher.continuousRefreshToken() + }() // Wait until continuousRefreshToken() is blocked <-chanBlockContinuousRefreshToken @@ -389,3 +475,12 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { t.Fatalf("Second request body failed to close: %v", err) } } + +func contains(arr []int, val int) bool { + for _, v := range arr { + if v == val { + return true + } + } + return false +} diff --git a/core/clients/key_flow_test.go b/core/clients/key_flow_test.go index 78dc43a3b..9803f24ee 100644 --- a/core/clients/key_flow_test.go +++ b/core/clients/key_flow_test.go @@ -105,25 +105,25 @@ func TestKeyFlowInit(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &KeyFlow{} - cfg := &KeyFlowConfig{} + keyFlow := &KeyFlow{} + keyFlowConfig := &KeyFlowConfig{} t.Setenv("STACKIT_SERVICE_ACCOUNT_KEY", "") if tt.genPrivateKey { privateKeyBytes, err := generatePrivateKey() if err != nil { t.Fatalf("Error generating private key: %s", err) } - cfg.PrivateKey = string(privateKeyBytes) + keyFlowConfig.PrivateKey = string(privateKeyBytes) } if tt.invalidPrivateKey { - cfg.PrivateKey = "invalid_key" + keyFlowConfig.PrivateKey = "invalid_key" } - cfg.ServiceAccountKey = tt.serviceAccountKey - if err := c.Init(cfg); (err != nil) != tt.wantErr { + keyFlowConfig.ServiceAccountKey = tt.serviceAccountKey + if err := keyFlow.Init(keyFlowConfig); (err != nil) != tt.wantErr { t.Errorf("KeyFlow.Init() error = %v, wantErr %v", err, tt.wantErr) } - if c.config == nil { + if keyFlow.config == nil { t.Error("config is nil") } }) @@ -167,8 +167,8 @@ func TestSetToken(t *testing.T) { } } - c := &KeyFlow{} - err = c.SetToken(accessToken, tt.refreshToken) + keyFlow := &KeyFlow{} + err = keyFlow.SetToken(accessToken, tt.refreshToken) if (err != nil) != tt.wantErr { t.Errorf("KeyFlow.SetToken() error = %v, wantErr %v", err, tt.wantErr) @@ -181,8 +181,8 @@ func TestSetToken(t *testing.T) { Scope: defaultScope, TokenType: defaultTokenType, } - if !cmp.Equal(expectedKeyFlowToken, c.token) { - t.Errorf("The returned result is wrong. Expected %+v, got %+v", expectedKeyFlowToken, c.token) + if !cmp.Equal(expectedKeyFlowToken, keyFlow.token) { + t.Errorf("The returned result is wrong. Expected %+v, got %+v", expectedKeyFlowToken, keyFlow.token) } } }) @@ -282,17 +282,27 @@ func TestRequestToken(t *testing.T) { for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { - c := &KeyFlow{ - authClient: &http.Client{ + keyFlow := &KeyFlow{} + privateKeyBytes, err := generatePrivateKey() + if err != nil { + t.Fatalf("Error generating private key: %s", err) + } + keyFlowConfig := &KeyFlowConfig{ + AuthHTTPClient: &http.Client{ Transport: mockTransportFn{func(_ *http.Request) (*http.Response, error) { return tt.mockResponse, tt.mockError }}, }, - config: &KeyFlowConfig{}, - rt: http.DefaultTransport, + ServiceAccountKey: fixtureServiceAccountKey(), + PrivateKey: string(privateKeyBytes), + HTTPTransport: http.DefaultTransport, + } + err = keyFlow.Init(keyFlowConfig) + if err != nil { + t.Fatalf("failed to initialize key flow: %v", err) } - res, err := c.requestToken(tt.grant, tt.assertion) + res, err := keyFlow.requestToken(tt.grant, tt.assertion) defer func() { if res != nil { tempErr := res.Body.Close() @@ -324,14 +334,12 @@ func TestKeyFlow_Do(t *testing.T) { tests := []struct { name string - keyFlow *KeyFlow handlerFn func(tb testing.TB) http.HandlerFunc want int wantErr bool }{ { - name: "success", - keyFlow: &KeyFlow{rt: http.DefaultTransport, config: &KeyFlowConfig{}}, + name: "success", handlerFn: func(tb testing.TB) http.HandlerFunc { tb.Helper() @@ -349,8 +357,7 @@ func TestKeyFlow_Do(t *testing.T) { wantErr: false, }, { - name: "success with code 500", - keyFlow: &KeyFlow{rt: http.DefaultTransport, config: &KeyFlowConfig{}}, + name: "success with code 500", handlerFn: func(_ testing.TB) http.HandlerFunc { return func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/html") @@ -363,16 +370,6 @@ func TestKeyFlow_Do(t *testing.T) { }, { name: "success with custom transport", - keyFlow: &KeyFlow{ - rt: mockTransportFn{ - fn: func(req *http.Request) (*http.Response, error) { - req.Header.Set("User-Agent", "custom_transport") - - return http.DefaultTransport.RoundTrip(req) - }, - }, - config: &KeyFlowConfig{}, - }, handlerFn: func(tb testing.TB) http.HandlerFunc { tb.Helper() @@ -391,14 +388,6 @@ func TestKeyFlow_Do(t *testing.T) { }, { name: "fail with custom proxy", - keyFlow: &KeyFlow{ - rt: &http.Transport{ - Proxy: func(_ *http.Request) (*url.URL, error) { - return nil, fmt.Errorf("proxy error") - }, - }, - config: &KeyFlowConfig{}, - }, handlerFn: func(testing.TB) http.HandlerFunc { return func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -421,37 +410,59 @@ func TestKeyFlow_Do(t *testing.T) { t.Errorf("no error is expected, but got %v", err) } - tt.keyFlow.config.ServiceAccountKey = fixtureServiceAccountKey() - tt.keyFlow.config.PrivateKey = string(privateKeyBytes) - tt.keyFlow.config.BackgroundTokenRefreshContext = ctx - tt.keyFlow.authClient = &http.Client{ - Transport: mockTransportFn{ - fn: func(_ *http.Request) (*http.Response, error) { - res := httptest.NewRecorder() - res.WriteHeader(http.StatusOK) - res.Header().Set("Content-Type", "application/json") - - token := &TokenResponseBody{ - AccessToken: testBearerToken, - ExpiresIn: 2147483647, - RefreshToken: testBearerToken, - TokenType: "Bearer", + keyFlow := &KeyFlow{} + keyFlowConfig := &KeyFlowConfig{ + ServiceAccountKey: fixtureServiceAccountKey(), + PrivateKey: string(privateKeyBytes), + BackgroundTokenRefreshContext: ctx, + HTTPTransport: func() http.RoundTripper { + switch tt.name { + case "success with custom transport": + return mockTransportFn{ + fn: func(req *http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", "custom_transport") + return http.DefaultTransport.RoundTrip(req) + }, } - - if err := json.NewEncoder(res.Body).Encode(token); err != nil { - t.Logf("no error is expected, but got %v", err) + case "fail with custom proxy": + return &http.Transport{ + Proxy: func(_ *http.Request) (*url.URL, error) { + return nil, fmt.Errorf("proxy error") + }, } - - return res.Result(), nil + default: + return http.DefaultTransport + } + }(), + AuthHTTPClient: &http.Client{ + Transport: mockTransportFn{ + fn: func(_ *http.Request) (*http.Response, error) { + res := httptest.NewRecorder() + res.WriteHeader(http.StatusOK) + res.Header().Set("Content-Type", "application/json") + + token := &TokenResponseBody{ + AccessToken: testBearerToken, + ExpiresIn: 2147483647, + RefreshToken: testBearerToken, + TokenType: "Bearer", + } + + if err := json.NewEncoder(res.Body).Encode(token); err != nil { + t.Logf("no error is expected, but got %v", err) + } + + return res.Result(), nil + }, }, }, } - - if err := tt.keyFlow.validate(); err != nil { - t.Errorf("no error is expected, but got %v", err) + err = keyFlow.Init(keyFlowConfig) + if err != nil { + t.Fatalf("failed to initialize key flow: %v", err) } - go continuousRefreshToken(tt.keyFlow) + go continuousRefreshToken(keyFlow) tokenCtx, tokenCancel := context.WithTimeout(context.Background(), 1*time.Second) @@ -461,14 +472,14 @@ func TestKeyFlow_Do(t *testing.T) { case <-tokenCtx.Done(): t.Error(tokenCtx.Err()) case <-time.After(50 * time.Millisecond): - tt.keyFlow.tokenMutex.RLock() - if tt.keyFlow.token != nil { - tt.keyFlow.tokenMutex.RUnlock() + keyFlow.tokenMutex.RLock() + if keyFlow.token != nil { + keyFlow.tokenMutex.RUnlock() tokenCancel() break token } - tt.keyFlow.tokenMutex.RUnlock() + keyFlow.tokenMutex.RUnlock() } } @@ -486,7 +497,7 @@ func TestKeyFlow_Do(t *testing.T) { } httpClient := &http.Client{ - Transport: tt.keyFlow, + Transport: keyFlow, } res, err := httpClient.Do(req)