Skip to content

Commit

Permalink
Merge pull request #41 from ncdc/idle-timeout
Browse files Browse the repository at this point in the history
Change idle implementation to be graceful
  • Loading branch information
dmcgowan committed Feb 2, 2015
2 parents 53b120e + 6e695ab commit 13f2d13
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 51 deletions.
105 changes: 64 additions & 41 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,60 @@ type StreamHandler func(stream *Stream)
type AuthHandler func(header http.Header, slot uint8, parent uint32) bool

type idleAwareFramer struct {
f *spdy.Framer
conn *Connection
f *spdy.Framer
conn *Connection
expired *time.Timer
startCh chan struct{}
}

func newIdleAwareFramer(framer *spdy.Framer) *idleAwareFramer {
iaf := &idleAwareFramer{
f: framer,
startCh: make(chan struct{}, 1),
}
go iaf.monitor()
return iaf
}

func (i *idleAwareFramer) monitor() {
// wait for a non-zero timeout
<-i.startCh

Loop:
for {
select {
case <-i.startCh:
// no-op
case <-i.expired.C:
for _, stream := range i.conn.streams {
stream.Reset()
}
i.conn.Close()
break Loop
case <-i.conn.closeChan:
break Loop
}
}
}

func (i *idleAwareFramer) setTimeout(timeout time.Duration) {
switch {
case timeout == 0:
if i.expired != nil {
i.expired.Stop()
}
case timeout > 0:
// TODO there may be a race condition with multiple goroutines calling this,
// as there's no lock around i.expired. A lock would probably result in
// decreased performance, but that needs to be explored.
if i.expired == nil {
i.expired = time.NewTimer(timeout)
// tell the monitor it can start waiting for a timeout
i.startCh <- struct{}{}
} else {
i.expired.Reset(timeout)
}
}
}

func (i *idleAwareFramer) WriteFrame(frame spdy.Frame) error {
Expand All @@ -40,7 +92,7 @@ func (i *idleAwareFramer) WriteFrame(frame spdy.Frame) error {
}

if i.conn.idleTimeout > 0 {
i.conn.conn.SetDeadline(time.Now().Add(i.conn.idleTimeout))
i.setTimeout(i.conn.idleTimeout)
}
return nil
}
Expand All @@ -52,7 +104,7 @@ func (i *idleAwareFramer) ReadFrame() (spdy.Frame, error) {
}

if i.conn.idleTimeout > 0 {
i.conn.conn.SetDeadline(time.Now().Add(i.conn.idleTimeout))
i.setTimeout(i.conn.idleTimeout)
}
return frame, nil
}
Expand Down Expand Up @@ -94,7 +146,7 @@ func NewConnection(conn net.Conn, server bool) (*Connection, error) {
if framerErr != nil {
return nil, framerErr
}
idleAwareFramer := &idleAwareFramer{f: framer}
idleAwareFramer := newIdleAwareFramer(framer)
var sid spdy.StreamId
var rid spdy.StreamId
var pid uint32
Expand Down Expand Up @@ -245,12 +297,7 @@ func (s *Connection) Serve(newHandler StreamHandler) {
// notify streams that they're now closed, which will
// unblock any stream Read() calls
for _, stream := range s.streams {
select {
case <-stream.closeChan:
// do nothing, stream is already closed
default:
close(stream.closeChan)
}
stream.closeRemoteChannels()
}
s.streams = make(map[spdy.StreamId]*Stream)
s.streamCond.Broadcast()
Expand Down Expand Up @@ -317,8 +364,7 @@ func (s *Connection) addStreamFrame(frame *spdy.SynStreamFrame) {
closeChan: make(chan bool),
}
if frame.CFHeader.Flags&spdy.ControlFlagFin != 0x00 {
close(stream.dataChan)
close(stream.closeChan)
stream.closeRemoteChannels()
}

s.addStream(stream)
Expand Down Expand Up @@ -388,15 +434,7 @@ func (s *Connection) handleResetFrame(frame *spdy.RstStreamFrame) error {
return nil
}
s.removeStream(stream)
stream.dataLock.Lock()
select {
case <-stream.closeChan:
break
default:
close(stream.dataChan)
close(stream.closeChan)
}
stream.dataLock.Unlock()
stream.closeRemoteChannels()

if !stream.replied {
stream.replied = true
Expand Down Expand Up @@ -455,10 +493,8 @@ func (s *Connection) handleDataFrame(frame *spdy.DataFrame) error {
stream.dataLock.RLock()
select {
case <-stream.closeChan:
break
default:
debugMessage("(%p) (%d) Data frame send chan", stream, stream.streamId)
stream.dataChan <- frame.Data
debugMessage("(%p) (%d) Data frame not sent (stream shut down)", stream, stream.streamId)
case stream.dataChan <- frame.Data:
debugMessage("(%p) (%d) Data frame sent", stream, stream.streamId)
}
stream.dataLock.RUnlock()
Expand Down Expand Up @@ -506,16 +542,7 @@ func (s *Connection) handleGoAwayFrame(frame *spdy.GoAwayFrame) error {
}

func (s *Connection) remoteStreamFinish(stream *Stream) {
// synchronize closing channel
stream.dataLock.Lock()
select {
case <-stream.closeChan:
break
default:
close(stream.dataChan)
close(stream.closeChan)
}
stream.dataLock.Unlock()
stream.closeRemoteChannels()

stream.finishLock.Lock()
if stream.finished {
Expand Down Expand Up @@ -704,11 +731,7 @@ func (s *Connection) SetCloseTimeout(timeout time.Duration) {
// it is forcefully terminated.
func (s *Connection) SetIdleTimeout(timeout time.Duration) {
s.idleTimeout = timeout
if timeout > 0 {
s.conn.SetDeadline(time.Now().Add(s.idleTimeout))
} else {
s.conn.SetDeadline(time.Time{})
}
s.framer.setTimeout(timeout)
}

func (s *Connection) sendHeaders(headers http.Header, stream *Stream, fin bool) error {
Expand Down
149 changes: 149 additions & 0 deletions spdy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ func TestIdleWithData(t *testing.T) {

spdyConn.SetIdleTimeout(25 * time.Millisecond)

authenticated = true
stream, err := spdyConn.CreateStream(http.Header{}, nil, false)
if err != nil {
t.Fatalf("Error creating stream: %v", err)
Expand Down Expand Up @@ -559,6 +560,154 @@ Loop:
wg.Wait()
}

func TestHalfClosedIdleTimeout(t *testing.T) {
listener, listenErr := net.Listen("tcp", "localhost:0")
if listenErr != nil {
t.Fatalf("Error listening: %v", listenErr)
}
listen := listener.Addr().String()

var serverConn net.Conn
var serverSpdyConn *Connection
var err error
go func() {
serverConn, err = listener.Accept()
if err != nil {
t.Fatalf("Error accepting: %v", err)
}

serverSpdyConn, err = NewConnection(serverConn, true)
if err != nil {
t.Fatalf("Error creating server connection: %v", err)
}
go serverSpdyConn.Serve(func(s *Stream) {
s.SendReply(http.Header{}, true)
})
serverSpdyConn.SetIdleTimeout(10 * time.Millisecond)
}()

conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}

spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)

stream, err := spdyConn.CreateStream(http.Header{}, nil, false)
if err != nil {
t.Fatalf("Error creating stream: %v", err)
}

time.Sleep(20 * time.Millisecond)

stream.Reset()

err = spdyConn.Close()
if err != nil {
t.Fatalf("Error closing client spdy conn: %v", err)
}
}

func TestStreamReset(t *testing.T) {
var wg sync.WaitGroup
server, listen, serverErr := runServer(&wg)
if serverErr != nil {
t.Fatalf("Error initializing server: %s", serverErr)
}

conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}

spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)

authenticated = true
stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false)
if streamErr != nil {
t.Fatalf("Error creating stream: %s", streamErr)
}

buf := []byte("dskjahfkdusahfkdsahfkdsafdkas")
for i := 0; i < 10; i++ {
if _, err := stream.Write(buf); err != nil {
t.Fatalf("Error writing to stream: %s", err)
}
}
for i := 0; i < 10; i++ {
if _, err := stream.Read(buf); err != nil {
t.Fatalf("Error reading from stream: %s", err)
}
}

// fmt.Printf("Resetting...\n")
if err := stream.Reset(); err != nil {
t.Fatalf("Error reseting stream: %s", err)
}

closeErr := server.Close()
if closeErr != nil {
t.Fatalf("Error shutting down server: %s", closeErr)
}
wg.Wait()
}

func TestStreamResetWithDataRemaining(t *testing.T) {
var wg sync.WaitGroup
server, listen, serverErr := runServer(&wg)
if serverErr != nil {
t.Fatalf("Error initializing server: %s", serverErr)
}

conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}

spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)

authenticated = true
stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false)
if streamErr != nil {
t.Fatalf("Error creating stream: %s", streamErr)
}

buf := []byte("dskjahfkdusahfkdsahfkdsafdkas")
for i := 0; i < 10; i++ {
if _, err := stream.Write(buf); err != nil {
t.Fatalf("Error writing to stream: %s", err)
}
}

// read a bit to make sure a goroutine gets to <-dataChan
if _, err := stream.Read(buf); err != nil {
t.Fatalf("Error reading from stream: %s", err)
}

// fmt.Printf("Resetting...\n")
if err := stream.Reset(); err != nil {
t.Fatalf("Error reseting stream: %s", err)
}

closeErr := server.Close()
if closeErr != nil {
t.Fatalf("Error shutting down server: %s", closeErr)
}
wg.Wait()
}

var authenticated bool

func authStreamHandler(stream *Stream) {
Expand Down
27 changes: 17 additions & 10 deletions stream.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package spdystream

import (
"code.google.com/p/go.net/spdy"
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"

"code.google.com/p/go.net/spdy"
)

var (
Expand All @@ -32,6 +33,7 @@ type Stream struct {
finished bool
replyCond *sync.Cond
replied bool
closeLock sync.Mutex
closeChan chan bool
}

Expand Down Expand Up @@ -175,15 +177,7 @@ func (s *Stream) Reset() error {
s.finished = true
s.finishLock.Unlock()

s.dataLock.Lock()
select {
case <-s.closeChan:
break
default:
close(s.dataChan)
close(s.closeChan)
}
s.dataLock.Unlock()
s.closeRemoteChannels()

resetFrame := &spdy.RstStreamFrame{
StreamId: s.streamId,
Expand Down Expand Up @@ -319,3 +313,16 @@ func (s *Stream) SetReadDeadline(t time.Time) error {
func (s *Stream) SetWriteDeadline(t time.Time) error {
return s.conn.conn.SetWriteDeadline(t)
}

func (s *Stream) closeRemoteChannels() {
s.closeLock.Lock()
defer s.closeLock.Unlock()
select {
case <-s.closeChan:
default:
close(s.closeChan)
s.dataLock.Lock()
defer s.dataLock.Unlock()
close(s.dataChan)
}
}

0 comments on commit 13f2d13

Please sign in to comment.