Skip to content

Commit

Permalink
Update handling of account authentication expired error
Browse files Browse the repository at this point in the history
Signed-off-by: Waldemar Quevedo <wally@synadia.com>
  • Loading branch information
wallyqs committed Mar 30, 2021
1 parent 7013a58 commit bd7b51f
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 182 deletions.
7 changes: 7 additions & 0 deletions nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ const (

// AUTHENTICATION_REVOKED_ERR is for when user authorization has been revoked.
AUTHENTICATION_REVOKED_ERR = "user authentication revoked"

// ACCOUNT_AUTHENTICATION_EXPIRED_ERR is for when nats server account authorization has expired.
ACCOUNT_AUTHENTICATION_EXPIRED_ERR = "account authentication expired"
)

// Errors
Expand All @@ -98,6 +101,7 @@ var (
ErrAuthorization = errors.New("nats: authorization violation")
ErrAuthExpired = errors.New("nats: authentication expired")
ErrAuthRevoked = errors.New("nats: authentication revoked")
ErrAccountAuthExpired = errors.New("nats: account authentication expired")
ErrNoServers = errors.New("nats: no servers available for connection")
ErrJsonParse = errors.New("nats: connect message, json parse error")
ErrChanArg = errors.New("nats: argument needs to be a channel type")
Expand Down Expand Up @@ -2766,6 +2770,9 @@ func checkAuthError(e string) error {
if strings.HasPrefix(e, AUTHENTICATION_REVOKED_ERR) {
return ErrAuthRevoked
}
if strings.HasPrefix(e, ACCOUNT_AUTHENTICATION_EXPIRED_ERR) {
return ErrAccountAuthExpired
}
return nil
}

Expand Down
277 changes: 95 additions & 182 deletions nats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1436,8 +1436,8 @@ func TestUserCredentialsChainedFile(t *testing.T) {
nc.Close()
}

func TestExpiredUserCredentials(t *testing.T) {
// The goal of this test was to check how a client with an expiring JWT
func TestExpiredAuthentication(t *testing.T) {
// The goal of these tests was to check how a client with an expiring JWT
// behaves. It should receive an async -ERR indicating that the auth
// has expired, which will trigger reconnects. There, the lib should
// received -ERR for auth violation in response to the CONNECT (instead
Expand All @@ -1451,204 +1451,117 @@ func TestExpiredUserCredentials(t *testing.T) {
// So for a deterministic test, we won't use an actual NATS Server.
// Instead, we will use a mock that simply returns appropriate -ERR and
// ensure the client behaves as expected.
l, e := net.Listen("tcp", "127.0.0.1:0")
if e != nil {
t.Fatal("Could not listen on an ephemeral port")
}
tl := l.(*net.TCPListener)
defer tl.Close()

addr := tl.Addr().(*net.TCPAddr)

wg := sync.WaitGroup{}
wg.Add(1)

go func() {
defer wg.Done()
connect := 0
for {
conn, err := l.Accept()
if err != nil {
return
for _, test := range []struct {
name string
expectedProto string
expectedErr error
}{
{"expired users credentials", AUTHENTICATION_EXPIRED_ERR, ErrAuthExpired},
{"revoked users credentials", AUTHENTICATION_REVOKED_ERR, ErrAuthRevoked},
{"expired account", ACCOUNT_AUTHENTICATION_EXPIRED_ERR, ErrAccountAuthExpired},
} {
t.Run(test.name, func(t *testing.T) {
l, e := net.Listen("tcp", "127.0.0.1:0")
if e != nil {
t.Fatal("Could not listen on an ephemeral port")
}
defer conn.Close()
tl := l.(*net.TCPListener)
defer tl.Close()

info := "INFO {\"server_id\":\"foobar\",\"nonce\":\"anonce\"}\r\n"
conn.Write([]byte(info))

// Read connect and ping commands sent from the client
br := bufio.NewReaderSize(conn, 10*1024)
br.ReadLine()
br.ReadLine()

if connect++; connect == 1 {
conn.Write([]byte(fmt.Sprintf("%s%s", _PONG_OP_, _CRLF_)))
time.Sleep(300 * time.Millisecond)
conn.Write([]byte(fmt.Sprintf("-ERR '%s'\r\n", AUTHENTICATION_EXPIRED_ERR)))
} else {
conn.Write([]byte(fmt.Sprintf("-ERR '%s'\r\n", AUTHORIZATION_ERR)))
}
conn.Close()
}
}()
addr := tl.Addr().(*net.TCPAddr)

ch := make(chan bool)
errCh := make(chan error, 10)
wg := sync.WaitGroup{}
wg.Add(1)

url := fmt.Sprintf("nats://127.0.0.1:%d", addr.Port)
nc, err := Connect(url,
ReconnectWait(25*time.Millisecond),
ReconnectJitter(0, 0),
MaxReconnects(-1),
ErrorHandler(func(_ *Conn, _ *Subscription, e error) {
select {
case errCh <- e:
default:
}
}),
ClosedHandler(func(nc *Conn) {
ch <- true
}),
)
if err != nil {
t.Fatalf("Expected to connect, got %v", err)
}
defer nc.Close()
go func() {
defer wg.Done()
connect := 0
for {
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()

// We should give up since we get the same error on both tries.
if err := WaitTime(ch, 2*time.Second); err != nil {
t.Fatal("Should have closed after multiple failed attempts.")
}
if stats := nc.Stats(); stats.Reconnects > 2 {
t.Fatalf("Expected at most 2 reconnects, got %d", stats.Reconnects)
}
// We expect 3 errors, an AUTHENTICATION_EXPIRED_ERR, then 2 AUTHORIZATION_ERR
// before the connection is closed.
for i := 0; i < 3; i++ {
select {
case e := <-errCh:
if i == 0 && e != ErrAuthExpired {
t.Fatalf("Expected error %q, got %q", ErrAuthExpired, e)
} else if i > 0 && e != ErrAuthorization {
t.Fatalf("Expected error %q, got %q", ErrAuthorization, e)
}
default:
if i == 0 {
t.Fatalf("Missing %q error", ErrAuthExpired)
} else {
t.Fatalf("Missing %q error", ErrAuthorization)
}
}
}
// We should not have any more error
select {
case e := <-errCh:
t.Fatalf("Extra error: %v", e)
default:
}
// Close the listener and wait for go routine to end.
l.Close()
wg.Wait()
}
info := "INFO {\"server_id\":\"foobar\",\"nonce\":\"anonce\"}\r\n"
conn.Write([]byte(info))

func TestRevokedUserCredentials(t *testing.T) {
// Mock that the client connects and then is revoked.
l, e := net.Listen("tcp", "127.0.0.1:0")
if e != nil {
t.Fatal("Could not listen on an ephemeral port")
}
tl := l.(*net.TCPListener)
defer tl.Close()
// Read connect and ping commands sent from the client
br := bufio.NewReaderSize(conn, 10*1024)
br.ReadLine()
br.ReadLine()

addr := tl.Addr().(*net.TCPAddr)
if connect++; connect == 1 {
conn.Write([]byte(fmt.Sprintf("%s%s", _PONG_OP_, _CRLF_)))
time.Sleep(300 * time.Millisecond)
conn.Write([]byte(fmt.Sprintf("-ERR '%s'\r\n", test.expectedProto)))
} else {
conn.Write([]byte(fmt.Sprintf("-ERR '%s'\r\n", AUTHORIZATION_ERR)))
}
conn.Close()
}
}()

wg := sync.WaitGroup{}
wg.Add(1)
ch := make(chan bool)
errCh := make(chan error, 10)

go func() {
defer wg.Done()
connect := 0
for {
conn, err := l.Accept()
url := fmt.Sprintf("nats://127.0.0.1:%d", addr.Port)
nc, err := Connect(url,
ReconnectWait(25*time.Millisecond),
ReconnectJitter(0, 0),
MaxReconnects(-1),
ErrorHandler(func(_ *Conn, _ *Subscription, e error) {
select {
case errCh <- e:
default:
}
}),
ClosedHandler(func(nc *Conn) {
ch <- true
}),
)
if err != nil {
return
t.Fatalf("Expected to connect, got %v", err)
}
defer conn.Close()

info := "INFO {\"server_id\":\"foobar\",\"nonce\":\"anonce\"}\r\n"
conn.Write([]byte(info))

// Read connect and ping commands sent from the client
br := bufio.NewReaderSize(conn, 10*1024)
br.ReadLine()
br.ReadLine()
defer nc.Close()

if connect++; connect == 1 {
conn.Write([]byte(fmt.Sprintf("%s%s", _PONG_OP_, _CRLF_)))
time.Sleep(300 * time.Millisecond)
conn.Write([]byte(fmt.Sprintf("-ERR '%s'\r\n", AUTHENTICATION_REVOKED_ERR)))
} else {
conn.Write([]byte(fmt.Sprintf("-ERR '%s'\r\n", AUTHORIZATION_ERR)))
// We should give up since we get the same error on both tries.
if err := WaitTime(ch, 2*time.Second); err != nil {
t.Fatal("Should have closed after multiple failed attempts.")
}
conn.Close()
}
}()

ch := make(chan bool)
errCh := make(chan error, 10)

url := fmt.Sprintf("nats://127.0.0.1:%d", addr.Port)
nc, err := Connect(url,
ReconnectWait(25*time.Millisecond),
ReconnectJitter(0, 0),
MaxReconnects(-1),
ErrorHandler(func(_ *Conn, _ *Subscription, e error) {
select {
case errCh <- e:
default:
if stats := nc.Stats(); stats.Reconnects > 2 {
t.Fatalf("Expected at most 2 reconnects, got %d", stats.Reconnects)
}
}),
ClosedHandler(func(nc *Conn) {
ch <- true
}),
)
if err != nil {
t.Fatalf("Expected to connect, got %v", err)
}
defer nc.Close()

// We should give up since we get the same error on both tries.
if err := WaitTime(ch, 2*time.Second); err != nil {
t.Fatal("Should have closed after multiple failed attempts.")
}
if stats := nc.Stats(); stats.Reconnects > 2 {
t.Fatalf("Expected at most 2 reconnects, got %d", stats.Reconnects)
}
for i := 0; i < 3; i++ {
select {
case e := <-errCh:
if i == 0 && e != ErrAuthRevoked {
t.Fatalf("Expected error %q, got %q", ErrAuthRevoked, e)
} else if i > 0 && e != ErrAuthorization {
t.Fatalf("Expected error %q, got %q", ErrAuthorization, e)
// We expect 3 errors, the expired auth/revoke error, then 2 AUTHORIZATION_ERR
// before the connection is closed.
for i := 0; i < 3; i++ {
select {
case e := <-errCh:
if i == 0 && e != test.expectedErr {
t.Fatalf("Expected error %q, got %q", test.expectedErr, e)
} else if i > 0 && e != ErrAuthorization {
t.Fatalf("Expected error %q, got %q", ErrAuthorization, e)
}
default:
if i == 0 {
t.Fatalf("Missing %q error", test.expectedErr)
} else {
t.Fatalf("Missing %q error", ErrAuthorization)
}
}
}
default:
if i == 0 {
t.Fatalf("Missing %q error", ErrAuthRevoked)
} else {
t.Fatalf("Missing %q error", ErrAuthorization)
// We should not have any more error
select {
case e := <-errCh:
t.Fatalf("Extra error: %v", e)
default:
}
}
}
// We should not have any more error
select {
case e := <-errCh:
t.Fatalf("Extra error: %v", e)
default:
// Close the listener and wait for go routine to end.
l.Close()
wg.Wait()
})
}
// Close the listener and wait for go routine to end.
l.Close()
wg.Wait()
}

// If we are using TLS and have multiple servers we try to match the IP
Expand Down

0 comments on commit bd7b51f

Please sign in to comment.