diff --git a/cmd/sup/main.go b/cmd/sup/main.go index 3a36928..08ac48a 100644 --- a/cmd/sup/main.go +++ b/cmd/sup/main.go @@ -3,6 +3,7 @@ package main import ( "flag" "fmt" + "io" "io/ioutil" "os" "os/user" @@ -12,20 +13,14 @@ import ( "text/tabwriter" "time" + "github.com/adammck/venv" "github.com/mikkeloscar/sshconfig" "github.com/pkg/errors" "github.com/pressly/sup" ) var ( - supfile string - envVars flagStringSlice - sshConfig string - onlyHosts string - exceptHosts string - - debug bool - disablePrefix bool + flags options showVersion bool showHelp bool @@ -49,158 +44,33 @@ func (f *flagStringSlice) Set(value string) error { return nil } -func init() { - flag.StringVar(&supfile, "f", "", "Custom path to ./Supfile[.yml]") - flag.Var(&envVars, "e", "Set environment variables") - flag.Var(&envVars, "env", "Set environment variables") - flag.StringVar(&sshConfig, "sshconfig", "", "Read SSH Config file, ie. ~/.ssh/config file") - flag.StringVar(&onlyHosts, "only", "", "Filter hosts using regexp") - flag.StringVar(&exceptHosts, "except", "", "Filter out hosts using regexp") - - flag.BoolVar(&debug, "D", false, "Enable debug mode") - flag.BoolVar(&debug, "debug", false, "Enable debug mode") - flag.BoolVar(&disablePrefix, "disable-prefix", false, "Disable hostname prefix") +type options struct { + dirname string + supfile string + envVars flagStringSlice + sshConfig string + onlyHosts string + exceptHosts string + debug bool + disablePrefix bool + env venv.Env +} +func init() { flag.BoolVar(&showVersion, "v", false, "Print version") flag.BoolVar(&showVersion, "version", false, "Print version") flag.BoolVar(&showHelp, "h", false, "Show help") flag.BoolVar(&showHelp, "help", false, "Show help") -} -func networkUsage(conf *sup.Supfile) { - w := &tabwriter.Writer{} - w.Init(os.Stderr, 4, 4, 2, ' ', 0) - defer w.Flush() - - // Print available networks/hosts. - fmt.Fprintln(w, "Networks:\t") - for _, name := range conf.Networks.Names { - fmt.Fprintf(w, "- %v\n", name) - network, _ := conf.Networks.Get(name) - for _, host := range network.Hosts { - fmt.Fprintf(w, "\t- %v\n", host) - } - } - fmt.Fprintln(w) -} - -func cmdUsage(conf *sup.Supfile) { - w := &tabwriter.Writer{} - w.Init(os.Stderr, 4, 4, 2, ' ', 0) - defer w.Flush() - - // Print available targets/commands. - fmt.Fprintln(w, "Targets:\t") - for _, name := range conf.Targets.Names { - cmds, _ := conf.Targets.Get(name) - fmt.Fprintf(w, "- %v\t%v\n", name, strings.Join(cmds, " ")) - } - fmt.Fprintln(w, "\t") - fmt.Fprintln(w, "Commands:\t") - for _, name := range conf.Commands.Names { - cmd, _ := conf.Commands.Get(name) - fmt.Fprintf(w, "- %v\t%v\n", name, cmd.Desc) - } - fmt.Fprintln(w) -} - -// parseArgs parses args and returns network and commands to be run. -// On error, it prints usage and exits. -func parseArgs(conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { - var commands []*sup.Command - - args := flag.Args() - if len(args) < 1 { - networkUsage(conf) - return nil, nil, ErrUsage - } - - // Does the exist? - network, ok := conf.Networks.Get(args[0]) - if !ok { - networkUsage(conf) - return nil, nil, ErrUnknownNetwork - } - - hosts, err := network.ParseInventory() - if err != nil { - return nil, nil, err - } - network.Hosts = append(network.Hosts, hosts...) - - // Does the have at least one host? - if len(network.Hosts) == 0 { - networkUsage(conf) - return nil, nil, ErrNetworkNoHosts - } - - // Check for the second argument - if len(args) < 2 { - cmdUsage(conf) - return nil, nil, ErrUsage - } - - // In case of the network.Env needs an initialization - if network.Env == nil { - network.Env = make(sup.EnvList, 0) - } - - // Add default env variable with current network - network.Env.Set("SUP_NETWORK", args[0]) - - // Add default nonce - network.Env.Set("SUP_TIME", time.Now().UTC().Format(time.RFC3339)) - if os.Getenv("SUP_TIME") != "" { - network.Env.Set("SUP_TIME", os.Getenv("SUP_TIME")) - } - - // Add user - if os.Getenv("SUP_USER") != "" { - network.Env.Set("SUP_USER", os.Getenv("SUP_USER")) - } else { - network.Env.Set("SUP_USER", os.Getenv("USER")) - } - - for _, cmd := range args[1:] { - // Target? - target, isTarget := conf.Targets.Get(cmd) - if isTarget { - // Loop over target's commands. - for _, cmd := range target { - command, isCommand := conf.Commands.Get(cmd) - if !isCommand { - cmdUsage(conf) - return nil, nil, fmt.Errorf("%v: %v", ErrCmd, cmd) - } - command.Name = cmd - commands = append(commands, &command) - } - } - - // Command? - command, isCommand := conf.Commands.Get(cmd) - if isCommand { - command.Name = cmd - commands = append(commands, &command) - } - - if !isTarget && !isCommand { - cmdUsage(conf) - return nil, nil, fmt.Errorf("%v: %v", ErrCmd, cmd) - } - } - - return &network, commands, nil -} - -func resolvePath(path string) string { - if path[:2] == "~/" { - usr, err := user.Current() - if err == nil { - path = filepath.Join(usr.HomeDir, path[2:]) - } - } - return path + flag.StringVar(&flags.supfile, "f", "", "Custom path to ./Supfile[.yml]") + flag.Var(&flags.envVars, "e", "Set environment variables") + flag.Var(&flags.envVars, "env", "Set environment variables") + flag.StringVar(&flags.sshConfig, "sshconfig", "", "Read SSH Config file, ie. ~/.ssh/config file") + flag.StringVar(&flags.onlyHosts, "only", "", "Filter hosts using regexp") + flag.StringVar(&flags.exceptHosts, "except", "", "Filter out hosts using regexp") + flag.BoolVar(&flags.debug, "D", false, "Enable debug mode") + flag.BoolVar(&flags.debug, "debug", false, "Enable debug mode") + flag.BoolVar(&flags.disablePrefix, "disable-prefix", false, "Disable hostname prefix") } func main() { @@ -217,38 +87,31 @@ func main() { return } - if supfile == "" { - supfile = "./Supfile" - } - data, err := ioutil.ReadFile(resolvePath(supfile)) - if err != nil { - firstErr := err - data, err = ioutil.ReadFile("./Supfile.yml") // Alternative to ./Supfile. - if err != nil { - fmt.Fprintln(os.Stderr, firstErr) - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } - } - conf, err := sup.NewSupfile(data) - if err != nil { + flags.env = venv.OS() + if err := runSupfile(os.Stderr, flags, flag.Args()); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } +} + +func runSupfile(errStream io.Writer, options options, args []string) error { + data, err := readSupfile(options) - // Parse network and commands to be run from args. - network, commands, err := parseArgs(conf) + conf, err := sup.NewSupfile(errStream, data) if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + return err + } + // Parse network and commands to be run from flags. + network, commands, err := parseArgs(errStream, options, args, conf) + if err != nil { + return err } // --only flag filters hosts - if onlyHosts != "" { - expr, err := regexp.CompilePOSIX(onlyHosts) + if options.onlyHosts != "" { + expr, err := regexp.CompilePOSIX(options.onlyHosts) if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + return err } var hosts []string @@ -258,18 +121,15 @@ func main() { } } if len(hosts) == 0 { - fmt.Fprintln(os.Stderr, fmt.Errorf("no hosts match --only '%v' regexp", onlyHosts)) - os.Exit(1) + return fmt.Errorf("no hosts match --only '%v' regexp", options.onlyHosts) } network.Hosts = hosts } - // --except flag filters out hosts - if exceptHosts != "" { - expr, err := regexp.CompilePOSIX(exceptHosts) + if options.exceptHosts != "" { + expr, err := regexp.CompilePOSIX(options.exceptHosts) if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + return err } var hosts []string @@ -279,18 +139,15 @@ func main() { } } if len(hosts) == 0 { - fmt.Fprintln(os.Stderr, fmt.Errorf("no hosts left after --except '%v' regexp", onlyHosts)) - os.Exit(1) + return fmt.Errorf("no hosts left after --except '%v' regexp", options.onlyHosts) } network.Hosts = hosts } - // --sshconfig flag location for ssh_config file - if sshConfig != "" { - confHosts, err := sshconfig.ParseSSHConfig(resolvePath(sshConfig)) + if options.sshConfig != "" { + confHosts, err := sshconfig.ParseSSHConfig(resolvePath(options.sshConfig)) if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + return err } // flatten Host -> *SSHHost, not the prettiest @@ -303,28 +160,25 @@ func main() { } // check network.Hosts for match - for _, host := range network.Hosts { + for i, host := range network.Hosts { conf, found := confMap[host] if found { network.User = conf.User network.IdentityFile = resolvePath(conf.IdentityFile) - network.Hosts = []string{fmt.Sprintf("%s:%d", conf.HostName, conf.Port)} + network.Hosts[i] = fmt.Sprintf("%s:%d", conf.HostName, conf.Port) } } } - var vars sup.EnvList for _, val := range append(conf.Env, network.Env...) { vars.Set(val.Key, val.Value) } if err := vars.ResolveValues(); err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + return err } - // Parse CLI --env flag env vars, define $SUP_ENV and override values defined in Supfile. var cliVars sup.EnvList - for _, env := range envVars { + for _, env := range options.envVars { if len(env) == 0 { continue } @@ -338,7 +192,6 @@ func main() { vars.Set(env[:i], env[i+1:]) cliVars.Set(env[:i], env[i+1:]) } - // SUP_ENV is generated only from CLI env vars. // Separate loop to omit duplicates. supEnv := "" @@ -346,20 +199,176 @@ func main() { supEnv += fmt.Sprintf(" -e %v=%q", v.Key, v.Value) } vars.Set("SUP_ENV", strings.TrimSpace(supEnv)) - // Create new Stackup app. app, err := sup.New(conf) if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + return err } - app.Debug(debug) - app.Prefix(!disablePrefix) + app.Debug(options.debug) + app.Prefix(!options.disablePrefix) // Run all the commands in the given network. - err = app.Run(network, vars, commands...) + return app.Run(network, vars, commands...) +} + +func readSupfile(options options) ([]byte, error) { + filename := options.supfile + if filename == "" { + filename = "Supfile" + } + + data, err := ioutil.ReadFile(defaultPath(options.dirname, filename)) if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + firstErr := err + // Alternative to ./Supfile. + data, err = ioutil.ReadFile(defaultPath(options.dirname, "Supfile.yml")) + if err != nil { + return []byte{}, fmt.Errorf("%v\n%v", firstErr, err) + } } + + return data, nil +} + +func defaultPath(dir, path string) string { + if dir == "" { + return resolvePath(path) + } + return filepath.Join(dir, path) + +} + +func resolvePath(path string) string { + if path[:2] == "~/" { + usr, err := user.Current() + if err == nil { + path = filepath.Join(usr.HomeDir, path[2:]) + } + } + return path +} + +// parseArgs parses args and returns network and commands to be run. +// On error, it prints usage and exits. +func parseArgs(errStream io.Writer, options options, args []string, conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { + var commands []*sup.Command + + if len(args) < 1 { + networkUsage(errStream, conf) + return nil, nil, ErrUsage + } + + // Does the exist? + network, ok := conf.Networks.Get(args[0]) + if !ok { + networkUsage(errStream, conf) + return nil, nil, ErrUnknownNetwork + } + + hosts, err := network.ParseInventory(errStream) + if err != nil { + return nil, nil, err + } + network.Hosts = append(network.Hosts, hosts...) + + // Does the have at least one host? + if len(network.Hosts) == 0 { + networkUsage(errStream, conf) + return nil, nil, ErrNetworkNoHosts + } + + // Check for the second argument + if len(args) < 2 { + cmdUsage(errStream, conf) + return nil, nil, ErrUsage + } + + // In case of the network.Env needs an initialization + if network.Env == nil { + network.Env = make(sup.EnvList, 0) + } + + // Add default env variable with current network + network.Env.Set("SUP_NETWORK", args[0]) + + // Add default nonce + network.Env.Set("SUP_TIME", time.Now().UTC().Format(time.RFC3339)) + if options.env.Getenv("SUP_TIME") != "" { + network.Env.Set("SUP_TIME", options.env.Getenv("SUP_TIME")) + } + + // Add user + if options.env.Getenv("SUP_USER") != "" { + network.Env.Set("SUP_USER", options.env.Getenv("SUP_USER")) + } else { + network.Env.Set("SUP_USER", options.env.Getenv("USER")) + } + + for _, cmd := range args[1:] { + // Target? + target, isTarget := conf.Targets.Get(cmd) + if isTarget { + // Loop over target's commands. + for _, cmd := range target { + command, isCommand := conf.Commands.Get(cmd) + if !isCommand { + cmdUsage(errStream, conf) + return nil, nil, fmt.Errorf("%v: %v", ErrCmd, cmd) + } + command.Name = cmd + commands = append(commands, &command) + } + } + + // Command? + command, isCommand := conf.Commands.Get(cmd) + if isCommand { + command.Name = cmd + commands = append(commands, &command) + } + + if !isTarget && !isCommand { + cmdUsage(errStream, conf) + return nil, nil, fmt.Errorf("%v: %v", ErrCmd, cmd) + } + } + + return &network, commands, nil +} + +func networkUsage(errStream io.Writer, conf *sup.Supfile) { + w := &tabwriter.Writer{} + w.Init(errStream, 4, 4, 2, ' ', 0) + defer w.Flush() + + // Print available networks/hosts. + fmt.Fprintln(w, "Networks:\t") + for _, name := range conf.Networks.Names { + fmt.Fprintf(w, "- %v\n", name) + network, _ := conf.Networks.Get(name) + for _, host := range network.Hosts { + fmt.Fprintf(w, "\t- %v\n", host) + } + } + fmt.Fprintln(w) +} + +func cmdUsage(errStream io.Writer, conf *sup.Supfile) { + w := &tabwriter.Writer{} + w.Init(errStream, 4, 4, 2, ' ', 0) + defer w.Flush() + + // Print available targets/commands. + fmt.Fprintln(w, "Targets:\t") + for _, name := range conf.Targets.Names { + cmds, _ := conf.Targets.Get(name) + fmt.Fprintf(w, "- %v\t%v\n", name, strings.Join(cmds, " ")) + } + fmt.Fprintln(w, "\t") + fmt.Fprintln(w, "Commands:\t") + for _, name := range conf.Commands.Names { + cmd, _ := conf.Commands.Get(name) + fmt.Fprintf(w, "- %v\t%v\n", name, cmd.Desc) + } + fmt.Fprintln(w) } diff --git a/cmd/sup/main_test.go b/cmd/sup/main_test.go new file mode 100644 index 0000000..2b8d52b --- /dev/null +++ b/cmd/sup/main_test.go @@ -0,0 +1,1129 @@ +package main + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/adammck/venv" +) + +const ( + envTestUser = "sup_test_user" +) + +var ( + testErrStream = ioutil.Discard +) + +func TestInvalidYaml(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 +efewf we +we kp re +` + + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + env: testEnv(), + } + if err := runSupfile(testErrStream, options, []string{}); err == nil { + t.Fatal("Expected an error") + } + }) +} + +func TestNoNetworkSpecified(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + env: testEnv(), + } + if err := runSupfile(testErrStream, options, []string{}); err != ErrUsage { + t.Fatal(err) + } + }) +} + +func TestUnknownNetwork(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + } + if err := runSupfile(testErrStream, options, []string{"production"}); err != ErrUnknownNetwork { + t.Fatal(err) + } + }) +} + +func TestNoHosts(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + env: testEnv(), + } + if err := runSupfile(testErrStream, options, []string{"staging"}); err != ErrNetworkNoHosts { + t.Fatal(err) + } + }) +} + +func TestNoCommand(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + env: testEnv(), + } + if err := runSupfile(testErrStream, options, []string{"staging"}); err != ErrUsage { + t.Fatal(err) + } + }) +} + +func TestNonexistentCommandOrTarget(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server2 + +commands: + step1: + run: echo "Hey over there" + step2: + run: echo "Hey again" + +targets: + walk: + - step1 + - step2 +` + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + env: testEnv(), + } + if err := runSupfile(testErrStream, options, []string{"staging", "step5"}); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), ErrCmd.Error()) { + t.Fatalf("Expected `%v` error, got `%v`", ErrCmd, err) + } + }) +} + +func TestTargetReferencingNonexistentCommand(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server2 + +commands: + step1: + run: echo "Hey over there" + step2: + run: echo "Hey again" + +targets: + walk: + - step1 + - step2 + - step3 +` + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + env: testEnv(), + } + if err := runSupfile(testErrStream, options, []string{"staging", "walk"}); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), ErrCmd.Error()) { + t.Fatalf("Expected `%v` error, got `%v`", ErrCmd, err) + } + }) +} + +func TestOneCommand(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 3) + if err != nil { + t.Fatal(err) + } + + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 2) + m.expectNoActivityOnServers(1) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) +} + +func TestSequenceOfCommands(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server2 + +commands: + step1: + run: echo "Hey over there" + step2: + run: echo "Hey again" +` + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 3) + if err != nil { + t.Fatal(err) + } + + args := []string{"staging", "step1", "step2"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 2) + m.expectNoActivityOnServers(1) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + m.expectCommandOnActiveServers(`echo "Hey again"`) + }) +} + +func TestTarget(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server2 + +commands: + step1: + run: echo "Hey over there" + step2: + run: echo "Hey again" + +targets: + walk: + - step1 + - step2 +` + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 3) + if err != nil { + t.Fatal(err) + } + + args := []string{"staging", "walk"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 2) + m.expectNoActivityOnServers(1) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + m.expectCommandOnActiveServers(`echo "Hey again"`) + }) +} + +func TestOnlyHosts(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 3) + if err != nil { + t.Fatal(err) + } + + options.onlyHosts = "server2" + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(2) + m.expectNoActivityOnServers(0, 1) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) +} + +func TestOnlyHostsEmpty(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + onlyHosts: "server42", + env: venv.Mock(), + } + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), "no hosts match") { + t.Fatalf("Expected a different error, got `%v`", err) + } + }) +} + +func TestOnlyHostsInvalid(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + onlyHosts: "server(", + env: venv.Mock(), + } + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), "error parsing regexp") { + t.Fatalf("Expected a different error, got `%v`", err) + } + }) +} + +func TestExceptHosts(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 3) + if err != nil { + t.Fatal(err) + } + options.exceptHosts = "server(1|2)" + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0) + m.expectNoActivityOnServers(1, 2) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) +} + +func TestExceptHostsEmpty(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + exceptHosts: "server", + env: venv.Mock(), + } + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), "no hosts left") { + t.Fatalf("Expected a different error, got `%v`", err) + } + }) +} + +func TestExceptHostsInvalid(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + exceptHosts: "server(", + env: venv.Mock(), + } + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), "error parsing regexp") { + t.Fatalf("Expected a different error, got `%v`", err) + } + }) +} + +func TestInventory(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + inventory: printf "server0\n# comment\n\nserver2\n" + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 3) + if err != nil { + t.Fatal(err) + } + + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 2) + m.expectNoActivityOnServers(1) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) +} + +func TestFailedInventory(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + inventory: this won't compile + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + env: testEnv(), + } + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err == nil { + t.Fatal("Expected an error") + } + }) +} + +func TestSupVariables(t *testing.T) { + t.Parallel() + + // these tests need to run in order because they mess with env vars + t.Run("default", func(t *testing.T) { + if time.Now().Hour() == 23 && time.Now().Minute() == 59 { + t.Skip("Skipping test") + } + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err != nil { + t.Fatal(err) + } + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`SUP_NETWORK="staging"`) + m.expectExportOnActiveServers(`SUP_ENV=""`) + m.expectExportOnActiveServers(fmt.Sprintf(`SUP_USER="%s"`, envTestUser)) + m.expectExportRegexpOnActiveServers(`SUP_HOST="localhost:\d+"`) + }) + }) + + t.Run("default SUP_TIME", func(t *testing.T) { + + if time.Now().Hour() == 23 && time.Now().Minute() == 59 { + t.Skip("Skipping SUP_TIME test because it might fail around midnight") + } + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err != nil { + t.Fatal(err) + } + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportRegexpOnActiveServers( + fmt.Sprintf( + `SUP_TIME="%4d-%02d-%02dT[0-2][0-9]:[0-5][0-9]:[0-5][0-9]Z"`, + time.Now().Year(), + time.Now().Month(), + time.Now().Day(), + ), + ) + }) + }) + + t.Run("SUP_TIME env var set", func(t *testing.T) { + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + options.env.Setenv("SUP_TIME", "now") + + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err != nil { + t.Fatal(err) + } + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`SUP_TIME="now"`) + }) + }) + + t.Run("SUP_USER env var set", func(t *testing.T) { + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + options.env.Setenv("SUP_USER", "sup_rules") + + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err != nil { + t.Fatal(err) + } + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`SUP_USER="sup_rules"`) + }) + }) +} + +func TestInvalidVars(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +env: + TODAYS_SPECIAL: this won't compile + +networks: + staging: + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + _, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err == nil { + t.Fatal("Expected an error") + } + + }) +} + +func TestEnvLevelVars(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +env: + TODAYS_SPECIAL: "dog milk" + +networks: + staging: + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`TODAYS_SPECIAL="dog milk"`) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) +} + +func TestNetworkLevelVars(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +env: + TODAYS_SPECIAL: "dog milk" + +networks: + staging: + env: + TODAYS_SPECIAL: "Trout a la Crème" + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`TODAYS_SPECIAL="Trout a la Crème"`) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) +} + +func TestCommandLineLevelVars(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +env: + TODAYS_SPECIAL: "dog milk" + +networks: + staging: + env: + TODAYS_SPECIAL: "Trout a la Crème" + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + options.envVars = []string{"IM_HERE", "TODAYS_SPECIAL=Gazpacho"} + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`IM_HERE=""`) + m.expectExportOnActiveServers(`TODAYS_SPECIAL="Gazpacho"`) + m.expectExportOnActiveServers(`SUP_ENV="-e TODAYS_SPECIAL="Gazpacho""`) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) +} + +func TestEmptyCommandLineLevelVars(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +env: + TODAYS_SPECIAL: "dog milk" + +networks: + staging: + env: + TODAYS_SPECIAL: "Trout a la Crème" + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + options.envVars = []string{""} + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`TODAYS_SPECIAL="Trout a la Crème"`) + m.expectExportOnActiveServers(`SUP_ENV=""`) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) +} + +func TestFileOption(t *testing.T) { + t.Parallel() + + t.Run("fallbacks to Supfile.yml", func(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDirWithoutSupfile(t, func(dirname string) { + writeSupfileAs(dirname, "Supfile.yml", input) + + outputs, options, err := setupMockEnv(dirname, 1) + if err != nil { + t.Fatal(err) + } + + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) + }) + + t.Run("prefers Supfile over Supfile.yml when not specified", func(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + +commands: + step1: + run: echo "Default Supfile" +` + withTmpDir(t, input, func(dirname string) { + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + +commands: + step1: + run: echo "Supfile.yml" +` + writeSupfileAs(dirname, "Supfile.yml", input) + + outputs, options, err := setupMockEnv(dirname, 1) + if err != nil { + t.Fatal(err) + } + + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0) + m.expectCommandOnActiveServers(`echo "Default Supfile"`) + }) + }) + + t.Run("can specify arbitrary file", func(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDirWithoutSupfile(t, func(dirname string) { + writeSupfileAs(dirname, "different_file_name", input) + + outputs, options, err := setupMockEnv(dirname, 1) + if err != nil { + t.Fatal(err) + } + + options.supfile = "different_file_name" + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) + }) + + t.Run("fails without a Supfile", func(t *testing.T) { + t.Parallel() + + withTmpDirWithoutSupfile(t, func(dirname string) { + options := options{ + dirname: dirname, + } + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err == nil { + t.Fatal("Expected an error") + } + }) + }) +} + +func withTmpDir(t *testing.T, input string, test func(dirname string)) { + dirname, err := tmpDir() + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dirname) + + if err := writeDefaultSupfile(dirname, input); err != nil { + t.Fatal(err) + } + + test(dirname) +} + +func withTmpDirWithoutSupfile(t *testing.T, test func(dirname string)) { + dirname, err := tmpDir() + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dirname) + + test(dirname) +} + +func tmpDir() (string, error) { + return ioutil.TempDir("", "suptest") +} + +func writeDefaultSupfile(dirname, input string) error { + return writeSupfileAs(dirname, "Supfile", input) +} + +func writeSupfileAs(dirname, filename, input string) error { + return ioutil.WriteFile( + filepath.Join(dirname, filename), + []byte(input), + 0666, + ) +} + +func testEnv() venv.Env { + env := venv.Mock() + env.Setenv("USER", envTestUser) + return env +} diff --git a/cmd/sup/matchers_test.go b/cmd/sup/matchers_test.go new file mode 100644 index 0000000..fcd266a --- /dev/null +++ b/cmd/sup/matchers_test.go @@ -0,0 +1,102 @@ +package main + +import ( + "bytes" + "fmt" + "regexp" + "strings" + "testing" +) + +// matcher defines high level expectations over a collection of output buffers +type matcher struct { + outputs []bytes.Buffer + t *testing.T + activeServers []int +} + +func newMatcher(outputs []bytes.Buffer, t *testing.T) matcher { + return matcher{ + outputs: outputs, + t: t, + } +} + +func (m *matcher) expectActivityOnServers(servers ...int) { + m.activeServers = servers + m.onEachActiveServer(func(server int, output string) { + if len(output) == 0 { + m.t.Errorf("expected activity on server #%d", server) + } + }) +} +func (m *matcher) expectNoActivityOnServers(servers ...int) { + for _, server := range servers { + if server >= len(m.outputs) || server < 0 { + m.t.Errorf("output from server #%d not provided", server) + return + } + output := m.outputs[server] + if output.Len() > 0 { + m.t.Errorf("expected no activity on server #%d:\n%s", server, output.String()) + } + } +} + +func (m matcher) expectExportOnActiveServers(export string) { + m.onEachActiveServer(func(server int, output string) { + for i, executed := range strings.Split(output, "\n") { + if !strings.Contains(executed, fmt.Sprintf("export %s;", export)) { + m.t.Errorf( + "command #%d on server #%d does not export `%s`:\n%s", + i, + server, + export, + executed, + ) + } + } + }) +} +func (m matcher) expectExportRegexpOnActiveServers(export string) { + m.onEachActiveServer(func(server int, output string) { + for i, executed := range strings.Split(output, "\n") { + re, err := regexp.Compile(fmt.Sprintf("export %s;", export)) + if err != nil { + m.t.Fatal(err) + } + if !re.MatchString(executed) { + m.t.Errorf( + "command #%d on server #%d does not export `%s`:\n%s", + i, + server, + export, + executed, + ) + } + } + }) +} + +func (m matcher) expectCommandOnActiveServers(command string) { + m.onEachActiveServer(func(server int, output string) { + for _, executed := range strings.Split(output, "\n") { + if strings.HasSuffix(executed, fmt.Sprintf(";%s", command)) { + return + } + } + m.t.Errorf("no command on server #%d executed `%s`", server, command) + }) +} + +func (m matcher) onEachActiveServer(expectation func(server int, output string)) { + for _, server := range m.activeServers { + if server >= len(m.outputs) || server < 0 { + m.t.Errorf("output from server #%d not provided", server) + return + } + + output := m.outputs[server] + expectation(server, strings.TrimSpace(output.String())) + } +} diff --git a/cmd/sup/mock_server_test.go b/cmd/sup/mock_server_test.go new file mode 100644 index 0000000..4443e0f --- /dev/null +++ b/cmd/sup/mock_server_test.go @@ -0,0 +1,289 @@ +package main + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "io" + "io/ioutil" + "net" + "os" + "path" + "strings" + "text/template" + + "github.com/pkg/errors" + "golang.org/x/crypto/ssh" +) + +// setupMockEnv prepares testing environment, it +// +// - creates a temporary directory for all files +// - generates RSA key pair; the private key is written into a file, +// fingerprint of the public key is written into a file as an authorized key +// - spins up mock SSH servers with the same authorized key +// - writes an SSH config file with entries for all servers, naming them +// server0, server1 etc. +func setupMockEnv(dirname string, count int) ([]bytes.Buffer, options, error) { + + privateKeyPath := path.Join(dirname, "gotest_private_key") + authorizedKeysPath := path.Join(dirname, "authorized_keys") + sshConfigPath := path.Join(dirname, "ssh_config") + + if err := generateKeyPair(privateKeyPath, authorizedKeysPath); err != nil { + return nil, options{}, err + } + + outputs := make([]bytes.Buffer, count) + addresses := make([]string, count) + for i := 0; i < count; i++ { + runTestServer(authorizedKeysPath, &addresses[i], &outputs[i]) + } + + err := writeSSHConfigFile(privateKeyPath, sshConfigPath, addresses) + if err != nil { + return nil, options{}, err + } + + options := options{ + sshConfig: sshConfigPath, + dirname: dirname, + env: testEnv(), + } + return outputs, options, nil +} + +// generateKeyPair generates a pair of keys, the private key is written into +// a file and the fingerprint of the public key into authorized_keys file for +// the server to use +func generateKeyPair(privateKeyPath, authorizedKeysPath string) error { + privateKey, err := generatePrivateRSAKey() + if err != nil { + return err + } + if err := writePrivateKeyToFile(privateKey, privateKeyPath); err != nil { + return err + } + + publicKey := privateKey.PublicKey + pub, err := ssh.NewPublicKey(&publicKey) + if err != nil { + return err + } + + return ioutil.WriteFile( + authorizedKeysPath, + ssh.MarshalAuthorizedKey(pub), + 0666, + ) +} + +func generatePrivateRSAKey() (*rsa.PrivateKey, error) { + return rsa.GenerateKey(rand.Reader, 2014) +} + +func writePrivateKeyToFile(privateKey *rsa.PrivateKey, filepath string) error { + privateKeyBlock := pem.Block{ + Type: "RSA PRIVATE KEY", + Headers: nil, + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + } + return ioutil.WriteFile( + filepath, + pem.EncodeToMemory(&privateKeyBlock), + 0666, + ) +} + +func runTestServer(authorizedKeysPath string, addr *string, out io.Writer) error { + authorizedKeysMap, err := loadAuthorizedKeys(authorizedKeysPath) + if err != nil { + return err + } + + config, err := buildServerConfig(authorizedKeysMap) + if err != nil { + return err + } + + listener, err := net.Listen("tcp", "localhost:") + if err != nil { + return errors.Wrap(err, "failed to listen for connection") + } + *addr = listener.Addr().String() + + go sshListen(config, listener, out) + + return nil +} + +func buildServerConfig(authorizedKeysMap map[string]bool) (*ssh.ServerConfig, error) { + // An SSH server is represented by a ServerConfig, which holds + // certificate details and handles authentication of ServerConns. + config := &ssh.ServerConfig{ + // Remove to disable public key auth. + PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + if authorizedKeysMap[string(pubKey.Marshal())] { + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "pubkey-fp": fingerprintSHA256(pubKey), + }, + }, nil + } + return nil, fmt.Errorf("unknown public key for %q", c.User()) + }, + } + + key, err := generatePrivateRSAKey() + if err != nil { + return nil, err + } + + private, err := ssh.NewSignerFromKey(key) + if err != nil { + return nil, err + } + + config.AddHostKey(private) + return config, nil +} + +func sshListen(config *ssh.ServerConfig, listener net.Listener, out io.Writer) { + func() { + nConn, err := listener.Accept() + if err != nil { + panic(errors.Wrap(err, "failed to accept incoming connection")) + } + + // Before use, a handshake must be performed on the incoming + // net.Conn. + _, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + panic(errors.Wrap(err, "failed to handshake")) + } + + // The incoming Request channel must be serviced. + go ssh.DiscardRequests(reqs) + + // Service the incoming Channel channel. + for newChannel := range chans { + // Channels have a type, depending on the application level + // protocol intended. In the case of a shell, the type is + // "session" and ServerShell may be used to present a simple + // terminal interface. + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + panic(errors.Wrap(err, "Could not accept channel")) + } + + go func(in <-chan *ssh.Request) { + defer channel.Close() + + for req := range in { + // reply to pty-req with success + if req.Type == "pty-req" { + req.Reply(true, []byte{}) + + // read exec command, write it to output and respond with success + } else if req.Type == "exec" { + type execMsg struct { + Command string + } + var payload execMsg + ssh.Unmarshal(req.Payload, &payload) + out.Write([]byte(payload.Command + "\n")) + req.Reply(true, nil) + + channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) + if err := channel.Close(); err != nil { + panic(err) + } + } + } + }(requests) + } + }() +} + +func fingerprintSHA256(pubKey ssh.PublicKey) string { + sha256sum := sha256.Sum256(pubKey.Marshal()) + hash := base64.RawStdEncoding.EncodeToString(sha256sum[:]) + return "SHA256:" + hash +} + +func loadAuthorizedKeys(filepath string) (map[string]bool, error) { + authorizedKeysBytes, err := ioutil.ReadFile(filepath) + if err != nil { + return nil, errors.Wrapf(err, "failed to load %sv", filepath) + } + authorizedKeysMap := map[string]bool{} + for len(authorizedKeysBytes) > 0 { + pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) + if err != nil { + return nil, err + } + + authorizedKeysMap[string(pubKey.Marshal())] = true + authorizedKeysBytes = rest + } + return authorizedKeysMap, nil +} + +// writes simple SSH config file for the given servers naming them server0, +// server1 etc. +func writeSSHConfigFile(privateKeyPath, sshConfigPath string, addresses []string) error { + type sshRecord struct { + Host string + Port string + IdentityFilename string + } + records := make([]sshRecord, len(addresses)) + for i, addr := range addresses { + records[i].Host = fmt.Sprintf("server%d", i) + records[i].IdentityFilename = privateKeyPath + records[i].Port = strings.Split(addr, ":")[1] + } + + sshConfigTemplate := ` +{{range .Records}} +Host {{.Host}} + HostName localhost + Port {{.Port}} + IdentityFile {{.IdentityFilename}} +{{end}} +` + + tmpl := template.New("ssh_config") + tmpl, err := tmpl.Parse(sshConfigTemplate) + if err != nil { + return err + } + + file, err := os.Create(sshConfigPath) + if err != nil { + return err + } + defer file.Close() + + data := struct { + Records []sshRecord + }{ + Records: records, + } + + if err := tmpl.Execute(file, data); err != nil { + return err + } + + return nil +} diff --git a/ssh.go b/ssh.go index eb3cefb..225b13e 100644 --- a/ssh.go +++ b/ssh.go @@ -9,7 +9,6 @@ import ( "os/user" "path/filepath" "strings" - "sync" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" @@ -21,6 +20,7 @@ type SSHClient struct { sess *ssh.Session user string host string + identityFile string remoteStdin io.WriteCloser remoteStdout io.Reader remoteStderr io.Reader @@ -29,6 +29,7 @@ type SSHClient struct { running bool env string //export FOO="bar"; export BAR="baz"; color string + authMethod ssh.AuthMethod } type ErrConnect struct { @@ -76,11 +77,12 @@ func (c *SSHClient) parseHost(host string) error { return nil } -var initAuthMethodOnce sync.Once -var authMethod ssh.AuthMethod +// ensureAuthMethod initiates SSH authentication method. +func (c *SSHClient) ensureAuthMethod() { + if c.authMethod != nil { + return + } -// initAuthMethod initiates SSH authentication method. -func initAuthMethod() { var signers []ssh.Signer // If there's a running SSH Agent, try to use its Private keys. @@ -92,6 +94,10 @@ func initAuthMethod() { // Try to read user's SSH private keys form the standard paths. files, _ := filepath.Glob(os.Getenv("HOME") + "/.ssh/id_*") + // Add nonstandard path + if c.identityFile != "" { + files = append(files, c.identityFile) + } for _, file := range files { if strings.HasSuffix(file, ".pub") { continue // Skip public keys. @@ -107,7 +113,7 @@ func initAuthMethod() { signers = append(signers, signer) } - authMethod = ssh.PublicKeys(signers...) + c.authMethod = ssh.PublicKeys(signers...) } // SSHDialFunc can dial an ssh server and return a client @@ -126,8 +132,7 @@ func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error { if c.connOpened { return fmt.Errorf("Already connected") } - - initAuthMethodOnce.Do(initAuthMethod) + c.ensureAuthMethod() err := c.parseHost(host) if err != nil { @@ -137,7 +142,7 @@ func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error { config := &ssh.ClientConfig{ User: c.user, Auth: []ssh.AuthMethod{ - authMethod, + c.authMethod, }, } diff --git a/sup.go b/sup.go index d815068..2f6db1f 100644 --- a/sup.go +++ b/sup.go @@ -70,9 +70,10 @@ func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command) // SSH client. remote := &SSHClient{ - env: env + `export SUP_HOST="` + host + `";`, - user: network.User, - color: Colors[i%len(Colors)], + env: env + `export SUP_HOST="` + host + `";`, + user: network.User, + color: Colors[i%len(Colors)], + identityFile: network.IdentityFile, } if bastion != nil { diff --git a/supfile.go b/supfile.go index a116ae5..06864e8 100644 --- a/supfile.go +++ b/supfile.go @@ -255,7 +255,7 @@ func (e ErrUnsupportedSupfileVersion) Error() string { } // NewSupfile parses configuration file and returns Supfile or error. -func NewSupfile(data []byte) (*Supfile, error) { +func NewSupfile(errStream io.Writer, data []byte) (*Supfile, error) { var conf Supfile if err := yaml.Unmarshal(data, &conf); err != nil { @@ -305,7 +305,7 @@ func NewSupfile(data []byte) (*Supfile, error) { } } if warning != "" { - fmt.Fprintf(os.Stderr, warning) + fmt.Fprintf(errStream, warning) } fallthrough @@ -321,13 +321,13 @@ func NewSupfile(data []byte) (*Supfile, error) { // ParseInventory runs the inventory command, if provided, and appends // the command's output lines to the manually defined list of hosts. -func (n Network) ParseInventory() ([]string, error) { +func (n Network) ParseInventory(errStream io.Writer) ([]string, error) { if n.Inventory == "" { return nil, nil } cmd := exec.Command("/bin/sh", "-c", n.Inventory) - cmd.Stderr = os.Stderr + cmd.Stderr = errStream output, err := cmd.Output() if err != nil { return nil, err diff --git a/vendor/github.com/adammck/venv/README.md b/vendor/github.com/adammck/venv/README.md new file mode 100644 index 0000000..09a10ba --- /dev/null +++ b/vendor/github.com/adammck/venv/README.md @@ -0,0 +1,40 @@ +# venv + +[![GoDoc](https://godoc.org/github.com/adammck/venv?status.svg)](https://godoc.org/github.com/adammck/venv) +[![Build Status](https://travis-ci.org/adammck/venv.svg?branch=master)](https://travis-ci.org/adammck/venv) + +This is a Go library to abstract access to environment variables. +Like [spf13/afero][afero] or [blang/vfs][vfs], but for the env. + +## Usage + +```go +package main + +import ( + "fmt" + "github.com/adammck/venv" +) + +func main() { + var e venv.Env + + // Use the real environment + + e = venv.OS() + fmt.Printf("Hello, %s!\n", e.Getenv("USER")) + + // Or use a mock + + e = venv.Mock() + e.Setenv("USER", "fred") + fmt.Printf("Hello, %s!\n", e.Getenv("USER")) +} +``` + +## License + +MIT. + +[afero]: https://github.com/spf13/afero +[vfs]: https://github.com/blang/vfs diff --git a/vendor/github.com/adammck/venv/mock/mock.go b/vendor/github.com/adammck/venv/mock/mock.go new file mode 100644 index 0000000..4ca535d --- /dev/null +++ b/vendor/github.com/adammck/venv/mock/mock.go @@ -0,0 +1,38 @@ +package mock + +import ( + "fmt" +) + +type MockEnv struct { + data map[string]string +} + +func New() *MockEnv { + e := &MockEnv{} + e.Clearenv() + return e +} + +func (e *MockEnv) Environ() []string { + out := []string{} + + for k, v := range e.data { + out = append(out, fmt.Sprintf("%s=%s", k, v)) + } + + return out +} + +func (e *MockEnv) Getenv(key string) string { + return e.data[key] +} + +func (e *MockEnv) Setenv(key, value string) error { + e.data[key] = value + return nil +} + +func (e *MockEnv) Clearenv() { + e.data = make(map[string]string) +} diff --git a/vendor/github.com/adammck/venv/os/os.go b/vendor/github.com/adammck/venv/os/os.go new file mode 100644 index 0000000..3f62e83 --- /dev/null +++ b/vendor/github.com/adammck/venv/os/os.go @@ -0,0 +1,28 @@ +package os + +import ( + "os" +) + +type OsEnv struct { +} + +func New() *OsEnv { + return &OsEnv{} +} + +func (e *OsEnv) Getenv(key string) string { + return os.Getenv(key) +} + +func (e *OsEnv) Setenv(key, value string) error { + return os.Setenv(key, value) +} + +func (e *OsEnv) Environ() []string { + return os.Environ() +} + +func (e *OsEnv) Clearenv() { + os.Clearenv() +} diff --git a/vendor/github.com/adammck/venv/venv.go b/vendor/github.com/adammck/venv/venv.go new file mode 100644 index 0000000..a873cc9 --- /dev/null +++ b/vendor/github.com/adammck/venv/venv.go @@ -0,0 +1,21 @@ +package venv + +import ( + "github.com/adammck/venv/mock" + "github.com/adammck/venv/os" +) + +type Env interface { + Environ() []string + Getenv(key string) string + Setenv(key, value string) error + Clearenv() +} + +func OS() Env { + return os.New() +} + +func Mock() Env { + return mock.New() +} diff --git a/vendor/vendor.json b/vendor/vendor.json index 30fae56..202b0d7 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -2,6 +2,24 @@ "comment": "", "ignore": "test appengine", "package": [ + { + "checksumSHA1": "JpbTntZjOHcUo5P9VrdcTFEGZ9Y=", + "path": "github.com/adammck/venv", + "revision": "8a9c907a37d36a8f34fa1c5b81aaf80c2554a306", + "revisionTime": "2016-08-19T02:56:05Z" + }, + { + "checksumSHA1": "5yJLvX3vSDfIqaCqSEaIwM0B3co=", + "path": "github.com/adammck/venv/mock", + "revision": "8a9c907a37d36a8f34fa1c5b81aaf80c2554a306", + "revisionTime": "2016-08-19T02:56:05Z" + }, + { + "checksumSHA1": "FIK6KYgGeXhdfcNHH80Ub1fBvwM=", + "path": "github.com/adammck/venv/os", + "revision": "8a9c907a37d36a8f34fa1c5b81aaf80c2554a306", + "revisionTime": "2016-08-19T02:56:05Z" + }, { "path": "github.com/goware/prefixer", "revision": "395022866408d928fc2439f7eac73dd8d370ec1d",