Skip to content

Commit

Permalink
Merge pull request kubernetes#23901 from cjcullen/automated-cherry-pi…
Browse files Browse the repository at this point in the history
…ck-of-#23843-upstream-release-1.2

Automated cherry pick of kubernetes#23843
  • Loading branch information
zmerlynn committed Apr 6, 2016
2 parents 0ec985f + 0efcbf0 commit eb5efc4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 2 deletions.
36 changes: 34 additions & 2 deletions pkg/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func makeSSHTunnel(user string, signer ssh.Signer, host string) (*SSHTunnel, err

func (s *SSHTunnel) Open() error {
var err error
s.client, err = ssh.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config)
s.client, err = realTimeoutDialer.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config)
tunnelOpenCounter.Inc()
if err != nil {
tunnelOpenFailCounter.Inc()
Expand Down Expand Up @@ -163,11 +163,43 @@ func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*s
return ssh.Dial(network, addr, config)
}

// timeoutDialer wraps an sshDialer with a timeout around Dial(). The golang
// ssh library can hang indefinitely inside the Dial() call (see issue #23835).
// Wrapping all Dial() calls with a conservative timeout provides safety against
// getting stuck on that.
type timeoutDialer struct {
dialer sshDialer
timeout time.Duration
}

// 150 seconds is longer than the underlying default TCP backoff delay (127
// seconds). This timeout is only intended to catch otherwise uncaught hangs.
const sshDialTimeout = 150 * time.Second

var realTimeoutDialer sshDialer = &timeoutDialer{&realSSHDialer{}, sshDialTimeout}

func (d *timeoutDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
var client *ssh.Client
errCh := make(chan error, 1)
go func() {
defer runtime.HandleCrash()
var err error
client, err = d.dialer.Dial(network, addr, config)
errCh <- err
}()
select {
case err := <-errCh:
return client, err
case <-time.After(d.timeout):
return nil, fmt.Errorf("timed out dialing %s:%s", network, addr)
}
}

// RunSSHCommand returns the stdout, stderr, and exit code from running cmd on
// host as specific user, along with any SSH-level error.
// If user=="", it will default (like SSH) to os.Getenv("USER")
func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
return runSSHCommand(&realSSHDialer{}, cmd, user, host, signer, true)
return runSSHCommand(realTimeoutDialer, cmd, user, host, signer, true)
}

// Internal implementation of runSSHCommand, for testing
Expand Down
36 changes: 36 additions & 0 deletions pkg/ssh/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,39 @@ func TestSSHUser(t *testing.T) {
}

}

type slowDialer struct {
delay time.Duration
err error
}

func (s *slowDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
time.Sleep(s.delay)
if s.err != nil {
return nil, s.err
}
return &ssh.Client{}, nil
}

func TestTimeoutDialer(t *testing.T) {
testCases := []struct {
delay time.Duration
timeout time.Duration
err error
expectedErrString string
}{
// delay > timeout should cause ssh.Dial to timeout.
{1 * time.Second, 0, nil, "timed out dialing"},
// delay < timeout should return the result of the call to the dialer.
{0, 1 * time.Second, nil, ""},
{0, 1 * time.Second, fmt.Errorf("test dial error"), "test dial error"},
}
for _, tc := range testCases {
dialer := &timeoutDialer{&slowDialer{tc.delay, tc.err}, tc.timeout}
_, err := dialer.Dial("tcp", "addr:port", &ssh.ClientConfig{})
if len(tc.expectedErrString) == 0 && err != nil ||
!strings.Contains(fmt.Sprint(err), tc.expectedErrString) {
t.Errorf("Expected error to contain %q; got %v", tc.expectedErrString, err)
}
}
}

0 comments on commit eb5efc4

Please sign in to comment.