diff --git a/cmd/cone/flags.go b/cmd/cone/flags.go index e4ed674a..bddfb2ad 100644 --- a/cmd/cone/flags.go +++ b/cmd/cone/flags.go @@ -21,6 +21,7 @@ const ( rawTokenFlag = "raw" appDisplayNameFlag = "app" showEncryptedFlag = "show-encrypted" + formDataFlag = "form-data" ) func addWaitFlag(cmd *cobra.Command) { @@ -83,3 +84,7 @@ func addAppDisplayNameFlag(cmd *cobra.Command) { func addShowEncryptedFlag(cmd *cobra.Command) { cmd.Flags().Bool(showEncryptedFlag, false, "Show credentials we could not decrypt.") } + +func addFormDataFlag(cmd *cobra.Command) { + cmd.Flags().String(formDataFlag, "", `Form field data as JSON (e.g., '{"field1":"value1","field2":"value2"}'). Required fields will be prompted interactively if not provided.`) +} diff --git a/cmd/cone/form_fields.go b/cmd/cone/form_fields.go new file mode 100644 index 00000000..97ffd45d --- /dev/null +++ b/cmd/cone/form_fields.go @@ -0,0 +1,468 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/pterm/pterm" + "github.com/spf13/viper" + + "github.com/conductorone/conductorone-sdk-go/pkg/models/shared" + + "github.com/conductorone/cone/pkg/client" + "github.com/conductorone/cone/pkg/logging" + "github.com/conductorone/cone/pkg/output" +) + +// collectFormFields collects form field values from the user based on the form definition. +// Returns a map of field names to their values, or nil if no form fields are present. +func collectFormFields(ctx context.Context, v *viper.Viper, form *shared.FormInput) (map[string]any, error) { + if form == nil || len(form.Fields) == 0 { + return nil, nil + } + + requestData := make(map[string]any) + isNonInteractive := v.GetBool(nonInteractiveFlag) + + // Collect form data from command-line flags if provided + formDataFlagValue := v.GetString(formDataFlag) + formDataMap, err := parseFormDataFlag(formDataFlagValue) + if err != nil { + return nil, err + } + + for _, field := range form.Fields { + fieldName := client.StringFromPtr(field.Name) + if fieldName == "" { + continue + } + + displayName := client.StringFromPtr(field.DisplayName) + if displayName == "" { + displayName = fieldName + } + + description := client.StringFromPtr(field.Description) + + // Check if value was provided via flag + if val, ok := formDataMap[fieldName]; ok { + requestData[fieldName] = val + continue + } + + // Skip if non-interactive and no flag value provided + if isNonInteractive { + // Use default value if available + if defaultValue := getFieldDefaultValue(field); defaultValue != nil { + requestData[fieldName] = defaultValue + } + continue + } + + // Collect value interactively + value, err := collectFieldValue(ctx, field, displayName, description) + if err != nil { + return nil, fmt.Errorf("error collecting field %s: %w", fieldName, err) + } + + if value != nil { + requestData[fieldName] = value + } + } + + if len(requestData) == 0 { + return nil, nil + } + + return requestData, nil +} + +// collectFieldValue collects a single field value from the user based on field type. +func collectFieldValue(ctx context.Context, field shared.Field, displayName, description string) (any, error) { + // Check for default value first + if defaultValue := getFieldDefaultValue(field); defaultValue != nil { + // Show default value and ask for confirmation + pterm.Info.Printf("Field '%s' has default value: %v\n", displayName, defaultValue) + if description != "" { + pterm.Println(description) + } + useDefault, err := pterm.DefaultInteractiveConfirm.Show("Use default value?") + if err != nil { + return nil, err + } + if useDefault { + return defaultValue, nil + } + } + + // Collect based on field type + switch { + case field.StringField != nil: + return collectStringField(ctx, field.StringField, displayName, description) + case field.BoolField != nil: + return collectBoolField(ctx, field.BoolField, displayName, description) + case field.Int64Field != nil: + return collectInt64Field(ctx, field.Int64Field, displayName, description) + case field.StringSliceField != nil: + return collectStringSliceField(ctx, field.StringSliceField, displayName, description) + default: + // Unknown field type - warn and skip to avoid breaking on new field types + logging.Warnf("Skipping field '%s': unsupported field type. You may need to update cone to handle this field.", displayName) + return nil, nil + } +} + +// collectStringField collects a string field value with validation. +func collectStringField(ctx context.Context, field *shared.StringField, displayName, description string) (string, error) { + validator := StringFieldValidator{ + field: field, + displayName: displayName, + description: description, + } + + defaultValue := "" + if field.DefaultValue != nil { + defaultValue = *field.DefaultValue + } + + value, err := output.GetValidInput(ctx, defaultValue, validator) + if err != nil { + return "", err + } + + return value, nil +} + +// collectBoolField collects a boolean field value. +func collectBoolField(ctx context.Context, field *shared.BoolField, displayName, description string) (bool, error) { + select { + case <-ctx.Done(): + return false, ctx.Err() + default: + } + + if description != "" { + pterm.Info.Println(description) + } + + prompt := fmt.Sprintf("Enter value for '%s' (true/false)", displayName) + if field.DefaultValue != nil { + prompt = fmt.Sprintf("Enter value for '%s' (true/false, default: %v)", displayName, *field.DefaultValue) + } + + result, err := pterm.DefaultInteractiveConfirm.Show(prompt) + if err != nil { + return false, err + } + + return result, nil +} + +// collectInt64Field collects an int64 field value with validation. +func collectInt64Field(ctx context.Context, field *shared.Int64Field, displayName, description string) (int64, error) { + validator := Int64FieldValidator{ + field: field, + displayName: displayName, + description: description, + } + + defaultValue := "" + if field.DefaultValue != nil { + defaultValue = strconv.FormatInt(*field.DefaultValue, 10) + } + + value, err := output.GetValidInput(ctx, defaultValue, validator) + if err != nil { + return 0, err + } + + return value, nil +} + +// collectStringSliceField collects a string slice field value using a multi-entry loop. +// User enters one value per line, empty line finishes input. +func collectStringSliceField(ctx context.Context, field *shared.StringSliceField, displayName, description string) ([]string, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + if description != "" { + pterm.Info.Println(description) + } + + if len(field.DefaultValues) > 0 { + pterm.Info.Printf("Default values for '%s': %s\n", displayName, strings.Join(field.DefaultValues, ", ")) + pterm.Info.Println("Press enter with no input to use defaults, or enter new values below.") + } + + pterm.Info.Printf("Enter values for '%s' (one per line, empty line to finish):\n", displayName) + + var result []string + userInput := pterm.DefaultInteractiveTextInput.WithMultiLine(false) + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + input, err := userInput.Show(fmt.Sprintf(" [%d]", len(result)+1)) + if err != nil { + return nil, err + } + + trimmed := strings.TrimSpace(input) + if trimmed == "" { + // Empty line ends input + break + } + + result = append(result, trimmed) + } + + // If no values entered and defaults exist, use defaults + if len(result) == 0 && len(field.DefaultValues) > 0 { + return field.DefaultValues, nil + } + + return result, nil +} + +// getFieldDefaultValue extracts the default value from a field based on its type. +func getFieldDefaultValue(field shared.Field) any { + switch { + case field.StringField != nil && field.StringField.DefaultValue != nil: + return *field.StringField.DefaultValue + case field.BoolField != nil && field.BoolField.DefaultValue != nil: + return *field.BoolField.DefaultValue + case field.Int64Field != nil && field.Int64Field.DefaultValue != nil: + return *field.Int64Field.DefaultValue + case field.StringSliceField != nil: + if len(field.StringSliceField.DefaultValues) > 0 { + return field.StringSliceField.DefaultValues + } + return nil + default: + return nil + } +} + +// parseFormDataFlag parses the --form-data flag value as JSON. +// Expected format: '{"field1":"value1","field2":"value2"}'. +func parseFormDataFlag(formDataFlag string) (map[string]any, error) { + if formDataFlag == "" { + return nil, nil + } + + result := make(map[string]any) + if err := json.Unmarshal([]byte(formDataFlag), &result); err != nil { + return nil, fmt.Errorf("invalid JSON in --form-data flag: %w", err) + } + + return result, nil +} + +// StringFieldValidator validates string field input. +type StringFieldValidator struct { + field *shared.StringField + displayName string + description string +} + +func (v StringFieldValidator) IsValid(txt string) (string, bool) { + if txt == "" { + // Check if field is required + if v.field.StringRules != nil { + // StringRules might have required field, but we'll be lenient here + // and allow empty if no explicit requirement + return txt, true + } + return txt, true + } + + // Apply validation rules if present + if v.field.StringRules != nil { + rules := v.field.StringRules + if rules.MinLen != nil { + minLen, err := strconv.Atoi(*rules.MinLen) + if err == nil && len(txt) < minLen { + return txt, false + } + } + if rules.MaxLen != nil { + maxLen, err := strconv.Atoi(*rules.MaxLen) + if err == nil && len(txt) > maxLen { + return txt, false + } + } + // Additional validations (email, URI, etc.) could be added here + } + + return txt, true +} + +func (v StringFieldValidator) Prompt(isFirstRun bool) { + if isFirstRun { + if v.description != "" { + pterm.Info.Println(v.description) + } + prompt := fmt.Sprintf("Enter value for '%s'", v.displayName) + if v.field.Placeholder != nil { + prompt = fmt.Sprintf("%s (e.g., %s)", prompt, *v.field.Placeholder) + } + if v.field.DefaultValue != nil { + prompt = fmt.Sprintf("%s (default: %s)", prompt, *v.field.DefaultValue) + } + output.InputNeeded.Println(prompt) + } else { + output.InputNeeded.Println("Invalid input. Please try again.") + } +} + +// Int64FieldValidator validates int64 field input. +type Int64FieldValidator struct { + field *shared.Int64Field + displayName string + description string +} + +func (v Int64FieldValidator) IsValid(txt string) (int64, bool) { + if txt == "" { + // Empty is invalid - default values are handled by passing them as initial value + return 0, false + } + + value, err := strconv.ParseInt(txt, 10, 64) + if err != nil { + return 0, false + } + + // Apply validation rules if present + if v.field.Int64Rules != nil { + rules := v.field.Int64Rules + if rules.Const != nil && value != *rules.Const { + return 0, false + } + if rules.Lt != nil && value >= *rules.Lt { + return 0, false + } + if rules.Lte != nil && value > *rules.Lte { + return 0, false + } + if rules.Gt != nil && value <= *rules.Gt { + return 0, false + } + if rules.Gte != nil && value < *rules.Gte { + return 0, false + } + } + + return value, true +} + +func (v Int64FieldValidator) Prompt(isFirstRun bool) { + if isFirstRun { + if v.description != "" { + pterm.Info.Println(v.description) + } + prompt := fmt.Sprintf("Enter integer value for '%s'", v.displayName) + if v.field.Placeholder != nil { + prompt = fmt.Sprintf("%s (e.g., %s)", prompt, *v.field.Placeholder) + } + if v.field.DefaultValue != nil { + prompt = fmt.Sprintf("%s (default: %d)", prompt, *v.field.DefaultValue) + } + output.InputNeeded.Println(prompt) + } else { + output.InputNeeded.Println("Invalid integer input. Please try again.") + } +} + +// isFieldRequired checks if a field is required based on its validation rules. +func isFieldRequired(field shared.Field) bool { + switch { + case field.StringField != nil: + rules := field.StringField.StringRules + if rules == nil { + return false + } + // Field is required if IgnoreEmpty is not explicitly true and MinLen >= 1 + ignoreEmpty := rules.IgnoreEmpty != nil && *rules.IgnoreEmpty + if ignoreEmpty { + return false + } + if rules.MinLen != nil { + minLen, err := strconv.Atoi(*rules.MinLen) + if err == nil && minLen >= 1 { + return true + } + } + return false + case field.Int64Field != nil: + // Int64 fields don't have an IgnoreEmpty concept in the same way + // Consider required if there are validation rules but no default + return field.Int64Field.Int64Rules != nil && field.Int64Field.DefaultValue == nil + case field.BoolField != nil: + // Bool fields typically have a default (false), so rarely required + return false + case field.StringSliceField != nil: + // String slice fields - check if there are rules requiring items + return false + default: + return false + } +} + +// validateFormData validates that all required form fields are present. +func validateFormData(form *shared.FormInput, requestData map[string]any) error { + if form == nil { + return nil + } + + for _, field := range form.Fields { + fieldName := client.StringFromPtr(field.Name) + if fieldName == "" { + continue + } + + displayName := client.StringFromPtr(field.DisplayName) + if displayName == "" { + displayName = fieldName + } + + // Check if field is required + if !isFieldRequired(field) { + continue + } + + // Check if value was provided + val, hasValue := requestData[fieldName] + if !hasValue { + // Check if there's a default value + if getFieldDefaultValue(field) == nil { + return fmt.Errorf("required field '%s' is missing", displayName) + } + continue + } + + // Check if value is empty + switch v := val.(type) { + case string: + if v == "" { + return fmt.Errorf("required field '%s' cannot be empty", displayName) + } + case []string: + if len(v) == 0 { + return fmt.Errorf("required field '%s' cannot be empty", displayName) + } + } + } + + return nil +} diff --git a/cmd/cone/get_drop_task.go b/cmd/cone/get_drop_task.go index 46096e9a..c51c0c5f 100644 --- a/cmd/cone/get_drop_task.go +++ b/cmd/cone/get_drop_task.go @@ -15,6 +15,7 @@ import ( "github.com/conductorone/conductorone-sdk-go/pkg/models/shared" "github.com/conductorone/cone/pkg/client" + "github.com/conductorone/cone/pkg/logging" "github.com/conductorone/cone/pkg/output" ) @@ -29,7 +30,17 @@ func getCmd() *cobra.Command { cmd := &cobra.Command{ Use: "get [flags]\n cone get --query [flags]\n cone get --app-id --entitlement-id [flags]", Short: "Create an access request for an entitlement by alias", - RunE: runGet, + Long: `Create an access request for an entitlement by alias, query, or explicit app/entitlement IDs. + +Some entitlements may require custom form fields to be filled out when making an access request. +If form fields are required, you will be prompted interactively to provide them, or you can +provide them via the --form-data flag as JSON. + +Examples: + cone get my-entitlement-alias + cone get --query "GitHub Admin" --justification "Need admin access" + cone get --app-id app123 --entitlement-id ent456 --form-data '{"reason":"project-work","duration":"2w"}'`, + RunE: runGet, } addGrantDurationFlag(cmd) addEmergencyAccessFlag(cmd) @@ -54,6 +65,7 @@ func taskCmd(cmd *cobra.Command) *cobra.Command { addEntitlementAliasFlag(cmd) addForceTaskCreateFlag(cmd) addEntitlementDetailsFlag(cmd) + addFormDataFlag(cmd) return cmd } @@ -231,7 +243,21 @@ func runGet(cmd *cobra.Command, args []string) error { apiDuration = fmt.Sprintf("%ds", seconds) } - accessRequest, err := c.CreateGrantTask(ctx, appId, entitlementId, userId, appUserId, justification, apiDuration, emergencyAccess) + // Collect form data if provided via flags + var requestData map[string]any + formDataFlagValue := v.GetString(formDataFlag) + + // Only send requestData if user explicitly provided form data + if formDataFlagValue != "" { + var err error + requestData, err = parseFormDataFlag(formDataFlagValue) + if err != nil { + return nil, err + } + } + + // Create the task with initial form data (if any) + accessRequest, err := c.CreateGrantTask(ctx, appId, entitlementId, userId, appUserId, justification, apiDuration, emergencyAccess, requestData) if err != nil { errorBody := err.Error() if strings.Contains(errorBody, durationErrorMessage) { @@ -241,7 +267,43 @@ func runGet(cmd *cobra.Command, args []string) error { } return nil, err } - return accessRequest.TaskView.Task, nil + + task := accessRequest.TaskView.Task + + // Check if the task has form fields + hasFormFields := task.Form != nil + if hasFormFields { + hasFormFields = len(task.Form.Fields) > 0 + } + + if hasFormFields { + // Collect form fields if not already provided + if len(requestData) == 0 { + collectedData, err := collectFormFields(ctx, v, task.Form) + if err != nil { + return nil, fmt.Errorf("error collecting form fields: %w", err) + } + if len(collectedData) > 0 { + // Update the task with the collected form data + taskID := client.StringFromPtr(task.ID) + _, err := c.UpdateTaskRequestData(ctx, taskID, collectedData) + if err != nil { + return nil, fmt.Errorf("error updating task with form data: %w", err) + } + } + } else { + // Validate that provided form data matches the form structure + if err := validateFormData(task.Form, requestData); err != nil { + pterm.Warning.Printf("Form data validation warning: %v\n", err) + } + } + } else if formDataFlagValue != "" { + // Form data was provided but task doesn't have form fields + // The data was already sent on task creation and will be ignored by the API + logging.Debugf("Form data was provided via --form-data flag, but this entitlement does not require form fields. The data was sent but may be ignored by the API.") + } + + return task, nil }) } diff --git a/cmd/cone/main.go b/cmd/cone/main.go index 1bfa7425..dc264b2e 100644 --- a/cmd/cone/main.go +++ b/cmd/cone/main.go @@ -8,8 +8,10 @@ import ( "syscall" "github.com/spf13/cobra" + "github.com/spf13/viper" "github.com/conductorone/cone/pkg/client" + "github.com/conductorone/cone/pkg/logging" ) var version = "dev" @@ -54,7 +56,8 @@ func runCli(ctx context.Context) int { cliCmd.PersistentFlags().String("client-secret", "", "Client secret") cliCmd.PersistentFlags().String("api-endpoint", "", "Override the API endpoint") cliCmd.PersistentFlags().StringP("output", "o", "table", "Output format. Valid values: table, json, json-pretty, wide.") - cliCmd.PersistentFlags().Bool("debug", false, "Enable debug logging") + cliCmd.PersistentFlags().Bool("debug", false, "Enable HTTP debug logging") + cliCmd.PersistentFlags().String("log-level", "", "Set log level (debug, info, warn, error)") err := initConfig(cliCmd) if err != nil { @@ -62,6 +65,11 @@ func runCli(ctx context.Context) int { return 1 } + // Initialize logging based on --log-level flag + if logLevel := viper.GetString("log-level"); logLevel != "" { + logging.Init(logging.Level(logLevel)) + } + cliCmd.AddCommand(getCmd()) cliCmd.AddCommand(dropCmd()) cliCmd.AddCommand(whoAmICmd()) diff --git a/pkg/client/client.go b/pkg/client/client.go index 1edc0447..fd1dccd5 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -76,6 +76,7 @@ type C1Client interface { justification string, duration string, emergencyAccess bool, + requestData map[string]any, ) (*shared.TaskServiceCreateGrantResponse, error) CreateRevokeTask( ctx context.Context, @@ -90,6 +91,7 @@ type C1Client interface { ApproveTask(ctx context.Context, taskId string, comment string, policyId string) (*shared.TaskActionsServiceApproveResponse, error) DenyTask(ctx context.Context, taskId string, comment string, policyId string) (*shared.TaskActionsServiceDenyResponse, error) EscalateTask(ctx context.Context, taskId string) (*shared.TaskServiceActionResponse, error) + UpdateTaskRequestData(ctx context.Context, taskID string, requestData map[string]any) (*shared.TaskServiceActionResponse, error) ListApps(ctx context.Context) ([]shared.App, error) ListAppUsers(ctx context.Context, appID string) ([]shared.AppUser, error) ListAppUsersForUser(ctx context.Context, appID string, userID string) ([]shared.AppUser, error) diff --git a/pkg/client/task.go b/pkg/client/task.go index 0a9f54fb..fd345b90 100644 --- a/pkg/client/task.go +++ b/pkg/client/task.go @@ -29,6 +29,7 @@ func (c *client) CreateGrantTask( justification string, duration string, emergencyAccess bool, + requestData map[string]any, ) (*shared.TaskServiceCreateGrantResponse, error) { req := shared.TaskServiceCreateGrantRequest{ AppEntitlementID: appEntitlementId, @@ -41,6 +42,9 @@ func (c *client) CreateGrantTask( if duration != "" { req.GrantDuration = &duration } + if len(requestData) > 0 { + req.RequestData = requestData + } resp, err := c.sdk.Task.CreateGrantTask(ctx, &req) if err != nil { return nil, err @@ -158,3 +162,22 @@ func (c *client) EscalateTask(ctx context.Context, taskID string) (*shared.TaskS } return resp.TaskServiceActionResponse, nil } + +func (c *client) UpdateTaskRequestData(ctx context.Context, taskID string, requestData map[string]any) (*shared.TaskServiceActionResponse, error) { + req := shared.TaskActionsServiceUpdateRequestDataRequest{} + if len(requestData) > 0 { + req.Data = requestData + } + resp, err := c.sdk.TaskActions.UpdateRequestData(ctx, operations.C1APITaskV1TaskActionsServiceUpdateRequestDataRequest{ + TaskActionsServiceUpdateRequestDataRequest: &req, + TaskID: taskID, + }) + if err != nil { + return nil, err + } + + if err := NewHTTPError(resp.RawResponse); err != nil { + return nil, err + } + return resp.TaskServiceActionResponse, nil +} diff --git a/pkg/logging/logging.go b/pkg/logging/logging.go new file mode 100644 index 00000000..d01da2ae --- /dev/null +++ b/pkg/logging/logging.go @@ -0,0 +1,118 @@ +package logging + +import ( + "sync" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +var ( + logger *zap.SugaredLogger + once sync.Once +) + +// Level represents log levels. +type Level string + +const ( + LevelDebug Level = "debug" + LevelInfo Level = "info" + LevelWarn Level = "warn" + LevelError Level = "error" +) + +// Init initializes the global logger with the specified level. +// This should be called once at application startup. +func Init(level Level) { + once.Do(func() { + logger = newLogger(level) + }) +} + +// Get returns the global logger. If Init hasn't been called, +// it returns a no-op logger. +func Get() *zap.SugaredLogger { + if logger == nil { + // Return a no-op logger if not initialized + return zap.NewNop().Sugar() + } + return logger +} + +func newLogger(level Level) *zap.SugaredLogger { + var zapLevel zapcore.Level + switch level { + case LevelDebug: + zapLevel = zapcore.DebugLevel + case LevelInfo: + zapLevel = zapcore.InfoLevel + case LevelWarn: + zapLevel = zapcore.WarnLevel + case LevelError: + zapLevel = zapcore.ErrorLevel + default: + zapLevel = zapcore.InfoLevel + } + + config := zap.Config{ + Level: zap.NewAtomicLevelAt(zapLevel), + Development: false, + Encoding: "console", + EncoderConfig: zap.NewDevelopmentEncoderConfig(), + OutputPaths: []string{"stderr"}, + ErrorOutputPaths: []string{"stderr"}, + } + + // Simplify the output format + config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder + + zapLogger, err := config.Build() + if err != nil { + // Fall back to nop logger on error + return zap.NewNop().Sugar() + } + + return zapLogger.Sugar() +} + +// Debug logs a debug message. +func Debug(args ...interface{}) { + Get().Debug(args...) +} + +// Debugf logs a formatted debug message. +func Debugf(template string, args ...interface{}) { + Get().Debugf(template, args...) +} + +// Info logs an info message. +func Info(args ...interface{}) { + Get().Info(args...) +} + +// Infof logs a formatted info message. +func Infof(template string, args ...interface{}) { + Get().Infof(template, args...) +} + +// Warn logs a warning message. +func Warn(args ...interface{}) { + Get().Warn(args...) +} + +// Warnf logs a formatted warning message. +func Warnf(template string, args ...interface{}) { + Get().Warnf(template, args...) +} + +// Error logs an error message. +func Error(args ...interface{}) { + Get().Error(args...) +} + +// Errorf logs a formatted error message. +func Errorf(template string, args ...interface{}) { + Get().Errorf(template, args...) +}