Skip to content

Commit

Permalink
sessionrecording: implement v2 recording endpoint support
Browse files Browse the repository at this point in the history
The v2 endpoint supports HTTP/2 bidirectional streaming and acks for
received bytes. This is used to detect when a recorder disappears to
more quickly terminate the session.

Updates tailscale/corp#24023

Signed-off-by: Andrew Lytvynov <awly@tailscale.com>
  • Loading branch information
awly committed Nov 15, 2024
1 parent 8fd471c commit 144fbe1
Show file tree
Hide file tree
Showing 4 changed files with 415 additions and 60 deletions.
2 changes: 1 addition & 1 deletion k8s-operator/sessionrecording/hijacker.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ type Hijacker struct {
// connection succeeds. In case of success, returns a list with a single
// successful recording attempt and an error channel. If the connection errors
// after having been established, an error is sent down the channel.
type RecorderDialFn func(context.Context, []netip.AddrPort, func(context.Context, string, string) (net.Conn, error)) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error)
type RecorderDialFn func(context.Context, []netip.AddrPort, sessionrecording.DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error)

// Hijack hijacks a 'kubectl exec' session and configures for the session
// contents to be sent to a recorder.
Expand Down
282 changes: 224 additions & 58 deletions sessionrecording/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,26 @@ package sessionrecording

import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
"net/netip"
"sync/atomic"
"time"

"golang.org/x/net/http2"
"tailscale.com/tailcfg"
"tailscale.com/util/multierr"
)

// DialFunc is a function for dialing the recorder.
type DialFunc func(context.Context, string, string) (net.Conn, error)

// ConnectToRecorder connects to the recorder at any of the provided addresses.
// It returns the first successful response, or a multierr if all attempts fail.
//
Expand All @@ -32,7 +39,7 @@ import (
// attempts are in order the recorder(s) was attempted. If successful a
// successful connection is made, the last attempt in the slice is the
// attempt for connected recorder.
func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(context.Context, string, string) (net.Conn, error)) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) {
func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) {
if len(recs) == 0 {
return nil, nil, nil, errors.New("no recorders configured")
}
Expand All @@ -41,10 +48,6 @@ func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(con
// unbounded context for the upload.
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
defer dialCancel()
hc, err := SessionRecordingClientForDialer(dialCtx, dial)
if err != nil {
return nil, nil, nil, err
}

var errs []error
var attempts []*tailcfg.SSHRecordingAttempt
Expand All @@ -54,72 +57,210 @@ func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(con
}
attempts = append(attempts, attempt)

// We dial the recorder and wait for it to send a 100-continue
// response before returning from this function. This ensures that
// the recorder is ready to accept the recording.

// got100 is closed when we receive the 100-continue response.
got100 := make(chan struct{})
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
Got100Continue: func() {
close(got100)
},
})

pr, pw := io.Pipe()
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", ap.Addr(), ap.Port()), pr)
var pw io.WriteCloser
var errChan <-chan error
var err error
hc := clientHTTP1(dialCtx, dial)
// Note: we need to use the HTTP/1 client for probing the recorder
// because older recorders don't support HTTP/2.
if supportsV2(ctx, hc, ap) {
pw, errChan, err = connectV2(ctx, clientHTTP2(dialCtx, dial), ap)
} else {
pw, errChan, err = connectV1(ctx, hc, ap)
}
if err != nil {
err = fmt.Errorf("recording: error starting recording: %w", err)
err = fmt.Errorf("recording: error starting recording on %q: %w", ap, err)
attempt.FailureMessage = err.Error()
errs = append(errs, err)
continue
}
// We set the Expect header to 100-continue, so that the recorder
// will send a 100-continue response before it starts reading the
// request body.
req.Header.Set("Expect", "100-continue")
return pw, attempts, errChan, nil
}
return nil, attempts, nil, multierr.New(errs...)
}

// errChan is used to indicate the result of the request.
errChan := make(chan error, 1)
go func() {
resp, err := hc.Do(req)
if err != nil {
errChan <- fmt.Errorf("recording: error starting recording: %w", err)
// supportsV2 checks whether a recorder instance supports the /v2/record
// endpoint.
func supportsV2(ctx context.Context, hc *http.Client, ap netip.AddrPort) bool {
req, err := http.NewRequestWithContext(ctx, http.MethodHead, fmt.Sprintf("http://%s/v2/record", ap), nil)
if err != nil {
return false
}
resp, err := hc.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}

// connectV1 connects to the legacy /record endpoint on the recorder. It is
// used for backwards-compatibility with older tsrecorder instances.
func connectV1(ctx context.Context, hc *http.Client, ap netip.AddrPort) (io.WriteCloser, <-chan error, error) {
// We dial the recorder and wait for it to send a 100-continue
// response before returning from this function. This ensures that
// the recorder is ready to accept the recording.

// got100 is closed when we receive the 100-continue response.
got100 := make(chan struct{})
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
Got100Continue: func() {
close(got100)
},
})

pr, pw := io.Pipe()
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s/record", ap), pr)
if err != nil {
return nil, nil, err
}
// We set the Expect header to 100-continue, so that the recorder
// will send a 100-continue response before it starts reading the
// request body.
req.Header.Set("Expect", "100-continue")

// errChan is used to indicate the result of the request.
errChan := make(chan error, 1)
go func() {
defer close(errChan)
resp, err := hc.Do(req)
if err != nil {
errChan <- err
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status)
return
}
}()
select {
case <-got100:
case err := <-errChan:
// If we get an error before we get the 100-continue response,
// we need to try another recorder.
if err == nil {
// If the error is nil, we got a 200 response, which
// is unexpected as we haven't sent any data yet.
err = errors.New("recording: unexpected EOF")
}
return nil, nil, err
}
return pw, errChan, nil
}

// connectV2 connects to the /v2/record endpoint on the recorder over HTTP/2.
// It explicitly tracks ack frames sent in the response and terminates the
// connection if sent recording data is un-acked for uploadAckWindow.
func connectV2(ctx context.Context, hc *http.Client, ap netip.AddrPort) (io.WriteCloser, <-chan error, error) {
pr, pw := io.Pipe()
upload := &readCounter{r: pr}
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s/v2/record", ap), upload)
if err != nil {
return nil, nil, err
}

// With HTTP/2, hc.Do will not block while the request body is being sent.
// It will return immediately and allow us to consume the response body at
// the same time.
resp, err := hc.Do(req)
if err != nil {
return nil, nil, err
}
if resp.StatusCode != 200 {
resp.Body.Close()
return nil, nil, fmt.Errorf("recording: unexpected status: %v", resp.Status)
}

errChan := make(chan error, 1)
acks := make(chan int64)
// Read acks from the response and send them to the acks channel.
go func() {
defer close(errChan)
defer close(acks)
defer resp.Body.Close()
dec := json.NewDecoder(resp.Body)
for {
var frame ackFrame
if err := dec.Decode(&frame); err != nil {
if !errors.Is(err, io.EOF) {
errChan <- fmt.Errorf("recording: unexpected error receiving acks: %w", err)
}
return
}
if resp.StatusCode != 200 {
errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status)
select {
case acks <- frame.Ack:
case <-ctx.Done():
return
}
errChan <- nil
}()
select {
case <-got100:
case err := <-errChan:
// If we get an error before we get the 100-continue response,
// we need to try another recorder.
if err == nil {
// If the error is nil, we got a 200 response, which
// is unexpected as we haven't sent any data yet.
err = errors.New("recording: unexpected EOF")
}
}()
// Track acks from the acks channel.
go func() {
// Hack for tests: some tests modify uploadAckWindow and reset it when
// the test ends. This can race with t.Reset call below. Making a copy
// here is a lazy workaround to not wait for this goroutine to exit in
// the test cases.
window := uploadAckWindow
// This timer fires if we didn't receive an ack for too long.
t := time.NewTimer(window)
defer t.Stop()
for {
select {
case <-t.C:
// Close the pipe which terminates the connection and cleans up
// other goroutines. Note that tsrecorder will send us ack
// frames even if there is no new data to ack. This helps
// detect broken recorder connection if the session is idle.
pr.CloseWithError(errNoAcks)
resp.Body.Close()
return
case _, ok := <-acks:
if !ok {
// acks channel closed means that the goroutine reading them
// finished, which means that the request has ended.
return
}
// TODO(awly): limit how far behind the received acks can be. This
// should handle scenarios where a session suddenly dumps a lot of
// output.
t.Reset(window)
case <-ctx.Done():
return
}
attempt.FailureMessage = err.Error()
errs = append(errs, err)
continue // try the next recorder
}
return pw, attempts, errChan, nil
}
return nil, attempts, nil, multierr.New(errs...)
}()

return pw, errChan, nil
}

// SessionRecordingClientForDialer returns an http.Client that uses a clone of
// the provided Dialer's PeerTransport to dial connections. This is used to make
// requests to the session recording server to upload session recordings. It
// uses the provided dialCtx to dial connections, and limits a single dial to 5
// seconds.
func SessionRecordingClientForDialer(dialCtx context.Context, dial func(context.Context, string, string) (net.Conn, error)) (*http.Client, error) {
tr := http.DefaultTransport.(*http.Transport).Clone()
// uploadAckWindow is the period of time to wait for an ackFrame from recorder
// before terminating the connection. This is a variable to allow overriding it
// in tests.
var uploadAckWindow = 30 * time.Second

var errNoAcks = errors.New("did not receive ack frames from the recorder in 30s")

type ackFrame struct {
Ack int64 `json:"ack"`
}

// readCounter is an io.Reader that counts how many bytes were read.
type readCounter struct {
r io.Reader
sent atomic.Int64
}

func (u *readCounter) Read(buf []byte) (int, error) {
n, err := u.r.Read(buf)
u.sent.Add(int64(n))
return n, err
}

// clientHTTP1 returns a claassic http.Client with a per-dial context. It uses
// dialCtx and adds a 5s timeout to it.
func clientHTTP1(dialCtx context.Context, dial DialFunc) *http.Client {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
perAttemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
Expand All @@ -132,7 +273,32 @@ func SessionRecordingClientForDialer(dialCtx context.Context, dial func(context.
}()
return dial(perAttemptCtx, network, addr)
}
return &http.Client{Transport: tr}
}

// clientHTTP2 is like clientHTTP1 but returns an http.Client suitable for h2c
// requests (HTTP/2 over plaintext). Unfortunately the same client does not
// work for HTTP/1 so we need to split these up.
func clientHTTP2(dialCtx context.Context, dial DialFunc) *http.Client {
return &http.Client{
Transport: tr,
}, nil
Transport: &http2.Transport{
// Allow "http://" scheme in URLs.
AllowHTTP: true,
// Pretend like we're using TLS, but actually use the provided
// DialFunc underneath. This is necessary to convince the transport
// to actually dial.
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
perAttemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
go func() {
select {
case <-perAttemptCtx.Done():
case <-dialCtx.Done():
cancel()
}
}()
return dial(perAttemptCtx, network, addr)
},
},
}
}
Loading

0 comments on commit 144fbe1

Please sign in to comment.