Skip to content

Commit

Permalink
Improve hlog handlers performance by switching to pointer logger
Browse files Browse the repository at this point in the history
Update request logger's context thru its pointer in order to avoid
multiple copies/allocations.
  • Loading branch information
rs committed Aug 30, 2017
1 parent 560e884 commit e26050b
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ matrix:
allow_failures:
- go: tip
script:
go test -v -race -cpu=1,2,4 ./...
go test -v -race -cpu=1,2,4 -bench . -benchmem ./...
15 changes: 12 additions & 3 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ import (
"io/ioutil"
)

var disabledLogger = New(ioutil.Discard).Level(Disabled)
var disabledLogger *Logger

func init() {
l := New(ioutil.Discard).Level(Disabled)
disabledLogger = &l
}

type ctxKey struct{}

Expand All @@ -24,14 +29,18 @@ func (l Logger) WithContext(ctx context.Context) context.Context {
*lp = l
return ctx
}
if l.level == Disabled {
// Do not store disabled logger.
return ctx
}
return context.WithValue(ctx, ctxKey{}, &l)
}

// Ctx returns the Logger associated with the ctx. If no logger
// is associated, a disabled logger is returned.
func Ctx(ctx context.Context) Logger {
func Ctx(ctx context.Context) *Logger {
if l, ok := ctx.Value(ctxKey{}).(*Logger); ok {
return *l
return l
}
return disabledLogger
}
23 changes: 20 additions & 3 deletions ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,37 @@ func TestCtx(t *testing.T) {
log := New(ioutil.Discard)
ctx := log.WithContext(context.Background())
log2 := Ctx(ctx)
if !reflect.DeepEqual(log, log2) {
if !reflect.DeepEqual(log, *log2) {
t.Error("Ctx did not return the expected logger")
}

// update
log = log.Level(InfoLevel)
ctx = log.WithContext(ctx)
log2 = Ctx(ctx)
if !reflect.DeepEqual(log, log2) {
if !reflect.DeepEqual(log, *log2) {
t.Error("Ctx did not return the expected logger")
}

log2 = Ctx(context.Background())
if !reflect.DeepEqual(log2, disabledLogger) {
if log2 != disabledLogger {
t.Error("Ctx did not return the expected logger")
}
}

func TestCtxDisabled(t *testing.T) {
ctx := disabledLogger.WithContext(context.Background())
if ctx != context.Background() {
t.Error("WithContext stored a disabled logger")
}

ctx = New(ioutil.Discard).WithContext(ctx)
if reflect.DeepEqual(Ctx(ctx), disabledLogger) {
t.Error("WithContext did not store logger")
}

ctx = disabledLogger.WithContext(ctx)
if !reflect.DeepEqual(Ctx(ctx), disabledLogger) {
t.Error("WithContext did not update logger pointer with disabled logger")
}
}
47 changes: 29 additions & 18 deletions hlog/hlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@ import (

// FromRequest gets the logger in the request's context.
// This is a shortcut for log.Ctx(r.Context())
func FromRequest(r *http.Request) zerolog.Logger {
func FromRequest(r *http.Request) *zerolog.Logger {
return log.Ctx(r.Context())
}

// NewHandler injects log into requests context.
func NewHandler(log zerolog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(log.WithContext(r.Context()))
// Create a copy of the logger (including internal context slice)
// to prevent data race when using UpdateContext.
l := log.With().Logger()
r = r.WithContext(l.WithContext(r.Context()))
next.ServeHTTP(w, r)
})
}
Expand All @@ -35,8 +38,9 @@ func URLHandler(fieldKey string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log := zerolog.Ctx(r.Context())
log = log.With().Str(fieldKey, r.URL.String()).Logger()
r = r.WithContext(log.WithContext(r.Context()))
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, r.URL.String())
})
next.ServeHTTP(w, r)
})
}
Expand All @@ -48,8 +52,9 @@ func MethodHandler(fieldKey string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log := zerolog.Ctx(r.Context())
log = log.With().Str(fieldKey, r.Method).Logger()
r = r.WithContext(log.WithContext(r.Context()))
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, r.Method)
})
next.ServeHTTP(w, r)
})
}
Expand All @@ -61,8 +66,9 @@ func RequestHandler(fieldKey string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log := zerolog.Ctx(r.Context())
log = log.With().Str(fieldKey, r.Method+" "+r.URL.String()).Logger()
r = r.WithContext(log.WithContext(r.Context()))
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, r.Method+" "+r.URL.String())
})
next.ServeHTTP(w, r)
})
}
Expand All @@ -75,8 +81,9 @@ func RemoteAddrHandler(fieldKey string) func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
log := zerolog.Ctx(r.Context())
log = log.With().Str(fieldKey, host).Logger()
r = r.WithContext(log.WithContext(r.Context()))
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, host)
})
}
next.ServeHTTP(w, r)
})
Expand All @@ -90,8 +97,9 @@ func UserAgentHandler(fieldKey string) func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if ua := r.Header.Get("User-Agent"); ua != "" {
log := zerolog.Ctx(r.Context())
log = log.With().Str(fieldKey, ua).Logger()
r = r.WithContext(log.WithContext(r.Context()))
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, ua)
})
}
next.ServeHTTP(w, r)
})
Expand All @@ -105,8 +113,9 @@ func RefererHandler(fieldKey string) func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if ref := r.Header.Get("Referer"); ref != "" {
log := zerolog.Ctx(r.Context())
log = log.With().Str(fieldKey, ref).Logger()
r = r.WithContext(log.WithContext(r.Context()))
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, ref)
})
}
next.ServeHTTP(w, r)
})
Expand Down Expand Up @@ -136,16 +145,18 @@ func IDFromRequest(r *http.Request) (id xid.ID, ok bool) {
func RequestIDHandler(fieldKey, headerName string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id, ok := IDFromRequest(r)
if !ok {
id = xid.New()
ctx := context.WithValue(r.Context(), idKey{}, id)
ctx = context.WithValue(ctx, idKey{}, id)
r = r.WithContext(ctx)
}
if fieldKey != "" {
log := zerolog.Ctx(r.Context())
log = log.With().Str(fieldKey, id.String()).Logger()
r = r.WithContext(log.WithContext(r.Context()))
log := zerolog.Ctx(ctx)
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, id.String())
})
}
if headerName != "" {
w.Header().Set(headerName, id.String())
Expand Down
108 changes: 86 additions & 22 deletions hlog/hlog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package hlog
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"testing"
Expand All @@ -23,7 +24,7 @@ func TestNewHandler(t *testing.T) {
lh := NewHandler(log)
h := lh(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
if !reflect.DeepEqual(l, log) {
if !reflect.DeepEqual(*l, log) {
t.Fail()
}
}))
Expand All @@ -38,12 +39,12 @@ func TestURLHandler(t *testing.T) {
h := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
if want, got := `{"url":"/path?foo=bar"}`+"\n", out.String(); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}))
h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(nil, r)
if want, got := `{"url":"/path?foo=bar"}`+"\n", out.String(); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

func TestMethodHandler(t *testing.T) {
Expand All @@ -54,12 +55,12 @@ func TestMethodHandler(t *testing.T) {
h := MethodHandler("method")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
if want, got := `{"method":"POST"}`+"\n", out.String(); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}))
h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(nil, r)
if want, got := `{"method":"POST"}`+"\n", out.String(); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

func TestRequestHandler(t *testing.T) {
Expand All @@ -71,12 +72,12 @@ func TestRequestHandler(t *testing.T) {
h := RequestHandler("request")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
if want, got := `{"request":"POST /path?foo=bar"}`+"\n", out.String(); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}))
h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(nil, r)
if want, got := `{"request":"POST /path?foo=bar"}`+"\n", out.String(); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

func TestRemoteAddrHandler(t *testing.T) {
Expand All @@ -87,12 +88,12 @@ func TestRemoteAddrHandler(t *testing.T) {
h := RemoteAddrHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
if want, got := `{"ip":"1.2.3.4"}`+"\n", out.String(); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}))
h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(nil, r)
if want, got := `{"ip":"1.2.3.4"}`+"\n", out.String(); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

func TestRemoteAddrHandlerIPv6(t *testing.T) {
Expand All @@ -103,12 +104,12 @@ func TestRemoteAddrHandlerIPv6(t *testing.T) {
h := RemoteAddrHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
if want, got := `{"ip":"2001:db8:a0b:12f0::1"}`+"\n", out.String(); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}))
h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(nil, r)
if want, got := `{"ip":"2001:db8:a0b:12f0::1"}`+"\n", out.String(); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

func TestUserAgentHandler(t *testing.T) {
Expand All @@ -121,12 +122,12 @@ func TestUserAgentHandler(t *testing.T) {
h := UserAgentHandler("ua")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
if want, got := `{"ua":"some user agent string"}`+"\n", out.String(); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}))
h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(nil, r)
if want, got := `{"ua":"some user agent string"}`+"\n", out.String(); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

func TestRefererHandler(t *testing.T) {
Expand All @@ -139,12 +140,12 @@ func TestRefererHandler(t *testing.T) {
h := RefererHandler("referer")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
if want, got := `{"referer":"http://foo.com/bar"}`+"\n", out.String(); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}))
h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(nil, r)
if want, got := `{"referer":"http://foo.com/bar"}`+"\n", out.String(); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

func TestRequestIDHandler(t *testing.T) {
Expand All @@ -171,3 +172,66 @@ func TestRequestIDHandler(t *testing.T) {
h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(httptest.NewRecorder(), r)
}

func TestCombinedHandlers(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{
Method: "POST",
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
}
h := MethodHandler("method")(RequestHandler("request")(URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
}))))
h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(nil, r)
if want, got := `{"method":"POST","request":"POST /path?foo=bar","url":"/path?foo=bar"}`+"\n", out.String(); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

func BenchmarkHandlers(b *testing.B) {
r := &http.Request{
Method: "POST",
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
}
h1 := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
}))
h2 := MethodHandler("method")(RequestHandler("request")(h1))
handlers := map[string]http.Handler{
"Single": NewHandler(zerolog.New(ioutil.Discard))(h1),
"Combined": NewHandler(zerolog.New(ioutil.Discard))(h2),
"SingleDisabled": NewHandler(zerolog.New(ioutil.Discard).Level(zerolog.Disabled))(h1),
"CombinedDisabled": NewHandler(zerolog.New(ioutil.Discard).Level(zerolog.Disabled))(h2),
}
for name := range handlers {
h := handlers[name]
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
h.ServeHTTP(nil, r)
}
})
}
}

func BenchmarkDataRace(b *testing.B) {
log := zerolog.New(nil).With().
Str("foo", "bar").
Logger()
lh := NewHandler(log)
h := lh(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("bar", "baz")
})
l.Log().Msg("")
}))

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
h.ServeHTTP(nil, &http.Request{})
}
})
}
Loading

0 comments on commit e26050b

Please sign in to comment.