Skip to content

Commit

Permalink
feat(context): add ContextWithFallback feature flag (gin-gonic#3166) (g…
Browse files Browse the repository at this point in the history
…in-gonic#3172)

Enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value()
  • Loading branch information
wei840222 authored Jun 6, 2022
1 parent 92ba8e1 commit f197a8b
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 16 deletions.
8 changes: 4 additions & 4 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -1158,23 +1158,23 @@ func (c *Context) SetAccepted(formats ...string) {

// Deadline returns that there is no deadline (ok==false) when c.Request has no Context.
func (c *Context) Deadline() (deadline time.Time, ok bool) {
if c.Request == nil || c.Request.Context() == nil {
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
return
}
return c.Request.Context().Deadline()
}

// Done returns nil (chan which will wait forever) when c.Request has no Context.
func (c *Context) Done() <-chan struct{} {
if c.Request == nil || c.Request.Context() == nil {
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
return nil
}
return c.Request.Context().Done()
}

// Err returns nil when c.Request has no Context.
func (c *Context) Err() error {
if c.Request == nil || c.Request.Context() == nil {
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
return nil
}
return c.Request.Context().Err()
Expand All @@ -1195,7 +1195,7 @@ func (c *Context) Value(key any) any {
return val
}
}
if c.Request == nil || c.Request.Context() == nil {
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
return nil
}
return c.Request.Context().Value(key)
Expand Down
115 changes: 103 additions & 12 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2097,12 +2097,18 @@ func TestRemoteIPFail(t *testing.T) {
}

func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
c := &Context{}
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true

deadline, ok := c.Deadline()
assert.Zero(t, deadline)
assert.False(t, ok)

c2 := &Context{}
c2, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c2.engine.ContextWithFallback = true

c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
d := time.Now().Add(time.Second)
ctx, cancel := context.WithDeadline(context.Background(), d)
Expand All @@ -2114,10 +2120,16 @@ func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
}

func TestContextWithFallbackDoneFromRequestContext(t *testing.T) {
c := &Context{}
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true

assert.Nil(t, c.Done())

c2 := &Context{}
c2, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c2.engine.ContextWithFallback = true

c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(context.Background())
c2.Request = c2.Request.WithContext(ctx)
Expand All @@ -2126,10 +2138,16 @@ func TestContextWithFallbackDoneFromRequestContext(t *testing.T) {
}

func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
c := &Context{}
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true

assert.Nil(t, c.Err())

c2 := &Context{}
c2, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c2.engine.ContextWithFallback = true

c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(context.Background())
c2.Request = c2.Request.WithContext(ctx)
Expand All @@ -2138,9 +2156,9 @@ func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
assert.EqualError(t, c2.Err(), context.Canceled.Error())
}

type contextKey string

func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
type contextKey string

tests := []struct {
name string
getContextAndKey func() (*Context, any)
Expand All @@ -2150,7 +2168,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
name: "c with struct context key",
getContextAndKey: func() (*Context, any) {
var key struct{}
c := &Context{}
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
c.Request, _ = http.NewRequest("POST", "/", nil)
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value"))
return c, key
Expand All @@ -2160,7 +2180,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
{
name: "c with string context key",
getContextAndKey: func() (*Context, any) {
c := &Context{}
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
c.Request, _ = http.NewRequest("POST", "/", nil)
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), contextKey("key"), "value"))
return c, contextKey("key")
Expand All @@ -2170,15 +2192,20 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
{
name: "c with nil http.Request",
getContextAndKey: func() (*Context, any) {
c := &Context{}
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
c.Request = nil
return c, "key"
},
value: nil,
},
{
name: "c with nil http.Request.Context()",
getContextAndKey: func() (*Context, any) {
c := &Context{}
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
c.Request, _ = http.NewRequest("POST", "/", nil)
return c, "key"
},
Expand All @@ -2193,6 +2220,70 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
}
}

func TestContextCopyShouldNotCancel(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

ensureRequestIsOver := make(chan struct{})

wg := &sync.WaitGroup{}

r := New()
r.GET("/", func(ginctx *Context) {
wg.Add(1)

ginctx = ginctx.Copy()

// start async goroutine for calling srv
go func() {
defer wg.Done()

<-ensureRequestIsOver // ensure request is done

req, err := http.NewRequestWithContext(ginctx, http.MethodGet, srv.URL, nil)
must(err)

res, err := http.DefaultClient.Do(req)
if err != nil {
t.Error(fmt.Errorf("request error: %w", err))
return
}

if res.StatusCode != http.StatusOK {
t.Error(fmt.Errorf("unexpected status code: %s", res.Status))
}
}()
})

l, err := net.Listen("tcp", ":0")
must(err)
go func() {
s := &http.Server{
Handler: r,
}

must(s.Serve(l))
}()

addr := strings.Split(l.Addr().String(), ":")
res, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/", addr[len(addr)-1]))
if err != nil {
t.Error(fmt.Errorf("request error: %w", err))
return
}

close(ensureRequestIsOver)

if res.StatusCode != http.StatusOK {
t.Error(fmt.Errorf("unexpected status code: %s", res.Status))
return
}

wg.Wait()
}

func TestContextAddParam(t *testing.T) {
c := &Context{}
id := "id"
Expand Down
3 changes: 3 additions & 0 deletions gin.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ type Engine struct {
// UseH2C enable h2c support.
UseH2C bool

// ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() when Context.Request.Context() is not nil.
ContextWithFallback bool

delims render.Delims
secureJSONPrefix string
HTMLRender render.HTMLRender
Expand Down

0 comments on commit f197a8b

Please sign in to comment.