diff --git a/go.mod b/go.mod index 661778fc3..3276d7451 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/gorilla/mux v1.8.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect diff --git a/go.sum b/go.sum index e422a548c..528bd3796 100644 --- a/go.sum +++ b/go.sum @@ -88,6 +88,8 @@ github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3A github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go index 13e89fc30..4f9a00145 100644 --- a/pkg/github/code_scanning_test.go +++ b/pkg/github/code_scanning_test.go @@ -10,7 +10,6 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/google/jsonschema-go/jsonschema" - "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -50,12 +49,9 @@ func Test_GetCodeScanningAlert(t *testing.T) { }{ { name: "successful alert fetch", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetReposCodeScanningAlertsByOwnerByRepoByAlertNumber, - mockAlert, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposCodeScanningAlertsByOwnerByRepoByAlertNumber: mockResponse(t, http.StatusOK, mockAlert), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -66,15 +62,12 @@ func Test_GetCodeScanningAlert(t *testing.T) { }, { name: "alert fetch fails", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposCodeScanningAlertsByOwnerByRepoByAlertNumber, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposCodeScanningAlertsByOwnerByRepoByAlertNumber: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -171,19 +164,16 @@ func Test_ListCodeScanningAlerts(t *testing.T) { }{ { name: "successful alerts listing", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposCodeScanningAlertsByOwnerByRepo, - expectQueryParams(t, map[string]string{ - "ref": "main", - "state": "open", - "severity": "high", - "tool_name": "codeql", - }).andThen( - mockResponse(t, http.StatusOK, mockAlerts), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposCodeScanningAlertsByOwnerByRepo: expectQueryParams(t, map[string]string{ + "ref": "main", + "state": "open", + "severity": "high", + "tool_name": "codeql", + }).andThen( + mockResponse(t, http.StatusOK, mockAlerts), ), - ), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -197,15 +187,12 @@ func Test_ListCodeScanningAlerts(t *testing.T) { }, { name: "alerts listing fails", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposCodeScanningAlertsByOwnerByRepo, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - _, _ = w.Write([]byte(`{"message": "Unauthorized access"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposCodeScanningAlertsByOwnerByRepo: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Unauthorized access"}`)) + }), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", diff --git a/pkg/github/git_test.go b/pkg/github/git_test.go index 66cbccd6e..7a08326bf 100644 --- a/pkg/github/git_test.go +++ b/pkg/github/git_test.go @@ -11,7 +11,6 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/google/jsonschema-go/jsonschema" - "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -71,16 +70,10 @@ func Test_GetRepositoryTree(t *testing.T) { }{ { name: "successfully get repository tree", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - mockResponse(t, http.StatusOK, mockRepo), - ), - mock.WithRequestMatchHandler( - mock.GetReposGitTreesByOwnerByRepoByTreeSha, - mockResponse(t, http.StatusOK, mockTree), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposByOwnerByRepo: mockResponse(t, http.StatusOK, mockRepo), + GetReposGitTreesByOwnerByRepoByTree: mockResponse(t, http.StatusOK, mockTree), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -88,16 +81,10 @@ func Test_GetRepositoryTree(t *testing.T) { }, { name: "successfully get repository tree with path filter", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - mockResponse(t, http.StatusOK, mockRepo), - ), - mock.WithRequestMatchHandler( - mock.GetReposGitTreesByOwnerByRepoByTreeSha, - mockResponse(t, http.StatusOK, mockTree), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposByOwnerByRepo: mockResponse(t, http.StatusOK, mockRepo), + GetReposGitTreesByOwnerByRepoByTree: mockResponse(t, http.StatusOK, mockTree), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -106,15 +93,12 @@ func Test_GetRepositoryTree(t *testing.T) { }, { name: "repository not found", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposByOwnerByRepo: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "nonexistent", @@ -124,19 +108,13 @@ func Test_GetRepositoryTree(t *testing.T) { }, { name: "tree not found", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - mockResponse(t, http.StatusOK, mockRepo), - ), - mock.WithRequestMatchHandler( - mock.GetReposGitTreesByOwnerByRepoByTreeSha, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposByOwnerByRepo: mockResponse(t, http.StatusOK, mockRepo), + GetReposGitTreesByOwnerByRepoByTree: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go index 9c55ba841..871a13fe2 100644 --- a/pkg/github/helper_test.go +++ b/pkg/github/helper_test.go @@ -1,15 +1,34 @@ package github import ( + "bytes" "encoding/json" + "io" "net/http" + "net/url" + "strings" "testing" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) +// GitHub API endpoint patterns for testing +// These constants define the URL patterns used in HTTP mocking for tests +const ( + // Repository endpoints + GetReposByOwnerByRepo = "GET /repos/{owner}/{repo}" + + // Git endpoints + GetReposGitTreesByOwnerByRepoByTree = "GET /repos/{owner}/{repo}/git/trees/{tree}" + + // Code scanning endpoints + GetReposCodeScanningAlertsByOwnerByRepo = "GET /repos/{owner}/{repo}/code-scanning/alerts" + GetReposCodeScanningAlertsByOwnerByRepoByAlertNumber = "GET /repos/{owner}/{repo}/code-scanning/alerts/{alert_number}" +) + type expectations struct { path string queryParams map[string]string @@ -272,3 +291,211 @@ func getResourceResult(t *testing.T, result *mcp.CallToolResult) *mcp.ResourceCo require.IsType(t, &mcp.ResourceContents{}, resource.Resource) return resource.Resource } + +// MockRoundTripper is a mock HTTP transport using testify/mock +type MockRoundTripper struct { + mock.Mock + handlers map[string]http.HandlerFunc +} + +// NewMockRoundTripper creates a new mock round tripper +func NewMockRoundTripper() *MockRoundTripper { + return &MockRoundTripper{ + handlers: make(map[string]http.HandlerFunc), + } +} + +// RoundTrip implements the http.RoundTripper interface +func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // Normalize the request path and method for matching + key := req.Method + " " + req.URL.Path + + // Check if we have a specific handler for this request + if handler, ok := m.handlers[key]; ok { + // Use httptest.ResponseRecorder to capture the handler's response + recorder := &responseRecorder{ + header: make(http.Header), + body: &bytes.Buffer{}, + } + handler(recorder, req) + + return &http.Response{ + StatusCode: recorder.statusCode, + Header: recorder.header, + Body: io.NopCloser(bytes.NewReader(recorder.body.Bytes())), + Request: req, + }, nil + } + + // Fall back to mock.Mock assertions if defined + args := m.Called(req) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*http.Response), args.Error(1) +} + +// On registers an expectation using testify/mock +func (m *MockRoundTripper) OnRequest(method, path string, handler http.HandlerFunc) *MockRoundTripper { + key := method + " " + path + m.handlers[key] = handler + return m +} + +// NewMockHTTPClient creates an HTTP client with a mock transport +func NewMockHTTPClient() (*http.Client, *MockRoundTripper) { + transport := NewMockRoundTripper() + client := &http.Client{Transport: transport} + return client, transport +} + +// responseRecorder is a simple response recorder for the mock transport +type responseRecorder struct { + statusCode int + header http.Header + body *bytes.Buffer +} + +func (r *responseRecorder) Header() http.Header { + return r.header +} + +func (r *responseRecorder) Write(data []byte) (int, error) { + if r.statusCode == 0 { + r.statusCode = http.StatusOK + } + return r.body.Write(data) +} + +func (r *responseRecorder) WriteHeader(statusCode int) { + r.statusCode = statusCode +} + +// matchPath checks if a request path matches a pattern (supports simple wildcards) +func matchPath(pattern, path string) bool { + // Simple exact match for now + if pattern == path { + return true + } + + // Support for path parameters like /repos/{owner}/{repo}/issues/{issue_number} + patternParts := strings.Split(strings.Trim(pattern, "/"), "/") + pathParts := strings.Split(strings.Trim(path, "/"), "/") + + if len(patternParts) != len(pathParts) { + return false + } + + for i := range patternParts { + // Check if this is a path parameter (enclosed in {}) + if strings.HasPrefix(patternParts[i], "{") && strings.HasSuffix(patternParts[i], "}") { + continue // Path parameters match anything + } + if patternParts[i] != pathParts[i] { + return false + } + } + + return true +} + +// executeHandler executes an HTTP handler and returns the response +func executeHandler(handler http.HandlerFunc, req *http.Request) *http.Response { + recorder := &responseRecorder{ + header: make(http.Header), + body: &bytes.Buffer{}, + } + handler(recorder, req) + + return &http.Response{ + StatusCode: recorder.statusCode, + Header: recorder.header, + Body: io.NopCloser(bytes.NewReader(recorder.body.Bytes())), + Request: req, + } +} + +// MockHTTPClientWithHandler creates an HTTP client with a single handler function +func MockHTTPClientWithHandler(handler http.HandlerFunc) *http.Client { + handlers := map[string]http.HandlerFunc{ + "": handler, // Empty key acts as catch-all + } + return MockHTTPClientWithHandlers(handlers) +} + +// MockHTTPClientWithHandlers creates an HTTP client with multiple handlers for different paths +func MockHTTPClientWithHandlers(handlers map[string]http.HandlerFunc) *http.Client { + transport := &multiHandlerTransport{handlers: handlers} + return &http.Client{Transport: transport} +} + +type multiHandlerTransport struct { + handlers map[string]http.HandlerFunc +} + +func (m *multiHandlerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Check for catch-all handler + if handler, ok := m.handlers[""]; ok { + return executeHandler(handler, req), nil + } + + // Try to find a handler for this request + key := req.Method + " " + req.URL.Path + + // First try exact match + if handler, ok := m.handlers[key]; ok { + return executeHandler(handler, req), nil + } + + // Then try pattern matching + for pattern, handler := range m.handlers { + if pattern == "" { + continue // Skip catch-all + } + parts := strings.SplitN(pattern, " ", 2) + if len(parts) == 2 { + method, pathPattern := parts[0], parts[1] + if req.Method == method && matchPath(pathPattern, req.URL.Path) { + return executeHandler(handler, req), nil + } + } + } + + // No handler found + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(bytes.NewReader([]byte("not found"))), + Request: req, + }, nil +} + +// extractPathParams extracts path parameters from a URL path given a pattern +func extractPathParams(pattern, path string) map[string]string { + params := make(map[string]string) + patternParts := strings.Split(strings.Trim(pattern, "/"), "/") + pathParts := strings.Split(strings.Trim(path, "/"), "/") + + if len(patternParts) != len(pathParts) { + return params + } + + for i := range patternParts { + if strings.HasPrefix(patternParts[i], "{") && strings.HasSuffix(patternParts[i], "}") { + paramName := strings.Trim(patternParts[i], "{}") + params[paramName] = pathParts[i] + } + } + + return params +} + +// ParseRequestPath is a helper to extract path parameters +func ParseRequestPath(t *testing.T, req *http.Request, pattern string) url.Values { + t.Helper() + params := extractPathParams(pattern, req.URL.Path) + values := url.Values{} + for k, v := range params { + values.Set(k, v) + } + return values +}