diff --git a/ssh.go b/ssh.go index eb3cefb..7328f6f 100644 --- a/ssh.go +++ b/ssh.go @@ -5,6 +5,7 @@ import ( "io" "io/ioutil" "net" + "net/url" "os" "os/user" "path/filepath" @@ -43,16 +44,29 @@ func (e ErrConnect) Error() string { // parseHost parses and normalizes @ from a given string. func (c *SSHClient) parseHost(host string) error { - c.host = host + // https://golang.org/pkg/net/url/#URL + // [scheme:][//[userinfo@]host][/]path[?query][#fragment] + if !strings.Contains(host, "://") && !strings.HasPrefix(host, "//") { + host = "ssh://" + host + } - // Remove extra "ssh://" schema - if len(c.host) > 6 && c.host[:6] == "ssh://" { - c.host = c.host[6:] + hostURL, err := url.Parse(host) + if err != nil { + return err } - if at := strings.Index(c.host, "@"); at != -1 { - c.user = c.host[:at] - c.host = c.host[at+1:] + // Add default port, if not set + hostname := hostURL.Hostname() + port := hostURL.Port() + if port == "" { + port = "22" + } + c.host = net.JoinHostPort(hostname, port) + + if hostURL.User != nil { + if u := hostURL.User.Username(); u != "" { + c.user = u + } } // Add default user, if not set @@ -64,14 +78,7 @@ func (c *SSHClient) parseHost(host string) error { c.user = u.Username } - if strings.Index(c.host, "/") != -1 { - return ErrConnect{c.user, c.host, "unexpected slash in the host URL"} - } - - // Add default port, if not set - if strings.Index(c.host, ":") == -1 { - c.host += ":22" - } + c.env = c.env + `export SUP_HOST="` + c.host + `";` return nil } diff --git a/sup.go b/sup.go index d815068..43cddb6 100644 --- a/sup.go +++ b/sup.go @@ -70,7 +70,7 @@ func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command) // SSH client. remote := &SSHClient{ - env: env + `export SUP_HOST="` + host + `";`, + env: env, user: network.User, color: Colors[i%len(Colors)], }