Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 116 additions & 14 deletions internal/toolsets/vulnerability/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"strings"
"sync"

"github.com/google/jsonschema-go/jsonschema"
"github.com/modelcontextprotocol/go-sdk/mcp"
Expand All @@ -30,11 +31,12 @@ const (

// getDeploymentsForCVEInput defines the input parameters for get_deployments_for_cve tool.
type getDeploymentsForCVEInput struct {
CVEName string `json:"cveName"`
FilterClusterID string `json:"filterClusterId,omitempty"`
FilterNamespace string `json:"filterNamespace,omitempty"`
FilterPlatform filterPlatformType `json:"filterPlatform,omitempty"`
Cursor string `json:"cursor,omitempty"`
CVEName string `json:"cveName"`
FilterClusterID string `json:"filterClusterId,omitempty"`
FilterNamespace string `json:"filterNamespace,omitempty"`
FilterPlatform filterPlatformType `json:"filterPlatform,omitempty"`
IncludeAffectedImages bool `json:"includeAffectedImages,omitempty"`
Cursor string `json:"cursor,omitempty"`
}

func (input *getDeploymentsForCVEInput) validate() error {
Expand All @@ -45,12 +47,16 @@ func (input *getDeploymentsForCVEInput) validate() error {
return nil
}

// DeploymentResult contains deployment information.
// DeploymentResult contains deployment information with optional image data.
type DeploymentResult struct {
Name string `json:"name"`
Namespace string `json:"namespace"`
ClusterID string `json:"clusterId"`
ClusterName string `json:"clusterName"`
id string

Name string `json:"name"`
Namespace string `json:"namespace"`
ClusterID string `json:"clusterId"`
ClusterName string `json:"clusterName"`
AffectedImages []string `json:"affectedImages,omitempty"`
ImageFetchError string `json:"imageFetchError,omitempty"`
}

// getDeploymentsForCVEOutput defines the output structure for get_deployments_for_cve tool.
Expand Down Expand Up @@ -118,6 +124,11 @@ func getDeploymentsForCVEInputSchema() *jsonschema.Schema {
filterPlatformPlatform,
}

schema.Properties["includeAffectedImages"].Description =
"Whether to include affected image names for each deployment.\n" +
"WARNING: This may significantly increase response time."
schema.Properties["includeAffectedImages"].Default = toolsets.MustJSONMarshal(false)

schema.Properties["cursor"].Description = "Cursor for next page provided by server"

return schema
Expand Down Expand Up @@ -168,7 +179,84 @@ func getCursor(input *getDeploymentsForCVEInput) (*cursor.Cursor, error) {
return currCursor, nil
}

const defaultMaxFetchImageConcurrency = 10

// deploymentEnricher handles parallel enrichment of deployments with image data.
type deploymentEnricher struct {
imageClient v1.ImageServiceClient
cveName string
semaphore chan struct{}
wg sync.WaitGroup
}

// newDeploymentEnricher creates a new enricher with max concurrency limit.
func newDeploymentEnricher(
imageClient v1.ImageServiceClient,
cveName string,
maxConcurrency int,
) *deploymentEnricher {
return &deploymentEnricher{
imageClient: imageClient,
cveName: cveName,
semaphore: make(chan struct{}, maxConcurrency),
}
}

// enrich enriches a single deployment result with image data in a goroutine.
// Must be called before wait().
func (e *deploymentEnricher) enrich(
ctx context.Context,
deployment *DeploymentResult,
) {
e.wg.Go(func() {
e.semaphore <- struct{}{}

defer func() { <-e.semaphore }()

// Enrich the result in-place.
images, err := fetchImagesForDeployment(ctx, e.imageClient, deployment, e.cveName)
if err != nil {
deployment.ImageFetchError = err.Error()

return
}

deployment.AffectedImages = images
})
}

// wait waits for all enrichment workers to complete.
func (e *deploymentEnricher) wait() {
e.wg.Wait()
}

// fetchImagesForDeployment fetches images for a single deployment.
// It queries the images API filtered by CVE and Deployment ID.
func fetchImagesForDeployment(
ctx context.Context,
imageClient v1.ImageServiceClient,
deployment *DeploymentResult,
cveName string,
) ([]string, error) {
query := fmt.Sprintf("CVE:%q+Deployment ID:%q", cveName, deployment.id)

resp, err := imageClient.ListImages(ctx, &v1.RawQuery{Query: query})
if err != nil {
return nil, errors.Wrapf(err, "failed to fetch images for deployment %q in namespace %q",
deployment.Name, deployment.Namespace)
}

images := make([]string, 0, len(resp.GetImages()))
for _, img := range resp.GetImages() {
images = append(images, img.GetName())
}

return images, nil
}

// handle is the handler for get_deployments_for_cve tool.
//
//nolint:funlen
func (t *getDeploymentsForCVETool) handle(
ctx context.Context,
req *mcp.CallToolRequest,
Expand Down Expand Up @@ -205,14 +293,28 @@ func (t *getDeploymentsForCVETool) handle(
return nil, nil, client.NewError(err, "ListDeployments")
}

deployments := make([]DeploymentResult, 0, len(resp.GetDeployments()))
for _, deployment := range resp.GetDeployments() {
deployments = append(deployments, DeploymentResult{
rawDeployments := resp.GetDeployments()

deployments := make([]DeploymentResult, len(rawDeployments))
for i, deployment := range rawDeployments {
deployments[i] = DeploymentResult{
id: deployment.GetId(),
Name: deployment.GetName(),
Namespace: deployment.GetNamespace(),
ClusterID: deployment.GetClusterId(),
ClusterName: deployment.GetCluster(),
})
}
}

if input.IncludeAffectedImages {
imageClient := v1.NewImageServiceClient(conn)
enricher := newDeploymentEnricher(imageClient, input.CVEName, defaultMaxFetchImageConcurrency)

for i := range deployments {
enricher.enrich(callCtx, &deployments[i])
}

enricher.wait()
}

// We always fetch limit+1 - if we do not have one additional element we can end paging.
Expand Down
Loading
Loading