diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml new file mode 100644 index 000000000..92524ea17 --- /dev/null +++ b/.github/workflows/conformance.yml @@ -0,0 +1,69 @@ +name: Conformance Test + +on: + pull_request: + +permissions: + contents: read + +jobs: + conformance: + runs-on: ubuntu-latest + + steps: + - name: Check out code + uses: actions/checkout@v6 + with: + # Fetch full history to access merge-base + fetch-depth: 0 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version-file: "go.mod" + + - name: Download dependencies + run: go mod download + + - name: Run conformance test + id: conformance + run: | + # Run conformance test, capture stdout for summary + script/conformance-test > conformance-summary.txt 2>&1 || true + + # Output the summary + cat conformance-summary.txt + + # Check result + if grep -q "RESULT: ALL TESTS PASSED" conformance-summary.txt; then + echo "status=passed" >> $GITHUB_OUTPUT + else + echo "status=differences" >> $GITHUB_OUTPUT + fi + + - name: Generate Job Summary + run: | + # Add the full markdown report to the job summary + echo "# MCP Server Conformance Report" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "Comparing PR branch against merge-base with \`origin/main\`" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + + # Extract and append the report content (skip the header since we added our own) + tail -n +5 conformance-report/CONFORMANCE_REPORT.md >> $GITHUB_STEP_SUMMARY + + echo "" >> $GITHUB_STEP_SUMMARY + echo "---" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + + # Add interpretation note + if [ "${{ steps.conformance.outputs.status }}" = "passed" ]; then + echo "✅ **All conformance tests passed** - No behavioral differences detected." >> $GITHUB_STEP_SUMMARY + else + echo "⚠️ **Differences detected** - Review the diffs above to ensure changes are intentional." >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "Common expected differences:" >> $GITHUB_STEP_SUMMARY + echo "- New tools/toolsets added" >> $GITHUB_STEP_SUMMARY + echo "- Tool descriptions updated" >> $GITHUB_STEP_SUMMARY + echo "- Capability changes (intentional improvements)" >> $GITHUB_STEP_SUMMARY + fi diff --git a/.gitignore b/.gitignore index b018fafac..5684108b0 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,5 @@ bin/ # binary github-mcp-server -.history \ No newline at end of file +.history +conformance-report/ diff --git a/README.md b/README.md index 117bacacd..e8737cd25 100644 --- a/README.md +++ b/README.md @@ -718,6 +718,11 @@ The following sets of tools are available: - `owner`: Repository owner (string, required) - `repo`: Repository name (string, required) +- **get_label** - Get a specific label from a repository. + - `name`: Label name. (string, required) + - `owner`: Repository owner (username or organization name) (string, required) + - `repo`: Repository name (string, required) + - **issue_read** - Get issue details - `issue_number`: The number of the issue (number, required) - `method`: The read operation to perform on a single issue. diff --git a/cmd/github-mcp-server/generate_docs.go b/cmd/github-mcp-server/generate_docs.go index ddfcd10ba..8760c3f32 100644 --- a/cmd/github-mcp-server/generate_docs.go +++ b/cmd/github-mcp-server/generate_docs.go @@ -8,7 +8,7 @@ import ( "strings" "github.com/github/github-mcp-server/pkg/github" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -50,8 +50,8 @@ func generateReadmeDocs(readmePath string) error { // Create translation helper t, _ := translations.TranslationHelper() - // Create toolset group - stateless, no dependencies needed for doc generation - r := github.NewRegistry(t) + // Build registry - stateless, no dependencies needed for doc generation + r := github.NewRegistry(t).Build() // Generate toolsets documentation toolsetsDoc := generateToolsetsDoc(r) @@ -104,7 +104,7 @@ func generateRemoteServerDocs(docsPath string) error { return os.WriteFile(docsPath, []byte(updatedContent), 0600) //#nosec G306 } -func generateToolsetsDoc(r *toolsets.Registry) string { +func generateToolsetsDoc(r *registry.Registry) string { var buf strings.Builder // Add table header and separator @@ -123,7 +123,7 @@ func generateToolsetsDoc(r *toolsets.Registry) string { return strings.TrimSuffix(buf.String(), "\n") } -func generateToolsDoc(r *toolsets.Registry) string { +func generateToolsDoc(r *registry.Registry) string { // AllTools() returns tools sorted by toolset ID then tool name. // We iterate once, grouping by toolset as we encounter them. tools := r.AllTools() @@ -133,7 +133,7 @@ func generateToolsDoc(r *toolsets.Registry) string { var buf strings.Builder var toolBuf strings.Builder - var currentToolsetID toolsets.ToolsetID + var currentToolsetID registry.ToolsetID firstSection := true writeSection := func() { @@ -299,8 +299,8 @@ func generateRemoteToolsetsDoc() string { // Create translation helper t, _ := translations.TranslationHelper() - // Create toolset group - stateless - r := github.NewRegistry(t) + // Build registry - stateless + r := github.NewRegistry(t).Build() // Generate table header buf.WriteString("| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) |\n") diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index ad9ebb190..e286930a2 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -178,7 +178,7 @@ func setupMCPClient(t *testing.T, options ...clientOption) *mcp.ClientSession { // so that there is a shared setup mechanism, but let's wait till we feel more friction. enabledToolsets := opts.enabledToolsets if enabledToolsets == nil { - enabledToolsets = github.NewRegistry(translations.NullTranslationHelper).DefaultToolsetIDs() + enabledToolsets = github.NewRegistry(translations.NullTranslationHelper).Build().DefaultToolsetIDs() } ghServer, err := ghmcp.NewMCPServer(ghmcp.MCPServerConfig{ diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 67fcad4a7..e98637067 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -18,7 +18,7 @@ import ( "github.com/github/github-mcp-server/pkg/lockdown" mcplog "github.com/github/github-mcp-server/pkg/log" "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" gogithub "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -69,146 +69,183 @@ type MCPServerConfig struct { RepoAccessTTL *time.Duration } -func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { - apiHost, err := parseAPIHost(cfg.Host) - if err != nil { - return nil, fmt.Errorf("failed to parse API host: %w", err) - } +// githubClients holds all the GitHub API clients created for a server instance. +type githubClients struct { + rest *gogithub.Client + gql *githubv4.Client + gqlHTTP *http.Client // retained for middleware to modify transport + raw *raw.Client + repoAccess *lockdown.RepoAccessCache +} - // Construct our REST client +// createGitHubClients creates all the GitHub API clients needed by the server. +func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients, error) { + // Construct REST client restClient := gogithub.NewClient(nil).WithAuthToken(cfg.Token) restClient.UserAgent = fmt.Sprintf("github-mcp-server/%s", cfg.Version) restClient.BaseURL = apiHost.baseRESTURL restClient.UploadURL = apiHost.uploadURL - // Construct our GraphQL client - // We're using NewEnterpriseClient here unconditionally as opposed to NewClient because we already - // did the necessary API host parsing so that github.com will return the correct URL anyway. + // Construct GraphQL client + // We use NewEnterpriseClient unconditionally since we already parsed the API host gqlHTTPClient := &http.Client{ Transport: &bearerAuthTransport{ transport: http.DefaultTransport, token: cfg.Token, }, - } // We're going to wrap the Transport later in beforeInit - gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient) - repoAccessOpts := []lockdown.RepoAccessOption{} - if cfg.RepoAccessTTL != nil { - repoAccessOpts = append(repoAccessOpts, lockdown.WithTTL(*cfg.RepoAccessTTL)) } + gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient) + + // Create raw content client (shares REST client's HTTP transport) + rawClient := raw.NewClient(restClient, apiHost.rawURL) - repoAccessLogger := cfg.Logger.With("component", "lockdown") - repoAccessOpts = append(repoAccessOpts, lockdown.WithLogger(repoAccessLogger)) + // Set up repo access cache for lockdown mode var repoAccessCache *lockdown.RepoAccessCache if cfg.LockdownMode { - repoAccessCache = lockdown.GetInstance(gqlClient, repoAccessOpts...) + opts := []lockdown.RepoAccessOption{ + lockdown.WithLogger(cfg.Logger.With("component", "lockdown")), + } + if cfg.RepoAccessTTL != nil { + opts = append(opts, lockdown.WithTTL(*cfg.RepoAccessTTL)) + } + repoAccessCache = lockdown.GetInstance(gqlClient, opts...) } - // Determine enabled toolsets based on configuration: - // - nil means "use defaults" (unless dynamic mode without explicit toolsets) - // - empty slice means "no toolsets" (for dynamic mode to enable on demand) - // - explicit list means "use these toolsets" - var enabledToolsets []string - if cfg.EnabledToolsets != nil { - enabledToolsets = cfg.EnabledToolsets - } else if cfg.DynamicToolsets { - // Dynamic mode with no toolsets specified: start with no toolsets enabled - // so users can enable them on demand via the dynamic tools - enabledToolsets = []string{} - } - // else: enabledToolsets stays nil, which means "use defaults" in WithToolsets + return &githubClients{ + rest: restClient, + gql: gqlClient, + gqlHTTP: gqlHTTPClient, + raw: rawClient, + repoAccess: repoAccessCache, + }, nil +} - // Generate instructions based on enabled toolsets - instructions := github.GenerateInstructions(enabledToolsets) +// resolveEnabledToolsets determines which toolsets should be enabled based on config. +// Returns nil for "use defaults", empty slice for "none", or explicit list. +func resolveEnabledToolsets(cfg MCPServerConfig) []string { + enabledToolsets := cfg.EnabledToolsets - getClient := func(_ context.Context) (*gogithub.Client, error) { - return restClient, nil // closing over client + // In dynamic mode, remove "all" and "default" since users enable toolsets on demand + if cfg.DynamicToolsets && enabledToolsets != nil { + enabledToolsets = github.RemoveToolset(enabledToolsets, string(github.ToolsetMetadataAll.ID)) + enabledToolsets = github.RemoveToolset(enabledToolsets, string(github.ToolsetMetadataDefault.ID)) } - getGQLClient := func(_ context.Context) (*githubv4.Client, error) { - return gqlClient, nil // closing over client + if enabledToolsets != nil { + return enabledToolsets + } + if cfg.DynamicToolsets { + // Dynamic mode with no toolsets specified: start empty so users enable on demand + return []string{} + } + if len(cfg.EnabledTools) > 0 { + // When specific tools are requested but no toolsets, don't use default toolsets + // This matches the original behavior: --tools=X alone registers only X + return []string{} } + // nil means "use defaults" in WithToolsets + return nil +} - getRawClient := func(ctx context.Context) (*raw.Client, error) { - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - return raw.NewClient(client, apiHost.rawURL), nil // closing over client +func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { + apiHost, err := parseAPIHost(cfg.Host) + if err != nil { + return nil, fmt.Errorf("failed to parse API host: %w", err) } - ghServer := github.NewServer(cfg.Version, &mcp.ServerOptions{ - Instructions: instructions, - Logger: cfg.Logger, - CompletionHandler: github.CompletionsHandler(getClient), - }) + clients, err := createGitHubClients(cfg, apiHost) + if err != nil { + return nil, fmt.Errorf("failed to create GitHub clients: %w", err) + } - // Add middlewares - ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) - ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, restClient, gqlHTTPClient)) - - // Create the dependencies struct for tool handlers - deps := github.ToolDependencies{ - GetClient: getClient, - GetGQLClient: getGQLClient, - GetRawClient: getRawClient, - RepoAccessCache: repoAccessCache, - T: cfg.Translator, - Flags: github.FeatureFlags{LockdownMode: cfg.LockdownMode}, - ContentWindowSize: cfg.ContentWindowSize, + enabledToolsets := resolveEnabledToolsets(cfg) + + // For instruction generation, we need actual toolset names (not nil). + // nil means "use defaults" in registry, so expand it for instructions. + instructionToolsets := enabledToolsets + if instructionToolsets == nil { + instructionToolsets = github.GetDefaultToolsetIDs() + } + + // Create the MCP server + serverOpts := &mcp.ServerOptions{ + Instructions: github.GenerateInstructions(instructionToolsets), + Logger: cfg.Logger, + CompletionHandler: github.CompletionsHandler(func(_ context.Context) (*gogithub.Client, error) { + return clients.rest, nil + }), } - // Create toolset group with all tools, resources, and prompts (stateless) - r := github.NewRegistry(cfg.Translator) + // In dynamic mode, explicitly advertise capabilities since tools/resources/prompts + // may be enabled at runtime even if none are registered initially. + if cfg.DynamicToolsets { + serverOpts.HasTools = true + serverOpts.HasResources = true + serverOpts.HasPrompts = true + } - // Clean tool names (WithTools will resolve any deprecated aliases) - enabledTools := github.CleanTools(cfg.EnabledTools) + ghServer := github.NewServer(cfg.Version, serverOpts) - // Apply filters based on configuration - // - WithDeprecatedToolAliases: adds backward compatibility aliases - // - WithReadOnly: filters out write tools when true - // - WithToolsets: nil=defaults, empty=none, handles "all"/"default" keywords - // - WithTools: additional tools that bypass toolset filtering (additive, resolves aliases) - // - WithFeatureChecker: filters based on feature flags - filteredReg := r. - WithDeprecatedToolAliases(github.DeprecatedToolAliases). + // Add middlewares + ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) + ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, clients.rest, clients.gqlHTTP)) + + // Create dependencies for tool handlers + deps := github.NewBaseDeps( + clients.rest, + clients.gql, + clients.raw, + clients.repoAccess, + cfg.Translator, + github.FeatureFlags{LockdownMode: cfg.LockdownMode}, + cfg.ContentWindowSize, + ) + + // Build and register the tool/resource/prompt registry + registry := github.NewRegistry(cfg.Translator). + WithDeprecatedAliases(github.DeprecatedToolAliases). WithReadOnly(cfg.ReadOnly). WithToolsets(enabledToolsets). - WithTools(enabledTools). - WithFeatureChecker(createFeatureChecker(cfg.EnabledFeatures)) + WithTools(github.CleanTools(cfg.EnabledTools)). + WithFeatureChecker(createFeatureChecker(cfg.EnabledFeatures)). + Build() - // Warn about unrecognized toolset names (likely typos) - if unrecognized := filteredReg.UnrecognizedToolsets(); len(unrecognized) > 0 { + if unrecognized := registry.UnrecognizedToolsets(); len(unrecognized) > 0 { fmt.Fprintf(os.Stderr, "Warning: unrecognized toolsets ignored: %s\n", strings.Join(unrecognized, ", ")) } - // Register all mcp functionality with the server - // Use background context for local server (no per-request actor context) - filteredReg.RegisterAll(context.Background(), ghServer, deps) + // Register GitHub tools/resources/prompts from the registry. + // In dynamic mode with no explicit toolsets, this is a no-op since enabledToolsets + // is empty - users enable toolsets at runtime via the dynamic tools below (but can + // enable toolsets or tools explicitly that do need registration). + registry.RegisterAll(context.Background(), ghServer, deps) - // Register dynamic toolset management if configured - // Dynamic tools get access to the filtered toolset group which tracks enabled state. - // ToolsForToolset() returns all tools for a toolset regardless of enabled status, - // so dynamic tools can enable any toolset at runtime. + // Register dynamic toolset management tools (enable/disable) - these are separate + // meta-tools that control the registry, not part of the registry itself if cfg.DynamicToolsets { - dynamicDeps := github.DynamicToolDependencies{ - Server: ghServer, - Registry: filteredReg, - ToolDeps: deps, - T: cfg.Translator, - } - dynamicTools := github.DynamicTools(filteredReg) - for _, tool := range dynamicTools { - tool.RegisterFunc(ghServer, dynamicDeps) - } + registerDynamicTools(ghServer, registry, deps, cfg.Translator) } return ghServer, nil } +// registerDynamicTools adds the dynamic toolset enable/disable tools to the server. +func registerDynamicTools(server *mcp.Server, registry *registry.Registry, deps *github.BaseDeps, t translations.TranslationHelperFunc) { + dynamicDeps := github.DynamicToolDependencies{ + Server: server, + Registry: registry, + ToolDeps: deps, + T: t, + } + for _, tool := range github.DynamicTools(registry) { + tool.RegisterFunc(server, dynamicDeps) + } +} + // createFeatureChecker returns a FeatureFlagChecker that checks if a flag name // is present in the provided list of enabled features. For the local server, // this is populated from the --features CLI flag. -func createFeatureChecker(enabledFeatures []string) toolsets.FeatureFlagChecker { +func createFeatureChecker(enabledFeatures []string) registry.FeatureFlagChecker { // Build a set for O(1) lookup featureSet := make(map[string]bool, len(enabledFeatures)) for _, f := range enabledFeatures { diff --git a/pkg/github/actions.go b/pkg/github/actions.go index f29f75e99..584a23200 100644 --- a/pkg/github/actions.go +++ b/pkg/github/actions.go @@ -11,7 +11,7 @@ import ( "github.com/github/github-mcp-server/internal/profiler" buffer "github.com/github/github-mcp-server/pkg/buffer" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -25,7 +25,7 @@ const ( ) // ListWorkflows creates a tool to list workflows in a repository -func ListWorkflows(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListWorkflows(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -96,7 +96,7 @@ func ListWorkflows(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListWorkflowRuns creates a tool to list workflow runs for a specific workflow -func ListWorkflowRuns(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListWorkflowRuns(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -250,7 +250,7 @@ func ListWorkflowRuns(t translations.TranslationHelperFunc) toolsets.ServerTool } // RunWorkflow creates a tool to run an Actions workflow -func RunWorkflow(t translations.TranslationHelperFunc) toolsets.ServerTool { +func RunWorkflow(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -362,7 +362,7 @@ func RunWorkflow(t translations.TranslationHelperFunc) toolsets.ServerTool { } // GetWorkflowRun creates a tool to get details of a specific workflow run -func GetWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetWorkflowRun(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -430,7 +430,7 @@ func GetWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { } // GetWorkflowRunLogs creates a tool to download logs for a specific workflow run -func GetWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetWorkflowRunLogs(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -508,7 +508,7 @@ func GetWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.ServerToo } // ListWorkflowJobs creates a tool to list jobs for a specific workflow run -func ListWorkflowJobs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListWorkflowJobs(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -608,7 +608,7 @@ func ListWorkflowJobs(t translations.TranslationHelperFunc) toolsets.ServerTool } // GetJobLogs creates a tool to download logs for a specific workflow job or efficiently get all failed job logs for a workflow run -func GetJobLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetJobLogs(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -706,10 +706,10 @@ func GetJobLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { if failedOnly && runID > 0 { // Handle failed-only mode: get logs for all failed jobs in the workflow run - return handleFailedJobLogs(ctx, client, owner, repo, int64(runID), returnContent, tailLines, deps.ContentWindowSize) + return handleFailedJobLogs(ctx, client, owner, repo, int64(runID), returnContent, tailLines, deps.GetContentWindowSize()) } else if jobID > 0 { // Handle single job mode - return handleSingleJobLogs(ctx, client, owner, repo, int64(jobID), returnContent, tailLines, deps.ContentWindowSize) + return handleSingleJobLogs(ctx, client, owner, repo, int64(jobID), returnContent, tailLines, deps.GetContentWindowSize()) } return utils.NewToolResultError("Either job_id must be provided for single job logs, or run_id with failed_only=true for failed job logs"), nil, nil @@ -873,7 +873,7 @@ func downloadLogContent(ctx context.Context, logURL string, tailLines int, maxLi } // RerunWorkflowRun creates a tool to re-run an entire workflow run -func RerunWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { +func RerunWorkflowRun(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -948,7 +948,7 @@ func RerunWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool } // RerunFailedJobs creates a tool to re-run only the failed jobs in a workflow run -func RerunFailedJobs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func RerunFailedJobs(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -1023,7 +1023,7 @@ func RerunFailedJobs(t translations.TranslationHelperFunc) toolsets.ServerTool { } // CancelWorkflowRun creates a tool to cancel a workflow run -func CancelWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CancelWorkflowRun(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -1100,7 +1100,7 @@ func CancelWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool } // ListWorkflowRunArtifacts creates a tool to list artifacts for a workflow run -func ListWorkflowRunArtifacts(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListWorkflowRunArtifacts(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -1180,7 +1180,7 @@ func ListWorkflowRunArtifacts(t translations.TranslationHelperFunc) toolsets.Ser } // DownloadWorkflowRunArtifact creates a tool to download a workflow run artifact -func DownloadWorkflowRunArtifact(t translations.TranslationHelperFunc) toolsets.ServerTool { +func DownloadWorkflowRunArtifact(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -1257,7 +1257,7 @@ func DownloadWorkflowRunArtifact(t translations.TranslationHelperFunc) toolsets. } // DeleteWorkflowRunLogs creates a tool to delete logs for a workflow run -func DeleteWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func DeleteWorkflowRunLogs(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -1333,7 +1333,7 @@ func DeleteWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.Server } // GetWorkflowRunUsage creates a tool to get usage metrics for a workflow run -func GetWorkflowRunUsage(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetWorkflowRunUsage(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ diff --git a/pkg/github/actions_test.go b/pkg/github/actions_test.go index 09ab3b2cc..4d56f01aa 100644 --- a/pkg/github/actions_test.go +++ b/pkg/github/actions_test.go @@ -105,8 +105,8 @@ func Test_ListWorkflows(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -194,8 +194,8 @@ func Test_RunWorkflow(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -290,8 +290,8 @@ func Test_RunWorkflow_WithFilename(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -398,8 +398,8 @@ func Test_CancelWorkflowRun(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -528,8 +528,8 @@ func Test_ListWorkflowRunArtifacts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -618,8 +618,8 @@ func Test_DownloadWorkflowRunArtifact(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -704,8 +704,8 @@ func Test_DeleteWorkflowRunLogs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -808,8 +808,8 @@ func Test_GetWorkflowRunUsage(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -1072,8 +1072,8 @@ func Test_GetJobLogs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, ContentWindowSize: 5000, } handler := toolDef.Handler(deps) @@ -1136,8 +1136,8 @@ func Test_GetJobLogs_WithContentReturn(t *testing.T) { client := github.NewClient(mockedClient) toolDef := GetJobLogs(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, ContentWindowSize: 5000, } handler := toolDef.Handler(deps) @@ -1188,8 +1188,8 @@ func Test_GetJobLogs_WithContentReturnAndTailLines(t *testing.T) { client := github.NewClient(mockedClient) toolDef := GetJobLogs(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, ContentWindowSize: 5000, } handler := toolDef.Handler(deps) @@ -1240,8 +1240,8 @@ func Test_GetJobLogs_WithContentReturnAndLargeTailLines(t *testing.T) { client := github.NewClient(mockedClient) toolDef := GetJobLogs(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, ContentWindowSize: 5000, } handler := toolDef.Handler(deps) diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 888ad4fd2..8826e4cf6 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -8,7 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,7 +16,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetCodeScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetCodeScanningAlert(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataCodeSecurity, mcp.Tool{ @@ -94,7 +94,7 @@ func GetCodeScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerT ) } -func ListCodeScanningAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListCodeScanningAlerts(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataCodeSecurity, mcp.Tool{ diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go index 5e56e6788..44c7a7e95 100644 --- a/pkg/github/code_scanning_test.go +++ b/pkg/github/code_scanning_test.go @@ -88,8 +88,8 @@ func Test_GetCodeScanningAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -220,8 +220,8 @@ func Test_ListCodeScanningAlerts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) diff --git a/pkg/github/context_tools.go b/pkg/github/context_tools.go index d5e0cfee9..837de00f7 100644 --- a/pkg/github/context_tools.go +++ b/pkg/github/context_tools.go @@ -6,7 +6,7 @@ import ( "time" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/jsonschema-go/jsonschema" @@ -37,7 +37,7 @@ type UserDetails struct { } // GetMe creates a tool to get details of the authenticated user. -func GetMe(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetMe(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataContext, mcp.Tool{ @@ -111,7 +111,7 @@ type OrganizationTeams struct { Teams []TeamInfo `json:"teams"` } -func GetTeams(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetTeams(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataContext, mcp.Tool{ @@ -210,7 +210,7 @@ func GetTeams(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func GetTeamMembers(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetTeamMembers(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataContext, mcp.Tool{ diff --git a/pkg/github/context_tools_test.go b/pkg/github/context_tools_test.go index 0e28aad49..e9faefc40 100644 --- a/pkg/github/context_tools_test.go +++ b/pkg/github/context_tools_test.go @@ -3,7 +3,7 @@ package github import ( "context" "encoding/json" - "fmt" + "net/http" "testing" "time" @@ -48,7 +48,8 @@ func Test_GetMe(t *testing.T) { tests := []struct { name string - stubbedGetClientFn GetClientFn + mockedClient *http.Client + clientErr string // if set, GetClient returns this error requestArgs map[string]any expectToolError bool expectedUser *github.User @@ -56,12 +57,10 @@ func Test_GetMe(t *testing.T) { }{ { name: "successful get user", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetUser, + mockUser, ), ), requestArgs: map[string]any{}, @@ -70,12 +69,10 @@ func Test_GetMe(t *testing.T) { }, { name: "successful get user with reason", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetUser, + mockUser, ), ), requestArgs: map[string]any{ @@ -86,19 +83,17 @@ func Test_GetMe(t *testing.T) { }, { name: "getting client fails", - stubbedGetClientFn: stubGetClientFnErr("expected test error"), + clientErr: "expected test error", requestArgs: map[string]any{}, expectToolError: true, expectedToolErrMsg: "failed to get GitHub client: expected test error", }, { name: "get user fails", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetUser, - badRequestHandler("expected test failure"), - ), + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUser, + badRequestHandler("expected test failure"), ), ), requestArgs: map[string]any{}, @@ -109,8 +104,11 @@ func Test_GetMe(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - deps := ToolDependencies{ - GetClient: tc.stubbedGetClientFn, + var deps ToolDependencies + if tc.clientErr != "" { + deps = stubDeps{clientFn: stubClientFnErr(tc.clientErr)} + } else { + deps = BaseDeps{Client: github.NewClient(tc.mockedClient)} } handler := serverTool.Handler(deps) @@ -223,49 +221,83 @@ func Test_GetTeams(t *testing.T) { }, }) + // Create GQL clients for different test scenarios - these are factory functions + // to ensure each test gets a fresh client + gqlClientForTestuser := func() *githubv4.Client { + queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" + vars := map[string]interface{}{ + "login": "testuser", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + gqlClientForSpecificuser := func() *githubv4.Client { + queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" + vars := map[string]interface{}{ + "login": "specificuser", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + gqlClientNoTeams := func() *githubv4.Client { + queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" + vars := map[string]interface{}{ + "login": "testuser", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoTeamsResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + // Factory function for mock HTTP clients with user response + httpClientWithUser := func() *http.Client { + return mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetUser, + mockUser, + ), + ) + } + + httpClientUserFails := func() *http.Client { + return mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUser, + badRequestHandler("expected test failure"), + ), + ) + } + tests := []struct { - name string - stubbedGetClientFn GetClientFn - stubbedGetGQLClientFn GetGQLClientFn - requestArgs map[string]any - expectToolError bool - expectedToolErrMsg string - expectedTeamsCount int + name string + makeDeps func() ToolDependencies + requestArgs map[string]any + expectToolError bool + expectedToolErrMsg string + expectedTeamsCount int }{ { name: "successful get teams", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), - ), - ), - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" - vars := map[string]interface{}{ - "login": "testuser", + makeDeps: func() ToolDependencies { + return BaseDeps{ + Client: github.NewClient(httpClientWithUser()), + GQLClient: gqlClientForTestuser(), } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil }, requestArgs: map[string]any{}, expectToolError: false, expectedTeamsCount: 2, }, { - name: "successful get teams for specific user", - stubbedGetClientFn: nil, - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" - vars := map[string]interface{}{ - "login": "specificuser", + name: "successful get teams for specific user", + makeDeps: func() ToolDependencies { + return BaseDeps{ + GQLClient: gqlClientForSpecificuser(), } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil }, requestArgs: map[string]any{ "user": "specificuser", @@ -275,62 +307,43 @@ func Test_GetTeams(t *testing.T) { }, { name: "no teams found", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), - ), - ), - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" - vars := map[string]interface{}{ - "login": "testuser", + makeDeps: func() ToolDependencies { + return BaseDeps{ + Client: github.NewClient(httpClientWithUser()), + GQLClient: gqlClientNoTeams(), } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoTeamsResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil }, requestArgs: map[string]any{}, expectToolError: false, expectedTeamsCount: 0, }, { - name: "getting client fails", - stubbedGetClientFn: stubGetClientFnErr("expected test error"), - stubbedGetGQLClientFn: nil, - requestArgs: map[string]any{}, - expectToolError: true, - expectedToolErrMsg: "failed to get GitHub client: expected test error", + name: "getting client fails", + makeDeps: func() ToolDependencies { + return stubDeps{clientFn: stubClientFnErr("expected test error")} + }, + requestArgs: map[string]any{}, + expectToolError: true, + expectedToolErrMsg: "failed to get GitHub client: expected test error", }, { name: "get user fails", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetUser, - badRequestHandler("expected test failure"), - ), - ), - ), - stubbedGetGQLClientFn: nil, - requestArgs: map[string]any{}, - expectToolError: true, - expectedToolErrMsg: "expected test failure", + makeDeps: func() ToolDependencies { + return BaseDeps{ + Client: github.NewClient(httpClientUserFails()), + } + }, + requestArgs: map[string]any{}, + expectToolError: true, + expectedToolErrMsg: "expected test failure", }, { name: "getting GraphQL client fails", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), - ), - ), - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - return nil, fmt.Errorf("GraphQL client error") + makeDeps: func() ToolDependencies { + return stubDeps{ + clientFn: stubClientFnFromHTTP(httpClientWithUser()), + gqlClientFn: stubGQLClientFnErr("GraphQL client error"), + } }, requestArgs: map[string]any{}, expectToolError: true, @@ -340,11 +353,7 @@ func Test_GetTeams(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - deps := ToolDependencies{ - GetClient: tc.stubbedGetClientFn, - GetGQLClient: tc.stubbedGetGQLClientFn, - } - handler := serverTool.Handler(deps) + handler := serverTool.Handler(tc.makeDeps()) request := createMCPRequest(tc.requestArgs) result, err := handler(context.Background(), &request) @@ -422,26 +431,40 @@ func Test_GetTeamMembers(t *testing.T) { }, }) + // Create GQL clients for different test scenarios + gqlClientWithMembers := func() *githubv4.Client { + queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" + vars := map[string]interface{}{ + "org": "testorg", + "teamSlug": "testteam", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamMembersResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + gqlClientNoMembers := func() *githubv4.Client { + queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" + vars := map[string]interface{}{ + "org": "testorg", + "teamSlug": "emptyteam", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoMembersResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + tests := []struct { - name string - stubbedGetGQLClientFn GetGQLClientFn - requestArgs map[string]any - expectToolError bool - expectedToolErrMsg string - expectedMembersCount int + name string + deps ToolDependencies + requestArgs map[string]any + expectToolError bool + expectedToolErrMsg string + expectedMembersCount int }{ { name: "successful get team members", - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" - vars := map[string]interface{}{ - "org": "testorg", - "teamSlug": "testteam", - } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamMembersResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil - }, + deps: BaseDeps{GQLClient: gqlClientWithMembers()}, requestArgs: map[string]any{ "org": "testorg", "team_slug": "testteam", @@ -451,16 +474,7 @@ func Test_GetTeamMembers(t *testing.T) { }, { name: "team with no members", - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" - vars := map[string]interface{}{ - "org": "testorg", - "teamSlug": "emptyteam", - } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoMembersResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil - }, + deps: BaseDeps{GQLClient: gqlClientNoMembers()}, requestArgs: map[string]any{ "org": "testorg", "team_slug": "emptyteam", @@ -470,9 +484,7 @@ func Test_GetTeamMembers(t *testing.T) { }, { name: "getting GraphQL client fails", - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - return nil, fmt.Errorf("GraphQL client error") - }, + deps: stubDeps{gqlClientFn: stubGQLClientFnErr("GraphQL client error")}, requestArgs: map[string]any{ "org": "testorg", "team_slug": "testteam", @@ -484,10 +496,7 @@ func Test_GetTeamMembers(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - deps := ToolDependencies{ - GetGQLClient: tc.stubbedGetGQLClientFn, - } - handler := serverTool.Handler(deps) + handler := serverTool.Handler(tc.deps) request := createMCPRequest(tc.requestArgs) result, err := handler(context.Background(), &request) diff --git a/pkg/github/dependabot.go b/pkg/github/dependabot.go index 1508d1382..daa2a124a 100644 --- a/pkg/github/dependabot.go +++ b/pkg/github/dependabot.go @@ -8,7 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,7 +16,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetDependabotAlert(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetDependabotAlert(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataDependabot, mcp.Tool{ @@ -94,7 +94,7 @@ func GetDependabotAlert(t translations.TranslationHelperFunc) toolsets.ServerToo ) } -func ListDependabotAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListDependabotAlerts(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataDependabot, mcp.Tool{ diff --git a/pkg/github/dependabot_test.go b/pkg/github/dependabot_test.go index ace0eb07a..614c6f383 100644 --- a/pkg/github/dependabot_test.go +++ b/pkg/github/dependabot_test.go @@ -81,7 +81,7 @@ func Test_GetDependabotAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) // Create call request @@ -232,7 +232,7 @@ func Test_ListDependabotAlerts(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index 7dcc33f75..040e61883 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -1,53 +1,123 @@ package github import ( + "context" + "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" + gogithub "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/shurcooL/githubv4" ) -// ToolDependencies contains all dependencies that tool handlers might need. -// This is a properly-typed struct that lives in pkg/github to avoid circular -// dependencies. The toolsets package uses `any` for deps and tool handlers -// type-assert to this struct. -type ToolDependencies struct { +// ToolDependencies defines the interface for dependencies that tool handlers need. +// This is an interface to allow different implementations: +// - Local server: stores closures that create clients on demand +// - Remote server: can store pre-created clients per-request for efficiency +// +// The toolsets package uses `any` for deps and tool handlers type-assert to this interface. +type ToolDependencies interface { // GetClient returns a GitHub REST API client - GetClient GetClientFn + GetClient(ctx context.Context) (*gogithub.Client, error) // GetGQLClient returns a GitHub GraphQL client - GetGQLClient GetGQLClientFn + GetGQLClient(ctx context.Context) (*githubv4.Client, error) + + // GetRawClient returns a raw content client for GitHub + GetRawClient(ctx context.Context) (*raw.Client, error) - // GetRawClient returns a raw HTTP client for GitHub - GetRawClient raw.GetRawClientFn + // GetRepoAccessCache returns the lockdown mode repo access cache + GetRepoAccessCache() *lockdown.RepoAccessCache - // RepoAccessCache is the lockdown mode repo access cache - RepoAccessCache *lockdown.RepoAccessCache + // GetT returns the translation helper function + GetT() translations.TranslationHelperFunc - // T is the translation helper function - T translations.TranslationHelperFunc + // GetFlags returns feature flags + GetFlags() FeatureFlags + + // GetContentWindowSize returns the content window size for log truncation + GetContentWindowSize() int +} - // Flags are feature flags - Flags FeatureFlags +// BaseDeps is the standard implementation of ToolDependencies for the local server. +// It stores pre-created clients. The remote server can create its own struct +// implementing ToolDependencies with different client creation strategies. +type BaseDeps struct { + // Pre-created clients + Client *gogithub.Client + GQLClient *githubv4.Client + RawClient *raw.Client - // ContentWindowSize is the size of the content window for log truncation + // Static dependencies + RepoAccessCache *lockdown.RepoAccessCache + T translations.TranslationHelperFunc + Flags FeatureFlags ContentWindowSize int } +// NewBaseDeps creates a BaseDeps with the provided clients and configuration. +func NewBaseDeps( + client *gogithub.Client, + gqlClient *githubv4.Client, + rawClient *raw.Client, + repoAccessCache *lockdown.RepoAccessCache, + t translations.TranslationHelperFunc, + flags FeatureFlags, + contentWindowSize int, +) *BaseDeps { + return &BaseDeps{ + Client: client, + GQLClient: gqlClient, + RawClient: rawClient, + RepoAccessCache: repoAccessCache, + T: t, + Flags: flags, + ContentWindowSize: contentWindowSize, + } +} + +// GetClient implements ToolDependencies. +func (d BaseDeps) GetClient(_ context.Context) (*gogithub.Client, error) { + return d.Client, nil +} + +// GetGQLClient implements ToolDependencies. +func (d BaseDeps) GetGQLClient(_ context.Context) (*githubv4.Client, error) { + return d.GQLClient, nil +} + +// GetRawClient implements ToolDependencies. +func (d BaseDeps) GetRawClient(_ context.Context) (*raw.Client, error) { + return d.RawClient, nil +} + +// GetRepoAccessCache implements ToolDependencies. +func (d BaseDeps) GetRepoAccessCache() *lockdown.RepoAccessCache { return d.RepoAccessCache } + +// GetT implements ToolDependencies. +func (d BaseDeps) GetT() translations.TranslationHelperFunc { return d.T } + +// GetFlags implements ToolDependencies. +func (d BaseDeps) GetFlags() FeatureFlags { return d.Flags } + +// GetContentWindowSize implements ToolDependencies. +func (d BaseDeps) GetContentWindowSize() int { return d.ContentWindowSize } + // NewTool creates a ServerTool with fully-typed ToolDependencies and toolset metadata. // This helper isolates the type assertion from `any` to `ToolDependencies`, // so tool implementations remain fully typed without assertions scattered throughout. -func NewTool[In, Out any](toolset toolsets.ToolsetMetadata, tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandlerFor[In, Out]) toolsets.ServerTool { - return toolsets.NewServerTool(tool, toolset, func(d any) mcp.ToolHandlerFor[In, Out] { +func NewTool[In, Out any](toolset registry.ToolsetMetadata, tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandlerFor[In, Out]) registry.ServerTool { + return registry.NewServerTool(tool, toolset, func(d any) mcp.ToolHandlerFor[In, Out] { return handler(d.(ToolDependencies)) }) } // NewToolFromHandler creates a ServerTool with fully-typed ToolDependencies and toolset metadata // for handlers that conform to mcp.ToolHandler directly. -func NewToolFromHandler(toolset toolsets.ToolsetMetadata, tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandler) toolsets.ServerTool { - return toolsets.NewServerToolFromHandler(tool, toolset, func(d any) mcp.ToolHandler { +func NewToolFromHandler(toolset registry.ToolsetMetadata, tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandler) registry.ServerTool { + return registry.NewServerToolFromHandler(tool, toolset, func(d any) mcp.ToolHandler { return handler(d.(ToolDependencies)) }) } diff --git a/pkg/github/discussions.go b/pkg/github/discussions.go index 5bbdb2b5f..50364fa58 100644 --- a/pkg/github/discussions.go +++ b/pkg/github/discussions.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/go-viper/mapstructure/v2" @@ -122,7 +122,7 @@ func getQueryType(useOrdering bool, categoryID *githubv4.ID) any { return &BasicNoOrder{} } -func ListDiscussions(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListDiscussions(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataDiscussions, mcp.Tool{ @@ -276,7 +276,7 @@ func ListDiscussions(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func GetDiscussion(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetDiscussion(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataDiscussions, mcp.Tool{ @@ -381,7 +381,7 @@ func GetDiscussion(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func GetDiscussionComments(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetDiscussionComments(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataDiscussions, mcp.Tool{ @@ -509,7 +509,7 @@ func GetDiscussionComments(t translations.TranslationHelperFunc) toolsets.Server ) } -func ListDiscussionCategories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListDiscussionCategories(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataDiscussions, mcp.Tool{ diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go index 758c82200..73ae66748 100644 --- a/pkg/github/discussions_test.go +++ b/pkg/github/discussions_test.go @@ -447,7 +447,7 @@ func Test_ListDiscussions(t *testing.T) { } gqlClient := githubv4.NewClient(httpClient) - deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + deps := BaseDeps{GQLClient: gqlClient} handler := toolDef.Handler(deps) req := createMCPRequest(tc.reqParams) @@ -559,7 +559,7 @@ func Test_GetDiscussion(t *testing.T) { matcher := githubv4mock.NewQueryMatcher(qGetDiscussion, vars, tc.response) httpClient := githubv4mock.NewMockedHTTPClient(matcher) gqlClient := githubv4.NewClient(httpClient) - deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + deps := BaseDeps{GQLClient: gqlClient} handler := toolDef.Handler(deps) reqParams := map[string]interface{}{"owner": "owner", "repo": "repo", "discussionNumber": int32(1)} @@ -639,7 +639,7 @@ func Test_GetDiscussionComments(t *testing.T) { matcher := githubv4mock.NewQueryMatcher(qGetComments, vars, mockResponse) httpClient := githubv4mock.NewMockedHTTPClient(matcher) gqlClient := githubv4.NewClient(httpClient) - deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + deps := BaseDeps{GQLClient: gqlClient} handler := toolDef.Handler(deps) reqParams := map[string]interface{}{ @@ -791,7 +791,7 @@ func Test_ListDiscussionCategories(t *testing.T) { httpClient := githubv4mock.NewMockedHTTPClient(matcher) gqlClient := githubv4.NewClient(httpClient) - deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + deps := BaseDeps{GQLClient: gqlClient} handler := toolDef.Handler(deps) req := createMCPRequest(tc.reqParams) diff --git a/pkg/github/dynamic_tools.go b/pkg/github/dynamic_tools.go index 93c24a07b..a749ecd1b 100644 --- a/pkg/github/dynamic_tools.go +++ b/pkg/github/dynamic_tools.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/jsonschema-go/jsonschema" @@ -19,7 +19,7 @@ type DynamicToolDependencies struct { // Server is the MCP server to register tools with Server *mcp.Server // Registry contains all available tools that can be enabled dynamically - Registry *toolsets.Registry + Registry *registry.Registry // ToolDeps are the dependencies passed to tools when they are registered ToolDeps any // T is the translation helper function @@ -27,14 +27,14 @@ type DynamicToolDependencies struct { } // NewDynamicTool creates a ServerTool with fully-typed DynamicToolDependencies. -func NewDynamicTool(toolset toolsets.ToolsetMetadata, tool mcp.Tool, handler func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any]) toolsets.ServerTool { - return toolsets.NewServerTool(tool, toolset, func(d any) mcp.ToolHandlerFor[map[string]any, any] { +func NewDynamicTool(toolset registry.ToolsetMetadata, tool mcp.Tool, handler func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any]) registry.ServerTool { + return registry.NewServerTool(tool, toolset, func(d any) mcp.ToolHandlerFor[map[string]any, any] { return handler(d.(DynamicToolDependencies)) }) } // toolsetIDsEnum returns the list of toolset IDs as an enum for JSON Schema. -func toolsetIDsEnum(r *toolsets.Registry) []any { +func toolsetIDsEnum(r *registry.Registry) []any { toolsetIDs := r.ToolsetIDs() result := make([]any, len(toolsetIDs)) for i, id := range toolsetIDs { @@ -44,10 +44,10 @@ func toolsetIDsEnum(r *toolsets.Registry) []any { } // DynamicTools returns the tools for dynamic toolset management. -// These tools allow runtime discovery and enablement of toolsets. +// These tools allow runtime discovery and enablement of registry. // The r parameter provides the available toolset IDs for JSON Schema enums. -func DynamicTools(r *toolsets.Registry) []toolsets.ServerTool { - return []toolsets.ServerTool{ +func DynamicTools(r *registry.Registry) []registry.ServerTool { + return []registry.ServerTool{ ListAvailableToolsets(), GetToolsetsTools(r), EnableToolset(r), @@ -55,7 +55,7 @@ func DynamicTools(r *toolsets.Registry) []toolsets.ServerTool { } // EnableToolset creates a tool that enables a toolset at runtime. -func EnableToolset(r *toolsets.Registry) toolsets.ServerTool { +func EnableToolset(r *registry.Registry) registry.ServerTool { return NewDynamicTool( ToolsetMetadataDynamic, mcp.Tool{ @@ -84,7 +84,7 @@ func EnableToolset(r *toolsets.Registry) toolsets.ServerTool { return utils.NewToolResultError(err.Error()), nil, nil } - toolsetID := toolsets.ToolsetID(toolsetName) + toolsetID := registry.ToolsetID(toolsetName) if !deps.Registry.HasToolset(toolsetID) { return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil @@ -109,8 +109,8 @@ func EnableToolset(r *toolsets.Registry) toolsets.ServerTool { ) } -// ListAvailableToolsets creates a tool that lists all available toolsets. -func ListAvailableToolsets() toolsets.ServerTool { +// ListAvailableToolsets creates a tool that lists all available registry. +func ListAvailableToolsets() registry.ServerTool { return NewDynamicTool( ToolsetMetadataDynamic, mcp.Tool{ @@ -153,7 +153,7 @@ func ListAvailableToolsets() toolsets.ServerTool { } // GetToolsetsTools creates a tool that lists all tools in a specific toolset. -func GetToolsetsTools(r *toolsets.Registry) toolsets.ServerTool { +func GetToolsetsTools(r *registry.Registry) registry.ServerTool { return NewDynamicTool( ToolsetMetadataDynamic, mcp.Tool{ @@ -182,7 +182,7 @@ func GetToolsetsTools(r *toolsets.Registry) toolsets.ServerTool { return utils.NewToolResultError(err.Error()), nil, nil } - toolsetID := toolsets.ToolsetID(toolsetName) + toolsetID := registry.ToolsetID(toolsetName) if !deps.Registry.HasToolset(toolsetID) { return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil diff --git a/pkg/github/dynamic_tools_test.go b/pkg/github/dynamic_tools_test.go new file mode 100644 index 000000000..4558204dc --- /dev/null +++ b/pkg/github/dynamic_tools_test.go @@ -0,0 +1,231 @@ +package github + +import ( + "context" + "encoding/json" + "testing" + + "github.com/github/github-mcp-server/pkg/registry" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// createDynamicRequest creates an MCP request with the given arguments for dynamic tools. +func createDynamicRequest(args map[string]any) *mcp.CallToolRequest { + argsJSON, _ := json.Marshal(args) + return &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Arguments: json.RawMessage(argsJSON), + }, + } +} + +func TestDynamicTools_ListAvailableToolsets(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewRegistry(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Registry: reg, + ToolDeps: nil, + T: translations.NullTranslationHelper, + } + + // Get the list_available_toolsets tool + tool := ListAvailableToolsets() + handler := tool.Handler(deps) + + // Call the handler + result, err := handler(context.Background(), createDynamicRequest(map[string]any{})) + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, result.Content, 1) + + // Parse the result + var toolsets []map[string]string + textContent := result.Content[0].(*mcp.TextContent) + err = json.Unmarshal([]byte(textContent.Text), &toolsets) + require.NoError(t, err) + + // Verify we got toolsets + assert.NotEmpty(t, toolsets, "should have available toolsets") + + // Find the repos toolset and verify it's not enabled + var reposToolset map[string]string + for _, ts := range toolsets { + if ts["name"] == "repos" { + reposToolset = ts + break + } + } + require.NotNil(t, reposToolset, "repos toolset should exist") + assert.Equal(t, "false", reposToolset["currently_enabled"], "repos should not be enabled initially") +} + +func TestDynamicTools_GetToolsetTools(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewRegistry(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Registry: reg, + ToolDeps: nil, + T: translations.NullTranslationHelper, + } + + // Get the get_toolset_tools tool + tool := GetToolsetsTools(reg) + handler := tool.Handler(deps) + + // Call the handler for repos toolset + result, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "repos", + })) + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, result.Content, 1) + + // Parse the result + var tools []map[string]string + textContent := result.Content[0].(*mcp.TextContent) + err = json.Unmarshal([]byte(textContent.Text), &tools) + require.NoError(t, err) + + // Verify we got tools for the repos toolset + assert.NotEmpty(t, tools, "repos toolset should have tools") + + // Verify at least get_commit is there (a repos toolset tool) + var foundGetCommit bool + for _, tool := range tools { + if tool["name"] == "get_commit" { + foundGetCommit = true + break + } + } + assert.True(t, foundGetCommit, "get_commit should be in repos toolset") +} + +func TestDynamicTools_EnableToolset(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewRegistry(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Registry: reg, + ToolDeps: NewBaseDeps(nil, nil, nil, nil, translations.NullTranslationHelper, FeatureFlags{}, 0), + T: translations.NullTranslationHelper, + } + + // Verify repos is not enabled initially + assert.False(t, reg.IsToolsetEnabled(registry.ToolsetID("repos"))) + + // Get the enable_toolset tool + tool := EnableToolset(reg) + handler := tool.Handler(deps) + + // Enable the repos toolset + result, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "repos", + })) + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, result.Content, 1) + + // Verify the toolset is now enabled + assert.True(t, reg.IsToolsetEnabled(registry.ToolsetID("repos")), "repos should be enabled after enable_toolset") + + // Verify the success message + textContent := result.Content[0].(*mcp.TextContent) + assert.Contains(t, textContent.Text, "enabled") + + // Try enabling again - should say already enabled + result2, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "repos", + })) + require.NoError(t, err) + textContent2 := result2.Content[0].(*mcp.TextContent) + assert.Contains(t, textContent2.Text, "already enabled") +} + +func TestDynamicTools_EnableToolset_InvalidToolset(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewRegistry(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Registry: reg, + ToolDeps: nil, + T: translations.NullTranslationHelper, + } + + // Get the enable_toolset tool + tool := EnableToolset(reg) + handler := tool.Handler(deps) + + // Try to enable a non-existent toolset + result, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "nonexistent", + })) + require.NoError(t, err) + require.NotNil(t, result) + + // Should be an error result + textContent := result.Content[0].(*mcp.TextContent) + assert.Contains(t, textContent.Text, "not found") +} + +func TestDynamicTools_ToolsetsEnum(t *testing.T) { + // Build a registry + reg := NewRegistry(translations.NullTranslationHelper).Build() + + // Get tools to verify they have proper enum values + tools := DynamicTools(reg) + + // Find enable_toolset and get_toolset_tools + for _, tool := range tools { + if tool.Tool.Name == "enable_toolset" || tool.Tool.Name == "get_toolset_tools" { + // Verify the toolset property has an enum + schema := tool.Tool.InputSchema.(*jsonschema.Schema) + toolsetProp := schema.Properties["toolset"] + require.NotNil(t, toolsetProp, "toolset property should exist") + assert.NotEmpty(t, toolsetProp.Enum, "toolset property should have enum values") + + // Verify repos is in the enum + var foundRepos bool + for _, v := range toolsetProp.Enum { + if v == registry.ToolsetID("repos") { + foundRepos = true + break + } + } + assert.True(t, foundRepos, "repos should be in toolset enum for %s", tool.Tool.Name) + } + } +} diff --git a/pkg/github/gists.go b/pkg/github/gists.go index 03e5e1bc8..7b8313f37 100644 --- a/pkg/github/gists.go +++ b/pkg/github/gists.go @@ -7,7 +7,7 @@ import ( "io" "net/http" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,7 +16,7 @@ import ( ) // ListGists creates a tool to list gists for a user -func ListGists(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListGists(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataGists, mcp.Tool{ @@ -104,7 +104,7 @@ func ListGists(t translations.TranslationHelperFunc) toolsets.ServerTool { } // GetGist creates a tool to get the content of a gist -func GetGist(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetGist(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataGists, mcp.Tool{ @@ -163,7 +163,7 @@ func GetGist(t translations.TranslationHelperFunc) toolsets.ServerTool { } // CreateGist creates a tool to create a new gist -func CreateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CreateGist(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataGists, mcp.Tool{ @@ -267,7 +267,7 @@ func CreateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { } // UpdateGist creates a tool to edit an existing gist -func UpdateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { +func UpdateGist(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataGists, mcp.Tool{ diff --git a/pkg/github/gists_test.go b/pkg/github/gists_test.go index 44b294eb6..7c6f69833 100644 --- a/pkg/github/gists_test.go +++ b/pkg/github/gists_test.go @@ -158,8 +158,8 @@ func Test_ListGists(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -275,8 +275,8 @@ func Test_GetGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -421,8 +421,8 @@ func Test_CreateGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -580,8 +580,8 @@ func Test_UpdateGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) diff --git a/pkg/github/git.go b/pkg/github/git.go index e619afc34..4755e2eb0 100644 --- a/pkg/github/git.go +++ b/pkg/github/git.go @@ -7,7 +7,7 @@ import ( "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -38,7 +38,7 @@ type TreeResponse struct { } // GetRepositoryTree creates a tool to get the tree structure of a GitHub repository. -func GetRepositoryTree(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetRepositoryTree(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataGit, mcp.Tool{ diff --git a/pkg/github/git_test.go b/pkg/github/git_test.go index 69442e312..c971995b2 100644 --- a/pkg/github/git_test.go +++ b/pkg/github/git_test.go @@ -148,8 +148,8 @@ func Test_GetRepositoryTree(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go index 9c55ba841..4fdf5f928 100644 --- a/pkg/github/helper_test.go +++ b/pkg/github/helper_test.go @@ -216,7 +216,7 @@ func TestOptionalParamOK(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Test with string type assertion if _, isString := tc.expectedVal.(string); isString || tc.errorMsg == "parameter myParam is not of type string, is bool" { - val, ok, err := OptionalParamOK[string, map[string]any](tc.args, tc.paramName) + val, ok, err := OptionalParamOK[string](tc.args, tc.paramName) if tc.expectError { require.Error(t, err) assert.Contains(t, err.Error(), tc.errorMsg) @@ -231,7 +231,7 @@ func TestOptionalParamOK(t *testing.T) { // Test with bool type assertion if _, isBool := tc.expectedVal.(bool); isBool || tc.errorMsg == "parameter myParam is not of type bool, is string" { - val, ok, err := OptionalParamOK[bool, map[string]any](tc.args, tc.paramName) + val, ok, err := OptionalParamOK[bool](tc.args, tc.paramName) if tc.expectError { require.Error(t, err) assert.Contains(t, err.Error(), tc.errorMsg) @@ -246,7 +246,7 @@ func TestOptionalParamOK(t *testing.T) { // Test with float64 type assertion (for number case) if _, isFloat := tc.expectedVal.(float64); isFloat { - val, ok, err := OptionalParamOK[float64, map[string]any](tc.args, tc.paramName) + val, ok, err := OptionalParamOK[float64](tc.args, tc.paramName) if tc.expectError { // This case shouldn't happen for float64 in the defined tests require.Fail(t, "Unexpected error case for float64") diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 1d0e3b2d5..3e449f8c5 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -11,8 +11,8 @@ import ( ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/sanitize" - "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/go-viper/mapstructure/v2" @@ -230,7 +230,7 @@ func fragmentToIssue(fragment IssueFragment) *github.Issue { } // IssueRead creates a tool to get details of a specific issue in a GitHub repository. -func IssueRead(t translations.TranslationHelperFunc) toolsets.ServerTool { +func IssueRead(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -310,13 +310,13 @@ Options are: switch method { case "get": - result, err := GetIssue(ctx, client, deps.RepoAccessCache, owner, repo, issueNumber, deps.Flags) + result, err := GetIssue(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, deps.GetFlags()) return result, nil, err case "get_comments": - result, err := GetIssueComments(ctx, client, deps.RepoAccessCache, owner, repo, issueNumber, pagination, deps.Flags) + result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) return result, nil, err case "get_sub_issues": - result, err := GetSubIssues(ctx, client, deps.RepoAccessCache, owner, repo, issueNumber, pagination, deps.Flags) + result, err := GetSubIssues(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) return result, nil, err case "get_labels": result, err := GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber) @@ -545,7 +545,7 @@ func GetIssueLabels(ctx context.Context, client *githubv4.Client, owner string, } // ListIssueTypes creates a tool to list defined issue types for an organization. This can be used to understand supported issue type values for creating or updating issues. -func ListIssueTypes(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListIssueTypes(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataIssues, mcp.Tool{ @@ -602,7 +602,7 @@ func ListIssueTypes(t translations.TranslationHelperFunc) toolsets.ServerTool { } // AddIssueComment creates a tool to add a comment to an issue. -func AddIssueComment(t translations.TranslationHelperFunc) toolsets.ServerTool { +func AddIssueComment(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataIssues, mcp.Tool{ @@ -687,7 +687,7 @@ func AddIssueComment(t translations.TranslationHelperFunc) toolsets.ServerTool { } // SubIssueWrite creates a tool to add a sub-issue to a parent issue. -func SubIssueWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SubIssueWrite(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataIssues, mcp.Tool{ @@ -916,7 +916,7 @@ func ReprioritizeSubIssue(ctx context.Context, client *github.Client, owner stri } // SearchIssues creates a tool to search for issues. -func SearchIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchIssues(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -979,7 +979,7 @@ func SearchIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { } // IssueWrite creates a tool to create a new or update an existing issue in a GitHub repository. -func IssueWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { +func IssueWrite(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataIssues, mcp.Tool{ @@ -1338,7 +1338,7 @@ func UpdateIssue(ctx context.Context, client *github.Client, gqlClient *githubv4 } // ListIssues creates a tool to list and filter repository issues -func ListIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListIssues(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1606,7 +1606,7 @@ func (d *mvpDescription) String() string { return sb.String() } -func AssignCopilotToIssue(t translations.TranslationHelperFunc) toolsets.ServerTool { +func AssignCopilotToIssue(t translations.TranslationHelperFunc) registry.ServerTool { description := mvpDescription{ summary: "Assign Copilot to a specific issue in a GitHub repository.", outcomes: []string{ @@ -1805,8 +1805,8 @@ func parseISOTimestamp(timestamp string) (time.Time, error) { return time.Time{}, fmt.Errorf("invalid ISO 8601 timestamp: %s (supported formats: YYYY-MM-DDThh:mm:ssZ or YYYY-MM-DD)", timestamp) } -func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) toolsets.ServerPrompt { - return toolsets.NewServerPrompt( +func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) registry.ServerPrompt { + return registry.NewServerPrompt( ToolsetMetadataIssues, mcp.Prompt{ Name: "AssignCodingAgent", diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index c832f031a..4c686cc57 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -330,9 +330,9 @@ func Test_GetIssue(t *testing.T) { } flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, RepoAccessCache: cache, Flags: flags, } @@ -447,8 +447,8 @@ func Test_AddIssueComment(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -781,8 +781,8 @@ func Test_SearchIssues(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -952,9 +952,9 @@ func Test_CreateIssue(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, } handler := serverTool.Handler(deps) @@ -1268,8 +1268,8 @@ func Test_ListIssues(t *testing.T) { } gqlClient := githubv4.NewClient(httpClient) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + GQLClient: gqlClient, } handler := serverTool.Handler(deps) @@ -1769,9 +1769,9 @@ func Test_UpdateIssue(t *testing.T) { // Setup clients with mocks restClient := github.NewClient(tc.mockedRESTClient) gqlClient := githubv4.NewClient(tc.mockedGQLClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(restClient), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: restClient, + GQLClient: gqlClient, } handler := serverTool.Handler(deps) @@ -2016,9 +2016,9 @@ func Test_GetIssueComments(t *testing.T) { } cache := stubRepoAccessCache(gqlClient, 15*time.Minute) flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, RepoAccessCache: cache, Flags: flags, } @@ -2136,9 +2136,9 @@ func Test_GetIssueLabels(t *testing.T) { t.Run(tc.name, func(t *testing.T) { gqlClient := githubv4.NewClient(tc.mockedClient) client := github.NewClient(nil) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, RepoAccessCache: stubRepoAccessCache(gqlClient, 15*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } @@ -2560,8 +2560,8 @@ func TestAssignCopilotToIssue(t *testing.T) { t.Parallel() // Setup client with mock client := githubv4.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -2791,8 +2791,8 @@ func Test_AddSubIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3035,9 +3035,9 @@ func Test_GetSubIssues(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, RepoAccessCache: stubRepoAccessCache(gqlClient, 15*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } @@ -3275,8 +3275,8 @@ func Test_RemoveSubIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3564,8 +3564,8 @@ func Test_ReprioritizeSubIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3698,8 +3698,8 @@ func Test_ListIssueTypes(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) diff --git a/pkg/github/labels.go b/pkg/github/labels.go index a98468fae..6088aaa8e 100644 --- a/pkg/github/labels.go +++ b/pkg/github/labels.go @@ -7,7 +7,7 @@ import ( "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/jsonschema-go/jsonschema" @@ -16,9 +16,9 @@ import ( ) // GetLabel retrieves a specific label by name from a GitHub repository -func GetLabel(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetLabel(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( - ToolsetLabels, + ToolsetMetadataIssues, mcp.Tool{ Name: "get_label", Description: t("TOOL_GET_LABEL_DESCRIPTION", "Get a specific label from a repository."), @@ -110,8 +110,16 @@ func GetLabel(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } +// GetLabelForLabelsToolset returns the same GetLabel tool but registered in the labels toolset. +// This provides conformance with the original behavior where get_label was in both toolsets. +func GetLabelForLabelsToolset(t translations.TranslationHelperFunc) registry.ServerTool { + tool := GetLabel(t) + tool.Toolset = ToolsetLabels + return tool +} + // ListLabels lists labels from a repository -func ListLabels(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListLabels(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetLabels, mcp.Tool{ @@ -203,7 +211,7 @@ func ListLabels(t translations.TranslationHelperFunc) toolsets.ServerTool { } // LabelWrite handles create, update, and delete operations for GitHub labels -func LabelWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { +func LabelWrite(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetLabels, mcp.Tool{ diff --git a/pkg/github/labels_test.go b/pkg/github/labels_test.go index 980395ff7..fa646e884 100644 --- a/pkg/github/labels_test.go +++ b/pkg/github/labels_test.go @@ -114,8 +114,8 @@ func TestGetLabel(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := githubv4.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -212,8 +212,8 @@ func TestListLabels(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := githubv4.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -463,8 +463,8 @@ func TestWriteLabel(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := githubv4.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index 4eb2d7b5b..569bef002 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -10,7 +10,7 @@ import ( "time" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -25,7 +25,7 @@ const ( ) // ListNotifications creates a tool to list notifications for the current user. -func ListNotifications(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListNotifications(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataNotifications, mcp.Tool{ @@ -163,7 +163,7 @@ func ListNotifications(t translations.TranslationHelperFunc) toolsets.ServerTool } // DismissNotification creates a tool to mark a notification as read/done. -func DismissNotification(t translations.TranslationHelperFunc) toolsets.ServerTool { +func DismissNotification(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataNotifications, mcp.Tool{ @@ -246,7 +246,7 @@ func DismissNotification(t translations.TranslationHelperFunc) toolsets.ServerTo } // MarkAllNotificationsRead creates a tool to mark all notifications as read. -func MarkAllNotificationsRead(t translations.TranslationHelperFunc) toolsets.ServerTool { +func MarkAllNotificationsRead(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataNotifications, mcp.Tool{ @@ -339,7 +339,7 @@ func MarkAllNotificationsRead(t translations.TranslationHelperFunc) toolsets.Ser } // GetNotificationDetails creates a tool to get details for a specific notification. -func GetNotificationDetails(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetNotificationDetails(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataNotifications, mcp.Tool{ @@ -409,7 +409,7 @@ const ( ) // ManageNotificationSubscription creates a tool to manage a notification subscription (ignore, watch, delete) -func ManageNotificationSubscription(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ManageNotificationSubscription(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataNotifications, mcp.Tool{ @@ -506,7 +506,7 @@ const ( ) // ManageRepositoryNotificationSubscription creates a tool to manage a repository notification subscription (ignore, watch, delete) -func ManageRepositoryNotificationSubscription(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ManageRepositoryNotificationSubscription(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataNotifications, mcp.Tool{ diff --git a/pkg/github/notifications_test.go b/pkg/github/notifications_test.go index 0a330c316..f730654db 100644 --- a/pkg/github/notifications_test.go +++ b/pkg/github/notifications_test.go @@ -125,8 +125,8 @@ func Test_ListNotifications(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -258,8 +258,8 @@ func Test_ManageNotificationSubscription(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -421,8 +421,8 @@ func Test_ManageRepositoryNotificationSubscription(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -563,8 +563,8 @@ func Test_DismissNotification(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -688,8 +688,8 @@ func Test_MarkAllNotificationsRead(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -772,8 +772,8 @@ func Test_GetNotificationDetails(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/projects.go b/pkg/github/projects.go index a12aca7be..4c0b5c09d 100644 --- a/pkg/github/projects.go +++ b/pkg/github/projects.go @@ -9,7 +9,7 @@ import ( "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -25,7 +25,7 @@ const ( MaxProjectsPerPage = 50 ) -func ListProjects(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListProjects(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -144,7 +144,7 @@ func ListProjects(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func GetProject(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetProject(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -234,7 +234,7 @@ func GetProject(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func ListProjectFields(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListProjectFields(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -342,7 +342,7 @@ func ListProjectFields(t translations.TranslationHelperFunc) toolsets.ServerTool ) } -func GetProjectField(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetProjectField(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -436,7 +436,7 @@ func GetProjectField(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func ListProjectItems(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListProjectItems(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -574,7 +574,7 @@ func ListProjectItems(t translations.TranslationHelperFunc) toolsets.ServerTool ) } -func GetProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetProjectItem(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -682,7 +682,7 @@ func GetProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func AddProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { +func AddProjectItem(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -795,7 +795,7 @@ func AddProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func UpdateProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { +func UpdateProjectItem(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -909,7 +909,7 @@ func UpdateProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool ) } -func DeleteProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { +func DeleteProjectItem(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ diff --git a/pkg/github/projects_test.go b/pkg/github/projects_test.go index 0c2e2ab52..67ecd8800 100644 --- a/pkg/github/projects_test.go +++ b/pkg/github/projects_test.go @@ -141,8 +141,8 @@ func Test_ListProjects(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -280,8 +280,8 @@ func Test_GetProject(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -432,8 +432,8 @@ func Test_ListProjectFields(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -592,8 +592,8 @@ func Test_GetProjectField(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -798,8 +798,8 @@ func Test_ListProjectItems(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -995,8 +995,8 @@ func Test_GetProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -1224,8 +1224,8 @@ func Test_AddProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -1509,8 +1509,8 @@ func Test_UpdateProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -1676,8 +1676,8 @@ func Test_DeleteProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/prompts.go b/pkg/github/prompts.go index 82d7bf514..229902d90 100644 --- a/pkg/github/prompts.go +++ b/pkg/github/prompts.go @@ -1,14 +1,14 @@ package github import ( - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" ) // AllPrompts returns all prompts with their embedded toolset metadata. // Prompt functions return ServerPrompt directly with toolset info. -func AllPrompts(t translations.TranslationHelperFunc) []toolsets.ServerPrompt { - return []toolsets.ServerPrompt{ +func AllPrompts(t translations.TranslationHelperFunc) []registry.ServerPrompt { + return []registry.ServerPrompt{ // Issue prompts AssignCodingAgentPrompt(t), IssueToFixWorkflowPrompt(t), diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 229e20e57..4e7ede755 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -15,14 +15,14 @@ import ( ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/sanitize" - "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" ) // PullRequestRead creates a tool to get details of a specific pull request. -func PullRequestRead(t translations.TranslationHelperFunc) toolsets.ServerTool { +func PullRequestRead(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -99,7 +99,7 @@ Possible options: switch method { case "get": - result, err := GetPullRequest(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, deps.Flags) + result, err := GetPullRequest(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) return result, nil, err case "get_diff": result, err := GetPullRequestDiff(ctx, client, owner, repo, pullNumber) @@ -111,13 +111,13 @@ Possible options: result, err := GetPullRequestFiles(ctx, client, owner, repo, pullNumber, pagination) return result, nil, err case "get_review_comments": - result, err := GetPullRequestReviewComments(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, pagination, deps.Flags) + result, err := GetPullRequestReviewComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, pagination, deps.GetFlags()) return result, nil, err case "get_reviews": - result, err := GetPullRequestReviews(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, deps.Flags) + result, err := GetPullRequestReviews(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) return result, nil, err case "get_comments": - result, err := GetIssueComments(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, pagination, deps.Flags) + result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, pagination, deps.GetFlags()) return result, nil, err default: return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil @@ -390,7 +390,7 @@ func GetPullRequestReviews(ctx context.Context, client *github.Client, cache *lo } // CreatePullRequest creates a tool to create a new pull request. -func CreatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CreatePullRequest(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -531,7 +531,7 @@ func CreatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool } // UpdatePullRequest creates a tool to update an existing pull request. -func UpdatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool { +func UpdatePullRequest(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -826,7 +826,7 @@ func UpdatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool } // ListPullRequests creates a tool to list and filter repository pull requests. -func ListPullRequests(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListPullRequests(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -970,7 +970,7 @@ func ListPullRequests(t translations.TranslationHelperFunc) toolsets.ServerTool } // MergePullRequest creates a tool to merge a pull request. -func MergePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool { +func MergePullRequest(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1079,7 +1079,7 @@ func MergePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool } // SearchPullRequests creates a tool to search for pull requests. -func SearchPullRequests(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchPullRequests(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1142,7 +1142,7 @@ func SearchPullRequests(t translations.TranslationHelperFunc) toolsets.ServerToo } // UpdatePullRequestBranch creates a tool to update a pull request branch with the latest changes from the base branch. -func UpdatePullRequestBranch(t translations.TranslationHelperFunc) toolsets.ServerTool { +func UpdatePullRequestBranch(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1247,7 +1247,7 @@ type PullRequestReviewWriteParams struct { CommitID *string } -func PullRequestReviewWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { +func PullRequestReviewWrite(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1559,7 +1559,7 @@ func DeletePendingPullRequestReview(ctx context.Context, client *githubv4.Client } // AddCommentToPendingReview creates a tool to add a comment to a pull request review. -func AddCommentToPendingReview(t translations.TranslationHelperFunc) toolsets.ServerTool { +func AddCommentToPendingReview(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1752,7 +1752,7 @@ func AddCommentToPendingReview(t translations.TranslationHelperFunc) toolsets.Se // RequestCopilotReview creates a tool to request a Copilot review for a pull request. // Note that this tool will not work on GHES where this feature is unsupported. In future, we should not expose this // tool if the configured host does not support it. -func RequestCopilotReview(t translations.TranslationHelperFunc) toolsets.ServerTool { +func RequestCopilotReview(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 7531edf6d..a22245700 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -105,9 +105,9 @@ func Test_GetPullRequest(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(githubv4mock.NewMockedHTTPClient()) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, RepoAccessCache: stubRepoAccessCache(gqlClient, 5*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } @@ -370,9 +370,9 @@ func Test_UpdatePullRequest(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, } handler := serverTool.Handler(deps) @@ -561,9 +561,9 @@ func Test_UpdatePullRequest_Draft(t *testing.T) { gqlClient := githubv4.NewClient(tc.mockedClient) serverTool := UpdatePullRequest(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(restClient), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: restClient, + GQLClient: gqlClient, } handler := serverTool.Handler(deps) @@ -693,8 +693,8 @@ func Test_ListPullRequests(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := ListPullRequests(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -817,8 +817,8 @@ func Test_MergePullRequest(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := MergePullRequest(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1123,8 +1123,8 @@ func Test_SearchPullRequests(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := SearchPullRequests(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1276,8 +1276,8 @@ func Test_GetPullRequestFiles(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := PullRequestRead(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, RepoAccessCache: stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } @@ -1451,8 +1451,8 @@ func Test_GetPullRequestStatus(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := PullRequestRead(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, RepoAccessCache: stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } @@ -1589,8 +1589,8 @@ func Test_UpdatePullRequestBranch(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := UpdatePullRequestBranch(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1763,8 +1763,8 @@ func Test_GetPullRequestComments(t *testing.T) { cache := stubRepoAccessCache(gqlClient, 5*time.Minute) flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) serverTool := PullRequestRead(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, RepoAccessCache: cache, Flags: flags, } @@ -1951,8 +1951,8 @@ func Test_GetPullRequestReviews(t *testing.T) { cache := stubRepoAccessCache(gqlClient, 5*time.Minute) flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) serverTool := PullRequestRead(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, RepoAccessCache: cache, Flags: flags, } @@ -2115,8 +2115,8 @@ func Test_CreatePullRequest(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := CreatePullRequest(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -2331,8 +2331,8 @@ func TestCreateAndSubmitPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -2447,8 +2447,8 @@ func Test_RequestCopilotReview(t *testing.T) { client := github.NewClient(tc.mockedClient) serverTool := RequestCopilotReview(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -2641,8 +2641,8 @@ func TestCreatePendingPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -2824,8 +2824,8 @@ func TestAddPullRequestReviewCommentToPendingReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) serverTool := AddCommentToPendingReview(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -2929,8 +2929,8 @@ func TestSubmitPendingPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -3028,8 +3028,8 @@ func TestDeletePendingPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -3119,8 +3119,8 @@ index 5d6e7b2..8a4f5c3 100644 // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := PullRequestRead(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, RepoAccessCache: stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } diff --git a/pkg/github/toolset_group.go b/pkg/github/registry.go similarity index 78% rename from pkg/github/toolset_group.go rename to pkg/github/registry.go index 7330e08d3..88795ee1e 100644 --- a/pkg/github/toolset_group.go +++ b/pkg/github/registry.go @@ -1,7 +1,7 @@ package github import ( - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" ) @@ -10,8 +10,8 @@ import ( // This function is stateless - no dependencies are captured. // Handlers are generated on-demand during registration via RegisterAll(ctx, server, deps). // The "default" keyword in WithToolsets will expand to toolsets marked with Default: true. -func NewRegistry(t translations.TranslationHelperFunc) *toolsets.Registry { - return toolsets.NewRegistry(). +func NewRegistry(t translations.TranslationHelperFunc) *registry.Builder { + return registry.NewBuilder(). SetTools(AllTools(t)). SetResources(AllResources(t)). SetPrompts(AllPrompts(t)) diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 81e5c3a8c..854da679a 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -11,7 +11,7 @@ import ( ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -19,7 +19,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetCommit(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetCommit(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -118,7 +118,7 @@ func GetCommit(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListCommits creates a tool to get commits of a branch in a repository. -func ListCommits(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListCommits(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -227,7 +227,7 @@ func ListCommits(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListBranches creates a tool to list branches in a GitHub repository. -func ListBranches(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListBranches(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -315,7 +315,7 @@ func ListBranches(t translations.TranslationHelperFunc) toolsets.ServerTool { } // CreateOrUpdateFile creates a tool to create or update a file in a GitHub repository. -func CreateOrUpdateFile(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CreateOrUpdateFile(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -443,7 +443,7 @@ func CreateOrUpdateFile(t translations.TranslationHelperFunc) toolsets.ServerToo } // CreateRepository creates a tool to create a new GitHub repository. -func CreateRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CreateRepository(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -550,7 +550,7 @@ func CreateRepository(t translations.TranslationHelperFunc) toolsets.ServerTool } // GetFileContents creates a tool to get the contents of a file or directory from a GitHub repository. -func GetFileContents(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetFileContents(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -770,7 +770,7 @@ func GetFileContents(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ForkRepository creates a tool to fork a repository. -func ForkRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ForkRepository(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -869,7 +869,7 @@ func ForkRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { // unlike how the endpoint backing the create_or_update_files tool does. This appears to be a quirk of the API. // The approach implemented here gets automatic commit signing when used with either the github-actions user or as an app, // both of which suit an LLM well. -func DeleteFile(t translations.TranslationHelperFunc) toolsets.ServerTool { +func DeleteFile(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1055,7 +1055,7 @@ func DeleteFile(t translations.TranslationHelperFunc) toolsets.ServerTool { } // CreateBranch creates a tool to create a new branch. -func CreateBranch(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CreateBranch(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1169,7 +1169,7 @@ func CreateBranch(t translations.TranslationHelperFunc) toolsets.ServerTool { } // PushFiles creates a tool to push multiple files in a single commit to a GitHub repository. -func PushFiles(t translations.TranslationHelperFunc) toolsets.ServerTool { +func PushFiles(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1354,7 +1354,7 @@ func PushFiles(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListTags creates a tool to list tags in a GitHub repository. -func ListTags(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListTags(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1434,7 +1434,7 @@ func ListTags(t translations.TranslationHelperFunc) toolsets.ServerTool { } // GetTag creates a tool to get details about a specific tag in a GitHub repository. -func GetTag(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetTag(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1533,7 +1533,7 @@ func GetTag(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListReleases creates a tool to list releases in a GitHub repository. -func ListReleases(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListReleases(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1609,7 +1609,7 @@ func ListReleases(t translations.TranslationHelperFunc) toolsets.ServerTool { } // GetLatestRelease creates a tool to get the latest release in a GitHub repository. -func GetLatestRelease(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetLatestRelease(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1675,7 +1675,7 @@ func GetLatestRelease(t translations.TranslationHelperFunc) toolsets.ServerTool ) } -func GetReleaseByTag(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetReleaseByTag(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1889,7 +1889,7 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner } // ListStarredRepositories creates a tool to list starred repositories for the authenticated user or a specified user. -func ListStarredRepositories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListStarredRepositories(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataStargazers, mcp.Tool{ @@ -2022,7 +2022,7 @@ func ListStarredRepositories(t translations.TranslationHelperFunc) toolsets.Serv } // StarRepository creates a tool to star a repository. -func StarRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { +func StarRepository(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataStargazers, mcp.Tool{ @@ -2088,7 +2088,7 @@ func StarRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { } // UnstarRepository creates a tool to unstar a repository. -func UnstarRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { +func UnstarRepository(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataStargazers, mcp.Tool{ diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index 949686d92..55f0866cb 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -286,9 +286,9 @@ func Test_GetFileContents(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) mockRawClient := raw.NewClient(client, &url.URL{Scheme: "https", Host: "raw.example.com", Path: "/"}) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetRawClient: stubGetRawClientFn(mockRawClient), + deps := BaseDeps{ + Client: client, + RawClient: mockRawClient, } handler := serverTool.Handler(deps) @@ -410,8 +410,8 @@ func Test_ForkRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -606,8 +606,8 @@ func Test_CreateBranch(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -738,8 +738,8 @@ func Test_GetCommit(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -964,8 +964,8 @@ func Test_ListCommits(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1143,8 +1143,8 @@ func Test_CreateOrUpdateFile(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1329,8 +1329,8 @@ func Test_CreateRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1668,8 +1668,8 @@ func Test_PushFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1789,8 +1789,8 @@ func Test_ListBranches(t *testing.T) { t.Run(tt.name, func(t *testing.T) { // Create mock client mockClient := github.NewClient(mock.NewMockedHTTPClient(tt.mockResponses...)) - deps := ToolDependencies{ - GetClient: stubGetClientFn(mockClient), + deps := BaseDeps{ + Client: mockClient, } handler := serverTool.Handler(deps) @@ -1977,8 +1977,8 @@ func Test_DeleteFile(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -2104,8 +2104,8 @@ func Test_ListTags(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -2264,8 +2264,8 @@ func Test_GetTag(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -2377,8 +2377,8 @@ func Test_ListReleases(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -2468,8 +2468,8 @@ func Test_GetLatestRelease(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -2616,8 +2616,8 @@ func Test_GetReleaseByTag(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3128,8 +3128,8 @@ func Test_ListStarredRepositories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3229,8 +3229,8 @@ func Test_StarRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3320,8 +3320,8 @@ func Test_UnstarRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) diff --git a/pkg/github/repository_resource.go b/pkg/github/repository_resource.go index 6dbbe90ec..af001af6f 100644 --- a/pkg/github/repository_resource.go +++ b/pkg/github/repository_resource.go @@ -14,7 +14,7 @@ import ( "strings" "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -30,8 +30,8 @@ var ( ) // GetRepositoryResourceContent defines the resource template for getting repository content. -func GetRepositoryResourceContent(t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { - return toolsets.NewServerResourceTemplate( +func GetRepositoryResourceContent(t translations.TranslationHelperFunc) registry.ServerResourceTemplate { + return registry.NewServerResourceTemplate( ToolsetMetadataRepos, mcp.ResourceTemplate{ Name: "repository_content", @@ -43,8 +43,8 @@ func GetRepositoryResourceContent(t translations.TranslationHelperFunc) toolsets } // GetRepositoryResourceBranchContent defines the resource template for getting repository content for a branch. -func GetRepositoryResourceBranchContent(t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { - return toolsets.NewServerResourceTemplate( +func GetRepositoryResourceBranchContent(t translations.TranslationHelperFunc) registry.ServerResourceTemplate { + return registry.NewServerResourceTemplate( ToolsetMetadataRepos, mcp.ResourceTemplate{ Name: "repository_content_branch", @@ -56,8 +56,8 @@ func GetRepositoryResourceBranchContent(t translations.TranslationHelperFunc) to } // GetRepositoryResourceCommitContent defines the resource template for getting repository content for a commit. -func GetRepositoryResourceCommitContent(t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { - return toolsets.NewServerResourceTemplate( +func GetRepositoryResourceCommitContent(t translations.TranslationHelperFunc) registry.ServerResourceTemplate { + return registry.NewServerResourceTemplate( ToolsetMetadataRepos, mcp.ResourceTemplate{ Name: "repository_content_commit", @@ -69,8 +69,8 @@ func GetRepositoryResourceCommitContent(t translations.TranslationHelperFunc) to } // GetRepositoryResourceTagContent defines the resource template for getting repository content for a tag. -func GetRepositoryResourceTagContent(t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { - return toolsets.NewServerResourceTemplate( +func GetRepositoryResourceTagContent(t translations.TranslationHelperFunc) registry.ServerResourceTemplate { + return registry.NewServerResourceTemplate( ToolsetMetadataRepos, mcp.ResourceTemplate{ Name: "repository_content_tag", @@ -82,8 +82,8 @@ func GetRepositoryResourceTagContent(t translations.TranslationHelperFunc) tools } // GetRepositoryResourcePrContent defines the resource template for getting repository content for a pull request. -func GetRepositoryResourcePrContent(t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { - return toolsets.NewServerResourceTemplate( +func GetRepositoryResourcePrContent(t translations.TranslationHelperFunc) registry.ServerResourceTemplate { + return registry.NewServerResourceTemplate( ToolsetMetadataRepos, mcp.ResourceTemplate{ Name: "repository_content_pr", @@ -95,15 +95,15 @@ func GetRepositoryResourcePrContent(t translations.TranslationHelperFunc) toolse } // repositoryResourceContentsHandlerFunc returns a ResourceHandlerFunc that creates handlers on-demand. -func repositoryResourceContentsHandlerFunc(resourceURITemplate *uritemplate.Template) toolsets.ResourceHandlerFunc { +func repositoryResourceContentsHandlerFunc(resourceURITemplate *uritemplate.Template) registry.ResourceHandlerFunc { return func(deps any) mcp.ResourceHandler { d := deps.(ToolDependencies) - return RepositoryResourceContentsHandler(d.GetClient, d.GetRawClient, resourceURITemplate) + return RepositoryResourceContentsHandler(d, resourceURITemplate) } } // RepositoryResourceContentsHandler returns a handler function for repository content requests. -func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.GetRawClientFn, resourceURITemplate *uritemplate.Template) mcp.ResourceHandler { +func RepositoryResourceContentsHandler(deps ToolDependencies, resourceURITemplate *uritemplate.Template) mcp.ResourceHandler { return func(ctx context.Context, request *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { // Match the URI to extract parameters uriValues := resourceURITemplate.Match(request.Params.URI) @@ -157,7 +157,7 @@ func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.G prNumber := uriValues.Get("prNumber").String() if prNumber != "" { // fetch the PR from the API to get the latest commit and use SHA - githubClient, err := getClient(ctx) + githubClient, err := deps.GetClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -177,7 +177,7 @@ func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.G if path == "" || strings.HasSuffix(path, "/") { return nil, fmt.Errorf("directories are not supported: %s", path) } - rawClient, err := getRawClient(ctx) + rawClient, err := deps.GetRawClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get GitHub raw content client: %w", err) diff --git a/pkg/github/repository_resource_test.go b/pkg/github/repository_resource_test.go index f938a57f5..99c06cdd6 100644 --- a/pkg/github/repository_resource_test.go +++ b/pkg/github/repository_resource_test.go @@ -27,7 +27,7 @@ func Test_repositoryResourceContents(t *testing.T) { name string mockedClient *http.Client uri string - handlerFn func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler + handlerFn func(deps ToolDependencies) mcp.ResourceHandler expectedResponseType resourceResponseType expectError string expectedResult *mcp.ReadResourceResult @@ -45,8 +45,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo:///repo/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "owner is required", @@ -64,8 +64,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner//refs/heads/main/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceBranchContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceBranchContentURITemplate) }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "repo is required", @@ -83,8 +83,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/data.png", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeBlob, expectedResult: &mcp.ReadResourceResult{ @@ -107,8 +107,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -133,8 +133,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/pkg/github/actions.go", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -157,8 +157,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/refs/heads/main/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceBranchContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceBranchContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -181,8 +181,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/refs/tags/v1.0.0/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceTagContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceTagContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -205,8 +205,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/sha/abc123/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceCommitContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceCommitContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -237,8 +237,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/refs/pull/42/head/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourcePrContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourcePrContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -260,8 +260,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/nonexistent.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "404 Not Found", @@ -272,7 +272,11 @@ func Test_repositoryResourceContents(t *testing.T) { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) mockRawClient := raw.NewClient(client, base) - handler := tc.handlerFn(stubGetClientFn(client), stubGetRawClientFn(mockRawClient)) + deps := BaseDeps{ + Client: client, + RawClient: mockRawClient, + } + handler := tc.handlerFn(deps) request := &mcp.ReadResourceRequest{ Params: &mcp.ReadResourceParams{ diff --git a/pkg/github/resources.go b/pkg/github/resources.go index 253c4bc11..6acf5eb6a 100644 --- a/pkg/github/resources.go +++ b/pkg/github/resources.go @@ -1,14 +1,14 @@ package github import ( - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" ) // AllResources returns all resource templates with their embedded toolset metadata. // Resource definitions are stateless - handlers are generated on-demand during registration. -func AllResources(t translations.TranslationHelperFunc) []toolsets.ServerResourceTemplate { - return []toolsets.ServerResourceTemplate{ +func AllResources(t translations.TranslationHelperFunc) []registry.ServerResourceTemplate { + return []registry.ServerResourceTemplate{ // Repository resources GetRepositoryResourceContent(t), GetRepositoryResourceBranchContent(t), diff --git a/pkg/github/search.go b/pkg/github/search.go index 730435eba..4b35f3f0d 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -8,7 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -17,7 +17,7 @@ import ( ) // SearchRepositories creates a tool to search for GitHub repositories. -func SearchRepositories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchRepositories(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -167,7 +167,7 @@ func SearchRepositories(t translations.TranslationHelperFunc) toolsets.ServerToo } // SearchCode creates a tool to search for code across GitHub repositories. -func SearchCode(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchCode(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -351,7 +351,7 @@ func userOrOrgHandler(accountType string, deps ToolDependencies) mcp.ToolHandler } // SearchUsers creates a tool to search for GitHub users. -func SearchUsers(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchUsers(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -392,7 +392,7 @@ func SearchUsers(t translations.TranslationHelperFunc) toolsets.ServerTool { } // SearchOrgs creates a tool to search for GitHub organizations. -func SearchOrgs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchOrgs(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go index 41d12df1b..707b55349 100644 --- a/pkg/github/search_test.go +++ b/pkg/github/search_test.go @@ -134,8 +134,8 @@ func Test_SearchRepositories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -209,8 +209,8 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { client := github.NewClient(mockedClient) serverTool := SearchRepositories(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -358,8 +358,8 @@ func Test_SearchCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -558,8 +558,8 @@ func Test_SearchUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -733,8 +733,8 @@ func Test_SearchOrgs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) diff --git a/pkg/github/secret_scanning.go b/pkg/github/secret_scanning.go index 7e842ded1..e840072b0 100644 --- a/pkg/github/secret_scanning.go +++ b/pkg/github/secret_scanning.go @@ -8,7 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,7 +16,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetSecretScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetSecretScanningAlert(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataSecretProtection, mcp.Tool{ @@ -94,7 +94,7 @@ func GetSecretScanningAlert(t translations.TranslationHelperFunc) toolsets.Serve ) } -func ListSecretScanningAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListSecretScanningAlerts(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataSecretProtection, mcp.Tool{ diff --git a/pkg/github/secret_scanning_test.go b/pkg/github/secret_scanning_test.go index 83de16409..b63617a46 100644 --- a/pkg/github/secret_scanning_test.go +++ b/pkg/github/secret_scanning_test.go @@ -87,8 +87,8 @@ func Test_GetSecretScanningAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -228,8 +228,8 @@ func Test_ListSecretScanningAlerts(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) diff --git a/pkg/github/security_advisories.go b/pkg/github/security_advisories.go index cf507d17a..28acb8156 100644 --- a/pkg/github/security_advisories.go +++ b/pkg/github/security_advisories.go @@ -7,7 +7,7 @@ import ( "io" "net/http" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -15,7 +15,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func ListGlobalSecurityAdvisories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListGlobalSecurityAdvisories(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataSecurityAdvisories, mcp.Tool{ @@ -207,7 +207,7 @@ func ListGlobalSecurityAdvisories(t translations.TranslationHelperFunc) toolsets ) } -func ListRepositorySecurityAdvisories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListRepositorySecurityAdvisories(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataSecurityAdvisories, mcp.Tool{ @@ -312,7 +312,7 @@ func ListRepositorySecurityAdvisories(t translations.TranslationHelperFunc) tool ) } -func GetGlobalSecurityAdvisory(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetGlobalSecurityAdvisory(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataSecurityAdvisories, mcp.Tool{ @@ -370,7 +370,7 @@ func GetGlobalSecurityAdvisory(t translations.TranslationHelperFunc) toolsets.Se ) } -func ListOrgRepositorySecurityAdvisories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListOrgRepositorySecurityAdvisories(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataSecurityAdvisories, mcp.Tool{ diff --git a/pkg/github/security_advisories_test.go b/pkg/github/security_advisories_test.go index 16506a3e8..3970949ec 100644 --- a/pkg/github/security_advisories_test.go +++ b/pkg/github/security_advisories_test.go @@ -103,7 +103,7 @@ func Test_ListGlobalSecurityAdvisories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) // Create call request @@ -224,7 +224,7 @@ func Test_GetGlobalSecurityAdvisory(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) // Create call request @@ -372,7 +372,7 @@ func Test_ListRepositorySecurityAdvisories(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -517,7 +517,7 @@ func Test_ListOrgRepositorySecurityAdvisories(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/server.go b/pkg/github/server.go index 7432466d1..a9c9305a2 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -21,13 +21,6 @@ func NewServer(version string, opts *mcp.ServerOptions) *mcp.Server { opts = &mcp.ServerOptions{} } - // Always advertise capabilities so clients know we support list_changed notifications. - // This is important for dynamic toolsets mode where we start with few tools - // and add more at runtime. - opts.HasTools = true - opts.HasResources = true - opts.HasPrompts = true - // Create a new MCP server s := mcp.NewServer(&mcp.Implementation{ Name: "github-mcp-server", diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 2e9ab43a3..a59cd9a93 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -11,32 +11,67 @@ import ( "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" + "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/shurcooL/githubv4" "github.com/stretchr/testify/assert" ) -func stubGetClientFn(client *github.Client) GetClientFn { - return func(_ context.Context) (*github.Client, error) { - return client, nil +// stubDeps is a test helper that implements ToolDependencies with configurable behavior. +// Use this when you need to test error paths or when you need closure-based client creation. +type stubDeps struct { + clientFn func(context.Context) (*github.Client, error) + gqlClientFn func(context.Context) (*githubv4.Client, error) + rawClientFn func(context.Context) (*raw.Client, error) + + repoAccessCache *lockdown.RepoAccessCache + t translations.TranslationHelperFunc + flags FeatureFlags + contentWindowSize int +} + +func (s stubDeps) GetClient(ctx context.Context) (*github.Client, error) { + if s.clientFn != nil { + return s.clientFn(ctx) } + return nil, nil } -func stubGetClientFromHTTPFn(client *http.Client) GetClientFn { +func (s stubDeps) GetGQLClient(ctx context.Context) (*githubv4.Client, error) { + if s.gqlClientFn != nil { + return s.gqlClientFn(ctx) + } + return nil, nil +} + +func (s stubDeps) GetRawClient(ctx context.Context) (*raw.Client, error) { + if s.rawClientFn != nil { + return s.rawClientFn(ctx) + } + return nil, nil +} + +func (s stubDeps) GetRepoAccessCache() *lockdown.RepoAccessCache { return s.repoAccessCache } +func (s stubDeps) GetT() translations.TranslationHelperFunc { return s.t } +func (s stubDeps) GetFlags() FeatureFlags { return s.flags } +func (s stubDeps) GetContentWindowSize() int { return s.contentWindowSize } + +// Helper functions to create stub client functions for error testing +func stubClientFnFromHTTP(httpClient *http.Client) func(context.Context) (*github.Client, error) { return func(_ context.Context) (*github.Client, error) { - return github.NewClient(client), nil + return github.NewClient(httpClient), nil } } -func stubGetClientFnErr(err string) GetClientFn { +func stubClientFnErr(errMsg string) func(context.Context) (*github.Client, error) { return func(_ context.Context) (*github.Client, error) { - return nil, errors.New(err) + return nil, errors.New(errMsg) } } -func stubGetGQLClientFn(client *githubv4.Client) GetGQLClientFn { +func stubGQLClientFnErr(errMsg string) func(context.Context) (*githubv4.Client, error) { return func(_ context.Context) (*githubv4.Client, error) { - return client, nil + return nil, errors.New(errMsg) } } @@ -51,12 +86,6 @@ func stubFeatureFlags(enabledFlags map[string]bool) FeatureFlags { } } -func stubGetRawClientFn(client *raw.Client) raw.GetRawClientFn { - return func(_ context.Context) (*raw.Client, error) { - return client, nil - } -} - func badRequestHandler(msg string) http.HandlerFunc { return func(w http.ResponseWriter, _ *http.Request) { structuredErrorResponse := github.ErrorResponse{ diff --git a/pkg/github/tools.go b/pkg/github/tools.go index dd2ad4ff4..1fde6e6bb 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -2,10 +2,9 @@ package github import ( "context" - "fmt" "strings" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/shurcooL/githubv4" @@ -17,96 +16,96 @@ type GetGQLClientFn func(context.Context) (*githubv4.Client, error) // Toolset metadata constants - these define all available toolsets and their descriptions. // Tools use these constants to declare which toolset they belong to. var ( - ToolsetMetadataAll = toolsets.ToolsetMetadata{ + ToolsetMetadataAll = registry.ToolsetMetadata{ ID: "all", Description: "Special toolset that enables all available toolsets", } - ToolsetMetadataDefault = toolsets.ToolsetMetadata{ + ToolsetMetadataDefault = registry.ToolsetMetadata{ ID: "default", Description: "Special toolset that enables the default toolset configuration. When no toolsets are specified, this is the set that is enabled", } - ToolsetMetadataContext = toolsets.ToolsetMetadata{ + ToolsetMetadataContext = registry.ToolsetMetadata{ ID: "context", Description: "Tools that provide context about the current user and GitHub context you are operating in", Default: true, } - ToolsetMetadataRepos = toolsets.ToolsetMetadata{ + ToolsetMetadataRepos = registry.ToolsetMetadata{ ID: "repos", Description: "GitHub Repository related tools", Default: true, } - ToolsetMetadataGit = toolsets.ToolsetMetadata{ + ToolsetMetadataGit = registry.ToolsetMetadata{ ID: "git", Description: "GitHub Git API related tools for low-level Git operations", } - ToolsetMetadataIssues = toolsets.ToolsetMetadata{ + ToolsetMetadataIssues = registry.ToolsetMetadata{ ID: "issues", Description: "GitHub Issues related tools", Default: true, } - ToolsetMetadataPullRequests = toolsets.ToolsetMetadata{ + ToolsetMetadataPullRequests = registry.ToolsetMetadata{ ID: "pull_requests", Description: "GitHub Pull Request related tools", Default: true, } - ToolsetMetadataUsers = toolsets.ToolsetMetadata{ + ToolsetMetadataUsers = registry.ToolsetMetadata{ ID: "users", Description: "GitHub User related tools", Default: true, } - ToolsetMetadataOrgs = toolsets.ToolsetMetadata{ + ToolsetMetadataOrgs = registry.ToolsetMetadata{ ID: "orgs", Description: "GitHub Organization related tools", } - ToolsetMetadataActions = toolsets.ToolsetMetadata{ + ToolsetMetadataActions = registry.ToolsetMetadata{ ID: "actions", Description: "GitHub Actions workflows and CI/CD operations", } - ToolsetMetadataCodeSecurity = toolsets.ToolsetMetadata{ + ToolsetMetadataCodeSecurity = registry.ToolsetMetadata{ ID: "code_security", Description: "Code security related tools, such as GitHub Code Scanning", } - ToolsetMetadataSecretProtection = toolsets.ToolsetMetadata{ + ToolsetMetadataSecretProtection = registry.ToolsetMetadata{ ID: "secret_protection", Description: "Secret protection related tools, such as GitHub Secret Scanning", } - ToolsetMetadataDependabot = toolsets.ToolsetMetadata{ + ToolsetMetadataDependabot = registry.ToolsetMetadata{ ID: "dependabot", Description: "Dependabot tools", } - ToolsetMetadataNotifications = toolsets.ToolsetMetadata{ + ToolsetMetadataNotifications = registry.ToolsetMetadata{ ID: "notifications", Description: "GitHub Notifications related tools", } - ToolsetMetadataExperiments = toolsets.ToolsetMetadata{ + ToolsetMetadataExperiments = registry.ToolsetMetadata{ ID: "experiments", Description: "Experimental features that are not considered stable yet", } - ToolsetMetadataDiscussions = toolsets.ToolsetMetadata{ + ToolsetMetadataDiscussions = registry.ToolsetMetadata{ ID: "discussions", Description: "GitHub Discussions related tools", } - ToolsetMetadataGists = toolsets.ToolsetMetadata{ + ToolsetMetadataGists = registry.ToolsetMetadata{ ID: "gists", Description: "GitHub Gist related tools", } - ToolsetMetadataSecurityAdvisories = toolsets.ToolsetMetadata{ + ToolsetMetadataSecurityAdvisories = registry.ToolsetMetadata{ ID: "security_advisories", Description: "Security advisories related tools", } - ToolsetMetadataProjects = toolsets.ToolsetMetadata{ + ToolsetMetadataProjects = registry.ToolsetMetadata{ ID: "projects", Description: "GitHub Projects related tools", } - ToolsetMetadataStargazers = toolsets.ToolsetMetadata{ + ToolsetMetadataStargazers = registry.ToolsetMetadata{ ID: "stargazers", Description: "GitHub Stargazers related tools", } - ToolsetMetadataDynamic = toolsets.ToolsetMetadata{ + ToolsetMetadataDynamic = registry.ToolsetMetadata{ ID: "dynamic", Description: "Discover GitHub MCP tools that can help achieve tasks by enabling additional sets of tools, you can control the enablement of any toolset to access its tools when this toolset is enabled.", } - ToolsetLabels = toolsets.ToolsetMetadata{ + ToolsetLabels = registry.ToolsetMetadata{ ID: "labels", Description: "GitHub Labels related tools", } @@ -114,8 +113,8 @@ var ( // AllTools returns all tools with their embedded toolset metadata. // Tool functions return ServerTool directly with toolset info. -func AllTools(t translations.TranslationHelperFunc) []toolsets.ServerTool { - return []toolsets.ServerTool{ +func AllTools(t translations.TranslationHelperFunc) []registry.ServerTool { + return []registry.ServerTool{ // Context tools GetMe(t), GetTeams(t), @@ -241,6 +240,7 @@ func AllTools(t translations.TranslationHelperFunc) []toolsets.ServerTool { // Label tools GetLabel(t), + GetLabelForLabelsToolset(t), ListLabels(t), LabelWrite(t), } @@ -263,19 +263,21 @@ func ToStringPtr(s string) *string { // GenerateToolsetsHelp generates the help text for the toolsets flag func GenerateToolsetsHelp() string { // Get toolset group to derive defaults and available toolsets - r := NewRegistry(stubTranslator) + r := NewRegistry(stubTranslator).Build() - // Format default tools from metadata + // Format default tools from metadata using strings.Builder + var defaultBuf strings.Builder defaultIDs := r.DefaultToolsetIDs() - defaultStrings := make([]string, len(defaultIDs)) for i, id := range defaultIDs { - defaultStrings[i] = string(id) + if i > 0 { + defaultBuf.WriteString(", ") + } + defaultBuf.WriteString(string(id)) } - defaultTools := strings.Join(defaultStrings, ", ") // Get all available toolsets (excludes context and dynamic for display) allToolsets := r.AvailableToolsets("context", "dynamic") - var availableToolsLines []string + var availableBuf strings.Builder const maxLineLength = 70 currentLine := "" @@ -287,27 +289,37 @@ func GenerateToolsetsHelp() string { case len(currentLine)+len(id)+2 <= maxLineLength: currentLine += ", " + id default: - availableToolsLines = append(availableToolsLines, currentLine) + if availableBuf.Len() > 0 { + availableBuf.WriteString(",\n\t ") + } + availableBuf.WriteString(currentLine) currentLine = id } } if currentLine != "" { - availableToolsLines = append(availableToolsLines, currentLine) - } - - availableTools := strings.Join(availableToolsLines, ",\n\t ") - - toolsetsHelp := fmt.Sprintf("Comma-separated list of tool groups to enable (no spaces).\n"+ - "Available: %s\n", availableTools) + - "Special toolset keywords:\n" + - " - all: Enables all available toolsets\n" + - fmt.Sprintf(" - default: Enables the default toolset configuration of:\n\t %s\n", defaultTools) + - "Examples:\n" + - " - --toolsets=actions,gists,notifications\n" + - " - Default + additional: --toolsets=default,actions,gists\n" + - " - All tools: --toolsets=all" - - return toolsetsHelp + if availableBuf.Len() > 0 { + availableBuf.WriteString(",\n\t ") + } + availableBuf.WriteString(currentLine) + } + + // Build the complete help text using strings.Builder + var buf strings.Builder + buf.WriteString("Comma-separated list of tool groups to enable (no spaces).\n") + buf.WriteString("Available: ") + buf.WriteString(availableBuf.String()) + buf.WriteString("\n") + buf.WriteString("Special toolset keywords:\n") + buf.WriteString(" - all: Enables all available toolsets\n") + buf.WriteString(" - default: Enables the default toolset configuration of:\n\t ") + buf.WriteString(defaultBuf.String()) + buf.WriteString("\n") + buf.WriteString("Examples:\n") + buf.WriteString(" - --toolsets=actions,gists,notifications\n") + buf.WriteString(" - Default + additional: --toolsets=default,actions,gists\n") + buf.WriteString(" - All tools: --toolsets=all") + + return buf.String() } // stubTranslator is a passthrough translator for cases where we need a Registry @@ -333,7 +345,7 @@ func AddDefaultToolset(result []string) []string { result = RemoveToolset(result, string(ToolsetMetadataDefault.ID)) // Get default toolset IDs from the Registry - r := NewRegistry(stubTranslator) + r := NewRegistry(stubTranslator).Build() for _, id := range r.DefaultToolsetIDs() { if !seen[string(id)] { result = append(result, string(id)) @@ -381,3 +393,15 @@ func CleanTools(toolNames []string) []string { return result } + +// GetDefaultToolsetIDs returns the IDs of toolsets marked as Default. +// This is a convenience function that builds a registry to determine defaults. +func GetDefaultToolsetIDs() []string { + r := NewRegistry(stubTranslator).Build() + ids := r.DefaultToolsetIDs() + result := make([]string, len(ids)) + for i, id := range ids { + result[i] = string(id) + } + return result +} diff --git a/pkg/github/tools_test.go b/pkg/github/tools_test.go index 4e6d91980..80270d2bc 100644 --- a/pkg/github/tools_test.go +++ b/pkg/github/tools_test.go @@ -151,3 +151,34 @@ func TestContainsToolset(t *testing.T) { }) } } + +func TestGenerateToolsetsHelp(t *testing.T) { + // Generate the help text + helpText := GenerateToolsetsHelp() + + // Verify help text is not empty + require.NotEmpty(t, helpText) + + // Verify it contains expected sections + assert.Contains(t, helpText, "Comma-separated list of tool groups to enable") + assert.Contains(t, helpText, "Available:") + assert.Contains(t, helpText, "Special toolset keywords:") + assert.Contains(t, helpText, "all: Enables all available toolsets") + assert.Contains(t, helpText, "default: Enables the default toolset configuration") + assert.Contains(t, helpText, "Examples:") + assert.Contains(t, helpText, "--toolsets=actions,gists,notifications") + assert.Contains(t, helpText, "--toolsets=default,actions,gists") + assert.Contains(t, helpText, "--toolsets=all") + + // Verify it contains some expected default toolsets + assert.Contains(t, helpText, "context") + assert.Contains(t, helpText, "repos") + assert.Contains(t, helpText, "issues") + assert.Contains(t, helpText, "pull_requests") + assert.Contains(t, helpText, "users") + + // Verify it contains some expected available toolsets + assert.Contains(t, helpText, "actions") + assert.Contains(t, helpText, "gists") + assert.Contains(t, helpText, "notifications") +} diff --git a/pkg/github/tools_validation_test.go b/pkg/github/tools_validation_test.go index d53243b42..aa809dfa6 100644 --- a/pkg/github/tools_validation_test.go +++ b/pkg/github/tools_validation_test.go @@ -3,7 +3,7 @@ package github import ( "testing" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -102,10 +102,18 @@ func TestNoDuplicateToolNames(t *testing.T) { tools := AllTools(stubTranslation) seen := make(map[string]bool) + // get_label is intentionally in both issues and labels toolsets for conformance + // with original behavior where it was registered in both + allowedDuplicates := map[string]bool{ + "get_label": true, + } + for _, tool := range tools { name := tool.Tool.Name - assert.False(t, seen[name], - "Duplicate tool name found: %q", name) + if !allowedDuplicates[name] { + assert.False(t, seen[name], + "Duplicate tool name found: %q", name) + } seen[name] = true } } @@ -153,7 +161,7 @@ func TestAllToolsHaveHandlerFunc(t *testing.T) { // TestToolsetMetadataConsistency ensures tools in the same toolset have consistent descriptions func TestToolsetMetadataConsistency(t *testing.T) { tools := AllTools(stubTranslation) - toolsetDescriptions := make(map[toolsets.ToolsetID]string) + toolsetDescriptions := make(map[registry.ToolsetID]string) for _, tool := range tools { id := tool.Toolset.ID diff --git a/pkg/github/workflow_prompts.go b/pkg/github/workflow_prompts.go index cf972020d..603a98087 100644 --- a/pkg/github/workflow_prompts.go +++ b/pkg/github/workflow_prompts.go @@ -4,14 +4,14 @@ import ( "context" "fmt" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/modelcontextprotocol/go-sdk/mcp" ) // IssueToFixWorkflowPrompt provides a guided workflow for creating an issue and then generating a PR to fix it -func IssueToFixWorkflowPrompt(t translations.TranslationHelperFunc) toolsets.ServerPrompt { - return toolsets.NewServerPrompt( +func IssueToFixWorkflowPrompt(t translations.TranslationHelperFunc) registry.ServerPrompt { + return registry.NewServerPrompt( ToolsetMetadataIssues, mcp.Prompt{ Name: "issue_to_fix_workflow", diff --git a/pkg/registry/builder.go b/pkg/registry/builder.go new file mode 100644 index 000000000..813712064 --- /dev/null +++ b/pkg/registry/builder.go @@ -0,0 +1,274 @@ +package registry + +import ( + "context" + "sort" + "strings" +) + +// ToolFilter is a function that determines if a tool should be included. +// Returns true if the tool should be included, false to exclude it. +type ToolFilter func(ctx context.Context, tool *ServerTool) (bool, error) + +// Builder builds a Registry with the specified configuration. +// Use NewBuilder to create a builder, chain configuration methods, +// then call Build() to create the final Registry. +// +// Example: +// +// reg := NewBuilder(). +// SetTools(tools). +// SetResources(resources). +// SetPrompts(prompts). +// WithDeprecatedAliases(aliases). +// WithReadOnly(true). +// WithToolsets([]string{"repos", "issues"}). +// WithFeatureChecker(checker). +// WithFilter(myFilter). +// Build() +type Builder struct { + tools []ServerTool + resourceTemplates []ServerResourceTemplate + prompts []ServerPrompt + deprecatedAliases map[string]string + + // Configuration options (processed at Build time) + readOnly bool + toolsetIDs []string // raw input, processed at Build() + toolsetIDsIsNil bool // tracks if nil was passed (nil = defaults) + additionalTools []string // raw input, processed at Build() + featureChecker FeatureFlagChecker + filters []ToolFilter // filters to apply to all tools +} + +// NewBuilder creates a new Builder. +func NewBuilder() *Builder { + return &Builder{ + deprecatedAliases: make(map[string]string), + toolsetIDsIsNil: true, // default to nil (use defaults) + } +} + +// SetTools sets the tools for the registry. Returns self for chaining. +func (b *Builder) SetTools(tools []ServerTool) *Builder { + b.tools = tools + return b +} + +// SetResources sets the resource templates for the registry. Returns self for chaining. +func (b *Builder) SetResources(resources []ServerResourceTemplate) *Builder { + b.resourceTemplates = resources + return b +} + +// SetPrompts sets the prompts for the registry. Returns self for chaining. +func (b *Builder) SetPrompts(prompts []ServerPrompt) *Builder { + b.prompts = prompts + return b +} + +// WithDeprecatedAliases adds deprecated tool name aliases that map to canonical names. +// Returns self for chaining. +func (b *Builder) WithDeprecatedAliases(aliases map[string]string) *Builder { + for oldName, newName := range aliases { + b.deprecatedAliases[oldName] = newName + } + return b +} + +// WithReadOnly sets whether only read-only tools should be available. +// When true, write tools are filtered out. Returns self for chaining. +func (b *Builder) WithReadOnly(readOnly bool) *Builder { + b.readOnly = readOnly + return b +} + +// WithToolsets specifies which toolsets should be enabled. +// Special keywords: +// - "all": enables all toolsets +// - "default": expands to toolsets marked with Default: true in their metadata +// +// Input strings are trimmed of whitespace and duplicates are removed. +// Pass nil to use default toolsets. Pass an empty slice to disable all toolsets +// (useful for dynamic toolsets mode where tools are enabled on demand). +// Returns self for chaining. +func (b *Builder) WithToolsets(toolsetIDs []string) *Builder { + b.toolsetIDs = toolsetIDs + b.toolsetIDsIsNil = toolsetIDs == nil + return b +} + +// WithTools specifies additional tools that bypass toolset filtering. +// These tools are additive - they will be included even if their toolset is not enabled. +// Read-only filtering still applies to these tools. +// Deprecated tool aliases are automatically resolved to their canonical names during Build(). +// Returns self for chaining. +func (b *Builder) WithTools(toolNames []string) *Builder { + b.additionalTools = toolNames + return b +} + +// WithFeatureChecker sets the feature flag checker function. +// The checker receives a context (for actor extraction) and feature flag name, +// returns (enabled, error). If error occurs, it will be logged and treated as false. +// If checker is nil, all feature flag checks return false. +// Returns self for chaining. +func (b *Builder) WithFeatureChecker(checker FeatureFlagChecker) *Builder { + b.featureChecker = checker + return b +} + +// WithFilter adds a filter function that will be applied to all tools. +// Multiple filters can be added and are evaluated in order. +// If any filter returns false or an error, the tool is excluded. +// Returns self for chaining. +func (b *Builder) WithFilter(filter ToolFilter) *Builder { + b.filters = append(b.filters, filter) + return b +} + +// Build creates the final Registry with all configuration applied. +// This processes toolset filtering, tool name resolution, and sets up +// the registry for use. The returned Registry is ready for use with +// AvailableTools(), RegisterAll(), etc. +func (b *Builder) Build() *Registry { + r := &Registry{ + tools: b.tools, + resourceTemplates: b.resourceTemplates, + prompts: b.prompts, + deprecatedAliases: b.deprecatedAliases, + readOnly: b.readOnly, + featureChecker: b.featureChecker, + filters: b.filters, + } + + // Process toolsets and pre-compute metadata in a single pass + r.enabledToolsets, r.unrecognizedToolsets, r.toolsetIDs, r.toolsetIDSet, r.defaultToolsetIDs, r.toolsetDescriptions = b.processToolsets() + + // Process additional tools (resolve aliases) + if len(b.additionalTools) > 0 { + r.additionalTools = make(map[string]bool, len(b.additionalTools)) + for _, name := range b.additionalTools { + // Resolve deprecated aliases to canonical names + if canonical, isAlias := b.deprecatedAliases[name]; isAlias { + r.additionalTools[canonical] = true + } else { + r.additionalTools[name] = true + } + } + } + + return r +} + +// processToolsets processes the toolsetIDs configuration and returns: +// - enabledToolsets map (nil means all enabled) +// - unrecognizedToolsets list for warnings +// - allToolsetIDs sorted list of all toolset IDs +// - toolsetIDSet map for O(1) HasToolset lookup +// - defaultToolsetIDs sorted list of default toolset IDs +// - toolsetDescriptions map of toolset ID to description +func (b *Builder) processToolsets() (map[ToolsetID]bool, []string, []ToolsetID, map[ToolsetID]bool, []ToolsetID, map[ToolsetID]string) { + // Single pass: collect all toolset metadata together + validIDs := make(map[ToolsetID]bool) + defaultIDs := make(map[ToolsetID]bool) + descriptions := make(map[ToolsetID]string) + + for i := range b.tools { + t := &b.tools[i] + validIDs[t.Toolset.ID] = true + if t.Toolset.Default { + defaultIDs[t.Toolset.ID] = true + } + if t.Toolset.Description != "" { + descriptions[t.Toolset.ID] = t.Toolset.Description + } + } + for i := range b.resourceTemplates { + r := &b.resourceTemplates[i] + validIDs[r.Toolset.ID] = true + if r.Toolset.Default { + defaultIDs[r.Toolset.ID] = true + } + if r.Toolset.Description != "" { + descriptions[r.Toolset.ID] = r.Toolset.Description + } + } + for i := range b.prompts { + p := &b.prompts[i] + validIDs[p.Toolset.ID] = true + if p.Toolset.Default { + defaultIDs[p.Toolset.ID] = true + } + if p.Toolset.Description != "" { + descriptions[p.Toolset.ID] = p.Toolset.Description + } + } + + // Build sorted slices from the collected maps + allToolsetIDs := make([]ToolsetID, 0, len(validIDs)) + for id := range validIDs { + allToolsetIDs = append(allToolsetIDs, id) + } + sort.Slice(allToolsetIDs, func(i, j int) bool { return allToolsetIDs[i] < allToolsetIDs[j] }) + + defaultToolsetIDList := make([]ToolsetID, 0, len(defaultIDs)) + for id := range defaultIDs { + defaultToolsetIDList = append(defaultToolsetIDList, id) + } + sort.Slice(defaultToolsetIDList, func(i, j int) bool { return defaultToolsetIDList[i] < defaultToolsetIDList[j] }) + + toolsetIDs := b.toolsetIDs + + // Check for "all" keyword - enables all toolsets + for _, id := range toolsetIDs { + if strings.TrimSpace(id) == "all" { + return nil, nil, allToolsetIDs, validIDs, defaultToolsetIDList, descriptions // nil means all enabled + } + } + + // nil means use defaults, empty slice means no toolsets + if b.toolsetIDsIsNil { + toolsetIDs = []string{"default"} + } + + // Expand "default" keyword, trim whitespace, collect other IDs, and track unrecognized + seen := make(map[ToolsetID]bool) + expanded := make([]ToolsetID, 0, len(toolsetIDs)) + var unrecognized []string + + for _, id := range toolsetIDs { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + continue + } + if trimmed == "default" { + for _, defaultID := range defaultToolsetIDList { + if !seen[defaultID] { + seen[defaultID] = true + expanded = append(expanded, defaultID) + } + } + } else { + tsID := ToolsetID(trimmed) + if !seen[tsID] { + seen[tsID] = true + expanded = append(expanded, tsID) + // Track if this toolset doesn't exist + if !validIDs[tsID] { + unrecognized = append(unrecognized, trimmed) + } + } + } + } + + if len(expanded) == 0 { + return make(map[ToolsetID]bool), unrecognized, allToolsetIDs, validIDs, defaultToolsetIDList, descriptions + } + + enabledToolsets := make(map[ToolsetID]bool, len(expanded)) + for _, id := range expanded { + enabledToolsets[id] = true + } + return enabledToolsets, unrecognized, allToolsetIDs, validIDs, defaultToolsetIDList, descriptions +} diff --git a/pkg/registry/errors.go b/pkg/registry/errors.go new file mode 100644 index 000000000..75cbb6f82 --- /dev/null +++ b/pkg/registry/errors.go @@ -0,0 +1,41 @@ +package registry + +import "fmt" + +// ToolsetDoesNotExistError is returned when a toolset is not found. +type ToolsetDoesNotExistError struct { + Name string +} + +func (e *ToolsetDoesNotExistError) Error() string { + return fmt.Sprintf("toolset %s does not exist", e.Name) +} + +func (e *ToolsetDoesNotExistError) Is(target error) bool { + if target == nil { + return false + } + if _, ok := target.(*ToolsetDoesNotExistError); ok { + return true + } + return false +} + +// NewToolsetDoesNotExistError creates a new ToolsetDoesNotExistError. +func NewToolsetDoesNotExistError(name string) *ToolsetDoesNotExistError { + return &ToolsetDoesNotExistError{Name: name} +} + +// ToolDoesNotExistError is returned when a tool is not found. +type ToolDoesNotExistError struct { + Name string +} + +func (e *ToolDoesNotExistError) Error() string { + return fmt.Sprintf("tool %s does not exist", e.Name) +} + +// NewToolDoesNotExistError creates a new ToolDoesNotExistError. +func NewToolDoesNotExistError(name string) *ToolDoesNotExistError { + return &ToolDoesNotExistError{Name: name} +} diff --git a/pkg/registry/filters.go b/pkg/registry/filters.go new file mode 100644 index 000000000..5bce63570 --- /dev/null +++ b/pkg/registry/filters.go @@ -0,0 +1,289 @@ +package registry + +import ( + "context" + "fmt" + "os" + "sort" +) + +// FeatureFlagChecker is a function that checks if a feature flag is enabled. +// The context can be used to extract actor/user information for flag evaluation. +// Returns (enabled, error). If error occurs, the caller should log and treat as false. +type FeatureFlagChecker func(ctx context.Context, flagName string) (bool, error) + +// isToolsetEnabled checks if a toolset is enabled based on current filters. +func (r *Registry) isToolsetEnabled(toolsetID ToolsetID) bool { + // Check enabled toolsets filter + if r.enabledToolsets != nil { + return r.enabledToolsets[toolsetID] + } + return true +} + +// checkFeatureFlag checks a feature flag using the feature checker. +// Returns false if checker is nil or returns an error (errors are logged). +func (r *Registry) checkFeatureFlag(ctx context.Context, flagName string) bool { + if r.featureChecker == nil || flagName == "" { + return false + } + enabled, err := r.featureChecker(ctx, flagName) + if err != nil { + fmt.Fprintf(os.Stderr, "Feature flag check error for %q: %v\n", flagName, err) + return false + } + return enabled +} + +// isFeatureFlagAllowed checks if an item passes feature flag filtering. +// - If FeatureFlagEnable is set, the item is only allowed if the flag is enabled +// - If FeatureFlagDisable is set, the item is excluded if the flag is enabled +func (r *Registry) isFeatureFlagAllowed(ctx context.Context, enableFlag, disableFlag string) bool { + // Check enable flag - item requires this flag to be on + if enableFlag != "" && !r.checkFeatureFlag(ctx, enableFlag) { + return false + } + // Check disable flag - item is excluded if this flag is on + if disableFlag != "" && r.checkFeatureFlag(ctx, disableFlag) { + return false + } + return true +} + +// isToolEnabled checks if a specific tool is enabled based on current filters. +// Filter evaluation order: +// 1. Tool.Enabled (tool self-filtering) +// 2. FeatureFlagEnable/FeatureFlagDisable +// 3. Read-only filter +// 4. Builder filters (via WithFilter) +// 5. Toolset/additional tools +func (r *Registry) isToolEnabled(ctx context.Context, tool *ServerTool) bool { + // 1. Check tool's own Enabled function first + if tool.Enabled != nil { + enabled, err := tool.Enabled(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Tool.Enabled check error for %q: %v\n", tool.Tool.Name, err) + return false + } + if !enabled { + return false + } + } + // 2. Check feature flags + if !r.isFeatureFlagAllowed(ctx, tool.FeatureFlagEnable, tool.FeatureFlagDisable) { + return false + } + // 3. Check read-only filter (applies to all tools) + if r.readOnly && !tool.IsReadOnly() { + return false + } + // 4. Apply builder filters + for _, filter := range r.filters { + allowed, err := filter(ctx, tool) + if err != nil { + fmt.Fprintf(os.Stderr, "Builder filter error for tool %q: %v\n", tool.Tool.Name, err) + return false + } + if !allowed { + return false + } + } + // 5. Check if tool is in additionalTools (bypasses toolset filter) + if r.additionalTools != nil && r.additionalTools[tool.Tool.Name] { + return true + } + // 5. Check toolset filter + if !r.isToolsetEnabled(tool.Toolset.ID) { + return false + } + return true +} + +// AvailableTools returns the tools that pass all current filters, +// sorted deterministically by toolset ID, then tool name. +// The context is used for feature flag evaluation. +func (r *Registry) AvailableTools(ctx context.Context) []ServerTool { + var result []ServerTool + for i := range r.tools { + tool := &r.tools[i] + if r.isToolEnabled(ctx, tool) { + result = append(result, *tool) + } + } + + // Sort deterministically: by toolset ID, then by tool name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Tool.Name < result[j].Tool.Name + }) + + return result +} + +// AvailableResourceTemplates returns resource templates that pass all current filters, +// sorted deterministically by toolset ID, then template name. +// The context is used for feature flag evaluation. +func (r *Registry) AvailableResourceTemplates(ctx context.Context) []ServerResourceTemplate { + var result []ServerResourceTemplate + for i := range r.resourceTemplates { + res := &r.resourceTemplates[i] + // Check feature flags + if !r.isFeatureFlagAllowed(ctx, res.FeatureFlagEnable, res.FeatureFlagDisable) { + continue + } + if r.isToolsetEnabled(res.Toolset.ID) { + result = append(result, *res) + } + } + + // Sort deterministically: by toolset ID, then by template name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Template.Name < result[j].Template.Name + }) + + return result +} + +// AvailablePrompts returns prompts that pass all current filters, +// sorted deterministically by toolset ID, then prompt name. +// The context is used for feature flag evaluation. +func (r *Registry) AvailablePrompts(ctx context.Context) []ServerPrompt { + var result []ServerPrompt + for i := range r.prompts { + prompt := &r.prompts[i] + // Check feature flags + if !r.isFeatureFlagAllowed(ctx, prompt.FeatureFlagEnable, prompt.FeatureFlagDisable) { + continue + } + if r.isToolsetEnabled(prompt.Toolset.ID) { + result = append(result, *prompt) + } + } + + // Sort deterministically: by toolset ID, then by prompt name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Prompt.Name < result[j].Prompt.Name + }) + + return result +} + +// filterToolsByName returns tools matching the given name, checking deprecated aliases. +// Uses linear scan - optimized for single-lookup per-request scenarios (ForMCPRequest). +func (r *Registry) filterToolsByName(name string) []ServerTool { + // First check for exact match + for i := range r.tools { + if r.tools[i].Tool.Name == name { + return []ServerTool{r.tools[i]} + } + } + // Check if name is a deprecated alias + if canonical, isAlias := r.deprecatedAliases[name]; isAlias { + for i := range r.tools { + if r.tools[i].Tool.Name == canonical { + return []ServerTool{r.tools[i]} + } + } + } + return []ServerTool{} +} + +// filterResourcesByURI returns resource templates matching the given URI pattern. +// Uses linear scan - optimized for single-lookup per-request scenarios (ForMCPRequest). +func (r *Registry) filterResourcesByURI(uri string) []ServerResourceTemplate { + for i := range r.resourceTemplates { + if r.resourceTemplates[i].Template.URITemplate == uri { + return []ServerResourceTemplate{r.resourceTemplates[i]} + } + } + return []ServerResourceTemplate{} +} + +// filterPromptsByName returns prompts matching the given name. +// Uses linear scan - optimized for single-lookup per-request scenarios (ForMCPRequest). +func (r *Registry) filterPromptsByName(name string) []ServerPrompt { + for i := range r.prompts { + if r.prompts[i].Prompt.Name == name { + return []ServerPrompt{r.prompts[i]} + } + } + return []ServerPrompt{} +} + +// ToolsForToolset returns all tools belonging to a specific toolset. +// This method bypasses the toolset enabled filter (for dynamic toolset registration), +// but still respects the read-only filter. +func (r *Registry) ToolsForToolset(toolsetID ToolsetID) []ServerTool { + var result []ServerTool + for i := range r.tools { + tool := &r.tools[i] + // Only check read-only filter, not toolset enabled filter + if tool.Toolset.ID == toolsetID { + if r.readOnly && !tool.IsReadOnly() { + continue + } + result = append(result, *tool) + } + } + + // Sort by tool name for deterministic order + sort.Slice(result, func(i, j int) bool { + return result[i].Tool.Name < result[j].Tool.Name + }) + + return result +} + +// IsToolsetEnabled checks if a toolset is currently enabled based on filters. +func (r *Registry) IsToolsetEnabled(toolsetID ToolsetID) bool { + return r.isToolsetEnabled(toolsetID) +} + +// EnableToolset marks a toolset as enabled in this group. +// This is used by dynamic toolset management to track which toolsets have been enabled. +func (r *Registry) EnableToolset(toolsetID ToolsetID) { + if r.enabledToolsets == nil { + // nil means all enabled, so nothing to do + return + } + r.enabledToolsets[toolsetID] = true +} + +// EnabledToolsetIDs returns the list of enabled toolset IDs based on current filters. +// Returns all toolset IDs if no filter is set. +func (r *Registry) EnabledToolsetIDs() []ToolsetID { + if r.enabledToolsets == nil { + return r.ToolsetIDs() + } + + ids := make([]ToolsetID, 0, len(r.enabledToolsets)) + for id := range r.enabledToolsets { + if r.HasToolset(id) { + ids = append(ids, id) + } + } + sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) + return ids +} + +// FilteredTools returns tools filtered by the Enabled function and builder filters. +// This provides an explicit API for accessing filtered tools, currently implemented +// as an alias for AvailableTools. +// +// The error return is currently always nil but is included for future extensibility. +// Library consumers (e.g., remote server implementations) may need to surface +// recoverable filter errors rather than silently logging them. Having the error +// return in the API now avoids breaking changes later. +// +// The context is used for Enabled function evaluation and builder filter checks. +func (r *Registry) FilteredTools(ctx context.Context) ([]ServerTool, error) { + return r.AvailableTools(ctx), nil +} diff --git a/pkg/registry/prompts.go b/pkg/registry/prompts.go new file mode 100644 index 000000000..02dda6c9c --- /dev/null +++ b/pkg/registry/prompts.go @@ -0,0 +1,26 @@ +package registry + +import "github.com/modelcontextprotocol/go-sdk/mcp" + +// ServerPrompt pairs a prompt with its toolset metadata. +type ServerPrompt struct { + Prompt mcp.Prompt + Handler mcp.PromptHandler + // Toolset identifies which toolset this prompt belongs to + Toolset ToolsetMetadata + // FeatureFlagEnable specifies a feature flag that must be enabled for this prompt + // to be available. If set and the flag is not enabled, the prompt is omitted. + FeatureFlagEnable string + // FeatureFlagDisable specifies a feature flag that, when enabled, causes this prompt + // to be omitted. Used to disable prompts when a feature flag is on. + FeatureFlagDisable string +} + +// NewServerPrompt creates a new ServerPrompt with toolset metadata. +func NewServerPrompt(toolset ToolsetMetadata, prompt mcp.Prompt, handler mcp.PromptHandler) ServerPrompt { + return ServerPrompt{ + Prompt: prompt, + Handler: handler, + Toolset: toolset, + } +} diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go new file mode 100644 index 000000000..7376a4c9c --- /dev/null +++ b/pkg/registry/registry.go @@ -0,0 +1,282 @@ +package registry + +import ( + "context" + "fmt" + "os" + "slices" + "sort" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// Registry holds a collection of tools, resources, and prompts with filtering applied. +// Create a Registry using Builder: +// +// reg := NewBuilder(). +// SetTools(tools). +// WithReadOnly(true). +// WithToolsets([]string{"repos"}). +// Build() +// +// The Registry is configured at build time and provides: +// - Filtered access to tools/resources/prompts via Available* methods +// - Deterministic ordering for documentation generation +// - Lazy dependency injection during registration via RegisterAll() +// - Runtime toolset enabling for dynamic toolsets mode +type Registry struct { + // tools holds all tools in this group (ordered for iteration) + tools []ServerTool + // resourceTemplates holds all resource templates in this group (ordered for iteration) + resourceTemplates []ServerResourceTemplate + // prompts holds all prompts in this group (ordered for iteration) + prompts []ServerPrompt + // deprecatedAliases maps old tool names to new canonical names + deprecatedAliases map[string]string + + // Pre-computed toolset metadata (set during Build) + toolsetIDs []ToolsetID // sorted list of all toolset IDs + toolsetIDSet map[ToolsetID]bool // set for O(1) HasToolset lookup + defaultToolsetIDs []ToolsetID // sorted list of default toolset IDs + toolsetDescriptions map[ToolsetID]string // toolset ID -> description + + // Filters - these control what's returned by Available* methods + // readOnly when true filters out write tools + readOnly bool + // enabledToolsets when non-nil, only include tools/resources/prompts from these toolsets + // when nil, all toolsets are enabled + enabledToolsets map[ToolsetID]bool + // additionalTools are specific tools that bypass toolset filtering (but still respect read-only) + // These are additive - a tool is included if it matches toolset filters OR is in this set + additionalTools map[string]bool + // featureChecker when non-nil, checks if a feature flag is enabled. + // Takes context and flag name, returns (enabled, error). If error, log and treat as false. + // If checker is nil, all flag checks return false. + featureChecker FeatureFlagChecker + // filters are functions that will be applied to all tools during filtering. + // If any filter returns false or an error, the tool is excluded. + filters []ToolFilter + // unrecognizedToolsets holds toolset IDs that were requested but don't match any registered toolsets + unrecognizedToolsets []string +} + +// UnrecognizedToolsets returns toolset IDs that were passed to WithToolsets but don't +// match any registered toolsets. This is useful for warning users about typos. +func (r *Registry) UnrecognizedToolsets() []string { + return r.unrecognizedToolsets +} + +// MCP method constants for use with ForMCPRequest. +const ( + MCPMethodInitialize = "initialize" + MCPMethodToolsList = "tools/list" + MCPMethodToolsCall = "tools/call" + MCPMethodResourcesList = "resources/list" + MCPMethodResourcesRead = "resources/read" + MCPMethodResourcesTemplatesList = "resources/templates/list" + MCPMethodPromptsList = "prompts/list" + MCPMethodPromptsGet = "prompts/get" +) + +// ForMCPRequest returns a Registry optimized for a specific MCP request. +// This is designed for servers that create a new instance per request (like the remote server), +// allowing them to only register the items needed for that specific request rather than all ~90 tools. +// +// Parameters: +// - method: The MCP method being called (use MCP* constants) +// - itemName: Name of specific item for call/get methods (tool name, resource URI, or prompt name) +// +// Returns a new Registry containing only the items relevant to the request: +// - MCPMethodInitialize: Empty (capabilities are set via ServerOptions, not registration) +// - MCPMethodToolsList: All available tools (no resources/prompts) +// - MCPMethodToolsCall: Only the named tool +// - MCPMethodResourcesList, MCPMethodResourcesTemplatesList: All available resources (no tools/prompts) +// - MCPMethodResourcesRead: Only the named resource template +// - MCPMethodPromptsList: All available prompts (no tools/resources) +// - MCPMethodPromptsGet: Only the named prompt +// - Unknown methods: Empty (no items registered) +// +// All existing filters (read-only, toolsets, etc.) still apply to the returned items. +func (r *Registry) ForMCPRequest(method string, itemName string) *Registry { + // Create a shallow copy with shared filter settings + // Note: lazy-init maps (toolsByName, etc.) are NOT copied - the new Registry + // will initialize its own maps on first use if needed + result := &Registry{ + tools: r.tools, + resourceTemplates: r.resourceTemplates, + prompts: r.prompts, + deprecatedAliases: r.deprecatedAliases, + readOnly: r.readOnly, + enabledToolsets: r.enabledToolsets, // shared, not modified + additionalTools: r.additionalTools, // shared, not modified + featureChecker: r.featureChecker, + filters: r.filters, // shared, not modified + unrecognizedToolsets: r.unrecognizedToolsets, + } + + // Helper to clear all item types + clearAll := func() { + result.tools = []ServerTool{} + result.resourceTemplates = []ServerResourceTemplate{} + result.prompts = []ServerPrompt{} + } + + switch method { + case MCPMethodInitialize: + clearAll() + case MCPMethodToolsList: + result.resourceTemplates, result.prompts = nil, nil + case MCPMethodToolsCall: + result.resourceTemplates, result.prompts = nil, nil + if itemName != "" { + result.tools = r.filterToolsByName(itemName) + } + case MCPMethodResourcesList, MCPMethodResourcesTemplatesList: + result.tools, result.prompts = nil, nil + case MCPMethodResourcesRead: + result.tools, result.prompts = nil, nil + if itemName != "" { + result.resourceTemplates = r.filterResourcesByURI(itemName) + } + case MCPMethodPromptsList: + result.tools, result.resourceTemplates = nil, nil + case MCPMethodPromptsGet: + result.tools, result.resourceTemplates = nil, nil + if itemName != "" { + result.prompts = r.filterPromptsByName(itemName) + } + default: + clearAll() + } + + return result +} + +// ToolsetIDs returns a sorted list of unique toolset IDs from all tools in this group. +func (r *Registry) ToolsetIDs() []ToolsetID { + return r.toolsetIDs +} + +// DefaultToolsetIDs returns the IDs of toolsets marked as Default in their metadata. +// The IDs are returned in sorted order for deterministic output. +func (r *Registry) DefaultToolsetIDs() []ToolsetID { + return r.defaultToolsetIDs +} + +// ToolsetDescriptions returns a map of toolset ID to description for all toolsets. +func (r *Registry) ToolsetDescriptions() map[ToolsetID]string { + return r.toolsetDescriptions +} + +// RegisterTools registers all available tools with the server using the provided dependencies. +// The context is used for feature flag evaluation. +func (r *Registry) RegisterTools(ctx context.Context, s *mcp.Server, deps any) { + for _, tool := range r.AvailableTools(ctx) { + tool.RegisterFunc(s, deps) + } +} + +// RegisterResourceTemplates registers all available resource templates with the server. +// The context is used for feature flag evaluation. +func (r *Registry) RegisterResourceTemplates(ctx context.Context, s *mcp.Server, deps any) { + for _, res := range r.AvailableResourceTemplates(ctx) { + s.AddResourceTemplate(&res.Template, res.Handler(deps)) + } +} + +// RegisterPrompts registers all available prompts with the server. +// The context is used for feature flag evaluation. +func (r *Registry) RegisterPrompts(ctx context.Context, s *mcp.Server) { + for _, prompt := range r.AvailablePrompts(ctx) { + s.AddPrompt(&prompt.Prompt, prompt.Handler) + } +} + +// RegisterAll registers all available tools, resources, and prompts with the server. +// The context is used for feature flag evaluation. +func (r *Registry) RegisterAll(ctx context.Context, s *mcp.Server, deps any) { + r.RegisterTools(ctx, s, deps) + r.RegisterResourceTemplates(ctx, s, deps) + r.RegisterPrompts(ctx, s) +} + +// ResolveToolAliases resolves deprecated tool aliases to their canonical names. +// It logs a warning to stderr for each deprecated alias that is resolved. +// Returns: +// - resolved: tool names with aliases replaced by canonical names +// - aliasesUsed: map of oldName → newName for each alias that was resolved +func (r *Registry) ResolveToolAliases(toolNames []string) (resolved []string, aliasesUsed map[string]string) { + resolved = make([]string, 0, len(toolNames)) + aliasesUsed = make(map[string]string) + for _, toolName := range toolNames { + if canonicalName, isAlias := r.deprecatedAliases[toolName]; isAlias { + fmt.Fprintf(os.Stderr, "Warning: tool %q is deprecated, use %q instead\n", toolName, canonicalName) + aliasesUsed[toolName] = canonicalName + resolved = append(resolved, canonicalName) + } else { + resolved = append(resolved, toolName) + } + } + return resolved, aliasesUsed +} + +// FindToolByName searches all tools for one matching the given name. +// Returns the tool, its toolset ID, and an error if not found. +// This searches ALL tools regardless of filters. +func (r *Registry) FindToolByName(toolName string) (*ServerTool, ToolsetID, error) { + for i := range r.tools { + if r.tools[i].Tool.Name == toolName { + return &r.tools[i], r.tools[i].Toolset.ID, nil + } + } + return nil, "", NewToolDoesNotExistError(toolName) +} + +// HasToolset checks if any tool/resource/prompt belongs to the given toolset. +func (r *Registry) HasToolset(toolsetID ToolsetID) bool { + return r.toolsetIDSet[toolsetID] +} + +// AllTools returns all tools without any filtering, sorted deterministically. +func (r *Registry) AllTools() []ServerTool { + result := slices.Clone(r.tools) + + // Sort deterministically: by toolset ID, then by tool name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Tool.Name < result[j].Tool.Name + }) + + return result +} + +// AvailableToolsets returns the unique toolsets that have tools, in sorted order. +// This is the ordered intersection of toolsets with reality - only toolsets that +// actually contain tools are returned, sorted by toolset ID. +// Optional exclude parameter filters out specific toolset IDs from the result. +func (r *Registry) AvailableToolsets(exclude ...ToolsetID) []ToolsetMetadata { + tools := r.AllTools() + if len(tools) == 0 { + return nil + } + + // Build exclude set for O(1) lookup + excludeSet := make(map[ToolsetID]bool, len(exclude)) + for _, id := range exclude { + excludeSet[id] = true + } + + var result []ToolsetMetadata + var lastID ToolsetID + for _, tool := range tools { + if tool.Toolset.ID != lastID { + lastID = tool.Toolset.ID + if !excludeSet[lastID] { + result = append(result, tool.Toolset) + } + } + } + return result +} diff --git a/pkg/toolsets/toolsets_test.go b/pkg/registry/registry_test.go similarity index 58% rename from pkg/toolsets/toolsets_test.go rename to pkg/registry/registry_test.go index 0d1c35e2e..44ed3d773 100644 --- a/pkg/toolsets/toolsets_test.go +++ b/pkg/registry/registry_test.go @@ -1,4 +1,4 @@ -package toolsets +package registry import ( "context" @@ -65,15 +65,15 @@ func mockTool(name string, toolsetID string, readOnly bool) ServerTool { } func TestNewRegistryEmpty(t *testing.T) { - tsg := NewRegistry() - if len(tsg.tools) != 0 { - t.Fatalf("Expected tools to be empty, got %d items", len(tsg.tools)) + reg := NewBuilder().Build() + if len(reg.AvailableTools(context.Background())) != 0 { + t.Fatalf("Expected tools to be empty") } - if len(tsg.resourceTemplates) != 0 { - t.Fatalf("Expected resourceTemplates to be empty, got %d items", len(tsg.resourceTemplates)) + if len(reg.AvailableResourceTemplates(context.Background())) != 0 { + t.Fatalf("Expected resourceTemplates to be empty") } - if len(tsg.prompts) != 0 { - t.Fatalf("Expected prompts to be empty, got %d items", len(tsg.prompts)) + if len(reg.AvailablePrompts(context.Background())) != 0 { + t.Fatalf("Expected prompts to be empty") } } @@ -84,10 +84,10 @@ func TestNewRegistryWithTools(t *testing.T) { mockTool("tool3", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) + reg := NewBuilder().SetTools(tools).Build() - if len(tsg.tools) != 3 { - t.Errorf("Expected 3 tools, got %d", len(tsg.tools)) + if len(reg.AllTools()) != 3 { + t.Errorf("Expected 3 tools, got %d", len(reg.AllTools())) } } @@ -98,8 +98,8 @@ func TestAvailableTools_NoFilters(t *testing.T) { mockTool("tool_c", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - available := tsg.AvailableTools(context.Background()) + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) if len(available) != 3 { t.Fatalf("Expected 3 available tools, got %d", len(available)) @@ -120,29 +120,22 @@ func TestWithReadOnly(t *testing.T) { mockTool("write_tool", "toolset1", false), } - tsg := NewRegistry().SetTools(tools) - - // Original should have both tools - allTools := tsg.AvailableTools(context.Background()) + // Build without read-only - should have both tools + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + allTools := reg.AvailableTools(context.Background()) if len(allTools) != 2 { - t.Fatalf("Expected 2 tools in original, got %d", len(allTools)) + t.Fatalf("Expected 2 tools without read-only, got %d", len(allTools)) } - // Read-only should filter out write tools - readOnlyTsg := tsg.WithReadOnly(true) - readOnlyTools := readOnlyTsg.AvailableTools(context.Background()) + // Build with read-only - should filter out write tools + readOnlyReg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithReadOnly(true).Build() + readOnlyTools := readOnlyReg.AvailableTools(context.Background()) if len(readOnlyTools) != 1 { t.Fatalf("Expected 1 tool in read-only, got %d", len(readOnlyTools)) } if readOnlyTools[0].Tool.Name != "read_tool" { t.Errorf("Expected read_tool, got %s", readOnlyTools[0].Tool.Name) } - - // Original should still have both (immutability test) - allTools = tsg.AvailableTools(context.Background()) - if len(allTools) != 2 { - t.Fatalf("Original was mutated! Expected 2 tools, got %d", len(allTools)) - } } func TestWithToolsets(t *testing.T) { @@ -152,10 +145,15 @@ func TestWithToolsets(t *testing.T) { mockTool("tool3", "toolset3", true), } - tsg := NewRegistry().SetTools(tools) + // Build with all toolsets + allReg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + allTools := allReg.AvailableTools(context.Background()) + if len(allTools) != 3 { + t.Fatalf("Expected 3 tools without filter, got %d", len(allTools)) + } - // Filter to specific toolsets - filteredReg := tsg.WithToolsets([]string{"toolset1", "toolset3"}) + // Build with specific toolsets + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset1", "toolset3"}).Build() filteredTools := filteredReg.AvailableTools(context.Background()) if len(filteredTools) != 2 { @@ -170,12 +168,6 @@ func TestWithToolsets(t *testing.T) { if !toolNames["tool1"] || !toolNames["tool3"] { t.Errorf("Expected tool1 and tool3, got %v", toolNames) } - - // Original should still have all 3 (immutability test) - allTools := tsg.AvailableTools(context.Background()) - if len(allTools) != 3 { - t.Fatalf("Original was mutated! Expected 3 tools, got %d", len(allTools)) - } } func TestWithToolsetsTrimsWhitespace(t *testing.T) { @@ -184,10 +176,8 @@ func TestWithToolsetsTrimsWhitespace(t *testing.T) { mockTool("tool2", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - // Whitespace should be trimmed - filteredReg := tsg.WithToolsets([]string{" toolset1 ", " toolset2 "}) + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{" toolset1 ", " toolset2 "}).Build() filteredTools := filteredReg.AvailableTools(context.Background()) if len(filteredTools) != 2 { @@ -200,10 +190,8 @@ func TestWithToolsetsDeduplicates(t *testing.T) { mockTool("tool1", "toolset1", true), } - tsg := NewRegistry().SetTools(tools) - // Duplicates should be removed - filteredReg := tsg.WithToolsets([]string{"toolset1", "toolset1", " toolset1 "}) + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset1", "toolset1", " toolset1 "}).Build() filteredTools := filteredReg.AvailableTools(context.Background()) if len(filteredTools) != 1 { @@ -216,10 +204,8 @@ func TestWithToolsetsIgnoresEmptyStrings(t *testing.T) { mockTool("tool1", "toolset1", true), } - tsg := NewRegistry().SetTools(tools) - // Empty strings should be ignored - filteredReg := tsg.WithToolsets([]string{"", "toolset1", " ", ""}) + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{"", "toolset1", " ", ""}).Build() filteredTools := filteredReg.AvailableTools(context.Background()) if len(filteredTools) != 1 { @@ -233,8 +219,6 @@ func TestUnrecognizedToolsets(t *testing.T) { mockTool("tool2", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - tests := []struct { name string input []string @@ -269,7 +253,7 @@ func TestUnrecognizedToolsets(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - filtered := tsg.WithToolsets(tt.input) + filtered := NewBuilder().SetTools(tools).WithToolsets(tt.input).Build() unrecognized := filtered.UnrecognizedToolsets() if len(unrecognized) != len(tt.expectedUnrecognized) { @@ -293,11 +277,9 @@ func TestWithTools(t *testing.T) { mockTool("tool3", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - // WithTools adds additional tools that bypass toolset filtering // When combined with WithToolsets([]), only the additional tools should be available - filteredReg := tsg.WithToolsets([]string{}).WithTools([]string{"tool1", "tool3"}) + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{}).WithTools([]string{"tool1", "tool3"}).Build() filteredTools := filteredReg.AvailableTools(context.Background()) if len(filteredTools) != 2 { @@ -321,10 +303,8 @@ func TestChainedFilters(t *testing.T) { mockTool("write2", "toolset2", false), } - tsg := NewRegistry().SetTools(tools) - // Chain read-only and toolset filter - filtered := tsg.WithReadOnly(true).WithToolsets([]string{"toolset1"}) + filtered := NewBuilder().SetTools(tools).WithReadOnly(true).WithToolsets([]string{"toolset1"}).Build() result := filtered.AvailableTools(context.Background()) if len(result) != 1 { @@ -342,8 +322,8 @@ func TestToolsetIDs(t *testing.T) { mockTool("tool3", "toolset_b", true), // duplicate toolset } - tsg := NewRegistry().SetTools(tools) - ids := tsg.ToolsetIDs() + reg := NewBuilder().SetTools(tools).Build() + ids := reg.ToolsetIDs() if len(ids) != 2 { t.Fatalf("Expected 2 unique toolset IDs, got %d", len(ids)) @@ -361,8 +341,8 @@ func TestToolsetDescriptions(t *testing.T) { mockTool("tool2", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - descriptions := tsg.ToolsetDescriptions() + reg := NewBuilder().SetTools(tools).Build() + descriptions := reg.ToolsetDescriptions() if len(descriptions) != 2 { t.Fatalf("Expected 2 descriptions, got %d", len(descriptions)) @@ -380,35 +360,31 @@ func TestToolsForToolset(t *testing.T) { mockTool("tool3", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - toolset1Tools := tsg.ToolsForToolset("toolset1") + reg := NewBuilder().SetTools(tools).Build() + toolset1Tools := reg.ToolsForToolset("toolset1") if len(toolset1Tools) != 2 { t.Fatalf("Expected 2 tools for toolset1, got %d", len(toolset1Tools)) } } -func TestWithDeprecatedToolAliases(t *testing.T) { +func TestWithDeprecatedAliases(t *testing.T) { tools := []ServerTool{ mockTool("new_name", "toolset1", true), } - tsg := NewRegistry().SetTools(tools) - tsgWithAliases := tsg.WithDeprecatedToolAliases(map[string]string{ + reg := NewBuilder().SetTools(tools).WithDeprecatedAliases(map[string]string{ "old_name": "new_name", "get_issue": "issue_read", - }) - - // Original should be unchanged (immutable) - if len(tsg.deprecatedAliases) != 0 { - t.Errorf("original should have 0 aliases, got %d", len(tsg.deprecatedAliases)) - } + }).Build() - if len(tsgWithAliases.deprecatedAliases) != 2 { - t.Errorf("expected 2 aliases, got %d", len(tsgWithAliases.deprecatedAliases)) + // Test resolving aliases + resolved, aliasesUsed := reg.ResolveToolAliases([]string{"old_name"}) + if len(resolved) != 1 || resolved[0] != "new_name" { + t.Errorf("expected alias to resolve to 'new_name', got %v", resolved) } - if tsgWithAliases.deprecatedAliases["old_name"] != "new_name" { - t.Errorf("expected alias 'old_name' -> 'new_name', got '%s'", tsgWithAliases.deprecatedAliases["old_name"]) + if len(aliasesUsed) != 1 || aliasesUsed["old_name"] != "new_name" { + t.Errorf("expected alias mapping, got %v", aliasesUsed) } } @@ -418,14 +394,14 @@ func TestResolveToolAliases(t *testing.T) { mockTool("some_tool", "toolset1", true), } - tsg := NewRegistry().SetTools(tools). - WithDeprecatedToolAliases(map[string]string{ + reg := NewBuilder().SetTools(tools). + WithDeprecatedAliases(map[string]string{ "get_issue": "issue_read", - }) + }).Build() // Test resolving a mix of aliases and canonical names input := []string{"get_issue", "some_tool"} - resolved, aliasesUsed := tsg.ResolveToolAliases(input) + resolved, aliasesUsed := reg.ResolveToolAliases(input) if len(resolved) != 2 { t.Fatalf("expected 2 resolved names, got %d", len(resolved)) @@ -450,10 +426,10 @@ func TestFindToolByName(t *testing.T) { mockTool("issue_read", "toolset1", true), } - tsg := NewRegistry().SetTools(tools) + reg := NewBuilder().SetTools(tools).Build() // Find by name - tool, toolsetID, err := tsg.FindToolByName("issue_read") + tool, toolsetID, err := reg.FindToolByName("issue_read") if err != nil { t.Fatalf("expected no error, got %v", err) } @@ -465,7 +441,7 @@ func TestFindToolByName(t *testing.T) { } // Non-existent tool - _, _, err = tsg.FindToolByName("nonexistent") + _, _, err = reg.FindToolByName("nonexistent") if err == nil { t.Error("expected error for non-existent tool") } @@ -478,11 +454,9 @@ func TestWithToolsAdditive(t *testing.T) { mockTool("repo_read", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - // Test WithTools bypasses toolset filtering // Enable only toolset2, but add issue_read as additional tool - filtered := tsg.WithToolsets([]string{"toolset2"}).WithTools([]string{"issue_read"}) + filtered := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset2"}).WithTools([]string{"issue_read"}).Build() available := filtered.AvailableTools(context.Background()) if len(available) != 2 { @@ -502,7 +476,7 @@ func TestWithToolsAdditive(t *testing.T) { } // Test WithTools respects read-only mode - readOnlyFiltered := tsg.WithReadOnly(true).WithTools([]string{"issue_write"}) + readOnlyFiltered := NewBuilder().SetTools(tools).WithReadOnly(true).WithTools([]string{"issue_write"}).Build() available = readOnlyFiltered.AvailableTools(context.Background()) // issue_write should be excluded because read-only applies to additional tools too @@ -513,7 +487,7 @@ func TestWithToolsAdditive(t *testing.T) { } // Test WithTools with non-existent tool (should not error, just won't match anything) - nonexistent := tsg.WithToolsets([]string{}).WithTools([]string{"nonexistent"}) + nonexistent := NewBuilder().SetTools(tools).WithToolsets([]string{}).WithTools([]string{"nonexistent"}).Build() available = nonexistent.AvailableTools(context.Background()) if len(available) != 0 { t.Errorf("expected 0 tools for non-existent additional tool, got %d", len(available)) @@ -525,13 +499,14 @@ func TestWithToolsResolvesAliases(t *testing.T) { mockTool("issue_read", "toolset1", true), } - tsg := NewRegistry().SetTools(tools). - WithDeprecatedToolAliases(map[string]string{ - "get_issue": "issue_read", - }) - // Using deprecated alias should resolve to canonical name - filtered := tsg.WithToolsets([]string{}).WithTools([]string{"get_issue"}) + filtered := NewBuilder().SetTools(tools). + WithDeprecatedAliases(map[string]string{ + "get_issue": "issue_read", + }). + WithToolsets([]string{}). + WithTools([]string{"get_issue"}). + Build() available := filtered.AvailableTools(context.Background()) if len(available) != 1 { @@ -547,12 +522,12 @@ func TestHasToolset(t *testing.T) { mockTool("tool1", "toolset1", true), } - tsg := NewRegistry().SetTools(tools) + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() - if !tsg.HasToolset("toolset1") { + if !reg.HasToolset("toolset1") { t.Error("expected HasToolset to return true for existing toolset") } - if tsg.HasToolset("nonexistent") { + if reg.HasToolset("nonexistent") { t.Error("expected HasToolset to return false for non-existent toolset") } } @@ -563,16 +538,15 @@ func TestEnabledToolsetIDs(t *testing.T) { mockTool("tool2", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - // Without filter, all toolsets are enabled - ids := tsg.EnabledToolsetIDs() + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + ids := reg.EnabledToolsetIDs() if len(ids) != 2 { t.Fatalf("Expected 2 enabled toolset IDs, got %d", len(ids)) } // With filter - filtered := tsg.WithToolsets([]string{"toolset1"}) + filtered := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset1"}).Build() filteredIDs := filtered.EnabledToolsetIDs() if len(filteredIDs) != 1 { t.Fatalf("Expected 1 enabled toolset ID, got %d", len(filteredIDs)) @@ -588,18 +562,16 @@ func TestAllTools(t *testing.T) { mockTool("write_tool", "toolset1", false), } - tsg := NewRegistry().SetTools(tools) - // Even with read-only filter, AllTools returns everything - readOnlyTsg := tsg.WithReadOnly(true) + readOnlyReg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithReadOnly(true).Build() - allTools := readOnlyTsg.AllTools() + allTools := readOnlyReg.AllTools() if len(allTools) != 2 { t.Fatalf("Expected 2 tools from AllTools, got %d", len(allTools)) } // But AvailableTools respects the filter - availableTools := readOnlyTsg.AvailableTools(context.Background()) + availableTools := readOnlyReg.AvailableTools(context.Background()) if len(availableTools) != 1 { t.Fatalf("Expected 1 tool from AvailableTools, got %d", len(availableTools)) } @@ -656,8 +628,8 @@ func TestForMCPRequest_Initialize(t *testing.T) { mockPrompt("prompt1", "repos"), } - tsg := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) - filtered := tsg.ForMCPRequest(MCPMethodInitialize, "") + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodInitialize, "") // Initialize should return empty - capabilities come from ServerOptions if len(filtered.AvailableTools(context.Background())) != 0 { @@ -683,8 +655,8 @@ func TestForMCPRequest_ToolsList(t *testing.T) { mockPrompt("prompt1", "repos"), } - tsg := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) - filtered := tsg.ForMCPRequest(MCPMethodToolsList, "") + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsList, "") // tools/list should return all tools, no resources or prompts if len(filtered.AvailableTools(context.Background())) != 2 { @@ -705,8 +677,8 @@ func TestForMCPRequest_ToolsCall(t *testing.T) { mockTool("list_repos", "repos", true), } - tsg := NewRegistry().SetTools(tools) - filtered := tsg.ForMCPRequest(MCPMethodToolsCall, "get_me") + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "get_me") available := filtered.AvailableTools(context.Background()) if len(available) != 1 { @@ -722,8 +694,8 @@ func TestForMCPRequest_ToolsCall_NotFound(t *testing.T) { mockTool("get_me", "context", true), } - tsg := NewRegistry().SetTools(tools) - filtered := tsg.ForMCPRequest(MCPMethodToolsCall, "nonexistent") + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "nonexistent") if len(filtered.AvailableTools(context.Background())) != 0 { t.Errorf("Expected 0 tools for nonexistent tool, got %d", len(filtered.AvailableTools(context.Background()))) @@ -736,13 +708,14 @@ func TestForMCPRequest_ToolsCall_DeprecatedAlias(t *testing.T) { mockTool("list_commits", "repos", true), } - tsg := NewRegistry().SetTools(tools). - WithDeprecatedToolAliases(map[string]string{ + reg := NewBuilder().SetTools(tools). + WithToolsets([]string{"all"}). + WithDeprecatedAliases(map[string]string{ "old_get_me": "get_me", - }) + }).Build() // Request using the deprecated alias - filtered := tsg.ForMCPRequest(MCPMethodToolsCall, "old_get_me") + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "old_get_me") available := filtered.AvailableTools(context.Background()) if len(available) != 1 { @@ -758,9 +731,9 @@ func TestForMCPRequest_ToolsCall_RespectsFilters(t *testing.T) { mockTool("create_issue", "issues", false), // write tool } - tsg := NewRegistry().SetTools(tools) - // Apply read-only filter, then ForMCPRequest - filtered := tsg.WithReadOnly(true).ForMCPRequest(MCPMethodToolsCall, "create_issue") + // Apply read-only filter at build time, then ForMCPRequest + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithReadOnly(true).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "create_issue") // The tool exists in the filtered group, but AvailableTools respects read-only available := filtered.AvailableTools(context.Background()) @@ -781,8 +754,8 @@ func TestForMCPRequest_ResourcesList(t *testing.T) { mockPrompt("prompt1", "repos"), } - tsg := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) - filtered := tsg.ForMCPRequest(MCPMethodResourcesList, "") + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodResourcesList, "") if len(filtered.AvailableTools(context.Background())) != 0 { t.Errorf("Expected 0 tools for resources/list, got %d", len(filtered.AvailableTools(context.Background()))) @@ -801,8 +774,8 @@ func TestForMCPRequest_ResourcesRead(t *testing.T) { mockResource("res2", "repos", "branch://{owner}/{repo}/{branch}"), } - tsg := NewRegistry().SetResources(resources) - filtered := tsg.ForMCPRequest(MCPMethodResourcesRead, "repo://{owner}/{repo}") + reg := NewBuilder().SetResources(resources).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodResourcesRead, "repo://{owner}/{repo}") available := filtered.AvailableResourceTemplates(context.Background()) if len(available) != 1 { @@ -825,8 +798,8 @@ func TestForMCPRequest_PromptsList(t *testing.T) { mockPrompt("prompt2", "issues"), } - tsg := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) - filtered := tsg.ForMCPRequest(MCPMethodPromptsList, "") + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodPromptsList, "") if len(filtered.AvailableTools(context.Background())) != 0 { t.Errorf("Expected 0 tools for prompts/list, got %d", len(filtered.AvailableTools(context.Background()))) @@ -845,8 +818,8 @@ func TestForMCPRequest_PromptsGet(t *testing.T) { mockPrompt("prompt2", "issues"), } - tsg := NewRegistry().SetPrompts(prompts) - filtered := tsg.ForMCPRequest(MCPMethodPromptsGet, "prompt1") + reg := NewBuilder().SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodPromptsGet, "prompt1") available := filtered.AvailablePrompts(context.Background()) if len(available) != 1 { @@ -868,8 +841,8 @@ func TestForMCPRequest_UnknownMethod(t *testing.T) { mockPrompt("prompt1", "repos"), } - tsg := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) - filtered := tsg.ForMCPRequest("unknown/method", "") + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest("unknown/method", "") // Unknown methods should return empty if len(filtered.AvailableTools(context.Background())) != 0 { @@ -883,7 +856,7 @@ func TestForMCPRequest_UnknownMethod(t *testing.T) { } } -func TestForMCPRequest_Immutability(t *testing.T) { +func TestForMCPRequest_DoesNotMutateOriginal(t *testing.T) { tools := []ServerTool{ mockTool("tool1", "repos", true), mockTool("tool2", "issues", true), @@ -895,7 +868,7 @@ func TestForMCPRequest_Immutability(t *testing.T) { mockPrompt("prompt1", "repos"), } - original := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) + original := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() filtered := original.ForMCPRequest(MCPMethodToolsCall, "tool1") // Original should be unchanged @@ -929,13 +902,12 @@ func TestForMCPRequest_ChainedWithOtherFilters(t *testing.T) { mockToolWithDefault("delete_repo", "repos", false, true), // default but write } - tsg := NewRegistry().SetTools(tools) - // Chain: default toolsets -> read-only -> specific method - filtered := tsg. + reg := NewBuilder().SetTools(tools). WithToolsets([]string{"default"}). WithReadOnly(true). - ForMCPRequest(MCPMethodToolsList, "") + Build() + filtered := reg.ForMCPRequest(MCPMethodToolsList, "") available := filtered.AvailableTools(context.Background()) @@ -972,8 +944,8 @@ func TestForMCPRequest_ResourcesTemplatesList(t *testing.T) { mockResource("res1", "repos", "repo://{owner}/{repo}"), } - tsg := NewRegistry().SetTools(tools).SetResources(resources) - filtered := tsg.ForMCPRequest(MCPMethodResourcesTemplatesList, "") + reg := NewBuilder().SetTools(tools).SetResources(resources).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodResourcesTemplatesList, "") // Same behavior as resources/list if len(filtered.AvailableTools(context.Background())) != 0 { @@ -1021,10 +993,9 @@ func TestFeatureFlagEnable(t *testing.T) { mockToolWithFlags("needs_flag", "toolset1", true, "my_feature", ""), } - tsg := NewRegistry().SetTools(tools) - // Without feature checker, tool with FeatureFlagEnable should be excluded - available := tsg.AvailableTools(context.Background()) + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) if len(available) != 1 { t.Fatalf("Expected 1 tool without feature checker, got %d", len(available)) } @@ -1034,8 +1005,8 @@ func TestFeatureFlagEnable(t *testing.T) { // With feature checker returning false, tool should still be excluded checkerFalse := func(_ context.Context, _ string) (bool, error) { return false, nil } - filteredFalse := tsg.WithFeatureChecker(checkerFalse) - availableFalse := filteredFalse.AvailableTools(context.Background()) + regFalse := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checkerFalse).Build() + availableFalse := regFalse.AvailableTools(context.Background()) if len(availableFalse) != 1 { t.Fatalf("Expected 1 tool with false checker, got %d", len(availableFalse)) } @@ -1044,8 +1015,8 @@ func TestFeatureFlagEnable(t *testing.T) { checkerTrue := func(_ context.Context, flag string) (bool, error) { return flag == "my_feature", nil } - filteredTrue := tsg.WithFeatureChecker(checkerTrue) - availableTrue := filteredTrue.AvailableTools(context.Background()) + regTrue := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checkerTrue).Build() + availableTrue := regTrue.AvailableTools(context.Background()) if len(availableTrue) != 2 { t.Fatalf("Expected 2 tools with true checker, got %d", len(availableTrue)) } @@ -1057,10 +1028,9 @@ func TestFeatureFlagDisable(t *testing.T) { mockToolWithFlags("disabled_by_flag", "toolset1", true, "", "kill_switch"), } - tsg := NewRegistry().SetTools(tools) - // Without feature checker, tool with FeatureFlagDisable should be included (flag is false) - available := tsg.AvailableTools(context.Background()) + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) if len(available) != 2 { t.Fatalf("Expected 2 tools without feature checker, got %d", len(available)) } @@ -1069,8 +1039,8 @@ func TestFeatureFlagDisable(t *testing.T) { checkerTrue := func(_ context.Context, flag string) (bool, error) { return flag == "kill_switch", nil } - filtered := tsg.WithFeatureChecker(checkerTrue) - availableFiltered := filtered.AvailableTools(context.Background()) + regFiltered := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checkerTrue).Build() + availableFiltered := regFiltered.AvailableTools(context.Background()) if len(availableFiltered) != 1 { t.Fatalf("Expected 1 tool with kill_switch enabled, got %d", len(availableFiltered)) } @@ -1085,23 +1055,24 @@ func TestFeatureFlagBoth(t *testing.T) { mockToolWithFlags("complex_tool", "toolset1", true, "new_feature", "kill_switch"), } - tsg := NewRegistry().SetTools(tools) - // Enable flag not set -> excluded checker1 := func(_ context.Context, _ string) (bool, error) { return false, nil } - if len(tsg.WithFeatureChecker(checker1).AvailableTools(context.Background())) != 0 { + reg1 := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checker1).Build() + if len(reg1.AvailableTools(context.Background())) != 0 { t.Error("Tool should be excluded when enable flag is false") } // Enable flag set, disable flag not set -> included checker2 := func(_ context.Context, flag string) (bool, error) { return flag == "new_feature", nil } - if len(tsg.WithFeatureChecker(checker2).AvailableTools(context.Background())) != 1 { + reg2 := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checker2).Build() + if len(reg2.AvailableTools(context.Background())) != 1 { t.Error("Tool should be included when enable flag is true and disable flag is false") } // Enable flag set, disable flag also set -> excluded (disable wins) checker3 := func(_ context.Context, _ string) (bool, error) { return true, nil } - if len(tsg.WithFeatureChecker(checker3).AvailableTools(context.Background())) != 0 { + reg3 := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checker3).Build() + if len(reg3.AvailableTools(context.Background())) != 0 { t.Error("Tool should be excluded when both flags are true (disable wins)") } } @@ -1111,14 +1082,12 @@ func TestFeatureFlagError(t *testing.T) { mockToolWithFlags("needs_flag", "toolset1", true, "my_feature", ""), } - tsg := NewRegistry().SetTools(tools) - // Checker that returns error should treat as false (tool excluded) checkerError := func(_ context.Context, _ string) (bool, error) { return false, fmt.Errorf("simulated error") } - filtered := tsg.WithFeatureChecker(checkerError) - available := filtered.AvailableTools(context.Background()) + reg := NewBuilder().SetTools(tools).WithFeatureChecker(checkerError).Build() + available := reg.AvailableTools(context.Background()) if len(available) != 0 { t.Errorf("Expected 0 tools when checker errors, got %d", len(available)) } @@ -1134,19 +1103,18 @@ func TestFeatureFlagResources(t *testing.T) { }, } - tsg := NewRegistry().SetResources(resources) - // Without checker, resource with enable flag should be excluded - available := tsg.AvailableResourceTemplates(context.Background()) + reg := NewBuilder().SetResources(resources).WithToolsets([]string{"all"}).Build() + available := reg.AvailableResourceTemplates(context.Background()) if len(available) != 1 { t.Fatalf("Expected 1 resource without checker, got %d", len(available)) } // With checker returning true, both should be included checker := func(_ context.Context, _ string) (bool, error) { return true, nil } - filtered := tsg.WithFeatureChecker(checker) - if len(filtered.AvailableResourceTemplates(context.Background())) != 2 { - t.Errorf("Expected 2 resources with checker, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + regWithChecker := NewBuilder().SetResources(resources).WithToolsets([]string{"all"}).WithFeatureChecker(checker).Build() + if len(regWithChecker.AvailableResourceTemplates(context.Background())) != 2 { + t.Errorf("Expected 2 resources with checker, got %d", len(regWithChecker.AvailableResourceTemplates(context.Background()))) } } @@ -1160,19 +1128,18 @@ func TestFeatureFlagPrompts(t *testing.T) { }, } - tsg := NewRegistry().SetPrompts(prompts) - // Without checker, prompt with enable flag should be excluded - available := tsg.AvailablePrompts(context.Background()) + reg := NewBuilder().SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + available := reg.AvailablePrompts(context.Background()) if len(available) != 1 { t.Fatalf("Expected 1 prompt without checker, got %d", len(available)) } // With checker returning true, both should be included checker := func(_ context.Context, _ string) (bool, error) { return true, nil } - filtered := tsg.WithFeatureChecker(checker) - if len(filtered.AvailablePrompts(context.Background())) != 2 { - t.Errorf("Expected 2 prompts with checker, got %d", len(filtered.AvailablePrompts(context.Background()))) + regWithChecker := NewBuilder().SetPrompts(prompts).WithToolsets([]string{"all"}).WithFeatureChecker(checker).Build() + if len(regWithChecker.AvailablePrompts(context.Background())) != 2 { + t.Errorf("Expected 2 prompts with checker, got %d", len(regWithChecker.AvailablePrompts(context.Background()))) } } @@ -1208,60 +1175,471 @@ func TestServerToolHandlerPanicOnNil(t *testing.T) { tool.Handler(nil) } -// TestRegistryCopyCopiesAllFields ensures the copy() method stays in sync with the struct. -// If you add a new field to Registry, this test will fail until you update copy(). -func TestRegistryCopyCopiesAllFields(t *testing.T) { - // Create a Registry with non-zero/non-nil values for ALL fields - original := &Registry{ - tools: []ServerTool{mockTool("t1", "ts1", true)}, - resourceTemplates: []ServerResourceTemplate{{Template: mcp.ResourceTemplate{Name: "r1"}}}, - prompts: []ServerPrompt{{Prompt: mcp.Prompt{Name: "p1"}}}, - deprecatedAliases: map[string]string{"old": "new"}, - readOnly: true, - enabledToolsets: map[ToolsetID]bool{"ts1": true}, - additionalTools: map[string]bool{"extra": true}, - featureChecker: func(_ context.Context, _ string) (bool, error) { return true, nil }, - unrecognizedToolsets: []string{"unknown"}, +// Tests for Enabled function on ServerTool +func TestServerToolEnabled(t *testing.T) { + tests := []struct { + name string + enabledFunc func(ctx context.Context) (bool, error) + expectedCount int + expectInResult bool + }{ + { + name: "nil Enabled function - tool included", + enabledFunc: nil, + expectedCount: 1, + expectInResult: true, + }, + { + name: "Enabled returns true - tool included", + enabledFunc: func(_ context.Context) (bool, error) { + return true, nil + }, + expectedCount: 1, + expectInResult: true, + }, + { + name: "Enabled returns false - tool excluded", + enabledFunc: func(_ context.Context) (bool, error) { + return false, nil + }, + expectedCount: 0, + expectInResult: false, + }, + { + name: "Enabled returns error - tool excluded", + enabledFunc: func(_ context.Context) (bool, error) { + return false, fmt.Errorf("simulated error") + }, + expectedCount: 0, + expectInResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tool := mockTool("test_tool", "toolset1", true) + tool.Enabled = tt.enabledFunc + + reg := NewBuilder().SetTools([]ServerTool{tool}).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) + + if len(available) != tt.expectedCount { + t.Errorf("Expected %d tools, got %d", tt.expectedCount, len(available)) + } + + found := false + for _, t := range available { + if t.Tool.Name == "test_tool" { + found = true + break + } + } + if found != tt.expectInResult { + t.Errorf("Expected tool in result: %v, got: %v", tt.expectInResult, found) + } + }) + } +} + +func TestServerToolEnabledWithContext(t *testing.T) { + type contextKey string + const userKey contextKey = "user" + + // Tool that checks context for user + tool := mockTool("context_aware_tool", "toolset1", true) + tool.Enabled = func(ctx context.Context) (bool, error) { + user := ctx.Value(userKey) + return user != nil && user.(string) == "authorized", nil + } + + reg := NewBuilder().SetTools([]ServerTool{tool}).WithToolsets([]string{"all"}).Build() + + // Without user in context - tool should be excluded + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools without user, got %d", len(available)) + } + + // With authorized user - tool should be included + ctxWithUser := context.WithValue(context.Background(), userKey, "authorized") + availableWithUser := reg.AvailableTools(ctxWithUser) + if len(availableWithUser) != 1 { + t.Errorf("Expected 1 tool with authorized user, got %d", len(availableWithUser)) + } + + // With unauthorized user - tool should be excluded + ctxWithBadUser := context.WithValue(context.Background(), userKey, "unauthorized") + availableWithBadUser := reg.AvailableTools(ctxWithBadUser) + if len(availableWithBadUser) != 0 { + t.Errorf("Expected 0 tools with unauthorized user, got %d", len(availableWithBadUser)) + } +} + +// Tests for WithFilter builder method +func TestBuilderWithFilter(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset1", true), + } + + // Filter that excludes tool2 + filter := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name != "tool2", nil } - copied := original.copy() + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() - // Verify all fields are copied correctly - if len(copied.tools) != len(original.tools) || (len(copied.tools) > 0 && copied.tools[0].Tool.Name != original.tools[0].Tool.Name) { - t.Error("tools not copied correctly") + available := reg.AvailableTools(context.Background()) + if len(available) != 2 { + t.Fatalf("Expected 2 tools after filter, got %d", len(available)) } - if len(copied.resourceTemplates) != len(original.resourceTemplates) { - t.Error("resourceTemplates not copied correctly") + + for _, tool := range available { + if tool.Tool.Name == "tool2" { + t.Error("tool2 should have been filtered out") + } } - if len(copied.prompts) != len(original.prompts) { - t.Error("prompts not copied correctly") +} + +func TestBuilderWithMultipleFilters(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset1", true), + mockTool("tool4", "toolset1", true), } - if len(copied.deprecatedAliases) != len(original.deprecatedAliases) || copied.deprecatedAliases["old"] != "new" { - t.Error("deprecatedAliases not copied correctly") + + // First filter excludes tool2 + filter1 := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name != "tool2", nil } - if copied.readOnly != original.readOnly { - t.Error("readOnly not copied correctly") + + // Second filter excludes tool3 + filter2 := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name != "tool3", nil } - if len(copied.enabledToolsets) != len(original.enabledToolsets) || !copied.enabledToolsets["ts1"] { - t.Error("enabledToolsets not copied correctly") + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter1). + WithFilter(filter2). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 2 { + t.Fatalf("Expected 2 tools after multiple filters, got %d", len(available)) } - if len(copied.additionalTools) != len(original.additionalTools) || !copied.additionalTools["extra"] { - t.Error("additionalTools not copied correctly") + + toolNames := make(map[string]bool) + for _, tool := range available { + toolNames[tool.Tool.Name] = true } - if copied.featureChecker == nil { - t.Error("featureChecker not copied correctly") + + if !toolNames["tool1"] || !toolNames["tool4"] { + t.Error("Expected tool1 and tool4 to be available") } - if len(copied.unrecognizedToolsets) != len(original.unrecognizedToolsets) || copied.unrecognizedToolsets[0] != "unknown" { - t.Error("unrecognizedToolsets not copied correctly") + if toolNames["tool2"] || toolNames["tool3"] { + t.Error("tool2 and tool3 should have been filtered out") } +} - // Verify maps are deep copied (mutations don't affect original) - copied.enabledToolsets["ts2"] = true - if original.enabledToolsets["ts2"] { - t.Error("enabledToolsets should be deep copied, not shared") +func TestBuilderFilterError(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), } - copied.additionalTools["another"] = true - if original.additionalTools["another"] { - t.Error("additionalTools should be deep copied, not shared") + + // Filter that returns an error + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + return false, fmt.Errorf("filter error") + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools when filter returns error, got %d", len(available)) + } +} + +func TestBuilderFilterWithContext(t *testing.T) { + type contextKey string + const scopeKey contextKey = "scope" + + tools := []ServerTool{ + mockTool("public_tool", "toolset1", true), + mockTool("private_tool", "toolset1", true), + } + + // Filter that checks context for scope + filter := func(ctx context.Context, tool *ServerTool) (bool, error) { + scope := ctx.Value(scopeKey) + if scope == "public" && tool.Tool.Name == "private_tool" { + return false, nil + } + return true, nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + // With public scope - private_tool should be excluded + ctxPublic := context.WithValue(context.Background(), scopeKey, "public") + availablePublic := reg.AvailableTools(ctxPublic) + if len(availablePublic) != 1 { + t.Fatalf("Expected 1 tool with public scope, got %d", len(availablePublic)) + } + if availablePublic[0].Tool.Name != "public_tool" { + t.Error("Expected only public_tool to be available") + } + + // With private scope - both tools should be available + ctxPrivate := context.WithValue(context.Background(), scopeKey, "private") + availablePrivate := reg.AvailableTools(ctxPrivate) + if len(availablePrivate) != 2 { + t.Errorf("Expected 2 tools with private scope, got %d", len(availablePrivate)) + } +} + +// Tests for interaction between Enabled, feature flags, and filters +func TestEnabledAndFeatureFlagInteraction(t *testing.T) { + // Tool with both Enabled function and feature flag + tool := mockToolWithFlags("complex_tool", "toolset1", true, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + + // Feature flag not enabled - tool should be excluded despite Enabled returning true + reg1 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + Build() + available1 := reg1.AvailableTools(context.Background()) + if len(available1) != 0 { + t.Error("Tool should be excluded when feature flag is not enabled") + } + + // Feature flag enabled - tool should be included + checker := func(_ context.Context, flag string) (bool, error) { + return flag == "my_feature", nil + } + reg2 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + Build() + available2 := reg2.AvailableTools(context.Background()) + if len(available2) != 1 { + t.Error("Tool should be included when both Enabled and feature flag pass") + } + + // Enabled returns false - tool should be excluded despite feature flag + tool.Enabled = func(_ context.Context) (bool, error) { + return false, nil + } + reg3 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + Build() + available3 := reg3.AvailableTools(context.Background()) + if len(available3) != 0 { + t.Error("Tool should be excluded when Enabled returns false") + } +} + +func TestEnabledAndBuilderFilterInteraction(t *testing.T) { + tool := mockTool("test_tool", "toolset1", true) + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + + // Filter that excludes the tool + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + return false, nil + } + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Error("Tool should be excluded when filter returns false, despite Enabled returning true") + } +} + +func TestAllFiltersInteraction(t *testing.T) { + // Tool with Enabled, feature flag, and subject to builder filter + tool := mockToolWithFlags("complex_tool", "toolset1", true, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + return true, nil + } + + checker := func(_ context.Context, flag string) (bool, error) { + return flag == "my_feature", nil + } + + // All conditions pass - tool should be included + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 1 { + t.Error("Tool should be included when all filters pass") + } + + // Change filter to return false - tool should be excluded + filterFalse := func(_ context.Context, _ *ServerTool) (bool, error) { + return false, nil + } + + reg2 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + WithFilter(filterFalse). + Build() + + available2 := reg2.AvailableTools(context.Background()) + if len(available2) != 0 { + t.Error("Tool should be excluded when any filter fails") + } +} + +// Test FilteredTools method +func TestFilteredTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + } + + filter := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name == "tool1", nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + filtered, err := reg.FilteredTools(context.Background()) + if err != nil { + t.Fatalf("FilteredTools returned error: %v", err) + } + + if len(filtered) != 1 { + t.Fatalf("Expected 1 filtered tool, got %d", len(filtered)) + } + + if filtered[0].Tool.Name != "tool1" { + t.Errorf("Expected tool1, got %s", filtered[0].Tool.Name) + } +} + +func TestFilteredToolsMatchesAvailableTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", false), + mockTool("tool3", "toolset2", true), + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"toolset1"}). + WithReadOnly(true). + Build() + + ctx := context.Background() + filtered, err := reg.FilteredTools(ctx) + if err != nil { + t.Fatalf("FilteredTools returned error: %v", err) + } + + available := reg.AvailableTools(ctx) + + // Both methods should return the same results + if len(filtered) != len(available) { + t.Errorf("FilteredTools and AvailableTools returned different counts: %d vs %d", + len(filtered), len(available)) + } + + for i := range filtered { + if filtered[i].Tool.Name != available[i].Tool.Name { + t.Errorf("Tool at index %d differs: FilteredTools=%s, AvailableTools=%s", + i, filtered[i].Tool.Name, available[i].Tool.Name) + } + } +} + +func TestFilteringOrder(t *testing.T) { + // Test that filters are applied in the correct order: + // 1. Tool.Enabled + // 2. Feature flags + // 3. Read-only + // 4. Builder filters + // 5. Toolset/additional tools + + callOrder := []string{} + + tool := mockToolWithFlags("test_tool", "toolset1", false, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + callOrder = append(callOrder, "Enabled") + return true, nil + } + + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + callOrder = append(callOrder, "Filter") + return true, nil + } + + checker := func(_ context.Context, _ string) (bool, error) { + callOrder = append(callOrder, "FeatureFlag") + return true, nil + } + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithReadOnly(true). // This will exclude the tool (it's not read-only) + WithFeatureChecker(checker). + WithFilter(filter). + Build() + + _ = reg.AvailableTools(context.Background()) + + // Expected order: Enabled, FeatureFlag, ReadOnly (stops here because it's write tool) + expectedOrder := []string{"Enabled", "FeatureFlag"} + if len(callOrder) != len(expectedOrder) { + t.Errorf("Expected %d checks, got %d: %v", len(expectedOrder), len(callOrder), callOrder) + } + + for i, expected := range expectedOrder { + if i >= len(callOrder) || callOrder[i] != expected { + t.Errorf("At position %d: expected %s, got %v", i, expected, callOrder) + } } } diff --git a/pkg/registry/resources.go b/pkg/registry/resources.go new file mode 100644 index 000000000..99e0240c5 --- /dev/null +++ b/pkg/registry/resources.go @@ -0,0 +1,48 @@ +package registry + +import "github.com/modelcontextprotocol/go-sdk/mcp" + +// ResourceHandlerFunc is a function that takes dependencies and returns an MCP resource handler. +// This allows resources to be defined statically while their handlers are generated +// on-demand with the appropriate dependencies. +type ResourceHandlerFunc func(deps any) mcp.ResourceHandler + +// ServerResourceTemplate pairs a resource template with its toolset metadata. +type ServerResourceTemplate struct { + Template mcp.ResourceTemplate + // HandlerFunc generates the handler when given dependencies. + // This allows resources to be passed around without handlers being set up, + // and handlers are only created when needed. + HandlerFunc ResourceHandlerFunc + // Toolset identifies which toolset this resource belongs to + Toolset ToolsetMetadata + // FeatureFlagEnable specifies a feature flag that must be enabled for this resource + // to be available. If set and the flag is not enabled, the resource is omitted. + FeatureFlagEnable string + // FeatureFlagDisable specifies a feature flag that, when enabled, causes this resource + // to be omitted. Used to disable resources when a feature flag is on. + FeatureFlagDisable string +} + +// HasHandler returns true if this resource has a handler function. +func (sr *ServerResourceTemplate) HasHandler() bool { + return sr.HandlerFunc != nil +} + +// Handler returns a resource handler by calling HandlerFunc with the given dependencies. +// Panics if HandlerFunc is nil - all resources should have handlers. +func (sr *ServerResourceTemplate) Handler(deps any) mcp.ResourceHandler { + if sr.HandlerFunc == nil { + panic("HandlerFunc is nil for resource: " + sr.Template.Name) + } + return sr.HandlerFunc(deps) +} + +// NewServerResourceTemplate creates a new ServerResourceTemplate with toolset metadata. +func NewServerResourceTemplate(toolset ToolsetMetadata, resourceTemplate mcp.ResourceTemplate, handlerFn ResourceHandlerFunc) ServerResourceTemplate { + return ServerResourceTemplate{ + Template: resourceTemplate, + HandlerFunc: handlerFn, + Toolset: toolset, + } +} diff --git a/pkg/toolsets/server_tool.go b/pkg/registry/server_tool.go similarity index 90% rename from pkg/toolsets/server_tool.go rename to pkg/registry/server_tool.go index eb30f01f4..3145b693d 100644 --- a/pkg/toolsets/server_tool.go +++ b/pkg/registry/server_tool.go @@ -1,4 +1,4 @@ -package toolsets +package registry import ( "context" @@ -52,6 +52,13 @@ type ServerTool struct { // FeatureFlagDisable specifies a feature flag that, when enabled, causes this tool // to be omitted. Used to disable tools when a feature flag is on. FeatureFlagDisable string + + // Enabled is an optional function called at build/filter time to determine + // if this tool should be available. If nil, the tool is considered enabled + // (subject to FeatureFlagEnable/FeatureFlagDisable checks). + // The context carries request-scoped information for the consumer to use. + // Returns (enabled, error). On error, the tool should be treated as disabled. + Enabled func(ctx context.Context) (bool, error) } // IsReadOnly returns true if this tool is marked as read-only via annotations. diff --git a/pkg/toolsets/toolsets.go b/pkg/toolsets/toolsets.go deleted file mode 100644 index 34e5fa923..000000000 --- a/pkg/toolsets/toolsets.go +++ /dev/null @@ -1,866 +0,0 @@ -package toolsets - -import ( - "context" - "fmt" - "os" - "slices" - "sort" - "strings" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -type ToolsetDoesNotExistError struct { - Name string -} - -func (e *ToolsetDoesNotExistError) Error() string { - return fmt.Sprintf("toolset %s does not exist", e.Name) -} - -func (e *ToolsetDoesNotExistError) Is(target error) bool { - if target == nil { - return false - } - if _, ok := target.(*ToolsetDoesNotExistError); ok { - return true - } - return false -} - -func NewToolsetDoesNotExistError(name string) *ToolsetDoesNotExistError { - return &ToolsetDoesNotExistError{Name: name} -} - -// ToolDoesNotExistError is returned when a tool is not found. -type ToolDoesNotExistError struct { - Name string -} - -func (e *ToolDoesNotExistError) Error() string { - return fmt.Sprintf("tool %s does not exist", e.Name) -} - -// NewToolDoesNotExistError creates a new ToolDoesNotExistError. -func NewToolDoesNotExistError(name string) *ToolDoesNotExistError { - return &ToolDoesNotExistError{Name: name} -} - -// ServerTool is defined in server_tool.go - -// ResourceHandlerFunc is a function that takes dependencies and returns an MCP resource handler. -// This allows resources to be defined statically while their handlers are generated -// on-demand with the appropriate dependencies. -type ResourceHandlerFunc func(deps any) mcp.ResourceHandler - -// ServerResourceTemplate pairs a resource template with its toolset metadata. -type ServerResourceTemplate struct { - Template mcp.ResourceTemplate - // HandlerFunc generates the handler when given dependencies. - // This allows resources to be passed around without handlers being set up, - // and handlers are only created when needed. - HandlerFunc ResourceHandlerFunc - // Toolset identifies which toolset this resource belongs to - Toolset ToolsetMetadata - // FeatureFlagEnable specifies a feature flag that must be enabled for this resource - // to be available. If set and the flag is not enabled, the resource is omitted. - FeatureFlagEnable string - // FeatureFlagDisable specifies a feature flag that, when enabled, causes this resource - // to be omitted. Used to disable resources when a feature flag is on. - FeatureFlagDisable string -} - -// HasHandler returns true if this resource has a handler function. -func (sr *ServerResourceTemplate) HasHandler() bool { - return sr.HandlerFunc != nil -} - -// Handler returns a resource handler by calling HandlerFunc with the given dependencies. -// Panics if HandlerFunc is nil - all resources should have handlers. -func (sr *ServerResourceTemplate) Handler(deps any) mcp.ResourceHandler { - if sr.HandlerFunc == nil { - panic("HandlerFunc is nil for resource: " + sr.Template.Name) - } - return sr.HandlerFunc(deps) -} - -// NewServerResourceTemplate creates a new ServerResourceTemplate with toolset metadata. -func NewServerResourceTemplate(toolset ToolsetMetadata, resourceTemplate mcp.ResourceTemplate, handlerFn ResourceHandlerFunc) ServerResourceTemplate { - return ServerResourceTemplate{ - Template: resourceTemplate, - HandlerFunc: handlerFn, - Toolset: toolset, - } -} - -// ServerPrompt pairs a prompt with its toolset metadata. -type ServerPrompt struct { - Prompt mcp.Prompt - Handler mcp.PromptHandler - // Toolset identifies which toolset this prompt belongs to - Toolset ToolsetMetadata - // FeatureFlagEnable specifies a feature flag that must be enabled for this prompt - // to be available. If set and the flag is not enabled, the prompt is omitted. - FeatureFlagEnable string - // FeatureFlagDisable specifies a feature flag that, when enabled, causes this prompt - // to be omitted. Used to disable prompts when a feature flag is on. - FeatureFlagDisable string -} - -// NewServerPrompt creates a new ServerPrompt with toolset metadata. -func NewServerPrompt(toolset ToolsetMetadata, prompt mcp.Prompt, handler mcp.PromptHandler) ServerPrompt { - return ServerPrompt{ - Prompt: prompt, - Handler: handler, - Toolset: toolset, - } -} - -// Registry holds a collection of tools, resources, and prompts. -// It supports immutable filtering operations that return new Registrys -// without modifying the original. This design allows for: -// - Building a full set of tools/resources/prompts once -// - Applying filters (read-only, feature flags, enabled toolsets) without mutation -// - Deterministic ordering for documentation generation -// - Lazy dependency injection only when registering with a server -type Registry struct { - // tools holds all tools in this group - tools []ServerTool - // resourceTemplates holds all resource templates in this group - resourceTemplates []ServerResourceTemplate - // prompts holds all prompts in this group - prompts []ServerPrompt - // deprecatedAliases maps old tool names to new canonical names - deprecatedAliases map[string]string - - // Filters - these control what's returned by Available* methods - // readOnly when true filters out write tools - readOnly bool - // enabledToolsets when non-nil, only include tools/resources/prompts from these toolsets - // when nil, all toolsets are enabled - enabledToolsets map[ToolsetID]bool - // additionalTools are specific tools that bypass toolset filtering (but still respect read-only) - // These are additive - a tool is included if it matches toolset filters OR is in this set - additionalTools map[string]bool - // featureChecker when non-nil, checks if a feature flag is enabled. - // Takes context and flag name, returns (enabled, error). If error, log and treat as false. - // If checker is nil, all flag checks return false. - featureChecker FeatureFlagChecker - // unrecognizedToolsets holds toolset IDs that were requested but don't match any registered toolsets - unrecognizedToolsets []string -} - -// FeatureFlagChecker is a function that checks if a feature flag is enabled. -// The context can be used to extract actor/user information for flag evaluation. -// Returns (enabled, error). If error occurs, the caller should log and treat as false. -type FeatureFlagChecker func(ctx context.Context, flagName string) (bool, error) - -// NewRegistry creates a new empty Registry. -// Use SetTools, SetResources, SetPrompts to populate it. -func NewRegistry() *Registry { - return &Registry{ - deprecatedAliases: make(map[string]string), - } -} - -// SetTools sets the tools for this group. Returns self for chaining. -func (r *Registry) SetTools(tools []ServerTool) *Registry { - r.tools = tools - return r -} - -// SetResources sets the resource templates for this group. Returns self for chaining. -func (r *Registry) SetResources(resources []ServerResourceTemplate) *Registry { - r.resourceTemplates = resources - return r -} - -// SetPrompts sets the prompts for this group. Returns self for chaining. -func (r *Registry) SetPrompts(prompts []ServerPrompt) *Registry { - r.prompts = prompts - return r -} - -// copy creates a shallow copy of the Registry for immutable operations. -func (r *Registry) copy() *Registry { - newTG := &Registry{ - tools: r.tools, // slices are shared (immutable) - resourceTemplates: r.resourceTemplates, - prompts: r.prompts, - deprecatedAliases: r.deprecatedAliases, - readOnly: r.readOnly, - featureChecker: r.featureChecker, - } - - // Copy maps if they exist - if r.enabledToolsets != nil { - newTG.enabledToolsets = make(map[ToolsetID]bool, len(r.enabledToolsets)) - for k, v := range r.enabledToolsets { - newTG.enabledToolsets[k] = v - } - } - if r.additionalTools != nil { - newTG.additionalTools = make(map[string]bool, len(r.additionalTools)) - for k, v := range r.additionalTools { - newTG.additionalTools[k] = v - } - } - newTG.unrecognizedToolsets = r.unrecognizedToolsets - - return newTG -} - -// WithReadOnly returns a new Registry with read-only mode set. -// When true, write tools are filtered out from Available* methods. -func (r *Registry) WithReadOnly(readOnly bool) *Registry { - newTG := r.copy() - newTG.readOnly = readOnly - return newTG -} - -// WithToolsets returns a new Registry that only includes items from the specified toolsets. -// Special keywords: -// - "all": enables all toolsets -// - "default": expands to toolsets marked with Default: true in their metadata -// -// Input strings are trimmed of whitespace and duplicates are removed. -// Toolset IDs that don't match any registered toolsets are tracked and can be -// retrieved via UnrecognizedToolsets() for warning purposes. -// -// Pass nil to use default toolsets. Pass an empty slice to disable all toolsets -// (useful for dynamic toolsets mode where tools are enabled on demand). -func (r *Registry) WithToolsets(toolsetIDs []string) *Registry { - newTG := r.copy() - newTG.unrecognizedToolsets = nil // reset for fresh calculation - - // Build a set of valid toolset IDs for validation - validIDs := make(map[ToolsetID]bool) - for _, t := range r.tools { - validIDs[t.Toolset.ID] = true - } - for _, r := range r.resourceTemplates { - validIDs[r.Toolset.ID] = true - } - for _, p := range r.prompts { - validIDs[p.Toolset.ID] = true - } - - // Check for "all" keyword - enables all toolsets - for _, id := range toolsetIDs { - if strings.TrimSpace(id) == "all" { - newTG.enabledToolsets = nil - return newTG - } - } - - // nil means use defaults, empty slice means no toolsets - if toolsetIDs == nil { - toolsetIDs = []string{"default"} - } - - // Expand "default" keyword, trim whitespace, collect other IDs, and track unrecognized - seen := make(map[ToolsetID]bool) - expanded := make([]ToolsetID, 0, len(toolsetIDs)) - var unrecognized []string - - for _, id := range toolsetIDs { - trimmed := strings.TrimSpace(id) - if trimmed == "" { - continue - } - if trimmed == "default" { - for _, defaultID := range r.DefaultToolsetIDs() { - if !seen[defaultID] { - seen[defaultID] = true - expanded = append(expanded, defaultID) - } - } - } else { - tsID := ToolsetID(trimmed) - if !seen[tsID] { - seen[tsID] = true - expanded = append(expanded, tsID) - // Track if this toolset doesn't exist - if !validIDs[tsID] { - unrecognized = append(unrecognized, trimmed) - } - } - } - } - - newTG.unrecognizedToolsets = unrecognized - - if len(expanded) == 0 { - newTG.enabledToolsets = make(map[ToolsetID]bool) - return newTG - } - - newTG.enabledToolsets = make(map[ToolsetID]bool, len(expanded)) - for _, id := range expanded { - newTG.enabledToolsets[id] = true - } - return newTG -} - -// UnrecognizedToolsets returns toolset IDs that were passed to WithToolsets but don't -// match any registered toolsets. This is useful for warning users about typos. -func (r *Registry) UnrecognizedToolsets() []string { - return r.unrecognizedToolsets -} - -// WithTools returns a new Registry with additional tools that bypass toolset filtering. -// These tools are additive - they will be included even if their toolset is not enabled. -// Read-only filtering still applies to these tools. -// Deprecated tool aliases are automatically resolved to their canonical names. -// Pass nil or empty slice to clear additional tools. -func (r *Registry) WithTools(toolNames []string) *Registry { - newTG := r.copy() - if len(toolNames) == 0 { - newTG.additionalTools = nil - return newTG - } - newTG.additionalTools = make(map[string]bool, len(toolNames)) - for _, name := range toolNames { - // Resolve deprecated aliases to canonical names - if canonical, isAlias := r.deprecatedAliases[name]; isAlias { - newTG.additionalTools[canonical] = true - } else { - newTG.additionalTools[name] = true - } - } - return newTG -} - -// WithFeatureChecker returns a new Registry with a feature checker function. -// The checker receives a context (for actor extraction) and feature flag name, returns (enabled, error). -// If error occurs, it will be logged and treated as false. -// If checker is nil, all feature flag checks return false (items with FeatureFlagEnable are excluded, -// items with FeatureFlagDisable are included). -func (r *Registry) WithFeatureChecker(checker FeatureFlagChecker) *Registry { - newTG := r.copy() - newTG.featureChecker = checker - return newTG -} - -// MCP method constants for use with ForMCPRequest. -const ( - MCPMethodInitialize = "initialize" - MCPMethodToolsList = "tools/list" - MCPMethodToolsCall = "tools/call" - MCPMethodResourcesList = "resources/list" - MCPMethodResourcesRead = "resources/read" - MCPMethodResourcesTemplatesList = "resources/templates/list" - MCPMethodPromptsList = "prompts/list" - MCPMethodPromptsGet = "prompts/get" -) - -// ForMCPRequest returns a Registry optimized for a specific MCP request. -// This is designed for servers that create a new instance per request (like the remote server), -// allowing them to only register the items needed for that specific request rather than all ~90 tools. -// -// Parameters: -// - method: The MCP method being called (use MCP* constants) -// - itemName: Name of specific item for call/get methods (tool name, resource URI, or prompt name) -// -// Returns a new Registry containing only the items relevant to the request: -// - MCPMethodInitialize: Empty (capabilities are set via ServerOptions, not registration) -// - MCPMethodToolsList: All available tools (no resources/prompts) -// - MCPMethodToolsCall: Only the named tool -// - MCPMethodResourcesList, MCPMethodResourcesTemplatesList: All available resources (no tools/prompts) -// - MCPMethodResourcesRead: Only the named resource template -// - MCPMethodPromptsList: All available prompts (no tools/resources) -// - MCPMethodPromptsGet: Only the named prompt -// - Unknown methods: Empty (no items registered) -// -// All existing filters (read-only, toolsets, etc.) still apply to the returned items. -func (r *Registry) ForMCPRequest(method string, itemName string) *Registry { - result := r.copy() - - // Helper to clear all item types - clearAll := func() { - result.tools = []ServerTool{} - result.resourceTemplates = []ServerResourceTemplate{} - result.prompts = []ServerPrompt{} - } - - switch method { - case MCPMethodInitialize: - clearAll() - case MCPMethodToolsList: - result.resourceTemplates, result.prompts = nil, nil - case MCPMethodToolsCall: - result.resourceTemplates, result.prompts = nil, nil - if itemName != "" { - result.tools = r.filterToolsByName(itemName) - } - case MCPMethodResourcesList, MCPMethodResourcesTemplatesList: - result.tools, result.prompts = nil, nil - case MCPMethodResourcesRead: - result.tools, result.prompts = nil, nil - if itemName != "" { - result.resourceTemplates = r.filterResourcesByURI(itemName) - } - case MCPMethodPromptsList: - result.tools, result.resourceTemplates = nil, nil - case MCPMethodPromptsGet: - result.tools, result.resourceTemplates = nil, nil - if itemName != "" { - result.prompts = r.filterPromptsByName(itemName) - } - default: - clearAll() - } - - return result -} - -// filterToolsByName returns tools matching the given name, checking deprecated aliases. -// Returns from the current tools slice (respects existing filter chain). -func (r *Registry) filterToolsByName(name string) []ServerTool { - // First check for exact match - for i := range r.tools { - if r.tools[i].Tool.Name == name { - return []ServerTool{r.tools[i]} - } - } - // Check if name is a deprecated alias - if canonical, isAlias := r.deprecatedAliases[name]; isAlias { - for i := range r.tools { - if r.tools[i].Tool.Name == canonical { - return []ServerTool{r.tools[i]} - } - } - } - return []ServerTool{} -} - -// filterResourcesByURI returns resource templates matching the given URI pattern. -func (r *Registry) filterResourcesByURI(uri string) []ServerResourceTemplate { - for i := range r.resourceTemplates { - // Check if URI matches the template pattern (exact match on URITemplate string) - if r.resourceTemplates[i].Template.URITemplate == uri { - return []ServerResourceTemplate{r.resourceTemplates[i]} - } - } - return []ServerResourceTemplate{} -} - -// filterPromptsByName returns prompts matching the given name. -func (r *Registry) filterPromptsByName(name string) []ServerPrompt { - for i := range r.prompts { - if r.prompts[i].Prompt.Name == name { - return []ServerPrompt{r.prompts[i]} - } - } - return []ServerPrompt{} -} - -// WithDeprecatedToolAliases returns a new Registry with the given deprecated aliases added. -// Aliases map old tool names to new canonical names. -func (r *Registry) WithDeprecatedToolAliases(aliases map[string]string) *Registry { - newTG := r.copy() - // Ensure we have a fresh map - newTG.deprecatedAliases = make(map[string]string, len(r.deprecatedAliases)+len(aliases)) - for k, v := range r.deprecatedAliases { - newTG.deprecatedAliases[k] = v - } - for oldName, newName := range aliases { - newTG.deprecatedAliases[oldName] = newName - } - return newTG -} - -// isToolsetEnabled checks if a toolset is enabled based on current filters. -func (r *Registry) isToolsetEnabled(toolsetID ToolsetID) bool { - // Check enabled toolsets filter - if r.enabledToolsets != nil { - return r.enabledToolsets[toolsetID] - } - return true -} - -// checkFeatureFlag checks a feature flag using the feature checker. -// Returns false if checker is nil or returns an error (errors are logged). -func (r *Registry) checkFeatureFlag(ctx context.Context, flagName string) bool { - if r.featureChecker == nil || flagName == "" { - return false - } - enabled, err := r.featureChecker(ctx, flagName) - if err != nil { - fmt.Fprintf(os.Stderr, "Feature flag check error for %q: %v\n", flagName, err) - return false - } - return enabled -} - -// isFeatureFlagAllowed checks if an item passes feature flag filtering. -// - If FeatureFlagEnable is set, the item is only allowed if the flag is enabled -// - If FeatureFlagDisable is set, the item is excluded if the flag is enabled -func (r *Registry) isFeatureFlagAllowed(ctx context.Context, enableFlag, disableFlag string) bool { - // Check enable flag - item requires this flag to be on - if enableFlag != "" && !r.checkFeatureFlag(ctx, enableFlag) { - return false - } - // Check disable flag - item is excluded if this flag is on - if disableFlag != "" && r.checkFeatureFlag(ctx, disableFlag) { - return false - } - return true -} - -// isToolEnabled checks if a specific tool is enabled based on current filters. -func (r *Registry) isToolEnabled(ctx context.Context, tool *ServerTool) bool { - // Check read-only filter first (applies to all tools) - if r.readOnly && !tool.IsReadOnly() { - return false - } - // Check feature flags - if !r.isFeatureFlagAllowed(ctx, tool.FeatureFlagEnable, tool.FeatureFlagDisable) { - return false - } - // Check if tool is in additionalTools (bypasses toolset filter) - if r.additionalTools != nil && r.additionalTools[tool.Tool.Name] { - return true - } - // Check toolset filter - if !r.isToolsetEnabled(tool.Toolset.ID) { - return false - } - return true -} - -// AvailableTools returns the tools that pass all current filters, -// sorted deterministically by toolset ID, then tool name. -// The context is used for feature flag evaluation. -func (r *Registry) AvailableTools(ctx context.Context) []ServerTool { - var result []ServerTool - for i := range r.tools { - tool := &r.tools[i] - if r.isToolEnabled(ctx, tool) { - result = append(result, *tool) - } - } - - // Sort deterministically: by toolset ID, then by tool name - sort.Slice(result, func(i, j int) bool { - if result[i].Toolset.ID != result[j].Toolset.ID { - return result[i].Toolset.ID < result[j].Toolset.ID - } - return result[i].Tool.Name < result[j].Tool.Name - }) - - return result -} - -// AvailableResourceTemplates returns resource templates that pass all current filters, -// sorted deterministically by toolset ID, then template name. -// The context is used for feature flag evaluation. -func (r *Registry) AvailableResourceTemplates(ctx context.Context) []ServerResourceTemplate { - var result []ServerResourceTemplate - for i := range r.resourceTemplates { - res := &r.resourceTemplates[i] - // Check feature flags - if !r.isFeatureFlagAllowed(ctx, res.FeatureFlagEnable, res.FeatureFlagDisable) { - continue - } - if r.isToolsetEnabled(res.Toolset.ID) { - result = append(result, *res) - } - } - - // Sort deterministically: by toolset ID, then by template name - sort.Slice(result, func(i, j int) bool { - if result[i].Toolset.ID != result[j].Toolset.ID { - return result[i].Toolset.ID < result[j].Toolset.ID - } - return result[i].Template.Name < result[j].Template.Name - }) - - return result -} - -// AvailablePrompts returns prompts that pass all current filters, -// sorted deterministically by toolset ID, then prompt name. -// The context is used for feature flag evaluation. -func (r *Registry) AvailablePrompts(ctx context.Context) []ServerPrompt { - var result []ServerPrompt - for i := range r.prompts { - prompt := &r.prompts[i] - // Check feature flags - if !r.isFeatureFlagAllowed(ctx, prompt.FeatureFlagEnable, prompt.FeatureFlagDisable) { - continue - } - if r.isToolsetEnabled(prompt.Toolset.ID) { - result = append(result, *prompt) - } - } - - // Sort deterministically: by toolset ID, then by prompt name - sort.Slice(result, func(i, j int) bool { - if result[i].Toolset.ID != result[j].Toolset.ID { - return result[i].Toolset.ID < result[j].Toolset.ID - } - return result[i].Prompt.Name < result[j].Prompt.Name - }) - - return result -} - -// ToolsetIDs returns a sorted list of unique toolset IDs from all tools in this group. -func (r *Registry) ToolsetIDs() []ToolsetID { - seen := make(map[ToolsetID]bool) - for i := range r.tools { - seen[r.tools[i].Toolset.ID] = true - } - for i := range r.resourceTemplates { - seen[r.resourceTemplates[i].Toolset.ID] = true - } - for i := range r.prompts { - seen[r.prompts[i].Toolset.ID] = true - } - - ids := make([]ToolsetID, 0, len(seen)) - for id := range seen { - ids = append(ids, id) - } - sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) - return ids -} - -// DefaultToolsetIDs returns the IDs of toolsets marked as Default in their metadata. -// The IDs are returned in sorted order for deterministic output. -func (r *Registry) DefaultToolsetIDs() []ToolsetID { - seen := make(map[ToolsetID]bool) - for i := range r.tools { - if r.tools[i].Toolset.Default { - seen[r.tools[i].Toolset.ID] = true - } - } - for i := range r.resourceTemplates { - if r.resourceTemplates[i].Toolset.Default { - seen[r.resourceTemplates[i].Toolset.ID] = true - } - } - for i := range r.prompts { - if r.prompts[i].Toolset.Default { - seen[r.prompts[i].Toolset.ID] = true - } - } - - ids := make([]ToolsetID, 0, len(seen)) - for id := range seen { - ids = append(ids, id) - } - sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) - return ids -} - -// ToolsetDescriptions returns a map of toolset ID to description for all toolsets. -func (r *Registry) ToolsetDescriptions() map[ToolsetID]string { - descriptions := make(map[ToolsetID]string) - for i := range r.tools { - t := &r.tools[i] - if t.Toolset.Description != "" { - descriptions[t.Toolset.ID] = t.Toolset.Description - } - } - for i := range r.resourceTemplates { - r := &r.resourceTemplates[i] - if r.Toolset.Description != "" { - descriptions[r.Toolset.ID] = r.Toolset.Description - } - } - for i := range r.prompts { - p := &r.prompts[i] - if p.Toolset.Description != "" { - descriptions[p.Toolset.ID] = p.Toolset.Description - } - } - return descriptions -} - -// ToolsForToolset returns all tools belonging to a specific toolset. -// This method bypasses the toolset enabled filter (for dynamic toolset registration), -// but still respects the read-only filter. -func (r *Registry) ToolsForToolset(toolsetID ToolsetID) []ServerTool { - var result []ServerTool - for i := range r.tools { - tool := &r.tools[i] - // Only check read-only filter, not toolset enabled filter - if tool.Toolset.ID == toolsetID { - if r.readOnly && !tool.IsReadOnly() { - continue - } - result = append(result, *tool) - } - } - - // Sort by tool name for deterministic order - sort.Slice(result, func(i, j int) bool { - return result[i].Tool.Name < result[j].Tool.Name - }) - - return result -} - -// RegisterTools registers all available tools with the server using the provided dependencies. -// The context is used for feature flag evaluation. -func (r *Registry) RegisterTools(ctx context.Context, s *mcp.Server, deps any) { - for _, tool := range r.AvailableTools(ctx) { - tool.RegisterFunc(s, deps) - } -} - -// RegisterResourceTemplates registers all available resource templates with the server. -// The context is used for feature flag evaluation. -func (r *Registry) RegisterResourceTemplates(ctx context.Context, s *mcp.Server, deps any) { - for _, res := range r.AvailableResourceTemplates(ctx) { - s.AddResourceTemplate(&res.Template, res.Handler(deps)) - } -} - -// RegisterPrompts registers all available prompts with the server. -// The context is used for feature flag evaluation. -func (r *Registry) RegisterPrompts(ctx context.Context, s *mcp.Server) { - for _, prompt := range r.AvailablePrompts(ctx) { - s.AddPrompt(&prompt.Prompt, prompt.Handler) - } -} - -// RegisterAll registers all available tools, resources, and prompts with the server. -// The context is used for feature flag evaluation. -func (r *Registry) RegisterAll(ctx context.Context, s *mcp.Server, deps any) { - r.RegisterTools(ctx, s, deps) - r.RegisterResourceTemplates(ctx, s, deps) - r.RegisterPrompts(ctx, s) -} - -// ResolveToolAliases resolves deprecated tool aliases to their canonical names. -// It logs a warning to stderr for each deprecated alias that is resolved. -// Returns: -// - resolved: tool names with aliases replaced by canonical names -// - aliasesUsed: map of oldName → newName for each alias that was resolved -func (r *Registry) ResolveToolAliases(toolNames []string) (resolved []string, aliasesUsed map[string]string) { - resolved = make([]string, 0, len(toolNames)) - aliasesUsed = make(map[string]string) - for _, toolName := range toolNames { - if canonicalName, isAlias := r.deprecatedAliases[toolName]; isAlias { - fmt.Fprintf(os.Stderr, "Warning: tool %q is deprecated, use %q instead\n", toolName, canonicalName) - aliasesUsed[toolName] = canonicalName - resolved = append(resolved, canonicalName) - } else { - resolved = append(resolved, toolName) - } - } - return resolved, aliasesUsed -} - -// FindToolByName searches all tools for one matching the given name. -// Returns the tool, its toolset ID, and an error if not found. -// This searches ALL tools regardless of filters. -func (r *Registry) FindToolByName(toolName string) (*ServerTool, ToolsetID, error) { - for i := range r.tools { - tool := &r.tools[i] - if tool.Tool.Name == toolName { - return tool, tool.Toolset.ID, nil - } - } - return nil, "", NewToolDoesNotExistError(toolName) -} - -// HasToolset checks if any tool/resource/prompt belongs to the given toolset. -func (r *Registry) HasToolset(toolsetID ToolsetID) bool { - for i := range r.tools { - if r.tools[i].Toolset.ID == toolsetID { - return true - } - } - for i := range r.resourceTemplates { - if r.resourceTemplates[i].Toolset.ID == toolsetID { - return true - } - } - for i := range r.prompts { - if r.prompts[i].Toolset.ID == toolsetID { - return true - } - } - return false -} - -// EnabledToolsetIDs returns the list of enabled toolset IDs based on current filters. -// Returns all toolset IDs if no filter is set. -func (r *Registry) EnabledToolsetIDs() []ToolsetID { - if r.enabledToolsets == nil { - return r.ToolsetIDs() - } - - ids := make([]ToolsetID, 0, len(r.enabledToolsets)) - for id := range r.enabledToolsets { - if r.HasToolset(id) { - ids = append(ids, id) - } - } - sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) - return ids -} - -// IsToolsetEnabled checks if a toolset is currently enabled based on filters. -func (r *Registry) IsToolsetEnabled(toolsetID ToolsetID) bool { - return r.isToolsetEnabled(toolsetID) -} - -// EnableToolset marks a toolset as enabled in this group. -// This is used by dynamic toolset management to track which toolsets have been enabled. -func (r *Registry) EnableToolset(toolsetID ToolsetID) { - if r.enabledToolsets == nil { - // nil means all enabled, so nothing to do - return - } - r.enabledToolsets[toolsetID] = true -} - -// AllTools returns all tools without any filtering, sorted deterministically. -func (r *Registry) AllTools() []ServerTool { - result := slices.Clone(r.tools) - - // Sort deterministically: by toolset ID, then by tool name - sort.Slice(result, func(i, j int) bool { - if result[i].Toolset.ID != result[j].Toolset.ID { - return result[i].Toolset.ID < result[j].Toolset.ID - } - return result[i].Tool.Name < result[j].Tool.Name - }) - - return result -} - -// AvailableToolsets returns the unique toolsets that have tools, in sorted order. -// This is the ordered intersection of toolsets with reality - only toolsets that -// actually contain tools are returned, sorted by toolset ID. -// Optional exclude parameter filters out specific toolset IDs from the result. -func (r *Registry) AvailableToolsets(exclude ...ToolsetID) []ToolsetMetadata { - tools := r.AllTools() - if len(tools) == 0 { - return nil - } - - // Build exclude set for O(1) lookup - excludeSet := make(map[ToolsetID]bool, len(exclude)) - for _, id := range exclude { - excludeSet[id] = true - } - - var result []ToolsetMetadata - var lastID ToolsetID - for _, tool := range tools { - if tool.Toolset.ID != lastID { - lastID = tool.Toolset.ID - if !excludeSet[lastID] { - result = append(result, tool.Toolset) - } - } - } - return result -} diff --git a/script/conformance-test b/script/conformance-test new file mode 100755 index 000000000..3ff0a55c2 --- /dev/null +++ b/script/conformance-test @@ -0,0 +1,432 @@ +#!/bin/bash +set -e + +# Conformance test script for comparing MCP server behavior between branches +# Builds both main and current branch, runs various flag combinations, +# and produces a conformance report with timing and diffs. +# +# Output: +# - Progress/status messages go to stderr (for visibility in CI) +# - Final report summary goes to stdout (for piping/capture) + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_DIR="$(dirname "$SCRIPT_DIR")" +REPORT_DIR="$PROJECT_DIR/conformance-report" +CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD) + +# Colors for output (only used on stderr) +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Helper to print to stderr +log() { + echo -e "$@" >&2 +} + +log "${BLUE}=== MCP Server Conformance Test ===${NC}" +log "Current branch: $CURRENT_BRANCH" +log "Report directory: $REPORT_DIR" + +# Find the common ancestor +MERGE_BASE=$(git merge-base HEAD origin/main) +log "Comparing against merge-base: $MERGE_BASE" +log "" + +# Create report directory +rm -rf "$REPORT_DIR" +mkdir -p "$REPORT_DIR"/{main,branch,diffs} + +# Build binaries +log "${YELLOW}Building binaries...${NC}" + +log "Building current branch ($CURRENT_BRANCH)..." +go build -o "$REPORT_DIR/branch/github-mcp-server" ./cmd/github-mcp-server +BRANCH_BUILD_OK=$? + +log "Building main branch (using temp worktree at merge-base)..." +TEMP_WORKTREE=$(mktemp -d) +git worktree add --quiet "$TEMP_WORKTREE" "$MERGE_BASE" +(cd "$TEMP_WORKTREE" && go build -o "$REPORT_DIR/main/github-mcp-server" ./cmd/github-mcp-server) +MAIN_BUILD_OK=$? +git worktree remove --force "$TEMP_WORKTREE" + +if [ $BRANCH_BUILD_OK -ne 0 ] || [ $MAIN_BUILD_OK -ne 0 ]; then + log "${RED}Build failed!${NC}" + exit 1 +fi + +log "${GREEN}Both binaries built successfully${NC}" +log "" + +# MCP JSON-RPC messages +INIT_MSG='{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"conformance-test","version":"1.0.0"}}}' +INITIALIZED_MSG='{"jsonrpc":"2.0","method":"notifications/initialized","params":{}}' +LIST_TOOLS_MSG='{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}' +LIST_RESOURCES_MSG='{"jsonrpc":"2.0","id":3,"method":"resources/listTemplates","params":{}}' +LIST_PROMPTS_MSG='{"jsonrpc":"2.0","id":4,"method":"prompts/list","params":{}}' + +# Dynamic toolset management tool calls (for dynamic mode testing) +LIST_TOOLSETS_MSG='{"jsonrpc":"2.0","id":10,"method":"tools/call","params":{"name":"list_available_toolsets","arguments":{}}}' +GET_TOOLSET_TOOLS_MSG='{"jsonrpc":"2.0","id":11,"method":"tools/call","params":{"name":"get_toolset_tools","arguments":{"toolset":"repos"}}}' +ENABLE_TOOLSET_MSG='{"jsonrpc":"2.0","id":12,"method":"tools/call","params":{"name":"enable_toolset","arguments":{"toolset":"repos"}}}' +LIST_TOOLSETS_AFTER_MSG='{"jsonrpc":"2.0","id":13,"method":"tools/call","params":{"name":"list_available_toolsets","arguments":{}}}' + +# Function to normalize JSON for comparison +# Sorts all arrays (including nested ones) and formats consistently +# Also handles embedded JSON strings in "text" fields (from tool call responses) +normalize_json() { + local file="$1" + if [ -s "$file" ]; then + # First, try to parse and re-serialize any JSON embedded in text fields + # This handles tool call responses where the result is JSON-in-a-string + jq -S ' + # Function to sort arrays recursively + def deep_sort: + if type == "array" then + [.[] | deep_sort] | sort_by(tostring) + elif type == "object" then + to_entries | map(.value |= deep_sort) | from_entries + else + . + end; + + # Walk the structure, and for any "text" field that looks like JSON array/object, parse and sort it + walk( + if type == "object" and .text and (.text | type == "string") and ((.text | startswith("[")) or (.text | startswith("{"))) then + .text = ((.text | fromjson | deep_sort) | tojson) + else + . + end + ) | deep_sort + ' "$file" 2>/dev/null > "${file}.tmp" && mv "${file}.tmp" "$file" + fi +} + +# Function to run MCP server and capture output with timing +run_mcp_test() { + local binary="$1" + local name="$2" + local flags="$3" + local output_prefix="$4" + + local start_time end_time duration + start_time=$(date +%s.%N) + + # Run the server with all list commands - each response is on its own line + output=$( + ( + echo "$INIT_MSG" + echo "$INITIALIZED_MSG" + echo "$LIST_TOOLS_MSG" + echo "$LIST_RESOURCES_MSG" + echo "$LIST_PROMPTS_MSG" + sleep 0.5 + ) | GITHUB_PERSONAL_ACCESS_TOKEN=1 $binary stdio $flags 2>/dev/null + ) + + end_time=$(date +%s.%N) + duration=$(echo "$end_time - $start_time" | bc) + + # Parse and save each response by matching JSON-RPC id + # Each line is a separate JSON response + echo "$output" | while IFS= read -r line; do + id=$(echo "$line" | jq -r '.id // empty' 2>/dev/null) + case "$id" in + 1) echo "$line" | jq -S '.' > "${output_prefix}_initialize.json" 2>/dev/null ;; + 2) echo "$line" | jq -S '.' > "${output_prefix}_tools.json" 2>/dev/null ;; + 3) echo "$line" | jq -S '.' > "${output_prefix}_resources.json" 2>/dev/null ;; + 4) echo "$line" | jq -S '.' > "${output_prefix}_prompts.json" 2>/dev/null ;; + esac + done + + # Create empty files if not created (in case of errors or missing responses) + touch "${output_prefix}_initialize.json" "${output_prefix}_tools.json" \ + "${output_prefix}_resources.json" "${output_prefix}_prompts.json" + + # Normalize all JSON files for consistent comparison (sorts arrays, keys) + for endpoint in initialize tools resources prompts; do + normalize_json "${output_prefix}_${endpoint}.json" + done + + echo "$duration" +} + +# Function to run MCP server with dynamic tool calls (for dynamic mode testing) +run_mcp_dynamic_test() { + local binary="$1" + local name="$2" + local flags="$3" + local output_prefix="$4" + + local start_time end_time duration + start_time=$(date +%s.%N) + + # Run the server with dynamic tool calls in sequence: + # 1. Initialize + # 2. List available toolsets (before enable) + # 3. Get tools for repos toolset + # 4. Enable repos toolset + # 5. List available toolsets (after enable - should show repos as enabled) + output=$( + ( + echo "$INIT_MSG" + echo "$INITIALIZED_MSG" + echo "$LIST_TOOLSETS_MSG" + sleep 0.1 + echo "$GET_TOOLSET_TOOLS_MSG" + sleep 0.1 + echo "$ENABLE_TOOLSET_MSG" + sleep 0.1 + echo "$LIST_TOOLSETS_AFTER_MSG" + sleep 0.3 + ) | GITHUB_PERSONAL_ACCESS_TOKEN=1 $binary stdio $flags 2>/dev/null + ) + + end_time=$(date +%s.%N) + duration=$(echo "$end_time - $start_time" | bc) + + # Parse and save each response by matching JSON-RPC id + echo "$output" | while IFS= read -r line; do + id=$(echo "$line" | jq -r '.id // empty' 2>/dev/null) + case "$id" in + 1) echo "$line" | jq -S '.' > "${output_prefix}_initialize.json" 2>/dev/null ;; + 10) echo "$line" | jq -S '.' > "${output_prefix}_list_toolsets_before.json" 2>/dev/null ;; + 11) echo "$line" | jq -S '.' > "${output_prefix}_get_toolset_tools.json" 2>/dev/null ;; + 12) echo "$line" | jq -S '.' > "${output_prefix}_enable_toolset.json" 2>/dev/null ;; + 13) echo "$line" | jq -S '.' > "${output_prefix}_list_toolsets_after.json" 2>/dev/null ;; + esac + done + + # Create empty files if not created + touch "${output_prefix}_initialize.json" "${output_prefix}_list_toolsets_before.json" \ + "${output_prefix}_get_toolset_tools.json" "${output_prefix}_enable_toolset.json" \ + "${output_prefix}_list_toolsets_after.json" + + # Normalize all JSON files + for endpoint in initialize list_toolsets_before get_toolset_tools enable_toolset list_toolsets_after; do + normalize_json "${output_prefix}_${endpoint}.json" + done + + echo "$duration" +} + +# Test configurations - array of "name|flags|type" +# type can be "standard" or "dynamic" (for dynamic tool call testing) +declare -a TEST_CONFIGS=( + "default||standard" + "read-only|--read-only|standard" + "dynamic-toolsets|--dynamic-toolsets|standard" + "read-only+dynamic|--read-only --dynamic-toolsets|standard" + "toolsets-repos|--toolsets=repos|standard" + "toolsets-issues|--toolsets=issues|standard" + "toolsets-pull_requests|--toolsets=pull_requests|standard" + "toolsets-repos,issues|--toolsets=repos,issues|standard" + "toolsets-all|--toolsets=all|standard" + "tools-get_me|--tools=get_me|standard" + "tools-get_me,list_issues|--tools=get_me,list_issues|standard" + "toolsets-repos+read-only|--toolsets=repos --read-only|standard" + "toolsets-all+dynamic|--toolsets=all --dynamic-toolsets|standard" + "toolsets-repos+dynamic|--toolsets=repos --dynamic-toolsets|standard" + "toolsets-repos,issues+dynamic|--toolsets=repos,issues --dynamic-toolsets|standard" + "dynamic-tool-calls|--dynamic-toolsets|dynamic" +) + +# Summary arrays +declare -a TEST_NAMES +declare -a MAIN_TIMES +declare -a BRANCH_TIMES +declare -a DIFF_STATUS + +log "${YELLOW}Running conformance tests...${NC}" +log "" + +for config in "${TEST_CONFIGS[@]}"; do + IFS='|' read -r test_name flags test_type <<< "$config" + + log "${BLUE}Test: ${test_name}${NC}" + log " Flags: ${flags:-}" + log " Type: ${test_type}" + + # Create output directories + mkdir -p "$REPORT_DIR/main/$test_name" + mkdir -p "$REPORT_DIR/branch/$test_name" + mkdir -p "$REPORT_DIR/diffs/$test_name" + + if [ "$test_type" = "dynamic" ]; then + # Run dynamic tool call test + main_time=$(run_mcp_dynamic_test "$REPORT_DIR/main/github-mcp-server" "main" "$flags" "$REPORT_DIR/main/$test_name/output") + log " Main: ${main_time}s" + + branch_time=$(run_mcp_dynamic_test "$REPORT_DIR/branch/github-mcp-server" "branch" "$flags" "$REPORT_DIR/branch/$test_name/output") + log " Branch: ${branch_time}s" + + endpoints="initialize list_toolsets_before get_toolset_tools enable_toolset list_toolsets_after" + else + # Run standard test + main_time=$(run_mcp_test "$REPORT_DIR/main/github-mcp-server" "main" "$flags" "$REPORT_DIR/main/$test_name/output") + log " Main: ${main_time}s" + + branch_time=$(run_mcp_test "$REPORT_DIR/branch/github-mcp-server" "branch" "$flags" "$REPORT_DIR/branch/$test_name/output") + log " Branch: ${branch_time}s" + + endpoints="initialize tools resources prompts" + fi + + # Calculate time difference + time_diff=$(echo "$branch_time - $main_time" | bc) + if (( $(echo "$time_diff > 0" | bc -l) )); then + log " Δ Time: ${RED}+${time_diff}s (slower)${NC}" + else + log " Δ Time: ${GREEN}${time_diff}s (faster)${NC}" + fi + + # Generate diffs for each endpoint + has_diff=false + for endpoint in $endpoints; do + main_file="$REPORT_DIR/main/$test_name/output_${endpoint}.json" + branch_file="$REPORT_DIR/branch/$test_name/output_${endpoint}.json" + diff_file="$REPORT_DIR/diffs/$test_name/${endpoint}.diff" + + if ! diff -u "$main_file" "$branch_file" > "$diff_file" 2>/dev/null; then + has_diff=true + lines=$(wc -l < "$diff_file" | tr -d ' ') + log " ${YELLOW}${endpoint}: DIFF (${lines} lines)${NC}" + else + rm -f "$diff_file" # No diff, remove empty file + log " ${GREEN}${endpoint}: OK${NC}" + fi + done + + # Store results + TEST_NAMES+=("$test_name") + MAIN_TIMES+=("$main_time") + BRANCH_TIMES+=("$branch_time") + if [ "$has_diff" = true ]; then + DIFF_STATUS+=("DIFF") + else + DIFF_STATUS+=("OK") + fi + + log "" +done + +# Generate summary report +REPORT_FILE="$REPORT_DIR/CONFORMANCE_REPORT.md" + +cat > "$REPORT_FILE" << EOF +# MCP Server Conformance Report + +Generated: $(date) +Current Branch: $CURRENT_BRANCH +Compared Against: merge-base ($MERGE_BASE) + +## Summary + +| Test | Main Time | Branch Time | Δ Time | Status | +|------|-----------|-------------|--------|--------| +EOF + +total_main=0 +total_branch=0 +diff_count=0 +ok_count=0 + +for i in "${!TEST_NAMES[@]}"; do + name="${TEST_NAMES[$i]}" + main_t="${MAIN_TIMES[$i]}" + branch_t="${BRANCH_TIMES[$i]}" + status="${DIFF_STATUS[$i]}" + + delta=$(echo "$branch_t - $main_t" | bc) + if (( $(echo "$delta > 0" | bc -l) )); then + delta_str="+${delta}s" + else + delta_str="${delta}s" + fi + + if [ "$status" = "DIFF" ]; then + status_str="⚠️ DIFF" + ((diff_count++)) || true + else + status_str="✅ OK" + ((ok_count++)) || true + fi + + total_main=$(echo "$total_main + $main_t" | bc) + total_branch=$(echo "$total_branch + $branch_t" | bc) + + echo "| $name | ${main_t}s | ${branch_t}s | $delta_str | $status_str |" >> "$REPORT_FILE" +done + +total_delta=$(echo "$total_branch - $total_main" | bc) +if (( $(echo "$total_delta > 0" | bc -l) )); then + total_delta_str="+${total_delta}s" +else + total_delta_str="${total_delta}s" +fi + +cat >> "$REPORT_FILE" << EOF +| **TOTAL** | **${total_main}s** | **${total_branch}s** | **$total_delta_str** | **$ok_count OK / $diff_count DIFF** | + +## Statistics + +- **Tests Passed (no diff):** $ok_count +- **Tests with Differences:** $diff_count +- **Total Main Time:** ${total_main}s +- **Total Branch Time:** ${total_branch}s +- **Overall Time Delta:** $total_delta_str + +## Detailed Diffs + +EOF + +# Add diff details to report +for i in "${!TEST_NAMES[@]}"; do + name="${TEST_NAMES[$i]}" + status="${DIFF_STATUS[$i]}" + + if [ "$status" = "DIFF" ]; then + echo "### $name" >> "$REPORT_FILE" + echo "" >> "$REPORT_FILE" + + # Check all possible endpoints + for endpoint in initialize tools resources prompts list_toolsets_before get_toolset_tools enable_toolset list_toolsets_after; do + diff_file="$REPORT_DIR/diffs/$name/${endpoint}.diff" + if [ -f "$diff_file" ] && [ -s "$diff_file" ]; then + echo "#### ${endpoint}" >> "$REPORT_FILE" + echo '```diff' >> "$REPORT_FILE" + cat "$diff_file" >> "$REPORT_FILE" + echo '```' >> "$REPORT_FILE" + echo "" >> "$REPORT_FILE" + fi + done + fi +done + +log "${BLUE}=== Conformance Test Complete ===${NC}" +log "" +log "Report: ${GREEN}$REPORT_FILE${NC}" +log "" + +# Output summary to stdout (for CI capture) +echo "=== Conformance Test Summary ===" +echo "Tests passed: $ok_count" +echo "Tests with diffs: $diff_count" +echo "Total main time: ${total_main}s" +echo "Total branch time: ${total_branch}s" +echo "Time delta: $total_delta_str" + +if [ $diff_count -gt 0 ]; then + log "" + log "${YELLOW}⚠️ Some tests have differences. Review the diffs in:${NC}" + log " $REPORT_DIR/diffs/" + echo "" + echo "RESULT: DIFFERENCES FOUND" + # Don't exit with error - diffs may be intentional improvements +else + echo "" + echo "RESULT: ALL TESTS PASSED" +fi