Skip to content

Commit

Permalink
Fix #542 Add additional context values for FastAPI apps (#544)
Browse files Browse the repository at this point in the history
  • Loading branch information
seratch authored Dec 14, 2021
1 parent d5289c9 commit bc094f0
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 14 deletions.
41 changes: 41 additions & 0 deletions examples/fastapi/async_app_custom_props.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging

logging.basicConfig(level=logging.DEBUG)

from slack_bolt.async_app import AsyncApp
from slack_bolt.adapter.fastapi.async_handler import AsyncSlackRequestHandler

app = AsyncApp()
app_handler = AsyncSlackRequestHandler(app)


@app.event("app_mention")
async def handle_app_mentions(context, say, logger):
logger.info(context)
assert context.get("foo") == "FOO"
await say("What's up?")


@app.event("message")
async def handle_message():
pass


from fastapi import FastAPI, Request, Depends

api = FastAPI()


def get_foo():
yield "FOO"


@api.post("/slack/events")
async def endpoint(req: Request, foo: str = Depends(get_foo)):
return await app_handler.handle(req, {"foo": foo})


# pip install -r requirements.txt
# export SLACK_SIGNING_SECRET=***
# export SLACK_BOT_TOKEN=xoxb-***
# uvicorn async_app_custom_props:api --reload --port 3000 --log-level warning
26 changes: 20 additions & 6 deletions slack_bolt/adapter/starlette/async_handler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict, Any, Optional

from starlette.requests import Request
from starlette.responses import Response

Expand All @@ -6,12 +8,20 @@
from slack_bolt.oauth.async_oauth_flow import AsyncOAuthFlow


def to_async_bolt_request(req: Request, body: bytes) -> AsyncBoltRequest:
return AsyncBoltRequest(
def to_async_bolt_request(
req: Request,
body: bytes,
addition_context_properties: Optional[Dict[str, Any]] = None,
) -> AsyncBoltRequest:
request = AsyncBoltRequest(
body=body.decode("utf-8"),
query=req.query_params,
headers=req.headers,
)
if addition_context_properties is not None:
for k, v in addition_context_properties.items():
request.context[k] = v
return request


def to_starlette_response(bolt_resp: BoltResponse) -> Response:
Expand Down Expand Up @@ -39,23 +49,27 @@ class AsyncSlackRequestHandler:
def __init__(self, app: AsyncApp): # type: ignore
self.app = app

async def handle(self, req: Request) -> Response:
async def handle(
self, req: Request, addition_context_properties: Optional[Dict[str, Any]] = None
) -> Response:
body = await req.body()
if req.method == "GET":
if self.app.oauth_flow is not None:
oauth_flow: AsyncOAuthFlow = self.app.oauth_flow
if req.url.path == oauth_flow.install_path:
bolt_resp = await oauth_flow.handle_installation(
to_async_bolt_request(req, body)
to_async_bolt_request(req, body, addition_context_properties)
)
return to_starlette_response(bolt_resp)
elif req.url.path == oauth_flow.redirect_uri_path:
bolt_resp = await oauth_flow.handle_callback(
to_async_bolt_request(req, body)
to_async_bolt_request(req, body, addition_context_properties)
)
return to_starlette_response(bolt_resp)
elif req.method == "POST":
bolt_resp = await self.app.async_dispatch(to_async_bolt_request(req, body))
bolt_resp = await self.app.async_dispatch(
to_async_bolt_request(req, body, addition_context_properties)
)
return to_starlette_response(bolt_resp)

return Response(
Expand Down
28 changes: 22 additions & 6 deletions slack_bolt/adapter/starlette/handler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
from typing import Dict, Any, Optional

from starlette.requests import Request
from starlette.responses import Response

from slack_bolt import BoltRequest, App, BoltResponse
from slack_bolt.oauth import OAuthFlow


def to_bolt_request(req: Request, body: bytes) -> BoltRequest:
return BoltRequest(
def to_bolt_request(
req: Request,
body: bytes,
addition_context_properties: Optional[Dict[str, Any]] = None,
) -> BoltRequest:
request = BoltRequest(
body=body.decode("utf-8"),
query=req.query_params,
headers=req.headers,
)
if addition_context_properties is not None:
for k, v in addition_context_properties.items():
request.context[k] = v
return request


def to_starlette_response(bolt_resp: BoltResponse) -> Response:
Expand Down Expand Up @@ -38,21 +48,27 @@ class SlackRequestHandler:
def __init__(self, app: App): # type: ignore
self.app = app

async def handle(self, req: Request) -> Response:
async def handle(
self, req: Request, addition_context_properties: Optional[Dict[str, Any]] = None
) -> Response:
body = await req.body()
if req.method == "GET":
if self.app.oauth_flow is not None:
oauth_flow: OAuthFlow = self.app.oauth_flow
if req.url.path == oauth_flow.install_path:
bolt_resp = oauth_flow.handle_installation(
to_bolt_request(req, body)
to_bolt_request(req, body, addition_context_properties)
)
return to_starlette_response(bolt_resp)
elif req.url.path == oauth_flow.redirect_uri_path:
bolt_resp = oauth_flow.handle_callback(to_bolt_request(req, body))
bolt_resp = oauth_flow.handle_callback(
to_bolt_request(req, body, addition_context_properties)
)
return to_starlette_response(bolt_resp)
elif req.method == "POST":
bolt_resp = self.app.dispatch(to_bolt_request(req, body))
bolt_resp = self.app.dispatch(
to_bolt_request(req, body, addition_context_properties)
)
return to_starlette_response(bolt_resp)

return Response(
Expand Down
50 changes: 49 additions & 1 deletion tests/adapter_tests/starlette/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from time import time
from urllib.parse import quote

from fastapi import FastAPI
from fastapi import FastAPI, Depends
from slack_sdk.signature import SignatureVerifier
from slack_sdk.web import WebClient
from starlette.requests import Request
Expand Down Expand Up @@ -214,3 +214,51 @@ async def endpoint(req: Request):
assert response.status_code == 200
assert response.headers.get("content-type") == "text/html; charset=utf-8"
assert "https://slack.com/oauth/v2/authorize?state=" in response.text

def test_custom_props(self):
app = App(
client=self.web_client,
signing_secret=self.signing_secret,
)

def shortcut_handler(ack, context):
assert context.get("foo") == "FOO"
ack()

app.shortcut("test-shortcut")(shortcut_handler)

input = {
"type": "shortcut",
"token": "verification_token",
"action_ts": "111.111",
"team": {
"id": "T111",
"domain": "workspace-domain",
"enterprise_id": "E111",
"enterprise_name": "Org Name",
},
"user": {"id": "W111", "username": "primary-owner", "team_id": "T111"},
"callback_id": "test-shortcut",
"trigger_id": "111.111.xxxxxx",
}

timestamp, body = str(int(time())), f"payload={quote(json.dumps(input))}"

api = FastAPI()
app_handler = SlackRequestHandler(app)

def get_foo():
yield "FOO"

@api.post("/slack/events")
async def endpoint(req: Request, foo: str = Depends(get_foo)):
return await app_handler.handle(req, {"foo": foo})

client = TestClient(api)
response = client.post(
"/slack/events",
data=body,
headers=self.build_headers(timestamp, body),
)
assert response.status_code == 200
assert_auth_test_count(self, 1)
50 changes: 49 additions & 1 deletion tests/adapter_tests_async/test_async_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from time import time
from urllib.parse import quote

from fastapi import FastAPI
from fastapi import FastAPI, Depends
from slack_sdk.signature import SignatureVerifier
from slack_sdk.web.async_client import AsyncWebClient
from starlette.requests import Request
Expand Down Expand Up @@ -215,3 +215,51 @@ async def endpoint(req: Request):
assert response.headers.get("content-type") == "text/html; charset=utf-8"
assert response.headers.get("content-length") == "597"
assert "https://slack.com/oauth/v2/authorize?state=" in response.text

def test_custom_props(self):
app = AsyncApp(
client=self.web_client,
signing_secret=self.signing_secret,
)

async def shortcut_handler(ack, context):
assert context.get("foo") == "FOO"
await ack()

app.shortcut("test-shortcut")(shortcut_handler)

input = {
"type": "shortcut",
"token": "verification_token",
"action_ts": "111.111",
"team": {
"id": "T111",
"domain": "workspace-domain",
"enterprise_id": "E111",
"enterprise_name": "Org Name",
},
"user": {"id": "W111", "username": "primary-owner", "team_id": "T111"},
"callback_id": "test-shortcut",
"trigger_id": "111.111.xxxxxx",
}

timestamp, body = str(int(time())), f"payload={quote(json.dumps(input))}"

api = FastAPI()
app_handler = AsyncSlackRequestHandler(app)

def get_foo():
yield "FOO"

@api.post("/slack/events")
async def endpoint(req: Request, foo: str = Depends(get_foo)):
return await app_handler.handle(req, {"foo": foo})

client = TestClient(api)
response = client.post(
"/slack/events",
data=body,
headers=self.build_headers(timestamp, body),
)
assert response.status_code == 200
assert_auth_test_count(self, 1)

0 comments on commit bc094f0

Please sign in to comment.