Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored workload API token management for better security and implemented generic API token dispenser #3154

Merged
merged 15 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Implement proper authorization checks for workload API tokens
  • Loading branch information
stefannica committed Nov 11, 2024
commit 041be8884e5ce9ebb91951278ac0fa7ed1ad1197
4 changes: 2 additions & 2 deletions src/zenml/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, Dict, List, Optional, Union
from uuid import UUID

from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, model_validator

from zenml.constants import (
DEFAULT_ZENML_JWT_TOKEN_ALGORITHM,
Expand Down Expand Up @@ -260,7 +260,7 @@ class ServerConfiguration(BaseModel):
device_expiration_minutes: Optional[int] = None
trusted_device_expiration_minutes: Optional[int] = None

generic_api_token_lifetime: int = (
generic_api_token_lifetime: PositiveInt = (
DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME
)

Expand Down
7 changes: 0 additions & 7 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
ENV_ZENML_ENFORCE_TYPE_ANNOTATIONS = "ZENML_ENFORCE_TYPE_ANNOTATIONS"
ENV_ZENML_ENABLE_IMPLICIT_AUTH_METHODS = "ZENML_ENABLE_IMPLICIT_AUTH_METHODS"
ENV_ZENML_DISABLE_STEP_LOGS_STORAGE = "ZENML_DISABLE_STEP_LOGS_STORAGE"
ENV_ZENML_PIPELINE_API_TOKEN_EXPIRES_MINUTES = (
"ZENML_PIPELINE_API_TOKEN_EXPIRES_MINUTES"
)
ENV_ZENML_IGNORE_FAILURE_HOOK = "ZENML_IGNORE_FAILURE_HOOK"
ENV_ZENML_CUSTOM_SOURCE_ROOT = "ZENML_CUSTOM_SOURCE_ROOT"
ENV_ZENML_WHEEL_PACKAGE_NAME = "ZENML_WHEEL_PACKAGE_NAME"
Expand Down Expand Up @@ -411,10 +408,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int:

# orchestrator constants
ORCHESTRATOR_DOCKER_IMAGE_KEY = "orchestrator"
PIPELINE_API_TOKEN_EXPIRES_MINUTES = handle_int_env_var(
ENV_ZENML_PIPELINE_API_TOKEN_EXPIRES_MINUTES,
default=60 * 24, # 24 hours
)

# Secret constants
SECRET_VALUES = "values"
Expand Down
7 changes: 7 additions & 0 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,13 @@ class OAuthDeviceStatus(StrEnum):
LOCKED = "locked"


class APITokenType(StrEnum):
"""The API token type."""

GENERIC = "generic"
WORKLOAD = "workload"


class GenericFilterOps(StrEnum):
"""Ops for all filters for string values on list methods."""

Expand Down
13 changes: 12 additions & 1 deletion src/zenml/orchestrators/base_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Type, cast
from uuid import UUID

from pydantic import model_validator

Expand Down Expand Up @@ -186,7 +187,17 @@ def run(
"""
self._prepare_run(deployment=deployment)

environment = get_config_environment_vars(deployment=deployment)
pipeline_run_id: Optional[UUID] = None
schedule_id: Optional[UUID] = None
if deployment.schedule:
schedule_id = deployment.schedule.id
if placeholder_run:
pipeline_run_id = placeholder_run.id

environment = get_config_environment_vars(
schedule_id=schedule_id,
pipeline_run_id=pipeline_run_id,
)

prevent_client_side_caching = handle_bool_env_var(
ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING, default=False
Expand Down
3 changes: 2 additions & 1 deletion src/zenml/orchestrators/step_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,8 @@ def _run_step_with_step_operator(
)
)
environment = orchestrator_utils.get_config_environment_vars(
deployment=self._deployment
pipeline_run_id=step_run_info.run_id,
step_run_id=step_run_info.step_run_id,
)
if last_retry:
environment[ENV_ZENML_IGNORE_FAILURE_HOOK] = str(False)
Expand Down
76 changes: 44 additions & 32 deletions src/zenml/orchestrators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
ENV_ZENML_DISABLE_CREDENTIALS_DISK_CACHING,
ENV_ZENML_SERVER,
ENV_ZENML_STORE_PREFIX,
PIPELINE_API_TOKEN_EXPIRES_MINUTES,
)
from zenml.enums import AuthScheme, StackComponentType, StoreType
from zenml.logger import get_logger
Expand All @@ -39,7 +38,6 @@

if TYPE_CHECKING:
from zenml.artifact_stores.base_artifact_store import BaseArtifactStore
from zenml.models import PipelineDeploymentResponse


def get_orchestrator_run_name(pipeline_name: str) -> str:
Expand Down Expand Up @@ -82,16 +80,23 @@ def is_setting_enabled(


def get_config_environment_vars(
deployment: Optional["PipelineDeploymentResponse"] = None,
schedule_id: Optional[UUID] = None,
pipeline_run_id: Optional[UUID] = None,
step_run_id: Optional[UUID] = None,
) -> Dict[str, str]:
"""Gets environment variables to set for mirroring the active config.

If a pipeline deployment is given, the environment variables will be set to
include a newly generated API token valid for the duration of the pipeline
run instead of the API token from the global config.
If a schedule ID, pipeline run ID or step run ID is given, and the current
client is not authenticated to a server with an API key, the environment
variables will be updated to include a newly generated workload API token
that will be valid for the duration of the schedule, pipeline run, or step
run instead of the current API token used to authenticate the client.

Args:
deployment: Optional deployment to use for the environment variables.
schedule_id: Optional schedule ID to use to generate a new API token.
pipeline_run_id: Optional pipeline run ID to use to generate a new API
token.
step_run_id: Optional step run ID to use to generate a new API token.

Returns:
Environment variable dict.
Expand All @@ -112,42 +117,49 @@ def get_config_environment_vars(
api_key = credentials_store.get_api_key(url)
api_token = credentials_store.get_token(url, allow_expired=False)
if api_key:
# If an API key is available, it is used to authenticate the
# pipeline run environment.
environment_vars[ENV_ZENML_STORE_PREFIX + "API_KEY"] = api_key
elif deployment:
# When connected to an authenticated ZenML server, if a pipeline
# deployment is supplied, we need to fetch an API token that will be
# valid for the duration of the pipeline run.
elif schedule_id or pipeline_run_id or step_run_id:
# When connected to an authenticated ZenML server, if a schedule ID,
# pipeline run ID or step run ID is supplied, we need to fetch a new
# workload API token scoped to the schedule, pipeline run or step
# run.
assert isinstance(global_config.zen_store, RestZenStore)
pipeline_id: Optional[UUID] = None
if deployment.pipeline:
pipeline_id = deployment.pipeline.id
schedule_id: Optional[UUID] = None
expires_minutes: Optional[int] = PIPELINE_API_TOKEN_EXPIRES_MINUTES
if deployment.schedule:
schedule_id = deployment.schedule.id
# If a schedule is given, this is a long running pipeline that
# should not have an API token that expires.
expires_minutes = None

if pipeline_run_id or step_run_id:
# If a pipeline run or step run is given, the pipeline run
# or step run credentials are scoped to the pipeline run or step
# run and will only be valid for the duration of the run/step.
new_api_token = global_config.zen_store.get_api_token(
schedule_id=schedule_id,
pipeline_run_id=pipeline_run_id,
step_run_id=step_run_id,
)
else:
# If a schedule is given, the pipeline run credentials is
# configured with a token that is scoped to the given schedule.
logger.warning(
"An API token without an expiration time will be generated "
"and used to run this pipeline on a schedule. This is very "
"insecure because the API token cannot be revoked in case "
"of potential theft without disabling the entire user "
"account. When deploying a pipeline on a schedule, it is "
"strongly advised to use a service account API key to "
"authenticate to the ZenML server instead of your regular "
"user account. For more information, see "
"insecure because the API token will be valid for the "
"entire lifetime of the schedule and can be used to access "
"your account if leaked. When deploying a pipeline on a "
"schedule, it is strongly advised to use a service account "
"API key to authenticate to the ZenML server instead of "
"your regular user account. For more information, see "
"https://docs.zenml.io/how-to/connecting-to-zenml/connect-with-a-service-account"
)
new_api_token = global_config.zen_store.get_api_token(
pipeline_id=pipeline_id,
schedule_id=schedule_id,
expires_minutes=expires_minutes,
)
new_api_token = global_config.zen_store.get_api_token(
schedule_id=schedule_id,
)

environment_vars[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = (
new_api_token
)
elif api_token:
# For all other cases, the pipeline run environment is configured
# with the current access token.
environment_vars[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = (
api_token.access_token
)
Expand Down
147 changes: 144 additions & 3 deletions src/zenml/zen_server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
"""Authentication module for ZenML server."""

from contextvars import ContextVar
from datetime import datetime
from datetime import datetime, timedelta
from typing import Callable, Optional, Union
from urllib.parse import urlencode
from uuid import UUID

import requests
from fastapi import Depends
from fastapi import Depends, Response
from fastapi.security import (
HTTPBasic,
HTTPBasicCredentials,
Expand All @@ -37,7 +37,7 @@
LOGIN,
VERSION_1,
)
from zenml.enums import AuthScheme, OAuthDeviceStatus
from zenml.enums import AuthScheme, ExecutionStatus, OAuthDeviceStatus
from zenml.exceptions import (
AuthorizationException,
CredentialsNotValid,
Expand All @@ -51,6 +51,7 @@
ExternalUserModel,
OAuthDeviceInternalResponse,
OAuthDeviceInternalUpdate,
OAuthTokenResponse,
UserAuthModel,
UserRequest,
UserResponse,
Expand Down Expand Up @@ -329,6 +330,72 @@ def authenticate_credentials(
),
)

if decoded_token.schedule_id:
# If the token contains a schedule ID, we need to check if the
# schedule still exists in the database.
try:
zen_store().get_schedule(
decoded_token.schedule_id, hydrate=False
)
except KeyError:
error = (
f"Authentication error: error retrieving token schedule "
f"{decoded_token.schedule_id}"
)
logger.error(error)
raise CredentialsNotValid(error)

if decoded_token.pipeline_run_id:
# If the token contains a pipeline run ID, we need to check if the
# pipeline run exists in the database and the pipeline run has
# not concluded.
try:
pipeline_run = zen_store().get_run(
decoded_token.pipeline_run_id, hydrate=False
)
except KeyError:
error = (
f"Authentication error: error retrieving token pipeline run "
f"{decoded_token.pipeline_run_id}"
)
logger.error(error)
raise CredentialsNotValid(error)

if pipeline_run.status in [
ExecutionStatus.FAILED,
ExecutionStatus.COMPLETED,
]:
error = (
f"The execution of pipeline run "
f"{decoded_token.pipeline_run_id} has already concluded and "
"API tokens scoped to it are no longer valid."
)

if decoded_token.step_run_id:
# If the token contains a step run ID, we need to check if the
# step run exists in the database and the step run has not concluded.
try:
step_run = zen_store().get_run_step(
decoded_token.step_run_id, hydrate=False
)
except KeyError:
error = (
f"Authentication error: error retrieving token step run "
f"{decoded_token.step_run_id}"
)
logger.error(error)
raise CredentialsNotValid(error)

if step_run.status in [
ExecutionStatus.FAILED,
ExecutionStatus.COMPLETED,
]:
error = (
f"The execution of step run "
f"{decoded_token.step_run_id} has already concluded and "
"API tokens scoped to it are no longer valid."
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I get this right, every time such a workload token is used to send a request to the server, this code is running, correct? Do you think this will cause any trouble at scale as this is running up to 3 additional database queries per request?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I could implement a local time-based cache that would delay re-fetching the same object from the DB for up to 30 seconds. We should keep this contained to a limited number of entries though, otherwise the memory requirements will explode during high loads.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially yeah, but not necessary for now I hope. We're considerably reducing the amounts of server requests during pipeline execution in this and the next release, which means this hopefully doesn't matter as much

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@schustmi I added a simple in-memory cache for objects that we want to cache in the ZenML server

auth_context = AuthContext(
user=user_model,
access_token=decoded_token,
Expand Down Expand Up @@ -660,6 +727,80 @@ def authenticate_api_key(
return AuthContext(user=user_model, api_key=internal_api_key)


def generate_access_token(
user_id: UUID,
response: Optional[Response] = None,
device: Optional[OAuthDeviceInternalResponse] = None,
api_key: Optional[APIKeyInternalResponse] = None,
expires_in: Optional[int] = None,
schedule_id: Optional[UUID] = None,
pipeline_run_id: Optional[UUID] = None,
step_run_id: Optional[UUID] = None,
) -> OAuthTokenResponse:
"""Generates an access token for the given user.

Args:
user_id: The ID of the user.
response: The FastAPI response object.
device: The device used for authentication.
api_key: The service account API key used for authentication.
expires_in: The number of seconds until the token expires.
schedule_id: The ID of the schedule to scope the token to.
pipeline_run_id: The ID of the pipeline run to scope the token to.
step_run_id: The ID of the step run to scope the token to.

Returns:
An authentication response with an access token.
"""
config = server_config()

# If the expiration time is not supplied, the JWT tokens are set to expire
# according to the values configured in the server config. Device tokens are
# handled separately from regular user tokens.
expires: Optional[datetime] = None
if expires_in:
expires = datetime.utcnow() + timedelta(seconds=expires_in)
elif device:
# If a device was used for authentication, the token will expire
# at the same time as the device.
expires = device.expires
if expires:
expires_in = max(
int(expires.timestamp() - datetime.utcnow().timestamp()), 0
)
elif config.jwt_token_expire_minutes:
expires = datetime.utcnow() + timedelta(
minutes=config.jwt_token_expire_minutes
)
expires_in = config.jwt_token_expire_minutes * 60

access_token = JWTToken(
user_id=user_id,
device_id=device.id if device else None,
api_key_id=api_key.id if api_key else None,
schedule_id=schedule_id,
pipeline_run_id=pipeline_run_id,
step_run_id=step_run_id,
).encode(expires=expires)

if not device and response:
# Also set the access token as an HTTP only cookie in the response
response.set_cookie(
key=config.get_auth_cookie_name(),
value=access_token,
httponly=True,
samesite="lax",
max_age=config.jwt_token_expire_minutes * 60
if config.jwt_token_expire_minutes
else None,
domain=config.auth_cookie_domain,
)

return OAuthTokenResponse(
access_token=access_token, expires_in=expires_in, token_type="bearer"
)


def http_authentication(
credentials: HTTPBasicCredentials = Depends(HTTPBasic()),
) -> AuthContext:
Expand Down
Loading
Loading