diff --git a/pkg/github/context_tools.go b/pkg/github/context_tools.go index 3fe622379..e8043731a 100644 --- a/pkg/github/context_tools.go +++ b/pkg/github/context_tools.go @@ -6,6 +6,7 @@ import ( "time" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/jsonschema-go/jsonschema" @@ -36,8 +37,9 @@ type UserDetails struct { } // GetMe creates a tool to get details of the authenticated user. -func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetMe(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_me", Description: t("TOOL_GET_ME_DESCRIPTION", "Get details of the authenticated GitHub user. Use this when a request is about the user's own profile for GitHub. Or when information is missing to build other tool calls."), Annotations: &mcp.ToolAnnotations{ @@ -48,50 +50,53 @@ func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Too // OpenAI strict mode requires the properties field to be present. InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - user, res, err := client.Users.Get(ctx, "") - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get user", - res, - err, - ), nil, err - } + user, res, err := client.Users.Get(ctx, "") + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get user", + res, + err, + ), nil, nil + } - // Create minimal user representation instead of returning full user object - minimalUser := MinimalUser{ - Login: user.GetLogin(), - ID: user.GetID(), - ProfileURL: user.GetHTMLURL(), - AvatarURL: user.GetAvatarURL(), - Details: &UserDetails{ - Name: user.GetName(), - Company: user.GetCompany(), - Blog: user.GetBlog(), - Location: user.GetLocation(), - Email: user.GetEmail(), - Hireable: user.GetHireable(), - Bio: user.GetBio(), - TwitterUsername: user.GetTwitterUsername(), - PublicRepos: user.GetPublicRepos(), - PublicGists: user.GetPublicGists(), - Followers: user.GetFollowers(), - Following: user.GetFollowing(), - CreatedAt: user.GetCreatedAt().Time, - UpdatedAt: user.GetUpdatedAt().Time, - PrivateGists: user.GetPrivateGists(), - TotalPrivateRepos: user.GetTotalPrivateRepos(), - OwnedPrivateRepos: user.GetOwnedPrivateRepos(), - }, - } + // Create minimal user representation instead of returning full user object + minimalUser := MinimalUser{ + Login: user.GetLogin(), + ID: user.GetID(), + ProfileURL: user.GetHTMLURL(), + AvatarURL: user.GetAvatarURL(), + Details: &UserDetails{ + Name: user.GetName(), + Company: user.GetCompany(), + Blog: user.GetBlog(), + Location: user.GetLocation(), + Email: user.GetEmail(), + Hireable: user.GetHireable(), + Bio: user.GetBio(), + TwitterUsername: user.GetTwitterUsername(), + PublicRepos: user.GetPublicRepos(), + PublicGists: user.GetPublicGists(), + Followers: user.GetFollowers(), + Following: user.GetFollowing(), + CreatedAt: user.GetCreatedAt().Time, + UpdatedAt: user.GetUpdatedAt().Time, + PrivateGists: user.GetPrivateGists(), + TotalPrivateRepos: user.GetTotalPrivateRepos(), + OwnedPrivateRepos: user.GetOwnedPrivateRepos(), + }, + } - return MarshalledTextResult(minimalUser), nil, nil - }) + return MarshalledTextResult(minimalUser), nil, nil + } + }, + ) } type TeamInfo struct { @@ -105,8 +110,9 @@ type OrganizationTeams struct { Teams []TeamInfo `json:"teams"` } -func GetTeams(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetTeams(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_teams", Description: t("TOOL_GET_TEAMS_DESCRIPTION", "Get details of the teams the user is a member of. Limited to organizations accessible with current credentials"), Annotations: &mcp.ToolAnnotations{ @@ -123,84 +129,88 @@ func GetTeams(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations }, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - user, err := OptionalParam[string](args, "user") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - var username string - if user != "" { - username = user - } else { - client, err := getClient(ctx) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + user, err := OptionalParam[string](args, "user") if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + return utils.NewToolResultError(err.Error()), nil, nil } - userResp, res, err := client.Users.Get(ctx, "") + var username string + if user != "" { + username = user + } else { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + userResp, res, err := client.Users.Get(ctx, "") + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get user", + res, + err, + ), nil, nil + } + username = userResp.GetLogin() + } + + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get user", - res, - err, - ), nil, nil + return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil } - username = userResp.GetLogin() - } - gqlClient, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil - } + var q struct { + User struct { + Organizations struct { + Nodes []struct { + Login githubv4.String + Teams struct { + Nodes []struct { + Name githubv4.String + Slug githubv4.String + Description githubv4.String + } + } `graphql:"teams(first: 100, userLogins: [$login])"` + } + } `graphql:"organizations(first: 100)"` + } `graphql:"user(login: $login)"` + } + vars := map[string]interface{}{ + "login": githubv4.String(username), + } + if err := gqlClient.Query(ctx, &q, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find teams", err), nil, nil + } - var q struct { - User struct { - Organizations struct { - Nodes []struct { - Login githubv4.String - Teams struct { - Nodes []struct { - Name githubv4.String - Slug githubv4.String - Description githubv4.String - } - } `graphql:"teams(first: 100, userLogins: [$login])"` - } - } `graphql:"organizations(first: 100)"` - } `graphql:"user(login: $login)"` - } - vars := map[string]interface{}{ - "login": githubv4.String(username), - } - if err := gqlClient.Query(ctx, &q, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find teams", err), nil, nil - } + var organizations []OrganizationTeams + for _, org := range q.User.Organizations.Nodes { + orgTeams := OrganizationTeams{ + Org: string(org.Login), + Teams: make([]TeamInfo, 0, len(org.Teams.Nodes)), + } - var organizations []OrganizationTeams - for _, org := range q.User.Organizations.Nodes { - orgTeams := OrganizationTeams{ - Org: string(org.Login), - Teams: make([]TeamInfo, 0, len(org.Teams.Nodes)), - } + for _, team := range org.Teams.Nodes { + orgTeams.Teams = append(orgTeams.Teams, TeamInfo{ + Name: string(team.Name), + Slug: string(team.Slug), + Description: string(team.Description), + }) + } - for _, team := range org.Teams.Nodes { - orgTeams.Teams = append(orgTeams.Teams, TeamInfo{ - Name: string(team.Name), - Slug: string(team.Slug), - Description: string(team.Description), - }) + organizations = append(organizations, orgTeams) } - organizations = append(organizations, orgTeams) + return MarshalledTextResult(organizations), nil, nil } - - return MarshalledTextResult(organizations), nil, nil - } + }, + ) } -func GetTeamMembers(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetTeamMembers(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_team_members", Description: t("TOOL_GET_TEAM_MEMBERS_DESCRIPTION", "Get member usernames of a specific team in an organization. Limited to organizations accessible with current credentials"), Annotations: &mcp.ToolAnnotations{ @@ -222,46 +232,49 @@ func GetTeamMembers(getGQLClient GetGQLClientFn, t translations.TranslationHelpe Required: []string{"org", "team_slug"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - org, err := RequiredParam[string](args, "org") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + org, err := RequiredParam[string](args, "org") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - teamSlug, err := RequiredParam[string](args, "team_slug") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + teamSlug, err := RequiredParam[string](args, "team_slug") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - gqlClient, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil - } + gqlClient, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil + } - var q struct { - Organization struct { - Team struct { - Members struct { - Nodes []struct { - Login githubv4.String - } - } `graphql:"members(first: 100)"` - } `graphql:"team(slug: $teamSlug)"` - } `graphql:"organization(login: $org)"` - } - vars := map[string]interface{}{ - "org": githubv4.String(org), - "teamSlug": githubv4.String(teamSlug), - } - if err := gqlClient.Query(ctx, &q, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to get team members", err), nil, nil - } + var q struct { + Organization struct { + Team struct { + Members struct { + Nodes []struct { + Login githubv4.String + } + } `graphql:"members(first: 100)"` + } `graphql:"team(slug: $teamSlug)"` + } `graphql:"organization(login: $org)"` + } + vars := map[string]interface{}{ + "org": githubv4.String(org), + "teamSlug": githubv4.String(teamSlug), + } + if err := gqlClient.Query(ctx, &q, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to get team members", err), nil, nil + } - var members []string - for _, member := range q.Organization.Team.Members.Nodes { - members = append(members, string(member.Login)) - } + var members []string + for _, member := range q.Organization.Team.Members.Nodes { + members = append(members, string(member.Login)) + } - return MarshalledTextResult(members), nil, nil - } + return MarshalledTextResult(members), nil, nil + } + }, + ) } diff --git a/pkg/github/context_tools_test.go b/pkg/github/context_tools_test.go index 96e21c233..0e28aad49 100644 --- a/pkg/github/context_tools_test.go +++ b/pkg/github/context_tools_test.go @@ -20,7 +20,8 @@ import ( func Test_GetMe(t *testing.T) { t.Parallel() - tool, _ := GetMe(nil, translations.NullTranslationHelper) + serverTool := GetMe(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) // Verify some basic very important properties @@ -108,21 +109,28 @@ func Test_GetMe(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - _, handler := GetMe(tc.stubbedGetClientFn, translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: tc.stubbedGetClientFn, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, _ := handler(context.Background(), &request, tc.requestArgs) - textContent := getTextResult(t, result) + result, err := handler(context.Background(), &request) + require.NoError(t, err) if tc.expectToolError { - assert.True(t, result.IsError, "expected tool call result to be an error") - assert.Contains(t, textContent.Text, tc.expectedToolErrMsg) + require.True(t, result.IsError, "expected tool call result to be an error") + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedToolErrMsg) return } + require.False(t, result.IsError) + textContent := getTextResult(t, result) + // Unmarshal and verify the result var returnedUser MinimalUser - err := json.Unmarshal([]byte(textContent.Text), &returnedUser) + err = json.Unmarshal([]byte(textContent.Text), &returnedUser) require.NoError(t, err) // Verify minimal user details @@ -145,7 +153,8 @@ func Test_GetMe(t *testing.T) { func Test_GetTeams(t *testing.T) { t.Parallel() - tool, _ := GetTeams(nil, nil, translations.NullTranslationHelper) + serverTool := GetTeams(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_teams", tool.Name) @@ -331,19 +340,26 @@ func Test_GetTeams(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - _, handler := GetTeams(tc.stubbedGetClientFn, tc.stubbedGetGQLClientFn, translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: tc.stubbedGetClientFn, + GetGQLClient: tc.stubbedGetGQLClientFn, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) - textContent := getTextResult(t, result) if tc.expectToolError { - assert.True(t, result.IsError, "expected tool call result to be an error") - assert.Contains(t, textContent.Text, tc.expectedToolErrMsg) + require.True(t, result.IsError, "expected tool call result to be an error") + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedToolErrMsg) return } + require.False(t, result.IsError) + textContent := getTextResult(t, result) + var organizations []OrganizationTeams err = json.Unmarshal([]byte(textContent.Text), &organizations) require.NoError(t, err) @@ -372,7 +388,8 @@ func Test_GetTeams(t *testing.T) { func Test_GetTeamMembers(t *testing.T) { t.Parallel() - tool, _ := GetTeamMembers(nil, translations.NullTranslationHelper) + serverTool := GetTeamMembers(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_team_members", tool.Name) @@ -467,19 +484,25 @@ func Test_GetTeamMembers(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - _, handler := GetTeamMembers(tc.stubbedGetGQLClientFn, translations.NullTranslationHelper) + deps := ToolDependencies{ + GetGQLClient: tc.stubbedGetGQLClientFn, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) - textContent := getTextResult(t, result) if tc.expectToolError { - assert.True(t, result.IsError, "expected tool call result to be an error") - assert.Contains(t, textContent.Text, tc.expectedToolErrMsg) + require.True(t, result.IsError, "expected tool call result to be an error") + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedToolErrMsg) return } + require.False(t, result.IsError) + textContent := getTextResult(t, result) + var members []string err = json.Unmarshal([]byte(textContent.Text), &members) require.NoError(t, err) diff --git a/pkg/github/gists.go b/pkg/github/gists.go index b54553aac..baca42399 100644 --- a/pkg/github/gists.go +++ b/pkg/github/gists.go @@ -7,6 +7,7 @@ import ( "io" "net/http" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -15,346 +16,350 @@ import ( ) // ListGists creates a tool to list gists for a user -func ListGists(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_gists", - Description: t("TOOL_LIST_GISTS_DESCRIPTION", "List gists for a user"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_GISTS", "List Gists"), - ReadOnlyHint: true, - }, - InputSchema: WithPagination(&jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "username": { - Type: "string", - Description: "GitHub username (omit for authenticated user's gists)", - }, - "since": { - Type: "string", - Description: "Only gists updated after this time (ISO 8601 timestamp)", - }, +func ListGists(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "list_gists", + Description: t("TOOL_LIST_GISTS_DESCRIPTION", "List gists for a user"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_GISTS", "List Gists"), + ReadOnlyHint: true, }, - }), - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - username, err := OptionalParam[string](args, "username") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - since, err := OptionalParam[string](args, "since") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.GistListOptions{ - ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - }, - } - - // Parse since timestamp if provided - if since != "" { - sinceTime, err := parseISOTimestamp(since) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid since timestamp: %v", err)), nil, nil - } - opts.Since = sinceTime - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - gists, resp, err := client.Gists.List(ctx, username, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list gists: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + InputSchema: WithPagination(&jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "username": { + Type: "string", + Description: "GitHub username (omit for authenticated user's gists)", + }, + "since": { + Type: "string", + Description: "Only gists updated after this time (ISO 8601 timestamp)", + }, + }, + }), + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + username, err := OptionalParam[string](args, "username") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + since, err := OptionalParam[string](args, "since") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + opts := &github.GistListOptions{ + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + } + + // Parse since timestamp if provided + if since != "" { + sinceTime, err := parseISOTimestamp(since) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid since timestamp: %v", err)), nil, nil + } + opts.Since = sinceTime + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + gists, resp, err := client.Gists.List(ctx, username, opts) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to list gists", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to list gists: %s", string(body))), nil, nil + } + + r, err := json.Marshal(gists) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list gists: %s", string(body))), nil, nil - } - - r, err := json.Marshal(gists) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } // GetGist creates a tool to get the content of a gist -func GetGist(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_gist", - Description: t("TOOL_GET_GIST_DESCRIPTION", "Get gist content of a particular gist, by gist ID"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_GIST", "Get Gist Content"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "gist_id": { - Type: "string", - Description: "The ID of the gist", +func GetGist(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "get_gist", + Description: t("TOOL_GET_GIST_DESCRIPTION", "Get gist content of a particular gist, by gist ID"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_GIST", "Get Gist Content"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "gist_id": { + Type: "string", + Description: "The ID of the gist", + }, }, + Required: []string{"gist_id"}, }, - Required: []string{"gist_id"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - gistID, err := RequiredParam[string](args, "gist_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - gist, resp, err := client.Gists.Get(ctx, gistID) - if err != nil { - return nil, nil, fmt.Errorf("failed to get gist: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + gistID, err := RequiredParam[string](args, "gist_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + gist, resp, err := client.Gists.Get(ctx, gistID) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get gist", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to get gist: %s", string(body))), nil, nil + } + + r, err := json.Marshal(gist) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get gist: %s", string(body))), nil, nil - } - - r, err := json.Marshal(gist) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } // CreateGist creates a tool to create a new gist -func CreateGist(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "create_gist", - Description: t("TOOL_CREATE_GIST_DESCRIPTION", "Create a new gist"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_CREATE_GIST", "Create Gist"), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "description": { - Type: "string", - Description: "Description of the gist", - }, - "filename": { - Type: "string", - Description: "Filename for simple single-file gist creation", - }, - "content": { - Type: "string", - Description: "Content for simple single-file gist creation", - }, - "public": { - Type: "boolean", - Description: "Whether the gist is public", - Default: json.RawMessage(`false`), +func CreateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "create_gist", + Description: t("TOOL_CREATE_GIST_DESCRIPTION", "Create a new gist"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_CREATE_GIST", "Create Gist"), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "description": { + Type: "string", + Description: "Description of the gist", + }, + "filename": { + Type: "string", + Description: "Filename for simple single-file gist creation", + }, + "content": { + Type: "string", + Description: "Content for simple single-file gist creation", + }, + "public": { + Type: "boolean", + Description: "Whether the gist is public", + Default: json.RawMessage(`false`), + }, }, + Required: []string{"filename", "content"}, }, - Required: []string{"filename", "content"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - description, err := OptionalParam[string](args, "description") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - filename, err := RequiredParam[string](args, "filename") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - content, err := RequiredParam[string](args, "content") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - public, err := OptionalParam[bool](args, "public") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - files := make(map[github.GistFilename]github.GistFile) - files[github.GistFilename(filename)] = github.GistFile{ - Filename: github.Ptr(filename), - Content: github.Ptr(content), - } - - gist := &github.Gist{ - Files: files, - Public: github.Ptr(public), - Description: github.Ptr(description), - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - createdGist, resp, err := client.Gists.Create(ctx, gist) - if err != nil { - return nil, nil, fmt.Errorf("failed to create gist: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + description, err := OptionalParam[string](args, "description") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + filename, err := RequiredParam[string](args, "filename") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + content, err := RequiredParam[string](args, "content") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + public, err := OptionalParam[bool](args, "public") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + files := make(map[github.GistFilename]github.GistFile) + files[github.GistFilename(filename)] = github.GistFile{ + Filename: github.Ptr(filename), + Content: github.Ptr(content), + } + + gist := &github.Gist{ + Files: files, + Public: github.Ptr(public), + Description: github.Ptr(description), + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + createdGist, resp, err := client.Gists.Create(ctx, gist) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to create gist", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to create gist: %s", string(body))), nil, nil + } + + minimalResponse := MinimalResponse{ + ID: createdGist.GetID(), + URL: createdGist.GetHTMLURL(), + } + + r, err := json.Marshal(minimalResponse) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to create gist: %s", string(body))), nil, nil - } - - minimalResponse := MinimalResponse{ - ID: createdGist.GetID(), - URL: createdGist.GetHTMLURL(), - } - - r, err := json.Marshal(minimalResponse) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } // UpdateGist creates a tool to edit an existing gist -func UpdateGist(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "update_gist", - Description: t("TOOL_UPDATE_GIST_DESCRIPTION", "Update an existing gist"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_UPDATE_GIST", "Update Gist"), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "gist_id": { - Type: "string", - Description: "ID of the gist to update", - }, - "description": { - Type: "string", - Description: "Updated description of the gist", - }, - "filename": { - Type: "string", - Description: "Filename to update or create", - }, - "content": { - Type: "string", - Description: "Content for the file", +func UpdateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "update_gist", + Description: t("TOOL_UPDATE_GIST_DESCRIPTION", "Update an existing gist"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_UPDATE_GIST", "Update Gist"), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "gist_id": { + Type: "string", + Description: "ID of the gist to update", + }, + "description": { + Type: "string", + Description: "Updated description of the gist", + }, + "filename": { + Type: "string", + Description: "Filename to update or create", + }, + "content": { + Type: "string", + Description: "Content for the file", + }, }, + Required: []string{"gist_id", "filename", "content"}, }, - Required: []string{"gist_id", "filename", "content"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - gistID, err := RequiredParam[string](args, "gist_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - description, err := OptionalParam[string](args, "description") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - filename, err := RequiredParam[string](args, "filename") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - content, err := RequiredParam[string](args, "content") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - files := make(map[github.GistFilename]github.GistFile) - files[github.GistFilename(filename)] = github.GistFile{ - Filename: github.Ptr(filename), - Content: github.Ptr(content), - } - - gist := &github.Gist{ - Files: files, - Description: github.Ptr(description), - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - updatedGist, resp, err := client.Gists.Edit(ctx, gistID, gist) - if err != nil { - return nil, nil, fmt.Errorf("failed to update gist: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + gistID, err := RequiredParam[string](args, "gist_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + description, err := OptionalParam[string](args, "description") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + filename, err := RequiredParam[string](args, "filename") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + content, err := RequiredParam[string](args, "content") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + files := make(map[github.GistFilename]github.GistFile) + files[github.GistFilename(filename)] = github.GistFile{ + Filename: github.Ptr(filename), + Content: github.Ptr(content), + } + + gist := &github.Gist{ + Files: files, + Description: github.Ptr(description), + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + updatedGist, resp, err := client.Gists.Edit(ctx, gistID, gist) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to update gist", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to update gist: %s", string(body))), nil, nil + } + + minimalResponse := MinimalResponse{ + ID: updatedGist.GetID(), + URL: updatedGist.GetHTMLURL(), + } + + r, err := json.Marshal(minimalResponse) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to update gist: %s", string(body))), nil, nil - } - - minimalResponse := MinimalResponse{ - ID: updatedGist.GetID(), - URL: updatedGist.GetHTMLURL(), - } - - r, err := json.Marshal(minimalResponse) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } diff --git a/pkg/github/gists_test.go b/pkg/github/gists_test.go index f0f62f420..44b294eb6 100644 --- a/pkg/github/gists_test.go +++ b/pkg/github/gists_test.go @@ -18,8 +18,8 @@ import ( func Test_ListGists(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := ListGists(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListGists(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -158,28 +158,27 @@ func Test_ListGists(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListGists(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -202,8 +201,8 @@ func Test_ListGists(t *testing.T) { func Test_GetGist(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := GetGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetGist(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -276,28 +275,27 @@ func Test_GetGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetGist(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -317,8 +315,8 @@ func Test_GetGist(t *testing.T) { func Test_CreateGist(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := CreateGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := CreateGist(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -423,28 +421,27 @@ func Test_CreateGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreateGist(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) assert.NotNil(t, result) // Parse the result and get the text content @@ -462,8 +459,8 @@ func Test_CreateGist(t *testing.T) { func Test_UpdateGist(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := UpdateGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := UpdateGist(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -583,28 +580,27 @@ func Test_UpdateGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdateGist(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) assert.NotNil(t, result) // Parse the result and get the text content diff --git a/pkg/github/search.go b/pkg/github/search.go index cffd0bf15..eaaf49369 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -8,6 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,7 +17,7 @@ import ( ) // SearchRepositories creates a tool to search for GitHub repositories. -func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchRepositories(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -44,7 +45,8 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF } WithPagination(schema) - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "search_repositories", Description: t("TOOL_SEARCH_REPOSITORIES_DESCRIPTION", "Find GitHub repositories by name, description, readme, topics, or other metadata. Perfect for discovering projects, finding examples, or locating specific repositories across GitHub."), Annotations: &mcp.ToolAnnotations{ @@ -53,115 +55,118 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - query, err := RequiredParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - order, err := OptionalParam[string](args, "order") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - minimalOutput, err := OptionalBoolParamWithDefault(args, "minimal_output", true) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - opts := &github.SearchOptions{ - Sort: sort, - Order: order, - ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - }, - } - - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - result, resp, err := client.Search.Repositories(ctx, query, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to search repositories with query '%s'", query), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + query, err := RequiredParam[string](args, "query") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sort, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + order, err := OptionalParam[string](args, "order") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + minimalOutput, err := OptionalBoolParamWithDefault(args, "minimal_output", true) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + opts := &github.SearchOptions{ + Sort: sort, + Order: order, + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + client, err := deps.GetClient(ctx) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to search repositories: %s", string(body))), nil, nil - } + result, resp, err := client.Search.Repositories(ctx, query, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to search repositories with query '%s'", query), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Return either minimal or full response based on parameter - var r []byte - if minimalOutput { - minimalRepos := make([]MinimalRepository, 0, len(result.Repositories)) - for _, repo := range result.Repositories { - minimalRepo := MinimalRepository{ - ID: repo.GetID(), - Name: repo.GetName(), - FullName: repo.GetFullName(), - Description: repo.GetDescription(), - HTMLURL: repo.GetHTMLURL(), - Language: repo.GetLanguage(), - Stars: repo.GetStargazersCount(), - Forks: repo.GetForksCount(), - OpenIssues: repo.GetOpenIssuesCount(), - Private: repo.GetPrivate(), - Fork: repo.GetFork(), - Archived: repo.GetArchived(), - DefaultBranch: repo.GetDefaultBranch(), + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return utils.NewToolResultError(fmt.Sprintf("failed to search repositories: %s", string(body))), nil, nil + } - if repo.UpdatedAt != nil { - minimalRepo.UpdatedAt = repo.UpdatedAt.Format("2006-01-02T15:04:05Z") - } - if repo.CreatedAt != nil { - minimalRepo.CreatedAt = repo.CreatedAt.Format("2006-01-02T15:04:05Z") - } - if repo.Topics != nil { - minimalRepo.Topics = repo.Topics + // Return either minimal or full response based on parameter + var r []byte + if minimalOutput { + minimalRepos := make([]MinimalRepository, 0, len(result.Repositories)) + for _, repo := range result.Repositories { + minimalRepo := MinimalRepository{ + ID: repo.GetID(), + Name: repo.GetName(), + FullName: repo.GetFullName(), + Description: repo.GetDescription(), + HTMLURL: repo.GetHTMLURL(), + Language: repo.GetLanguage(), + Stars: repo.GetStargazersCount(), + Forks: repo.GetForksCount(), + OpenIssues: repo.GetOpenIssuesCount(), + Private: repo.GetPrivate(), + Fork: repo.GetFork(), + Archived: repo.GetArchived(), + DefaultBranch: repo.GetDefaultBranch(), + } + + if repo.UpdatedAt != nil { + minimalRepo.UpdatedAt = repo.UpdatedAt.Format("2006-01-02T15:04:05Z") + } + if repo.CreatedAt != nil { + minimalRepo.CreatedAt = repo.CreatedAt.Format("2006-01-02T15:04:05Z") + } + if repo.Topics != nil { + minimalRepo.Topics = repo.Topics + } + + minimalRepos = append(minimalRepos, minimalRepo) } - minimalRepos = append(minimalRepos, minimalRepo) - } + minimalResult := &MinimalSearchRepositoriesResult{ + TotalCount: result.GetTotal(), + IncompleteResults: result.GetIncompleteResults(), + Items: minimalRepos, + } - minimalResult := &MinimalSearchRepositoriesResult{ - TotalCount: result.GetTotal(), - IncompleteResults: result.GetIncompleteResults(), - Items: minimalRepos, + r, err = json.Marshal(minimalResult) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal minimal response", err), nil, nil + } + } else { + r, err = json.Marshal(result) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal full response", err), nil, nil + } } - r, err = json.Marshal(minimalResult) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal minimal response", err), nil, nil - } - } else { - r, err = json.Marshal(result) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal full response", err), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } // SearchCode creates a tool to search for code across GitHub repositories. -func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchCode(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -183,7 +188,8 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mc } WithPagination(schema) - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "search_code", Description: t("TOOL_SEARCH_CODE_DESCRIPTION", "Fast and precise code search across ALL GitHub repositories using GitHub's native search engine. Best for finding exact symbols, functions, classes, or specific code patterns."), Annotations: &mcp.ToolAnnotations{ @@ -192,66 +198,69 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mc }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - query, err := RequiredParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - order, err := OptionalParam[string](args, "order") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + query, err := RequiredParam[string](args, "query") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sort, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + order, err := OptionalParam[string](args, "order") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - opts := &github.SearchOptions{ - Sort: sort, - Order: order, - ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - }, - } + opts := &github.SearchOptions{ + Sort: sort, + Order: order, + ListOptions: github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + }, + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - result, resp, err := client.Search.Code(ctx, query, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to search code with query '%s'", query), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + result, resp, err := client.Search.Code(ctx, query, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to search code with query '%s'", query), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to search code: %s", string(body))), nil, nil + } + + r, err := json.Marshal(result) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to search code: %s", string(body))), nil, nil - } - r, err := json.Marshal(result) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func userOrOrgHandler(accountType string, getClient GetClientFn) mcp.ToolHandlerFor[map[string]any, any] { +func userOrOrgHandler(accountType string, deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { query, err := RequiredParam[string](args, "query") if err != nil { @@ -279,7 +288,7 @@ func userOrOrgHandler(accountType string, getClient GetClientFn) mcp.ToolHandler }, } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -340,7 +349,7 @@ func userOrOrgHandler(accountType string, getClient GetClientFn) mcp.ToolHandler } // SearchUsers creates a tool to search for GitHub users. -func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchUsers(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -363,19 +372,24 @@ func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (m } WithPagination(schema) - return mcp.Tool{ - Name: "search_users", - Description: t("TOOL_SEARCH_USERS_DESCRIPTION", "Find GitHub users by username, real name, or other profile information. Useful for locating developers, contributors, or team members."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"), - ReadOnlyHint: true, + return NewTool( + mcp.Tool{ + Name: "search_users", + Description: t("TOOL_SEARCH_USERS_DESCRIPTION", "Find GitHub users by username, real name, or other profile information. Useful for locating developers, contributors, or team members."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"), + ReadOnlyHint: true, + }, + InputSchema: schema, + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return userOrOrgHandler("user", deps) }, - InputSchema: schema, - }, userOrOrgHandler("user", getClient) + ) } // SearchOrgs creates a tool to search for GitHub organizations. -func SearchOrgs(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchOrgs(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -398,13 +412,18 @@ func SearchOrgs(getClient GetClientFn, t translations.TranslationHelperFunc) (mc } WithPagination(schema) - return mcp.Tool{ - Name: "search_orgs", - Description: t("TOOL_SEARCH_ORGS_DESCRIPTION", "Find GitHub organizations by name, location, or other organization metadata. Ideal for discovering companies, open source foundations, or teams."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_SEARCH_ORGS_USER_TITLE", "Search organizations"), - ReadOnlyHint: true, + return NewTool( + mcp.Tool{ + Name: "search_orgs", + Description: t("TOOL_SEARCH_ORGS_DESCRIPTION", "Find GitHub organizations by name, location, or other organization metadata. Ideal for discovering companies, open source foundations, or teams."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_SEARCH_ORGS_USER_TITLE", "Search organizations"), + ReadOnlyHint: true, + }, + InputSchema: schema, + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return userOrOrgHandler("org", deps) }, - InputSchema: schema, - }, userOrOrgHandler("org", getClient) + ) } diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go index 0b923edcd..41d12df1b 100644 --- a/pkg/github/search_test.go +++ b/pkg/github/search_test.go @@ -17,8 +17,8 @@ import ( func Test_SearchRepositories(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchRepositories(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchRepositories(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_repositories", tool.Name) @@ -134,13 +134,16 @@ func Test_SearchRepositories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchRepositories(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -205,7 +208,11 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { ) client := github.NewClient(mockedClient) - _, handlerTest := SearchRepositories(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := SearchRepositories(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) args := map[string]interface{}{ "query": "golang test", @@ -214,7 +221,7 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { request := createMCPRequest(args) - result, _, err := handlerTest(context.Background(), &request, args) + result, err := handler(context.Background(), &request) require.NoError(t, err) require.False(t, result.IsError) @@ -236,8 +243,8 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { func Test_SearchCode(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchCode(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchCode(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_code", tool.Name) @@ -351,13 +358,16 @@ func Test_SearchCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchCode(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -394,8 +404,8 @@ func Test_SearchCode(t *testing.T) { func Test_SearchUsers(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchUsers(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchUsers(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_users", tool.Name) @@ -548,13 +558,16 @@ func Test_SearchUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchUsers(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -592,8 +605,8 @@ func Test_SearchUsers(t *testing.T) { func Test_SearchOrgs(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchOrgs(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchOrgs(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -720,13 +733,16 @@ func Test_SearchOrgs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchOrgs(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { diff --git a/pkg/github/tools.go b/pkg/github/tools.go index efabfc92f..d9008a912 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -179,10 +179,10 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG repos := toolsets.NewToolset(ToolsetMetadataRepos.ID, ToolsetMetadataRepos.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(SearchRepositories(getClient, t)), + SearchRepositories(t), toolsets.NewServerToolLegacy(GetFileContents(getClient, getRawClient, t)), toolsets.NewServerToolLegacy(ListCommits(getClient, t)), - toolsets.NewServerToolLegacy(SearchCode(getClient, t)), + SearchCode(t), toolsets.NewServerToolLegacy(GetCommit(getClient, t)), toolsets.NewServerToolLegacy(ListBranches(getClient, t)), toolsets.NewServerToolLegacy(ListTags(getClient, t)), @@ -232,12 +232,12 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG users := toolsets.NewToolset(ToolsetMetadataUsers.ID, ToolsetMetadataUsers.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(SearchUsers(getClient, t)), + SearchUsers(t), ) orgs := toolsets.NewToolset(ToolsetMetadataOrgs.ID, ToolsetMetadataOrgs.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(SearchOrgs(getClient, t)), + SearchOrgs(t), ) pullRequests := toolsets.NewToolset(ToolsetMetadataPullRequests.ID, ToolsetMetadataPullRequests.Description). SetDependencies(deps). @@ -334,20 +334,20 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG contextTools := toolsets.NewToolset(ToolsetMetadataContext.ID, ToolsetMetadataContext.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(GetMe(getClient, t)), - toolsets.NewServerToolLegacy(GetTeams(getClient, getGQLClient, t)), - toolsets.NewServerToolLegacy(GetTeamMembers(getGQLClient, t)), + GetMe(t), + GetTeams(t), + GetTeamMembers(t), ) gists := toolsets.NewToolset(ToolsetMetadataGists.ID, ToolsetMetadataGists.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(ListGists(getClient, t)), - toolsets.NewServerToolLegacy(GetGist(getClient, t)), + ListGists(t), + GetGist(t), ). AddWriteTools( - toolsets.NewServerToolLegacy(CreateGist(getClient, t)), - toolsets.NewServerToolLegacy(UpdateGist(getClient, t)), + CreateGist(t), + UpdateGist(t), ) projects := toolsets.NewToolset(ToolsetMetadataProjects.ID, ToolsetMetadataProjects.Description).