Skip to content

Commit

Permalink
Launch webbrowser for oauth2 authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet authored and hashhar committed Apr 20, 2022
1 parent 1714795 commit 5681fb7
Show file tree
Hide file tree
Showing 5 changed files with 452 additions and 138 deletions.
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,14 @@ the [`JWT` authentication type](https://trino.io/docs/current/security/jwt.html)

### OAuth2 Authentication

- `OAuth2Authentication` class can be used to connect to a Trino cluster configured with
The `OAuth2Authentication` class can be used to connect to a Trino cluster configured with
the [OAuth2 authentication type](https://trino.io/docs/current/security/oauth2.html).
- A callback to handle the redirect url can be provided via param `redirect_auth_url_handler`, by default it just outputs the redirect url to stdout.

* DBAPI
A callback to handle the redirect url can be provided via param `redirect_auth_url_handler` of the `trino.auth.OAuth2Authentication` class. By default, it will try to launch a web browser (`trino.auth.WebBrowserRedirectHandler`) to go through the authentication flow and output the redirect url to stdout (`trino.auth.ConsoleRedirectHandler`). Multiple redirect handlers are combined using the `trino.auth.CompositeRedirectHandler` class.

The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` instance.

- DBAPI

```python
from trino.dbapi import connect
Expand All @@ -185,7 +188,7 @@ the [OAuth2 authentication type](https://trino.io/docs/current/security/oauth2.h
)
```

* SQLAlchemy
- SQLAlchemy

```python
from sqlalchemy import create_engine
Expand Down
128 changes: 128 additions & 0 deletions tests/unit/oauth_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import re
import uuid
from collections import namedtuple

import httpretty

from trino import constants

SERVER_ADDRESS = "https://coordinator"
REDIRECT_PATH = "oauth2/initiate"
TOKEN_PATH = "oauth2/token"
REDIRECT_RESOURCE = f"{SERVER_ADDRESS}/{REDIRECT_PATH}"
TOKEN_RESOURCE = f"{SERVER_ADDRESS}/{TOKEN_PATH}"


class RedirectHandler:
def __init__(self):
self.redirect_server = ""

def __call__(self, url):
self.redirect_server += url


class PostStatementCallback:
def __init__(self, redirect_server, token_server, tokens, sample_post_response_data):
self.redirect_server = redirect_server
self.token_server = token_server
self.tokens = tokens
self.sample_post_response_data = sample_post_response_data

def __call__(self, request, uri, response_headers):
authorization = request.headers.get("Authorization")
if authorization and authorization.replace("Bearer ", "") in self.tokens:
return [200, response_headers, json.dumps(self.sample_post_response_data)]
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", '
f'x_token_server="{self.token_server}"',
'Basic realm': '"Trino"'}, ""]


class GetTokenCallback:
def __init__(self, token_server, token, attempts=1):
self.token_server = token_server
self.token = token
self.attempts = attempts

def __call__(self, request, uri, response_headers):
self.attempts -= 1
if self.attempts < 0:
return [404, response_headers, "{}"]
if self.attempts == 0:
return [200, response_headers, f'{{"token": "{self.token}"}}']
return [200, response_headers, f'{{"nextUri": "{self.token_server}"}}']


def _get_token_requests(challenge_id):
return list(filter(
lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}",
httpretty.latest_requests()))


def _post_statement_requests():
return list(filter(
lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH,
httpretty.latest_requests()))


class MultithreadedTokenServer:
Challenge = namedtuple('Challenge', ['token', 'attempts'])

def __init__(self, sample_post_response_data, attempts=1):
self.tokens = set()
self.challenges = {}
self.sample_post_response_data = sample_post_response_data
self.attempts = attempts

# bind post statement
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
body=self.post_statement_callback)

# bind get token
httpretty.register_uri(
method=httpretty.GET,
uri=re.compile(rf"{TOKEN_RESOURCE}/.*"),
body=self.get_token_callback)

# noinspection PyUnusedLocal
def post_statement_callback(self, request, uri, response_headers):
authorization = request.headers.get("Authorization")

if authorization and authorization.replace("Bearer ", "") in self.tokens:
return [200, response_headers, json.dumps(self.sample_post_response_data)]

challenge_id = str(uuid.uuid4())
token = str(uuid.uuid4())
self.tokens.add(token)
self.challenges[challenge_id] = MultithreadedTokenServer.Challenge(token, self.attempts)
redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{redirect_server}", '
f'x_token_server="{token_server}"',
'Basic realm': '"Trino"'}, ""]

# noinspection PyUnusedLocal
def get_token_callback(self, request, uri, response_headers):
challenge_id = uri.replace(f"{TOKEN_RESOURCE}/", "")
challenge = self.challenges[challenge_id]
challenge = challenge._replace(attempts=challenge.attempts - 1)
self.challenges[challenge_id] = challenge
if challenge.attempts < 0:
return [404, response_headers, "{}"]
if challenge.attempts == 0:
return [200, response_headers, f'{{"token": "{challenge.token}"}}']
return [200, response_headers, f'{{"nextUri": "{uri}"}}']
124 changes: 8 additions & 116 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
import threading
import time
import uuid
from collections import namedtuple
from unittest import mock
from urllib.parse import urlparse

Expand All @@ -25,6 +23,9 @@
from requests_kerberos.exceptions import KerberosExchangeError

import trino.exceptions
from tests.unit.oauth_test_utils import RedirectHandler, GetTokenCallback, PostStatementCallback, \
MultithreadedTokenServer, _post_statement_requests, _get_token_requests, REDIRECT_RESOURCE, TOKEN_RESOURCE, \
SERVER_ADDRESS
from trino import constants
from trino.auth import KerberosAuthentication, _OAuth2TokenBearer
from trino.client import TrinoQuery, TrinoRequest, TrinoResult
Expand Down Expand Up @@ -259,52 +260,6 @@ def long_call(request, uri, headers):
httpretty.reset()


SERVER_ADDRESS = "https://coordinator"
REDIRECT_PATH = "oauth2/initiate"
TOKEN_PATH = "oauth2/token"
REDIRECT_RESOURCE = f"{SERVER_ADDRESS}/{REDIRECT_PATH}"
TOKEN_RESOURCE = f"{SERVER_ADDRESS}/{TOKEN_PATH}"


class RedirectHandler:
def __init__(self):
self.redirect_server = ""

def __call__(self, url):
self.redirect_server += url


class PostStatementCallback:
def __init__(self, redirect_server, token_server, tokens, sample_post_response_data):
self.redirect_server = redirect_server
self.token_server = token_server
self.tokens = tokens
self.sample_post_response_data = sample_post_response_data

def __call__(self, request, uri, response_headers):
authorization = request.headers.get("Authorization")
if authorization and authorization.replace("Bearer ", "") in self.tokens:
return [200, response_headers, json.dumps(self.sample_post_response_data)]
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", '
f'x_token_server="{self.token_server}"',
'Basic realm': '"Trino"'}, ""]


class GetTokenCallback:
def __init__(self, token_server, token, attempts=1):
self.token_server = token_server
self.token = token
self.attempts = attempts

def __call__(self, request, uri, response_headers):
self.attempts -= 1
if self.attempts < 0:
return [404, response_headers, "{}"]
if self.attempts == 0:
return [200, response_headers, f'{{"token": "{self.token}"}}']
return [200, response_headers, f'{{"nextUri": "{self.token_server}"}}']


@pytest.mark.parametrize("attempts", [1, 3, 5])
@httprettified
def test_oauth2_authentication_flow(attempts, sample_post_response_data):
Expand Down Expand Up @@ -511,57 +466,6 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
assert len(_get_token_requests(challenge_id)) == 1


class MultithreadedTokenServer:
Challenge = namedtuple('Challenge', ['token', 'attempts'])

def __init__(self, sample_post_response_data, attempts=1):
self.tokens = set()
self.challenges = {}
self.sample_post_response_data = sample_post_response_data
self.attempts = attempts

# bind post statement
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
body=self.post_statement_callback)

# bind get token
httpretty.register_uri(
method=httpretty.GET,
uri=re.compile(rf"{TOKEN_RESOURCE}/.*"),
body=self.get_token_callback)

# noinspection PyUnusedLocal
def post_statement_callback(self, request, uri, response_headers):
authorization = request.headers.get("Authorization")

if authorization and authorization.replace("Bearer ", "") in self.tokens:
return [200, response_headers, json.dumps(self.sample_post_response_data)]

challenge_id = str(uuid.uuid4())
token = str(uuid.uuid4())
self.tokens.add(token)
self.challenges[challenge_id] = MultithreadedTokenServer.Challenge(token, self.attempts)
redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{redirect_server}", '
f'x_token_server="{token_server}"',
'Basic realm': '"Trino"'}, ""]

# noinspection PyUnusedLocal
def get_token_callback(self, request, uri, response_headers):
challenge_id = uri.replace(f"{TOKEN_RESOURCE}/", "")
challenge = self.challenges[challenge_id]
challenge = challenge._replace(attempts=challenge.attempts - 1)
self.challenges[challenge_id] = challenge
if challenge.attempts < 0:
return [404, response_headers, "{}"]
if challenge.attempts == 0:
return [200, response_headers, f'{{"token": "{challenge.token}"}}']
return [200, response_headers, f'{{"nextUri": "{uri}"}}']


@httprettified
def test_multithreaded_oauth2_authentication_flow(sample_post_response_data):
redirect_handler = RedirectHandler()
Expand Down Expand Up @@ -598,31 +502,19 @@ def run(self) -> None:
for thread in threads:
thread.join()

# should issue only 3 tokens and each thread should get one
assert len(token_server.tokens) == 3
# should issue only 1 token and each thread should reuse it
assert len(token_server.tokens) == 1
for thread in threads:
assert thread.token in token_server.tokens

# should start only 3 challenges and every token should be obtained
assert len(token_server.challenges.keys()) == 3
# should start only 1 challenge
assert len(token_server.challenges.keys()) == 1
for challenge_id, challenge in token_server.challenges.items():
assert f"{REDIRECT_RESOURCE}/{challenge_id}" in redirect_handler.redirect_server
assert challenge.attempts == 0
assert len(_get_token_requests(challenge_id)) == 1
# 3 threads * (10 POST /statement each + 1 replied request by authentication)
assert len(_post_statement_requests()) == 33


def _get_token_requests(challenge_id):
return list(filter(
lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}",
httpretty.latest_requests()))


def _post_statement_requests():
return list(filter(
lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH,
httpretty.latest_requests()))
assert len(_post_statement_requests()) == 31


@mock.patch("trino.client.TrinoRequest.http")
Expand Down
Loading

0 comments on commit 5681fb7

Please sign in to comment.