Skip to content

Commit

Permalink
feat(providers/amazon): Use asset in common provider (apache#43110)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W authored Oct 17, 2024
1 parent c7b9c8d commit a2b32d6
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 151 deletions.
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"amazon": {
"deps": [
"PyAthena>=3.0.10",
"apache-airflow-providers-common-compat>=1.1.0",
"apache-airflow-providers-common-compat>=1.2.1",
"apache-airflow-providers-common-sql>=1.3.1",
"apache-airflow-providers-http",
"apache-airflow>=2.8.0",
Expand Down
14 changes: 7 additions & 7 deletions providers/src/airflow/providers/amazon/aws/assets/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@
from typing import TYPE_CHECKING

from airflow.providers.amazon.aws.hooks.s3 import S3Hook

try:
from airflow.assets import Asset
except ModuleNotFoundError:
from airflow.datasets import Dataset as Asset # type: ignore[no-redef]
from airflow.providers.common.compat.assets import Asset

if TYPE_CHECKING:
from urllib.parse import SplitResult

from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset
from airflow.providers.common.compat.openlineage.facet import (
Dataset as OpenLineageDataset,
)


def create_asset(*, bucket: str, key: str, extra=None) -> Asset:
Expand All @@ -43,7 +41,9 @@ def sanitize_uri(uri: SplitResult) -> SplitResult:

def convert_asset_to_openlineage(asset: Asset, lineage_context) -> OpenLineageDataset:
"""Translate Asset with valid AIP-60 uri to OpenLineage with assistance from the hook."""
from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset
from airflow.providers.common.compat.openlineage.facet import (
Dataset as OpenLineageDataset,
)

bucket, key = S3Hook.parse_s3_url(asset.uri)
return OpenLineageDataset(namespace=f"s3://{bucket}", name=key if key else "/")
106 changes: 77 additions & 29 deletions providers/src/airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@
from uuid import uuid4

if TYPE_CHECKING:
from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as S3ResourceObject
from mypy_boto3_s3.service_resource import (
Bucket as S3Bucket,
Object as S3ResourceObject,
)

from airflow.utils.types import ArgNotSet

with suppress(ImportError):
from aiobotocore.client import AioBaseClient

from importlib.util import find_spec

from asgiref.sync import sync_to_async
from boto3.s3.transfer import S3Transfer, TransferConfig
Expand All @@ -58,15 +60,9 @@
from airflow.providers.amazon.aws.exceptions import S3HookUriParseFailure
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
from airflow.utils.helpers import chunks

if find_spec("airflow.assets"):
from airflow.lineage.hook import get_hook_lineage_collector
else:
# TODO: import from common.compat directly after common.compat providers with
# asset_compat_lineage_collector released
from airflow.providers.amazon.aws.utils.asset_compat_lineage_collector import get_hook_lineage_collector

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -148,9 +144,10 @@ def wrapper(*args, **kwargs) -> Callable:

if "bucket_name" not in bound_args.arguments:
with suppress(S3HookUriParseFailure):
bound_args.arguments["bucket_name"], bound_args.arguments[key_name] = S3Hook.parse_s3_url(
bound_args.arguments[key_name]
)
(
bound_args.arguments["bucket_name"],
bound_args.arguments[key_name],
) = S3Hook.parse_s3_url(bound_args.arguments[key_name])

return func(*bound_args.args, **bound_args.kwargs)

Expand Down Expand Up @@ -318,7 +315,8 @@ def check_for_bucket(self, bucket_name: str | None = None) -> bool:
self.log.info('Bucket "%s" does not exist', bucket_name)
elif return_code == 403:
self.log.error(
'Access to bucket "%s" is forbidden or there was an error with the request', bucket_name
'Access to bucket "%s" is forbidden or there was an error with the request',
bucket_name,
)
self.log.error(e)
return False
Expand Down Expand Up @@ -359,7 +357,8 @@ def create_bucket(self, bucket_name: str | None = None, region_name: str | None
self.get_conn().create_bucket(Bucket=bucket_name)
else:
self.get_conn().create_bucket(
Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region_name}
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": region_name},
)

@provide_bucket_name
Expand Down Expand Up @@ -410,7 +409,10 @@ def list_prefixes(

paginator = self.get_conn().get_paginator("list_objects_v2")
response = paginator.paginate(
Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config
Bucket=bucket_name,
Prefix=prefix,
Delimiter=delimiter,
PaginationConfig=config,
)

prefixes: list[str] = []
Expand Down Expand Up @@ -471,7 +473,10 @@ async def list_prefixes_async(

paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(
Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config
Bucket=bucket_name,
Prefix=prefix,
Delimiter=delimiter,
PaginationConfig=config,
)

prefixes = []
Expand Down Expand Up @@ -569,7 +574,11 @@ async def check_key_async(
return await self._check_key_async(client, bucket, wildcard_match, bucket_keys, use_regex)

async def check_for_prefix_async(
self, client: AioBaseClient, prefix: str, delimiter: str, bucket_name: str | None = None
self,
client: AioBaseClient,
prefix: str,
delimiter: str,
bucket_name: str | None = None,
) -> bool:
"""
Check that a prefix exists in a bucket.
Expand All @@ -587,7 +596,11 @@ async def check_for_prefix_async(
return prefix in plist

async def _check_for_prefix_async(
self, client: AioBaseClient, prefix: str, delimiter: str, bucket_name: str | None = None
self,
client: AioBaseClient,
prefix: str,
delimiter: str,
bucket_name: str | None = None,
) -> bool:
return await self.check_for_prefix_async(
client, prefix=prefix, delimiter=delimiter, bucket_name=bucket_name
Expand Down Expand Up @@ -643,7 +656,10 @@ async def _list_keys_async(

paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(
Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config
Bucket=bucket_name,
Prefix=prefix,
Delimiter=delimiter,
PaginationConfig=config,
)

keys = []
Expand All @@ -655,7 +671,10 @@ async def _list_keys_async(
return keys

def _list_key_object_filter(
self, keys: list, from_datetime: datetime | None = None, to_datetime: datetime | None = None
self,
keys: list,
from_datetime: datetime | None = None,
to_datetime: datetime | None = None,
) -> list:
def _is_in_period(input_date: datetime) -> bool:
if from_datetime is not None and input_date <= from_datetime:
Expand Down Expand Up @@ -766,7 +785,10 @@ async def is_keys_unchanged_async(
"message": success_message,
}

self.log.error("FAILURE: Inactivity Period passed, not enough objects found in %s", path)
self.log.error(
"FAILURE: Inactivity Period passed, not enough objects found in %s",
path,
)
return {
"status": "error",
"message": f"FAILURE: Inactivity Period passed, not enough objects found in {path}",
Expand Down Expand Up @@ -1109,7 +1131,13 @@ def load_file(
extra_args["ACL"] = acl_policy

client = self.get_conn()
client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args, Config=self.transfer_config)
client.upload_file(
filename,
bucket_name,
key,
ExtraArgs=extra_args,
Config=self.transfer_config,
)
get_hook_lineage_collector().add_input_asset(
context=self, scheme="file", asset_kwargs={"path": filename}
)
Expand Down Expand Up @@ -1308,18 +1336,32 @@ def copy_object(
)

source_bucket_name, source_bucket_key = self.get_s3_bucket_key(
source_bucket_name, source_bucket_key, "source_bucket_name", "source_bucket_key"
source_bucket_name,
source_bucket_key,
"source_bucket_name",
"source_bucket_key",
)

copy_source = {"Bucket": source_bucket_name, "Key": source_bucket_key, "VersionId": source_version_id}
copy_source = {
"Bucket": source_bucket_name,
"Key": source_bucket_key,
"VersionId": source_version_id,
}
response = self.get_conn().copy_object(
Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, **kwargs
Bucket=dest_bucket_name,
Key=dest_bucket_key,
CopySource=copy_source,
**kwargs,
)
get_hook_lineage_collector().add_input_asset(
context=self, scheme="s3", asset_kwargs={"bucket": source_bucket_name, "key": source_bucket_key}
context=self,
scheme="s3",
asset_kwargs={"bucket": source_bucket_name, "key": source_bucket_key},
)
get_hook_lineage_collector().add_output_asset(
context=self, scheme="s3", asset_kwargs={"bucket": dest_bucket_name, "key": dest_bucket_key}
context=self,
scheme="s3",
asset_kwargs={"bucket": dest_bucket_name, "key": dest_bucket_key},
)
return response

Expand Down Expand Up @@ -1435,7 +1477,10 @@ def download_file(
file_path = Path(local_dir, subdir, filename_in_s3)

if file_path.is_file():
self.log.error("file '%s' already exists. Failing the task and not overwriting it", file_path)
self.log.error(
"file '%s' already exists. Failing the task and not overwriting it",
file_path,
)
raise FileExistsError

file_path.parent.mkdir(exist_ok=True, parents=True)
Expand Down Expand Up @@ -1484,7 +1529,10 @@ def generate_presigned_url(
s3_client = self.get_conn()
try:
return s3_client.generate_presigned_url(
ClientMethod=client_method, Params=params, ExpiresIn=expires_in, HttpMethod=http_method
ClientMethod=client_method,
Params=params,
ExpiresIn=expires_in,
HttpMethod=http_method,
)

except ClientError as e:
Expand Down

This file was deleted.

2 changes: 1 addition & 1 deletion providers/src/airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ versions:

dependencies:
- apache-airflow>=2.8.0
- apache-airflow-providers-common-compat>=1.1.0
- apache-airflow-providers-common-compat>=1.2.1
- apache-airflow-providers-common-sql>=1.3.1
- apache-airflow-providers-http
# We should update minimum version of boto3 and here regularly to avoid `pip` backtracking with the number
Expand Down
Loading

0 comments on commit a2b32d6

Please sign in to comment.