-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Add updating draft state to update_pull_request
tool
#774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ae89361
ffd5798
e0d67eb
93fa25b
ec521d5
812048b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -203,7 +203,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu | |||||||||
} | ||||||||||
|
||||||||||
// UpdatePullRequest creates a tool to update an existing pull request. | ||||||||||
func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { | ||||||||||
func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { | ||||||||||
return mcp.NewTool("update_pull_request", | ||||||||||
mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository.")), | ||||||||||
mcp.WithToolAnnotation(mcp.ToolAnnotation{ | ||||||||||
|
@@ -232,6 +232,9 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu | |||||||||
mcp.Description("New state"), | ||||||||||
mcp.Enum("open", "closed"), | ||||||||||
), | ||||||||||
mcp.WithBoolean("draft", | ||||||||||
mcp.Description("Mark pull request as draft (true) or ready for review (false)"), | ||||||||||
), | ||||||||||
mcp.WithString("base", | ||||||||||
mcp.Description("New base branch name"), | ||||||||||
), | ||||||||||
|
@@ -253,74 +256,165 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu | |||||||||
return mcp.NewToolResultError(err.Error()), nil | ||||||||||
} | ||||||||||
|
||||||||||
// Build the update struct only with provided fields | ||||||||||
draftProvided := request.GetArguments()["draft"] != nil | ||||||||||
var draftValue bool | ||||||||||
if draftProvided { | ||||||||||
draftValue, err = OptionalParam[bool](request, "draft") | ||||||||||
if err != nil { | ||||||||||
return nil, err | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
update := &github.PullRequest{} | ||||||||||
updateNeeded := false | ||||||||||
restUpdateNeeded := false | ||||||||||
|
||||||||||
if title, ok, err := OptionalParamOK[string](request, "title"); err != nil { | ||||||||||
return mcp.NewToolResultError(err.Error()), nil | ||||||||||
} else if ok { | ||||||||||
update.Title = github.Ptr(title) | ||||||||||
updateNeeded = true | ||||||||||
restUpdateNeeded = true | ||||||||||
} | ||||||||||
|
||||||||||
if body, ok, err := OptionalParamOK[string](request, "body"); err != nil { | ||||||||||
return mcp.NewToolResultError(err.Error()), nil | ||||||||||
} else if ok { | ||||||||||
update.Body = github.Ptr(body) | ||||||||||
updateNeeded = true | ||||||||||
restUpdateNeeded = true | ||||||||||
} | ||||||||||
|
||||||||||
if state, ok, err := OptionalParamOK[string](request, "state"); err != nil { | ||||||||||
return mcp.NewToolResultError(err.Error()), nil | ||||||||||
} else if ok { | ||||||||||
update.State = github.Ptr(state) | ||||||||||
updateNeeded = true | ||||||||||
restUpdateNeeded = true | ||||||||||
} | ||||||||||
|
||||||||||
if base, ok, err := OptionalParamOK[string](request, "base"); err != nil { | ||||||||||
return mcp.NewToolResultError(err.Error()), nil | ||||||||||
} else if ok { | ||||||||||
update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)} | ||||||||||
updateNeeded = true | ||||||||||
restUpdateNeeded = true | ||||||||||
} | ||||||||||
|
||||||||||
if maintainerCanModify, ok, err := OptionalParamOK[bool](request, "maintainer_can_modify"); err != nil { | ||||||||||
return mcp.NewToolResultError(err.Error()), nil | ||||||||||
} else if ok { | ||||||||||
update.MaintainerCanModify = github.Ptr(maintainerCanModify) | ||||||||||
updateNeeded = true | ||||||||||
restUpdateNeeded = true | ||||||||||
} | ||||||||||
|
||||||||||
if !updateNeeded { | ||||||||||
if !restUpdateNeeded && !draftProvided { | ||||||||||
return mcp.NewToolResultError("No update parameters provided."), nil | ||||||||||
} | ||||||||||
|
||||||||||
if restUpdateNeeded { | ||||||||||
client, err := getClient(ctx) | ||||||||||
if err != nil { | ||||||||||
return nil, fmt.Errorf("failed to get GitHub client: %w", err) | ||||||||||
} | ||||||||||
|
||||||||||
_, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) | ||||||||||
if err != nil { | ||||||||||
return ghErrors.NewGitHubAPIErrorResponse(ctx, | ||||||||||
"failed to update pull request", | ||||||||||
resp, | ||||||||||
err, | ||||||||||
), nil | ||||||||||
} | ||||||||||
defer func() { _ = resp.Body.Close() }() | ||||||||||
|
||||||||||
if resp.StatusCode != http.StatusOK { | ||||||||||
body, err := io.ReadAll(resp.Body) | ||||||||||
if err != nil { | ||||||||||
return nil, fmt.Errorf("failed to read response body: %w", err) | ||||||||||
} | ||||||||||
return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
if draftProvided { | ||||||||||
gqlClient, err := getGQLClient(ctx) | ||||||||||
if err != nil { | ||||||||||
return nil, fmt.Errorf("failed to get GitHub GraphQL client: %w", err) | ||||||||||
} | ||||||||||
|
||||||||||
var prQuery struct { | ||||||||||
Repository struct { | ||||||||||
PullRequest struct { | ||||||||||
ID githubv4.ID | ||||||||||
IsDraft githubv4.Boolean | ||||||||||
} `graphql:"pullRequest(number: $prNum)"` | ||||||||||
} `graphql:"repository(owner: $owner, name: $repo)"` | ||||||||||
} | ||||||||||
|
||||||||||
err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{ | ||||||||||
"owner": githubv4.String(owner), | ||||||||||
"repo": githubv4.String(repo), | ||||||||||
"prNum": githubv4.Int(pullNumber), // #nosec G115 - pull request numbers are always small positive integers | ||||||||||
}) | ||||||||||
if err != nil { | ||||||||||
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil | ||||||||||
} | ||||||||||
|
||||||||||
currentIsDraft := bool(prQuery.Repository.PullRequest.IsDraft) | ||||||||||
|
||||||||||
if currentIsDraft != draftValue { | ||||||||||
if draftValue { | ||||||||||
// Convert to draft | ||||||||||
var mutation struct { | ||||||||||
ConvertPullRequestToDraft struct { | ||||||||||
PullRequest struct { | ||||||||||
ID githubv4.ID | ||||||||||
IsDraft githubv4.Boolean | ||||||||||
} | ||||||||||
} `graphql:"convertPullRequestToDraft(input: $input)"` | ||||||||||
} | ||||||||||
|
||||||||||
err = gqlClient.Mutate(ctx, &mutation, githubv4.ConvertPullRequestToDraftInput{ | ||||||||||
PullRequestID: prQuery.Repository.PullRequest.ID, | ||||||||||
}, nil) | ||||||||||
if err != nil { | ||||||||||
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to convert pull request to draft", err), nil | ||||||||||
} | ||||||||||
} else { | ||||||||||
// Mark as ready for review | ||||||||||
var mutation struct { | ||||||||||
MarkPullRequestReadyForReview struct { | ||||||||||
PullRequest struct { | ||||||||||
ID githubv4.ID | ||||||||||
IsDraft githubv4.Boolean | ||||||||||
} | ||||||||||
} `graphql:"markPullRequestReadyForReview(input: $input)"` | ||||||||||
} | ||||||||||
|
||||||||||
err = gqlClient.Mutate(ctx, &mutation, githubv4.MarkPullRequestReadyForReviewInput{ | ||||||||||
PullRequestID: prQuery.Repository.PullRequest.ID, | ||||||||||
}, nil) | ||||||||||
if err != nil { | ||||||||||
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to mark pull request ready for review", err), nil | ||||||||||
} | ||||||||||
} | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
client, err := getClient(ctx) | ||||||||||
if err != nil { | ||||||||||
return nil, fmt.Errorf("failed to get GitHub client: %w", err) | ||||||||||
return nil, err | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The REST client is retrieved twice in this function - once for the conditional update and again at the end. Consider reusing the client instance or restructuring the code to avoid duplicate client retrieval.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||
} | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The REST client is retrieved twice in this function (lines 311 and 400). Consider retrieving it once at the beginning and reusing it to improve efficiency and reduce duplication.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||
pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) | ||||||||||
|
||||||||||
finalPR, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) | ||||||||||
if err != nil { | ||||||||||
return ghErrors.NewGitHubAPIErrorResponse(ctx, | ||||||||||
"failed to update pull request", | ||||||||||
resp, | ||||||||||
err, | ||||||||||
), nil | ||||||||||
return ghErrors.NewGitHubAPIErrorResponse(ctx, "Failed to get pull request", resp, err), nil | ||||||||||
} | ||||||||||
defer func() { _ = resp.Body.Close() }() | ||||||||||
|
||||||||||
if resp.StatusCode != http.StatusOK { | ||||||||||
body, err := io.ReadAll(resp.Body) | ||||||||||
if err != nil { | ||||||||||
return nil, fmt.Errorf("failed to read response body: %w", err) | ||||||||||
defer func() { | ||||||||||
if resp != nil && resp.Body != nil { | ||||||||||
_ = resp.Body.Close() | ||||||||||
} | ||||||||||
return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil | ||||||||||
} | ||||||||||
}() | ||||||||||
|
||||||||||
r, err := json.Marshal(pr) | ||||||||||
r, err := json.Marshal(finalPR) | ||||||||||
if err != nil { | ||||||||||
return nil, fmt.Errorf("failed to marshal response: %w", err) | ||||||||||
return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal response: %v", err)), nil | ||||||||||
} | ||||||||||
|
||||||||||
return mcp.NewToolResultText(string(r)), nil | ||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.