Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transport: Discard the buffer when empty after http connect handshake #7424

Merged
merged 4 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions internal/transport/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,14 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri
}
return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump)
}

return &bufConn{Conn: conn, r: r}, nil
// The buffer could contain extra bytes from the target server, so we can't
// discard it. However, in many cases where the server waits for the client
// to send the first message (e.g. when TLS is being used), the buffer will
// be empty, so we can avoid the overhead of reading through this buffer.
if r.Buffered() != 0 {
return &bufConn{Conn: conn, r: r}, nil
}
return conn, nil
}

// proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy
Expand Down
82 changes: 68 additions & 14 deletions internal/transport/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package transport

import (
"bufio"
"bytes"
"context"
"encoding/base64"
"fmt"
Expand Down Expand Up @@ -58,7 +59,7 @@ type proxyServer struct {
requestCheck func(*http.Request) error
}

func (p *proxyServer) run() {
func (p *proxyServer) run(waitForServerHello bool) {
in, err := p.lis.Accept()
if err != nil {
return
Expand All @@ -83,8 +84,24 @@ func (p *proxyServer) run() {
p.t.Errorf("failed to dial to server: %v", err)
return
}
out.SetReadDeadline(time.Now().Add(defaultTestTimeout))
resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"}
resp.Write(p.in)
var buf bytes.Buffer
resp.Write(&buf)
if waitForServerHello {
// Batch the first message from the server with the http connect
// response. This is done to test the cases in which the grpc client has
// the response to the connect request and proxied packets from the
// destination server when it reads the transport.
b := make([]byte, 50)
bytesRead, err := out.Read(b)
if err != nil {
p.t.Errorf("Got error while reading server hello: %v", err)
return
}
buf.Write(b[0:bytesRead])
}
p.in.Write(buf.Bytes())
p.out = out
go io.Copy(p.in, p.out)
go io.Copy(p.out, p.in)
Expand All @@ -100,17 +117,23 @@ func (p *proxyServer) stop() {
}
}

func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error) {
type testArgs struct {
proxyURLModify func(*url.URL) *url.URL
proxyReqCheck func(*http.Request) error
serverMessage []byte
}

func testHTTPConnect(t *testing.T, args testArgs) {
plis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
p := &proxyServer{
t: t,
lis: plis,
requestCheck: proxyReqCheck,
requestCheck: args.proxyReqCheck,
}
go p.run()
go p.run(len(args.serverMessage) > 0)
defer p.stop()

blis, err := net.Listen("tcp", "localhost:0")
Expand All @@ -128,13 +151,14 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy
return
}
defer in.Close()
in.Write(args.serverMessage)
in.Read(recvBuf)
done <- nil
}()

// Overwrite the function in the test and restore them in defer.
hpfe := func(req *http.Request) (*url.URL, error) {
return proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
return args.proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
}
defer overwrite(hpfe)()

Expand All @@ -157,33 +181,62 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy
if string(recvBuf) != string(msg) {
t.Fatalf("received msg: %v, want %v", recvBuf, msg)
}

if len(args.serverMessage) > 0 {
c.SetReadDeadline(time.Now().Add(defaultTestTimeout))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's now 3 places where we're setting this deadline.

Maybe instead when the conn is created in proxyDial we should just do c.SetDeadline(....defaultTestTimeout) (note Deadline not Read to also abort writes)?

Copy link
Contributor Author

@arjan-bal arjan-bal Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During testing, I noticed that the test can hang in proxyDial if the proxy server waits for a server hello when the server doesn't intend to send one. This was due to a logical error on my part which I fixed.

To be safe, I added read deadlines in both the places where reads could hang.

Maybe instead when the conn is created in proxyDial we should just do c.SetDeadline(....defaultTestTimeout) (note Deadline not Read to also abort writes)?

Do you mean adding a deadline in the actual proxyDial() implementation by introducing a timeout that only the tests use?
Or are you referring to adding a deadline on the conn returned by proxyDial in the test code?

The latter would not prevent proxyDial from hanging when the proxy server is stuck reading the server hello. The former would work but it involves changing non-test code, which I was trying to avoid.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't mean in proxyDial... I meant the connection it returns, we could immediately, always, do a SetDeadline(...), instead of doing it later and conditionally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed SetReadDeadline to SetDeadline. Moved the SetDeadline immediately after the proxyDial call. The SetDeadline in proxyServer.run() is still present to avoid proxyDial from hanging indefinitely in case of test failures.

gotServerMessage := make([]byte, len(args.serverMessage))
if _, err := c.Read(gotServerMessage); err != nil {
t.Errorf("Got error while reading message from server: %v", err)
return
}
if string(gotServerMessage) != string(args.serverMessage) {
t.Errorf("message from server: %v, want %v", gotServerMessage, args.serverMessage)
}
}
}

func (s) TestHTTPConnect(t *testing.T) {
testHTTPConnect(t,
func(in *url.URL) *url.URL {
args := testArgs{
proxyURLModify: func(in *url.URL) *url.URL {
return in
},
func(req *http.Request) error {
proxyReqCheck: func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
return nil
},
)
}
testHTTPConnect(t, args)
}

func (s) TestHTTPConnectWithServerHello(t *testing.T) {
args := testArgs{
proxyURLModify: func(in *url.URL) *url.URL {
return in
},
proxyReqCheck: func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
return nil
},
serverMessage: []byte("server-hello"),
}
testHTTPConnect(t, args)
}

func (s) TestHTTPConnectBasicAuth(t *testing.T) {
const (
user = "notAUser"
password = "notAPassword"
)
testHTTPConnect(t,
func(in *url.URL) *url.URL {
args := testArgs{
proxyURLModify: func(in *url.URL) *url.URL {
in.User = url.UserPassword(user, password)
return in
},
func(req *http.Request) error {
proxyReqCheck: func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
Expand All @@ -195,7 +248,8 @@ func (s) TestHTTPConnectBasicAuth(t *testing.T) {
}
return nil
},
)
}
testHTTPConnect(t, args)
}

func (s) TestMapAddressEnv(t *testing.T) {
Expand Down