Skip to content

Commit

Permalink
Add working implementation of basic auth to server and client (#16408)
Browse files Browse the repository at this point in the history
  • Loading branch information
cicdw authored Dec 17, 2024
1 parent c4d3bb2 commit 9af5624
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 2 deletions.
24 changes: 24 additions & 0 deletions docs/v3/develop/settings-ref.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,18 @@ The URL of the Prefect API. If not set, the client will attempt to infer it.
**Supported environment variables**:
`PREFECT_API_URL`

### `auth_string`
The auth string used for basic authentication with a self-hosted Prefect API. Should be kept secret.

**Type**: `string | None`

**Default**: `None`

**TOML dotted key path**: `api.auth_string`

**Supported environment variables**:
`PREFECT_API_AUTH_STRING`

### `key`
The API key used for authentication with the Prefect API. Should be kept secret.

Expand Down Expand Up @@ -862,6 +874,18 @@ Number of seconds a runner should wait between heartbeats for flow runs.
---
## ServerAPISettings
Settings for controlling API server behavior
### `auth_string`
A string to use for basic authentication with the API; typically in the form 'user:password' but can be any string.

**Type**: `string | None`

**Default**: `None`

**TOML dotted key path**: `server.api.auth_string`

**Supported environment variables**:
`PREFECT_SERVER_API_AUTH_STRING`

### `host`
The API's host address (defaults to `127.0.0.1`).

Expand Down
36 changes: 36 additions & 0 deletions schemas/settings.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,24 @@
],
"title": "Url"
},
"auth_string": {
"anyOf": [
{
"format": "password",
"type": "string",
"writeOnly": true
},
{
"type": "null"
}
],
"default": null,
"description": "The auth string used for basic authentication with a self-hosted Prefect API. Should be kept secret.",
"supported_environment_variables": [
"PREFECT_API_AUTH_STRING"
],
"title": "Auth String"
},
"key": {
"anyOf": [
{
Expand Down Expand Up @@ -724,6 +742,24 @@
"ServerAPISettings": {
"description": "Settings for controlling API server behavior",
"properties": {
"auth_string": {
"anyOf": [
{
"format": "password",
"type": "string",
"writeOnly": true
},
{
"type": "null"
}
],
"default": null,
"description": "A string to use for basic authentication with the API; typically in the form 'user:password' but can be any string.",
"supported_environment_variables": [
"PREFECT_SERVER_API_AUTH_STRING"
],
"title": "Auth String"
},
"host": {
"default": "127.0.0.1",
"description": "The API's host address (defaults to `127.0.0.1`).",
Expand Down
18 changes: 18 additions & 0 deletions src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import base64
import datetime
import ssl
import warnings
Expand Down Expand Up @@ -114,6 +115,7 @@
from prefect.events.schemas.automations import Automation, AutomationCore
from prefect.logging import get_logger
from prefect.settings import (
PREFECT_API_AUTH_STRING,
PREFECT_API_DATABASE_CONNECTION_URL,
PREFECT_API_ENABLE_HTTP2,
PREFECT_API_KEY,
Expand Down Expand Up @@ -228,13 +230,15 @@ def get_client(
if sync_client:
return SyncPrefectClient(
api,
auth_string=PREFECT_API_AUTH_STRING.value(),
api_key=PREFECT_API_KEY.value(),
httpx_settings=httpx_settings,
server_type=server_type,
)
else:
return PrefectClient(
api,
auth_string=PREFECT_API_AUTH_STRING.value(),
api_key=PREFECT_API_KEY.value(),
httpx_settings=httpx_settings,
server_type=server_type,
Expand Down Expand Up @@ -271,6 +275,7 @@ def __init__(
self,
api: Union[str, ASGIApp],
*,
auth_string: Optional[str] = None,
api_key: Optional[str] = None,
api_version: Optional[str] = None,
httpx_settings: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -299,6 +304,10 @@ def __init__(
if api_key:
httpx_settings["headers"].setdefault("Authorization", f"Bearer {api_key}")

if auth_string:
token = base64.b64encode(auth_string.encode("utf-8")).decode("utf-8")
httpx_settings["headers"].setdefault("Authorization", f"Basic {token}")

# Context management
self._context_stack: int = 0
self._exit_stack = AsyncExitStack()
Expand Down Expand Up @@ -3469,6 +3478,8 @@ async def raise_for_api_version_mismatch(self) -> None:
try:
api_version = await self.api_version()
except Exception as e:
if "Unauthorized" in str(e):
raise e
raise RuntimeError(f"Failed to reach API at {self.api_url}") from e

api_version = version.parse(api_version)
Expand Down Expand Up @@ -3590,6 +3601,7 @@ def __init__(
self,
api: Union[str, ASGIApp],
*,
auth_string: Optional[str] = None,
api_key: Optional[str] = None,
api_version: Optional[str] = None,
httpx_settings: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -3618,6 +3630,10 @@ def __init__(
if api_key:
httpx_settings["headers"].setdefault("Authorization", f"Bearer {api_key}")

if auth_string:
token = base64.b64encode(auth_string.encode("utf-8")).decode("utf-8")
httpx_settings["headers"].setdefault("Authorization", f"Basic {token}")

# Context management
self._context_stack: int = 0
self._ephemeral_app: Optional[ASGIApp] = None
Expand Down Expand Up @@ -3800,6 +3816,8 @@ def raise_for_api_version_mismatch(self) -> None:
try:
api_version = self.api_version()
except Exception as e:
if "Unauthorized" in str(e):
raise e
raise RuntimeError(f"Failed to reach API at {self.api_url}") from e

api_version = version.parse(api_version)
Expand Down
25 changes: 25 additions & 0 deletions src/prefect/server/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
import atexit
import base64
import contextlib
import mimetypes
import os
Expand Down Expand Up @@ -315,6 +316,30 @@ async def server_version():
for router in API_ROUTERS:
api_app.include_router(router, dependencies=dependencies)

auth_string = prefect.settings.PREFECT_SERVER_API_AUTH_STRING.value()

if auth_string is not None:

@api_app.middleware("http")
async def token_validation(request: Request, call_next):
header_token = request.headers.get("Authorization")

try:
scheme, creds = header_token.split()
assert scheme == "Basic"
decoded = base64.b64decode(creds).decode("utf-8")
except Exception:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"exception_message": "Unauthorized"},
)
if decoded != auth_string:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"exception_message": "Unauthorized"},
)
return await call_next(request)

return api_app


Expand Down
1 change: 1 addition & 0 deletions src/prefect/settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __getattr__(name: str) -> Setting:
"temporary_settings",
"DEFAULT_PROFILES_PATH",
# add public settings here for auto-completion
"PREFECT_API_AUTH_STRING", # type: ignore
"PREFECT_API_KEY", # type: ignore
"PREFECT_API_URL", # type: ignore
"PREFECT_UI_URL", # type: ignore
Expand Down
4 changes: 4 additions & 0 deletions src/prefect/settings/models/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class APISettings(PrefectBaseSettings):
default=None,
description="The URL of the Prefect API. If not set, the client will attempt to infer it.",
)
auth_string: Optional[SecretStr] = Field(
default=None,
description="The auth string used for basic authentication with a self-hosted Prefect API. Should be kept secret.",
)
key: Optional[SecretStr] = Field(
default=None,
description="The API key used for authentication with the Prefect API. Should be kept secret.",
Expand Down
8 changes: 7 additions & 1 deletion src/prefect/settings/models/server/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import timedelta
from typing import Optional

from pydantic import AliasChoices, AliasPath, Field
from pydantic import AliasChoices, AliasPath, Field, SecretStr

from prefect.settings.base import PrefectBaseSettings, _build_settings_config

Expand All @@ -12,6 +13,11 @@ class ServerAPISettings(PrefectBaseSettings):

model_config = _build_settings_config(("server", "api"))

auth_string: Optional[SecretStr] = Field(
default=None,
description="A string to use for basic authentication with the API; typically in the form 'user:password' but can be any string.",
)

host: str = Field(
default="127.0.0.1",
description="The API's host address (defaults to `127.0.0.1`).",
Expand Down
36 changes: 35 additions & 1 deletion tests/client/test_prefect_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
import respx
from fastapi import Depends, FastAPI, status
from fastapi.security import HTTPBearer
from fastapi.security import HTTPBasic, HTTPBearer

import prefect.client.schemas as client_schemas
import prefect.context
Expand Down Expand Up @@ -76,6 +76,7 @@
from prefect.server.api.server import create_app
from prefect.server.database.orm_models import WorkPool
from prefect.settings import (
PREFECT_API_AUTH_STRING,
PREFECT_API_DATABASE_MIGRATE_ON_START,
PREFECT_API_KEY,
PREFECT_API_SSL_CERT_FILE,
Expand Down Expand Up @@ -1734,6 +1735,39 @@ async def test_get_client_includes_api_key_from_context(self):
assert client._client.headers["Authorization"] == "Bearer test"


class TestClientAuthString:
@pytest.fixture
async def test_app(self):
app = FastAPI()
basic = HTTPBasic()

# Returns given credentials if an Authorization
# header is passed, otherwise raises 403
@app.get("/api/check_for_auth_header")
async def check_for_auth_header(credentials=Depends(basic)):
return {"username": credentials.username, "password": credentials.password}

return app

async def test_client_passes_auth_string_as_auth_header(self, test_app):
auth_string = "admin:admin"
async with PrefectClient(test_app, auth_string=auth_string) as client:
res = await client._client.get("/check_for_auth_header")
assert res.status_code == status.HTTP_200_OK
assert res.json() == {"username": "admin", "password": "admin"}

async def test_client_no_auth_header_without_auth_string(self, test_app):
async with PrefectClient(test_app) as client:
with pytest.raises(httpx.HTTPStatusError, match="401"):
await client._client.get("/check_for_auth_header")

async def test_get_client_includes_auth_string_from_context(self):
with temporary_settings(updates={PREFECT_API_AUTH_STRING: "admin:test"}):
client = get_client()

assert client._client.headers["Authorization"].startswith("Basic")


class TestClientWorkQueues:
@pytest.fixture
async def deployment(self, prefect_client):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from prefect.utilities.filesystem import tmpchdir

SUPPORTED_SETTINGS = {
"PREFECT_API_AUTH_STRING": {"test_value": "admin:admin"},
"PREFECT_API_BLOCKS_REGISTER_ON_START": {"test_value": True, "legacy": True},
"PREFECT_API_DATABASE_CONNECTION_TIMEOUT": {"test_value": 10.0, "legacy": True},
"PREFECT_API_DATABASE_CONNECTION_URL": {"test_value": "sqlite:///", "legacy": True},
Expand Down Expand Up @@ -277,6 +278,7 @@
"PREFECT_RUNNER_SERVER_MISSED_POLLS_TOLERANCE": {"test_value": 10},
"PREFECT_RUNNER_SERVER_PORT": {"test_value": 8080},
"PREFECT_SERVER_ALLOW_EPHEMERAL_MODE": {"test_value": True, "legacy": True},
"PREFECT_SERVER_API_AUTH_STRING": {"test_value": "admin:admin"},
"PREFECT_SERVER_ANALYTICS_ENABLED": {"test_value": True},
"PREFECT_SERVER_API_CORS_ALLOWED_HEADERS": {"test_value": "foo"},
"PREFECT_SERVER_API_CORS_ALLOWED_METHODS": {"test_value": "foo"},
Expand Down

0 comments on commit 9af5624

Please sign in to comment.