Skip to content

Commit

Permalink
applying pep8 style on yapf
Browse files Browse the repository at this point in the history
  • Loading branch information
igorborgest committed Oct 17, 2019
1 parent 7d68fc4 commit 3867a39
Show file tree
Hide file tree
Showing 21 changed files with 1,658 additions and 1,568 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[flake8]
max-line-length = 120
ignore = E501,E126,W503
4 changes: 4 additions & 0 deletions .style.yapf
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[style]
BASED_ON_STYLE = pep8
SPACES_BEFORE_COMMENT = 2
COLUMN_LIMIT = 120
45 changes: 22 additions & 23 deletions awswrangler/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@ class Athena:
def __init__(self, session):
self._session = session
self._client_athena = session.boto3_session.client(
service_name="athena", config=session.botocore_config)
service_name="athena", config=session.botocore_config
)

def get_query_columns_metadata(self, query_execution_id):
response = self._client_athena.get_query_results(
QueryExecutionId=query_execution_id, MaxResults=1)
QueryExecutionId=query_execution_id, MaxResults=1
)
col_info = response["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]
return {x["Name"]: x["Type"] for x in col_info}

def get_query_dtype(self, query_execution_id):
cols_metadata = self.get_query_columns_metadata(
query_execution_id=query_execution_id)
cols_metadata = self.get_query_columns_metadata(query_execution_id=query_execution_id)
logger.debug(f"cols_metadata: {cols_metadata}")
dtype = {}
parse_timestamps = []
Expand All @@ -53,10 +54,11 @@ def create_athena_bucket(self):
:return: Bucket s3 path (E.g. s3://aws-athena-query-results-ACCOUNT-REGION/)
"""
account_id = (self._session.boto3_session.client(
service_name="sts",
config=self._session.botocore_config).get_caller_identity().get(
"Account"))
account_id = (
self._session.boto3_session.client(
service_name="sts", config=self._session.botocore_config
).get_caller_identity().get("Account")
)
session_region = self._session.boto3_session.region_name
s3_output = f"s3://aws-athena-query-results-{account_id}-{session_region}/"
s3_resource = self._session.boto3_session.resource("s3")
Expand All @@ -82,7 +84,8 @@ def run_query(self, query, database, s3_output=None, workgroup=None):
QueryString=query,
QueryExecutionContext={"Database": database},
ResultConfiguration={"OutputLocation": s3_output},
WorkGroup=workgroup)
WorkGroup=workgroup
)
return response["QueryExecutionId"]

def wait_query(self, query_execution_id):
Expand All @@ -93,24 +96,20 @@ def wait_query(self, query_execution_id):
:return: Query response
"""
final_states = ["FAILED", "SUCCEEDED", "CANCELLED"]
response = self._client_athena.get_query_execution(
QueryExecutionId=query_execution_id)
response = self._client_athena.get_query_execution(QueryExecutionId=query_execution_id)
state = response["QueryExecution"]["Status"]["State"]
while state not in final_states:
sleep(QUERY_WAIT_POLLING_DELAY)
response = self._client_athena.get_query_execution(
QueryExecutionId=query_execution_id)
response = self._client_athena.get_query_execution(QueryExecutionId=query_execution_id)
state = response["QueryExecution"]["Status"]["State"]
logger.debug(f"state: {state}")
logger.debug(
f"StateChangeReason: {response['QueryExecution']['Status'].get('StateChangeReason')}"
)
if state == "FAILED":
raise QueryFailed(
response["QueryExecution"]["Status"].get("StateChangeReason"))
raise QueryFailed(response["QueryExecution"]["Status"].get("StateChangeReason"))
elif state == "CANCELLED":
raise QueryCancelled(
response["QueryExecution"]["Status"].get("StateChangeReason"))
raise QueryCancelled(response["QueryExecution"]["Status"].get("StateChangeReason"))
return response

def repair_table(self, database, table, s3_output=None, workgroup=None):
Expand All @@ -130,17 +129,17 @@ def repair_table(self, database, table, s3_output=None, workgroup=None):
:return: Query execution ID
"""
query = f"MSCK REPAIR TABLE {table};"
query_id = self.run_query(query=query,
database=database,
s3_output=s3_output,
workgroup=workgroup)
query_id = self.run_query(
query=query, database=database, s3_output=s3_output, workgroup=workgroup
)
self.wait_query(query_execution_id=query_id)
return query_id

@staticmethod
def _normalize_name(name):
name = "".join(c for c in unicodedata.normalize("NFD", name)
if unicodedata.category(c) != "Mn")
name = "".join(
c for c in unicodedata.normalize("NFD", name) if unicodedata.category(c) != "Mn"
)
name = name.replace(" ", "_")
name = name.replace("-", "_")
name = name.replace(".", "_")
Expand Down
43 changes: 25 additions & 18 deletions awswrangler/cloudwatchlogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ class CloudWatchLogs:
def __init__(self, session):
self._session = session
self._client_logs = session.boto3_session.client(
service_name="logs", config=session.botocore_config)
service_name="logs", config=session.botocore_config
)

def start_query(self,
query,
log_group_names,
start_time=datetime(year=1970, month=1, day=1),
end_time=datetime.utcnow(),
limit=None):
def start_query(
self,
query,
log_group_names,
start_time=datetime(year=1970, month=1, day=1),
end_time=datetime.utcnow(),
limit=None
):
"""
Run a query against AWS CloudWatchLogs Insights and wait the results
https://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/CWL_QuerySyntax.html
Expand Down Expand Up @@ -69,12 +72,14 @@ def wait_query(self, query_id):
raise QueryCancelled(f"query ID: {query_id}")
return response

def query(self,
query,
log_group_names,
start_time=datetime(year=1970, month=1, day=1),
end_time=datetime.utcnow(),
limit=None):
def query(
self,
query,
log_group_names,
start_time=datetime(year=1970, month=1, day=1),
end_time=datetime.utcnow(),
limit=None
):
"""
Run a query against AWS CloudWatchLogs Insights and wait the results
https://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/CWL_QuerySyntax.html
Expand All @@ -86,10 +91,12 @@ def query(self,
:param limit: The maximum number of log events to return in the query.
:return: Results
"""
query_id = self.start_query(query=query,
log_group_names=log_group_names,
start_time=start_time,
end_time=end_time,
limit=limit)
query_id = self.start_query(
query=query,
log_group_names=log_group_names,
start_time=start_time,
end_time=end_time,
limit=limit
)
response = self.wait_query(query_id=query_id)
return response["results"]
28 changes: 10 additions & 18 deletions awswrangler/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ def pyarrow2athena(dtype: pa.types) -> str:
elif dtype_str.startswith("list"):
return f"array<{pyarrow2athena(dtype.value_type)}>"
elif dtype_str == "null":
raise UndetectedType(
"We can't infer the data type from an entire null object column")
raise UndetectedType("We can't infer the data type from an entire null object column")
else:
raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}")

Expand Down Expand Up @@ -283,8 +282,7 @@ def spark2redshift(dtype: str) -> str:
raise UnsupportedType("Unsupported Spark type: " + dtype)


def convert_schema(func: Callable,
schema: List[Tuple[str, str]]) -> Dict[str, str]:
def convert_schema(func: Callable, schema: List[Tuple[str, str]]) -> Dict[str, str]:
"""
Convert schema in the format of {"col name": "bigint", "col2 name": "int"}
applying some data types conversion function (e.g. spark2redshift)
Expand All @@ -296,10 +294,9 @@ def convert_schema(func: Callable,
return {name: func(dtype) for name, dtype in schema}


def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame,
preserve_index: bool,
indexes_position: str = "right"
) -> List[Tuple[str, str]]:
def extract_pyarrow_schema_from_pandas(
dataframe: pd.DataFrame, preserve_index: bool, indexes_position: str = "right"
) -> List[Tuple[str, str]]:
"""
Extract the related Pyarrow schema from any Pandas DataFrame
Expand All @@ -323,8 +320,7 @@ def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame,

# Filling cols_dtypes and indexes
indexes = []
for field in pa.Schema.from_pandas(df=dataframe[cols],
preserve_index=preserve_index):
for field in pa.Schema.from_pandas(df=dataframe[cols], preserve_index=preserve_index):
name = str(field.name)
dtype = field.type
cols_dtypes[name] = dtype
Expand All @@ -333,15 +329,11 @@ def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame,

# Filling schema
if indexes_position == "right":
schema = [(name, cols_dtypes[name])
for name in dataframe.columns] # adding columns
schema += [(name, cols_dtypes[name])
for name in indexes] # adding indexes
schema = [(name, cols_dtypes[name]) for name in dataframe.columns] # adding columns
schema += [(name, cols_dtypes[name]) for name in indexes] # adding indexes
elif indexes_position == "left":
schema = [(name, cols_dtypes[name])
for name in indexes] # adding indexes
schema += [(name, cols_dtypes[name])
for name in dataframe.columns] # adding columns
schema = [(name, cols_dtypes[name]) for name in indexes] # adding indexes
schema += [(name, cols_dtypes[name]) for name in dataframe.columns] # adding columns
else:
raise ValueError(f"indexes_position must be \"right\" or \"left\"")

Expand Down
Loading

0 comments on commit 3867a39

Please sign in to comment.