Skip to content

Commit

Permalink
Add retries for 502, 504 HTTP statuses
Browse files Browse the repository at this point in the history
Add 502 Bad Gateway and 504 Gateway Timeout HTTP status to client's retry logic

Add tests for retries on 502,504 and test that no retry is done on another status (404)
  • Loading branch information
springjd authored and nineinchnick committed Jan 7, 2025
1 parent 70bd4d7 commit 3360780
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 23 deletions.
2 changes: 1 addition & 1 deletion trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ func (c *Conn) roundTrip(ctx context.Context, req *http.Request) (*http.Response
}
}
return resp, nil
case http.StatusServiceUnavailable:
case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
resp.Body.Close()
timer.Reset(delay)
delay = time.Duration(math.Min(
Expand Down
74 changes: 52 additions & 22 deletions trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,32 +266,62 @@ func TestRegisterCustomClientReserved(t *testing.T) {
}

func TestRoundTripRetryQueryError(t *testing.T) {
count := 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if count == 0 {
count++
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(&stmtResponse{
Error: ErrTrino{
ErrorName: "TEST",
},
})
}))
testcases := []struct {
Name string
HttpStatus int
ExpectedErrorStatus string
}{
{
Name: "Test retry 502 Bad Gateway",
HttpStatus: http.StatusBadGateway,
ExpectedErrorStatus: "200 OK",
},
{
Name: "Test retry 503 Service Unavailable",
HttpStatus: http.StatusServiceUnavailable,
ExpectedErrorStatus: "200 OK",
},
{
Name: "Test retry 504 Gateway Timeout",
HttpStatus: http.StatusGatewayTimeout,
ExpectedErrorStatus: "200 OK",
},
{
Name: "Test no retry 404 Not Found",
HttpStatus: http.StatusNotFound,
ExpectedErrorStatus: "404 Not Found",
},
}
for _, tc := range testcases {
t.Run(tc.Name, func(t *testing.T) {
count := 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if count == 0 {
count++
w.WriteHeader(tc.HttpStatus)
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(&stmtResponse{
Error: ErrTrino{
ErrorName: "TEST",
},
})
}))

t.Cleanup(ts.Close)
t.Cleanup(ts.Close)

db, err := sql.Open("trino", ts.URL)
require.NoError(t, err)
db, err := sql.Open("trino", ts.URL)
require.NoError(t, err)

t.Cleanup(func() {
assert.NoError(t, db.Close())
})
t.Cleanup(func() {
assert.NoError(t, db.Close())
})

_, err = db.Query("SELECT 1")
assert.IsTypef(t, new(ErrQueryFailed), err, "unexpected error: %w", err)
_, err = db.Query("SELECT 1")
assert.ErrorContains(t, err, tc.ExpectedErrorStatus, "unexpected error: %w", err)
})
}
}

func TestRoundTripBogusData(t *testing.T) {
Expand Down

0 comments on commit 3360780

Please sign in to comment.