Skip to content

Commit

Permalink
add support for api gateway v2 requests and responses. update to g 1.…
Browse files Browse the repository at this point in the history
…21 for slog adapter
  • Loading branch information
iamatypeofwalrus committed Apr 6, 2024
1 parent 27a1d5b commit 4525257
Show file tree
Hide file tree
Showing 13 changed files with 398 additions and 50 deletions.
4 changes: 2 additions & 2 deletions http_request.go → api_gateway_proxy_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ var (
errCouldNotCreateHTTPRequest = errors.New("encountered error while create http request")
)

// NewHTTPRequest creates an *http.Request from a context.Context and an events.APIGatewayProxyRequest
func NewHTTPRequest(ctx context.Context, event events.APIGatewayProxyRequest) (*http.Request, error) {
// NewHttpRequestFromAPIGatewayProxyRequest creates an *http.Request from a context.Context and an events.APIGatewayProxyRequest
func NewHttpRequestFromAPIGatewayProxyRequest(ctx context.Context, event events.APIGatewayProxyRequest) (*http.Request, error) {
u, err := url.Parse(event.Path)
if err != nil {
return nil, errCouldNotParsePath
Expand Down
8 changes: 4 additions & 4 deletions http_request_test.go → api_gateway_proxy_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestNewHTTPRequestPassesQueryStrings(t *testing.T) {
QueryStringParameters: qp,
}

req, err := NewHTTPRequest(context.TODO(), event)
req, err := NewHttpRequestFromAPIGatewayProxyRequest(context.TODO(), event)
if err != nil {
t.Errorf("expected error from New HTTPRequest to be nil")
t.Logf("err: %s\n", err)
Expand Down Expand Up @@ -58,7 +58,7 @@ func TestNewHTTPRequestDecodesBase64Bodies(t *testing.T) {
IsBase64Encoded: true,
}

req, err := NewHTTPRequest(context.TODO(), event)
req, err := NewHttpRequestFromAPIGatewayProxyRequest(context.TODO(), event)
if err != nil {
t.Fatal("execpted error from NewHTTPRequest to be nil but was ", err)
}
Expand Down Expand Up @@ -87,7 +87,7 @@ func TestNewHTTPRequestPassesHeaders(t *testing.T) {
},
}

req, err := NewHTTPRequest(context.TODO(), event)
req, err := NewHttpRequestFromAPIGatewayProxyRequest(context.TODO(), event)
if err != nil {
t.Fatal("exepected error from NewHTTPRequest to be nil but was ", err)
}
Expand All @@ -105,7 +105,7 @@ func TestNewHTTPRequestSetsContentLength(t *testing.T) {
Body: body,
}

req, err := NewHTTPRequest(context.TODO(), event)
req, err := NewHttpRequestFromAPIGatewayProxyRequest(context.TODO(), event)
if err != nil {
t.Fatal("expected error from NewHTTPRequest to be nil but was ", err)
}
Expand Down
12 changes: 3 additions & 9 deletions api_gateway_proxy_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

const (
contentType = "Content-Type"
httpHeaderContentType = "Content-Type"
multipleValueSeperator = ","
prefixText = "text/"
)
Expand All @@ -31,10 +31,10 @@ func NewAPIGatewayProxyResponse(rw *ResponseWriter) events.APIGatewayProxyRespon
}

headers := formatHeaders(rw.Headers)
setDefaultContentType(headers, rw.Body.Bytes())
headers[httpHeaderContentType] = http.DetectContentType(rw.Body.Bytes())
resp.Headers = headers

if shouldConvertToBase64(resp.Headers[contentType]) {
if shouldConvertToBase64(resp.Headers[httpHeaderContentType]) {
resp.Body = base64.StdEncoding.EncodeToString(rw.Body.Bytes())
resp.IsBase64Encoded = true
} else {
Expand Down Expand Up @@ -67,12 +67,6 @@ func formatHeaders(h http.Header) map[string]string {
return headers
}

func setDefaultContentType(lambdaHeaders map[string]string, body []byte) {
if _, ok := lambdaHeaders[contentType]; !ok {
lambdaHeaders[contentType] = http.DetectContentType(body)
}
}

func shouldConvertToBase64(ct string) bool {
mimeType, _, err := mime.ParseMediaType(ct)
if err != nil {
Expand Down
18 changes: 14 additions & 4 deletions api_gateway_proxy_response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,15 @@ func TestFormatHeadersHandlesMultipleValusForAKey(t *testing.T) {
}

func TestNewAPIGatewayProxyResponseConvertsNonTextResponsesToBase64(t *testing.T) {
body := "hello, world"
input := "hello, world"
body, err := gzipString(input)
if err != nil {
t.Fatalf("unable to gzip string: %v", err)
}

rw := NewResponseWriter()
rw.Write([]byte(body))
rw.Headers.Set(contentType, "application/octet-stream")
rw.Headers.Set(httpHeaderContentType, "application/octet-stream")
resp := NewAPIGatewayProxyResponse(rw)

if !resp.IsBase64Encoded {
Expand All @@ -68,8 +73,13 @@ func TestNewAPIGatewayProxyResponseConvertsNonTextResponsesToBase64(t *testing.T
t.Fatal("expected error from base64 decode to be nil but was", err)
}

if string(decodedBody) != body {
t.Errorf("expected decodedBody to be %v but was %v", body, string(decodedBody))
decodedGunzippedInput, err := gunzipBytes(decodedBody)
if err != nil {
t.Fatal("expected error from gunzip to be nil but was", err)
}

if input != decodedGunzippedInput {
t.Errorf("expected decodedBody to be %v but was %v", input, decodedGunzippedInput)
}
}

Expand Down
62 changes: 62 additions & 0 deletions api_gateway_v2_http_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package shim

import (
"context"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"

"github.com/aws/aws-lambda-go/events"
)

// NewHttpRequestFromAPIGatewayV2HTTPRequest creates an *http.Request from the context passed from the Lambda library, and the event itself.
func NewHttpRequestFromAPIGatewayV2HTTPRequest(ctx context.Context, event events.APIGatewayV2HTTPRequest) (*http.Request, error) {
u, err := url.Parse(event.RawPath)
if err != nil {
return nil, fmt.Errorf("shim could not parse path from event: %w", err)
}

if event.RawQueryString != "" {
u.RawQuery = event.RawQueryString
}

body := event.Body
if event.IsBase64Encoded {
d, err := base64.StdEncoding.DecodeString(body)
if err != nil {
return nil, fmt.Errorf("shim encountered an error while base64 decoding request body: %w", err)
}

body = string(d)
}

req, err := http.NewRequest(
event.RequestContext.HTTP.Method,
u.String(),
strings.NewReader(body),
)

if err != nil {
return nil, fmt.Errorf("shim could not create http request from event: %w", err)
}

for h, v := range event.Headers {
req.Header.Set(h, v)
}

req.URL.Host = req.Header.Get("Host")
req.Host = req.Header.Get("Host")

req.RemoteAddr = event.RequestContext.HTTP.SourceIP

if req.Header.Get(contentLength) == "" && body != "" {
req.Header.Set(contentLength, strconv.Itoa(len(body)))
}

req = req.WithContext(ctx)

return req, nil
}
87 changes: 87 additions & 0 deletions api_gateway_v2_http_request_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package shim

import (
"context"
"encoding/base64"
"io/ioutil"
"net/http"
"strings"
"testing"

"github.com/aws/aws-lambda-go/events"
)

func TestNewHttpRequestFromAPIGatewayV2HTTPRequest(t *testing.T) {
ctx := context.Background()

// Create a sample APIGatewayV2HTTPRequest event
event := events.APIGatewayV2HTTPRequest{
Version: "2.0",
RouteKey: "GET /hello",
RawPath: "/hello",
RawQueryString: "name=John&age=30",
Headers: map[string]string{"Content-Type": "application/json"},
QueryStringParameters: map[string]string{"name": "John", "age": "30"},
}

// Call the function under test
req, err := NewHttpRequestFromAPIGatewayV2HTTPRequest(ctx, event)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

// Verify the created request
if req.Method != http.MethodGet {
t.Errorf("expected GET method, got %s", req.Method)
}

if req.URL.Path != "/hello" {
t.Errorf("expected path '/hello', got %s", req.URL.Path)
}

if req.URL.RawQuery != "name=John&age=30" {
t.Errorf("expected query string 'name=John&age=30', got %s", req.URL.RawQuery)
}

if req.Header.Get("Content-Type") != "application/json" {
t.Errorf("expected 'Content-Type' header to be 'application/json', got %s", req.Header.Get("Content-Type"))
}

if req.Context() != ctx {
t.Error("expected context to be the same")
}
}

func TestNewHttpRequestFromAPIGatewayV2HTTPRequest_Base64(t *testing.T) {
ctx := context.Background()

// Create a sample APIGatewayV2HTTPRequest event with base64 encoded body
body := "Hello, World!"
encodedBody := base64.StdEncoding.EncodeToString([]byte(body))
event := events.APIGatewayV2HTTPRequest{
Version: "2.0",
RouteKey: "GET /hello",
RawPath: "/hello",
RawQueryString: "name=John&age=30",
Headers: map[string]string{"Content-Type": "application/json"},
QueryStringParameters: map[string]string{"name": "John", "age": "30"},
Body: encodedBody,
IsBase64Encoded: true,
}

// Call the function under test
req, err := NewHttpRequestFromAPIGatewayV2HTTPRequest(ctx, event)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

// Verify the created request
decodedBody, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Errorf("unexpected error when reading request body: %v", err)
}

if strings.TrimSpace(string(decodedBody)) != body {
t.Errorf("expected body '%s', got '%s'", body, string(decodedBody))
}
}
28 changes: 28 additions & 0 deletions api_gateway_v2_http_response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package shim

import (
"encoding/base64"
"net/http"

"github.com/aws/aws-lambda-go/events"
)

func NewApiGatewayV2HttpResponse(rw *ResponseWriter) events.APIGatewayV2HTTPResponse {
resp := events.APIGatewayV2HTTPResponse{
StatusCode: rw.Code,
}

headers := rw.Headers
headers[httpHeaderContentType] = []string{http.DetectContentType(rw.Body.Bytes())}

resp.MultiValueHeaders = headers

if shouldConvertToBase64(rw.Headers.Get(httpHeaderContentType)) {
resp.Body = base64.StdEncoding.EncodeToString(rw.Body.Bytes())
resp.IsBase64Encoded = true
} else {
resp.Body = string(rw.Body.String())
}

return resp
}
89 changes: 89 additions & 0 deletions api_gateway_v2_http_response_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package shim

import (
"bytes"
"compress/gzip"
"encoding/base64"
"fmt"
"io"
"net/http"
"testing"
)

func TestNewApiGatewayV2HttpResponse_Base64(t *testing.T) {
body, err := gzipString("hello, world")
if err != nil {
t.Fatalf("unable to gzip string: %v", err)
}

rw := NewResponseWriter()
rw.Write([]byte(body))
rw.Headers.Set(httpHeaderContentType, "application/octet-stream")
resp := NewApiGatewayV2HttpResponse(rw)

if resp.StatusCode != rw.Code {
t.Errorf("expected status code %d, got %d", rw.Code, resp.StatusCode)
}

if resp.Body != base64.StdEncoding.EncodeToString(rw.Body.Bytes()) {
t.Errorf("expected body to be base64 encoded")
}

if !resp.IsBase64Encoded {
t.Errorf("expected IsBase64Encoded to be true")
}
}

func TestNewApiGatewayV2HttpResponse_NoBase64(t *testing.T) {
rw := &ResponseWriter{
Code: http.StatusOK,
Headers: http.Header{},
Body: *bytes.NewBufferString("Hello, World!"),
}
rw.Headers.Set(httpHeaderContentType, "text/plain")

resp := NewApiGatewayV2HttpResponse(rw)

if resp.StatusCode != rw.Code {
t.Errorf("expected status code %d, got %d", rw.Code, resp.StatusCode)
}

if resp.Body != rw.Body.String() {
t.Errorf("expected body to be '%s', got '%s'", rw.Body.String(), resp.Body)
}

if resp.IsBase64Encoded {
t.Errorf("expected IsBase64Encoded to be false")
}
}

func gzipString(input string) ([]byte, error) {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)

if _, err := gz.Write([]byte(input)); err != nil {
return nil, fmt.Errorf("unable to write to gzip writer: %w", err)
}

if err := gz.Close(); err != nil {
return nil, fmt.Errorf("unable to close gzip writer: %w", err)
}

return buf.Bytes(), nil
}

func gunzipBytes(input []byte) (string, error) {
buf := bytes.NewBuffer(input)
gz, err := gzip.NewReader(buf)
if err != nil {
return "", fmt.Errorf("unable to create gzip reader: %w", err)
}
defer gz.Close()

res, err := io.ReadAll(gz)
if err != nil {
return "", fmt.Errorf("unable to read from gzip reader: %w", err)
}

return string(res), nil
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module github.com/iamatypeofwalrus/shim

go 1.17
go 1.21

require github.com/aws/aws-lambda-go v1.46.0
Loading

0 comments on commit 4525257

Please sign in to comment.