From 324c3ddf051aac94dcc2a7932c977bd0646ca2c2 Mon Sep 17 00:00:00 2001 From: DANGerous Date: Mon, 23 Jan 2023 20:58:00 -0800 Subject: [PATCH] [td] Support MySQL in dbt (#1809) * [td] 0.7.75 * [td] Add dbt mysql support * [td] undo version change * done --- Dockerfile | 1 + dev.Dockerfile | 1 + .../data_preparation/models/block/__init__.py | 2 +- .../models/block/dbt/utils/__init__.py | 85 ++++++- .../models/block/sql/__init__.py | 33 +++ .../models/block/sql/mysql.py | 65 +++++ mage_ai/data_preparation/models/pipeline.py | 11 +- .../templates/repo/io_config.yaml | 5 + mage_ai/io/base.py | 1 + mage_ai/io/config.py | 6 + mage_ai/io/export_utils.py | 8 +- mage_ai/io/mysql.py | 118 +++++++++ mage_ai/io/postgres.py | 237 ++---------------- mage_ai/io/sql/base.py | 226 +++++++++++++++++ .../destinations/mysql/utils.py | 2 +- requirements.txt | 5 +- setup.py | 5 +- 17 files changed, 580 insertions(+), 231 deletions(-) create mode 100644 mage_ai/data_preparation/models/block/sql/mysql.py create mode 100644 mage_ai/io/mysql.py create mode 100644 mage_ai/io/sql/base.py diff --git a/Dockerfile b/Dockerfile index f1697acd3200..31e6f694533b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,6 +6,7 @@ USER root # Install Mage RUN ${PIP} install --upgrade pip RUN ${PIP} install --no-cache "git+https://github.com/mage-ai/mage-ai.git#egg=mage-integrations&subdirectory=mage_integrations" +RUN ${PIP} install "git+https://github.com/mage-ai/dbt-mysql.git#egg=dbt-mysql" RUN ${PIP} install "git+https://github.com/mage-ai/singer-python.git#egg=singer-python" COPY ./mage_ai/server/constants.py constants.py RUN tag=$(tail -n 1 constants.py) && VERSION=$(echo $tag | tr -d "'") && ${PIP} install --no-cache "mage-ai[all]"==$VERSION diff --git a/dev.Dockerfile b/dev.Dockerfile index 0897e91ed00e..581e4986d838 100644 --- a/dev.Dockerfile +++ b/dev.Dockerfile @@ -17,6 +17,7 @@ COPY requirements.txt requirements.txt RUN ${PIP} install --upgrade pip COPY mage_integrations mage_integrations RUN ${PIP} install mage_integrations/ +RUN ${PIP} install "git+https://github.com/mage-ai/dbt-mysql.git#egg=dbt-mysql" RUN ${PIP} install "git+https://github.com/mage-ai/singer-python.git#egg=singer-python" RUN ${PIP} install -r requirements.txt RUN ${PIP} install jupyterlab diff --git a/mage_ai/data_preparation/models/block/__init__.py b/mage_ai/data_preparation/models/block/__init__.py index 6c90631554a7..fedaf1991bd0 100644 --- a/mage_ai/data_preparation/models/block/__init__.py +++ b/mage_ai/data_preparation/models/block/__init__.py @@ -306,7 +306,7 @@ def file(self): @property def table_name(self): - table_name = f'{self.pipeline.uuid}_{self.uuid}_{self.pipeline.version_name}' + table_name = f'{self.pipeline.uuid}_{clean_name_orig(self.uuid)}_{self.pipeline.version_name}' env = (self.global_vars or dict()).get('env') if env == ENV_DEV: diff --git a/mage_ai/data_preparation/models/block/dbt/utils/__init__.py b/mage_ai/data_preparation/models/block/dbt/utils/__init__.py index 623d235a2287..675c118e1687 100644 --- a/mage_ai/data_preparation/models/block/dbt/utils/__init__.py +++ b/mage_ai/data_preparation/models/block/dbt/utils/__init__.py @@ -4,6 +4,7 @@ from mage_ai.data_preparation.models.block.sql import ( bigquery, execute_sql_code as execute_sql_code_orig, + mysql, postgres, snowflake, ) @@ -16,7 +17,7 @@ from mage_ai.shared.array import find from mage_ai.shared.hash import merge_dict from mage_ai.shared.parsers import encode_complex -from mage_ai.shared.utils import files_in_path +from mage_ai.shared.utils import clean_name, files_in_path from pandas import DataFrame from typing import Callable, Dict, List, Tuple import os @@ -44,7 +45,14 @@ def parse_attributes(block) -> Dict: full_path = f'{get_repo_path()}/dbt/{file_path}' sources_full_path = re.sub(filename, 'mage_sources.yml', full_path) - source_name = f'mage_{project_name}' + profiles_full_path = f'{project_full_path}/profiles.yml' + profile_target = block.configuration.get('dbt_profile_target') + profile = load_profile(project_name, profiles_full_path, profile_target) + + if profile and 'mysql' == profile.get('type'): + source_name = profile['schema'] + else: + source_name = f'mage_{project_name}' return dict( file_extension=file_extension, @@ -52,7 +60,7 @@ def parse_attributes(block) -> Dict: filename=filename, full_path=full_path, model_name=model_name, - profiles_full_path=f'{project_full_path}/profiles.yml', + profiles_full_path=profiles_full_path, project_full_path=project_full_path, project_name=project_name, source_name=source_name, @@ -118,10 +126,15 @@ def load_sources(block) -> Dict: def source_table_name_for_block(block) -> str: - return f'{block.pipeline.uuid}_{block.uuid}' + return f'{clean_name(block.pipeline.uuid)}_{clean_name(block.uuid)}' -def update_model_settings(block, upstream_blocks, upstream_blocks_previous): +def update_model_settings( + block: 'Block', + upstream_blocks: List['Block'], + upstream_blocks_previous: List['Block'], + force_update: bool = False, +): attributes_dict = parse_attributes(block) filename = attributes_dict['filename'] @@ -130,7 +143,7 @@ def update_model_settings(block, upstream_blocks, upstream_blocks_previous): sources_full_path = attributes_dict['sources_full_path'] source_name = attributes_dict['source_name'] - if len(upstream_blocks_previous) > len(upstream_blocks): + if not force_update and len(upstream_blocks_previous) > len(upstream_blocks): # TODO (tommy dangerous): should we remove sources? # How do we know no other model is using a source? @@ -175,6 +188,8 @@ def update_model_settings(block, upstream_blocks, upstream_blocks_previous): ) settings = load_sources(block) + + print('WTFFFFFFFFFFFFFF', new_source) if settings: source = find(lambda x: x['name'] == source_name, settings.get('sources', [])) if source: @@ -201,7 +216,9 @@ def get_profile(block, profile_target: str = None) -> Dict: attr = parse_attributes(block) project_name = attr['project_name'] profiles_full_path = attr['profiles_full_path'] + return load_profile(project_name, profiles_full_path, profile_target) +def load_profile(project_name: str, profiles_full_path: str, profile_target: str = None) -> Dict: with open(profiles_full_path, 'r') as f: try: text = Template(f.read()).render( @@ -261,6 +278,27 @@ def config_file_loader_and_configuration(block, profile_target: str) -> Dict: data_provider_schema=schema, export_write_policy=ExportWritePolicy.REPLACE, ) + elif DataSource.MYSQL == profile_type: + host = profile.get('server') + password = profile.get('password') + port = profile.get('port') + schema = profile.get('schema') + ssl_disabled = profile.get('ssl_disabled') + username = profile.get('username') + + config_file_loader = ConfigFileLoader(config=dict( + MYSQL_CONNECTION_METHOD='ssh_tunnel' if not ssl_disabled else None, + MYSQL_DATABASE=schema, + MYSQL_HOST=host, + MYSQL_PASSWORD=password, + MYSQL_PORT=port, + MYSQL_USER=username, + )) + configuration = dict( + data_provider=profile_type, + data_provider_database=schema, + export_write_policy=ExportWritePolicy.REPLACE, + ) elif DataSource.SNOWFLAKE == profile_type: database = profile.get('database') schema = profile.get('schema') @@ -339,6 +377,18 @@ def create_upstream_tables( cache_upstream_dbt_models=cache_upstream_dbt_models, **kwargs, ) + elif DataSource.MYSQL == data_provider: + from mage_ai.io.mysql import MySQL + + with MySQL.with_config(config_file_loader) as loader: + mysql.create_upstream_block_tables( + loader, + block, + cascade_on_drop=True, + configuration=configuration, + cache_upstream_dbt_models=cache_upstream_dbt_models, + **kwargs, + ) elif DataSource.BIGQUERY == data_provider: from mage_ai.io.bigquery import BigQuery @@ -379,6 +429,9 @@ def __replace_func(db, schema, tn): if replace_func: return replace_func(db, schema, tn) + if db and not schema: + return f'{__quoted(db)}.{__quoted(tn)}' + return f'{__quoted(schema)}.{__quoted(tn)}' for idx, upstream_block in enumerate(block.upstream_blocks): @@ -387,8 +440,15 @@ def __replace_func(db, schema, tn): attrs = parse_attributes(upstream_block) model_name = attrs['model_name'] - matcher1 = f'{__quoted(profile_database)}.{__quoted(profile_schema)}.'\ - f'{__quoted(model_name)}' + + arr = [] + if profile_database: + arr.append(__quoted(profile_database)) + if profile_schema: + arr.append(__quoted(profile_schema)) + if model_name: + arr.append(__quoted(model_name)) + matcher1 = '.'.join(arr) database = configuration.get('data_provider_database') schema = configuration.get('data_provider_schema') @@ -426,6 +486,10 @@ def query_from_compiled_sql(block, profile_target: str) -> DataFrame: database = profile['dbname'] schema = profile['schema'] quote_str = '"' + elif DataSource.MYSQL == profile_type: + database = configuration['data_provider_database'] + schema = None + quote_str = '`' elif DataSource.BIGQUERY == profile_type: database = profile['project'] schema = profile['dataset'] @@ -448,6 +512,11 @@ def query_from_compiled_sql(block, profile_target: str) -> DataFrame: with Postgres.with_config(config_file_loader) as loader: return loader.load(query_string) + elif DataSource.MYSQL == data_provider: + from mage_ai.io.mysql import MySQL + + with MySQL.with_config(config_file_loader) as loader: + return loader.load(query_string) elif DataSource.BIGQUERY == data_provider: from mage_ai.io.bigquery import BigQuery diff --git a/mage_ai/data_preparation/models/block/sql/__init__.py b/mage_ai/data_preparation/models/block/sql/__init__.py index 514aa1680e60..f96d90600828 100644 --- a/mage_ai/data_preparation/models/block/sql/__init__.py +++ b/mage_ai/data_preparation/models/block/sql/__init__.py @@ -1,6 +1,7 @@ from mage_ai.data_preparation.models.block import Block from mage_ai.data_preparation.models.block.sql import ( bigquery, + mysql, postgres, redshift, snowflake, @@ -88,6 +89,38 @@ def execute_sql_code( except Exception as err: if '404' not in str(err): raise err + elif DataSource.MYSQL.value == data_provider: + from mage_ai.io.mysql import MySQL + + with MySQL.with_config(config_file_loader) as loader: + mysql.create_upstream_block_tables( + loader, + block, + configuration=configuration, + execution_partition=execution_partition, + ) + + query_string = mysql.interpolate_input_data(block, query) + query_string = interpolate_vars(query_string, global_vars=global_vars) + + loader.export( + None, + None, + table_name, + drop_table_on_replace=True, + if_exists=export_write_policy, + index=False, + query_string=query_string, + verbose=BlockType.DATA_EXPORTER == block.type, + ) + + if should_query: + return [ + loader.load( + f'SELECT * FROM {table_name}', + verbose=False, + ), + ] elif DataSource.POSTGRES.value == data_provider: from mage_ai.io.postgres import Postgres diff --git a/mage_ai/data_preparation/models/block/sql/mysql.py b/mage_ai/data_preparation/models/block/sql/mysql.py new file mode 100644 index 000000000000..169740bdca01 --- /dev/null +++ b/mage_ai/data_preparation/models/block/sql/mysql.py @@ -0,0 +1,65 @@ +from mage_ai.data_preparation.models.constants import BlockType +from mage_ai.data_preparation.variable_manager import get_variable +from mage_ai.data_preparation.models.block.sql.utils.shared import ( + interpolate_input, + should_cache_data_from_upstream, +) +from mage_ai.io.config import ConfigKey +from typing import Dict + + +def create_upstream_block_tables( + loader, + block, + cascade_on_drop: bool = False, + configuration: Dict = None, + execution_partition: str = None, + cache_upstream_dbt_models: bool = False, +): + from mage_ai.data_preparation.models.block.dbt.utils import ( + parse_attributes, + source_table_name_for_block, + ) + configuration = configuration if configuration else block.configuration + + for idx, upstream_block in enumerate(block.upstream_blocks): + if should_cache_data_from_upstream(block, upstream_block, [ + 'data_provider', + ], [ + ConfigKey.MYSQL_DATABASE, + ConfigKey.MYSQL_HOST, + ConfigKey.MYSQL_PORT, + ]): + if BlockType.DBT == upstream_block.type and not cache_upstream_dbt_models: + continue + + table_name = upstream_block.table_name + + df = get_variable( + upstream_block.pipeline.uuid, + upstream_block.uuid, + 'output_0', + partition=execution_partition, + ) + + if BlockType.DBT == block.type and BlockType.DBT != upstream_block.type: + table_name = source_table_name_for_block(upstream_block) + + loader.export( + df, + None, + table_name, + cascade_on_drop=cascade_on_drop, + drop_table_on_replace=True, + if_exists='replace', + index=False, + verbose=False, + ) + + +def interpolate_input_data(block, query): + return interpolate_input( + block, + query, + lambda db, schema, tn: tn, + ) diff --git a/mage_ai/data_preparation/models/pipeline.py b/mage_ai/data_preparation/models/pipeline.py index f2eff31fb4e8..7e6f776e93ae 100644 --- a/mage_ai/data_preparation/models/pipeline.py +++ b/mage_ai/data_preparation/models/pipeline.py @@ -1,7 +1,9 @@ from mage_ai.data_preparation.models.block import Block, run_blocks, run_blocks_sync +from mage_ai.data_preparation.models.block.dbt.utils import update_model_settings from mage_ai.data_preparation.models.block.errors import NoMultipleDynamicUpstreamBlocks from mage_ai.data_preparation.models.block.utils import is_dynamic_block from mage_ai.data_preparation.models.constants import ( + BlockLanguage, BlockType, ExecutorType, PipelineType, @@ -183,7 +185,11 @@ async def get_async(self, uuid, repo_path: str = None): if not os.path.exists(config_path): raise Exception(f'Pipeline {uuid} does not exist.') async with aiofiles.open(config_path, mode='r') as f: - config = yaml.safe_load(await f.read()) + try: + config = yaml.safe_load(await f.read()) or {} + except Exception as err: + config = {} + print(err) if PipelineType.INTEGRATION == config.get('type'): pipeline = IntegrationPipeline(uuid, repo_path=repo_path, config=config) @@ -503,6 +509,9 @@ async def update(self, data, update_content=False): block.configuration = configuration should_save = True + if BlockType.DBT == block.type and BlockLanguage.SQL == block.language: + update_model_settings(block, block.upstream_blocks, [], force_update=True) + if widget: keys_to_update = [] diff --git a/mage_ai/data_preparation/templates/repo/io_config.yaml b/mage_ai/data_preparation/templates/repo/io_config.yaml index 5ca4083f6624..16fff8dedf39 100644 --- a/mage_ai/data_preparation/templates/repo/io_config.yaml +++ b/mage_ai/data_preparation/templates/repo/io_config.yaml @@ -21,6 +21,11 @@ default: auth_provider_x509_cert_url: "https://www.googleapis.com/oauth2/v1/certs" client_x509_cert_url: "https://www.googleapis.com/robot/v1/metadata/x509/your_service_account_email" GOOGLE_SERVICE_ACC_KEY_FILEPATH: "/path/to/your/service/account/key.json" + MYSQL_DATABASE: database + MYSQL_HOST: host + MYSQL_PASSWORD: password + MYSQL_PORT: 3306 + MYSQL_USER: root POSTGRES_DBNAME: postgres POSTGRES_USER: username POSTGRES_PASSWORD: password diff --git a/mage_ai/io/base.py b/mage_ai/io/base.py index 784430c317c0..c8a08d2fc4a4 100644 --- a/mage_ai/io/base.py +++ b/mage_ai/io/base.py @@ -16,6 +16,7 @@ class DataSource(str, Enum): FILE = 'file' GOOGLE_CLOUD_STORAGE = 'google_cloud_storage' KAFKA = 'kafka' + MYSQL = 'mysql' OPENSEARCH = 'opensearch' POSTGRES = 'postgres' REDSHIFT = 'redshift' diff --git a/mage_ai/io/config.py b/mage_ai/io/config.py index 57cae2043b32..cd5c4cf88c1d 100644 --- a/mage_ai/io/config.py +++ b/mage_ai/io/config.py @@ -23,6 +23,12 @@ class ConfigKey(str, Enum): AZURE_TENANT_ID = 'AZURE_TENANT_ID' GOOGLE_SERVICE_ACC_KEY = 'GOOGLE_SERVICE_ACC_KEY' GOOGLE_SERVICE_ACC_KEY_FILEPATH = 'GOOGLE_SERVICE_ACC_KEY_FILEPATH' + MYSQL_CONNECTION_METHOD = 'MYSQL_CONNECTION_METHOD' + MYSQL_DATABASE = 'MYSQL_DATABASE' + MYSQL_HOST = 'MYSQL_HOST' + MYSQL_PASSWORD = 'MYSQL_PASSWORD' + MYSQL_PORT = 'MYSQL_PORT' + MYSQL_USER = 'MYSQL_USER' POSTGRES_DBNAME = 'POSTGRES_DBNAME' POSTGRES_USER = 'POSTGRES_USER' POSTGRES_PASSWORD = 'POSTGRES_PASSWORD' diff --git a/mage_ai/io/export_utils.py b/mage_ai/io/export_utils.py index 02914c72d10a..fcd4d61d62aa 100644 --- a/mage_ai/io/export_utils.py +++ b/mage_ai/io/export_utils.py @@ -98,4 +98,10 @@ def gen_table_creation_query( query = [] for cname in dtypes: query.append(f'"{clean_name(cname)}" {dtypes[cname]}') - return f'CREATE TABLE {schema_name}.{table_name} (' + ','.join(query) + ');' + + if schema_name: + full_table_name = f'{schema_name}.{table_name}' + else: + full_table_name = table_name + + return f'CREATE TABLE {full_table_name} (' + ','.join(query) + ');' diff --git a/mage_ai/io/mysql.py b/mage_ai/io/mysql.py new file mode 100644 index 000000000000..43d3f983b207 --- /dev/null +++ b/mage_ai/io/mysql.py @@ -0,0 +1,118 @@ +from mage_ai.io.config import BaseConfigLoader, ConfigKey +from mage_ai.io.export_utils import BadConversionError, PandasTypes +from mage_ai.io.sql.base import BaseSQL +from mage_ai.shared.utils import clean_name +from mysql.connector import connect +from pandas import DataFrame, Series +import numpy as np + + +class MySQL(BaseSQL): + def __init__( + self, + database: str, + host: str, + password: str, + user: str, + port: int = 3306, + verbose: bool = True, + **kwargs, + ) -> None: + super().__init__( + database=database, + host=host, + password=password, + port=port or 3306, + user=user, + verbose=verbose, + **kwargs, + ) + + @classmethod + def with_config(cls, config: BaseConfigLoader) -> 'MySQL': + return cls( + database=config[ConfigKey.MYSQL_DATABASE], + host=config[ConfigKey.MYSQL_HOST], + password=config[ConfigKey.MYSQL_PASSWORD], + port=config[ConfigKey.MYSQL_PORT], + user=config[ConfigKey.MYSQL_USER], + ) + + def build_create_table_command(self, dtypes, schema_name: str, table_name: str) -> str: + query = [] + for cname in dtypes: + query.append(f'`{clean_name(cname)}` {dtypes[cname]}') + + return f'CREATE TABLE {table_name} (' + ','.join(query) + ');' + + def open(self) -> None: + with self.printer.print_msg('Opening connection to MySQL database'): + self._ctx = connect(**self.settings) + + def table_exists(self, schema_name: str, table_name: str) -> bool: + with self.conn.cursor() as cur: + database_name = self.settings['database'] + cur.execute('\n'.join([ + 'SELECT * FROM information_schema.tables ', + f'WHERE table_schema = \'{database_name}\' AND table_name = \'{table_name}\'', + 'LIMIT 1', + ])) + return len(cur.fetchall()) >= 1 + + def upload_dataframe(self, cursor, df: DataFrame, full_table_name: str, buffer = None) -> None: + values_placeholder = ', '.join(["%s" for i in range(len(df.columns))]) + values = [] + for i, row in df.iterrows(): + values.append(tuple(row)) + + sql = f'INSERT INTO {full_table_name} VALUES ({values_placeholder})' + cursor.executemany(sql, values) + + def get_type(self, column: Series, dtype: str) -> str: + if dtype in ( + PandasTypes.MIXED, + PandasTypes.UNKNOWN_ARRAY, + PandasTypes.COMPLEX, + ): + raise BadConversionError( + f'Cannot convert column \'{column.name}\' with data type \'{dtype}\' to a MySQL datatype.' + ) + elif dtype in (PandasTypes.DATETIME, PandasTypes.DATETIME64): + try: + if column.dt.tz: + return 'TIMESTAMP' + except AttributeError: + pass + return 'TIMESTAMP' + elif dtype == PandasTypes.TIME: + try: + if column.dt.tz: + return 'TIME' + except AttributeError: + pass + return 'TIME' + elif dtype == PandasTypes.DATE: + return 'DATE' + elif dtype == PandasTypes.STRING: + return 'CHAR(255)' + elif dtype == PandasTypes.CATEGORICAL: + return 'TEXT' + elif dtype == PandasTypes.BYTES: + return 'VARBINARY(255)' + elif dtype in (PandasTypes.FLOATING, PandasTypes.DECIMAL, PandasTypes.MIXED_INTEGER_FLOAT): + return 'DECIMAL' + elif dtype == PandasTypes.INTEGER: + max_int, min_int = column.max(), column.min() + if np.int16(max_int) == max_int and np.int16(min_int) == min_int: + return 'BIGINT' + elif np.int32(max_int) == max_int and np.int32(min_int) == min_int: + return 'BIGINT' + else: + return 'BIGINT' + elif dtype == PandasTypes.BOOLEAN: + return 'CHAR(52)' + elif dtype in (PandasTypes.TIMEDELTA, PandasTypes.TIMEDELTA64, PandasTypes.PERIOD): + return 'BIGINT' + else: + raise ValueError(f'Invalid datatype provided: {dtype}') + diff --git a/mage_ai/io/postgres.py b/mage_ai/io/postgres.py index 0d6f255d7d82..d096e3c19761 100644 --- a/mage_ai/io/postgres.py +++ b/mage_ai/io/postgres.py @@ -1,23 +1,15 @@ -from io import StringIO -from mage_ai.io.base import BaseSQLConnection, ExportWritePolicy, QUERY_ROW_LIMIT from mage_ai.io.config import BaseConfigLoader, ConfigKey -from mage_ai.io.export_utils import ( - BadConversionError, - clean_df_for_export, - gen_table_creation_query, - infer_dtypes, - PandasTypes, -) -from pandas import DataFrame, read_sql, Series +from mage_ai.io.export_utils import BadConversionError, PandasTypes +from mage_ai.io.sql.base import BaseSQL +from pandas import DataFrame, Series from psycopg2 import connect import numpy as np -class Postgres(BaseSQLConnection): +class Postgres(BaseSQL): """ Handles data transfer between a PostgreSQL database and the Mage app. """ - def __init__( self, dbname: str, @@ -49,210 +41,31 @@ def __init__( **kwargs, ) + @classmethod + def with_config(cls, config: BaseConfigLoader) -> 'Postgres': + return cls( + dbname=config[ConfigKey.POSTGRES_DBNAME], + user=config[ConfigKey.POSTGRES_USER], + password=config[ConfigKey.POSTGRES_PASSWORD], + host=config[ConfigKey.POSTGRES_HOST], + port=config[ConfigKey.POSTGRES_PORT], + ) + def open(self) -> None: - """ - Opens a connection to the PostgreSQL database specified by the parameters. - """ with self.printer.print_msg('Opening connection to PostgreSQL database'): self._ctx = connect(**self.settings, keepalives=1, keepalives_idle=300, ) - def execute(self, query_string: str, **query_vars) -> None: - """ - Sends query to the connected database. - - Args: - query_string (str): SQL query string to apply on the connected database. - query_vars: Variable values to fill in when using format strings in query. - """ - with self.printer.print_msg(f'Executing query \'{query_string}\''): - query_string = self._clean_query(query_string) - with self.conn.cursor() as cur: - cur.execute(query_string, **query_vars) - - def load( - self, - query_string: str, - limit: int = QUERY_ROW_LIMIT, - display_query: str = None, - verbose: bool = True, - **kwargs, - ) -> DataFrame: - """ - Loads data from the connected database into a Pandas data frame based on the query given. - This will fail if the query returns no data from the database. This function will load at - maximum 10,000,000 rows of data. To operate on more data, consider performing data - transformations in warehouse. - - Args: - query_string (str): Query to execute on the database. - limit (int, Optional): The number of rows to limit the loaded dataframe to. Defaults to 10,000,000. - **kwargs: Additional query parameters. - - Returns: - DataFrame: The data frame corresponding to the data returned by the given query. - """ - print_message = 'Loading data' - if verbose: - print_message += ' with query' - - if display_query: - for line in display_query.split('\n'): - print_message += f'\n{line}' - else: - print_message += f'\n{query_string}' - - query_string = self._clean_query(query_string) - - with self.printer.print_msg(print_message): - return read_sql(self._enforce_limit(query_string, limit), self.conn, **kwargs) - - def export( - self, - df: DataFrame, - schema_name: str, - table_name: str, - if_exists: ExportWritePolicy = ExportWritePolicy.REPLACE, - index: bool = False, - verbose: bool = True, - query_string: str = None, - drop_table_on_replace: bool = False, - cascade_on_drop: bool = False, - ) -> None: - """ - Exports dataframe to the connected database from a Pandas data frame. If table doesn't - exist, the table is automatically created. If the schema doesn't exist, the schema is also created. - - Args: - schema_name (str): Name of the schema of the table to export data to. - table_name (str): Name of the table to insert rows from this data frame into. - if_exists (ExportWritePolicy): Specifies export policy if table exists. Either - - `'fail'`: throw an error. - - `'replace'`: drops existing table and creates new table of same name. - - `'append'`: appends data frame to existing table. In this case the schema must match the original table. - Defaults to `'replace'`. - index (bool): If true, the data frame index is also exported alongside the table. Defaults to False. - **kwargs: Additional query parameters. - """ - - full_table_name = f'{schema_name}.{table_name}' - - if not query_string: - if index: - df = df.reset_index() - - dtypes = infer_dtypes(df) - df = clean_df_for_export(df, self.clean, dtypes) - - def __process(): - buffer = StringIO() - table_exists = self.__table_exists(schema_name, table_name) - - with self.conn.cursor() as cur: - cur.execute(f'CREATE SCHEMA IF NOT EXISTS {schema_name};') - - should_create_table = not table_exists - - if table_exists: - if ExportWritePolicy.FAIL == if_exists: - raise ValueError( - f'Table \'{full_table_name}\' already exists in database.' - ) - elif ExportWritePolicy.REPLACE == if_exists: - if drop_table_on_replace: - cmd = f'DROP TABLE {full_table_name}' - if cascade_on_drop: - cmd = f'{cmd} CASCADE' - cur.execute(cmd) - should_create_table = True - else: - cur.execute(f'DELETE FROM {full_table_name}') - - if query_string: - query = 'CREATE TABLE {} AS\n{}'.format( - full_table_name, - query_string, - ) - - if ExportWritePolicy.APPEND == if_exists and table_exists: - query = 'INSERT INTO {}\n{}'.format( - full_table_name, - query_string, - ) - cur.execute(query) - else: - if should_create_table: - db_dtypes = {col: self.get_type(df[col], dtypes[col]) for col in dtypes} - query = gen_table_creation_query(db_dtypes, schema_name, table_name) - cur.execute(query) - - df.to_csv(buffer, index=False, header=False) - buffer.seek(0) - cur.copy_expert( - f'COPY {full_table_name} FROM STDIN (FORMAT csv, DELIMITER \',\', NULL \'\');', - buffer, - ) - self.conn.commit() - - if verbose: - with self.printer.print_msg( - f'Exporting data to \'{full_table_name}\'' - ): - __process() - else: - __process() - - def __table_exists(self, schema_name: str, table_name: str) -> bool: - """ - Returns whether the specified table exists. - - Args: - schema_name (str): Name of the schema the table belongs to. - table_name (str): Name of the table to check existence of. - - Returns: - bool: True if the table exists, else False. - """ + def table_exists(self, schema_name: str, table_name: str) -> bool: with self.conn.cursor() as cur: cur.execute( f'SELECT * FROM pg_tables WHERE schemaname = \'{schema_name}\' AND tablename = \'{table_name}\'' ) return bool(cur.rowcount) - def clean(self, column: Series, dtype: str) -> Series: - """ - Cleans column in order to write data frame to PostgreSQL database - - Args: - column (Series): Column to clean - dtype (str): The pandas data types of this column - - Returns: - Series: Cleaned column - """ - if dtype == PandasTypes.CATEGORICAL: - return column.astype(str) - elif dtype in (PandasTypes.TIMEDELTA, PandasTypes.TIMEDELTA64, PandasTypes.PERIOD): - return column.view(int) - else: - return column - def get_type(self, column: Series, dtype: str) -> str: - """ - Maps pandas Data Frame column to PostgreSQL type - - Args: - series (Series): Column to map - dtype (str): Pandas data type of this column - - Raises: - ConversionError: Returned if this type cannot be converted to a PostgreSQL data type - - Returns: - str: PostgreSQL data type for this column - """ if dtype in ( PandasTypes.MIXED, PandasTypes.UNKNOWN_ARRAY, @@ -300,18 +113,10 @@ def get_type(self, column: Series, dtype: str) -> str: else: raise ValueError(f'Invalid datatype provided: {dtype}') - @classmethod - def with_config(cls, config: BaseConfigLoader) -> 'Postgres': - """ - Initializes PostgreSQL loader from configuration loader - - Args: - config (BaseConfigLoader): Configuration loader object - """ - return cls( - dbname=config[ConfigKey.POSTGRES_DBNAME], - user=config[ConfigKey.POSTGRES_USER], - password=config[ConfigKey.POSTGRES_PASSWORD], - host=config[ConfigKey.POSTGRES_HOST], - port=config[ConfigKey.POSTGRES_PORT], + def upload_dataframe(self, cursor, df: DataFrame, full_table_name: str, buffer = None) -> None: + df.to_csv(buffer, index=False, header=False) + buffer.seek(0) + cursor.copy_expert( + f'COPY {full_table_name} FROM STDIN (FORMAT csv, DELIMITER \',\', NULL \'\');', + buffer, ) diff --git a/mage_ai/io/sql/base.py b/mage_ai/io/sql/base.py new file mode 100644 index 000000000000..00cbbd947973 --- /dev/null +++ b/mage_ai/io/sql/base.py @@ -0,0 +1,226 @@ +from io import StringIO +from mage_ai.io.base import BaseSQLConnection, ExportWritePolicy, QUERY_ROW_LIMIT +from mage_ai.io.config import BaseConfigLoader +from mage_ai.io.export_utils import ( + BadConversionError, + clean_df_for_export, + gen_table_creation_query, + infer_dtypes, + PandasTypes, +) +from pandas import DataFrame, read_sql, Series + + +class BaseSQL(BaseSQLConnection): + @classmethod + def with_config(cls, config: BaseConfigLoader) -> 'Postgres': + """ + Initializes SQL loader from configuration loader + + Args: + config (BaseConfigLoader): Configuration loader object + """ + raise Exception('Subclasses must override this method.') + + def get_type(self, column: Series, dtype: str) -> str: + """ + Maps pandas Data Frame column to SQL type + + Args: + series (Series): Column to map + dtype (str): Pandas data type of this column + + Raises: + ConversionError: Returned if this type cannot be converted to a SQL data type + + Returns: + str: SQL data type for this column + """ + raise Exception('Subclasses must override this method.') + + def build_create_table_command(self, dtypes, schema_name: str, table_name: str) -> str: + return gen_table_creation_query(dtypes, schema_name, table_name) + + def open(self) -> None: + """ + Opens a connection to the SQL database specified by the parameters. + """ + raise Exception('Subclasses must override this method.') + + def table_exists(self, schema_name: str, table_name: str) -> bool: + """ + Returns whether the specified table exists. + + Args: + schema_name (str): Name of the schema the table belongs to. + table_name (str): Name of the table to check existence of. + + Returns: + bool: True if the table exists, else False. + """ + raise Exception('Subclasses must override this method.') + + def upload_dataframe(self, cursor, df: DataFrame, full_table_name: str, buffer = None) -> None: + raise Exception('Subclasses must override this method.') + + def execute(self, query_string: str, **query_vars) -> None: + """ + Sends query to the connected database. + + Args: + query_string (str): SQL query string to apply on the connected database. + query_vars: Variable values to fill in when using format strings in query. + """ + with self.printer.print_msg(f'Executing query \'{query_string}\''): + query_string = self._clean_query(query_string) + with self.conn.cursor() as cur: + cur.execute(query_string, **query_vars) + + def load( + self, + query_string: str, + limit: int = QUERY_ROW_LIMIT, + display_query: str = None, + verbose: bool = True, + **kwargs, + ) -> DataFrame: + """ + Loads data from the connected database into a Pandas data frame based on the query given. + This will fail if the query returns no data from the database. This function will load at + maximum 10,000,000 rows of data. To operate on more data, consider performing data + transformations in warehouse. + + Args: + query_string (str): Query to execute on the database. + limit (int, Optional): The number of rows to limit the loaded dataframe to. Defaults to 10,000,000. + **kwargs: Additional query parameters. + + Returns: + DataFrame: The data frame corresponding to the data returned by the given query. + """ + print_message = 'Loading data' + if verbose: + print_message += ' with query' + + if display_query: + for line in display_query.split('\n'): + print_message += f'\n{line}' + else: + print_message += f'\n\n{query_string}\n\n' + + query_string = self._clean_query(query_string) + + with self.printer.print_msg(print_message): + return read_sql(self._enforce_limit(query_string, limit), self.conn, **kwargs) + + def export( + self, + df: DataFrame, + schema_name: str, + table_name: str, + if_exists: ExportWritePolicy = ExportWritePolicy.REPLACE, + index: bool = False, + verbose: bool = True, + query_string: str = None, + drop_table_on_replace: bool = False, + cascade_on_drop: bool = False, + ) -> None: + """ + Exports dataframe to the connected database from a Pandas data frame. If table doesn't + exist, the table is automatically created. If the schema doesn't exist, the schema is also created. + + Args: + schema_name (str): Name of the schema of the table to export data to. + table_name (str): Name of the table to insert rows from this data frame into. + if_exists (ExportWritePolicy): Specifies export policy if table exists. Either + - `'fail'`: throw an error. + - `'replace'`: drops existing table and creates new table of same name. + - `'append'`: appends data frame to existing table. In this case the schema must match the original table. + Defaults to `'replace'`. + index (bool): If true, the data frame index is also exported alongside the table. Defaults to False. + **kwargs: Additional query parameters. + """ + + if schema_name: + full_table_name = f'{schema_name}.{table_name}' + else: + full_table_name = table_name + + if not query_string: + if index: + df = df.reset_index() + + dtypes = infer_dtypes(df) + df = clean_df_for_export(df, self.clean, dtypes) + + def __process(): + buffer = StringIO() + table_exists = self.table_exists(schema_name, table_name) + + with self.conn.cursor() as cur: + if schema_name: + cur.execute(f'CREATE SCHEMA IF NOT EXISTS {schema_name};') + + should_create_table = not table_exists + + if table_exists: + if ExportWritePolicy.FAIL == if_exists: + raise ValueError( + f'Table \'{full_table_name}\' already exists in database.' + ) + elif ExportWritePolicy.REPLACE == if_exists: + if drop_table_on_replace: + cmd = f'DROP TABLE {full_table_name}' + if cascade_on_drop: + cmd = f'{cmd} CASCADE' + cur.execute(cmd) + should_create_table = True + else: + cur.execute(f'DELETE FROM {full_table_name}') + + if query_string: + query = 'CREATE TABLE {} AS\n{}'.format( + full_table_name, + query_string, + ) + + if ExportWritePolicy.APPEND == if_exists and table_exists: + query = 'INSERT INTO {}\n{}'.format( + full_table_name, + query_string, + ) + cur.execute(query) + else: + if should_create_table: + db_dtypes = {col: self.get_type(df[col], dtypes[col]) for col in dtypes} + query = self.build_create_table_command(db_dtypes, schema_name, table_name) + cur.execute(query) + + self.upload_dataframe(cur, df, full_table_name, buffer) + self.conn.commit() + + if verbose: + with self.printer.print_msg( + f'Exporting data to \'{full_table_name}\'' + ): + __process() + else: + __process() + + def clean(self, column: Series, dtype: str) -> Series: + """ + Cleans column in order to write data frame to PostgreSQL database + + Args: + column (Series): Column to clean + dtype (str): The pandas data types of this column + + Returns: + Series: Cleaned column + """ + if dtype == PandasTypes.CATEGORICAL: + return column.astype(str) + elif dtype in (PandasTypes.TIMEDELTA, PandasTypes.TIMEDELTA64, PandasTypes.PERIOD): + return column.view(int) + else: + return column diff --git a/mage_integrations/mage_integrations/destinations/mysql/utils.py b/mage_integrations/mage_integrations/destinations/mysql/utils.py index b4707a4bd50c..73fc017ef772 100644 --- a/mage_integrations/mage_integrations/destinations/mysql/utils.py +++ b/mage_integrations/mage_integrations/destinations/mysql/utils.py @@ -65,7 +65,7 @@ def convert_column_type(column_type: str, column_settings: Dict, **kwargs) -> st if COLUMN_TYPE_BOOLEAN == column_type: return 'CHAR(52)' elif COLUMN_TYPE_INTEGER == column_type: - return 'UNSIGNED' + return 'BIGINT' elif COLUMN_TYPE_NUMBER == column_type: return 'DECIMAL' elif COLUMN_TYPE_OBJECT == column_type: diff --git a/requirements.txt b/requirements.txt index bb3f60179281..01c0b0778803 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,8 +26,8 @@ aiofiles==22.1.0 # extras aws-secretsmanager-caching==1.1.1.5 -botocore==1.27.19 boto3==1.24.19 +botocore==1.27.19 db-dtypes==1.0.5 google-api-python-client==2.70.0 google-cloud-bigquery~=2.34.4 @@ -35,6 +35,7 @@ google-cloud-iam==2.9.0 google-cloud-run==0.5.0 google-cloud-storage==2.5.0 kubernetes==25.3.0 +mysql-connector-python==8.0.31 psycopg2-binary==2.9.3 redshift-connector==2.0.909 snowflake-connector-python==2.7.9 @@ -53,7 +54,7 @@ opensearch-py==2.0.0 requests_aws4auth==1.1.2 # DBT +dbt-bigquery==1.3.0 dbt-core==1.3.0 dbt-postgres==1.3.0 -dbt-bigquery==1.3.0 dbt-snowflake==1.3.0 diff --git a/setup.py b/setup.py index d6ac66ef89fe..343a71edbe41 100644 --- a/setup.py +++ b/setup.py @@ -51,13 +51,16 @@ def readme(): 'db-dtypes==1.0.2', ], 'dbt': [ + 'dbt-bigquery==1.3.0', 'dbt-core==1.3.0', 'dbt-postgres==1.3.0', - 'dbt-bigquery==1.3.0', 'dbt-snowflake==1.3.0', ], 'google-cloud-storage': ['google-cloud-storage==2.5.0'], 'hdf5': ['tables==3.7.0'], + 'mysql': [ + 'mysql-connector-python==8.0.31', + ], 'postgres': ['psycopg2-binary==2.9.3'], 'redshift': ['boto3==1.24.19', 'redshift-connector==2.0.909'], 's3': ['botocore==1.27.19', 'boto3==1.24.19'],