package github import ( "context" "encoding/json" "net/http" "strings" "testing" "github.com/github/github-mcp-server/internal/toolsnaps" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v82/github" "github.com/google/jsonschema-go/jsonschema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_GetRepositoryTree(t *testing.T) { // Verify tool definition once toolDef := GetRepositoryTree(translations.NullTranslationHelper) require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) assert.Equal(t, "get_repository_tree", toolDef.Tool.Name) assert.NotEmpty(t, toolDef.Tool.Description) // Type assert the InputSchema to access its properties inputSchema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "expected InputSchema to be *jsonschema.Schema") assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") assert.Contains(t, inputSchema.Properties, "tree_sha") assert.Contains(t, inputSchema.Properties, "recursive") assert.Contains(t, inputSchema.Properties, "path_filter") assert.ElementsMatch(t, inputSchema.Required, []string{"owner", "repo"}) // Setup mock data mockRepo := &github.Repository{ DefaultBranch: github.Ptr("main"), } mockTree := &github.Tree{ SHA: github.Ptr("abc123"), Truncated: github.Ptr(false), Entries: []*github.TreeEntry{ { Path: github.Ptr("README.md"), Mode: github.Ptr("100644"), Type: github.Ptr("blob"), SHA: github.Ptr("file1sha"), Size: github.Ptr(123), URL: github.Ptr("https://api.github.com/repos/owner/repo/git/blobs/file1sha"), }, { Path: github.Ptr("src/main.go"), Mode: github.Ptr("100644"), Type: github.Ptr("blob"), SHA: github.Ptr("file2sha"), Size: github.Ptr(456), URL: github.Ptr("https://api.github.com/repos/owner/repo/git/blobs/file2sha"), }, }, } tests := []struct { name string mockedClient *http.Client requestArgs map[string]any expectError bool expectedErrMsg string }{ { name: "successfully get repository tree", mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ GetReposByOwnerByRepo: mockResponse(t, http.StatusOK, mockRepo), GetReposGitTreesByOwnerByRepoByTree: mockResponse(t, http.StatusOK, mockTree), }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", }, }, { name: "successfully get repository tree with path filter", mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ GetReposByOwnerByRepo: mockResponse(t, http.StatusOK, mockRepo), GetReposGitTreesByOwnerByRepoByTree: mockResponse(t, http.StatusOK, mockTree), }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", "path_filter": "src/", }, }, { name: "repository not found", mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ GetReposByOwnerByRepo: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusNotFound) _, _ = w.Write([]byte(`{"message": "Not Found"}`)) }), }), requestArgs: map[string]any{ "owner": "owner", "repo": "nonexistent", }, expectError: true, expectedErrMsg: "failed to get repository info", }, { name: "tree not found", mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ GetReposByOwnerByRepo: mockResponse(t, http.StatusOK, mockRepo), GetReposGitTreesByOwnerByRepoByTree: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusNotFound) _, _ = w.Write([]byte(`{"message": "Not Found"}`)) }), }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", }, expectError: true, expectedErrMsg: "failed to get repository tree", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) deps := BaseDeps{ Client: client, } handler := toolDef.Handler(deps) // Create the tool request request := createMCPRequest(tc.requestArgs) result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.NoError(t, err) require.True(t, result.IsError) errorContent := getErrorResult(t, result) assert.Contains(t, errorContent.Text, tc.expectedErrMsg) } else { require.NoError(t, err) require.False(t, result.IsError) // Parse the result and get the text content textContent := getTextResult(t, result) // Parse the JSON response var treeResponse map[string]any err := json.Unmarshal([]byte(textContent.Text), &treeResponse) require.NoError(t, err) // Verify response structure assert.Equal(t, "owner", treeResponse["owner"]) assert.Equal(t, "repo", treeResponse["repo"]) assert.Contains(t, treeResponse, "tree") assert.Contains(t, treeResponse, "count") assert.Contains(t, treeResponse, "sha") assert.Contains(t, treeResponse, "truncated") // Check filtering if path_filter was provided if pathFilter, exists := tc.requestArgs["path_filter"]; exists { tree := treeResponse["tree"].([]any) for _, entry := range tree { entryMap := entry.(map[string]any) path := entryMap["path"].(string) assert.True(t, strings.HasPrefix(path, pathFilter.(string)), "Path %s should start with filter %s", path, pathFilter) } } } }) } }