From b0e2b5d9e40719b4b3688cc2f6d3ad19bcf5aa64 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Thu, 18 Jul 2024 12:16:21 +0530 Subject: [PATCH 1/4] Discard the buffer when empty after http connect handshake --- internal/transport/proxy.go | 10 ++++-- internal/transport/proxy_test.go | 52 ++++++++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/internal/transport/proxy.go b/internal/transport/proxy.go index 24fa1032574c..54b224436544 100644 --- a/internal/transport/proxy.go +++ b/internal/transport/proxy.go @@ -107,8 +107,14 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri } return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump) } - - return &bufConn{Conn: conn, r: r}, nil + // The buffer could contain extra bytes from the target server, so we can't + // discard it. However, in many cases where the server waits for the client + // to send the first message (e.g. when TLS is being used), the buffer will + // be empty, so we can avoid the overhead of reading through this buffer. + if r.Buffered() != 0 { + return &bufConn{Conn: conn, r: r}, nil + } + return conn, nil } // proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy diff --git a/internal/transport/proxy_test.go b/internal/transport/proxy_test.go index 8abee1e7b383..8790d20c146b 100644 --- a/internal/transport/proxy_test.go +++ b/internal/transport/proxy_test.go @@ -23,6 +23,7 @@ package transport import ( "bufio" + "bytes" "context" "encoding/base64" "fmt" @@ -84,7 +85,24 @@ func (p *proxyServer) run() { return } resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"} - resp.Write(p.in) + var buf bytes.Buffer + resp.Write(&buf) + // Batch the first message from the server with the http connect + // response. This is done to test the cases in which the grpc client has + // the response to the connect request and proxied packets from the + // destination server when it reads the transport. + out.SetReadDeadline(time.Now().Add(20 * time.Millisecond)) + b := make([]byte, 50) + bytesRead, err := out.Read(b) + // The read is expected to fail with deadline exceeded if the server doesn't + // have a message to send. + if err != nil { + p.t.Logf("Got error while reading message from server: %v", err) + } + // reset the deadline. + out.SetReadDeadline(time.Time{}) + buf.Write(b[0:bytesRead]) + p.in.Write(buf.Bytes()) p.out = out go io.Copy(p.in, p.out) go io.Copy(p.out, p.in) @@ -100,7 +118,7 @@ func (p *proxyServer) stop() { } } -func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error) { +func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error, serverMessage []byte) { plis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("failed to listen: %v", err) @@ -128,6 +146,7 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy return } defer in.Close() + in.Write(serverMessage) in.Read(recvBuf) done <- nil }() @@ -157,6 +176,18 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy if string(recvBuf) != string(msg) { t.Fatalf("received msg: %v, want %v", recvBuf, msg) } + + c.SetReadDeadline(time.Now().Add(20 * time.Millisecond)) + + gotServerMessage := make([]byte, len(serverMessage)) + // This call will return a deadline exceeded error if the server has nothing + // to send. This is expected. + if _, err := c.Read(gotServerMessage); err != nil { + t.Logf("Got error while reading message from server: %v", err) + } + if string(gotServerMessage) != string(serverMessage) { + t.Fatalf("message from server: %v, want %v", gotServerMessage, serverMessage) + } } func (s) TestHTTPConnect(t *testing.T) { @@ -170,6 +201,22 @@ func (s) TestHTTPConnect(t *testing.T) { } return nil }, + nil, + ) +} + +func (s) TestHTTPConnectWithServerHello(t *testing.T) { + testHTTPConnect(t, + func(in *url.URL) *url.URL { + return in + }, + func(req *http.Request) error { + if req.Method != http.MethodConnect { + return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) + } + return nil + }, + []byte("server-hello"), ) } @@ -195,6 +242,7 @@ func (s) TestHTTPConnectBasicAuth(t *testing.T) { } return nil }, + nil, ) } From 3cf3bd6776830a74803adf358ff4f76488f04340 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Fri, 26 Jul 2024 15:54:39 +0530 Subject: [PATCH 2/4] configure the proxy to wait for server hello --- internal/transport/proxy_test.go | 51 ++++++++++++++++---------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/internal/transport/proxy_test.go b/internal/transport/proxy_test.go index 8790d20c146b..3cd9fee8d21d 100644 --- a/internal/transport/proxy_test.go +++ b/internal/transport/proxy_test.go @@ -59,7 +59,7 @@ type proxyServer struct { requestCheck func(*http.Request) error } -func (p *proxyServer) run() { +func (p *proxyServer) run(waitForServerHello bool) { in, err := p.lis.Accept() if err != nil { return @@ -84,24 +84,23 @@ func (p *proxyServer) run() { p.t.Errorf("failed to dial to server: %v", err) return } + out.SetReadDeadline(time.Now().Add(defaultTestTimeout)) resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"} var buf bytes.Buffer resp.Write(&buf) - // Batch the first message from the server with the http connect - // response. This is done to test the cases in which the grpc client has - // the response to the connect request and proxied packets from the - // destination server when it reads the transport. - out.SetReadDeadline(time.Now().Add(20 * time.Millisecond)) - b := make([]byte, 50) - bytesRead, err := out.Read(b) - // The read is expected to fail with deadline exceeded if the server doesn't - // have a message to send. - if err != nil { - p.t.Logf("Got error while reading message from server: %v", err) + if waitForServerHello { + // Batch the first message from the server with the http connect + // response. This is done to test the cases in which the grpc client has + // the response to the connect request and proxied packets from the + // destination server when it reads the transport. + b := make([]byte, 50) + bytesRead, err := out.Read(b) + if err != nil { + p.t.Errorf("Got error while reading server hello: %v", err) + return + } + buf.Write(b[0:bytesRead]) } - // reset the deadline. - out.SetReadDeadline(time.Time{}) - buf.Write(b[0:bytesRead]) p.in.Write(buf.Bytes()) p.out = out go io.Copy(p.in, p.out) @@ -128,7 +127,7 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy lis: plis, requestCheck: proxyReqCheck, } - go p.run() + go p.run(len(serverMessage) > 0) defer p.stop() blis, err := net.Listen("tcp", "localhost:0") @@ -177,16 +176,16 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy t.Fatalf("received msg: %v, want %v", recvBuf, msg) } - c.SetReadDeadline(time.Now().Add(20 * time.Millisecond)) - - gotServerMessage := make([]byte, len(serverMessage)) - // This call will return a deadline exceeded error if the server has nothing - // to send. This is expected. - if _, err := c.Read(gotServerMessage); err != nil { - t.Logf("Got error while reading message from server: %v", err) - } - if string(gotServerMessage) != string(serverMessage) { - t.Fatalf("message from server: %v, want %v", gotServerMessage, serverMessage) + if len(serverMessage) > 0 { + c.SetReadDeadline(time.Now().Add(defaultTestTimeout)) + gotServerMessage := make([]byte, len(serverMessage)) + if _, err := c.Read(gotServerMessage); err != nil { + t.Errorf("Got error while reading message from server: %v", err) + return + } + if string(gotServerMessage) != string(serverMessage) { + t.Fatalf("message from server: %v, want %v", gotServerMessage, serverMessage) + } } } From c17668c71c8b7a99923690eae03c0581271073c4 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Fri, 26 Jul 2024 16:07:19 +0530 Subject: [PATCH 3/4] Extract test args to a struct --- internal/transport/proxy_test.go | 55 ++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/internal/transport/proxy_test.go b/internal/transport/proxy_test.go index 3cd9fee8d21d..59cc528abb88 100644 --- a/internal/transport/proxy_test.go +++ b/internal/transport/proxy_test.go @@ -117,7 +117,13 @@ func (p *proxyServer) stop() { } } -func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error, serverMessage []byte) { +type testArgs struct { + proxyURLModify func(*url.URL) *url.URL + proxyReqCheck func(*http.Request) error + serverMessage []byte +} + +func testHTTPConnect(t *testing.T, args testArgs) { plis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("failed to listen: %v", err) @@ -125,9 +131,9 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy p := &proxyServer{ t: t, lis: plis, - requestCheck: proxyReqCheck, + requestCheck: args.proxyReqCheck, } - go p.run(len(serverMessage) > 0) + go p.run(len(args.serverMessage) > 0) defer p.stop() blis, err := net.Listen("tcp", "localhost:0") @@ -145,14 +151,14 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy return } defer in.Close() - in.Write(serverMessage) + in.Write(args.serverMessage) in.Read(recvBuf) done <- nil }() // Overwrite the function in the test and restore them in defer. hpfe := func(req *http.Request) (*url.URL, error) { - return proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil + return args.proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil } defer overwrite(hpfe)() @@ -176,47 +182,48 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy t.Fatalf("received msg: %v, want %v", recvBuf, msg) } - if len(serverMessage) > 0 { + if len(args.serverMessage) > 0 { c.SetReadDeadline(time.Now().Add(defaultTestTimeout)) - gotServerMessage := make([]byte, len(serverMessage)) + gotServerMessage := make([]byte, len(args.serverMessage)) if _, err := c.Read(gotServerMessage); err != nil { t.Errorf("Got error while reading message from server: %v", err) return } - if string(gotServerMessage) != string(serverMessage) { - t.Fatalf("message from server: %v, want %v", gotServerMessage, serverMessage) + if string(gotServerMessage) != string(args.serverMessage) { + t.Errorf("message from server: %v, want %v", gotServerMessage, args.serverMessage) } } } func (s) TestHTTPConnect(t *testing.T) { - testHTTPConnect(t, - func(in *url.URL) *url.URL { + args := testArgs{ + proxyURLModify: func(in *url.URL) *url.URL { return in }, - func(req *http.Request) error { + proxyReqCheck: func(req *http.Request) error { if req.Method != http.MethodConnect { return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) } return nil }, - nil, - ) + } + testHTTPConnect(t, args) } func (s) TestHTTPConnectWithServerHello(t *testing.T) { - testHTTPConnect(t, - func(in *url.URL) *url.URL { + args := testArgs{ + proxyURLModify: func(in *url.URL) *url.URL { return in }, - func(req *http.Request) error { + proxyReqCheck: func(req *http.Request) error { if req.Method != http.MethodConnect { return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) } return nil }, - []byte("server-hello"), - ) + serverMessage: []byte("server-hello"), + } + testHTTPConnect(t, args) } func (s) TestHTTPConnectBasicAuth(t *testing.T) { @@ -224,12 +231,12 @@ func (s) TestHTTPConnectBasicAuth(t *testing.T) { user = "notAUser" password = "notAPassword" ) - testHTTPConnect(t, - func(in *url.URL) *url.URL { + args := testArgs{ + proxyURLModify: func(in *url.URL) *url.URL { in.User = url.UserPassword(user, password) return in }, - func(req *http.Request) error { + proxyReqCheck: func(req *http.Request) error { if req.Method != http.MethodConnect { return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) } @@ -241,8 +248,8 @@ func (s) TestHTTPConnectBasicAuth(t *testing.T) { } return nil }, - nil, - ) + } + testHTTPConnect(t, args) } func (s) TestMapAddressEnv(t *testing.T) { From db50122c8456c398e0061888936ba6b108bc7c83 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Tue, 30 Jul 2024 18:54:35 +0530 Subject: [PATCH 4/4] Change deadline sets --- internal/transport/proxy_test.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/internal/transport/proxy_test.go b/internal/transport/proxy_test.go index 59cc528abb88..9fdd662ddd82 100644 --- a/internal/transport/proxy_test.go +++ b/internal/transport/proxy_test.go @@ -84,7 +84,7 @@ func (p *proxyServer) run(waitForServerHello bool) { p.t.Errorf("failed to dial to server: %v", err) return } - out.SetReadDeadline(time.Now().Add(defaultTestTimeout)) + out.SetDeadline(time.Now().Add(defaultTestTimeout)) resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"} var buf bytes.Buffer resp.Write(&buf) @@ -97,6 +97,8 @@ func (p *proxyServer) run(waitForServerHello bool) { bytesRead, err := out.Read(b) if err != nil { p.t.Errorf("Got error while reading server hello: %v", err) + in.Close() + out.Close() return } buf.Write(b[0:bytesRead]) @@ -167,30 +169,30 @@ func testHTTPConnect(t *testing.T, args testArgs) { defer cancel() c, err := proxyDial(ctx, blis.Addr().String(), "test") if err != nil { - t.Fatalf("http connect Dial failed: %v", err) + t.Fatalf("HTTP connect Dial failed: %v", err) } defer c.Close() + c.SetDeadline(time.Now().Add(defaultTestTimeout)) // Send msg on the connection. c.Write(msg) if err := <-done; err != nil { - t.Fatalf("failed to accept: %v", err) + t.Fatalf("Failed to accept: %v", err) } // Check received msg. if string(recvBuf) != string(msg) { - t.Fatalf("received msg: %v, want %v", recvBuf, msg) + t.Fatalf("Received msg: %v, want %v", recvBuf, msg) } if len(args.serverMessage) > 0 { - c.SetReadDeadline(time.Now().Add(defaultTestTimeout)) gotServerMessage := make([]byte, len(args.serverMessage)) if _, err := c.Read(gotServerMessage); err != nil { t.Errorf("Got error while reading message from server: %v", err) return } if string(gotServerMessage) != string(args.serverMessage) { - t.Errorf("message from server: %v, want %v", gotServerMessage, args.serverMessage) + t.Errorf("Message from server: %v, want %v", gotServerMessage, args.serverMessage) } } }