Skip to content

Commit

Permalink
Merge pull request #46 from ncdc/conn-close-deadlock
Browse files Browse the repository at this point in the history
Don't try to read/write if connection is closed
  • Loading branch information
dmcgowan committed Mar 9, 2015
2 parents e9bf991 + 63b695f commit e731c8f
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 37 deletions.
44 changes: 24 additions & 20 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type AuthHandler func(header http.Header, slot uint8, parent uint32) bool
type idleAwareFramer struct {
f *spdy.Framer
conn *Connection
writeLock sync.Mutex
resetChan chan struct{}
setTimeoutChan chan time.Duration
timeout time.Duration
Expand All @@ -47,8 +48,9 @@ func newIdleAwareFramer(framer *spdy.Framer) *idleAwareFramer {

func (i *idleAwareFramer) monitor() {
var (
timer *time.Timer
expired <-chan time.Time
timer *time.Timer
expired <-chan time.Time
resetChan = i.resetChan
)
Loop:
for {
Expand All @@ -67,7 +69,7 @@ Loop:
timer.Reset(timeout)
}
}
case <-i.resetChan:
case <-resetChan:
if timer != nil && i.timeout > 0 {
timer.Reset(i.timeout)
}
Expand All @@ -87,12 +89,25 @@ Loop:
if timer != nil {
timer.Stop()
}
i.writeLock.Lock()
close(resetChan)
i.resetChan = nil
i.writeLock.Unlock()
break Loop
}
}

// Drain resetChan
for _ = range resetChan {
}
}

func (i *idleAwareFramer) WriteFrame(frame spdy.Frame) error {
i.writeLock.Lock()
defer i.writeLock.Unlock()
if i.resetChan == nil {
return io.EOF
}
err := i.f.WriteFrame(frame)
if err != nil {
return err
Expand All @@ -109,15 +124,18 @@ func (i *idleAwareFramer) ReadFrame() (spdy.Frame, error) {
return nil, err
}

// resetChan should never be closed since it is only closed
// when the connection has closed its closeChan. This closure
// only occurs after all Reads have finished
// TODO (dmcgowan): refactor relationship into connection
i.resetChan <- struct{}{}

return frame, nil
}

type Connection struct {
conn net.Conn
framer *idleAwareFramer
writeLock sync.Mutex
conn net.Conn
framer *idleAwareFramer

closeChan chan bool
goneAway bool
Expand Down Expand Up @@ -209,9 +227,7 @@ func (s *Connection) Ping() (time.Duration, error) {

frame := &spdy.PingFrame{Id: pid}
startTime := time.Now()
s.writeLock.Lock()
writeErr := s.framer.WriteFrame(frame)
s.writeLock.Unlock()
if writeErr != nil {
return time.Duration(0), writeErr
}
Expand Down Expand Up @@ -512,8 +528,6 @@ func (s *Connection) handleDataFrame(frame *spdy.DataFrame) error {

func (s *Connection) handlePingFrame(frame *spdy.PingFrame) error {
if s.pingId&0x01 != frame.Id&0x01 {
s.writeLock.Lock()
defer s.writeLock.Unlock()
return s.framer.WriteFrame(frame)
}
pingChan, pingOk := s.pingChans[frame.Id]
Expand Down Expand Up @@ -663,9 +677,7 @@ func (s *Connection) Close() error {
Status: spdy.GoAwayOK,
}

s.writeLock.Lock()
err := s.framer.WriteFrame(goAwayFrame)
s.writeLock.Unlock()
if err != nil {
return err
}
Expand Down Expand Up @@ -750,8 +762,6 @@ func (s *Connection) sendHeaders(headers http.Header, stream *Stream, fin bool)
CFHeader: spdy.ControlFrameHeader{Flags: flags},
}

s.writeLock.Lock()
defer s.writeLock.Unlock()
return s.framer.WriteFrame(headerFrame)
}

Expand All @@ -767,8 +777,6 @@ func (s *Connection) sendReply(headers http.Header, stream *Stream, fin bool) er
CFHeader: spdy.ControlFrameHeader{Flags: flags},
}

s.writeLock.Lock()
defer s.writeLock.Unlock()
return s.framer.WriteFrame(replyFrame)
}

Expand All @@ -778,8 +786,6 @@ func (s *Connection) sendResetFrame(status spdy.RstStreamStatus, streamId spdy.S
Status: status,
}

s.writeLock.Lock()
defer s.writeLock.Unlock()
return s.framer.WriteFrame(resetFrame)
}

Expand All @@ -806,8 +812,6 @@ func (s *Connection) sendStream(stream *Stream, fin bool) error {
CFHeader: spdy.ControlFrameHeader{Flags: flags},
}

s.writeLock.Lock()
defer s.writeLock.Unlock()
return s.framer.WriteFrame(streamFrame)
}

Expand Down
119 changes: 106 additions & 13 deletions spdy_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package spdystream

import (
"bufio"
"bytes"
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -322,8 +324,6 @@ func TestUnexpectedRemoteConnectionClosed(t *testing.T) {
if e == nil || e != io.EOF {
t.Fatalf("(%d) Expected to get an EOF stream error", tix)
}
case <-time.After(500 * time.Millisecond):
t.Fatalf("(%d) Timeout waiting for stream closure", tix)
}

closeErr = conn.Close()
Expand Down Expand Up @@ -381,8 +381,6 @@ func TestCloseNotification(t *testing.T) {
var serverConn net.Conn
select {
case serverConn = <-serverConnChan:
case <-time.After(500 * time.Millisecond):
t.Fatal("Timed out waiting for connection closed notification")
}

err = serverConn.Close()
Expand Down Expand Up @@ -522,11 +520,7 @@ func TestIdleNoData(t *testing.T) {
go spdyConn.Serve(NoOpStreamHandler)

spdyConn.SetIdleTimeout(10 * time.Millisecond)
select {
case <-spdyConn.CloseChan():
case <-time.After(20 * time.Millisecond):
t.Fatal("Timed out waiting for idle connection closure")
}
<-spdyConn.CloseChan()

closeErr := server.Close()
if closeErr != nil {
Expand Down Expand Up @@ -577,8 +571,6 @@ func TestIdleWithData(t *testing.T) {

writesFinished := false

expired := time.NewTimer(200 * time.Millisecond)

Loop:
for {
select {
Expand All @@ -589,8 +581,6 @@ Loop:
t.Fatal("Connection closed before all writes finished")
}
break Loop
case <-expired.C:
t.Fatal("Timed out waiting for idle connection closure")
}
}

Expand Down Expand Up @@ -784,6 +774,109 @@ func TestStreamResetWithDataRemaining(t *testing.T) {
wg.Wait()
}

type roundTripper struct {
conn net.Conn
}

func (s *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
r := *req
req = &r

conn, err := net.Dial("tcp", req.URL.Host)
if err != nil {
return nil, err
}

err = req.Write(conn)
if err != nil {
return nil, err
}

resp, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
return nil, err
}

s.conn = conn

return resp, nil
}

// see https://github.com/GoogleCloudPlatform/kubernetes/issues/4882
func TestFramingAfterRemoteConnectionClosed(t *testing.T) {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamCh := make(chan *Stream)

w.WriteHeader(http.StatusSwitchingProtocols)

netconn, _, _ := w.(http.Hijacker).Hijack()
conn, _ := NewConnection(netconn, true)
go conn.Serve(func(s *Stream) {
s.SendReply(http.Header{}, false)
streamCh <- s
})

stream := <-streamCh
io.Copy(stream, stream)

closeChan := make(chan struct{})
go func() {
stream.Reset()
conn.Close()
close(closeChan)
}()

<-closeChan
}))

server.Start()
defer server.Close()

req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatalf("Error creating request: %s", err)
}

rt := &roundTripper{}
client := &http.Client{Transport: rt}

_, err = client.Do(req)
if err != nil {
t.Fatalf("unexpected error from client.Do: %s", err)
}

conn, err := NewConnection(rt.conn, false)
go conn.Serve(NoOpStreamHandler)

stream, err := conn.CreateStream(http.Header{}, nil, false)
if err != nil {
t.Fatalf("error creating client stream: %s", err)
}

n, err := stream.Write([]byte("hello"))
if err != nil {
t.Fatalf("error writing to stream: %s", err)
}
if n != 5 {
t.Fatalf("Expected to write 5 bytes, but actually wrote %d", n)
}

b := make([]byte, 5)
n, err = stream.Read(b)
if err != nil {
t.Fatalf("error reading from stream: %s", err)
}
if n != 5 {
t.Fatalf("Expected to read 5 bytes, but actually read %d", n)
}
if e, a := "hello", string(b[0:n]); e != a {
t.Fatalf("expected '%s', got '%s'", e, a)
}

stream.Reset()
conn.Close()
}

var authenticated bool

func authStreamHandler(stream *Stream) {
Expand Down
4 changes: 0 additions & 4 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ func (s *Stream) WriteData(data []byte, fin bool) error {
Data: data,
}

s.conn.writeLock.Lock()
defer s.conn.writeLock.Unlock()
debugMessage("(%p) (%d) Writing data frame", s, s.streamId)
return s.conn.framer.WriteFrame(dataFrame)
}
Expand Down Expand Up @@ -186,8 +184,6 @@ func (s *Stream) resetStream() error {
StreamId: s.streamId,
Status: spdy.Cancel,
}
s.conn.writeLock.Lock()
defer s.conn.writeLock.Unlock()
return s.conn.framer.WriteFrame(resetFrame)
}

Expand Down

0 comments on commit e731c8f

Please sign in to comment.