diff --git a/internal/toolsets/vulnerability/clusters.go b/internal/toolsets/vulnerability/clusters.go new file mode 100644 index 0000000..f5852f4 --- /dev/null +++ b/internal/toolsets/vulnerability/clusters.go @@ -0,0 +1,159 @@ +package vulnerability + +import ( + "context" + "fmt" + "sort" + "strings" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/pkg/errors" + v1 "github.com/stackrox/rox/generated/api/v1" + "github.com/stackrox/stackrox-mcp/internal/client" + "github.com/stackrox/stackrox-mcp/internal/client/auth" + "github.com/stackrox/stackrox-mcp/internal/logging" + "github.com/stackrox/stackrox-mcp/internal/toolsets" +) + +// getClustersForCVEInput defines the input parameters for get_clusters_for_cve tool. +type getClustersForCVEInput struct { + CVEName string `json:"cveName"` + FilterClusterID string `json:"filterClusterId,omitempty"` +} + +func (input *getClustersForCVEInput) validate() error { + if input.CVEName == "" { + return errors.New("CVE name is required") + } + + return nil +} + +// ClusterResult contains cluster information. +type ClusterResult struct { + ClusterID string `json:"clusterId"` + ClusterName string `json:"clusterName"` +} + +// getClustersForCVEOutput defines the output structure for get_clusters_for_cve tool. +type getClustersForCVEOutput struct { + Clusters []ClusterResult `json:"clusters"` +} + +// getClustersForCVETool implements the get_clusters_for_cve tool. +type getClustersForCVETool struct { + name string + client *client.Client +} + +// NewGetClustersForCVETool creates a new get_clusters_for_cve tool. +func NewGetClustersForCVETool(c *client.Client) toolsets.Tool { + return &getClustersForCVETool{ + name: "get_clusters_for_cve", + client: c, + } +} + +// IsReadOnly returns true as this tool only reads data. +func (t *getClustersForCVETool) IsReadOnly() bool { + return true +} + +// GetName returns the tool name. +func (t *getClustersForCVETool) GetName() string { + return t.name +} + +// GetTool returns the MCP Tool definition. +func (t *getClustersForCVETool) GetTool() *mcp.Tool { + return &mcp.Tool{ + Name: t.name, + Description: "Get list of clusters affected by a specific CVE", + InputSchema: getClustersForCVEInputSchema(), + } +} + +// getClustersForCVEInputSchema returns the JSON schema for input validation. +func getClustersForCVEInputSchema() *jsonschema.Schema { + schema, err := jsonschema.For[getClustersForCVEInput](nil) + if err != nil { + logging.Fatal("Could not get jsonschema for get_clusters_for_cve input", err) + + return nil + } + + // CVE name is required. + schema.Required = []string{"cveName"} + + schema.Properties["cveName"].Description = "CVE name to filter clusters (e.g., CVE-2021-44228)" + schema.Properties["filterClusterId"].Description = "Optional cluster ID to verify if a specific cluster is affected" + + return schema +} + +// RegisterWith registers the get_clusters_for_cve tool handler with the MCP server. +func (t *getClustersForCVETool) RegisterWith(server *mcp.Server) { + mcp.AddTool(server, t.GetTool(), t.handle) +} + +// buildClusterQuery builds query string for filtering clusters by CVE. +// We quote values for exact match (CVE-2025-10 won't match CVE-2025-101). +func buildClusterQuery(input getClustersForCVEInput) string { + queryParts := []string{fmt.Sprintf("CVE:%q", input.CVEName)} + + if input.FilterClusterID != "" { + queryParts = append(queryParts, fmt.Sprintf("Cluster ID:%q", input.FilterClusterID)) + } + + return strings.Join(queryParts, "+") +} + +// handle is the handler for get_clusters_for_cve tool. +func (t *getClustersForCVETool) handle( + ctx context.Context, + req *mcp.CallToolRequest, + input getClustersForCVEInput, +) (*mcp.CallToolResult, *getClustersForCVEOutput, error) { + err := input.validate() + if err != nil { + return nil, nil, err + } + + conn, err := t.client.ReadyConn(ctx) + if err != nil { + return nil, nil, errors.Wrap(err, "unable to connect to server") + } + + callCtx := auth.WithMCPRequestContext(ctx, req) + + clustersClient := v1.NewClustersServiceClient(conn) + + query := buildClusterQuery(input) + + resp, err := clustersClient.GetClusters(callCtx, &v1.GetClustersRequest{ + Query: query, + }) + if err != nil { + return nil, nil, client.NewError(err, "GetClusters") + } + + clusters := make([]ClusterResult, 0, len(resp.GetClusters())) + for _, cluster := range resp.GetClusters() { + clusters = append(clusters, ClusterResult{ + ClusterID: cluster.GetId(), + ClusterName: cluster.GetName(), + }) + } + + // Sort by cluster ID for deterministic output. + sort.Slice(clusters, func(i, j int) bool { + return clusters[i].ClusterID < clusters[j].ClusterID + }) + + output := &getClustersForCVEOutput{ + Clusters: clusters, + } + + return nil, output, nil +} diff --git a/internal/toolsets/vulnerability/clusters_test.go b/internal/toolsets/vulnerability/clusters_test.go new file mode 100644 index 0000000..5d34a61 --- /dev/null +++ b/internal/toolsets/vulnerability/clusters_test.go @@ -0,0 +1,339 @@ +package vulnerability + +import ( + "context" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + v1 "github.com/stackrox/rox/generated/api/v1" + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/stackrox-mcp/internal/client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestNewGetClustersForCVETool(t *testing.T) { + tool := NewGetClustersForCVETool(&client.Client{}) + require.NotNil(t, tool) + assert.Equal(t, "get_clusters_for_cve", tool.GetName()) +} + +func TestGetClustersForCVETool_IsReadOnly(t *testing.T) { + c := &client.Client{} + tool := NewGetClustersForCVETool(c) + + assert.True(t, tool.IsReadOnly(), "get_clusters_for_cve should be read-only") +} + +func TestGetClustersForCVETool_GetTool(t *testing.T) { + c := &client.Client{} + tool := NewGetClustersForCVETool(c) + + mcpTool := tool.GetTool() + + require.NotNil(t, mcpTool) + assert.Equal(t, "get_clusters_for_cve", mcpTool.Name) + assert.Contains(t, mcpTool.Description, "clusters affected") + assert.NotNil(t, mcpTool.InputSchema) +} + +func TestGetClustersForCVETool_RegisterWith(t *testing.T) { + c := &client.Client{} + tool := NewGetClustersForCVETool(c) + server := mcp.NewServer( + &mcp.Implementation{ + Name: "test-server", + Version: "1.0.0", + }, + &mcp.ServerOptions{}, + ) + + // Should not panic. + assert.NotPanics(t, func() { + tool.RegisterWith(server) + }) +} + +// Unit tests for input validate method. +func TestClusterInputValidate(t *testing.T) { + tests := map[string]struct { + input getClustersForCVEInput + expectError bool + errorMsg string + }{ + "valid input with CVE only": { + input: getClustersForCVEInput{CVEName: "CVE-2021-44228"}, + expectError: false, + }, + "missing CVE name (empty string)": { + input: getClustersForCVEInput{CVEName: ""}, + expectError: true, + errorMsg: "CVE name is required", + }, + "missing CVE name (zero value)": { + input: getClustersForCVEInput{}, + expectError: true, + errorMsg: "CVE name is required", + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + err := testCase.input.validate() + + if !testCase.expectError { + require.NoError(t, err) + + return + } + + require.Error(t, err) + assert.Contains(t, err.Error(), testCase.errorMsg) + }) + } +} + +// Mock infrastructure for gRPC testing. + +// mockClustersService implements v1.ClustersServiceServer for testing. +type mockClustersService struct { + v1.UnimplementedClustersServiceServer + + clusters []*storage.Cluster + err error + lastCallQuery string +} + +func (m *mockClustersService) GetClusters( + _ context.Context, + req *v1.GetClustersRequest, +) (*v1.ClustersList, error) { + m.lastCallQuery = req.GetQuery() + + if m.err != nil { + return nil, m.err + } + + return &v1.ClustersList{ + Clusters: m.clusters, + }, nil +} + +// Integration tests for handle method. +func TestClusterHandle_MissingCVE(t *testing.T) { + mockService := &mockClustersService{ + clusters: []*storage.Cluster{}, + } + + grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + defer grpcServer.Stop() + + v1.RegisterClustersServiceServer(grpcServer, mockService) + + testClient := createTestClient(t, listener) + tool, ok := NewGetClustersForCVETool(testClient).(*getClustersForCVETool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + inputWithoutCVEName := getClustersForCVEInput{} + + result, output, err := tool.handle(ctx, req, inputWithoutCVEName) + + require.Error(t, err) + assert.Nil(t, result) + assert.Nil(t, output) + assert.Contains(t, err.Error(), "CVE name is required") +} + +func TestClusterHandle_EmptyResults(t *testing.T) { + mockService := &mockClustersService{ + clusters: []*storage.Cluster{}, + } + + grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + defer grpcServer.Stop() + + v1.RegisterClustersServiceServer(grpcServer, mockService) + + testClient := createTestClient(t, listener) + tool, ok := NewGetClustersForCVETool(testClient).(*getClustersForCVETool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + input := getClustersForCVEInput{ + CVEName: "CVE-9999-99999", + } + + result, output, err := tool.handle(ctx, req, input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) + assert.Empty(t, output.Clusters, "Should have empty clusters array") +} + +func TestClusterHandle_GetClustersError(t *testing.T) { + mockService := &mockClustersService{ + err: status.Error(codes.Internal, "database error"), + } + + grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + defer grpcServer.Stop() + + v1.RegisterClustersServiceServer(grpcServer, mockService) + + testClient := createTestClient(t, listener) + tool, ok := NewGetClustersForCVETool(testClient).(*getClustersForCVETool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + input := getClustersForCVEInput{ + CVEName: "CVE-2021-44228", + } + + result, output, err := tool.handle(ctx, req, input) + + require.Error(t, err) + assert.Nil(t, result) + assert.Nil(t, output) + assert.Contains(t, err.Error(), "database error") +} + +func TestClusterHandle_MultipleResults(t *testing.T) { + mockService := &mockClustersService{ + clusters: []*storage.Cluster{ + {Id: "cluster-z", Name: "Production"}, + {Id: "cluster-a", Name: "Development"}, + {Id: "cluster-m", Name: "Testing"}, + }, + } + + grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + defer grpcServer.Stop() + + v1.RegisterClustersServiceServer(grpcServer, mockService) + + testClient := createTestClient(t, listener) + tool, ok := NewGetClustersForCVETool(testClient).(*getClustersForCVETool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + input := getClustersForCVEInput{ + CVEName: "CVE-2021-44228", + } + + result, output, err := tool.handle(ctx, req, input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) + + // Should have 3 clusters, sorted by ID. + require.Len(t, output.Clusters, 3) + assert.Equal(t, "cluster-a", output.Clusters[0].ClusterID) + assert.Equal(t, "Development", output.Clusters[0].ClusterName) + + assert.Equal(t, "cluster-m", output.Clusters[1].ClusterID) + assert.Equal(t, "Testing", output.Clusters[1].ClusterName) + + assert.Equal(t, "cluster-z", output.Clusters[2].ClusterID) + assert.Equal(t, "Production", output.Clusters[2].ClusterName) + + // Verify query was built correctly. + assert.Equal(t, `CVE:"CVE-2021-44228"`, mockService.lastCallQuery) +} + +func TestClusterHandle_Sorting(t *testing.T) { + mockService := &mockClustersService{ + clusters: []*storage.Cluster{ + {Id: "z-cluster", Name: "A"}, + {Id: "a-cluster", Name: "Z"}, + {Id: "m-cluster", Name: "M"}, + }, + } + + grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + defer grpcServer.Stop() + + v1.RegisterClustersServiceServer(grpcServer, mockService) + + testClient := createTestClient(t, listener) + tool, ok := NewGetClustersForCVETool(testClient).(*getClustersForCVETool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + input := getClustersForCVEInput{ + CVEName: "CVE-2021-44228", + } + + result, output, err := tool.handle(ctx, req, input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) + + require.Len(t, output.Clusters, 3) + + // Should be sorted by cluster ID alphabetically. + // Expected order: a-cluster, m-cluster, z-cluster + assert.Equal(t, "a-cluster", output.Clusters[0].ClusterID) + assert.Equal(t, "Z", output.Clusters[0].ClusterName) + + assert.Equal(t, "m-cluster", output.Clusters[1].ClusterID) + assert.Equal(t, "M", output.Clusters[1].ClusterName) + + assert.Equal(t, "z-cluster", output.Clusters[2].ClusterID) + assert.Equal(t, "A", output.Clusters[2].ClusterName) +} + +func TestClusterHandle_WithFilters(t *testing.T) { + mockService := &mockClustersService{ + clusters: []*storage.Cluster{ + {Id: "cluster-1", Name: "C1"}, + }, + } + + grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + defer grpcServer.Stop() + + v1.RegisterClustersServiceServer(grpcServer, mockService) + + tool, ok := NewGetClustersForCVETool(createTestClient(t, listener)).(*getClustersForCVETool) + require.True(t, ok) + + tests := map[string]struct { + input getClustersForCVEInput + expectedQuery string + }{ + "CVE only": { + input: getClustersForCVEInput{CVEName: "CVE-2021-44228"}, + expectedQuery: `CVE:"CVE-2021-44228"`, + }, + "CVE with cluster": { + input: getClustersForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterID: "cluster-123", + }, + expectedQuery: `CVE:"CVE-2021-44228"+Cluster ID:"cluster-123"`, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + result, output, err := tool.handle(context.Background(), &mcp.CallToolRequest{}, testCase.input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) + assert.Len(t, output.Clusters, 1) + assert.Equal(t, testCase.expectedQuery, mockService.lastCallQuery) + }) + } +} diff --git a/internal/toolsets/vulnerability/toolset.go b/internal/toolsets/vulnerability/toolset.go index 0ab8f0d..5e6a9f5 100644 --- a/internal/toolsets/vulnerability/toolset.go +++ b/internal/toolsets/vulnerability/toolset.go @@ -14,12 +14,13 @@ type Toolset struct { } // NewToolset creates a new vulnerability management toolset. -func NewToolset(cfg *config.Config, c *client.Client) *Toolset { +func NewToolset(cfg *config.Config, client *client.Client) *Toolset { return &Toolset{ cfg: cfg, tools: []toolsets.Tool{ - NewGetDeploymentsForCVETool(c), - NewGetNodesForCVETool(c), + NewGetDeploymentsForCVETool(client), + NewGetNodesForCVETool(client), + NewGetClustersForCVETool(client), }, } } diff --git a/internal/toolsets/vulnerability/toolset_test.go b/internal/toolsets/vulnerability/toolset_test.go index 9eccf0d..8dca67d 100644 --- a/internal/toolsets/vulnerability/toolset_test.go +++ b/internal/toolsets/vulnerability/toolset_test.go @@ -38,9 +38,10 @@ func TestToolset_IsEnabled_True(t *testing.T) { tools := toolset.GetTools() require.NotEmpty(t, tools, "Should return tools when enabled") - require.Len(t, tools, 2, "Should have all vulnerability tools") + require.Len(t, tools, 3, "Should have all vulnerability tools") assert.Equal(t, "get_deployments_for_cve", tools[0].GetName()) assert.Equal(t, "get_nodes_for_cve", tools[1].GetName()) + assert.Equal(t, "get_clusters_for_cve", tools[2].GetName()) } func TestToolset_IsEnabled_False(t *testing.T) {