From 829e6de049e3b1c482aa9c1d9d325dd845a67952 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0t=C4=9Bp=C3=A1n=20Pila=C5=99?= Date: Wed, 7 Mar 2018 17:37:28 +0100 Subject: [PATCH 01/12] Use parsed identity file https://github.com/pressly/sup/issues/128 --- ssh.go | 57 ++++++++++++++++++++++++++++++++------------------------- sup.go | 7 ++++--- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/ssh.go b/ssh.go index eb3cefb..43c4546 100644 --- a/ssh.go +++ b/ssh.go @@ -21,6 +21,7 @@ type SSHClient struct { sess *ssh.Session user string host string + identityFile string remoteStdin io.WriteCloser remoteStdout io.Reader remoteStderr io.Reader @@ -80,34 +81,40 @@ var initAuthMethodOnce sync.Once var authMethod ssh.AuthMethod // initAuthMethod initiates SSH authentication method. -func initAuthMethod() { - var signers []ssh.Signer - - // If there's a running SSH Agent, try to use its Private keys. - sock, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) - if err == nil { - agent := agent.NewClient(sock) - signers, _ = agent.Signers() - } - - // Try to read user's SSH private keys form the standard paths. - files, _ := filepath.Glob(os.Getenv("HOME") + "/.ssh/id_*") - for _, file := range files { - if strings.HasSuffix(file, ".pub") { - continue // Skip public keys. - } - data, err := ioutil.ReadFile(file) - if err != nil { - continue +func initAuthMethod(identityFilePath string) func() { + return func() { + var signers []ssh.Signer + + // If there's a running SSH Agent, try to use its Private keys. + sock, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) + if err == nil { + agent := agent.NewClient(sock) + signers, _ = agent.Signers() } - signer, err := ssh.ParsePrivateKey(data) - if err != nil { - continue + + // Try to read user's SSH private keys form the standard paths. + files, _ := filepath.Glob(os.Getenv("HOME") + "/.ssh/id_*") + // Add nonstandard path + if identityFilePath != "" { + files = append(files, identityFilePath) } - signers = append(signers, signer) + for _, file := range files { + if strings.HasSuffix(file, ".pub") { + continue // Skip public keys. + } + data, err := ioutil.ReadFile(file) + if err != nil { + continue + } + signer, err := ssh.ParsePrivateKey(data) + if err != nil { + continue + } + signers = append(signers, signer) + } + authMethod = ssh.PublicKeys(signers...) } - authMethod = ssh.PublicKeys(signers...) } // SSHDialFunc can dial an ssh server and return a client @@ -127,7 +134,7 @@ func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error { return fmt.Errorf("Already connected") } - initAuthMethodOnce.Do(initAuthMethod) + initAuthMethodOnce.Do(initAuthMethod(c.identityFile)) err := c.parseHost(host) if err != nil { diff --git a/sup.go b/sup.go index d815068..2f6db1f 100644 --- a/sup.go +++ b/sup.go @@ -70,9 +70,10 @@ func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command) // SSH client. remote := &SSHClient{ - env: env + `export SUP_HOST="` + host + `";`, - user: network.User, - color: Colors[i%len(Colors)], + env: env + `export SUP_HOST="` + host + `";`, + user: network.User, + color: Colors[i%len(Colors)], + identityFile: network.IdentityFile, } if bastion != nil { From 7e3e816d61a72bbdb3165913dc52c76a36d06c4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0t=C4=9Bp=C3=A1n=20Pila=C5=99?= Date: Wed, 7 Mar 2018 20:24:15 +0100 Subject: [PATCH 02/12] Apply SSH config host only to the matched host https://github.com/pressly/sup/issues/128 --- cmd/sup/main.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/sup/main.go b/cmd/sup/main.go index 3a36928..da99e77 100644 --- a/cmd/sup/main.go +++ b/cmd/sup/main.go @@ -303,12 +303,12 @@ func main() { } // check network.Hosts for match - for _, host := range network.Hosts { + for i, host := range network.Hosts { conf, found := confMap[host] if found { network.User = conf.User network.IdentityFile = resolvePath(conf.IdentityFile) - network.Hosts = []string{fmt.Sprintf("%s:%d", conf.HostName, conf.Port)} + network.Hosts[i] = fmt.Sprintf("%s:%d", conf.HostName, conf.Port) } } } From 60218c8ab4c72d2ee8a209539d6d122cb1490ff9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0t=C4=9Bp=C3=A1n=20Pila=C5=99?= Date: Wed, 7 Mar 2018 15:35:25 +0100 Subject: [PATCH 03/12] Setup mock server for tests and minimal code to run it --- cmd/sup/main_test.go | 99 +++++++++++++ cmd/sup/mock_server_test.go | 275 ++++++++++++++++++++++++++++++++++++ 2 files changed, 374 insertions(+) create mode 100644 cmd/sup/main_test.go create mode 100644 cmd/sup/mock_server_test.go diff --git a/cmd/sup/main_test.go b/cmd/sup/main_test.go new file mode 100644 index 0000000..8c8f6e7 --- /dev/null +++ b/cmd/sup/main_test.go @@ -0,0 +1,99 @@ +package main + +import ( + "flag" + "fmt" + "os" + "testing" + + "github.com/mikkeloscar/sshconfig" + "github.com/pressly/sup" +) + +func TestSSH(t *testing.T) { + outputs, err := setupMockEnv(sshConfigFilename, 3) + if err != nil { + t.Fatal(err) + } + + flag.CommandLine = flag.NewFlagSet("test", flag.ExitOnError) + flag.CommandLine.Parse([]string{"local", "t"}) + + input := ` +--- +version: 0.4 + +networks: + local: + hosts: + - server0 + - server1 + - server2 + +commands: + test: + run: echo "Hey over there" + test2: + run: echo "Hey again" + +targets: + t: + - test + - test2 +` + conf, err := sup.NewSupfile([]byte(input)) + if err != nil { + t.Fatal(err) + } + + network, commands, err := parseArgs(conf) + if err != nil { + t.Fatal(err) + } + + confHosts, err := sshconfig.ParseSSHConfig(sshConfigFilename) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + + // flatten Host -> *SSHHost, not the prettiest + // but will do + confMap := map[string]*sshconfig.SSHHost{} + for _, conf := range confHosts { + for _, host := range conf.Host { + confMap[host] = conf + } + } + + // check network.Hosts for match + for i, host := range network.Hosts { + conf, found := confMap[host] + if found { + network.User = conf.User + network.IdentityFile = resolvePath(conf.IdentityFile) + network.Hosts[i] = fmt.Sprintf("%s:%d", conf.HostName, conf.Port) + } + } + + var vars sup.EnvList + for _, val := range append(conf.Env, network.Env...) { + vars.Set(val.Key, val.Value) + } + if err := vars.ResolveValues(); err != nil { + t.Fatal(err) + } + + app, err := sup.New(conf) + if err != nil { + t.Fatal(err) + } + + err = app.Run(network, vars, commands...) + if err != nil { + t.Fatal(err) + } + for i, output := range outputs { + t.Log(i, output.String()) + } +} diff --git a/cmd/sup/mock_server_test.go b/cmd/sup/mock_server_test.go new file mode 100644 index 0000000..fe8e7e8 --- /dev/null +++ b/cmd/sup/mock_server_test.go @@ -0,0 +1,275 @@ +package main + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "io" + "io/ioutil" + "net" + "os" + "strings" + "text/template" + + "github.com/pkg/errors" + "golang.org/x/crypto/ssh" +) + +const privateKeyFilename = "gotest_private_key" +const authorizedKeysFilename = "authorized_keys" +const sshConfigFilename = "ssh_config" + +// setupMockEnv prepares testing environment, it +// +// - generates RSA key pair +// - the private key is written into a file +// - fingerprint of the public key is written into a file as an authorized key +// - spins up mock SSH servers with the same authorized key +// - writes an SSH config file with entries for all servers +func setupMockEnv(sshConfigFilename string, count int) ([]bytes.Buffer, error) { + if err := generateKeyPair(privateKeyFilename, authorizedKeysFilename); err != nil { + return nil, err + } + + outputs := make([]bytes.Buffer, count) + addresses := make([]string, count) + for i := 0; i < count; i++ { + runTestServer(authorizedKeysFilename, &addresses[i], &outputs[i]) + } + + if err := writeSSHConfigFile(privateKeyFilename, sshConfigFilename, addresses); err != nil { + return nil, err + } + return outputs, nil +} + +// generateKeyPair generates a pair of keys, the private key is written into +// a file and the fingerprint of the public key into authorized_keys file for +// the server to use +func generateKeyPair(privateKeyFilename, authorizedKeysFilename string) error { + privateKey, err := generatePrivateRSAKey() + if err != nil { + return err + } + if err := writePrivateKeyToFile(privateKey, privateKeyFilename); err != nil { + return err + } + + publicKey := privateKey.PublicKey + pub, err := ssh.NewPublicKey(&publicKey) + if err != nil { + return err + } + + return ioutil.WriteFile(authorizedKeysFilename, ssh.MarshalAuthorizedKey(pub), os.ModePerm) +} + +func generatePrivateRSAKey() (*rsa.PrivateKey, error) { + return rsa.GenerateKey(rand.Reader, 2014) +} + +func writePrivateKeyToFile(privateKey *rsa.PrivateKey, filename string) error { + privateKeyBlock := pem.Block{ + Type: "RSA PRIVATE KEY", + Headers: nil, + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + } + return ioutil.WriteFile( + filename, + pem.EncodeToMemory(&privateKeyBlock), + os.ModePerm, + ) +} + +func runTestServer(authorizedKeysFilename string, addr *string, out io.Writer) error { + authorizedKeysMap, err := loadAuthorizedKeys(authorizedKeysFilename) + if err != nil { + return err + } + + config, err := buildServerConfig(authorizedKeysMap) + if err != nil { + return err + } + + listener, err := net.Listen("tcp", "localhost:") + if err != nil { + return errors.Wrap(err, "failed to listen for connection") + } + *addr = listener.Addr().String() + + go sshListen(config, listener, out) + + return nil +} + +func buildServerConfig(authorizedKeysMap map[string]bool) (*ssh.ServerConfig, error) { + // An SSH server is represented by a ServerConfig, which holds + // certificate details and handles authentication of ServerConns. + config := &ssh.ServerConfig{ + // Remove to disable public key auth. + PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + if authorizedKeysMap[string(pubKey.Marshal())] { + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "pubkey-fp": fingerprintSHA256(pubKey), + }, + }, nil + } + return nil, fmt.Errorf("unknown public key for %q", c.User()) + }, + } + + key, err := generatePrivateRSAKey() + if err != nil { + return nil, err + } + + private, err := ssh.NewSignerFromKey(key) + if err != nil { + return nil, err + } + + config.AddHostKey(private) + return config, nil +} + +func sshListen(config *ssh.ServerConfig, listener net.Listener, out io.Writer) { + func() { + nConn, err := listener.Accept() + if err != nil { + panic(errors.Wrap(err, "failed to accept incoming connection")) + } + + // Before use, a handshake must be performed on the incoming + // net.Conn. + _, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + panic(errors.Wrap(err, "failed to handshake")) + } + + // The incoming Request channel must be serviced. + go ssh.DiscardRequests(reqs) + + // Service the incoming Channel channel. + for newChannel := range chans { + // Channels have a type, depending on the application level + // protocol intended. In the case of a shell, the type is + // "session" and ServerShell may be used to present a simple + // terminal interface. + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + panic(errors.Wrap(err, "Could not accept channel")) + } + + go func(in <-chan *ssh.Request) { + defer channel.Close() + + for req := range in { + // reply to pty-req with success + if req.Type == "pty-req" { + req.Reply(true, []byte{}) + + // read exec command, write it to output and respond with success + } else if req.Type == "exec" { + type execMsg struct { + Command string + } + var payload execMsg + ssh.Unmarshal(req.Payload, &payload) + out.Write([]byte(payload.Command + "\n")) + req.Reply(true, nil) + + channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) + if err := channel.Close(); err != nil { + panic(err) + } + } + } + }(requests) + } + }() +} + +func fingerprintSHA256(pubKey ssh.PublicKey) string { + sha256sum := sha256.Sum256(pubKey.Marshal()) + hash := base64.RawStdEncoding.EncodeToString(sha256sum[:]) + return "SHA256:" + hash +} + +func loadAuthorizedKeys(filename string) (map[string]bool, error) { + authorizedKeysBytes, err := ioutil.ReadFile(filename) + if err != nil { + return nil, errors.Wrapf(err, "failed to load %sv", filename) + } + authorizedKeysMap := map[string]bool{} + for len(authorizedKeysBytes) > 0 { + pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) + if err != nil { + return nil, err + } + + authorizedKeysMap[string(pubKey.Marshal())] = true + authorizedKeysBytes = rest + } + return authorizedKeysMap, nil +} + +// writes simple SSH config file for the given servers naming them server0, +// server1 etc. +func writeSSHConfigFile(privateKeyFilename, sshConfigFilename string, addresses []string) error { + type sshRecord struct { + Host string + Port string + IdentityFilename string + } + records := make([]sshRecord, len(addresses)) + for i, addr := range addresses { + records[i].Host = fmt.Sprintf("server%d", i) + records[i].IdentityFilename = privateKeyFilename + records[i].Port = strings.Split(addr, ":")[1] + } + + sshConfigTemplate := ` +{{range .Records}} +Host {{.Host}} + HostName localhost + Port {{.Port}} + IdentityFile {{.IdentityFilename}} +{{end}} +` + + tmpl := template.New("ssh_config") + tmpl, err := tmpl.Parse(sshConfigTemplate) + if err != nil { + return err + } + + file, err := os.Create(sshConfigFilename) + if err != nil { + return err + } + defer file.Close() + + data := struct { + Records []sshRecord + }{ + Records: records, + } + + if err := tmpl.Execute(file, data); err != nil { + return err + } + + return nil +} From 53024daa06eef979956fbd9abfec418400a708f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0t=C4=9Bp=C3=A1n=20Pila=C5=99?= Date: Wed, 14 Mar 2018 20:25:08 +0100 Subject: [PATCH 04/12] Add matcher and expectations --- cmd/sup/main_test.go | 10 +++-- cmd/sup/matchers_test.go | 82 +++++++++++++++++++++++++++++++++++++ cmd/sup/mock_server_test.go | 8 ++-- 3 files changed, 92 insertions(+), 8 deletions(-) create mode 100644 cmd/sup/matchers_test.go diff --git a/cmd/sup/main_test.go b/cmd/sup/main_test.go index 8c8f6e7..10a78fc 100644 --- a/cmd/sup/main_test.go +++ b/cmd/sup/main_test.go @@ -27,7 +27,6 @@ networks: local: hosts: - server0 - - server1 - server2 commands: @@ -93,7 +92,10 @@ targets: if err != nil { t.Fatal(err) } - for i, output := range outputs { - t.Log(i, output.String()) - } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 2) + m.expectNoActivityOnServers(1) + m.expectExportOnActiveServers(`SUP_NETWORK="local"`) + m.expectCommandOnActiveServers(`echo "Hey over there"`) } diff --git a/cmd/sup/matchers_test.go b/cmd/sup/matchers_test.go new file mode 100644 index 0000000..7266402 --- /dev/null +++ b/cmd/sup/matchers_test.go @@ -0,0 +1,82 @@ +package main + +import ( + "bytes" + "fmt" + "strings" + "testing" +) + +// matcher defines high level expectations over a collection of output buffers +type matcher struct { + outputs []bytes.Buffer + t *testing.T + activeServers []int +} + +func newMatcher(outputs []bytes.Buffer, t *testing.T) matcher { + return matcher{ + outputs: outputs, + t: t, + } +} + +func (m *matcher) expectActivityOnServers(servers ...int) { + m.activeServers = servers + m.onEachActiveServer(func(server int, output string) { + if len(output) == 0 { + m.t.Errorf("expected activity on server #%d", server) + } + }) +} +func (m *matcher) expectNoActivityOnServers(servers ...int) { + for _, server := range servers { + if server >= len(m.outputs) || server < 0 { + m.t.Errorf("output from server #%d not provided", server) + return + } + output := m.outputs[server] + if output.Len() > 0 { + m.t.Errorf("expected no activity on server #%d:\n%s", server, output.String()) + } + } +} + +func (m matcher) expectExportOnActiveServers(export string) { + m.onEachActiveServer(func(server int, output string) { + for i, executed := range strings.Split(output, "\n") { + if !strings.Contains(executed, fmt.Sprintf("export %s;", export)) { + m.t.Errorf( + "command #%d on server #%d does not export `%s`:\n%s", + i, + server, + export, + executed, + ) + } + } + }) +} + +func (m matcher) expectCommandOnActiveServers(command string) { + m.onEachActiveServer(func(server int, output string) { + for _, executed := range strings.Split(output, "\n") { + if strings.HasSuffix(executed, fmt.Sprintf(";%s", command)) { + return + } + } + m.t.Errorf("no command on server #%d executed `%s`", server, command) + }) +} + +func (m matcher) onEachActiveServer(expectation func(server int, output string)) { + for _, server := range m.activeServers { + if server >= len(m.outputs) || server < 0 { + m.t.Errorf("output from server #%d not provided", server) + return + } + + output := m.outputs[server] + expectation(server, strings.TrimSpace(output.String())) + } +} diff --git a/cmd/sup/mock_server_test.go b/cmd/sup/mock_server_test.go index fe8e7e8..b1b5010 100644 --- a/cmd/sup/mock_server_test.go +++ b/cmd/sup/mock_server_test.go @@ -26,11 +26,11 @@ const sshConfigFilename = "ssh_config" // setupMockEnv prepares testing environment, it // -// - generates RSA key pair -// - the private key is written into a file -// - fingerprint of the public key is written into a file as an authorized key +// - generates RSA key pair; the private key is written into a file, +// fingerprint of the public key is written into a file as an authorized key // - spins up mock SSH servers with the same authorized key -// - writes an SSH config file with entries for all servers +// - writes an SSH config file with entries for all servers, naming them +// server0, server1 etc. func setupMockEnv(sshConfigFilename string, count int) ([]bytes.Buffer, error) { if err := generateKeyPair(privateKeyFilename, authorizedKeysFilename); err != nil { return nil, err From 8717c27c6893141a7b368fce15cc3b8473dad725 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0t=C4=9Bp=C3=A1n=20Pila=C5=99?= Date: Wed, 14 Mar 2018 20:53:27 +0100 Subject: [PATCH 05/12] Generate all test files into a tmp directory --- cmd/sup/main_test.go | 5 +-- cmd/sup/mock_server_test.go | 63 ++++++++++++++++++++++--------------- 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/cmd/sup/main_test.go b/cmd/sup/main_test.go index 10a78fc..00227bd 100644 --- a/cmd/sup/main_test.go +++ b/cmd/sup/main_test.go @@ -11,10 +11,11 @@ import ( ) func TestSSH(t *testing.T) { - outputs, err := setupMockEnv(sshConfigFilename, 3) + outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 3) if err != nil { t.Fatal(err) } + defer cleanup() flag.CommandLine = flag.NewFlagSet("test", flag.ExitOnError) flag.CommandLine.Parse([]string{"local", "t"}) @@ -50,7 +51,7 @@ targets: t.Fatal(err) } - confHosts, err := sshconfig.ParseSSHConfig(sshConfigFilename) + confHosts, err := sshconfig.ParseSSHConfig(sshConfigPath) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) diff --git a/cmd/sup/mock_server_test.go b/cmd/sup/mock_server_test.go index b1b5010..ed2cfec 100644 --- a/cmd/sup/mock_server_test.go +++ b/cmd/sup/mock_server_test.go @@ -18,45 +18,54 @@ import ( "github.com/pkg/errors" "golang.org/x/crypto/ssh" + "path" ) -const privateKeyFilename = "gotest_private_key" -const authorizedKeysFilename = "authorized_keys" -const sshConfigFilename = "ssh_config" - // setupMockEnv prepares testing environment, it // +// - creates a temporary directory for all files // - generates RSA key pair; the private key is written into a file, // fingerprint of the public key is written into a file as an authorized key // - spins up mock SSH servers with the same authorized key // - writes an SSH config file with entries for all servers, naming them // server0, server1 etc. -func setupMockEnv(sshConfigFilename string, count int) ([]bytes.Buffer, error) { - if err := generateKeyPair(privateKeyFilename, authorizedKeysFilename); err != nil { - return nil, err +func setupMockEnv(sshConfigFilename string, count int) ([]bytes.Buffer, string, func(), error) { + dir, err := ioutil.TempDir("", "suptest") + if err != nil { + return nil, "", nil, err + } + + privateKeyPath := path.Join(dir, "gotest_private_key") + authorizedKeysPath := path.Join(dir, "authorized_keys") + sshConfigPath := path.Join(dir, sshConfigFilename) + + if err := generateKeyPair(privateKeyPath, authorizedKeysPath); err != nil { + return nil, "", nil, err } outputs := make([]bytes.Buffer, count) addresses := make([]string, count) for i := 0; i < count; i++ { - runTestServer(authorizedKeysFilename, &addresses[i], &outputs[i]) + runTestServer(authorizedKeysPath, &addresses[i], &outputs[i]) } - if err := writeSSHConfigFile(privateKeyFilename, sshConfigFilename, addresses); err != nil { - return nil, err + err = writeSSHConfigFile(privateKeyPath, sshConfigPath, addresses) + if err != nil { + return nil, "", nil, err } - return outputs, nil + + return outputs, sshConfigPath, func() { os.RemoveAll(dir) }, nil } // generateKeyPair generates a pair of keys, the private key is written into // a file and the fingerprint of the public key into authorized_keys file for // the server to use -func generateKeyPair(privateKeyFilename, authorizedKeysFilename string) error { +func generateKeyPair(privateKeyPath, authorizedKeysPath string) error { privateKey, err := generatePrivateRSAKey() if err != nil { return err } - if err := writePrivateKeyToFile(privateKey, privateKeyFilename); err != nil { + if err := writePrivateKeyToFile(privateKey, privateKeyPath); err != nil { return err } @@ -66,28 +75,32 @@ func generateKeyPair(privateKeyFilename, authorizedKeysFilename string) error { return err } - return ioutil.WriteFile(authorizedKeysFilename, ssh.MarshalAuthorizedKey(pub), os.ModePerm) + return ioutil.WriteFile( + authorizedKeysPath, + ssh.MarshalAuthorizedKey(pub), + 0666, + ) } func generatePrivateRSAKey() (*rsa.PrivateKey, error) { return rsa.GenerateKey(rand.Reader, 2014) } -func writePrivateKeyToFile(privateKey *rsa.PrivateKey, filename string) error { +func writePrivateKeyToFile(privateKey *rsa.PrivateKey, filepath string) error { privateKeyBlock := pem.Block{ Type: "RSA PRIVATE KEY", Headers: nil, Bytes: x509.MarshalPKCS1PrivateKey(privateKey), } return ioutil.WriteFile( - filename, + filepath, pem.EncodeToMemory(&privateKeyBlock), - os.ModePerm, + 0666, ) } -func runTestServer(authorizedKeysFilename string, addr *string, out io.Writer) error { - authorizedKeysMap, err := loadAuthorizedKeys(authorizedKeysFilename) +func runTestServer(authorizedKeysPath string, addr *string, out io.Writer) error { + authorizedKeysMap, err := loadAuthorizedKeys(authorizedKeysPath) if err != nil { return err } @@ -207,10 +220,10 @@ func fingerprintSHA256(pubKey ssh.PublicKey) string { return "SHA256:" + hash } -func loadAuthorizedKeys(filename string) (map[string]bool, error) { - authorizedKeysBytes, err := ioutil.ReadFile(filename) +func loadAuthorizedKeys(filepath string) (map[string]bool, error) { + authorizedKeysBytes, err := ioutil.ReadFile(filepath) if err != nil { - return nil, errors.Wrapf(err, "failed to load %sv", filename) + return nil, errors.Wrapf(err, "failed to load %sv", filepath) } authorizedKeysMap := map[string]bool{} for len(authorizedKeysBytes) > 0 { @@ -227,7 +240,7 @@ func loadAuthorizedKeys(filename string) (map[string]bool, error) { // writes simple SSH config file for the given servers naming them server0, // server1 etc. -func writeSSHConfigFile(privateKeyFilename, sshConfigFilename string, addresses []string) error { +func writeSSHConfigFile(privateKeyPath, sshConfigPath string, addresses []string) error { type sshRecord struct { Host string Port string @@ -236,7 +249,7 @@ func writeSSHConfigFile(privateKeyFilename, sshConfigFilename string, addresses records := make([]sshRecord, len(addresses)) for i, addr := range addresses { records[i].Host = fmt.Sprintf("server%d", i) - records[i].IdentityFilename = privateKeyFilename + records[i].IdentityFilename = privateKeyPath records[i].Port = strings.Split(addr, ":")[1] } @@ -255,7 +268,7 @@ Host {{.Host}} return err } - file, err := os.Create(sshConfigFilename) + file, err := os.Create(sshConfigPath) if err != nil { return err } From 19696dad98d3534242399b63ec7761b43cd341f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0t=C4=9Bp=C3=A1n=20Pila=C5=99?= Date: Fri, 23 Mar 2018 13:51:28 +0100 Subject: [PATCH 06/12] Separate main processing code so it can be tested repeatedly --- cmd/sup/main.go | 113 +++++++++++++++++------------------- cmd/sup/main_test.go | 61 +------------------ cmd/sup/mock_server_test.go | 2 +- ssh.go | 70 +++++++++++----------- 4 files changed, 90 insertions(+), 156 deletions(-) diff --git a/cmd/sup/main.go b/cmd/sup/main.go index da99e77..a4fa7de 100644 --- a/cmd/sup/main.go +++ b/cmd/sup/main.go @@ -18,14 +18,8 @@ import ( ) var ( - supfile string - envVars flagStringSlice - sshConfig string - onlyHosts string - exceptHosts string - - debug bool - disablePrefix bool + supfile string + flags options showVersion bool showHelp bool @@ -49,22 +43,32 @@ func (f *flagStringSlice) Set(value string) error { return nil } +type options struct { + envVars flagStringSlice + sshConfig string + onlyHosts string + exceptHosts string + debug bool + disablePrefix bool +} + func init() { flag.StringVar(&supfile, "f", "", "Custom path to ./Supfile[.yml]") - flag.Var(&envVars, "e", "Set environment variables") - flag.Var(&envVars, "env", "Set environment variables") - flag.StringVar(&sshConfig, "sshconfig", "", "Read SSH Config file, ie. ~/.ssh/config file") - flag.StringVar(&onlyHosts, "only", "", "Filter hosts using regexp") - flag.StringVar(&exceptHosts, "except", "", "Filter out hosts using regexp") - - flag.BoolVar(&debug, "D", false, "Enable debug mode") - flag.BoolVar(&debug, "debug", false, "Enable debug mode") - flag.BoolVar(&disablePrefix, "disable-prefix", false, "Disable hostname prefix") flag.BoolVar(&showVersion, "v", false, "Print version") flag.BoolVar(&showVersion, "version", false, "Print version") flag.BoolVar(&showHelp, "h", false, "Show help") flag.BoolVar(&showHelp, "help", false, "Show help") + + flag.Var(&flags.envVars, "e", "Set environment variables") + flag.Var(&flags.envVars, "env", "Set environment variables") + flag.StringVar(&flags.sshConfig, "sshconfig", "", "Read SSH Config file, ie. ~/.ssh/config file") + flag.StringVar(&flags.onlyHosts, "only", "", "Filter hosts using regexp") + flag.StringVar(&flags.exceptHosts, "except", "", "Filter out hosts using regexp") + flag.BoolVar(&flags.debug, "D", false, "Enable debug mode") + flag.BoolVar(&flags.debug, "debug", false, "Enable debug mode") + flag.BoolVar(&flags.disablePrefix, "disable-prefix", false, "Disable hostname prefix") + } func networkUsage(conf *sup.Supfile) { @@ -106,10 +110,9 @@ func cmdUsage(conf *sup.Supfile) { // parseArgs parses args and returns network and commands to be run. // On error, it prints usage and exits. -func parseArgs(conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { +func parseArgs(args []string, conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { var commands []*sup.Command - args := flag.Args() if len(args) < 1 { networkUsage(conf) return nil, nil, ErrUsage @@ -230,25 +233,29 @@ func main() { os.Exit(1) } } - conf, err := sup.NewSupfile(data) - if err != nil { + + if err := runSupfile(flags, flag.Args(), data); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } +} - // Parse network and commands to be run from args. - network, commands, err := parseArgs(conf) +func runSupfile(options options, args []string, data []byte) error { + conf, err := sup.NewSupfile(data) if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + return err } - + // Parse network and commands to be run from flags. + network, commands, err := parseArgs(args, conf) + if err != nil { + return err + } + fmt.Println(network.Hosts) // --only flag filters hosts - if onlyHosts != "" { - expr, err := regexp.CompilePOSIX(onlyHosts) + if options.onlyHosts != "" { + expr, err := regexp.CompilePOSIX(options.onlyHosts) if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + return err } var hosts []string @@ -258,39 +265,34 @@ func main() { } } if len(hosts) == 0 { - fmt.Fprintln(os.Stderr, fmt.Errorf("no hosts match --only '%v' regexp", onlyHosts)) - os.Exit(1) + return fmt.Errorf("no hosts match --only '%v' regexp", options.onlyHosts) } network.Hosts = hosts } - // --except flag filters out hosts - if exceptHosts != "" { - expr, err := regexp.CompilePOSIX(exceptHosts) + if options.exceptHosts != "" { + expr, err := regexp.CompilePOSIX(options.exceptHosts) if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + return err } var hosts []string for _, host := range network.Hosts { + fmt.Println(host) if !expr.MatchString(host) { hosts = append(hosts, host) } } if len(hosts) == 0 { - fmt.Fprintln(os.Stderr, fmt.Errorf("no hosts left after --except '%v' regexp", onlyHosts)) - os.Exit(1) + return fmt.Errorf("no hosts left after --except '%v' regexp", options.onlyHosts) } network.Hosts = hosts } - // --sshconfig flag location for ssh_config file - if sshConfig != "" { - confHosts, err := sshconfig.ParseSSHConfig(resolvePath(sshConfig)) + if options.sshConfig != "" { + confHosts, err := sshconfig.ParseSSHConfig(resolvePath(options.sshConfig)) if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + return err } // flatten Host -> *SSHHost, not the prettiest @@ -312,19 +314,16 @@ func main() { } } } - var vars sup.EnvList for _, val := range append(conf.Env, network.Env...) { vars.Set(val.Key, val.Value) } if err := vars.ResolveValues(); err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + return err } - // Parse CLI --env flag env vars, define $SUP_ENV and override values defined in Supfile. var cliVars sup.EnvList - for _, env := range envVars { + for _, env := range options.envVars { if len(env) == 0 { continue } @@ -338,7 +337,6 @@ func main() { vars.Set(env[:i], env[i+1:]) cliVars.Set(env[:i], env[i+1:]) } - // SUP_ENV is generated only from CLI env vars. // Separate loop to omit duplicates. supEnv := "" @@ -346,20 +344,15 @@ func main() { supEnv += fmt.Sprintf(" -e %v=%q", v.Key, v.Value) } vars.Set("SUP_ENV", strings.TrimSpace(supEnv)) - // Create new Stackup app. app, err := sup.New(conf) if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + return err } - app.Debug(debug) - app.Prefix(!disablePrefix) + app.Debug(options.debug) + app.Prefix(!options.disablePrefix) // Run all the commands in the given network. - err = app.Run(network, vars, commands...) - if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } + return app.Run(network, vars, commands...) + } diff --git a/cmd/sup/main_test.go b/cmd/sup/main_test.go index 00227bd..c334da2 100644 --- a/cmd/sup/main_test.go +++ b/cmd/sup/main_test.go @@ -1,13 +1,7 @@ package main import ( - "flag" - "fmt" - "os" "testing" - - "github.com/mikkeloscar/sshconfig" - "github.com/pressly/sup" ) func TestSSH(t *testing.T) { @@ -17,9 +11,6 @@ func TestSSH(t *testing.T) { } defer cleanup() - flag.CommandLine = flag.NewFlagSet("test", flag.ExitOnError) - flag.CommandLine.Parse([]string{"local", "t"}) - input := ` --- version: 0.4 @@ -41,56 +32,8 @@ targets: - test - test2 ` - conf, err := sup.NewSupfile([]byte(input)) - if err != nil { - t.Fatal(err) - } - - network, commands, err := parseArgs(conf) - if err != nil { - t.Fatal(err) - } - - confHosts, err := sshconfig.ParseSSHConfig(sshConfigPath) - if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } - - // flatten Host -> *SSHHost, not the prettiest - // but will do - confMap := map[string]*sshconfig.SSHHost{} - for _, conf := range confHosts { - for _, host := range conf.Host { - confMap[host] = conf - } - } - - // check network.Hosts for match - for i, host := range network.Hosts { - conf, found := confMap[host] - if found { - network.User = conf.User - network.IdentityFile = resolvePath(conf.IdentityFile) - network.Hosts[i] = fmt.Sprintf("%s:%d", conf.HostName, conf.Port) - } - } - - var vars sup.EnvList - for _, val := range append(conf.Env, network.Env...) { - vars.Set(val.Key, val.Value) - } - if err := vars.ResolveValues(); err != nil { - t.Fatal(err) - } - - app, err := sup.New(conf) - if err != nil { - t.Fatal(err) - } - - err = app.Run(network, vars, commands...) - if err != nil { + args := []string{"--sshconfig", sshConfigPath, "local", "t"} + if err := runSupfile(args, []byte(input)); err != nil { t.Fatal(err) } diff --git a/cmd/sup/mock_server_test.go b/cmd/sup/mock_server_test.go index ed2cfec..202e264 100644 --- a/cmd/sup/mock_server_test.go +++ b/cmd/sup/mock_server_test.go @@ -13,12 +13,12 @@ import ( "io/ioutil" "net" "os" + "path" "strings" "text/template" "github.com/pkg/errors" "golang.org/x/crypto/ssh" - "path" ) // setupMockEnv prepares testing environment, it diff --git a/ssh.go b/ssh.go index 43c4546..225b13e 100644 --- a/ssh.go +++ b/ssh.go @@ -9,7 +9,6 @@ import ( "os/user" "path/filepath" "strings" - "sync" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" @@ -30,6 +29,7 @@ type SSHClient struct { running bool env string //export FOO="bar"; export BAR="baz"; color string + authMethod ssh.AuthMethod } type ErrConnect struct { @@ -77,44 +77,43 @@ func (c *SSHClient) parseHost(host string) error { return nil } -var initAuthMethodOnce sync.Once -var authMethod ssh.AuthMethod +// ensureAuthMethod initiates SSH authentication method. +func (c *SSHClient) ensureAuthMethod() { + if c.authMethod != nil { + return + } -// initAuthMethod initiates SSH authentication method. -func initAuthMethod(identityFilePath string) func() { - return func() { - var signers []ssh.Signer + var signers []ssh.Signer - // If there's a running SSH Agent, try to use its Private keys. - sock, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) - if err == nil { - agent := agent.NewClient(sock) - signers, _ = agent.Signers() - } + // If there's a running SSH Agent, try to use its Private keys. + sock, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) + if err == nil { + agent := agent.NewClient(sock) + signers, _ = agent.Signers() + } - // Try to read user's SSH private keys form the standard paths. - files, _ := filepath.Glob(os.Getenv("HOME") + "/.ssh/id_*") - // Add nonstandard path - if identityFilePath != "" { - files = append(files, identityFilePath) + // Try to read user's SSH private keys form the standard paths. + files, _ := filepath.Glob(os.Getenv("HOME") + "/.ssh/id_*") + // Add nonstandard path + if c.identityFile != "" { + files = append(files, c.identityFile) + } + for _, file := range files { + if strings.HasSuffix(file, ".pub") { + continue // Skip public keys. } - for _, file := range files { - if strings.HasSuffix(file, ".pub") { - continue // Skip public keys. - } - data, err := ioutil.ReadFile(file) - if err != nil { - continue - } - signer, err := ssh.ParsePrivateKey(data) - if err != nil { - continue - } - signers = append(signers, signer) - + data, err := ioutil.ReadFile(file) + if err != nil { + continue + } + signer, err := ssh.ParsePrivateKey(data) + if err != nil { + continue } - authMethod = ssh.PublicKeys(signers...) + signers = append(signers, signer) + } + c.authMethod = ssh.PublicKeys(signers...) } // SSHDialFunc can dial an ssh server and return a client @@ -133,8 +132,7 @@ func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error { if c.connOpened { return fmt.Errorf("Already connected") } - - initAuthMethodOnce.Do(initAuthMethod(c.identityFile)) + c.ensureAuthMethod() err := c.parseHost(host) if err != nil { @@ -144,7 +142,7 @@ func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error { config := &ssh.ClientConfig{ User: c.user, Auth: []ssh.AuthMethod{ - authMethod, + c.authMethod, }, } From 53c5d28ecf66e78b09cc27aca89cc1b29aa69e20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0t=C4=9Bp=C3=A1n=20Pila=C5=99?= Date: Fri, 23 Mar 2018 14:52:16 +0100 Subject: [PATCH 07/12] Reorder functions for readability --- cmd/sup/main.go | 269 ++++++++++++++++++++++++------------------------ 1 file changed, 134 insertions(+), 135 deletions(-) diff --git a/cmd/sup/main.go b/cmd/sup/main.go index a4fa7de..ccc3e43 100644 --- a/cmd/sup/main.go +++ b/cmd/sup/main.go @@ -71,141 +71,6 @@ func init() { } -func networkUsage(conf *sup.Supfile) { - w := &tabwriter.Writer{} - w.Init(os.Stderr, 4, 4, 2, ' ', 0) - defer w.Flush() - - // Print available networks/hosts. - fmt.Fprintln(w, "Networks:\t") - for _, name := range conf.Networks.Names { - fmt.Fprintf(w, "- %v\n", name) - network, _ := conf.Networks.Get(name) - for _, host := range network.Hosts { - fmt.Fprintf(w, "\t- %v\n", host) - } - } - fmt.Fprintln(w) -} - -func cmdUsage(conf *sup.Supfile) { - w := &tabwriter.Writer{} - w.Init(os.Stderr, 4, 4, 2, ' ', 0) - defer w.Flush() - - // Print available targets/commands. - fmt.Fprintln(w, "Targets:\t") - for _, name := range conf.Targets.Names { - cmds, _ := conf.Targets.Get(name) - fmt.Fprintf(w, "- %v\t%v\n", name, strings.Join(cmds, " ")) - } - fmt.Fprintln(w, "\t") - fmt.Fprintln(w, "Commands:\t") - for _, name := range conf.Commands.Names { - cmd, _ := conf.Commands.Get(name) - fmt.Fprintf(w, "- %v\t%v\n", name, cmd.Desc) - } - fmt.Fprintln(w) -} - -// parseArgs parses args and returns network and commands to be run. -// On error, it prints usage and exits. -func parseArgs(args []string, conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { - var commands []*sup.Command - - if len(args) < 1 { - networkUsage(conf) - return nil, nil, ErrUsage - } - - // Does the exist? - network, ok := conf.Networks.Get(args[0]) - if !ok { - networkUsage(conf) - return nil, nil, ErrUnknownNetwork - } - - hosts, err := network.ParseInventory() - if err != nil { - return nil, nil, err - } - network.Hosts = append(network.Hosts, hosts...) - - // Does the have at least one host? - if len(network.Hosts) == 0 { - networkUsage(conf) - return nil, nil, ErrNetworkNoHosts - } - - // Check for the second argument - if len(args) < 2 { - cmdUsage(conf) - return nil, nil, ErrUsage - } - - // In case of the network.Env needs an initialization - if network.Env == nil { - network.Env = make(sup.EnvList, 0) - } - - // Add default env variable with current network - network.Env.Set("SUP_NETWORK", args[0]) - - // Add default nonce - network.Env.Set("SUP_TIME", time.Now().UTC().Format(time.RFC3339)) - if os.Getenv("SUP_TIME") != "" { - network.Env.Set("SUP_TIME", os.Getenv("SUP_TIME")) - } - - // Add user - if os.Getenv("SUP_USER") != "" { - network.Env.Set("SUP_USER", os.Getenv("SUP_USER")) - } else { - network.Env.Set("SUP_USER", os.Getenv("USER")) - } - - for _, cmd := range args[1:] { - // Target? - target, isTarget := conf.Targets.Get(cmd) - if isTarget { - // Loop over target's commands. - for _, cmd := range target { - command, isCommand := conf.Commands.Get(cmd) - if !isCommand { - cmdUsage(conf) - return nil, nil, fmt.Errorf("%v: %v", ErrCmd, cmd) - } - command.Name = cmd - commands = append(commands, &command) - } - } - - // Command? - command, isCommand := conf.Commands.Get(cmd) - if isCommand { - command.Name = cmd - commands = append(commands, &command) - } - - if !isTarget && !isCommand { - cmdUsage(conf) - return nil, nil, fmt.Errorf("%v: %v", ErrCmd, cmd) - } - } - - return &network, commands, nil -} - -func resolvePath(path string) string { - if path[:2] == "~/" { - usr, err := user.Current() - if err == nil { - path = filepath.Join(usr.HomeDir, path[2:]) - } - } - return path -} - func main() { flag.Parse() @@ -240,6 +105,16 @@ func main() { } } +func resolvePath(path string) string { + if path[:2] == "~/" { + usr, err := user.Current() + if err == nil { + path = filepath.Join(usr.HomeDir, path[2:]) + } + } + return path +} + func runSupfile(options options, args []string, data []byte) error { conf, err := sup.NewSupfile(data) if err != nil { @@ -354,5 +229,129 @@ func runSupfile(options options, args []string, data []byte) error { // Run all the commands in the given network. return app.Run(network, vars, commands...) +} + +// parseArgs parses args and returns network and commands to be run. +// On error, it prints usage and exits. +func parseArgs(args []string, conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { + var commands []*sup.Command + + if len(args) < 1 { + networkUsage(conf) + return nil, nil, ErrUsage + } + + // Does the exist? + network, ok := conf.Networks.Get(args[0]) + if !ok { + networkUsage(conf) + return nil, nil, ErrUnknownNetwork + } + + hosts, err := network.ParseInventory() + if err != nil { + return nil, nil, err + } + network.Hosts = append(network.Hosts, hosts...) + + // Does the have at least one host? + if len(network.Hosts) == 0 { + networkUsage(conf) + return nil, nil, ErrNetworkNoHosts + } + + // Check for the second argument + if len(args) < 2 { + cmdUsage(conf) + return nil, nil, ErrUsage + } + + // In case of the network.Env needs an initialization + if network.Env == nil { + network.Env = make(sup.EnvList, 0) + } + + // Add default env variable with current network + network.Env.Set("SUP_NETWORK", args[0]) + + // Add default nonce + network.Env.Set("SUP_TIME", time.Now().UTC().Format(time.RFC3339)) + if os.Getenv("SUP_TIME") != "" { + network.Env.Set("SUP_TIME", os.Getenv("SUP_TIME")) + } + + // Add user + if os.Getenv("SUP_USER") != "" { + network.Env.Set("SUP_USER", os.Getenv("SUP_USER")) + } else { + network.Env.Set("SUP_USER", os.Getenv("USER")) + } + for _, cmd := range args[1:] { + // Target? + target, isTarget := conf.Targets.Get(cmd) + if isTarget { + // Loop over target's commands. + for _, cmd := range target { + command, isCommand := conf.Commands.Get(cmd) + if !isCommand { + cmdUsage(conf) + return nil, nil, fmt.Errorf("%v: %v", ErrCmd, cmd) + } + command.Name = cmd + commands = append(commands, &command) + } + } + + // Command? + command, isCommand := conf.Commands.Get(cmd) + if isCommand { + command.Name = cmd + commands = append(commands, &command) + } + + if !isTarget && !isCommand { + cmdUsage(conf) + return nil, nil, fmt.Errorf("%v: %v", ErrCmd, cmd) + } + } + + return &network, commands, nil +} + +func networkUsage(conf *sup.Supfile) { + w := &tabwriter.Writer{} + w.Init(os.Stderr, 4, 4, 2, ' ', 0) + defer w.Flush() + + // Print available networks/hosts. + fmt.Fprintln(w, "Networks:\t") + for _, name := range conf.Networks.Names { + fmt.Fprintf(w, "- %v\n", name) + network, _ := conf.Networks.Get(name) + for _, host := range network.Hosts { + fmt.Fprintf(w, "\t- %v\n", host) + } + } + fmt.Fprintln(w) +} + +func cmdUsage(conf *sup.Supfile) { + w := &tabwriter.Writer{} + w.Init(os.Stderr, 4, 4, 2, ' ', 0) + defer w.Flush() + + // Print available targets/commands. + fmt.Fprintln(w, "Targets:\t") + for _, name := range conf.Targets.Names { + cmds, _ := conf.Targets.Get(name) + fmt.Fprintf(w, "- %v\t%v\n", name, strings.Join(cmds, " ")) + } + fmt.Fprintln(w, "\t") + fmt.Fprintln(w, "Commands:\t") + for _, name := range conf.Commands.Names { + cmd, _ := conf.Commands.Get(name) + fmt.Fprintf(w, "- %v\t%v\n", name, cmd.Desc) + } + fmt.Fprintln(w) } From 32b524882040fb5006dcdd14d45e7612a668befa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0t=C4=9Bp=C3=A1n=20Pila=C5=99?= Date: Sat, 24 Mar 2018 11:42:07 +0100 Subject: [PATCH 08/12] Add tests for existing functionality as a sanity check --- cmd/sup/main.go | 3 +- cmd/sup/main_test.go | 881 ++++++++++++++++++++++++++++++++++++++- cmd/sup/matchers_test.go | 20 + 3 files changed, 892 insertions(+), 12 deletions(-) diff --git a/cmd/sup/main.go b/cmd/sup/main.go index ccc3e43..4b34c8f 100644 --- a/cmd/sup/main.go +++ b/cmd/sup/main.go @@ -125,7 +125,7 @@ func runSupfile(options options, args []string, data []byte) error { if err != nil { return err } - fmt.Println(network.Hosts) + // --only flag filters hosts if options.onlyHosts != "" { expr, err := regexp.CompilePOSIX(options.onlyHosts) @@ -153,7 +153,6 @@ func runSupfile(options options, args []string, data []byte) error { var hosts []string for _, host := range network.Hosts { - fmt.Println(host) if !expr.MatchString(host) { hosts = append(hosts, host) } diff --git a/cmd/sup/main_test.go b/cmd/sup/main_test.go index c334da2..0042460 100644 --- a/cmd/sup/main_test.go +++ b/cmd/sup/main_test.go @@ -1,10 +1,257 @@ package main import ( + "fmt" + "os" + "os/user" + "strings" "testing" + "time" ) -func TestSSH(t *testing.T) { +func TestInvalidYaml(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 +efewf we +we kp re +` + if err := runSupfile(options{}, []string{}, []byte(input)); err == nil { + t.Fatal("Expected an error") + } +} + +func TestNoNetworkSpecified(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + if err := runSupfile(options{}, []string{}, []byte(input)); err != ErrUsage { + t.Fatal(err) + } +} + +func TestUnknownNetwork(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + if err := runSupfile(options{}, []string{"production"}, []byte(input)); err != ErrUnknownNetwork { + t.Fatal(err) + } +} + +func TestNoHosts(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + +commands: + step1: + run: echo "Hey over there" +` + if err := runSupfile(options{}, []string{"staging"}, []byte(input)); err != ErrNetworkNoHosts { + t.Fatal(err) + } +} + +func TestNoCommand(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + if err := runSupfile(options{}, []string{"staging"}, []byte(input)); err != ErrUsage { + t.Fatal(err) + } +} + +func TestNonexistentCommandOrTarget(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server2 + +commands: + step1: + run: echo "Hey over there" + step2: + run: echo "Hey again" + +targets: + walk: + - step1 + - step2 +` + if err := runSupfile(options{}, []string{"staging", "step5"}, []byte(input)); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), ErrCmd.Error()) { + t.Fatalf("Expected `%v` error, got `%v`", ErrCmd, err) + } +} + +func TestTargetReferencingNonexistentCommand(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server2 + +commands: + step1: + run: echo "Hey over there" + step2: + run: echo "Hey again" + +targets: + walk: + - step1 + - step2 + - step3 +` + if err := runSupfile(options{}, []string{"staging", "walk"}, []byte(input)); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), ErrCmd.Error()) { + t.Fatalf("Expected `%v` error, got `%v`", ErrCmd, err) + } +} + +func TestOneCommand(t *testing.T) { + t.Parallel() + + outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 3) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + options := options{ + sshConfig: sshConfigPath, + } + args := []string{"staging", "step1"} + if err := runSupfile(options, args, []byte(input)); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 2) + m.expectNoActivityOnServers(1) + m.expectCommandOnActiveServers(`echo "Hey over there"`) +} + +func TestSequenceOfCommands(t *testing.T) { + t.Parallel() + + outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 3) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server2 + +commands: + step1: + run: echo "Hey over there" + step2: + run: echo "Hey again" +` + options := options{ + sshConfig: sshConfigPath, + } + args := []string{"staging", "step1", "step2"} + if err := runSupfile(options, args, []byte(input)); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 2) + m.expectNoActivityOnServers(1) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + m.expectCommandOnActiveServers(`echo "Hey again"`) +} + +func TestTarget(t *testing.T) { + t.Parallel() + outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 3) if err != nil { t.Fatal(err) @@ -16,30 +263,644 @@ func TestSSH(t *testing.T) { version: 0.4 networks: - local: + staging: hosts: - server0 - server2 commands: - test: + step1: run: echo "Hey over there" - test2: + step2: run: echo "Hey again" targets: - t: - - test - - test2 + walk: + - step1 + - step2 ` - args := []string{"--sshconfig", sshConfigPath, "local", "t"} - if err := runSupfile(args, []byte(input)); err != nil { + options := options{ + sshConfig: sshConfigPath, + } + args := []string{"staging", "walk"} + if err := runSupfile(options, args, []byte(input)); err != nil { t.Fatal(err) } m := newMatcher(outputs, t) m.expectActivityOnServers(0, 2) m.expectNoActivityOnServers(1) - m.expectExportOnActiveServers(`SUP_NETWORK="local"`) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + m.expectCommandOnActiveServers(`echo "Hey again"`) +} + +func TestOnlyHosts(t *testing.T) { + t.Parallel() + + outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 3) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + options := options{ + sshConfig: sshConfigPath, + onlyHosts: "server2", + } + args := []string{"staging", "step1"} + if err := runSupfile(options, args, []byte(input)); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(2) + m.expectNoActivityOnServers(0, 1) + m.expectCommandOnActiveServers(`echo "Hey over there"`) +} + +func TestOnlyHostsEmpty(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + options := options{ + onlyHosts: "server42", + } + if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), "no hosts match") { + t.Fatalf("Expected a different error, got `%v`", err) + } +} + +func TestOnlyHostsInvalid(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + options := options{ + onlyHosts: "server(", + } + if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), "error parsing regexp") { + t.Fatalf("Expected a different error, got `%v`", err) + } +} + +func TestExceptHosts(t *testing.T) { + t.Parallel() + + outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 3) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + options := options{ + sshConfig: sshConfigPath, + exceptHosts: "server(1|2)", + } + args := []string{"staging", "step1"} + if err := runSupfile(options, args, []byte(input)); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0) + m.expectNoActivityOnServers(1, 2) + m.expectCommandOnActiveServers(`echo "Hey over there"`) +} + +func TestExceptHostsEmpty(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + options := options{ + exceptHosts: "server", + } + if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), "no hosts left") { + t.Fatalf("Expected a different error, got `%v`", err) + } +} + +func TestExceptHostsInvalid(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + - server2 + +commands: + step1: + run: echo "Hey over there" +` + options := options{ + exceptHosts: "server(", + } + if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), "error parsing regexp") { + t.Fatalf("Expected a different error, got `%v`", err) + } +} + +func TestInventory(t *testing.T) { + t.Parallel() + + outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 3) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + input := ` +--- +version: 0.4 + +networks: + staging: + inventory: array=( 0 2 ); for i in "${array[@]}"; do printf "server$i\n\n# comment\n"; done + +commands: + step1: + run: echo "Hey over there" +` + options := options{ + sshConfig: sshConfigPath, + } + args := []string{"staging", "step1"} + if err := runSupfile(options, args, []byte(input)); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 2) + m.expectNoActivityOnServers(1) + m.expectCommandOnActiveServers(`echo "Hey over there"`) +} + +func TestFailedInventory(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + inventory: this won't compile + +commands: + step1: + run: echo "Hey over there" +` + args := []string{"staging", "step1"} + if err := runSupfile(options{}, args, []byte(input)); err == nil { + t.Fatal("Expected an error") + } +} + +func TestSupVariables(t *testing.T) { + t.Parallel() + + // these tests need to run in order because they mess with env vars + t.Run("default", func(t *testing.T) { + if time.Now().Hour() == 23 && time.Now().Minute() == 59 { + t.Skip("Skipping test") + } + + outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + options := options{ + sshConfig: sshConfigPath, + } + if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err != nil { + t.Fatal(err) + } + currentUser, err := user.Current() + if err != nil { + t.Fatal(err) + } + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`SUP_NETWORK="staging"`) + m.expectExportOnActiveServers(`SUP_ENV=""`) + m.expectExportOnActiveServers(fmt.Sprintf(`SUP_USER="%s"`, currentUser.Name)) + m.expectExportRegexpOnActiveServers(`SUP_HOST="localhost:\d+"`) + }) + + t.Run("default SUP_TIME", func(t *testing.T) { + + if time.Now().Hour() == 23 && time.Now().Minute() == 59 { + t.Skip("Skipping SUP_TIME test because it might fail around midnight") + } + + outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + options := options{ + sshConfig: sshConfigPath, + } + if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err != nil { + t.Fatal(err) + } + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportRegexpOnActiveServers( + fmt.Sprintf( + `SUP_TIME="%4d-%02d-%02dT[0-2][0-9]:[0-5][0-9]:[0-5][0-9]Z"`, + time.Now().Year(), + time.Now().Month(), + time.Now().Day(), + ), + ) + }) + + t.Run("SUP_TIME env var set", func(t *testing.T) { + outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + os.Setenv("SUP_TIME", "now") + options := options{ + sshConfig: sshConfigPath, + } + if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err != nil { + t.Fatal(err) + } + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`SUP_TIME="now"`) + }) + + t.Run("SUP_USER env var set", func(t *testing.T) { + outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + os.Setenv("SUP_USER", "sup_rules") + options := options{ + sshConfig: sshConfigPath, + } + if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err != nil { + t.Fatal(err) + } + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`SUP_USER="sup_rules"`) + }) +} + +func TestInvalidVars(t *testing.T) { + t.Parallel() + + _, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + input := ` +--- +version: 0.4 + +env: + TODAYS_SPECIAL: this won't compile + +networks: + staging: + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + options := options{ + sshConfig: sshConfigPath, + } + args := []string{"staging", "step1"} + if err := runSupfile(options, args, []byte(input)); err == nil { + t.Fatal("Expected an error") + } + +} + +func TestEnvLevelVars(t *testing.T) { + t.Parallel() + + outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + input := ` +--- +version: 0.4 + +env: + TODAYS_SPECIAL: "dog milk" + +networks: + staging: + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + options := options{ + sshConfig: sshConfigPath, + } + args := []string{"staging", "step1"} + if err := runSupfile(options, args, []byte(input)); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`TODAYS_SPECIAL="dog milk"`) + m.expectCommandOnActiveServers(`echo "Hey over there"`) +} + +func TestNetworkLevelVars(t *testing.T) { + t.Parallel() + + outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + input := ` +--- +version: 0.4 + +env: + TODAYS_SPECIAL: "dog milk" + +networks: + staging: + env: + TODAYS_SPECIAL: "Trout a la Crème" + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + options := options{ + sshConfig: sshConfigPath, + } + args := []string{"staging", "step1"} + if err := runSupfile(options, args, []byte(input)); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`TODAYS_SPECIAL="Trout a la Crème"`) + m.expectCommandOnActiveServers(`echo "Hey over there"`) +} + +func TestCommandLineLevelVars(t *testing.T) { + t.Parallel() + + outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + input := ` +--- +version: 0.4 + +env: + TODAYS_SPECIAL: "dog milk" + +networks: + staging: + env: + TODAYS_SPECIAL: "Trout a la Crème" + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + options := options{ + sshConfig: sshConfigPath, + envVars: []string{"IM_HERE", "TODAYS_SPECIAL=Gazpacho"}, + } + args := []string{"staging", "step1"} + if err := runSupfile(options, args, []byte(input)); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`IM_HERE=""`) + m.expectExportOnActiveServers(`TODAYS_SPECIAL="Gazpacho"`) + m.expectExportOnActiveServers(`SUP_ENV="-e TODAYS_SPECIAL="Gazpacho""`) + m.expectCommandOnActiveServers(`echo "Hey over there"`) +} + +func TestEmptyCommandLineLevelVars(t *testing.T) { + t.Parallel() + + outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + input := ` +--- +version: 0.4 + +env: + TODAYS_SPECIAL: "dog milk" + +networks: + staging: + env: + TODAYS_SPECIAL: "Trout a la Crème" + hosts: + - server0 + - server1 + +commands: + step1: + run: echo "Hey over there" +` + options := options{ + sshConfig: sshConfigPath, + envVars: []string{""}, + } + args := []string{"staging", "step1"} + if err := runSupfile(options, args, []byte(input)); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`TODAYS_SPECIAL="Trout a la Crème"`) + m.expectExportOnActiveServers(`SUP_ENV=""`) m.expectCommandOnActiveServers(`echo "Hey over there"`) } diff --git a/cmd/sup/matchers_test.go b/cmd/sup/matchers_test.go index 7266402..fcd266a 100644 --- a/cmd/sup/matchers_test.go +++ b/cmd/sup/matchers_test.go @@ -3,6 +3,7 @@ package main import ( "bytes" "fmt" + "regexp" "strings" "testing" ) @@ -57,6 +58,25 @@ func (m matcher) expectExportOnActiveServers(export string) { } }) } +func (m matcher) expectExportRegexpOnActiveServers(export string) { + m.onEachActiveServer(func(server int, output string) { + for i, executed := range strings.Split(output, "\n") { + re, err := regexp.Compile(fmt.Sprintf("export %s;", export)) + if err != nil { + m.t.Fatal(err) + } + if !re.MatchString(executed) { + m.t.Errorf( + "command #%d on server #%d does not export `%s`:\n%s", + i, + server, + export, + executed, + ) + } + } + }) +} func (m matcher) expectCommandOnActiveServers(command string) { m.onEachActiveServer(func(server int, output string) { From 9332cda5949a3f373c457cb90aa6a7939b1d8f43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0t=C4=9Bp=C3=A1n=20Pila=C5=99?= Date: Sat, 24 Mar 2018 18:10:11 +0100 Subject: [PATCH 09/12] Pass error stream as an explicit dependency This allows for a clean test output. --- cmd/sup/main.go | 33 +++++++++++++------------ cmd/sup/main_test.go | 59 ++++++++++++++++++++++++-------------------- supfile.go | 8 +++--- 3 files changed, 53 insertions(+), 47 deletions(-) diff --git a/cmd/sup/main.go b/cmd/sup/main.go index 4b34c8f..bf8675f 100644 --- a/cmd/sup/main.go +++ b/cmd/sup/main.go @@ -15,6 +15,7 @@ import ( "github.com/mikkeloscar/sshconfig" "github.com/pkg/errors" "github.com/pressly/sup" + "io" ) var ( @@ -99,7 +100,7 @@ func main() { } } - if err := runSupfile(flags, flag.Args(), data); err != nil { + if err := runSupfile(os.Stderr, flags, flag.Args(), data); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } @@ -115,13 +116,13 @@ func resolvePath(path string) string { return path } -func runSupfile(options options, args []string, data []byte) error { - conf, err := sup.NewSupfile(data) +func runSupfile(errStream io.Writer, options options, args []string, data []byte) error { + conf, err := sup.NewSupfile(errStream, data) if err != nil { return err } // Parse network and commands to be run from flags. - network, commands, err := parseArgs(args, conf) + network, commands, err := parseArgs(errStream, args, conf) if err != nil { return err } @@ -232,22 +233,22 @@ func runSupfile(options options, args []string, data []byte) error { // parseArgs parses args and returns network and commands to be run. // On error, it prints usage and exits. -func parseArgs(args []string, conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { +func parseArgs(errStream io.Writer, args []string, conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { var commands []*sup.Command if len(args) < 1 { - networkUsage(conf) + networkUsage(errStream, conf) return nil, nil, ErrUsage } // Does the exist? network, ok := conf.Networks.Get(args[0]) if !ok { - networkUsage(conf) + networkUsage(errStream, conf) return nil, nil, ErrUnknownNetwork } - hosts, err := network.ParseInventory() + hosts, err := network.ParseInventory(errStream) if err != nil { return nil, nil, err } @@ -255,13 +256,13 @@ func parseArgs(args []string, conf *sup.Supfile) (*sup.Network, []*sup.Command, // Does the have at least one host? if len(network.Hosts) == 0 { - networkUsage(conf) + networkUsage(errStream, conf) return nil, nil, ErrNetworkNoHosts } // Check for the second argument if len(args) < 2 { - cmdUsage(conf) + cmdUsage(errStream, conf) return nil, nil, ErrUsage } @@ -294,7 +295,7 @@ func parseArgs(args []string, conf *sup.Supfile) (*sup.Network, []*sup.Command, for _, cmd := range target { command, isCommand := conf.Commands.Get(cmd) if !isCommand { - cmdUsage(conf) + cmdUsage(errStream, conf) return nil, nil, fmt.Errorf("%v: %v", ErrCmd, cmd) } command.Name = cmd @@ -310,7 +311,7 @@ func parseArgs(args []string, conf *sup.Supfile) (*sup.Network, []*sup.Command, } if !isTarget && !isCommand { - cmdUsage(conf) + cmdUsage(errStream, conf) return nil, nil, fmt.Errorf("%v: %v", ErrCmd, cmd) } } @@ -318,9 +319,9 @@ func parseArgs(args []string, conf *sup.Supfile) (*sup.Network, []*sup.Command, return &network, commands, nil } -func networkUsage(conf *sup.Supfile) { +func networkUsage(errStream io.Writer, conf *sup.Supfile) { w := &tabwriter.Writer{} - w.Init(os.Stderr, 4, 4, 2, ' ', 0) + w.Init(errStream, 4, 4, 2, ' ', 0) defer w.Flush() // Print available networks/hosts. @@ -335,9 +336,9 @@ func networkUsage(conf *sup.Supfile) { fmt.Fprintln(w) } -func cmdUsage(conf *sup.Supfile) { +func cmdUsage(errStream io.Writer, conf *sup.Supfile) { w := &tabwriter.Writer{} - w.Init(os.Stderr, 4, 4, 2, ' ', 0) + w.Init(errStream, 4, 4, 2, ' ', 0) defer w.Flush() // Print available targets/commands. diff --git a/cmd/sup/main_test.go b/cmd/sup/main_test.go index 0042460..1ca7dec 100644 --- a/cmd/sup/main_test.go +++ b/cmd/sup/main_test.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "io/ioutil" "os" "os/user" "strings" @@ -9,6 +10,10 @@ import ( "time" ) +var ( + testErrStream = ioutil.Discard +) + func TestInvalidYaml(t *testing.T) { t.Parallel() @@ -18,7 +23,7 @@ version: 0.4 efewf we we kp re ` - if err := runSupfile(options{}, []string{}, []byte(input)); err == nil { + if err := runSupfile(testErrStream, options{}, []string{}, []byte(input)); err == nil { t.Fatal("Expected an error") } } @@ -40,7 +45,7 @@ commands: step1: run: echo "Hey over there" ` - if err := runSupfile(options{}, []string{}, []byte(input)); err != ErrUsage { + if err := runSupfile(testErrStream, options{}, []string{}, []byte(input)); err != ErrUsage { t.Fatal(err) } } @@ -62,7 +67,7 @@ commands: step1: run: echo "Hey over there" ` - if err := runSupfile(options{}, []string{"production"}, []byte(input)); err != ErrUnknownNetwork { + if err := runSupfile(testErrStream, options{}, []string{"production"}, []byte(input)); err != ErrUnknownNetwork { t.Fatal(err) } } @@ -82,7 +87,7 @@ commands: step1: run: echo "Hey over there" ` - if err := runSupfile(options{}, []string{"staging"}, []byte(input)); err != ErrNetworkNoHosts { + if err := runSupfile(testErrStream, options{}, []string{"staging"}, []byte(input)); err != ErrNetworkNoHosts { t.Fatal(err) } } @@ -104,7 +109,7 @@ commands: step1: run: echo "Hey over there" ` - if err := runSupfile(options{}, []string{"staging"}, []byte(input)); err != ErrUsage { + if err := runSupfile(testErrStream, options{}, []string{"staging"}, []byte(input)); err != ErrUsage { t.Fatal(err) } } @@ -133,7 +138,7 @@ targets: - step1 - step2 ` - if err := runSupfile(options{}, []string{"staging", "step5"}, []byte(input)); err == nil { + if err := runSupfile(testErrStream, options{}, []string{"staging", "step5"}, []byte(input)); err == nil { t.Fatal("Expected an error") } else if !strings.Contains(err.Error(), ErrCmd.Error()) { t.Fatalf("Expected `%v` error, got `%v`", ErrCmd, err) @@ -165,7 +170,7 @@ targets: - step2 - step3 ` - if err := runSupfile(options{}, []string{"staging", "walk"}, []byte(input)); err == nil { + if err := runSupfile(testErrStream, options{}, []string{"staging", "walk"}, []byte(input)); err == nil { t.Fatal("Expected an error") } else if !strings.Contains(err.Error(), ErrCmd.Error()) { t.Fatalf("Expected `%v` error, got `%v`", ErrCmd, err) @@ -199,7 +204,7 @@ commands: sshConfig: sshConfigPath, } args := []string{"staging", "step1"} - if err := runSupfile(options, args, []byte(input)); err != nil { + if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { t.Fatal(err) } @@ -238,7 +243,7 @@ commands: sshConfig: sshConfigPath, } args := []string{"staging", "step1", "step2"} - if err := runSupfile(options, args, []byte(input)); err != nil { + if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { t.Fatal(err) } @@ -283,7 +288,7 @@ targets: sshConfig: sshConfigPath, } args := []string{"staging", "walk"} - if err := runSupfile(options, args, []byte(input)); err != nil { + if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { t.Fatal(err) } @@ -323,7 +328,7 @@ commands: onlyHosts: "server2", } args := []string{"staging", "step1"} - if err := runSupfile(options, args, []byte(input)); err != nil { + if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { t.Fatal(err) } @@ -354,7 +359,7 @@ commands: options := options{ onlyHosts: "server42", } - if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err == nil { + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err == nil { t.Fatal("Expected an error") } else if !strings.Contains(err.Error(), "no hosts match") { t.Fatalf("Expected a different error, got `%v`", err) @@ -382,7 +387,7 @@ commands: options := options{ onlyHosts: "server(", } - if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err == nil { + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err == nil { t.Fatal("Expected an error") } else if !strings.Contains(err.Error(), "error parsing regexp") { t.Fatalf("Expected a different error, got `%v`", err) @@ -418,7 +423,7 @@ commands: exceptHosts: "server(1|2)", } args := []string{"staging", "step1"} - if err := runSupfile(options, args, []byte(input)); err != nil { + if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { t.Fatal(err) } @@ -449,7 +454,7 @@ commands: options := options{ exceptHosts: "server", } - if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err == nil { + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err == nil { t.Fatal("Expected an error") } else if !strings.Contains(err.Error(), "no hosts left") { t.Fatalf("Expected a different error, got `%v`", err) @@ -477,7 +482,7 @@ commands: options := options{ exceptHosts: "server(", } - if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err == nil { + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err == nil { t.Fatal("Expected an error") } else if !strings.Contains(err.Error(), "error parsing regexp") { t.Fatalf("Expected a different error, got `%v`", err) @@ -509,7 +514,7 @@ commands: sshConfig: sshConfigPath, } args := []string{"staging", "step1"} - if err := runSupfile(options, args, []byte(input)); err != nil { + if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { t.Fatal(err) } @@ -535,7 +540,7 @@ commands: run: echo "Hey over there" ` args := []string{"staging", "step1"} - if err := runSupfile(options{}, args, []byte(input)); err == nil { + if err := runSupfile(testErrStream, options{}, args, []byte(input)); err == nil { t.Fatal("Expected an error") } } @@ -572,7 +577,7 @@ commands: options := options{ sshConfig: sshConfigPath, } - if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err != nil { + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err != nil { t.Fatal(err) } currentUser, err := user.Current() @@ -616,7 +621,7 @@ commands: options := options{ sshConfig: sshConfigPath, } - if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err != nil { + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err != nil { t.Fatal(err) } m := newMatcher(outputs, t) @@ -656,7 +661,7 @@ commands: options := options{ sshConfig: sshConfigPath, } - if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err != nil { + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err != nil { t.Fatal(err) } m := newMatcher(outputs, t) @@ -689,7 +694,7 @@ commands: options := options{ sshConfig: sshConfigPath, } - if err := runSupfile(options, []string{"staging", "step1"}, []byte(input)); err != nil { + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err != nil { t.Fatal(err) } m := newMatcher(outputs, t) @@ -728,7 +733,7 @@ commands: sshConfig: sshConfigPath, } args := []string{"staging", "step1"} - if err := runSupfile(options, args, []byte(input)); err == nil { + if err := runSupfile(testErrStream, options, args, []byte(input)); err == nil { t.Fatal("Expected an error") } @@ -764,7 +769,7 @@ commands: sshConfig: sshConfigPath, } args := []string{"staging", "step1"} - if err := runSupfile(options, args, []byte(input)); err != nil { + if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { t.Fatal(err) } @@ -806,7 +811,7 @@ commands: sshConfig: sshConfigPath, } args := []string{"staging", "step1"} - if err := runSupfile(options, args, []byte(input)); err != nil { + if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { t.Fatal(err) } @@ -849,7 +854,7 @@ commands: envVars: []string{"IM_HERE", "TODAYS_SPECIAL=Gazpacho"}, } args := []string{"staging", "step1"} - if err := runSupfile(options, args, []byte(input)); err != nil { + if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { t.Fatal(err) } @@ -894,7 +899,7 @@ commands: envVars: []string{""}, } args := []string{"staging", "step1"} - if err := runSupfile(options, args, []byte(input)); err != nil { + if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { t.Fatal(err) } diff --git a/supfile.go b/supfile.go index a116ae5..06864e8 100644 --- a/supfile.go +++ b/supfile.go @@ -255,7 +255,7 @@ func (e ErrUnsupportedSupfileVersion) Error() string { } // NewSupfile parses configuration file and returns Supfile or error. -func NewSupfile(data []byte) (*Supfile, error) { +func NewSupfile(errStream io.Writer, data []byte) (*Supfile, error) { var conf Supfile if err := yaml.Unmarshal(data, &conf); err != nil { @@ -305,7 +305,7 @@ func NewSupfile(data []byte) (*Supfile, error) { } } if warning != "" { - fmt.Fprintf(os.Stderr, warning) + fmt.Fprintf(errStream, warning) } fallthrough @@ -321,13 +321,13 @@ func NewSupfile(data []byte) (*Supfile, error) { // ParseInventory runs the inventory command, if provided, and appends // the command's output lines to the manually defined list of hosts. -func (n Network) ParseInventory() ([]string, error) { +func (n Network) ParseInventory(errStream io.Writer) ([]string, error) { if n.Inventory == "" { return nil, nil } cmd := exec.Command("/bin/sh", "-c", n.Inventory) - cmd.Stderr = os.Stderr + cmd.Stderr = errStream output, err := cmd.Output() if err != nil { return nil, err From 798e6c93c9589be3646a5db65d14f7960d9279b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0t=C4=9Bp=C3=A1n=20Pila=C5=99?= Date: Sun, 25 Mar 2018 11:04:57 +0200 Subject: [PATCH 10/12] Move file reading into testable zone --- cmd/sup/main.go | 74 ++-- cmd/sup/main_test.go | 848 ++++++++++++++++++++++-------------- cmd/sup/mock_server_test.go | 24 +- 3 files changed, 580 insertions(+), 366 deletions(-) diff --git a/cmd/sup/main.go b/cmd/sup/main.go index bf8675f..28a7113 100644 --- a/cmd/sup/main.go +++ b/cmd/sup/main.go @@ -19,8 +19,7 @@ import ( ) var ( - supfile string - flags options + flags options showVersion bool showHelp bool @@ -45,6 +44,8 @@ func (f *flagStringSlice) Set(value string) error { } type options struct { + dirname string + supfile string envVars flagStringSlice sshConfig string onlyHosts string @@ -54,13 +55,12 @@ type options struct { } func init() { - flag.StringVar(&supfile, "f", "", "Custom path to ./Supfile[.yml]") - flag.BoolVar(&showVersion, "v", false, "Print version") flag.BoolVar(&showVersion, "version", false, "Print version") flag.BoolVar(&showHelp, "h", false, "Show help") flag.BoolVar(&showHelp, "help", false, "Show help") + flag.StringVar(&flags.supfile, "f", "", "Custom path to ./Supfile[.yml]") flag.Var(&flags.envVars, "e", "Set environment variables") flag.Var(&flags.envVars, "env", "Set environment variables") flag.StringVar(&flags.sshConfig, "sshconfig", "", "Read SSH Config file, ie. ~/.ssh/config file") @@ -69,7 +69,6 @@ func init() { flag.BoolVar(&flags.debug, "D", false, "Enable debug mode") flag.BoolVar(&flags.debug, "debug", false, "Enable debug mode") flag.BoolVar(&flags.disablePrefix, "disable-prefix", false, "Disable hostname prefix") - } func main() { @@ -86,37 +85,15 @@ func main() { return } - if supfile == "" { - supfile = "./Supfile" - } - data, err := ioutil.ReadFile(resolvePath(supfile)) - if err != nil { - firstErr := err - data, err = ioutil.ReadFile("./Supfile.yml") // Alternative to ./Supfile. - if err != nil { - fmt.Fprintln(os.Stderr, firstErr) - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } - } - - if err := runSupfile(os.Stderr, flags, flag.Args(), data); err != nil { + if err := runSupfile(os.Stderr, flags, flag.Args()); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } } -func resolvePath(path string) string { - if path[:2] == "~/" { - usr, err := user.Current() - if err == nil { - path = filepath.Join(usr.HomeDir, path[2:]) - } - } - return path -} +func runSupfile(errStream io.Writer, options options, args []string) error { + data, err := readSupfile(options) -func runSupfile(errStream io.Writer, options options, args []string, data []byte) error { conf, err := sup.NewSupfile(errStream, data) if err != nil { return err @@ -231,6 +208,43 @@ func runSupfile(errStream io.Writer, options options, args []string, data []byte return app.Run(network, vars, commands...) } +func readSupfile(options options) ([]byte, error) { + filename := options.supfile + if filename == "" { + filename = "Supfile" + } + + data, err := ioutil.ReadFile(defaultPath(options.dirname, filename)) + if err != nil { + firstErr := err + // Alternative to ./Supfile. + data, err = ioutil.ReadFile(defaultPath(options.dirname, "Supfile.yml")) + if err != nil { + return []byte{}, fmt.Errorf("%v\n%v", firstErr, err) + } + } + + return data, nil +} + +func defaultPath(dir, path string) string { + if dir == "" { + return resolvePath(path) + } + return filepath.Join(dir, path) + +} + +func resolvePath(path string) string { + if path[:2] == "~/" { + usr, err := user.Current() + if err == nil { + path = filepath.Join(usr.HomeDir, path[2:]) + } + } + return path +} + // parseArgs parses args and returns network and commands to be run. // On error, it prints usage and exits. func parseArgs(errStream io.Writer, args []string, conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { diff --git a/cmd/sup/main_test.go b/cmd/sup/main_test.go index 1ca7dec..a966a37 100644 --- a/cmd/sup/main_test.go +++ b/cmd/sup/main_test.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "os" "os/user" + "path/filepath" "strings" "testing" "time" @@ -23,9 +24,15 @@ version: 0.4 efewf we we kp re ` - if err := runSupfile(testErrStream, options{}, []string{}, []byte(input)); err == nil { - t.Fatal("Expected an error") - } + + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + } + if err := runSupfile(testErrStream, options, []string{}); err == nil { + t.Fatal("Expected an error") + } + }) } func TestNoNetworkSpecified(t *testing.T) { @@ -45,9 +52,14 @@ commands: step1: run: echo "Hey over there" ` - if err := runSupfile(testErrStream, options{}, []string{}, []byte(input)); err != ErrUsage { - t.Fatal(err) - } + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + } + if err := runSupfile(testErrStream, options, []string{}); err != ErrUsage { + t.Fatal(err) + } + }) } func TestUnknownNetwork(t *testing.T) { @@ -67,9 +79,14 @@ commands: step1: run: echo "Hey over there" ` - if err := runSupfile(testErrStream, options{}, []string{"production"}, []byte(input)); err != ErrUnknownNetwork { - t.Fatal(err) - } + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + } + if err := runSupfile(testErrStream, options, []string{"production"}); err != ErrUnknownNetwork { + t.Fatal(err) + } + }) } func TestNoHosts(t *testing.T) { @@ -87,9 +104,14 @@ commands: step1: run: echo "Hey over there" ` - if err := runSupfile(testErrStream, options{}, []string{"staging"}, []byte(input)); err != ErrNetworkNoHosts { - t.Fatal(err) - } + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + } + if err := runSupfile(testErrStream, options, []string{"staging"}); err != ErrNetworkNoHosts { + t.Fatal(err) + } + }) } func TestNoCommand(t *testing.T) { @@ -109,9 +131,14 @@ commands: step1: run: echo "Hey over there" ` - if err := runSupfile(testErrStream, options{}, []string{"staging"}, []byte(input)); err != ErrUsage { - t.Fatal(err) - } + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + } + if err := runSupfile(testErrStream, options, []string{"staging"}); err != ErrUsage { + t.Fatal(err) + } + }) } func TestNonexistentCommandOrTarget(t *testing.T) { @@ -138,11 +165,16 @@ targets: - step1 - step2 ` - if err := runSupfile(testErrStream, options{}, []string{"staging", "step5"}, []byte(input)); err == nil { - t.Fatal("Expected an error") - } else if !strings.Contains(err.Error(), ErrCmd.Error()) { - t.Fatalf("Expected `%v` error, got `%v`", ErrCmd, err) - } + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + } + if err := runSupfile(testErrStream, options, []string{"staging", "step5"}); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), ErrCmd.Error()) { + t.Fatalf("Expected `%v` error, got `%v`", ErrCmd, err) + } + }) } func TestTargetReferencingNonexistentCommand(t *testing.T) { @@ -170,22 +202,21 @@ targets: - step2 - step3 ` - if err := runSupfile(testErrStream, options{}, []string{"staging", "walk"}, []byte(input)); err == nil { - t.Fatal("Expected an error") - } else if !strings.Contains(err.Error(), ErrCmd.Error()) { - t.Fatalf("Expected `%v` error, got `%v`", ErrCmd, err) - } + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + } + if err := runSupfile(testErrStream, options, []string{"staging", "walk"}); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), ErrCmd.Error()) { + t.Fatalf("Expected `%v` error, got `%v`", ErrCmd, err) + } + }) } func TestOneCommand(t *testing.T) { t.Parallel() - outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 3) - if err != nil { - t.Fatal(err) - } - defer cleanup() - input := ` --- version: 0.4 @@ -200,29 +231,27 @@ commands: step1: run: echo "Hey over there" ` - options := options{ - sshConfig: sshConfigPath, - } - args := []string{"staging", "step1"} - if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { - t.Fatal(err) - } + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 3) + if err != nil { + t.Fatal(err) + } + + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } - m := newMatcher(outputs, t) - m.expectActivityOnServers(0, 2) - m.expectNoActivityOnServers(1) - m.expectCommandOnActiveServers(`echo "Hey over there"`) + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 2) + m.expectNoActivityOnServers(1) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) } func TestSequenceOfCommands(t *testing.T) { t.Parallel() - outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 3) - if err != nil { - t.Fatal(err) - } - defer cleanup() - input := ` --- version: 0.4 @@ -239,30 +268,28 @@ commands: step2: run: echo "Hey again" ` - options := options{ - sshConfig: sshConfigPath, - } - args := []string{"staging", "step1", "step2"} - if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { - t.Fatal(err) - } + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 3) + if err != nil { + t.Fatal(err) + } + + args := []string{"staging", "step1", "step2"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } - m := newMatcher(outputs, t) - m.expectActivityOnServers(0, 2) - m.expectNoActivityOnServers(1) - m.expectCommandOnActiveServers(`echo "Hey over there"`) - m.expectCommandOnActiveServers(`echo "Hey again"`) + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 2) + m.expectNoActivityOnServers(1) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + m.expectCommandOnActiveServers(`echo "Hey again"`) + }) } func TestTarget(t *testing.T) { t.Parallel() - outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 3) - if err != nil { - t.Fatal(err) - } - defer cleanup() - input := ` --- version: 0.4 @@ -284,30 +311,28 @@ targets: - step1 - step2 ` - options := options{ - sshConfig: sshConfigPath, - } - args := []string{"staging", "walk"} - if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { - t.Fatal(err) - } + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 3) + if err != nil { + t.Fatal(err) + } - m := newMatcher(outputs, t) - m.expectActivityOnServers(0, 2) - m.expectNoActivityOnServers(1) - m.expectCommandOnActiveServers(`echo "Hey over there"`) - m.expectCommandOnActiveServers(`echo "Hey again"`) + args := []string{"staging", "walk"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 2) + m.expectNoActivityOnServers(1) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + m.expectCommandOnActiveServers(`echo "Hey again"`) + }) } func TestOnlyHosts(t *testing.T) { t.Parallel() - outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 3) - if err != nil { - t.Fatal(err) - } - defer cleanup() - input := ` --- version: 0.4 @@ -323,19 +348,23 @@ commands: step1: run: echo "Hey over there" ` - options := options{ - sshConfig: sshConfigPath, - onlyHosts: "server2", - } - args := []string{"staging", "step1"} - if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { - t.Fatal(err) - } + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 3) + if err != nil { + t.Fatal(err) + } + + options.onlyHosts = "server2" + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } - m := newMatcher(outputs, t) - m.expectActivityOnServers(2) - m.expectNoActivityOnServers(0, 1) - m.expectCommandOnActiveServers(`echo "Hey over there"`) + m := newMatcher(outputs, t) + m.expectActivityOnServers(2) + m.expectNoActivityOnServers(0, 1) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) } func TestOnlyHostsEmpty(t *testing.T) { @@ -356,14 +385,17 @@ commands: step1: run: echo "Hey over there" ` - options := options{ - onlyHosts: "server42", - } - if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err == nil { - t.Fatal("Expected an error") - } else if !strings.Contains(err.Error(), "no hosts match") { - t.Fatalf("Expected a different error, got `%v`", err) - } + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + onlyHosts: "server42", + } + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), "no hosts match") { + t.Fatalf("Expected a different error, got `%v`", err) + } + }) } func TestOnlyHostsInvalid(t *testing.T) { @@ -384,25 +416,22 @@ commands: step1: run: echo "Hey over there" ` - options := options{ - onlyHosts: "server(", - } - if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err == nil { - t.Fatal("Expected an error") - } else if !strings.Contains(err.Error(), "error parsing regexp") { - t.Fatalf("Expected a different error, got `%v`", err) - } + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + onlyHosts: "server(", + } + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), "error parsing regexp") { + t.Fatalf("Expected a different error, got `%v`", err) + } + }) } func TestExceptHosts(t *testing.T) { t.Parallel() - outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 3) - if err != nil { - t.Fatal(err) - } - defer cleanup() - input := ` --- version: 0.4 @@ -418,19 +447,22 @@ commands: step1: run: echo "Hey over there" ` - options := options{ - sshConfig: sshConfigPath, - exceptHosts: "server(1|2)", - } - args := []string{"staging", "step1"} - if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { - t.Fatal(err) - } + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 3) + if err != nil { + t.Fatal(err) + } + options.exceptHosts = "server(1|2)" + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } - m := newMatcher(outputs, t) - m.expectActivityOnServers(0) - m.expectNoActivityOnServers(1, 2) - m.expectCommandOnActiveServers(`echo "Hey over there"`) + m := newMatcher(outputs, t) + m.expectActivityOnServers(0) + m.expectNoActivityOnServers(1, 2) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) } func TestExceptHostsEmpty(t *testing.T) { @@ -451,14 +483,17 @@ commands: step1: run: echo "Hey over there" ` - options := options{ - exceptHosts: "server", - } - if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err == nil { - t.Fatal("Expected an error") - } else if !strings.Contains(err.Error(), "no hosts left") { - t.Fatalf("Expected a different error, got `%v`", err) - } + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + exceptHosts: "server", + } + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), "no hosts left") { + t.Fatalf("Expected a different error, got `%v`", err) + } + }) } func TestExceptHostsInvalid(t *testing.T) { @@ -479,25 +514,22 @@ commands: step1: run: echo "Hey over there" ` - options := options{ - exceptHosts: "server(", - } - if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err == nil { - t.Fatal("Expected an error") - } else if !strings.Contains(err.Error(), "error parsing regexp") { - t.Fatalf("Expected a different error, got `%v`", err) - } + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + exceptHosts: "server(", + } + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err == nil { + t.Fatal("Expected an error") + } else if !strings.Contains(err.Error(), "error parsing regexp") { + t.Fatalf("Expected a different error, got `%v`", err) + } + }) } func TestInventory(t *testing.T) { t.Parallel() - outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 3) - if err != nil { - t.Fatal(err) - } - defer cleanup() - input := ` --- version: 0.4 @@ -510,18 +542,22 @@ commands: step1: run: echo "Hey over there" ` - options := options{ - sshConfig: sshConfigPath, - } - args := []string{"staging", "step1"} - if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { - t.Fatal(err) - } + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 3) + if err != nil { + t.Fatal(err) + } - m := newMatcher(outputs, t) - m.expectActivityOnServers(0, 2) - m.expectNoActivityOnServers(1) - m.expectCommandOnActiveServers(`echo "Hey over there"`) + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 2) + m.expectNoActivityOnServers(1) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) } func TestFailedInventory(t *testing.T) { @@ -539,10 +575,15 @@ commands: step1: run: echo "Hey over there" ` - args := []string{"staging", "step1"} - if err := runSupfile(testErrStream, options{}, args, []byte(input)); err == nil { - t.Fatal("Expected an error") - } + withTmpDir(t, input, func(dirname string) { + options := options{ + dirname: dirname, + } + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err == nil { + t.Fatal("Expected an error") + } + }) } func TestSupVariables(t *testing.T) { @@ -554,12 +595,6 @@ func TestSupVariables(t *testing.T) { t.Skip("Skipping test") } - outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) - if err != nil { - t.Fatal(err) - } - defer cleanup() - input := ` --- version: 0.4 @@ -574,22 +609,26 @@ commands: step1: run: echo "Hey over there" ` - options := options{ - sshConfig: sshConfigPath, - } - if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err != nil { - t.Fatal(err) - } - currentUser, err := user.Current() - if err != nil { - t.Fatal(err) - } - m := newMatcher(outputs, t) - m.expectActivityOnServers(0, 1) - m.expectExportOnActiveServers(`SUP_NETWORK="staging"`) - m.expectExportOnActiveServers(`SUP_ENV=""`) - m.expectExportOnActiveServers(fmt.Sprintf(`SUP_USER="%s"`, currentUser.Name)) - m.expectExportRegexpOnActiveServers(`SUP_HOST="localhost:\d+"`) + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err != nil { + t.Fatal(err) + } + currentUser, err := user.Current() + if err != nil { + t.Fatal(err) + } + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`SUP_NETWORK="staging"`) + m.expectExportOnActiveServers(`SUP_ENV=""`) + m.expectExportOnActiveServers(fmt.Sprintf(`SUP_USER="%s"`, currentUser.Name)) + m.expectExportRegexpOnActiveServers(`SUP_HOST="localhost:\d+"`) + }) }) t.Run("default SUP_TIME", func(t *testing.T) { @@ -598,12 +637,6 @@ commands: t.Skip("Skipping SUP_TIME test because it might fail around midnight") } - outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) - if err != nil { - t.Fatal(err) - } - defer cleanup() - input := ` --- version: 0.4 @@ -618,30 +651,29 @@ commands: step1: run: echo "Hey over there" ` - options := options{ - sshConfig: sshConfigPath, - } - if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err != nil { - t.Fatal(err) - } - m := newMatcher(outputs, t) - m.expectActivityOnServers(0, 1) - m.expectExportRegexpOnActiveServers( - fmt.Sprintf( - `SUP_TIME="%4d-%02d-%02dT[0-2][0-9]:[0-5][0-9]:[0-5][0-9]Z"`, - time.Now().Year(), - time.Now().Month(), - time.Now().Day(), - ), - ) + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err != nil { + t.Fatal(err) + } + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportRegexpOnActiveServers( + fmt.Sprintf( + `SUP_TIME="%4d-%02d-%02dT[0-2][0-9]:[0-5][0-9]:[0-5][0-9]Z"`, + time.Now().Year(), + time.Now().Month(), + time.Now().Day(), + ), + ) + }) }) t.Run("SUP_TIME env var set", func(t *testing.T) { - outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) - if err != nil { - t.Fatal(err) - } - defer cleanup() input := ` --- @@ -657,24 +689,23 @@ commands: step1: run: echo "Hey over there" ` - os.Setenv("SUP_TIME", "now") - options := options{ - sshConfig: sshConfigPath, - } - if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err != nil { - t.Fatal(err) - } - m := newMatcher(outputs, t) - m.expectActivityOnServers(0, 1) - m.expectExportOnActiveServers(`SUP_TIME="now"`) + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + os.Setenv("SUP_TIME", "now") + + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err != nil { + t.Fatal(err) + } + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`SUP_TIME="now"`) + }) }) t.Run("SUP_USER env var set", func(t *testing.T) { - outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) - if err != nil { - t.Fatal(err) - } - defer cleanup() input := ` --- @@ -690,28 +721,26 @@ commands: step1: run: echo "Hey over there" ` - os.Setenv("SUP_USER", "sup_rules") - options := options{ - sshConfig: sshConfigPath, - } - if err := runSupfile(testErrStream, options, []string{"staging", "step1"}, []byte(input)); err != nil { - t.Fatal(err) - } - m := newMatcher(outputs, t) - m.expectActivityOnServers(0, 1) - m.expectExportOnActiveServers(`SUP_USER="sup_rules"`) + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + os.Setenv("SUP_USER", "sup_rules") + + if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err != nil { + t.Fatal(err) + } + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`SUP_USER="sup_rules"`) + }) }) } func TestInvalidVars(t *testing.T) { t.Parallel() - _, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) - if err != nil { - t.Fatal(err) - } - defer cleanup() - input := ` --- version: 0.4 @@ -729,25 +758,23 @@ commands: step1: run: echo "Hey over there" ` - options := options{ - sshConfig: sshConfigPath, - } - args := []string{"staging", "step1"} - if err := runSupfile(testErrStream, options, args, []byte(input)); err == nil { - t.Fatal("Expected an error") - } + withTmpDir(t, input, func(dirname string) { + _, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err == nil { + t.Fatal("Expected an error") + } + + }) } func TestEnvLevelVars(t *testing.T) { t.Parallel() - outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) - if err != nil { - t.Fatal(err) - } - defer cleanup() - input := ` --- version: 0.4 @@ -765,29 +792,27 @@ commands: step1: run: echo "Hey over there" ` - options := options{ - sshConfig: sshConfigPath, - } - args := []string{"staging", "step1"} - if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { - t.Fatal(err) - } + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } - m := newMatcher(outputs, t) - m.expectActivityOnServers(0, 1) - m.expectExportOnActiveServers(`TODAYS_SPECIAL="dog milk"`) - m.expectCommandOnActiveServers(`echo "Hey over there"`) + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`TODAYS_SPECIAL="dog milk"`) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) } func TestNetworkLevelVars(t *testing.T) { t.Parallel() - outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) - if err != nil { - t.Fatal(err) - } - defer cleanup() - input := ` --- version: 0.4 @@ -807,29 +832,27 @@ commands: step1: run: echo "Hey over there" ` - options := options{ - sshConfig: sshConfigPath, - } - args := []string{"staging", "step1"} - if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { - t.Fatal(err) - } + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } - m := newMatcher(outputs, t) - m.expectActivityOnServers(0, 1) - m.expectExportOnActiveServers(`TODAYS_SPECIAL="Trout a la Crème"`) - m.expectCommandOnActiveServers(`echo "Hey over there"`) + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`TODAYS_SPECIAL="Trout a la Crème"`) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) } func TestCommandLineLevelVars(t *testing.T) { t.Parallel() - outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) - if err != nil { - t.Fatal(err) - } - defer cleanup() - input := ` --- version: 0.4 @@ -849,32 +872,29 @@ commands: step1: run: echo "Hey over there" ` - options := options{ - sshConfig: sshConfigPath, - envVars: []string{"IM_HERE", "TODAYS_SPECIAL=Gazpacho"}, - } - args := []string{"staging", "step1"} - if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { - t.Fatal(err) - } + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + options.envVars = []string{"IM_HERE", "TODAYS_SPECIAL=Gazpacho"} + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } - m := newMatcher(outputs, t) - m.expectActivityOnServers(0, 1) - m.expectExportOnActiveServers(`IM_HERE=""`) - m.expectExportOnActiveServers(`TODAYS_SPECIAL="Gazpacho"`) - m.expectExportOnActiveServers(`SUP_ENV="-e TODAYS_SPECIAL="Gazpacho""`) - m.expectCommandOnActiveServers(`echo "Hey over there"`) + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`IM_HERE=""`) + m.expectExportOnActiveServers(`TODAYS_SPECIAL="Gazpacho"`) + m.expectExportOnActiveServers(`SUP_ENV="-e TODAYS_SPECIAL="Gazpacho""`) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) } func TestEmptyCommandLineLevelVars(t *testing.T) { t.Parallel() - outputs, sshConfigPath, cleanup, err := setupMockEnv("ssh_config", 2) - if err != nil { - t.Fatal(err) - } - defer cleanup() - input := ` --- version: 0.4 @@ -894,18 +914,198 @@ commands: step1: run: echo "Hey over there" ` - options := options{ - sshConfig: sshConfigPath, - envVars: []string{""}, + withTmpDir(t, input, func(dirname string) { + outputs, options, err := setupMockEnv(dirname, 2) + if err != nil { + t.Fatal(err) + } + options.envVars = []string{""} + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0, 1) + m.expectExportOnActiveServers(`TODAYS_SPECIAL="Trout a la Crème"`) + m.expectExportOnActiveServers(`SUP_ENV=""`) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) +} + +func TestFileOption(t *testing.T) { + t.Parallel() + + t.Run("fallbacks to Supfile.yml", func(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDirWithoutSupfile(t, func(dirname string) { + writeSupfileAs(dirname, "Supfile.yml", input) + + outputs, options, err := setupMockEnv(dirname, 1) + if err != nil { + t.Fatal(err) + } + + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) + }) + + t.Run("prefers Supfile over Supfile.yml when not specified", func(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + +commands: + step1: + run: echo "Default Supfile" +` + withTmpDir(t, input, func(dirname string) { + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + +commands: + step1: + run: echo "Supfile.yml" +` + writeSupfileAs(dirname, "Supfile.yml", input) + + outputs, options, err := setupMockEnv(dirname, 1) + if err != nil { + t.Fatal(err) + } + + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0) + m.expectCommandOnActiveServers(`echo "Default Supfile"`) + }) + }) + + t.Run("can specify arbitrary file", func(t *testing.T) { + t.Parallel() + + input := ` +--- +version: 0.4 + +networks: + staging: + hosts: + - server0 + +commands: + step1: + run: echo "Hey over there" +` + withTmpDirWithoutSupfile(t, func(dirname string) { + writeSupfileAs(dirname, "different_file_name", input) + + outputs, options, err := setupMockEnv(dirname, 1) + if err != nil { + t.Fatal(err) + } + + options.supfile = "different_file_name" + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err != nil { + t.Fatal(err) + } + + m := newMatcher(outputs, t) + m.expectActivityOnServers(0) + m.expectCommandOnActiveServers(`echo "Hey over there"`) + }) + }) + + t.Run("fails without a Supfile", func(t *testing.T) { + t.Parallel() + + withTmpDirWithoutSupfile(t, func(dirname string) { + options := options{ + dirname: dirname, + } + args := []string{"staging", "step1"} + if err := runSupfile(testErrStream, options, args); err == nil { + t.Fatal("Expected an error") + } + }) + }) +} + +func withTmpDir(t *testing.T, input string, test func(dirname string)) { + dirname, err := tmpDir() + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dirname) + + if err := writeDefaultSupfile(dirname, input); err != nil { + t.Fatal(err) } - args := []string{"staging", "step1"} - if err := runSupfile(testErrStream, options, args, []byte(input)); err != nil { + + test(dirname) +} + +func withTmpDirWithoutSupfile(t *testing.T, test func(dirname string)) { + dirname, err := tmpDir() + if err != nil { t.Fatal(err) } + defer os.RemoveAll(dirname) + + test(dirname) +} + +func tmpDir() (string, error) { + return ioutil.TempDir("", "suptest") +} + +func writeDefaultSupfile(dirname, input string) error { + return writeSupfileAs(dirname, "Supfile", input) +} - m := newMatcher(outputs, t) - m.expectActivityOnServers(0, 1) - m.expectExportOnActiveServers(`TODAYS_SPECIAL="Trout a la Crème"`) - m.expectExportOnActiveServers(`SUP_ENV=""`) - m.expectCommandOnActiveServers(`echo "Hey over there"`) +func writeSupfileAs(dirname, filename, input string) error { + return ioutil.WriteFile( + filepath.Join(dirname, filename), + []byte(input), + 0666, + ) } diff --git a/cmd/sup/mock_server_test.go b/cmd/sup/mock_server_test.go index 202e264..9fe30a8 100644 --- a/cmd/sup/mock_server_test.go +++ b/cmd/sup/mock_server_test.go @@ -29,18 +29,14 @@ import ( // - spins up mock SSH servers with the same authorized key // - writes an SSH config file with entries for all servers, naming them // server0, server1 etc. -func setupMockEnv(sshConfigFilename string, count int) ([]bytes.Buffer, string, func(), error) { - dir, err := ioutil.TempDir("", "suptest") - if err != nil { - return nil, "", nil, err - } +func setupMockEnv(dirname string, count int) ([]bytes.Buffer, options, error) { - privateKeyPath := path.Join(dir, "gotest_private_key") - authorizedKeysPath := path.Join(dir, "authorized_keys") - sshConfigPath := path.Join(dir, sshConfigFilename) + privateKeyPath := path.Join(dirname, "gotest_private_key") + authorizedKeysPath := path.Join(dirname, "authorized_keys") + sshConfigPath := path.Join(dirname, "ssh_config") if err := generateKeyPair(privateKeyPath, authorizedKeysPath); err != nil { - return nil, "", nil, err + return nil, options{}, err } outputs := make([]bytes.Buffer, count) @@ -49,12 +45,16 @@ func setupMockEnv(sshConfigFilename string, count int) ([]bytes.Buffer, string, runTestServer(authorizedKeysPath, &addresses[i], &outputs[i]) } - err = writeSSHConfigFile(privateKeyPath, sshConfigPath, addresses) + err := writeSSHConfigFile(privateKeyPath, sshConfigPath, addresses) if err != nil { - return nil, "", nil, err + return nil, options{}, err } - return outputs, sshConfigPath, func() { os.RemoveAll(dir) }, nil + options := options{ + sshConfig: sshConfigPath, + dirname: dirname, + } + return outputs, options, nil } // generateKeyPair generates a pair of keys, the private key is written into From bbaf7c50abd171e0bd51f089f98d9ef074007748 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0t=C4=9Bp=C3=A1n=20Pila=C5=99?= Date: Sun, 15 Apr 2018 11:43:58 +0200 Subject: [PATCH 11/12] Mock env in tests to simplify run on Travis CI --- cmd/sup/main.go | 19 +++++----- cmd/sup/main_test.go | 34 +++++++++++++----- cmd/sup/mock_server_test.go | 1 + vendor/github.com/adammck/venv/README.md | 40 +++++++++++++++++++++ vendor/github.com/adammck/venv/mock/mock.go | 38 ++++++++++++++++++++ vendor/github.com/adammck/venv/os/os.go | 28 +++++++++++++++ vendor/github.com/adammck/venv/venv.go | 21 +++++++++++ vendor/vendor.json | 18 ++++++++++ 8 files changed, 183 insertions(+), 16 deletions(-) create mode 100644 vendor/github.com/adammck/venv/README.md create mode 100644 vendor/github.com/adammck/venv/mock/mock.go create mode 100644 vendor/github.com/adammck/venv/os/os.go create mode 100644 vendor/github.com/adammck/venv/venv.go diff --git a/cmd/sup/main.go b/cmd/sup/main.go index 28a7113..08ac48a 100644 --- a/cmd/sup/main.go +++ b/cmd/sup/main.go @@ -3,6 +3,7 @@ package main import ( "flag" "fmt" + "io" "io/ioutil" "os" "os/user" @@ -12,10 +13,10 @@ import ( "text/tabwriter" "time" + "github.com/adammck/venv" "github.com/mikkeloscar/sshconfig" "github.com/pkg/errors" "github.com/pressly/sup" - "io" ) var ( @@ -52,6 +53,7 @@ type options struct { exceptHosts string debug bool disablePrefix bool + env venv.Env } func init() { @@ -85,6 +87,7 @@ func main() { return } + flags.env = venv.OS() if err := runSupfile(os.Stderr, flags, flag.Args()); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) @@ -99,7 +102,7 @@ func runSupfile(errStream io.Writer, options options, args []string) error { return err } // Parse network and commands to be run from flags. - network, commands, err := parseArgs(errStream, args, conf) + network, commands, err := parseArgs(errStream, options, args, conf) if err != nil { return err } @@ -247,7 +250,7 @@ func resolvePath(path string) string { // parseArgs parses args and returns network and commands to be run. // On error, it prints usage and exits. -func parseArgs(errStream io.Writer, args []string, conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { +func parseArgs(errStream io.Writer, options options, args []string, conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { var commands []*sup.Command if len(args) < 1 { @@ -290,15 +293,15 @@ func parseArgs(errStream io.Writer, args []string, conf *sup.Supfile) (*sup.Netw // Add default nonce network.Env.Set("SUP_TIME", time.Now().UTC().Format(time.RFC3339)) - if os.Getenv("SUP_TIME") != "" { - network.Env.Set("SUP_TIME", os.Getenv("SUP_TIME")) + if options.env.Getenv("SUP_TIME") != "" { + network.Env.Set("SUP_TIME", options.env.Getenv("SUP_TIME")) } // Add user - if os.Getenv("SUP_USER") != "" { - network.Env.Set("SUP_USER", os.Getenv("SUP_USER")) + if options.env.Getenv("SUP_USER") != "" { + network.Env.Set("SUP_USER", options.env.Getenv("SUP_USER")) } else { - network.Env.Set("SUP_USER", os.Getenv("USER")) + network.Env.Set("SUP_USER", options.env.Getenv("USER")) } for _, cmd := range args[1:] { diff --git a/cmd/sup/main_test.go b/cmd/sup/main_test.go index a966a37..80d354c 100644 --- a/cmd/sup/main_test.go +++ b/cmd/sup/main_test.go @@ -4,11 +4,16 @@ import ( "fmt" "io/ioutil" "os" - "os/user" "path/filepath" "strings" "testing" "time" + + "github.com/adammck/venv" +) + +const ( + envTestUser = "sup_test_user" ) var ( @@ -28,6 +33,7 @@ we kp re withTmpDir(t, input, func(dirname string) { options := options{ dirname: dirname, + env: testEnv(), } if err := runSupfile(testErrStream, options, []string{}); err == nil { t.Fatal("Expected an error") @@ -55,6 +61,7 @@ commands: withTmpDir(t, input, func(dirname string) { options := options{ dirname: dirname, + env: testEnv(), } if err := runSupfile(testErrStream, options, []string{}); err != ErrUsage { t.Fatal(err) @@ -107,6 +114,7 @@ commands: withTmpDir(t, input, func(dirname string) { options := options{ dirname: dirname, + env: testEnv(), } if err := runSupfile(testErrStream, options, []string{"staging"}); err != ErrNetworkNoHosts { t.Fatal(err) @@ -134,6 +142,7 @@ commands: withTmpDir(t, input, func(dirname string) { options := options{ dirname: dirname, + env: testEnv(), } if err := runSupfile(testErrStream, options, []string{"staging"}); err != ErrUsage { t.Fatal(err) @@ -168,6 +177,7 @@ targets: withTmpDir(t, input, func(dirname string) { options := options{ dirname: dirname, + env: testEnv(), } if err := runSupfile(testErrStream, options, []string{"staging", "step5"}); err == nil { t.Fatal("Expected an error") @@ -205,6 +215,7 @@ targets: withTmpDir(t, input, func(dirname string) { options := options{ dirname: dirname, + env: testEnv(), } if err := runSupfile(testErrStream, options, []string{"staging", "walk"}); err == nil { t.Fatal("Expected an error") @@ -389,6 +400,7 @@ commands: options := options{ dirname: dirname, onlyHosts: "server42", + env: venv.Mock(), } if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err == nil { t.Fatal("Expected an error") @@ -420,6 +432,7 @@ commands: options := options{ dirname: dirname, onlyHosts: "server(", + env: venv.Mock(), } if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err == nil { t.Fatal("Expected an error") @@ -487,6 +500,7 @@ commands: options := options{ dirname: dirname, exceptHosts: "server", + env: venv.Mock(), } if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err == nil { t.Fatal("Expected an error") @@ -518,6 +532,7 @@ commands: options := options{ dirname: dirname, exceptHosts: "server(", + env: venv.Mock(), } if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err == nil { t.Fatal("Expected an error") @@ -578,6 +593,7 @@ commands: withTmpDir(t, input, func(dirname string) { options := options{ dirname: dirname, + env: testEnv(), } args := []string{"staging", "step1"} if err := runSupfile(testErrStream, options, args); err == nil { @@ -618,15 +634,11 @@ commands: if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err != nil { t.Fatal(err) } - currentUser, err := user.Current() - if err != nil { - t.Fatal(err) - } m := newMatcher(outputs, t) m.expectActivityOnServers(0, 1) m.expectExportOnActiveServers(`SUP_NETWORK="staging"`) m.expectExportOnActiveServers(`SUP_ENV=""`) - m.expectExportOnActiveServers(fmt.Sprintf(`SUP_USER="%s"`, currentUser.Name)) + m.expectExportOnActiveServers(fmt.Sprintf(`SUP_USER="%s"`, envTestUser)) m.expectExportRegexpOnActiveServers(`SUP_HOST="localhost:\d+"`) }) }) @@ -694,7 +706,7 @@ commands: if err != nil { t.Fatal(err) } - os.Setenv("SUP_TIME", "now") + options.env.Setenv("SUP_TIME", "now") if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err != nil { t.Fatal(err) @@ -726,7 +738,7 @@ commands: if err != nil { t.Fatal(err) } - os.Setenv("SUP_USER", "sup_rules") + options.env.Setenv("SUP_USER", "sup_rules") if err := runSupfile(testErrStream, options, []string{"staging", "step1"}); err != nil { t.Fatal(err) @@ -1109,3 +1121,9 @@ func writeSupfileAs(dirname, filename, input string) error { 0666, ) } + +func testEnv() venv.Env { + env := venv.Mock() + env.Setenv("USER", envTestUser) + return env +} diff --git a/cmd/sup/mock_server_test.go b/cmd/sup/mock_server_test.go index 9fe30a8..4443e0f 100644 --- a/cmd/sup/mock_server_test.go +++ b/cmd/sup/mock_server_test.go @@ -53,6 +53,7 @@ func setupMockEnv(dirname string, count int) ([]bytes.Buffer, options, error) { options := options{ sshConfig: sshConfigPath, dirname: dirname, + env: testEnv(), } return outputs, options, nil } diff --git a/vendor/github.com/adammck/venv/README.md b/vendor/github.com/adammck/venv/README.md new file mode 100644 index 0000000..09a10ba --- /dev/null +++ b/vendor/github.com/adammck/venv/README.md @@ -0,0 +1,40 @@ +# venv + +[![GoDoc](https://godoc.org/github.com/adammck/venv?status.svg)](https://godoc.org/github.com/adammck/venv) +[![Build Status](https://travis-ci.org/adammck/venv.svg?branch=master)](https://travis-ci.org/adammck/venv) + +This is a Go library to abstract access to environment variables. +Like [spf13/afero][afero] or [blang/vfs][vfs], but for the env. + +## Usage + +```go +package main + +import ( + "fmt" + "github.com/adammck/venv" +) + +func main() { + var e venv.Env + + // Use the real environment + + e = venv.OS() + fmt.Printf("Hello, %s!\n", e.Getenv("USER")) + + // Or use a mock + + e = venv.Mock() + e.Setenv("USER", "fred") + fmt.Printf("Hello, %s!\n", e.Getenv("USER")) +} +``` + +## License + +MIT. + +[afero]: https://github.com/spf13/afero +[vfs]: https://github.com/blang/vfs diff --git a/vendor/github.com/adammck/venv/mock/mock.go b/vendor/github.com/adammck/venv/mock/mock.go new file mode 100644 index 0000000..4ca535d --- /dev/null +++ b/vendor/github.com/adammck/venv/mock/mock.go @@ -0,0 +1,38 @@ +package mock + +import ( + "fmt" +) + +type MockEnv struct { + data map[string]string +} + +func New() *MockEnv { + e := &MockEnv{} + e.Clearenv() + return e +} + +func (e *MockEnv) Environ() []string { + out := []string{} + + for k, v := range e.data { + out = append(out, fmt.Sprintf("%s=%s", k, v)) + } + + return out +} + +func (e *MockEnv) Getenv(key string) string { + return e.data[key] +} + +func (e *MockEnv) Setenv(key, value string) error { + e.data[key] = value + return nil +} + +func (e *MockEnv) Clearenv() { + e.data = make(map[string]string) +} diff --git a/vendor/github.com/adammck/venv/os/os.go b/vendor/github.com/adammck/venv/os/os.go new file mode 100644 index 0000000..3f62e83 --- /dev/null +++ b/vendor/github.com/adammck/venv/os/os.go @@ -0,0 +1,28 @@ +package os + +import ( + "os" +) + +type OsEnv struct { +} + +func New() *OsEnv { + return &OsEnv{} +} + +func (e *OsEnv) Getenv(key string) string { + return os.Getenv(key) +} + +func (e *OsEnv) Setenv(key, value string) error { + return os.Setenv(key, value) +} + +func (e *OsEnv) Environ() []string { + return os.Environ() +} + +func (e *OsEnv) Clearenv() { + os.Clearenv() +} diff --git a/vendor/github.com/adammck/venv/venv.go b/vendor/github.com/adammck/venv/venv.go new file mode 100644 index 0000000..a873cc9 --- /dev/null +++ b/vendor/github.com/adammck/venv/venv.go @@ -0,0 +1,21 @@ +package venv + +import ( + "github.com/adammck/venv/mock" + "github.com/adammck/venv/os" +) + +type Env interface { + Environ() []string + Getenv(key string) string + Setenv(key, value string) error + Clearenv() +} + +func OS() Env { + return os.New() +} + +func Mock() Env { + return mock.New() +} diff --git a/vendor/vendor.json b/vendor/vendor.json index 30fae56..202b0d7 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -2,6 +2,24 @@ "comment": "", "ignore": "test appengine", "package": [ + { + "checksumSHA1": "JpbTntZjOHcUo5P9VrdcTFEGZ9Y=", + "path": "github.com/adammck/venv", + "revision": "8a9c907a37d36a8f34fa1c5b81aaf80c2554a306", + "revisionTime": "2016-08-19T02:56:05Z" + }, + { + "checksumSHA1": "5yJLvX3vSDfIqaCqSEaIwM0B3co=", + "path": "github.com/adammck/venv/mock", + "revision": "8a9c907a37d36a8f34fa1c5b81aaf80c2554a306", + "revisionTime": "2016-08-19T02:56:05Z" + }, + { + "checksumSHA1": "FIK6KYgGeXhdfcNHH80Ub1fBvwM=", + "path": "github.com/adammck/venv/os", + "revision": "8a9c907a37d36a8f34fa1c5b81aaf80c2554a306", + "revisionTime": "2016-08-19T02:56:05Z" + }, { "path": "github.com/goware/prefixer", "revision": "395022866408d928fc2439f7eac73dd8d370ec1d", From 1b4821b6f18296b8ddc1b4ecf8522f2f7af04d3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0t=C4=9Bp=C3=A1n=20Pila=C5=99?= Date: Sun, 15 Apr 2018 13:32:22 +0200 Subject: [PATCH 12/12] Fix inventory test --- cmd/sup/main_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/sup/main_test.go b/cmd/sup/main_test.go index 80d354c..2b8d52b 100644 --- a/cmd/sup/main_test.go +++ b/cmd/sup/main_test.go @@ -551,7 +551,7 @@ version: 0.4 networks: staging: - inventory: array=( 0 2 ); for i in "${array[@]}"; do printf "server$i\n\n# comment\n"; done + inventory: printf "server0\n# comment\n\nserver2\n" commands: step1: