Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Add logging for relevant Amazon headers useful for debugging #2390

Merged
merged 4 commits into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions integration/combination/test_api_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
except ImportError:
from pathlib2 import Path

import requests
from parameterized import parameterized

from integration.helpers.base_test import BaseTest
Expand Down Expand Up @@ -158,7 +157,7 @@ def test_implicit_api_settings(self):

def verify_binary_media_request(self, url, expected_status_code):
headers = {"accept": "image/png"}
response = requests.get(url, headers=headers)
response = BaseTest.do_get_request_with_logging(url, headers)

status = response.status_code
expected_file_path = str(Path(self.code_dir, "AWS_logo_RGB.png"))
Expand Down
10 changes: 2 additions & 8 deletions integration/combination/test_api_with_authorizer_apikey.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
from unittest.case import skipIf

import requests

from integration.helpers.base_test import BaseTest
from integration.helpers.deployer.utils.retry import retry
from integration.helpers.exception import StatusCodeError
from integration.helpers.resource import current_region_does_not_support
from integration.config.service_names import COGNITO


class TestApiWithAuthorizerApiKey(BaseTest):
Expand Down Expand Up @@ -82,10 +76,10 @@ def verify_authorized_request(
header_value=None,
):
if not header_key or not header_value:
response = requests.get(url)
response = BaseTest.do_get_request_with_logging(url)
else:
headers = {header_key: header_value}
response = requests.get(url, headers=headers)
response = BaseTest.do_get_request_with_logging(url, headers)
status = response.status_code
if status != expected_status_code:
raise StatusCodeError(
Expand Down
7 changes: 3 additions & 4 deletions integration/combination/test_api_with_authorizers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from unittest.case import skipIf

import requests

from integration.helpers.base_test import BaseTest
from integration.helpers.deployer.utils.retry import retry
from integration.helpers.exception import StatusCodeError
Expand Down Expand Up @@ -436,11 +434,12 @@ def verify_authorized_request(
header_value=None,
):
if not header_key or not header_value:
response = requests.get(url)
response = BaseTest.do_get_request_with_logging(url)
else:
headers = {header_key: header_value}
response = requests.get(url, headers=headers)
response = BaseTest.do_get_request_with_logging(url, headers)
status = response.status_code

if status != expected_status_code:
raise StatusCodeError(
"Request to {} failed with status: {}, expected status: {}".format(url, status, expected_status_code)
Expand Down
14 changes: 14 additions & 0 deletions integration/config/logger_configurations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Logger configurations
import logging


class LoggingConfiguration:
@staticmethod
def configure_request_logger(logger):
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(
logging.Formatter("%(asctime)s %(message)s | Status: %(status)s | Headers: %(headers)s ")
)
logger.addHandler(console_handler)
logger.propagate = False
31 changes: 27 additions & 4 deletions integration/helpers/base_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import logging
import os
import requests
import shutil

import botocore
import pytest
import requests

from integration.config.logger_configurations import LoggingConfiguration
from integration.helpers.client_provider import ClientProvider
from integration.helpers.deployer.exceptions.exceptions import ThrottlingError
from integration.helpers.deployer.utils.retry import retry_with_exponential_backoff_and_jitter
from integration.helpers.request_utils import RequestUtils
from integration.helpers.resource import generate_suffix, create_bucket, verify_stack_resources
from integration.helpers.s3_uploader import S3Uploader
from integration.helpers.yaml_utils import dump_yaml, load_yaml
Expand All @@ -28,6 +30,10 @@
from integration.helpers.file_resources import FILE_TO_S3_URI_MAP, CODE_KEY_TO_FILE_MAP

LOG = logging.getLogger(__name__)

REQUEST_LOGGER = logging.getLogger(f"{__name__}.requests")
LoggingConfiguration.configure_request_logger(REQUEST_LOGGER)

STACK_NAME_PREFIX = "sam-integ-stack-"
S3_BUCKET_PREFIX = "sam-integ-bucket-"

Expand Down Expand Up @@ -496,7 +502,7 @@ def verify_stack(self, end_state="CREATE_COMPLETE"):
if error:
self.fail(error)

def verify_get_request_response(self, url, expected_status_code):
def verify_get_request_response(self, url, expected_status_code, headers=None):
"""
Verify if the get request to a certain url return the expected status code

Expand All @@ -506,9 +512,10 @@ def verify_get_request_response(self, url, expected_status_code):
the url for the get request
expected_status_code : string
the expected status code
headers : dict
headers to use in request
"""
print("Making request to " + url)
response = requests.get(url)
response = BaseTest.do_get_request_with_logging(url, headers)
self.assertEqual(response.status_code, expected_status_code, " must return HTTP " + str(expected_status_code))
return response

Expand Down Expand Up @@ -547,3 +554,19 @@ def generate_parameter(key, value, previous_value=False, resolved_value="string"
"ResolvedValue": resolved_value,
}
return parameter

@staticmethod
def do_get_request_with_logging(url, headers=None):
"""
Perform a get request to an APIGW endpoint and log relevant info
Parameters
----------
url : string
the url for the get request
headers : dict
headers to use in request
"""
response = requests.get(url, headers=headers) if headers else requests.get(url)
amazon_headers = RequestUtils(response).get_amazon_headers()
REQUEST_LOGGER.info("Request made to " + url, extra={"status": response.status_code, "headers": amazon_headers})
return response
37 changes: 37 additions & 0 deletions integration/helpers/request_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Utils for requests

# Relevant headers that should be captured for debugging
AMAZON_HEADERS = [
"x-amzn-requestid",
"x-amz-apigw-id",
"x-amz-cf-id",
"x-amzn-errortype",
"apigw-requestid",
]


class RequestUtils:
def __init__(self, response):
self.response = response
self.headers = self._normalize_response_headers()

def get_amazon_headers(self):
"""
Get a list of relevant amazon headers that could be useful for debugging
"""
amazon_headers = {}
for header, header_val in self.headers.items():
if header in AMAZON_HEADERS:
amazon_headers[header] = header_val
return amazon_headers

def _normalize_response_headers(self):
"""
API gateway can return headers with letters in different cases i.e. x-amzn-requestid or x-amzn-requestId
We make them all lowercase here to more easily match them up
"""
if self.response is None or not self.response.headers:
# Need to check for response is None here since the __bool__ method checks 200 <= status < 400
return {}

return dict((k.lower(), v) for k, v in self.response.headers.items())
7 changes: 3 additions & 4 deletions integration/single/test_basic_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from unittest.case import skipIf

from integration.helpers.base_test import BaseTest
import requests

from integration.helpers.resource import current_region_does_not_support
from integration.config.service_names import MODE
Expand Down Expand Up @@ -40,17 +39,17 @@ def test_basic_api_with_mode(self):

stack_output = self.get_stack_outputs()
api_endpoint = stack_output.get("ApiEndpoint")
response = requests.get(f"{api_endpoint}/get")
response = BaseTest.do_get_request_with_logging(f"{api_endpoint}/get")
self.assertEqual(response.status_code, 200)

# Removes get from the API
self.update_and_verify_stack(file_path="single/basic_api_with_mode_update")
response = requests.get(f"{api_endpoint}/get")
response = BaseTest.do_get_request_with_logging(f"{api_endpoint}/get")
# API Gateway by default returns 403 if a path do not exist
retries = 20
while retries > 0:
retries -= 1
response = requests.get(f"{api_endpoint}/get")
response = BaseTest.do_get_request_with_logging(f"{api_endpoint}/get")
if response.status_code != 500:
break
time.sleep(5)
Expand Down
4 changes: 1 addition & 3 deletions integration/single/test_basic_function.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from unittest.case import skipIf

import requests

from integration.config.service_names import KMS, XRAY, ARM
from integration.helpers.resource import current_region_does_not_support
from parameterized import parameterized
Expand Down Expand Up @@ -42,7 +40,7 @@ def test_function_with_http_api_events(self, file_name):

endpoint = self.get_api_v2_endpoint("MyHttpApi")

self.assertEqual(requests.get(endpoint).text, self.FUNCTION_OUTPUT)
self.assertEqual(BaseTest.do_get_request_with_logging(endpoint).text, self.FUNCTION_OUTPUT)

@parameterized.expand(
[
Expand Down
24 changes: 15 additions & 9 deletions integration/single/test_function_with_http_api_and_auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import requests
from parameterized import parameterized
from integration.helpers.base_test import BaseTest


Expand All @@ -16,14 +14,22 @@ def test_function_with_http_api_and_auth(self):
self.create_and_verify_stack("function_with_http_api_events_and_auth")

implicitEndpoint = self.get_api_v2_endpoint("ServerlessHttpApi")
self.assertEqual(requests.get(implicitEndpoint + "/default-auth").text, self.FUNCTION_OUTPUT)
self.assertEqual(requests.get(implicitEndpoint + "/iam-auth").text, IAM_AUTH_OUTPUT)
self.assertEqual(
BaseTest.do_get_request_with_logging(implicitEndpoint + "/default-auth").text, self.FUNCTION_OUTPUT
)
self.assertEqual(BaseTest.do_get_request_with_logging(implicitEndpoint + "/iam-auth").text, IAM_AUTH_OUTPUT)

defaultIamEndpoint = self.get_api_v2_endpoint("MyDefaultIamAuthHttpApi")
self.assertEqual(requests.get(defaultIamEndpoint + "/no-auth").text, self.FUNCTION_OUTPUT)
self.assertEqual(requests.get(defaultIamEndpoint + "/default-auth").text, IAM_AUTH_OUTPUT)
self.assertEqual(requests.get(defaultIamEndpoint + "/iam-auth").text, IAM_AUTH_OUTPUT)
self.assertEqual(
BaseTest.do_get_request_with_logging(defaultIamEndpoint + "/no-auth").text, self.FUNCTION_OUTPUT
)
self.assertEqual(
BaseTest.do_get_request_with_logging(defaultIamEndpoint + "/default-auth").text, IAM_AUTH_OUTPUT
)
self.assertEqual(BaseTest.do_get_request_with_logging(defaultIamEndpoint + "/iam-auth").text, IAM_AUTH_OUTPUT)

iamEnabledEndpoint = self.get_api_v2_endpoint("MyIamAuthEnabledHttpApi")
self.assertEqual(requests.get(iamEnabledEndpoint + "/default-auth").text, self.FUNCTION_OUTPUT)
self.assertEqual(requests.get(iamEnabledEndpoint + "/iam-auth").text, IAM_AUTH_OUTPUT)
self.assertEqual(
BaseTest.do_get_request_with_logging(iamEnabledEndpoint + "/default-auth").text, self.FUNCTION_OUTPUT
)
self.assertEqual(BaseTest.do_get_request_with_logging(iamEnabledEndpoint + "/iam-auth").text, IAM_AUTH_OUTPUT)
3 changes: 2 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ env =
AWS_DEFAULT_REGION = ap-southeast-1
markers =
slow: marks tests as slow (deselect with '-m "not slow"')

log_cli = 1
log_cli_level = INFO