From ae89361f275086a16c11fa4f03438e59f85f4327 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Mon, 28 Jul 2025 15:41:42 +0100 Subject: [PATCH 1/6] initial impl of pull request draft state update --- .../__toolsnaps__/update_pull_request.snap | 4 + pkg/github/pullrequests.go | 146 ++++++++++++++---- pkg/github/pullrequests_test.go | 13 +- pkg/github/tools.go | 2 +- 4 files changed, 136 insertions(+), 29 deletions(-) diff --git a/pkg/github/__toolsnaps__/update_pull_request.snap b/pkg/github/__toolsnaps__/update_pull_request.snap index 765983afd..c44d8afa0 100644 --- a/pkg/github/__toolsnaps__/update_pull_request.snap +++ b/pkg/github/__toolsnaps__/update_pull_request.snap @@ -14,6 +14,10 @@ "description": "New description", "type": "string" }, + "draft": { + "description": "Mark pull request as draft (true) or ready for review (false)", + "type": "boolean" + }, "maintainer_can_modify": { "description": "Allow maintainer edits", "type": "boolean" diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 47b7c6bd2..aeb0551c9 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -203,7 +203,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } // UpdatePullRequest creates a tool to update an existing pull request. -func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { +func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("update_pull_request", mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -232,6 +232,9 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu mcp.Description("New state"), mcp.Enum("open", "closed"), ), + mcp.WithBoolean("draft", + mcp.Description("Mark pull request as draft (true) or ready for review (false)"), + ), mcp.WithString("base", mcp.Description("New base branch name"), ), @@ -253,74 +256,165 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu return mcp.NewToolResultError(err.Error()), nil } - // Build the update struct only with provided fields + draftProvided := request.GetArguments()["draft"] != nil + var draftValue bool + if draftProvided { + draftValue, err = OptionalParam[bool](request, "draft") + if err != nil { + return nil, err + } + } + update := &github.PullRequest{} - updateNeeded := false + restUpdateNeeded := false if title, ok, err := OptionalParamOK[string](request, "title"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.Title = github.Ptr(title) - updateNeeded = true + restUpdateNeeded = true } if body, ok, err := OptionalParamOK[string](request, "body"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.Body = github.Ptr(body) - updateNeeded = true + restUpdateNeeded = true } if state, ok, err := OptionalParamOK[string](request, "state"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.State = github.Ptr(state) - updateNeeded = true + restUpdateNeeded = true } if base, ok, err := OptionalParamOK[string](request, "base"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)} - updateNeeded = true + restUpdateNeeded = true } if maintainerCanModify, ok, err := OptionalParamOK[bool](request, "maintainer_can_modify"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.MaintainerCanModify = github.Ptr(maintainerCanModify) - updateNeeded = true + restUpdateNeeded = true } - if !updateNeeded { + if !restUpdateNeeded && !draftProvided { return mcp.NewToolResultError("No update parameters provided."), nil } + if restUpdateNeeded { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + _, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update pull request", + resp, + err, + ), nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil + } + } + + if draftProvided { + gqlClient, err := getGQLClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub GraphQL client: %w", err) + } + + var prQuery struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + + err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "prNum": githubv4.Int(pullNumber), + }) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil + } + + currentIsDraft := bool(prQuery.Repository.PullRequest.IsDraft) + + if currentIsDraft != draftValue { + if draftValue { + // Convert to draft + var mutation struct { + ConvertPullRequestToDraft struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"convertPullRequestToDraft(input: $input)"` + } + + err = gqlClient.Mutate(ctx, &mutation, githubv4.ConvertPullRequestToDraftInput{ + PullRequestID: prQuery.Repository.PullRequest.ID, + }, nil) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to convert pull request to draft", err), nil + } + } else { + // Mark as ready for review + var mutation struct { + MarkPullRequestReadyForReview struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"markPullRequestReadyForReview(input: $input)"` + } + + err = gqlClient.Mutate(ctx, &mutation, githubv4.MarkPullRequestReadyForReviewInput{ + PullRequestID: prQuery.Repository.PullRequest.ID, + }, nil) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to mark pull request ready for review", err), nil + } + } + } + } + client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, err } - pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) + + finalPR, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to update pull request", - resp, - err, - ), nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, "Failed to get pull request", resp, err), nil } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() } - return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil - } + }() - r, err := json.Marshal(pr) + r, err := json.Marshal(finalPR) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to marshal response: %v", err), nil } return mcp.NewToolResultText(string(r)), nil diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 42fd5bf03..823a87525 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -137,7 +137,7 @@ func Test_GetPullRequest(t *testing.T) { func Test_UpdatePullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := UpdatePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool, _ := UpdatePullRequest(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "update_pull_request", tool.Name) @@ -145,6 +145,7 @@ func Test_UpdatePullRequest(t *testing.T) { assert.Contains(t, tool.InputSchema.Properties, "owner") assert.Contains(t, tool.InputSchema.Properties, "repo") assert.Contains(t, tool.InputSchema.Properties, "pullNumber") + assert.Contains(t, tool.InputSchema.Properties, "draft") assert.Contains(t, tool.InputSchema.Properties, "title") assert.Contains(t, tool.InputSchema.Properties, "body") assert.Contains(t, tool.InputSchema.Properties, "state") @@ -194,6 +195,10 @@ func Test_UpdatePullRequest(t *testing.T) { mockResponse(t, http.StatusOK, mockUpdatedPR), ), ), + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockUpdatedPR, + ), ), requestArgs: map[string]interface{}{ "owner": "owner", @@ -218,6 +223,10 @@ func Test_UpdatePullRequest(t *testing.T) { mockResponse(t, http.StatusOK, mockClosedPR), ), ), + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockClosedPR, + ), ), requestArgs: map[string]interface{}{ "owner": "owner", @@ -266,7 +275,7 @@ func Test_UpdatePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdatePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) + _, handler := UpdatePullRequest(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/tools.go b/pkg/github/tools.go index e01b7cc40..caa4f9cfe 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -87,7 +87,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG toolsets.NewServerTool(MergePullRequest(getClient, t)), toolsets.NewServerTool(UpdatePullRequestBranch(getClient, t)), toolsets.NewServerTool(CreatePullRequest(getClient, t)), - toolsets.NewServerTool(UpdatePullRequest(getClient, t)), + toolsets.NewServerTool(UpdatePullRequest(getClient, getGQLClient, t)), toolsets.NewServerTool(RequestCopilotReview(getClient, t)), // Reviews From ffd5798e7aaa2d3566136b29002ec4a2732deb6c Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Mon, 28 Jul 2025 16:12:03 +0100 Subject: [PATCH 2/6] appease linter --- pkg/github/pullrequests.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index aeb0551c9..080887f35 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -350,7 +350,7 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{ "owner": githubv4.String(owner), "repo": githubv4.String(repo), - "prNum": githubv4.Int(pullNumber), + "prNum": githubv4.Int(int32(pullNumber)), }) if err != nil { return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil From e0d67eb7ac8ea4c23124651a40fa61d56c5bdbe4 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Mon, 28 Jul 2025 16:13:15 +0100 Subject: [PATCH 3/6] update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index be9288e40..34edc7b33 100644 --- a/README.md +++ b/README.md @@ -736,6 +736,7 @@ The following sets of tools are available (all are on by default): - **update_pull_request** - Edit pull request - `base`: New base branch name (string, optional) - `body`: New description (string, optional) + - `draft`: Mark pull request as draft (true) or ready for review (false) (boolean, optional) - `maintainer_can_modify`: Allow maintainer edits (boolean, optional) - `owner`: Repository owner (string, required) - `pullNumber`: Pull request number to update (number, required) From 93fa25b9039cb7198c2d8baac22f0fa2ce4158d7 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Mon, 28 Jul 2025 16:19:44 +0100 Subject: [PATCH 4/6] add nosec --- pkg/github/pullrequests.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 080887f35..e30f856d2 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -350,7 +350,7 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{ "owner": githubv4.String(owner), "repo": githubv4.String(repo), - "prNum": githubv4.Int(int32(pullNumber)), + "prNum": githubv4.Int(pullNumber), // #nosec G115 - pull request numbers are always small positive integers }) if err != nil { return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil From ec521d58096db7620bfae61deef3eac0bda1e8a0 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Mon, 28 Jul 2025 16:27:07 +0100 Subject: [PATCH 5/6] fixed err return type for json marshalling --- pkg/github/pullrequests.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index e30f856d2..f380843b0 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -414,7 +414,7 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra r, err := json.Marshal(finalPR) if err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to marshal response: %v", err), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal response: %v", err)), nil } return mcp.NewToolResultText(string(r)), nil From 812048bfe2fafd56adfab13db20e5806d73f6804 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Tue, 29 Jul 2025 10:31:22 +0100 Subject: [PATCH 6/6] add gql test --- pkg/github/pullrequests_test.go | 211 ++++++++++++++++++++++++++++++++ 1 file changed, 211 insertions(+) diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 823a87525..179458115 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -161,6 +161,7 @@ func Test_UpdatePullRequest(t *testing.T) { HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), Body: github.Ptr("Updated test PR body."), MaintainerCanModify: github.Ptr(false), + Draft: github.Ptr(false), Base: &github.PullRequestBranch{ Ref: github.Ptr("develop"), }, @@ -237,6 +238,31 @@ func Test_UpdatePullRequest(t *testing.T) { expectError: false, expectedPR: mockClosedPR, }, + { + name: "successful PR update (title only)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchReposPullsByOwnerByRepoByPullNumber, + expectRequestBody(t, map[string]interface{}{ + "title": "Updated Test PR Title", + }).andThen( + mockResponse(t, http.StatusOK, mockUpdatedPR), + ), + ), + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockUpdatedPR, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "title": "Updated Test PR Title", + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, { name: "no update parameters provided", mockedClient: mock.NewMockedHTTPClient(), // No API call expected @@ -325,6 +351,191 @@ func Test_UpdatePullRequest(t *testing.T) { } } +func Test_UpdatePullRequest_Draft(t *testing.T) { + // Setup mock PR for success case + mockUpdatedPR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR Title"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + Body: github.Ptr("Test PR body."), + MaintainerCanModify: github.Ptr(false), + Draft: github.Ptr(false), // Updated to ready for review + Base: &github.PullRequestBranch{ + Ref: github.Ptr("main"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedPR *github.PullRequest + expectedErrMsg string + }{ + { + name: "successful draft update to ready for review", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": true, // Current state is draft + }, + }, + }), + ), + githubv4mock.NewMutationMatcher( + struct { + MarkPullRequestReadyForReview struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"markPullRequestReadyForReview(input: $input)"` + }{}, + githubv4.MarkPullRequestReadyForReviewInput{ + PullRequestID: "PR_kwDOA0xdyM50BPaO", + }, + nil, + githubv4mock.DataResponse(map[string]any{ + "markPullRequestReadyForReview": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": false, + }, + }, + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "draft": false, + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, + { + name: "successful convert pull request to draft", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": false, // Current state is draft + }, + }, + }), + ), + githubv4mock.NewMutationMatcher( + struct { + ConvertPullRequestToDraft struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"convertPullRequestToDraft(input: $input)"` + }{}, + githubv4.ConvertPullRequestToDraftInput{ + PullRequestID: "PR_kwDOA0xdyM50BPaO", + }, + nil, + githubv4mock.DataResponse(map[string]any{ + "convertPullRequestToDraft": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": true, + }, + }, + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "draft": true, + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // For draft-only tests, we need to mock both GraphQL and the final REST GET call + restClient := github.NewClient(mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockUpdatedPR, + ), + )) + gqlClient := githubv4.NewClient(tc.mockedClient) + + _, handler := UpdatePullRequest(stubGetClientFn(restClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + + request := createMCPRequest(tc.requestArgs) + + result, err := handler(context.Background(), request) + + if tc.expectError || tc.expectedErrMsg != "" { + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + if tc.expectedErrMsg != "" { + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + + // Unmarshal and verify the successful result + var returnedPR github.PullRequest + err = json.Unmarshal([]byte(textContent.Text), &returnedPR) + require.NoError(t, err) + assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number) + if tc.expectedPR.Draft != nil { + assert.Equal(t, *tc.expectedPR.Draft, *returnedPR.Draft) + } + }) + } +} + func Test_ListPullRequests(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil)