Skip to content

Commit

Permalink
fix new-websocket race in wstuncli
Browse files Browse the repository at this point in the history
  • Loading branch information
tve committed Aug 10, 2015
1 parent 716a6e6 commit 4766a8b
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 88 deletions.
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ default: $(NAME)
$(NAME): *.go version
go build -o $(NAME) .

gopath:
@echo "export GOPATH=$(GOPATH)"

# the standard build produces a "local" executable, a linux tgz, and a darwin (macos) tgz
build: depend $(NAME) build/$(NAME)-linux-amd64.tgz
# build/$(NAME)-darwin-amd64.tgz build/$(NAME)-linux-arm.tgz build/$(NAME)-windows-amd64.zip
Expand Down Expand Up @@ -130,6 +133,6 @@ travis-test: lint
# that if there are errors the output of gingko refers to incorrect line numbers
# tip: if you don't like colors use ginkgo -r -noColor
test: lint
ginkgo -r
ginkgo -r -cover
ginkgo -r -race -v
ginkgo -r -race -cover
go tool cover -func=`basename $$PWD`.coverprofile
2 changes: 1 addition & 1 deletion tunnel/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ var _ = Describe("Testing misc requests", func() {
Ω(resp.StatusCode).Should(Equal(200))

// break the tunnel
wstuncli.ws.Close()
wstuncli.conn.ws.Close()
time.Sleep(20 * time.Millisecond)

// second request
Expand Down
184 changes: 99 additions & 85 deletions tunnel/wstuncli.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,32 @@ import (

var _ fmt.Formatter

// WSTunnelClient represents a persistent tunnel that can cycle through many websockets. The
// fields in this struct are relatively static/constant. The conn field points to the latest
// websocket, but it's important to realize that there may be goroutines handling older
// websockets that are not fully closed yet running at any point in time
type WSTunnelClient struct {
Token string // Rendez-vous token
Tunnel string // websocket server to connect to (ws[s]://hostname:port)
Server string // local HTTP(S) server to send received requests to (default server)
InternalServer http.Handler // internal Server to dispatch HTTP requests to
Regexp *regexp.Regexp // regexp for allowed local HTTP(S) servers
Insecure bool // accept self-signed SSL certs from local HTTPS servers
Timeout time.Duration // timeout on websocket
Connected bool // true when we have an active connection to wstunsrv
Proxy *url.URL // if non-nil, external proxy to use
Log log15.Logger // logger with "pkg=WStuncli"
StatusFd *os.File // output periodic tunnel status information
exitChan chan struct{} // channel to tell the tunnel goroutines to end
ws *websocket.Conn // websocket connection
Token string // Rendez-vous token
Tunnel string // websocket server to connect to (ws[s]://hostname:port)
Server string // local HTTP(S) server to send received requests to (default server)
InternalServer http.Handler // internal Server to dispatch HTTP requests to
Regexp *regexp.Regexp // regexp for allowed local HTTP(S) servers
Insecure bool // accept self-signed SSL certs from local HTTPS servers
Timeout time.Duration // timeout on websocket
Proxy *url.URL // if non-nil, external proxy to use
Log log15.Logger // logger with "pkg=WStuncli"
StatusFd *os.File // output periodic tunnel status information
Connected bool // true when we have an active connection to wstunsrv
exitChan chan struct{} // channel to tell the tunnel goroutines to end
conn *WSConnection
//ws *websocket.Conn // websocket connection
}

// WSConnection represents a single websocket connection
type WSConnection struct {
Log log15.Logger // logger with "ws=0x1234"
ws *websocket.Conn // websocket connection
tun *WSTunnelClient // link back to tunnel
}

var httpClient http.Client // client used for all requests, gets special transport for -insecure
Expand Down Expand Up @@ -218,9 +230,7 @@ func (t *WSTunnelClient) Start() error {
url := fmt.Sprintf("%s/_tunnel", t.Tunnel)
timer := time.NewTimer(10 * time.Second)
t.Log.Info("WS Opening", "url", url, "token", t.Token[0:5]+"...")
var err error
var resp *http.Response
t.ws, resp, err = d.Dial(url, h)
ws, resp, err := d.Dial(url, h)
if err != nil {
extra := ""
if resp != nil {
Expand All @@ -235,16 +245,18 @@ func (t *WSTunnelClient) Start() error {
t.Log.Error("Error opening connection",
"err", err.Error(), "info", extra)
} else {
t.conn = &WSConnection{ws: ws, tun: t,
Log: t.Log.New("ws", fmt.Sprintf("%p", ws))}
// Safety setting
t.ws.SetReadLimit(100 * 1024 * 1024)
ws.SetReadLimit(100 * 1024 * 1024)
// Request Loop
srv := t.Server
if t.InternalServer != nil {
srv = "<internal>"
}
t.Log.Info("WS ready", "server", srv)
t.conn.Log.Info("WS ready", "server", srv)
t.Connected = true
t.handleWsRequests()
t.conn.handleRequests()
t.Connected = false
}
// check whether we need to exit
Expand All @@ -267,104 +279,106 @@ func (t *WSTunnelClient) Stop() {

// Main function to handle WS requests: it reads a request from the socket, then forks
// a goroutine to perform the actual http request and return the result
func (t *WSTunnelClient) handleWsRequests() {
go t.pinger()
func (wsc *WSConnection) handleRequests() {
go wsc.pinger()
for {
t.ws.SetReadDeadline(time.Time{}) // separate ping-pong routine does timeout
typ, r, err := t.ws.NextReader()
wsc.ws.SetReadDeadline(time.Time{}) // separate ping-pong routine does timeout
typ, r, err := wsc.ws.NextReader()
if err != nil {
t.Log.Info("WS ReadMessage", "err", err.Error())
wsc.Log.Info("WS ReadMessage", "err", err.Error())
break
}
if typ != websocket.BinaryMessage {
t.Log.Info("WS invalid message type", "type", typ)
wsc.Log.Info("WS invalid message type", "type", typ)
break
}
// give the sender a minute to produce the request
t.ws.SetReadDeadline(time.Now().Add(time.Minute))
wsc.ws.SetReadDeadline(time.Now().Add(time.Minute))
// read request id
var id int16
_, err = fmt.Fscanf(io.LimitReader(r, 4), "%04x", &id)
if err != nil {
t.Log.Info("WS cannot read request ID", "err", err.Error())
wsc.Log.Info("WS cannot read request ID", "err", err.Error())
break
}
// read request itself
req, err := http.ReadRequest(bufio.NewReader(r))
if err != nil {
t.Log.Info("WS cannot read request body", "id", id, "err", err.Error())
wsc.Log.Info("WS cannot read request body", "id", id, "err", err.Error())
break
}
// Hand off to goroutine to finish off while we read the next request
if t.InternalServer != nil {
go t.finishInternalRequest(id, req)
if wsc.tun.InternalServer != nil {
go wsc.finishInternalRequest(id, req)
} else {
go t.finishRequest(id, req)
go wsc.finishRequest(id, req)
}
}
// delay a few seconds to allow for writes to drain and then force-close the socket
go func() {
time.Sleep(5 * time.Second)
t.ws.Close()
wsc.ws.Close()
}()
}

//===== Keep-alive ping-pong =====

// Pinger that keeps connections alive and terminates them if they seem stuck
func (t *WSTunnelClient) pinger() {
func (wsc *WSConnection) pinger() {
defer func() {
// panics may occur in WriteControl (in unit tests at least) for closed
// websocket connections
if x := recover(); x != nil {
t.Log.Error("Panic in pinger", "err", x)
wsc.Log.Error("Panic in pinger", "err", x)
}
}()
t.Log.Info("pinger starting")
wsc.Log.Info("pinger starting")
tunTimeout := wsc.tun.Timeout

// timeout handler sends a close message, waits a few seconds, then kills the socket
timeout := func() {
if t.ws == nil {
if wsc.ws == nil {
return
}
t.ws.WriteControl(websocket.CloseMessage, nil, time.Now().Add(1*time.Second))
t.Log.Info("ping timeout, closing WS")
wsc.ws.WriteControl(websocket.CloseMessage, nil, time.Now().Add(1*time.Second))
wsc.Log.Info("ping timeout, closing WS")
time.Sleep(5 * time.Second)
if t.ws != nil {
t.ws.Close()
if wsc.ws != nil {
wsc.ws.Close()
}
}
// timeout timer
timer := time.AfterFunc(t.Timeout, timeout)
timer := time.AfterFunc(tunTimeout, timeout)
// pong handler resets last pong time
ph := func(message string) error {
timer.Reset(t.Timeout)
if t.StatusFd != nil {
t.StatusFd.Seek(0, 0)
t.writeStatus()
pos, _ := t.StatusFd.Seek(0, 1)
t.StatusFd.Truncate(pos)
timer.Reset(tunTimeout)
if sf := wsc.tun.StatusFd; sf != nil {
sf.Seek(0, 0)
wsc.writeStatus()
pos, _ := sf.Seek(0, 1)
sf.Truncate(pos)
}
return nil
}
t.ws.SetPongHandler(ph)
wsc.ws.SetPongHandler(ph)
// ping loop, ends when socket is closed...
for {
if t.ws == nil {
if wsc.ws == nil {
break
}
err := t.ws.WriteControl(websocket.PingMessage, nil, time.Now().Add(t.Timeout/3))
err := wsc.ws.WriteControl(websocket.PingMessage, nil, time.Now().Add(tunTimeout/3))
if err != nil {
break
}
time.Sleep(t.Timeout / 3)
time.Sleep(tunTimeout / 3)
}
t.Log.Info("pinger ending (WS errored or closed)")
t.ws.Close()
wsc.Log.Info("pinger ending (WS errored or closed)")
wsc.ws.Close()
}

func (t *WSTunnelClient) writeStatus() {
fmt.Fprintf(t.StatusFd, "Unix: %d\n", time.Now().Unix())
fmt.Fprintf(t.StatusFd, "Time: %s\n", time.Now().UTC().Format(time.RFC3339))
func (wsc *WSConnection) writeStatus() {
fmt.Fprintf(wsc.tun.StatusFd, "Unix: %d\n", time.Now().Unix())
fmt.Fprintf(wsc.tun.StatusFd, "Time: %s\n", time.Now().UTC().Format(time.RFC3339))
}

//===== Proxy support =====
Expand Down Expand Up @@ -507,8 +521,8 @@ var wsWriterMutex sync.Mutex // mutex to allow a single goroutine to send a resp
// Issue a request to an internal handler. This duplicates some logic found in
// net.http.serve http://golang.org/src/net/http/server.go?#L1124 and
// net.http.readRequest http://golang.org/src/net/http/server.go?#L
func (t *WSTunnelClient) finishInternalRequest(id int16, req *http.Request) {
log := t.Log.New("id", id, "verb", req.Method, "uri", req.RequestURI)
func (wsc *WSConnection) finishInternalRequest(id int16, req *http.Request) {
log := wsc.Log.New("id", id, "verb", req.Method, "uri", req.RequestURI)
log.Info("HTTP issuing internal request")

// Remove hop-by-hop headers
Expand Down Expand Up @@ -539,48 +553,49 @@ func (t *WSTunnelClient) finishInternalRequest(id int16, req *http.Request) {
rw := newResponseWriter(req)

// Issue the request to the HTTP server
t.InternalServer.ServeHTTP(rw, req)
wsc.tun.InternalServer.ServeHTTP(rw, req)

err := rw.finishResponse()
if err != nil {
//dump2, _ := httputil.DumpResponse(resp, true)
//log15.Info("handleWsRequests: request error", "err", err.Error(),
// "req", string(dump), "resp", string(dump2))
log.Info("HTTP request error", "err", err.Error())
writeResponseMessage(t, id, concoctResponse(req, err.Error(), 502))
wsc.writeResponseMessage(id, concoctResponse(req, err.Error(), 502))
return
}

log.Info("HTTP responded", "status", rw.resp.StatusCode)
writeResponseMessage(t, id, rw.resp)
wsc.writeResponseMessage(id, rw.resp)
}

func (t *WSTunnelClient) finishRequest(id int16, req *http.Request) {
func (wsc *WSConnection) finishRequest(id int16, req *http.Request) {

log := t.Log.New("id", id, "verb", req.Method, "uri", req.RequestURI)
log := wsc.Log.New("id", id, "verb", req.Method, "uri", req.RequestURI)

// Honor X-Host header
host := t.Server
host := wsc.tun.Server
xHost := req.Header.Get("X-Host")
if xHost != "" {
if t.Regexp == nil {
re := wsc.tun.Regexp
if re == nil {
log.Info("WS got x-host header but no regexp provided")
writeResponseMessage(t, id, concoctResponse(req,
wsc.writeResponseMessage(id, concoctResponse(req,
"X-Host header disallowed by wstunnel cli (no -regexp option)", 403))
return
} else if t.Regexp.FindString(xHost) == xHost {
} else if re.FindString(xHost) == xHost {
host = xHost
} else {
log.Info("WS x-host disallowed by regexp", "x-host", xHost, "regexp",
t.Regexp.String(), "match", t.Regexp.FindString(xHost))
writeResponseMessage(t, id, concoctResponse(req,
re.String(), "match", re.FindString(xHost))
wsc.writeResponseMessage(id, concoctResponse(req,
"X-Host header '"+xHost+"' does not match regexp in wstunnel cli",
403))
return
}
} else if host == "" {
log.Info("WS no x-host header and -server not specified")
writeResponseMessage(t, id, concoctResponse(req,
wsc.writeResponseMessage(id, concoctResponse(req,
"X-Host header required by wstunnel cli (no -server option)", 403))
return
}
Expand All @@ -591,8 +606,7 @@ func (t *WSTunnelClient) finishRequest(id int16, req *http.Request) {
req.URL, err = url.Parse(fmt.Sprintf("%s%s", host, req.RequestURI))
if err != nil {
log.Warn("WS cannot parse requestURI", "err", err.Error())
writeResponseMessage(t, id, concoctResponse(req,
"Cannot parse request URI", 400))
wsc.writeResponseMessage(id, concoctResponse(req, "Cannot parse request URI", 400))
return
}
req.Host = req.URL.Host // we delete req.Header["Host"] further down
Expand All @@ -612,51 +626,51 @@ func (t *WSTunnelClient) finishRequest(id int16, req *http.Request) {
//log15.Info("handleWsRequests: request error", "err", err.Error(),
// "req", string(dump), "resp", string(dump2))
log.Info("HTTP request error", "err", err.Error())
writeResponseMessage(t, id, concoctResponse(req, err.Error(), 502))
wsc.writeResponseMessage(id, concoctResponse(req, err.Error(), 502))
return
}
log.Info("HTTP responded", "status", resp.Status)
defer resp.Body.Close()

writeResponseMessage(t, id, resp)
wsc.writeResponseMessage(id, resp)
}

// Write the response message to the websocket
func writeResponseMessage(t *WSTunnelClient, id int16, resp *http.Response) {
func (wsc *WSConnection) writeResponseMessage(id int16, resp *http.Response) {
// Get writer's lock
wsWriterMutex.Lock()
defer wsWriterMutex.Unlock()
// Write response into the tunnel
t.ws.SetWriteDeadline(time.Now().Add(time.Minute))
w, err := t.ws.NextWriter(websocket.BinaryMessage)
wsc.ws.SetWriteDeadline(time.Now().Add(time.Minute))
w, err := wsc.ws.NextWriter(websocket.BinaryMessage)
// got an error, reply with a "hey, retry" to the request handler
if err != nil {
t.Log.Warn("WS NextWriter", "err", err.Error())
t.ws.Close()
wsc.Log.Warn("WS NextWriter", "err", err.Error())
wsc.ws.Close()
return
}

// write the request Id
_, err = fmt.Fprintf(w, "%04x", id)
if err != nil {
t.Log.Warn("WS cannot write request Id", "err", err.Error())
t.ws.Close()
wsc.Log.Warn("WS cannot write request Id", "err", err.Error())
wsc.ws.Close()
return
}

// write the response itself
err = resp.Write(w)
if err != nil {
t.Log.Warn("WS cannot write response", "err", err.Error())
t.ws.Close()
wsc.Log.Warn("WS cannot write response", "err", err.Error())
wsc.ws.Close()
return
}

// done
err = w.Close()
if err != nil {
t.Log.Warn("WS write-close failed", "err", err.Error())
t.ws.Close()
wsc.Log.Warn("WS write-close failed", "err", err.Error())
wsc.ws.Close()
return
}
}
Expand Down

0 comments on commit 4766a8b

Please sign in to comment.