From 1905801d5fe58f205e6016f11f70ef844735df67 Mon Sep 17 00:00:00 2001 From: Mladen Todorovic Date: Thu, 18 Dec 2025 15:40:53 +0100 Subject: [PATCH] Refactor tools tests --- internal/toolsets/config/tools_test.go | 64 +--- internal/toolsets/mock/api_server.go | 283 ++++++++++++++++++ .../toolsets/vulnerability/clusters_test.go | 89 ++---- .../{tools.go => deployments.go} | 0 .../{tools_test.go => deployments_test.go} | 195 +++--------- internal/toolsets/vulnerability/nodes_test.go | 105 ++----- 6 files changed, 398 insertions(+), 338 deletions(-) create mode 100644 internal/toolsets/mock/api_server.go rename internal/toolsets/vulnerability/{tools.go => deployments.go} (100%) rename internal/toolsets/vulnerability/{tools_test.go => deployments_test.go} (75%) diff --git a/internal/toolsets/config/tools_test.go b/internal/toolsets/config/tools_test.go index 72c4854..c2b08ec 100644 --- a/internal/toolsets/config/tools_test.go +++ b/internal/toolsets/config/tools_test.go @@ -7,10 +7,10 @@ import ( "testing" "github.com/modelcontextprotocol/go-sdk/mcp" - v1 "github.com/stackrox/rox/generated/api/v1" "github.com/stackrox/rox/generated/storage" "github.com/stackrox/stackrox-mcp/internal/client" "github.com/stackrox/stackrox-mcp/internal/config" + "github.com/stackrox/stackrox-mcp/internal/toolsets/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" @@ -60,44 +60,6 @@ func TestListClustersTool_RegisterWith(t *testing.T) { }) } -// Mock infrastructure for gRPC testing. - -// mockClustersService implements v1.ClustersServiceServer for testing. -type mockClustersService struct { - v1.UnimplementedClustersServiceServer - - clusters []*storage.Cluster - err error -} - -func (m *mockClustersService) GetClusters( - _ context.Context, - _ *v1.GetClustersRequest, -) (*v1.ClustersList, error) { - if m.err != nil { - return nil, m.err - } - - return &v1.ClustersList{ - Clusters: m.clusters, - }, nil -} - -// setupMockServer creates an in-memory gRPC server using bufconn. -func setupMockServer(mockService *mockClustersService) (*grpc.Server, *bufconn.Listener) { - buffer := 1024 * 1024 - listener := bufconn.Listen(buffer) - - grpcServer := grpc.NewServer() - v1.RegisterClustersServiceServer(grpcServer, mockService) - - go func() { - _ = grpcServer.Serve(listener) - }() - - return grpcServer, listener -} - // bufDialer creates a dialer function for bufconn. func bufDialer(listener *bufconn.Listener) func(context.Context, string) (net.Conn, error) { return func(_ context.Context, _ string) (net.Conn, error) { @@ -129,17 +91,18 @@ func createTestClient(t *testing.T, listener *bufconn.Listener) *client.Client { } func TestHandle_DefaultLimit(t *testing.T) { - mockService := &mockClustersService{ - clusters: []*storage.Cluster{ + mockService := mock.NewClustersServiceMock( + []*storage.Cluster{ {Id: "c1", Name: "Cluster 1", Type: storage.ClusterType_KUBERNETES_CLUSTER}, {Id: "c2", Name: "Cluster 2", Type: storage.ClusterType_KUBERNETES_CLUSTER}, {Id: "c3", Name: "Cluster 3", Type: storage.ClusterType_KUBERNETES_CLUSTER}, {Id: "c4", Name: "Cluster 4", Type: storage.ClusterType_KUBERNETES_CLUSTER}, {Id: "c5", Name: "Cluster 5", Type: storage.ClusterType_KUBERNETES_CLUSTER}, }, - } + nil, + ) - grpcServer, listener := setupMockServer(mockService) + grpcServer, listener := mock.SetupClusterServer(mockService) defer grpcServer.Stop() testClient := createTestClient(t, listener) @@ -180,11 +143,9 @@ func TestHandle_WithPagination(t *testing.T) { } } - mockService := &mockClustersService{ - clusters: clusters, - } + mockService := mock.NewClustersServiceMock(clusters, nil) - grpcServer, listener := setupMockServer(mockService) + grpcServer, listener := mock.SetupClusterServer(mockService) defer grpcServer.Stop() testClient := createTestClient(t, listener) @@ -255,11 +216,12 @@ func TestHandle_WithPagination(t *testing.T) { } func TestHandle_GetClustersError(t *testing.T) { - mockService := &mockClustersService{ - err: status.Error(codes.Internal, "test"), - } + mockService := mock.NewClustersServiceMock( + []*storage.Cluster{}, + status.Error(codes.Internal, "test"), + ) - grpcServer, listener := setupMockServer(mockService) + grpcServer, listener := mock.SetupClusterServer(mockService) defer grpcServer.Stop() testClient := createTestClient(t, listener) diff --git a/internal/toolsets/mock/api_server.go b/internal/toolsets/mock/api_server.go new file mode 100644 index 0000000..2752e47 --- /dev/null +++ b/internal/toolsets/mock/api_server.go @@ -0,0 +1,283 @@ +package mock + +import ( + "context" + "strings" + "sync" + + "github.com/pkg/errors" + v1 "github.com/stackrox/rox/generated/api/v1" + "github.com/stackrox/rox/generated/storage" + "google.golang.org/grpc" + "google.golang.org/grpc/test/bufconn" +) + +const bufferSize = 1024 * 1024 + +// SetupAPIServer creates an in-memory gRPC Central server. +func SetupAPIServer( + deploymentService v1.DeploymentServiceServer, + imageService v1.ImageServiceServer, + nodeService v1.NodeServiceServer, + clusterService v1.ClustersServiceServer, +) (*grpc.Server, *bufconn.Listener) { + buffer := bufferSize + listener := bufconn.Listen(buffer) + + grpcServer := grpc.NewServer() + v1.RegisterDeploymentServiceServer(grpcServer, deploymentService) + v1.RegisterImageServiceServer(grpcServer, imageService) + v1.RegisterNodeServiceServer(grpcServer, nodeService) + v1.RegisterClustersServiceServer(grpcServer, clusterService) + + go func() { + _ = grpcServer.Serve(listener) + }() + + return grpcServer, listener +} + +// SetupNodeServer creates an in-memory gRPC server with node services. +func SetupNodeServer(nodeService v1.NodeServiceServer) (*grpc.Server, *bufconn.Listener) { + return SetupAPIServer( + v1.UnimplementedDeploymentServiceServer{}, + v1.UnimplementedImageServiceServer{}, + nodeService, + v1.UnimplementedClustersServiceServer{}, + ) +} + +// SetupDeploymentServer creates an in-memory gRPC server with deployment services. +func SetupDeploymentServer(mockService v1.DeploymentServiceServer) (*grpc.Server, *bufconn.Listener) { + return SetupAPIServer( + mockService, + v1.UnimplementedImageServiceServer{}, + v1.UnimplementedNodeServiceServer{}, + v1.UnimplementedClustersServiceServer{}, + ) +} + +// SetupClusterServer creates an in-memory gRPC server with cluster services. +func SetupClusterServer(mockService v1.ClustersServiceServer) (*grpc.Server, *bufconn.Listener) { + return SetupAPIServer( + v1.UnimplementedDeploymentServiceServer{}, + v1.UnimplementedImageServiceServer{}, + v1.UnimplementedNodeServiceServer{}, + mockService, + ) +} + +// ClustersService implements v1.ClustersServiceServer for testing. +type ClustersService struct { + v1.UnimplementedClustersServiceServer + + clusters []*storage.Cluster + err error + + lastCallQuery string +} + +// NewClustersServiceMock return mock to cluster service. +func NewClustersServiceMock(clusters []*storage.Cluster, err error) *ClustersService { + return &ClustersService{clusters: clusters, err: err} +} + +// GetLastCallQuery returns query used for the last call. +func (cs *ClustersService) GetLastCallQuery() string { + return cs.lastCallQuery +} + +// GetClusters implements v1.ClustersServiceServer.GetClusters for testing. +func (cs *ClustersService) GetClusters( + _ context.Context, + req *v1.GetClustersRequest, +) (*v1.ClustersList, error) { + cs.lastCallQuery = req.GetQuery() + + if cs.err != nil { + return nil, cs.err + } + + return &v1.ClustersList{ + Clusters: cs.clusters, + }, nil +} + +// NodeService implements v1.NodeServiceServer for testing. +type NodeService struct { + v1.UnimplementedNodeServiceServer + + nodes []*storage.Node + err error + + lastCallQuery string +} + +// NewNodeServiceMock return mock to node service. +func NewNodeServiceMock(nodes []*storage.Node, err error) *NodeService { + return &NodeService{nodes: nodes, err: err} +} + +// GetLastCallQuery returns query used for the last call. +func (ns *NodeService) GetLastCallQuery() string { + return ns.lastCallQuery +} + +// ExportNodes implements v1.NodeServiceServer.ExportNodes for testing. +func (ns *NodeService) ExportNodes( + req *v1.ExportNodeRequest, + stream grpc.ServerStreamingServer[v1.ExportNodeResponse], +) error { + ns.lastCallQuery = req.GetQuery() + + if ns.err != nil { + return ns.err + } + + // Send all nodes through the stream. + for _, node := range ns.nodes { + resp := &v1.ExportNodeResponse{Node: node} + if err := stream.Send(resp); err != nil { + return errors.Wrap(err, "sending node over stream failed") + } + } + + return nil +} + +// DeploymentService implements v1.DeploymentServiceServer for testing. +type DeploymentService struct { + v1.UnimplementedDeploymentServiceServer + + deployments []*storage.ListDeployment + err error + + // Mock call information. + lastCallQuery string + lastCallLimit int32 + lastCallOffset int32 +} + +// NewDeploymentServiceMock returns mock for deployment service. +func NewDeploymentServiceMock(deployments []*storage.ListDeployment, err error) *DeploymentService { + return &DeploymentService{ + deployments: deployments, + err: err, + } +} + +// GetLastCallQuery returns query used for the last call. +func (ds *DeploymentService) GetLastCallQuery() string { + return ds.lastCallQuery +} + +// GetLastCallLimit returns limit used for the last call. +func (ds *DeploymentService) GetLastCallLimit() int32 { + return ds.lastCallLimit +} + +// GetLastCallOffset returns offset used for the last call. +func (ds *DeploymentService) GetLastCallOffset() int32 { + return ds.lastCallOffset +} + +// ListDeployments implements v1.DeploymentServiceServer.ListDeployments for testing. +func (ds *DeploymentService) ListDeployments( + _ context.Context, + query *v1.RawQuery, +) (*v1.ListDeploymentsResponse, error) { + ds.lastCallQuery = query.GetQuery() + ds.lastCallLimit = query.GetPagination().GetLimit() + ds.lastCallOffset = query.GetPagination().GetOffset() + + if ds.err != nil { + return nil, ds.err + } + + return &v1.ListDeploymentsResponse{ + Deployments: ds.deployments, + }, nil +} + +// ImageService implements v1.ImageServiceServer for testing. +type ImageService struct { + v1.UnimplementedImageServiceServer + + images map[string][]*storage.ListImage // keyed by deploymentID + err error + + // We are requesting images in parallel requests. + lock sync.Mutex + + // Mock call information. + lastCallQuery string + lastCallLimit int32 + callCount int +} + +// NewImageServiceMock returns mock for image service. +func NewImageServiceMock(images map[string][]*storage.ListImage, err error) *ImageService { + return &ImageService{ + images: images, + err: err, + } +} + +// GetLastCallQuery returns query used for the last call. +func (is *ImageService) GetLastCallQuery() string { + return is.lastCallQuery +} + +// GetLastCallLimit returns limit used for the last call. +func (is *ImageService) GetLastCallLimit() int32 { + return is.lastCallLimit +} + +// GetCallCount returns count off all calls. +func (is *ImageService) GetCallCount() int { + return is.callCount +} + +// ListImages implements v1.ImageServiceServer.ListImages for testing. +func (is *ImageService) ListImages( + _ context.Context, + query *v1.RawQuery, +) (*v1.ListImagesResponse, error) { + is.lock.Lock() + defer is.lock.Unlock() + + is.callCount++ + is.lastCallQuery = query.GetQuery() + is.lastCallLimit = query.GetPagination().GetLimit() + + if is.err != nil { + return nil, is.err + } + + // Extract deployment ID from query. + // Query format: CVE:"CVE-2021-44228"+Deployment ID:"dep-1" + deploymentID := extractDeploymentIDFromQuery(query.GetQuery()) + + return &v1.ListImagesResponse{ + Images: is.images[deploymentID], + }, nil +} + +// extractDeploymentIDFromQuery extracts deployment ID from the query string. +func extractDeploymentIDFromQuery(query string) string { + const prefix = "Deployment ID:\"" + + start := strings.Index(query, prefix) + if start == -1 { + return "" + } + + start += len(prefix) + + end := strings.Index(query[start:], "\"") + if end == -1 { + return "" + } + + return query[start : start+end] +} diff --git a/internal/toolsets/vulnerability/clusters_test.go b/internal/toolsets/vulnerability/clusters_test.go index 5d34a61..795c388 100644 --- a/internal/toolsets/vulnerability/clusters_test.go +++ b/internal/toolsets/vulnerability/clusters_test.go @@ -5,9 +5,9 @@ import ( "testing" "github.com/modelcontextprotocol/go-sdk/mcp" - v1 "github.com/stackrox/rox/generated/api/v1" "github.com/stackrox/rox/generated/storage" "github.com/stackrox/stackrox-mcp/internal/client" + "github.com/stackrox/stackrox-mcp/internal/toolsets/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" @@ -95,43 +95,13 @@ func TestClusterInputValidate(t *testing.T) { } } -// Mock infrastructure for gRPC testing. - -// mockClustersService implements v1.ClustersServiceServer for testing. -type mockClustersService struct { - v1.UnimplementedClustersServiceServer - - clusters []*storage.Cluster - err error - lastCallQuery string -} - -func (m *mockClustersService) GetClusters( - _ context.Context, - req *v1.GetClustersRequest, -) (*v1.ClustersList, error) { - m.lastCallQuery = req.GetQuery() - - if m.err != nil { - return nil, m.err - } - - return &v1.ClustersList{ - Clusters: m.clusters, - }, nil -} - // Integration tests for handle method. func TestClusterHandle_MissingCVE(t *testing.T) { - mockService := &mockClustersService{ - clusters: []*storage.Cluster{}, - } + mockService := mock.NewClustersServiceMock([]*storage.Cluster{}, nil) - grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + grpcServer, listener := mock.SetupClusterServer(mockService) defer grpcServer.Stop() - v1.RegisterClustersServiceServer(grpcServer, mockService) - testClient := createTestClient(t, listener) tool, ok := NewGetClustersForCVETool(testClient).(*getClustersForCVETool) require.True(t, ok) @@ -149,15 +119,11 @@ func TestClusterHandle_MissingCVE(t *testing.T) { } func TestClusterHandle_EmptyResults(t *testing.T) { - mockService := &mockClustersService{ - clusters: []*storage.Cluster{}, - } + mockService := mock.NewClustersServiceMock([]*storage.Cluster{}, nil) - grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + grpcServer, listener := mock.SetupClusterServer(mockService) defer grpcServer.Stop() - v1.RegisterClustersServiceServer(grpcServer, mockService) - testClient := createTestClient(t, listener) tool, ok := NewGetClustersForCVETool(testClient).(*getClustersForCVETool) require.True(t, ok) @@ -177,15 +143,11 @@ func TestClusterHandle_EmptyResults(t *testing.T) { } func TestClusterHandle_GetClustersError(t *testing.T) { - mockService := &mockClustersService{ - err: status.Error(codes.Internal, "database error"), - } + mockService := mock.NewClustersServiceMock([]*storage.Cluster{}, status.Error(codes.Internal, "database error")) - grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + grpcServer, listener := mock.SetupClusterServer(mockService) defer grpcServer.Stop() - v1.RegisterClustersServiceServer(grpcServer, mockService) - testClient := createTestClient(t, listener) tool, ok := NewGetClustersForCVETool(testClient).(*getClustersForCVETool) require.True(t, ok) @@ -205,19 +167,18 @@ func TestClusterHandle_GetClustersError(t *testing.T) { } func TestClusterHandle_MultipleResults(t *testing.T) { - mockService := &mockClustersService{ - clusters: []*storage.Cluster{ + mockService := mock.NewClustersServiceMock( + []*storage.Cluster{ {Id: "cluster-z", Name: "Production"}, {Id: "cluster-a", Name: "Development"}, {Id: "cluster-m", Name: "Testing"}, }, - } + nil, + ) - grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + grpcServer, listener := mock.SetupClusterServer(mockService) defer grpcServer.Stop() - v1.RegisterClustersServiceServer(grpcServer, mockService) - testClient := createTestClient(t, listener) tool, ok := NewGetClustersForCVETool(testClient).(*getClustersForCVETool) require.True(t, ok) @@ -246,23 +207,22 @@ func TestClusterHandle_MultipleResults(t *testing.T) { assert.Equal(t, "Production", output.Clusters[2].ClusterName) // Verify query was built correctly. - assert.Equal(t, `CVE:"CVE-2021-44228"`, mockService.lastCallQuery) + assert.Equal(t, `CVE:"CVE-2021-44228"`, mockService.GetLastCallQuery()) } func TestClusterHandle_Sorting(t *testing.T) { - mockService := &mockClustersService{ - clusters: []*storage.Cluster{ + mockService := mock.NewClustersServiceMock( + []*storage.Cluster{ {Id: "z-cluster", Name: "A"}, {Id: "a-cluster", Name: "Z"}, {Id: "m-cluster", Name: "M"}, }, - } + nil, + ) - grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + grpcServer, listener := mock.SetupClusterServer(mockService) defer grpcServer.Stop() - v1.RegisterClustersServiceServer(grpcServer, mockService) - testClient := createTestClient(t, listener) tool, ok := NewGetClustersForCVETool(testClient).(*getClustersForCVETool) require.True(t, ok) @@ -294,17 +254,16 @@ func TestClusterHandle_Sorting(t *testing.T) { } func TestClusterHandle_WithFilters(t *testing.T) { - mockService := &mockClustersService{ - clusters: []*storage.Cluster{ + mockService := mock.NewClustersServiceMock( + []*storage.Cluster{ {Id: "cluster-1", Name: "C1"}, }, - } + nil, + ) - grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + grpcServer, listener := mock.SetupClusterServer(mockService) defer grpcServer.Stop() - v1.RegisterClustersServiceServer(grpcServer, mockService) - tool, ok := NewGetClustersForCVETool(createTestClient(t, listener)).(*getClustersForCVETool) require.True(t, ok) @@ -333,7 +292,7 @@ func TestClusterHandle_WithFilters(t *testing.T) { require.NotNil(t, output) assert.Nil(t, result) assert.Len(t, output.Clusters, 1) - assert.Equal(t, testCase.expectedQuery, mockService.lastCallQuery) + assert.Equal(t, testCase.expectedQuery, mockService.GetLastCallQuery()) }) } } diff --git a/internal/toolsets/vulnerability/tools.go b/internal/toolsets/vulnerability/deployments.go similarity index 100% rename from internal/toolsets/vulnerability/tools.go rename to internal/toolsets/vulnerability/deployments.go diff --git a/internal/toolsets/vulnerability/tools_test.go b/internal/toolsets/vulnerability/deployments_test.go similarity index 75% rename from internal/toolsets/vulnerability/tools_test.go rename to internal/toolsets/vulnerability/deployments_test.go index af6dd4c..4121135 100644 --- a/internal/toolsets/vulnerability/tools_test.go +++ b/internal/toolsets/vulnerability/deployments_test.go @@ -4,8 +4,6 @@ import ( "context" "fmt" "net" - "strings" - "sync" "testing" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -14,6 +12,7 @@ import ( "github.com/stackrox/stackrox-mcp/internal/client" "github.com/stackrox/stackrox-mcp/internal/config" "github.com/stackrox/stackrox-mcp/internal/cursor" + "github.com/stackrox/stackrox-mcp/internal/toolsets/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" @@ -104,118 +103,6 @@ func TestInputValidate(t *testing.T) { } } -// Mock infrastructure for gRPC testing. - -// mockDeploymentService implements v1.DeploymentServiceServer for testing. -type mockDeploymentService struct { - v1.UnimplementedDeploymentServiceServer - - deployments []*storage.ListDeployment - err error - - lastCallQuery string - lastCallLimit int32 - lastCallOffset int32 -} - -func (m *mockDeploymentService) ListDeployments( - _ context.Context, - query *v1.RawQuery, -) (*v1.ListDeploymentsResponse, error) { - m.lastCallQuery = query.GetQuery() - m.lastCallLimit = query.GetPagination().GetLimit() - m.lastCallOffset = query.GetPagination().GetOffset() - - if m.err != nil { - return nil, m.err - } - - return &v1.ListDeploymentsResponse{ - Deployments: m.deployments, - }, nil -} - -// mockImageService implements v1.ImageServiceServer for testing. -type mockImageService struct { - v1.UnimplementedImageServiceServer - - images map[string][]*storage.ListImage // keyed by deploymentID - err error - - // We are requesting images in parallel requests. - lock sync.Mutex - lastCallQuery string - lastCallLimit int32 - callCount int -} - -func (m *mockImageService) ListImages( - _ context.Context, - query *v1.RawQuery, -) (*v1.ListImagesResponse, error) { - m.lock.Lock() - defer m.lock.Unlock() - - m.callCount++ - m.lastCallQuery = query.GetQuery() - m.lastCallLimit = query.GetPagination().GetLimit() - - if m.err != nil { - return nil, m.err - } - - // Extract deployment ID from query. - // Query format: CVE:"CVE-2021-44228"+Deployment ID:"dep-1" - deploymentID := extractDeploymentIDFromQuery(query.GetQuery()) - - return &v1.ListImagesResponse{ - Images: m.images[deploymentID], - }, nil -} - -// extractDeploymentIDFromQuery extracts deployment ID from the query string. -func extractDeploymentIDFromQuery(query string) string { - const prefix = "Deployment ID:\"" - - start := strings.Index(query, prefix) - if start == -1 { - return "" - } - - start += len(prefix) - - end := strings.Index(query[start:], "\"") - if end == -1 { - return "" - } - - return query[start : start+end] -} - -// setupMockServer creates an in-memory gRPC server with both deployment and image services. -func setupMockServer( - deploymentService *mockDeploymentService, - imageService *mockImageService, -) (*grpc.Server, *bufconn.Listener) { - buffer := 1024 * 1024 - listener := bufconn.Listen(buffer) - - grpcServer := grpc.NewServer() - v1.RegisterDeploymentServiceServer(grpcServer, deploymentService) - v1.RegisterImageServiceServer(grpcServer, imageService) - - go func() { - _ = grpcServer.Serve(listener) - }() - - return grpcServer, listener -} - -// setupMockDeploymentServer creates an in-memory gRPC server using bufconn. -func setupMockDeploymentServer(mockService *mockDeploymentService) (*grpc.Server, *bufconn.Listener) { - return setupMockServer(mockService, &mockImageService{}) -} - // bufDialer creates a dialer function for bufconn. func bufDialer(listener *bufconn.Listener) func(context.Context, string) (net.Conn, error) { return func(_ context.Context, _ string) (net.Conn, error) { @@ -264,11 +151,12 @@ func getTestDeployments(totalDeployments int) []*storage.ListDeployment { // Integration tests for handle method. func TestHandle_MissingCVE(t *testing.T) { - mockService := &mockDeploymentService{ - deployments: []*storage.ListDeployment{}, - } + mockService := mock.NewDeploymentServiceMock( + []*storage.ListDeployment{}, + nil, + ) - grpcServer, listener := setupMockDeploymentServer(mockService) + grpcServer, listener := mock.SetupDeploymentServer(mockService) defer grpcServer.Stop() testClient := createTestClient(t, listener) @@ -288,11 +176,12 @@ func TestHandle_MissingCVE(t *testing.T) { } func TestHandle_WithPagination(t *testing.T) { - mockService := &mockDeploymentService{ - deployments: getTestDeployments(defaultLimit + 1), - } + mockService := mock.NewDeploymentServiceMock( + getTestDeployments(defaultLimit+1), + nil, + ) - grpcServer, listener := setupMockDeploymentServer(mockService) + grpcServer, listener := mock.SetupDeploymentServer(mockService) defer grpcServer.Stop() testClient := createTestClient(t, listener) @@ -319,8 +208,8 @@ func TestHandle_WithPagination(t *testing.T) { assert.Nil(t, result) assert.Len(t, output.Deployments, defaultLimit) - assert.Equal(t, int32(2), mockService.lastCallOffset) - assert.Equal(t, int32(defaultLimit+1), mockService.lastCallLimit) + assert.Equal(t, int32(2), mockService.GetLastCallOffset()) + assert.Equal(t, int32(defaultLimit+1), mockService.GetLastCallLimit()) nextCursor := currCursor.GetNextCursor(defaultLimit) returnedCursor, err := cursor.Decode(output.NextCursor) @@ -329,11 +218,12 @@ func TestHandle_WithPagination(t *testing.T) { } func TestHandle_EmptyResults(t *testing.T) { - mockService := &mockDeploymentService{ - deployments: []*storage.ListDeployment{}, - } + mockService := mock.NewDeploymentServiceMock( + []*storage.ListDeployment{}, + nil, + ) - grpcServer, listener := setupMockDeploymentServer(mockService) + grpcServer, listener := mock.SetupDeploymentServer(mockService) defer grpcServer.Stop() testClient := createTestClient(t, listener) @@ -355,11 +245,12 @@ func TestHandle_EmptyResults(t *testing.T) { } func TestHandle_ListDeploymentsError(t *testing.T) { - mockService := &mockDeploymentService{ - err: status.Error(codes.Internal, "database error"), - } + mockService := mock.NewDeploymentServiceMock( + []*storage.ListDeployment{}, + status.Error(codes.Internal, "database error"), + ) - grpcServer, listener := setupMockDeploymentServer(mockService) + grpcServer, listener := mock.SetupDeploymentServer(mockService) defer grpcServer.Stop() testClient := createTestClient(t, listener) @@ -381,9 +272,9 @@ func TestHandle_ListDeploymentsError(t *testing.T) { } func TestHandle_WithFilters(t *testing.T) { - mockService := &mockDeploymentService{deployments: getTestDeployments(1)} + mockService := mock.NewDeploymentServiceMock(getTestDeployments(1), nil) - grpcServer, listener := setupMockDeploymentServer(mockService) + grpcServer, listener := mock.SetupDeploymentServer(mockService) defer grpcServer.Stop() tool, ok := NewGetDeploymentsForCVETool(createTestClient(t, listener)).(*getDeploymentsForCVETool) @@ -438,7 +329,7 @@ func TestHandle_WithFilters(t *testing.T) { require.NotNil(t, output) assert.Nil(t, result) assert.Len(t, output.Deployments, 1) - assert.Equal(t, testCase.expectedQuery, mockService.lastCallQuery) + assert.Equal(t, testCase.expectedQuery, mockService.GetLastCallQuery()) }) } } @@ -515,15 +406,15 @@ func TestHandle_WithIncludeAffectedImages(t *testing.T) { for testName, testCase := range tests { t.Run(testName, func(t *testing.T) { - deploymentService := &mockDeploymentService{ - deployments: testCase.deployments, - } - imageService := &mockImageService{ - images: testCase.imagesByDeployment, - err: testCase.imageServiceError, - } - - grpcServer, listener := setupMockServer(deploymentService, imageService) + deploymentService := mock.NewDeploymentServiceMock(testCase.deployments, nil) + imageService := mock.NewImageServiceMock(testCase.imagesByDeployment, testCase.imageServiceError) + + grpcServer, listener := mock.SetupAPIServer( + deploymentService, + imageService, + v1.UnimplementedNodeServiceServer{}, + v1.UnimplementedClustersServiceServer{}, + ) defer grpcServer.Stop() testClient := createTestClient(t, listener) @@ -572,17 +463,23 @@ func TestHandle_ImageFetchPartialFailure(t *testing.T) { } // Create a mock service that returns error for dep-2. - imageService := &mockImageService{ - images: map[string][]*storage.ListImage{ + imageService := mock.NewImageServiceMock( + map[string][]*storage.ListImage{ "dep-1": {{Name: "nginx:1.19"}}, // dep-2 will be missing - which the mock will treat as empty }, - } + nil, + ) // Simpler approach: use the existing mock but verify the structure allows errors. - deploymentService := &mockDeploymentService{deployments: deployments} - grpcServer, listener := setupMockServer(deploymentService, imageService) + deploymentService := mock.NewDeploymentServiceMock(deployments, nil) + grpcServer, listener := mock.SetupAPIServer( + deploymentService, + imageService, + v1.UnimplementedNodeServiceServer{}, + v1.UnimplementedClustersServiceServer{}, + ) defer grpcServer.Stop() testClient := createTestClient(t, listener) diff --git a/internal/toolsets/vulnerability/nodes_test.go b/internal/toolsets/vulnerability/nodes_test.go index 1cd39ed..49cd8b6 100644 --- a/internal/toolsets/vulnerability/nodes_test.go +++ b/internal/toolsets/vulnerability/nodes_test.go @@ -5,13 +5,11 @@ import ( "testing" "github.com/modelcontextprotocol/go-sdk/mcp" - "github.com/pkg/errors" - v1 "github.com/stackrox/rox/generated/api/v1" "github.com/stackrox/rox/generated/storage" "github.com/stackrox/stackrox-mcp/internal/client" + "github.com/stackrox/stackrox-mcp/internal/toolsets/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -97,50 +95,16 @@ func TestNodeInputValidate(t *testing.T) { } } -// Mock infrastructure for gRPC testing. - -// mockNodeService implements v1.NodeServiceServer for testing. -type mockNodeService struct { - v1.UnimplementedNodeServiceServer - - nodes []*storage.Node - err error - - lastCallQuery string -} - -func (m *mockNodeService) ExportNodes( - req *v1.ExportNodeRequest, - stream grpc.ServerStreamingServer[v1.ExportNodeResponse], -) error { - m.lastCallQuery = req.GetQuery() - - if m.err != nil { - return m.err - } - - // Send all nodes through the stream. - for _, node := range m.nodes { - resp := &v1.ExportNodeResponse{Node: node} - if err := stream.Send(resp); err != nil { - return errors.Wrap(err, "sending node over stream failed") - } - } - - return nil -} - // Integration tests for handle method. func TestNodeHandle_MissingCVE(t *testing.T) { - mockService := &mockNodeService{ - nodes: []*storage.Node{}, - } + mockService := mock.NewNodeServiceMock( + []*storage.Node{}, + nil, + ) - grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + grpcServer, listener := mock.SetupNodeServer(mockService) defer grpcServer.Stop() - v1.RegisterNodeServiceServer(grpcServer, mockService) - testClient := createTestClient(t, listener) tool, ok := NewGetNodesForCVETool(testClient).(*getNodesForCVETool) require.True(t, ok) @@ -158,15 +122,14 @@ func TestNodeHandle_MissingCVE(t *testing.T) { } func TestNodeHandle_EmptyResults(t *testing.T) { - mockService := &mockNodeService{ - nodes: []*storage.Node{}, - } + mockService := mock.NewNodeServiceMock( + []*storage.Node{}, + nil, + ) - grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + grpcServer, listener := mock.SetupNodeServer(mockService) defer grpcServer.Stop() - v1.RegisterNodeServiceServer(grpcServer, mockService) - testClient := createTestClient(t, listener) tool, ok := NewGetNodesForCVETool(testClient).(*getNodesForCVETool) require.True(t, ok) @@ -186,15 +149,14 @@ func TestNodeHandle_EmptyResults(t *testing.T) { } func TestNodeHandle_ExportNodesError(t *testing.T) { - mockService := &mockNodeService{ - err: status.Error(codes.Internal, "database error"), - } + mockService := mock.NewNodeServiceMock( + []*storage.Node{}, + status.Error(codes.Internal, "database error"), + ) - grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + grpcServer, listener := mock.SetupNodeServer(mockService) defer grpcServer.Stop() - v1.RegisterNodeServiceServer(grpcServer, mockService) - testClient := createTestClient(t, listener) tool, ok := NewGetNodesForCVETool(testClient).(*getNodesForCVETool) require.True(t, ok) @@ -214,8 +176,8 @@ func TestNodeHandle_ExportNodesError(t *testing.T) { } func TestNodeHandle_Aggregation(t *testing.T) { - mockService := &mockNodeService{ - nodes: []*storage.Node{ + mockService := mock.NewNodeServiceMock( + []*storage.Node{ {Name: "n1", ClusterId: "c1", ClusterName: "Prod", OsImage: "Ubuntu 20.04"}, {Name: "n2", ClusterId: "c1", ClusterName: "Prod", OsImage: "Ubuntu 20.04"}, {Name: "n3", ClusterId: "c1", ClusterName: "Prod", OsImage: "Ubuntu 22.04"}, @@ -223,13 +185,12 @@ func TestNodeHandle_Aggregation(t *testing.T) { {Name: "n5", ClusterId: "c2", ClusterName: "Dev", OsImage: "Ubuntu 20.04"}, {Name: "n6", ClusterId: "c1", ClusterName: "Prod", OsImage: "Ubuntu 20.04"}, }, - } + nil, + ) - grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + grpcServer, listener := mock.SetupNodeServer(mockService) defer grpcServer.Stop() - v1.RegisterNodeServiceServer(grpcServer, mockService) - testClient := createTestClient(t, listener) tool, ok := NewGetNodesForCVETool(testClient).(*getNodesForCVETool) require.True(t, ok) @@ -272,20 +233,19 @@ func TestNodeHandle_Aggregation(t *testing.T) { } func TestNodeHandle_Sorting(t *testing.T) { - mockService := &mockNodeService{ - nodes: []*storage.Node{ + mockService := mock.NewNodeServiceMock( + []*storage.Node{ {Name: "n1", ClusterId: "z-cluster", ClusterName: "A", OsImage: "Ubuntu 20.04"}, {Name: "n2", ClusterId: "a-cluster", ClusterName: "Z", OsImage: "RHEL 8"}, {Name: "n3", ClusterId: "a-cluster", ClusterName: "Z", OsImage: "CentOS 7"}, {Name: "n4", ClusterId: "z-cluster", ClusterName: "A", OsImage: ""}, }, - } + nil, + ) - grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + grpcServer, listener := mock.SetupNodeServer(mockService) defer grpcServer.Stop() - v1.RegisterNodeServiceServer(grpcServer, mockService) - testClient := createTestClient(t, listener) tool, ok := NewGetNodesForCVETool(testClient).(*getNodesForCVETool) require.True(t, ok) @@ -324,17 +284,16 @@ func TestNodeHandle_Sorting(t *testing.T) { } func TestNodeHandle_WithFilters(t *testing.T) { - mockService := &mockNodeService{ - nodes: []*storage.Node{ + mockService := mock.NewNodeServiceMock( + []*storage.Node{ {Name: "n1", ClusterId: "cluster-1", ClusterName: "C1", OsImage: "Ubuntu 20.04"}, }, - } + nil, + ) - grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + grpcServer, listener := mock.SetupNodeServer(mockService) defer grpcServer.Stop() - v1.RegisterNodeServiceServer(grpcServer, mockService) - tool, ok := NewGetNodesForCVETool(createTestClient(t, listener)).(*getNodesForCVETool) require.True(t, ok) @@ -363,7 +322,7 @@ func TestNodeHandle_WithFilters(t *testing.T) { require.NotNil(t, output) assert.Nil(t, result) assert.Len(t, output.NodeGroups, 1) - assert.Equal(t, testCase.expectedQuery, mockService.lastCallQuery) + assert.Equal(t, testCase.expectedQuery, mockService.GetLastCallQuery()) }) } }