Skip to content

Commit

Permalink
Fix possible blocking on socket write or connection close (when using…
Browse files Browse the repository at this point in the history
… TLS)

Ensure that all socket writes are protected with deadlines.
For connection Close(), also use deadlines since in case of TLS,
the Close() will send an alert (do a write) if the handshake was
completed. If the peer is not reading, this would cause the Close()
to hang.
  • Loading branch information
kozlovic committed May 24, 2016
1 parent 7d79fa1 commit 188f7bf
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 24 deletions.
50 changes: 32 additions & 18 deletions server/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,26 +363,40 @@ func (c *client) maxPayloadViolation(sz int) {
c.closeConnection()
}

// Assume the lock is held upon entry.
func (c *client) sendProto(info []byte, doFlush bool) error {
var err error
if c.bw != nil && c.nc != nil {
deadlineSet := false
if doFlush || c.bw.Available() < len(info) {
c.nc.SetWriteDeadline(time.Now().Add(DEFAULT_FLUSH_DEADLINE))
deadlineSet = true
}
_, err = c.bw.Write(info)
if err == nil && doFlush {
err = c.bw.Flush()
}
if deadlineSet {
c.nc.SetWriteDeadline(time.Time{})
}
}
return err
}

// Assume the lock is held upon entry.
func (c *client) sendInfo(info []byte) {
c.bw.Write(info)
c.bw.Flush()
c.sendProto(info, true)
}

func (c *client) sendErr(err string) {
c.mu.Lock()
if c.bw != nil {
c.bw.WriteString(fmt.Sprintf("-ERR '%s'\r\n", err))
// Flush errors in place.
c.bw.Flush()
//c.pcd[c] = needFlush
}
c.sendProto([]byte(fmt.Sprintf("-ERR '%s'\r\n", err)), true)
c.mu.Unlock()
}

func (c *client) sendOK() {
c.mu.Lock()
c.bw.WriteString("+OK\r\n")
c.sendProto([]byte("+OK\r\n"), false)
c.pcd[c] = needFlush
c.mu.Unlock()
}
Expand All @@ -395,8 +409,7 @@ func (c *client) processPing() {
return
}
c.traceOutOp("PONG", nil)
c.bw.WriteString("PONG\r\n")
err := c.bw.Flush()
err := c.sendProto([]byte("PONG\r\n"), true)
if err != nil {
c.clearConnection()
c.Debugf("Error on Flush, error %s", err.Error())
Expand Down Expand Up @@ -734,7 +747,7 @@ func (c *client) deliverMsg(sub *subscription, mh, msg []byte) {
// will wait for flush to complete.

deadlineSet := false
if client.bw.Available() < (len(mh) + len(msg) + len(CR_LF)) {
if client.bw.Available() < (len(mh) + len(msg)) {
client.wfc += 1
client.nc.SetWriteDeadline(time.Now().Add(DEFAULT_FLUSH_DEADLINE))
deadlineSet = true
Expand Down Expand Up @@ -929,19 +942,15 @@ func (c *client) processPingTimer() {
c.pout++
if c.pout > c.srv.opts.MaxPingsOut {
c.Debugf("Stale Client Connection - Closing")
if c.bw != nil {
c.bw.WriteString(fmt.Sprintf("-ERR '%s'\r\n", "Stale Connection"))
c.bw.Flush()
}
c.sendProto([]byte(fmt.Sprintf("-ERR '%s'\r\n", "Stale Connection")), true)
c.clearConnection()
return
}

c.traceOutOp("PING", nil)

// Send PING
c.bw.WriteString("PING\r\n")
err := c.bw.Flush()
err := c.sendProto([]byte("PING\r\n"), true)
if err != nil {
c.Debugf("Error on Client Ping Flush, error %s", err)
c.clearConnection()
Expand Down Expand Up @@ -994,8 +1003,13 @@ func (c *client) clearConnection() {
if c.nc == nil {
return
}
// With TLS, Close() is sending an alert (that is doing a write).
// Need to set a deadline otherwise the server could block there
// if the peer is not reading from socket.
c.nc.SetWriteDeadline(time.Now().Add(DEFAULT_FLUSH_DEADLINE))
c.bw.Flush()
c.nc.Close()
c.nc.SetWriteDeadline(time.Time{})
}

func (c *client) typeString() string {
Expand Down
89 changes: 89 additions & 0 deletions server/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"testing"
"time"

"crypto/tls"
"github.com/nats-io/nats"
)

Expand Down Expand Up @@ -609,3 +610,91 @@ func TestUnsubRace(t *testing.T) {

wg.Wait()
}

func TestTLSCloseClientConnection(t *testing.T) {
opts, err := ProcessConfigFile("./configs/tls.conf")
if err != nil {
t.Fatalf("Error processign config file: %v", err)
}
opts.Authorization = ""
opts.TLSTimeout = 100
s := New(opts)
if s == nil {
panic("No NATS Server object returned.")
}
// Run server in Go routine.
go s.Start()
defer s.Shutdown()

endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port)
conn, err := net.DialTimeout("tcp", endpoint, 2*time.Second)
if err != nil {
t.Fatalf("Unexpected error on dial: %v", err)
}
defer conn.Close()
br := bufio.NewReaderSize(conn, 100)
if _, err := br.ReadString('\n'); err != nil {
t.Fatalf("Unexpected error reading INFO: %v", err)
}

tlsConn := tls.Client(conn, &tls.Config{InsecureSkipVerify: true})
defer tlsConn.Close()
if err := tlsConn.Handshake(); err != nil {
t.Fatalf("Unexpected error during handshake: %v", err)
}
br = bufio.NewReaderSize(tlsConn, 100)
connectOp := []byte("CONNECT {\"verbose\":false,\"pedantic\":false,\"tls_required\":true}\r\n")
if _, err := tlsConn.Write(connectOp); err != nil {
t.Fatalf("Unexpected error writing CONNECT: %v", err)
}
if _, err := tlsConn.Write([]byte("PING\r\n")); err != nil {
t.Fatalf("Unexpected error writing PING: %v", err)
}
if _, err := br.ReadString('\n'); err != nil {
t.Fatalf("Unexpected error reading PONG: %v", err)
}

getClient := func() *client {
s.mu.Lock()
defer s.mu.Unlock()
for _, c := range s.clients {
return c
}
return nil
}
// Wait for client to be registered.
timeout := time.Now().Add(5 * time.Second)
var cli *client
for time.Now().Before(timeout) {
cli = getClient()
if cli != nil {
break
}
}
if cli == nil {
t.Fatal("Did not register client on time")
}
// Fill the buffer. Need to send 1 byte at a time so that we timeout here
// the nc.Close() would block due to a write that can not complete.
done := false
for !done {
cli.nc.SetWriteDeadline(time.Now().Add(time.Second))
if _, err := cli.nc.Write([]byte("a")); err != nil {
done = true
}
cli.nc.SetWriteDeadline(time.Time{})
}
ch := make(chan bool)
go func() {
select {
case <-ch:
return
case <-time.After(3 * time.Second):
fmt.Println("!!!! closeConnection is blocked, test will hang !!!")
return
}
}()
// Close the client
cli.closeConnection()
ch <- true
}
10 changes: 4 additions & 6 deletions server/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ func (c *client) sendConnect(tlsRequired bool) {
c.closeConnection()
return
}
c.bw.WriteString(fmt.Sprintf(ConProto, b))
c.bw.Flush()
c.sendProto([]byte(fmt.Sprintf(ConProto, b)), true)
}

// Process the info message if we are a route.
Expand Down Expand Up @@ -249,8 +248,7 @@ func (s *Server) sendLocalSubsToRoute(route *client) {

route.mu.Lock()
defer route.mu.Unlock()
route.bw.Write(b.Bytes())
route.bw.Flush()
route.sendProto(b.Bytes(), true)

route.Debugf("Route sent local subscriptions")
}
Expand Down Expand Up @@ -495,12 +493,12 @@ func (s *Server) broadcastInterestToRoutes(proto string) {
if atomic.LoadInt32(&trace) == 1 {
arg = []byte(proto[:len(proto)-LEN_CR_LF])
}
protoAsBytes := []byte(proto)
s.mu.Lock()
for _, route := range s.routes {
// FIXME(dlc) - Make same logic as deliverMsg
route.mu.Lock()
route.bw.WriteString(proto)
route.bw.Flush()
route.sendProto(protoAsBytes, true)
route.mu.Unlock()
route.traceOutOp("", arg)
}
Expand Down

0 comments on commit 188f7bf

Please sign in to comment.