Skip to content

Commit

Permalink
ssh: fail keyboard-interactive auth with unexpectedMessageError() whe…
Browse files Browse the repository at this point in the history
…n auth fails before receiving the UserAuthInfoRequest from server
  • Loading branch information
samiponkanen committed Jun 16, 2024
1 parent d4e7c9c commit 95b457c
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 0 deletions.
5 changes: 5 additions & 0 deletions ssh/client_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
}

gotMsgExtInfo := false
gotUserAuthInfoRequest := false
for {
packet, err := c.readPacket()
if err != nil {
Expand Down Expand Up @@ -585,6 +586,9 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
if msg.PartialSuccess {
return authPartialSuccess, msg.Methods, nil
}
if !gotUserAuthInfoRequest {
return authFailure, msg.Methods, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
}
return authFailure, msg.Methods, nil
case msgUserAuthSuccess:
return authSuccess, nil, nil
Expand All @@ -596,6 +600,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
if err := Unmarshal(packet, &msg); err != nil {
return authFailure, nil, err
}
gotUserAuthInfoRequest = true

// Manually unpack the prompt/echo pairs.
rest := msg.Prompts
Expand Down
144 changes: 144 additions & 0 deletions ssh/client_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"os"
"runtime"
"strings"
"sync"
"testing"
)

Expand Down Expand Up @@ -1282,3 +1283,146 @@ func TestCertAuthOpenSSHCompat(t *testing.T) {
t.Fatalf("unable to dial remote side: %s", err)
}
}

func TestKeyboardInteractiveAuthEarlyFail(t *testing.T) {
const maxAuthTries = 2

// Start testserver
dst, err := func() (string, error) {
var serverAddr string
var serverErr error
var wg sync.WaitGroup
wg.Add(1)

go func() {
config := &ServerConfig{
MaxAuthTries: maxAuthTries,
KeyboardInteractiveCallback: func(c ConnMetadata,
client KeyboardInteractiveChallenge) (*Permissions, error) {
// Fail keyboard-interactive authentication early before
// any prompt is sent to client.
return nil, errors.New("keyboard-interactive auth failed")
},
PasswordCallback: func(c ConnMetadata,
pass []byte) (*Permissions, error) {
if string(pass) == clientPassword {
return nil, nil
}
return nil, errors.New("password auth failed")
},
}

// Use a static hostkey.
// This key has been generated with following ssh-keygen command
// and used exclusively in this unit test:
// $ ssh-keygen -t RSA -b 2048 -f /tmp/static_hostkey \
// -C "Static RSA hostkey for golang.org/x/crypto/ssh unit test"

const privKeyData = `-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABFwAAAAdzc2gtcn
NhAAAAAwEAAQAAAQEAsg9ZsQ3vbWppRLe2NzzUIV5NcPbpO5EBvLyzfItURKmYHmwa6aoy
P34fmEG3BVVx5f1pgw54Rdaic4hG2p2nvGIijTktDxFz+tREwwMfywpwrlJbGslUTi0TKO
jTWkyDACjMwo65yXbsSZLq+8rGD3uinf3Vq1bVlaEckmClhWMLTsynr/YpdF2I/+InPCep
1AuaWj1dHFNL8fbWXd8xNONumkMS1i6xtP3PnzdUqN+DuoGy26x5ic3qxWVrUp69/J2J42
/B0WEYbATAfCQiL8iGeeM7Ll45GASI4r93uDnXropnHQy+RThG5BFBRiAqmzN6kncri/k5
65p63Jb33QAAA/AX6WXzF+ll8wAAAAdzc2gtcnNhAAABAQCyD1mxDe9tamlEt7Y3PNQhXk
1w9uk7kQG8vLN8i1REqZgebBrpqjI/fh+YQbcFVXHl/WmDDnhF1qJziEbanae8YiKNOS0P
EXP61ETDAx/LCnCuUlsayVROLRMo6NNaTIMAKMzCjrnJduxJkur7ysYPe6Kd/dWrVtWVoR
ySYKWFYwtOzKev9il0XYj/4ic8J6nUC5paPV0cU0vx9tZd3zE0426aQxLWLrG0/c+fN1So
34O6gbLbrHmJzerFZWtSnr38nYnjb8HRYRhsBMB8JCIvyIZ54zsuXjkYBIjiv3e4Odeuim
cdDL5FOEbkEUFGICqbM3qSdyuL+TnrmnrclvfdAAAAAwEAAQAAAQAJjuVjqbnWh8XK2InB
gVRpziQeEkMG3YvYU9DWuKv3W5s81tTDAk3cNqr/g1eNw75veCD31gkCxrjFtuUGyzu70x
DDv/P5QRiWuFpQlZRZU+Akm2skjvYllCnZIlZmHIFTutzy/LJgbC/W6zoN9h6Xqi1aicu0
fN7OP23HNcTs2gNAhzidpDMGOAxdzpcnXeQ3JCFOcv4LSi7TgmJHvLv1AgXQggSHeB8Nf1
DvS+E86O9Wm6Xj+OKRiEgrRlngNQQ0om9yhLmUMat7Nw3hn2ZSb2+ByaaYuDfQ6FAG9nno
HjxaqSHF83/8fKXJW/wku6ee3hvjTBNuuvUCLkLF48DBAAAAgQDUV28GyoWTqIR+uOa/4t
OyFjfTtTdQ6fnLaqxbiwOaPz6SXCiwE6qIEZF5Ll5QK+7tMPDKWcOak2uGaSUHcjaEXVh0
kaKwFqiFIBY3IGwCuyixjJITl+g+48SrFgLWNrNpwVrw1NOv+iBz+z6SGqcsi5f030qyzv
O2P2wkSZkQqgAAAIEA5UB0F2+Lh63JHvdoUDZutBvTgrsjpwiReIuy+7WgxeGHe54DaXTY
HMOORZM6unDRvi6uBBul7ON9Cs20UGeER+ZMA2SKXTicb0UwJuCYKKO8AIlMP0ykyfaF1p
ZJw9DciKEu92jx+e2MdhEOnIIt1jQ9e5UIMLsI/SicnseTDLUAAACBAMbV19NIEhjLczBp
MEYSBGDnyN6HWyrHCuCFLSpnHWePd6/apExGE049YRmAgLYaycmnjX9VJRKoumk9zYv1zu
W9WTuewZVuLjLpOq/4mO5/jOQortL1dUigiiA7ZTGazTFMHwG+fZdfVqgSxbMfEU2rYhND
S0UghNmRaqbzNl+JAAAAOFN0YXRpYyBSU0EgaG9zdGtleSBmb3IgZ29sYW5nLm9yZy94L2
NyeXB0by9zc2ggdW5pdCB0ZXN0AQI=
-----END OPENSSH PRIVATE KEY-----`

private, err := ParsePrivateKey([]byte(privKeyData))
if err != nil {
serverErr = err
wg.Done()
return
}
config.AddHostKey(private)

listener, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
serverErr = err
wg.Done()
return
}
serverAddr = listener.Addr().String()
wg.Done()

nConn, err := listener.Accept()
if err != nil {
return
}

conn, chans, reqs, err := NewServerConn(nConn, config)
if err != nil {
return
}
_ = conn.Close()

var connWg sync.WaitGroup
connWg.Add(1)
go func() {
defer connWg.Done()
DiscardRequests(reqs)
}()
for newChannel := range chans {
newChannel.Reject(Prohibited,
"testserver not accepting requests")
}
connWg.Wait()
}()

wg.Wait()
return serverAddr, serverErr
}()
if err != nil {
t.Fatalf("failed to start testserver: %v", err)
}

// Connect to testserver expect KeyboardInteractive() to be not called and
// PasswordCallback() to be called.
passwordCallbackCalled := false
cfg := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
RetryableAuthMethod(KeyboardInteractive(func(name,
instruction string, questions []string,
echos []bool) ([]string, error) {
t.Errorf("unexpected call to KeyboardInteractive()")
return []string{clientPassword}, nil
}), maxAuthTries),
RetryableAuthMethod(PasswordCallback(func() (secret string,
err error) {
t.Logf("PasswordCallback()")
passwordCallbackCalled = true
return clientPassword, nil
}), maxAuthTries),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}

conn, _ := Dial("tcp", dst, cfg)
if conn != nil {
conn.Close()
}

if !passwordCallbackCalled {
t.Errorf("expected PasswordCallback() to be called")
}
}

0 comments on commit 95b457c

Please sign in to comment.