Skip to content

Commit

Permalink
[td] Support MySQL in dbt (mage-ai#1809)
Browse files Browse the repository at this point in the history
* [td] 0.7.75

* [td] Add dbt mysql support

* [td] undo version change

* done
  • Loading branch information
tommydangerous authored Jan 24, 2023
1 parent 7ec589b commit 324c3dd
Show file tree
Hide file tree
Showing 17 changed files with 580 additions and 231 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ USER root
# Install Mage
RUN ${PIP} install --upgrade pip
RUN ${PIP} install --no-cache "git+https://github.com/mage-ai/mage-ai.git#egg=mage-integrations&subdirectory=mage_integrations"
RUN ${PIP} install "git+https://github.com/mage-ai/dbt-mysql.git#egg=dbt-mysql"
RUN ${PIP} install "git+https://github.com/mage-ai/singer-python.git#egg=singer-python"
COPY ./mage_ai/server/constants.py constants.py
RUN tag=$(tail -n 1 constants.py) && VERSION=$(echo $tag | tr -d "'") && ${PIP} install --no-cache "mage-ai[all]"==$VERSION
Expand Down
1 change: 1 addition & 0 deletions dev.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ COPY requirements.txt requirements.txt
RUN ${PIP} install --upgrade pip
COPY mage_integrations mage_integrations
RUN ${PIP} install mage_integrations/
RUN ${PIP} install "git+https://github.com/mage-ai/dbt-mysql.git#egg=dbt-mysql"
RUN ${PIP} install "git+https://github.com/mage-ai/singer-python.git#egg=singer-python"
RUN ${PIP} install -r requirements.txt
RUN ${PIP} install jupyterlab
Expand Down
2 changes: 1 addition & 1 deletion mage_ai/data_preparation/models/block/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def file(self):

@property
def table_name(self):
table_name = f'{self.pipeline.uuid}_{self.uuid}_{self.pipeline.version_name}'
table_name = f'{self.pipeline.uuid}_{clean_name_orig(self.uuid)}_{self.pipeline.version_name}'

env = (self.global_vars or dict()).get('env')
if env == ENV_DEV:
Expand Down
85 changes: 77 additions & 8 deletions mage_ai/data_preparation/models/block/dbt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mage_ai.data_preparation.models.block.sql import (
bigquery,
execute_sql_code as execute_sql_code_orig,
mysql,
postgres,
snowflake,
)
Expand All @@ -16,7 +17,7 @@
from mage_ai.shared.array import find
from mage_ai.shared.hash import merge_dict
from mage_ai.shared.parsers import encode_complex
from mage_ai.shared.utils import files_in_path
from mage_ai.shared.utils import clean_name, files_in_path
from pandas import DataFrame
from typing import Callable, Dict, List, Tuple
import os
Expand Down Expand Up @@ -44,15 +45,22 @@ def parse_attributes(block) -> Dict:
full_path = f'{get_repo_path()}/dbt/{file_path}'
sources_full_path = re.sub(filename, 'mage_sources.yml', full_path)

source_name = f'mage_{project_name}'
profiles_full_path = f'{project_full_path}/profiles.yml'
profile_target = block.configuration.get('dbt_profile_target')
profile = load_profile(project_name, profiles_full_path, profile_target)

if profile and 'mysql' == profile.get('type'):
source_name = profile['schema']
else:
source_name = f'mage_{project_name}'

return dict(
file_extension=file_extension,
file_path=file_path,
filename=filename,
full_path=full_path,
model_name=model_name,
profiles_full_path=f'{project_full_path}/profiles.yml',
profiles_full_path=profiles_full_path,
project_full_path=project_full_path,
project_name=project_name,
source_name=source_name,
Expand Down Expand Up @@ -118,10 +126,15 @@ def load_sources(block) -> Dict:


def source_table_name_for_block(block) -> str:
return f'{block.pipeline.uuid}_{block.uuid}'
return f'{clean_name(block.pipeline.uuid)}_{clean_name(block.uuid)}'


def update_model_settings(block, upstream_blocks, upstream_blocks_previous):
def update_model_settings(
block: 'Block',
upstream_blocks: List['Block'],
upstream_blocks_previous: List['Block'],
force_update: bool = False,
):
attributes_dict = parse_attributes(block)

filename = attributes_dict['filename']
Expand All @@ -130,7 +143,7 @@ def update_model_settings(block, upstream_blocks, upstream_blocks_previous):
sources_full_path = attributes_dict['sources_full_path']
source_name = attributes_dict['source_name']

if len(upstream_blocks_previous) > len(upstream_blocks):
if not force_update and len(upstream_blocks_previous) > len(upstream_blocks):
# TODO (tommy dangerous): should we remove sources?
# How do we know no other model is using a source?

Expand Down Expand Up @@ -175,6 +188,8 @@ def update_model_settings(block, upstream_blocks, upstream_blocks_previous):
)

settings = load_sources(block)

print('WTFFFFFFFFFFFFFF', new_source)
if settings:
source = find(lambda x: x['name'] == source_name, settings.get('sources', []))
if source:
Expand All @@ -201,7 +216,9 @@ def get_profile(block, profile_target: str = None) -> Dict:
attr = parse_attributes(block)
project_name = attr['project_name']
profiles_full_path = attr['profiles_full_path']
return load_profile(project_name, profiles_full_path, profile_target)

def load_profile(project_name: str, profiles_full_path: str, profile_target: str = None) -> Dict:
with open(profiles_full_path, 'r') as f:
try:
text = Template(f.read()).render(
Expand Down Expand Up @@ -261,6 +278,27 @@ def config_file_loader_and_configuration(block, profile_target: str) -> Dict:
data_provider_schema=schema,
export_write_policy=ExportWritePolicy.REPLACE,
)
elif DataSource.MYSQL == profile_type:
host = profile.get('server')
password = profile.get('password')
port = profile.get('port')
schema = profile.get('schema')
ssl_disabled = profile.get('ssl_disabled')
username = profile.get('username')

config_file_loader = ConfigFileLoader(config=dict(
MYSQL_CONNECTION_METHOD='ssh_tunnel' if not ssl_disabled else None,
MYSQL_DATABASE=schema,
MYSQL_HOST=host,
MYSQL_PASSWORD=password,
MYSQL_PORT=port,
MYSQL_USER=username,
))
configuration = dict(
data_provider=profile_type,
data_provider_database=schema,
export_write_policy=ExportWritePolicy.REPLACE,
)
elif DataSource.SNOWFLAKE == profile_type:
database = profile.get('database')
schema = profile.get('schema')
Expand Down Expand Up @@ -339,6 +377,18 @@ def create_upstream_tables(
cache_upstream_dbt_models=cache_upstream_dbt_models,
**kwargs,
)
elif DataSource.MYSQL == data_provider:
from mage_ai.io.mysql import MySQL

with MySQL.with_config(config_file_loader) as loader:
mysql.create_upstream_block_tables(
loader,
block,
cascade_on_drop=True,
configuration=configuration,
cache_upstream_dbt_models=cache_upstream_dbt_models,
**kwargs,
)
elif DataSource.BIGQUERY == data_provider:
from mage_ai.io.bigquery import BigQuery

Expand Down Expand Up @@ -379,6 +429,9 @@ def __replace_func(db, schema, tn):
if replace_func:
return replace_func(db, schema, tn)

if db and not schema:
return f'{__quoted(db)}.{__quoted(tn)}'

return f'{__quoted(schema)}.{__quoted(tn)}'

for idx, upstream_block in enumerate(block.upstream_blocks):
Expand All @@ -387,8 +440,15 @@ def __replace_func(db, schema, tn):

attrs = parse_attributes(upstream_block)
model_name = attrs['model_name']
matcher1 = f'{__quoted(profile_database)}.{__quoted(profile_schema)}.'\
f'{__quoted(model_name)}'

arr = []
if profile_database:
arr.append(__quoted(profile_database))
if profile_schema:
arr.append(__quoted(profile_schema))
if model_name:
arr.append(__quoted(model_name))
matcher1 = '.'.join(arr)

database = configuration.get('data_provider_database')
schema = configuration.get('data_provider_schema')
Expand Down Expand Up @@ -426,6 +486,10 @@ def query_from_compiled_sql(block, profile_target: str) -> DataFrame:
database = profile['dbname']
schema = profile['schema']
quote_str = '"'
elif DataSource.MYSQL == profile_type:
database = configuration['data_provider_database']
schema = None
quote_str = '`'
elif DataSource.BIGQUERY == profile_type:
database = profile['project']
schema = profile['dataset']
Expand All @@ -448,6 +512,11 @@ def query_from_compiled_sql(block, profile_target: str) -> DataFrame:

with Postgres.with_config(config_file_loader) as loader:
return loader.load(query_string)
elif DataSource.MYSQL == data_provider:
from mage_ai.io.mysql import MySQL

with MySQL.with_config(config_file_loader) as loader:
return loader.load(query_string)
elif DataSource.BIGQUERY == data_provider:
from mage_ai.io.bigquery import BigQuery

Expand Down
33 changes: 33 additions & 0 deletions mage_ai/data_preparation/models/block/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from mage_ai.data_preparation.models.block import Block
from mage_ai.data_preparation.models.block.sql import (
bigquery,
mysql,
postgres,
redshift,
snowflake,
Expand Down Expand Up @@ -88,6 +89,38 @@ def execute_sql_code(
except Exception as err:
if '404' not in str(err):
raise err
elif DataSource.MYSQL.value == data_provider:
from mage_ai.io.mysql import MySQL

with MySQL.with_config(config_file_loader) as loader:
mysql.create_upstream_block_tables(
loader,
block,
configuration=configuration,
execution_partition=execution_partition,
)

query_string = mysql.interpolate_input_data(block, query)
query_string = interpolate_vars(query_string, global_vars=global_vars)

loader.export(
None,
None,
table_name,
drop_table_on_replace=True,
if_exists=export_write_policy,
index=False,
query_string=query_string,
verbose=BlockType.DATA_EXPORTER == block.type,
)

if should_query:
return [
loader.load(
f'SELECT * FROM {table_name}',
verbose=False,
),
]
elif DataSource.POSTGRES.value == data_provider:
from mage_ai.io.postgres import Postgres

Expand Down
65 changes: 65 additions & 0 deletions mage_ai/data_preparation/models/block/sql/mysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from mage_ai.data_preparation.models.constants import BlockType
from mage_ai.data_preparation.variable_manager import get_variable
from mage_ai.data_preparation.models.block.sql.utils.shared import (
interpolate_input,
should_cache_data_from_upstream,
)
from mage_ai.io.config import ConfigKey
from typing import Dict


def create_upstream_block_tables(
loader,
block,
cascade_on_drop: bool = False,
configuration: Dict = None,
execution_partition: str = None,
cache_upstream_dbt_models: bool = False,
):
from mage_ai.data_preparation.models.block.dbt.utils import (
parse_attributes,
source_table_name_for_block,
)
configuration = configuration if configuration else block.configuration

for idx, upstream_block in enumerate(block.upstream_blocks):
if should_cache_data_from_upstream(block, upstream_block, [
'data_provider',
], [
ConfigKey.MYSQL_DATABASE,
ConfigKey.MYSQL_HOST,
ConfigKey.MYSQL_PORT,
]):
if BlockType.DBT == upstream_block.type and not cache_upstream_dbt_models:
continue

table_name = upstream_block.table_name

df = get_variable(
upstream_block.pipeline.uuid,
upstream_block.uuid,
'output_0',
partition=execution_partition,
)

if BlockType.DBT == block.type and BlockType.DBT != upstream_block.type:
table_name = source_table_name_for_block(upstream_block)

loader.export(
df,
None,
table_name,
cascade_on_drop=cascade_on_drop,
drop_table_on_replace=True,
if_exists='replace',
index=False,
verbose=False,
)


def interpolate_input_data(block, query):
return interpolate_input(
block,
query,
lambda db, schema, tn: tn,
)
11 changes: 10 additions & 1 deletion mage_ai/data_preparation/models/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from mage_ai.data_preparation.models.block import Block, run_blocks, run_blocks_sync
from mage_ai.data_preparation.models.block.dbt.utils import update_model_settings
from mage_ai.data_preparation.models.block.errors import NoMultipleDynamicUpstreamBlocks
from mage_ai.data_preparation.models.block.utils import is_dynamic_block
from mage_ai.data_preparation.models.constants import (
BlockLanguage,
BlockType,
ExecutorType,
PipelineType,
Expand Down Expand Up @@ -183,7 +185,11 @@ async def get_async(self, uuid, repo_path: str = None):
if not os.path.exists(config_path):
raise Exception(f'Pipeline {uuid} does not exist.')
async with aiofiles.open(config_path, mode='r') as f:
config = yaml.safe_load(await f.read())
try:
config = yaml.safe_load(await f.read()) or {}
except Exception as err:
config = {}
print(err)

if PipelineType.INTEGRATION == config.get('type'):
pipeline = IntegrationPipeline(uuid, repo_path=repo_path, config=config)
Expand Down Expand Up @@ -503,6 +509,9 @@ async def update(self, data, update_content=False):
block.configuration = configuration
should_save = True

if BlockType.DBT == block.type and BlockLanguage.SQL == block.language:
update_model_settings(block, block.upstream_blocks, [], force_update=True)

if widget:
keys_to_update = []

Expand Down
5 changes: 5 additions & 0 deletions mage_ai/data_preparation/templates/repo/io_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ default:
auth_provider_x509_cert_url: "https://www.googleapis.com/oauth2/v1/certs"
client_x509_cert_url: "https://www.googleapis.com/robot/v1/metadata/x509/your_service_account_email"
GOOGLE_SERVICE_ACC_KEY_FILEPATH: "/path/to/your/service/account/key.json"
MYSQL_DATABASE: database
MYSQL_HOST: host
MYSQL_PASSWORD: password
MYSQL_PORT: 3306
MYSQL_USER: root
POSTGRES_DBNAME: postgres
POSTGRES_USER: username
POSTGRES_PASSWORD: password
Expand Down
1 change: 1 addition & 0 deletions mage_ai/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class DataSource(str, Enum):
FILE = 'file'
GOOGLE_CLOUD_STORAGE = 'google_cloud_storage'
KAFKA = 'kafka'
MYSQL = 'mysql'
OPENSEARCH = 'opensearch'
POSTGRES = 'postgres'
REDSHIFT = 'redshift'
Expand Down
6 changes: 6 additions & 0 deletions mage_ai/io/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class ConfigKey(str, Enum):
AZURE_TENANT_ID = 'AZURE_TENANT_ID'
GOOGLE_SERVICE_ACC_KEY = 'GOOGLE_SERVICE_ACC_KEY'
GOOGLE_SERVICE_ACC_KEY_FILEPATH = 'GOOGLE_SERVICE_ACC_KEY_FILEPATH'
MYSQL_CONNECTION_METHOD = 'MYSQL_CONNECTION_METHOD'
MYSQL_DATABASE = 'MYSQL_DATABASE'
MYSQL_HOST = 'MYSQL_HOST'
MYSQL_PASSWORD = 'MYSQL_PASSWORD'
MYSQL_PORT = 'MYSQL_PORT'
MYSQL_USER = 'MYSQL_USER'
POSTGRES_DBNAME = 'POSTGRES_DBNAME'
POSTGRES_USER = 'POSTGRES_USER'
POSTGRES_PASSWORD = 'POSTGRES_PASSWORD'
Expand Down
Loading

0 comments on commit 324c3dd

Please sign in to comment.