diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index f53d35aa6..53fe596fb 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -91,6 +91,7 @@ jobs: with: python-version: '3.9' - run: pip install -r ./src/test/python/requirements.txt + - run: pip install -r ./src/test/python/sqlalchemy/requirements.txt - run: pip install -r ./src/test/python/pg8000/requirements.txt - uses: actions/setup-node@v3 with: diff --git a/.github/workflows/units.yaml b/.github/workflows/units.yaml index 7f598a0fe..81d8a7d11 100644 --- a/.github/workflows/units.yaml +++ b/.github/workflows/units.yaml @@ -26,6 +26,7 @@ jobs: python-version: '3.9' - run: python --version - run: pip install -r ./src/test/python/requirements.txt + - run: pip install -r ./src/test/python/sqlalchemy/requirements.txt - run: pip install -r ./src/test/python/pg8000/requirements.txt - uses: actions/setup-node@v3 with: @@ -51,6 +52,7 @@ jobs: python-version: '3.9' - run: python --version - run: pip install -r ./src/test/python/requirements.txt + - run: pip install -r ./src/test/python/sqlalchemy/requirements.txt - run: pip install -r ./src/test/python/pg8000/requirements.txt - run: mvn -B test macos: @@ -77,6 +79,7 @@ jobs: python-version: '3.9' - run: python --version - run: pip install -r ./src/test/python/requirements.txt + - run: pip install -r ./src/test/python/sqlalchemy/requirements.txt - run: pip install -r ./src/test/python/pg8000/requirements.txt - uses: actions/setup-dotnet@v3 with: diff --git a/CHANGELOG.md b/CHANGELOG.md index f90c874a2..e43b7f759 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,36 @@ # Changelog +## [0.16.0](https://github.com/GoogleCloudPlatform/pgadapter/compare/v0.15.0...v0.16.0) (2023-02-05) + + +### Features + +* allow unsupported OIDs as param types ([#604](https://github.com/GoogleCloudPlatform/pgadapter/issues/604)) ([5e9f95a](https://github.com/GoogleCloudPlatform/pgadapter/commit/5e9f95a720f1648236b39167b227cc70bd40e323)) +* make table and function replacements client-aware ([#605](https://github.com/GoogleCloudPlatform/pgadapter/issues/605)) ([ad49e99](https://github.com/GoogleCloudPlatform/pgadapter/commit/ad49e990298d0e91736d4f5afe581d2f1411b5ca)) + + +### Bug Fixes + +* binary copy header should be included in first data message ([#609](https://github.com/GoogleCloudPlatform/pgadapter/issues/609)) ([2fbf89e](https://github.com/GoogleCloudPlatform/pgadapter/commit/2fbf89e6a6b3ba0b66f126abf019e386e9276d4c)) +* copy to for a query would fail with a column list ([#616](https://github.com/GoogleCloudPlatform/pgadapter/issues/616)) ([16f030e](https://github.com/GoogleCloudPlatform/pgadapter/commit/16f030e3f6b93ae0a243b6c495b0c906403c5e16)) +* CopyResponse did not return correct column format ([#633](https://github.com/GoogleCloudPlatform/pgadapter/issues/633)) ([dc0d482](https://github.com/GoogleCloudPlatform/pgadapter/commit/dc0d482ffb61d1857a3f49fc424a07d72886b460)) +* csv copy header was repeated for each row ([#619](https://github.com/GoogleCloudPlatform/pgadapter/issues/619)) ([622c49a](https://github.com/GoogleCloudPlatform/pgadapter/commit/622c49a02cf2a865874764f44a77b96539382be0)) +* empty copy from stdin statements could be unresponsive ([#617](https://github.com/GoogleCloudPlatform/pgadapter/issues/617)) ([c576124](https://github.com/GoogleCloudPlatform/pgadapter/commit/c576124e40ad7f07ee0d1e2f3090886896c70dc3)) +* empty partitions could skip binary copy header ([#615](https://github.com/GoogleCloudPlatform/pgadapter/issues/615)) ([e7dd650](https://github.com/GoogleCloudPlatform/pgadapter/commit/e7dd6508015ed45147af59c25f95e18628461d85)) +* show statements failed in pgx ([#629](https://github.com/GoogleCloudPlatform/pgadapter/issues/629)) ([734f521](https://github.com/GoogleCloudPlatform/pgadapter/commit/734f52176f75e4ccb0b8bddc96eae49ace9ab19e)) +* support end-of-data record in COPY ([#602](https://github.com/GoogleCloudPlatform/pgadapter/issues/602)) ([8b705e8](https://github.com/GoogleCloudPlatform/pgadapter/commit/8b705e8f917035cbabe9e6751008e93692355158)) + + +### Dependencies + +* update Spanner client to 6.35.1 ([#607](https://github.com/GoogleCloudPlatform/pgadapter/issues/607)) ([0c607c7](https://github.com/GoogleCloudPlatform/pgadapter/commit/0c607c7c1bce48139f28688a5d7f1e202d839860)) + + +### Documentation + +* document pgbench usage ([#603](https://github.com/GoogleCloudPlatform/pgadapter/issues/603)) ([5a62bf6](https://github.com/GoogleCloudPlatform/pgadapter/commit/5a62bf64c56a976625e2c707b6d049e593cddc96)) +* document unix domain sockets with Docker ([#622](https://github.com/GoogleCloudPlatform/pgadapter/issues/622)) ([e4e41f7](https://github.com/GoogleCloudPlatform/pgadapter/commit/e4e41f70e5ad23d8e7d6f2a1bc1851458466bbb6)) + ## [0.15.0](https://github.com/GoogleCloudPlatform/pgadapter/compare/v0.14.1...v0.15.0) (2023-01-18) diff --git a/README.md b/README.md index dca5be2f2..6756f34c6 100644 --- a/README.md +++ b/README.md @@ -11,12 +11,17 @@ PGAdapter can be used with the following drivers and clients: 4. `psycopg2`: Version 2.9.3 and higher (but not `psycopg3`) are supported. See [psycopg2](docs/psycopg2.md) for more details. 5. `node-postgres`: Version 8.8.0 and higher are supported. See [node-postgres support](docs/node-postgres.md) for more details. -## Frameworks -PGAdapter can be used with the following frameworks: +## Frameworks and Tools +PGAdapter can be used with the following frameworks and tools: 1. `Liquibase`: Version 4.12.0 and higher are supported. See [Liquibase support](docs/liquibase.md) for more details. See also [this directory](samples/java/liquibase) for a sample application using `Liquibase`. 2. `gorm`: Version 1.23.8 and higher are supported. See [gorm support](docs/gorm.md) for more details. See also [this directory](samples/golang/gorm) for a sample application using `gorm`. +3. `SQLAlchemy`: Version 1.4.45 has _experimental support_. See [SQLAlchemy support](docs/sqlalchemy.md) + for more details. See also [this directory](samples/python/sqlalchemy-sample) for a sample + application using `SQLAlchemy`. +4. `pgbench` can be used with PGAdapter, but with some limitations. See [pgbench.md](docs/pgbench.md) + for more details. ## FAQ See [Frequently Asked Questions](docs/faq.md) for answers to frequently asked questions. @@ -63,9 +68,9 @@ Use the `-s` option to specify a different local port than the default 5432 if y PostgreSQL running on your local system. -You can also download a specific version of the jar. Example (replace `v0.15.0` with the version you want to download): +You can also download a specific version of the jar. Example (replace `v0.16.0` with the version you want to download): ```shell -VERSION=v0.15.0 +VERSION=v0.16.0 wget https://storage.googleapis.com/pgadapter-jar-releases/pgadapter-${VERSION}.tar.gz \ && tar -xzvf pgadapter-${VERSION}.tar.gz java -jar pgadapter.jar -p my-project -i my-instance -d my-database @@ -100,7 +105,7 @@ This option is only available for Java/JVM-based applications. com.google.cloud google-cloud-spanner-pgadapter - 0.15.0 + 0.16.0 ``` diff --git a/benchmarks/ycsb/run.sh b/benchmarks/ycsb/run.sh index dfeb48e09..947687499 100644 --- a/benchmarks/ycsb/run.sh +++ b/benchmarks/ycsb/run.sh @@ -31,16 +31,19 @@ psql -h localhost -c "CREATE TABLE IF NOT EXISTS usertable ( read_min float, read_max float, read_avg float, + read_p50 float, read_p95 float, read_p99 float, update_min float, update_max float, update_avg float, + update_p50 float, update_p95 float, update_p99 float, insert_min float, insert_max float, insert_avg float, + insert_p50 float, insert_p95 float, insert_p99 float, primary key (executed_at, deployment, workload, threads, batch_size, operation_count) @@ -98,6 +101,7 @@ do fi ./bin/ycsb run jdbc -P workloads/workload$WORKLOAD \ -threads $THREADS \ + -p hdrhistogram.percentiles=50,95,99 \ -p operationcount=$OPERATION_COUNT \ -p recordcount=100000 \ -p db.batchsize=$BATCH_SIZE \ @@ -112,30 +116,33 @@ do READ_AVG=$(grep '\[READ\], AverageLatency(us), ' ycsb.log | sed 's/^.*, //' | sed "s/NaN/'NaN'/" || echo null) READ_MIN=$(grep '\[READ\], MinLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null) READ_MAX=$(grep '\[READ\], MaxLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null) + READ_P50=$(grep '\[READ\], 50thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null) READ_P95=$(grep '\[READ\], 95thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null) READ_P99=$(grep '\[READ\], 99thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null) UPDATE_AVG=$(grep '\[UPDATE\], AverageLatency(us), ' ycsb.log | sed 's/^.*, //' | sed "s/NaN/'NaN'/" || echo null) UPDATE_MIN=$(grep '\[UPDATE\], MinLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null) UPDATE_MAX=$(grep '\[UPDATE\], MaxLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null) + UPDATE_P50=$(grep '\[UPDATE\], 50thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null) UPDATE_P95=$(grep '\[UPDATE\], 95thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null) UPDATE_P99=$(grep '\[UPDATE\], 99thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null) INSERT_AVG=$(grep '\[INSERT\], AverageLatency(us), ' ycsb.log | sed 's/^.*, //' | sed "s/NaN/'NaN'/" || echo null) INSERT_MIN=$(grep '\[INSERT\], MinLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null) INSERT_MAX=$(grep '\[INSERT\], MaxLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null) + INSERT_P50=$(grep '\[INSERT\], 50thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null) INSERT_P95=$(grep '\[INSERT\], 95thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null) INSERT_P99=$(grep '\[INSERT\], 99thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null) psql -h localhost \ -c "insert into run (executed_at, deployment, workload, threads, batch_size, operation_count, run_time, throughput, - read_min, read_max, read_avg, read_p95, read_p99, - update_min, update_max, update_avg, update_p95, update_p99, - insert_min, insert_max, insert_avg, insert_p95, insert_p99) values + read_min, read_max, read_avg, read_p50, read_p95, read_p99, + update_min, update_max, update_avg, update_p50, update_p95, update_p99, + insert_min, insert_max, insert_avg, insert_p50, insert_p95, insert_p99) values ('$EXECUTED_AT', '$DEPLOYMENT', '$WORKLOAD', $THREADS, $BATCH_SIZE, $OPERATION_COUNT, $OVERALL_RUNTIME, $OVERALL_THROUGHPUT, - $READ_MIN, $READ_MAX, $READ_AVG, $READ_P95, $READ_P99, - $UPDATE_MIN, $UPDATE_MAX, $UPDATE_AVG, $UPDATE_P95, $UPDATE_P99, - $INSERT_MIN, $INSERT_MAX, $INSERT_AVG, $INSERT_P95, $INSERT_P99)" + $READ_MIN, $READ_MAX, $READ_AVG, $READ_P50, $READ_P95, $READ_P99, + $UPDATE_MIN, $UPDATE_MAX, $UPDATE_AVG, $UPDATE_P50, $UPDATE_P95, $UPDATE_P99, + $INSERT_MIN, $INSERT_MAX, $INSERT_AVG, $INSERT_P50, $INSERT_P95, $INSERT_P99)" done done done diff --git a/docs/docker.md b/docs/docker.md index ccc605ac5..9ab4b046b 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -12,6 +12,7 @@ docker run \ -d -p 5432:5432 \ -v ${GOOGLE_APPLICATION_CREDENTIALS}:${GOOGLE_APPLICATION_CREDENTIALS}:ro \ -e GOOGLE_APPLICATION_CREDENTIALS \ + -v /tmp:/tmp:rw \ gcr.io/cloud-spanner-pg-adapter/pgadapter \ -p my-project -i my-instance -d my-database \ -x @@ -26,6 +27,8 @@ The Docker options in the `docker run` command that are used in the above exampl The local file should contain the credentials that should be used by PGAdapter. * `-e`: Copy the value of the environment variable `GOOGLE_APPLICATION_CREDENTIALS` to the container. This will make the virtual file `/path/to/credentials.json` the default credentials in the container. +* `-v /tmp:/tmp:rw`: Map the `/tmp` host directory to the `/tmp` directory in the container. PGAdapter by + default uses `/tmp` as the directory where it creates a Unix Domain Socket. The PGAdapter options in the `docker run` command that are used in the above example are: * `-p`: The Google Cloud project name where the Cloud Spanner database is located. @@ -54,3 +57,24 @@ psql -h localhost -p 5433 The above example starts PGAdapter in a Docker container on the default port 5432. That port is mapped to port 5433 on the host machine. When connecting to PGAdapter in the Docker container, you must specify port 5433 for the connection. + +## Using a different directory for Unix Domain Sockets + +PGAdapter by default uses `/tmp` as the directory for Unix Domain Sockets. You can change this with +the `-dir` command line argument for PGAdapter. You must then also change the volume mapping to your +host to ensure that you can connect to the Unix Domain Socket. + +```shell +docker run \ + -d -p 5432:5432 \ + -v ${GOOGLE_APPLICATION_CREDENTIALS}:${GOOGLE_APPLICATION_CREDENTIALS}:ro \ + -e GOOGLE_APPLICATION_CREDENTIALS \ + -v /var/pgadapter:/var/pgadapter:rw \ + gcr.io/cloud-spanner-pg-adapter/pgadapter \ + -p my-project -i my-instance -d my-database \ + -dir /var/pgadapter \ + -x +psql -h /var/pgadapter +``` + +The above example uses `/var/pgadapter` as the directory for Unix Domain Sockets. diff --git a/docs/pgbench.md b/docs/pgbench.md new file mode 100644 index 000000000..eb393d906 --- /dev/null +++ b/docs/pgbench.md @@ -0,0 +1,100 @@ +# PGAdapter - pgbench + +[pgbench](https://www.postgresql.org/docs/current/pgbench.html) can be used with PGAdapter, but with +some limitations. + +Follow these steps to initialize and run benchmarks with `pgbench` with PGAdapter: + +## Create Data Model +The default data model that is generated by `pgbench` does not include primary keys for the tables. +Cloud Spanner requires all tables to have primary keys. Execute the following command to manually +create the data model for `pgbench`: + +```shell +psql -h /tmp -p 5432 -d my-database \ + -c "START BATCH DDL; + CREATE TABLE pgbench_accounts ( + aid integer primary key NOT NULL, + bid integer NULL, + abalance integer NULL, + filler varchar(84) NULL + ); + CREATE TABLE pgbench_branches ( + bid integer primary key NOT NULL, + bbalance integer NULL, + filler varchar(88) NULL + ); + CREATE TABLE pgbench_history ( + tid integer NOT NULL DEFAULT -1, + bid integer NOT NULL DEFAULT -1, + aid integer NOT NULL DEFAULT -1, + delta integer NULL, + mtime timestamptz NULL, + filler varchar(22) NULL, + primary key (tid, bid, aid) + ); + CREATE TABLE pgbench_tellers ( + tid integer primary key NOT NULL, + bid integer NULL, + tbalance integer NULL, + filler varchar(84) NULL + ); + RUN BATCH;" +``` + +## Initialize Data +`pgbench` deletes and inserts data into PostgreSQL using a combination of `truncate`, `insert` and +`copy` statements. These statements all run in a single transaction. The amount of data that is +modified during this transaction will exceed the transaction mutation limits of Cloud Spanner. This +can be worked around by adding the following options to the `pgbench` initialization command: + +```shell +pgbench "host=/tmp port=5432 dbname=my-database \ + options='-c spanner.force_autocommit=on -c spanner.autocommit_dml_mode=\'partitioned_non_atomic\''" \ + -i -Ig \ + --scale=100 +``` + +These additional options do the following: +1. `spanner.force_autocommit=true`: This instructs PGAdapter to ignore any transaction statements and + execute all statements in autocommit mode. This prevents the initialization from being executed as + a single, large transaction. +2. `spanner.autocommit_dml_mode='partitioned_non_atomic'`: This instructs PGAdapter to use Partitioned + DML for (large) update statements. This ensures that a single statement succeeds even if it would + exceed the transaction limits of Cloud Spanner, including large `copy` operations. +3. `-i` activates initialization mode of `pgbench`. +4. `-Ig` instructs `pgbench` to generate test data client side. PGAdapter does not support generating + test data server side. + +## Running Benchmarks +You can run different benchmarks after finishing the steps above. + +### Default Benchmark +Run a default benchmark to verify that everything works as expected. + +```shell +pgbench "host=/tmp port=5432 dbname=my-database" +``` + +### Number of Clients +Increase the number of clients and threads to increase the number of parallel transactions. + +```shell +pgbench "host=/tmp port=5432 dbname=my-database" \ + --client=100 --jobs=100 \ + --progress=10 +``` + +## Dropping Tables +Execute the following command to remove the `pgbench` tables from your database if you no longer +need them. + +```shell +psql -h /tmp -p 5432 -d my-database \ + -c "START BATCH DDL; + DROP TABLE pgbench_history; + DROP TABLE pgbench_tellers; + DROP TABLE pgbench_branches; + DROP TABLE pgbench_accounts; + RUN BATCH;" +``` diff --git a/docs/sqlalchemy.md b/docs/sqlalchemy.md new file mode 100644 index 000000000..f4e577fea --- /dev/null +++ b/docs/sqlalchemy.md @@ -0,0 +1,47 @@ +# PGAdapter - SQLAlchemy Connection Options + +## Limitations +PGAdapter has experimental support for SQLAlchemy 1.4 with Cloud Spanner PostgreSQL databases. It +has been tested with SQLAlchemy 1.4.45 and psycopg2 2.9.3. Developing new applications using +SQLAlchemy is possible as long as the listed limitations are taken into account. +Porting an existing application from PostgreSQL to Cloud Spanner is likely to require code changes. + +See [Limitations](../samples/python/sqlalchemy-sample/README.md#limitations) in the `sqlalchemy-sample` +directory for a full list of limitations. + +## Usage + +First start PGAdapter: + +```shell +export GOOGLE_APPLICATION_CREDENTIALS=/path/to/credentials.json +docker pull gcr.io/cloud-spanner-pg-adapter/pgadapter +docker run \ + -d -p 5432:5432 \ + -v ${GOOGLE_APPLICATION_CREDENTIALS}:${GOOGLE_APPLICATION_CREDENTIALS}:ro \ + -e GOOGLE_APPLICATION_CREDENTIALS \ + gcr.io/cloud-spanner-pg-adapter/pgadapter \ + -p my-project -i my-instance \ + -x +``` + +Then connect to PGAdapter and use `SQLAlchemy` like this: + +```python +conn_string = "postgresql+psycopg2://user:password@localhost:5432/my-database" +engine = create_engine(conn_string) +with Session(engine) as session: + user1 = User( + name="user1", + fullname="Full name of User 1", + addresses=[Address(email_address="user1@sqlalchemy.org")] + ) + session.add(user1) + session.commit() +``` + +## Full Sample and Limitations +[This directory](../samples/python/sqlalchemy-sample) contains a full sample of how to work with +`SQLAlchemy` with Cloud Spanner and PGAdapter. The sample readme file also lists the +[current limitations](../samples/python/sqlalchemy-sample/README.md#limitations) when working with +`SQLAlchemy`. diff --git a/pom.xml b/pom.xml index 76c891e9c..c4cb90a62 100644 --- a/pom.xml +++ b/pom.xml @@ -34,13 +34,13 @@ com.google.cloud.spanner.pgadapter.nodejs.NodeJSTest - 6.33.0 + 6.35.1 2.6.1 4.0.0 google-cloud-spanner-pgadapter - 0.15.0 + 0.16.0 Google Cloud Spanner PostgreSQL Adapter jar @@ -111,7 +111,7 @@ org.apache.commons commons-csv - 1.9.0 + 1.10.0 org.apache.commons @@ -147,7 +147,7 @@ net.java.dev.jna jna - 5.12.1 + 5.13.0 test diff --git a/samples/python/sqlalchemy-sample/README.md b/samples/python/sqlalchemy-sample/README.md new file mode 100644 index 000000000..e7abad7ce --- /dev/null +++ b/samples/python/sqlalchemy-sample/README.md @@ -0,0 +1,174 @@ +# PGAdapter and SQLAlchemy + +PGAdapter has experimental support for [SQLAlchemy 1.4](https://docs.sqlalchemy.org/en/14/index.html) +with the `psycopg2` driver. This document shows how to use this sample application, and lists the +limitations when working with `SQLAlchemy` with PGAdapter. + +The [sample.py](sample.py) file contains a sample application using `SQLAlchemy` with PGAdapter. Use +this as a reference for features of `SQLAlchemy` that are supported with PGAdapter. This sample +assumes that the reader is familiar with `SQLAlchemy`, and it is not intended as a tutorial for how +to use `SQLAlchemy` in general. + +See [Limitations](#limitations) for a full list of known limitations when working with `SQLAlchemy`. + +## Start PGAdapter +You must start PGAdapter before you can run the sample. The following command shows how to start PGAdapter using the +pre-built Docker image. See [Running PGAdapter](../../../README.md#usage) for more information on other options for how +to run PGAdapter. + +```shell +export GOOGLE_APPLICATION_CREDENTIALS=/path/to/credentials.json +docker pull gcr.io/cloud-spanner-pg-adapter/pgadapter +docker run \ + -d -p 5432:5432 \ + -v ${GOOGLE_APPLICATION_CREDENTIALS}:${GOOGLE_APPLICATION_CREDENTIALS}:ro \ + -e GOOGLE_APPLICATION_CREDENTIALS \ + -v /tmp:/tmp \ + gcr.io/cloud-spanner-pg-adapter/pgadapter \ + -p my-project -i my-instance \ + -x +``` + +## Creating the Sample Data Model +The sample data model contains example tables that cover all supported data types the Cloud Spanner +PostgreSQL dialect. It also includes an example for how [interleaved tables](https://cloud.google.com/spanner/docs/reference/postgresql/data-definition-language#extensions_to) +can be used with SQLAlchemy. Interleaved tables is a Cloud Spanner extension of the standard +PostgreSQL dialect. + +The corresponding SQLAlchemy model is defined in [model.py](model.py). + +Run the following command in this directory. Replace the host, port and database name with the actual +host, port and database name for your PGAdapter and database setup. + +```shell +psql -h localhost -p 5432 -d my-database -f create_data_model.sql +``` + +You can also drop an existing data model using the `drop_data_model.sql` script: + +```shell +psql -h localhost -p 5432 -d my-database -f drop_data_model.sql +``` + +## Data Types +Cloud Spanner supports the following data types in combination with `SQLAlchemy`. + +| PostgreSQL Type | SQLAlchemy type | +|----------------------------------------|-------------------------| +| boolean | Boolean | +| bigint / int8 | Integer, BigInteger | +| varchar | String | +| text | String | +| float8 / double precision | Float | +| numeric | Numeric | +| timestamptz / timestamp with time zone | DateTime(timezone=True) | +| date | Date | +| bytea | LargeBinary | +| jsonb | JSONB | + + +## Limitations +The following limitations are currently known: + +| Limitation | Workaround | +|------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Creating and Dropping Tables | Cloud Spanner does not support the full PostgreSQL DDL dialect. Automated creation of tables using `SQLAlchemy` is therefore not supported. | +| metadata.reflect() | Cloud Spanner does not support all PostgreSQL `pg_catalog` tables. Using `metadata.reflect()` to get the current objects in the database is therefore not supported. | +| DDL Transactions | Cloud Spanner does not support DDL statements in a transaction. Add `?options=-c spanner.ddl_transaction_mode=AutocommitExplicitTransaction` to your connection string to automatically convert DDL transactions to [non-atomic DDL batches](../../../docs/ddl.md). | +| Generated primary keys | Manually assign a value to the primary key column in your code. The recommended primary key type is a random UUID. Sequences / SERIAL / IDENTITY columns are currently not supported. | +| INSERT ... ON CONFLICT | `INSERT ... ON CONFLICT` is not supported. | +| SAVEPOINT | Nested transactions and savepoints are not supported. | +| SELECT ... FOR UPDATE | `SELECT ... FOR UPDATE` is not supported. | +| Server side cursors | Server side cursors are currently not supported. | +| Transaction isolation level | Only SERIALIZABLE and AUTOCOMMIT are supported. `postgresql_readonly=True` is also supported. It is recommended to use either autocommit or read-only for workloads that only read data and/or that do not need to be atomic to get the best possible performance. | +| Stored procedures | Cloud Spanner does not support Stored Procedures. | +| User defined functions | Cloud Spanner does not support User Defined Functions. | +| Other drivers than psycopg2 | PGAdapter does not support using SQLAlchemy with any other drivers than `psycopg2`. | + +### Generated Primary Keys +Generated primary keys are not supported and should be replaced with primary key definitions that +are manually assigned. See https://cloud.google.com/spanner/docs/schema-design#primary-key-prevent-hotspots +for more information on choosing a good primary key. This sample uses random UUIDs that are generated +by the client and stored as strings for primary keys. + +```python +from uuid import uuid4 + +class Singer(Base): + id = Column(String, primary_key=True) + name = Column(String(100)) + +singer = Singer( + id="{}".format(uuid4()), + name="Alice") +``` + +### ON CONFLICT Clauses +`INSERT ... ON CONFLICT ...` are not supported by Cloud Spanner and should not be used. Trying to +use https://docs.sqlalchemy.org/en/14/dialects/postgresql.html#sqlalchemy.dialects.postgresql.Insert.on_conflict_do_update +or https://docs.sqlalchemy.org/en/14/dialects/postgresql.html#sqlalchemy.dialects.postgresql.Insert.on_conflict_do_nothing +will fail. + +### SAVEPOINT - Nested transactions +`SAVEPOINT`s are not supported by Cloud Spanner. Nested transactions in SQLAlchemy are translated to +savepoints and are therefore not supported. Trying to use `Session.begin_nested()` +(https://docs.sqlalchemy.org/en/14/orm/session_api.html#sqlalchemy.orm.Session.begin_nested) will fail. + +### Locking - SELECT ... FOR UPDATE +Locking clauses, like `SELECT ... FOR UPDATE`, are not supported (see also https://docs.sqlalchemy.org/en/20/orm/queryguide/query.html#sqlalchemy.orm.Query.with_for_update). +These are normally also not required, as Cloud Spanner uses isolation level `serializable` for +read/write transactions. + +## Performance Considerations + +### Parameterized Queries +`psycopg2` does not use [parameterized queries](https://cloud.google.com/spanner/docs/sql-best-practices#postgresql_1). +Instead, `psycopg2` will replace query parameters with literals in the SQL string before sending +these to PGAdapter. This means that each query must be parsed and planned separately. This will add +latency to queries that are executed multiple times compared to if the queries were executed using +actual query parameters. + +### Read-only Transactions +SQLAlchemy will by default use read/write transactions for all database operations, including for +workloads that only read data. This will cause Cloud Spanner to take locks for all data that is read +during the transaction. It is recommended to use either autocommit or [read-only transactions](https://cloud.google.com/spanner/docs/transactions#read-only_transactions) +for workloads that are known to only execute read operations. Read-only transactions do not take any +locks. You can create a separate database engine that can be used for read-only transactions from +your default database engine by adding the `postgresql_readonly=True` execution option. + +```python +read_only_engine = engine.execution_options(postgresql_readonly=True) +``` + +### Autocommit +Using isolation level `AUTOCOMMIT` will suppress the use of (read/write) transactions for each +database operation in SQLAlchemy. Using autocommit is more efficient than read/write transactions +for workloads that only read and/or that do not need the atomicity that is offered by transactions. + +You can create a separate database engine that can be used for workloads that do not need +transactions by adding the `isolation_level="AUTOCOMMIT"` execution option to your default database +engine. + +```python +autocommit_engine = engine.execution_options(isolation_level="AUTOCOMMIT") +``` + +### Stale reads +Read-only transactions and database engines using `AUTOCOMMIT` will by default use strong reads for +queries. Cloud Spanner also supports stale reads. + +* A strong read is a read at a current timestamp and is guaranteed to see all data that has been + committed up until the start of this read. Spanner defaults to using strong reads to serve read requests. +* A stale read is read at a timestamp in the past. If your application is latency sensitive but + tolerant of stale data, then stale reads can provide performance benefits. + +See also https://cloud.google.com/spanner/docs/reads#read_types + +You can create a database engine that will use stale reads in autocommit mode by adding the following +to the connection string and execution options of the engine: + +```python +conn_string = "postgresql+psycopg2://user:password@localhost:5432/my-database" \ + "?options=-c spanner.read_only_staleness='MAX_STALENESS 10s'" +engine = create_engine(conn_string).execution_options(isolation_level="AUTOCOMMIT") +``` diff --git a/samples/python/sqlalchemy-sample/connect.py b/samples/python/sqlalchemy-sample/connect.py new file mode 100644 index 000000000..0d6980c99 --- /dev/null +++ b/samples/python/sqlalchemy-sample/connect.py @@ -0,0 +1,60 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import argparse + +from sqlalchemy import create_engine, event +from model import * +import sys + + +# Creates a database engine that can be used with the sample.py script in this +# directory. This function assumes that the host, port and database name was +# given as command line arguments. +def create_test_engine(autocommit: bool=False, options: str=""): + parser = argparse.ArgumentParser(description='Run SQLAlchemy sample.') + parser.add_argument('host', type=str, help='host to connect to') + parser.add_argument('port', type=int, help='port number to connect to') + parser.add_argument('database', type=str, help='database to connect to', default='d') + args = parser.parse_args() + + conn_string = "postgresql+psycopg2://user:password@{host}:{port}/" \ + "{database}{options}".format(host=args.host, + port=args.port, + database=args.database, + options=options) + if args.host == "": + if options == "": + conn_string = conn_string + "?host=/tmp" + else: + conn_string = conn_string + "&host=/tmp" + engine = create_engine(conn_string, future=True) + if autocommit: + engine = engine.execution_options(isolation_level="AUTOCOMMIT") + return engine + + +def register_event_listener_for_prepared_statements(engine): + # Register an event listener for this engine that creates prepared statements + # for each connection that is created. The prepared statement definitions can + # be added to the model classes. + @event.listens_for(engine, "connect") + def connect(dbapi_connection, connection_record): + cursor_obj = dbapi_connection.cursor() + for model in BaseMixin.__subclasses__(): + if model.__prepare_statements__ is not None: + for prepare_statement in model.__prepare_statements__: + cursor_obj.execute(prepare_statement) + + cursor_obj.close() diff --git a/samples/python/sqlalchemy-sample/create_data_model.sql b/samples/python/sqlalchemy-sample/create_data_model.sql new file mode 100644 index 000000000..c8e173b32 --- /dev/null +++ b/samples/python/sqlalchemy-sample/create_data_model.sql @@ -0,0 +1,64 @@ + +-- Executing the schema creation in a batch will improve execution speed. +start batch ddl; + +create table if not exists singers ( + id varchar not null primary key, + version_id int not null, + first_name varchar, + last_name varchar not null, + full_name varchar generated always as (coalesce(concat(first_name, ' '::varchar, last_name), last_name)) stored, + active boolean, + created_at timestamptz, + updated_at timestamptz +); + +create table if not exists albums ( + id varchar not null primary key, + version_id int not null, + title varchar not null, + marketing_budget numeric, + release_date date, + cover_picture bytea, + singer_id varchar not null, + created_at timestamptz, + updated_at timestamptz, + constraint fk_albums_singers foreign key (singer_id) references singers (id) +); + +create table if not exists tracks ( + id varchar not null, + track_number bigint not null, + version_id int not null, + title varchar not null, + sample_rate float8 not null, + created_at timestamptz, + updated_at timestamptz, + primary key (id, track_number) +) interleave in parent albums on delete cascade; + +create table if not exists venues ( + id varchar not null primary key, + version_id int not null, + name varchar not null, + description jsonb not null, + created_at timestamptz, + updated_at timestamptz +); + +create table if not exists concerts ( + id varchar not null primary key, + version_id int not null, + venue_id varchar not null, + singer_id varchar not null, + name varchar not null, + start_time timestamptz not null, + end_time timestamptz not null, + created_at timestamptz, + updated_at timestamptz, + constraint fk_concerts_venues foreign key (venue_id) references venues (id), + constraint fk_concerts_singers foreign key (singer_id) references singers (id), + constraint chk_end_time_after_start_time check (end_time > start_time) +); + +run batch; diff --git a/samples/python/sqlalchemy-sample/drop_data_model.sql b/samples/python/sqlalchemy-sample/drop_data_model.sql new file mode 100644 index 000000000..b57ea2ad8 --- /dev/null +++ b/samples/python/sqlalchemy-sample/drop_data_model.sql @@ -0,0 +1,10 @@ +-- Executing the schema drop in a batch will improve execution speed. +start batch ddl; + +drop table if exists concerts; +drop table if exists venues; +drop table if exists tracks; +drop table if exists albums; +drop table if exists singers; + +run batch; diff --git a/samples/python/sqlalchemy-sample/model.py b/samples/python/sqlalchemy-sample/model.py new file mode 100644 index 000000000..b920cac03 --- /dev/null +++ b/samples/python/sqlalchemy-sample/model.py @@ -0,0 +1,178 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from sqlalchemy import Column, Integer, String, Boolean, LargeBinary, Float, \ + Numeric, DateTime, Date, FetchedValue, ForeignKey, ColumnDefault +from sqlalchemy.orm import registry, relationship +from sqlalchemy.dialects.postgresql import JSONB +from datetime import timezone, datetime + +mapper_registry = registry() +Base = mapper_registry.generate_base() + + +""" + BaseMixin contains properties that are common to all models in this sample. All + models use a string column that contains a client-side generated UUID as the + primary key. + The created_at and updated_at properties are automatically filled with the + current client system time when a model is created or updated. +""" + + +def format_timestamp(timestamp: datetime) -> str: + return timestamp.astimezone(timezone.utc).isoformat() if timestamp else None + + +class BaseMixin(object): + __prepare_statements__ = None + + id = Column(String, primary_key=True) + version_id = Column(Integer, nullable=False) + created_at = Column(DateTime(timezone=True), + # We need to explicitly format the timestamps with a + # timezone to ensure that SQLAlchemy uses a + # timestamptz instead of just timestamp. + ColumnDefault(datetime.utcnow().astimezone(timezone.utc))) + updated_at = Column(DateTime(timezone=True), + ColumnDefault( + datetime.utcnow().astimezone(timezone.utc), + for_update=True)) + __mapper_args__ = {"version_id_col": version_id} + + +class Singer(BaseMixin, Base): + __tablename__ = "singers" + + first_name = Column(String(100)) + last_name = Column(String(200)) + full_name = Column(String, + server_default=FetchedValue(), + server_onupdate=FetchedValue()) + active = Column(Boolean) + albums = relationship("Album", back_populates="singer") + + __mapper_args__ = { + "eager_defaults": True, + "version_id_col": BaseMixin.version_id + } + + def __repr__(self): + return ( + f"singers(" + f"id={self.id!r}," + f"first_name={self.first_name!r}," + f"last_name={self.last_name!r}," + f"active={self.active!r}," + f"created_at={format_timestamp(self.created_at)!r}," + f"updated_at={format_timestamp(self.updated_at)!r}" + f")" + ) + + +class Album(BaseMixin, Base): + __tablename__ = "albums" + + title = Column(String(200)) + marketing_budget = Column(Numeric) + release_date = Column(Date) + cover_picture = Column(LargeBinary) + singer_id = Column(String, ForeignKey("singers.id")) + singer = relationship("Singer", back_populates="albums") + # The `tracks` relationship uses passive_deletes=True, because `tracks` is + # interleaved in `albums` with `ON DELETE CASCADE`. This prevents SQLAlchemy + # from deleting the related tracks when an album is deleted, and lets the + # database handle it. + tracks = relationship("Track", back_populates="album", passive_deletes=True) + + def __repr__(self): + return ( + f"albums(" + f"id={self.id!r}," + f"title={self.title!r}," + f"marketing_budget={self.marketing_budget!r}," + f"release_date={self.release_date!r}," + f"cover_picture={self.cover_picture!r}," + f"singer={self.singer_id!r}," + f"created_at={format_timestamp(self.created_at)!r}," + f"updated_at={format_timestamp(self.updated_at)!r}" + f")" + ) + + +class Track(BaseMixin, Base): + __tablename__ = "tracks" + + id = Column(String, ForeignKey("albums.id"), primary_key=True) + track_number = Column(Integer, primary_key=True) + title = Column(String) + sample_rate = Column(Float) + album = relationship("Album", back_populates="tracks") + + def __repr__(self): + return ( + f"tracks(" + f"id={self.id!r}," + f"track_number={self.track_number!r}," + f"title={self.title!r}," + f"sample_rate={self.sample_rate!r}," + f"created_at={format_timestamp(self.created_at)!r}," + f"updated_at={format_timestamp(self.updated_at)!r}" + f")" + ) + + +class Venue(BaseMixin, Base): + __tablename__ = "venues" + + name = Column(String(200)) + description = Column(JSONB) + + def __repr__(self): + return ( + f"venues(" + f"id={self.id!r}," + f"name={self.name!r}," + f"description={self.description!r}," + f"created_at={format_timestamp(self.created_at)!r}," + f"updated_at={format_timestamp(self.updated_at)!r}" + f")" + ) + + +class Concert(BaseMixin, Base): + __tablename__ = "concerts" + + name = Column(String(200)) + venue_id = Column(String, ForeignKey("venues.id")) + venue = relationship("Venue") + singer_id = Column(String, ForeignKey("singers.id")) + singer = relationship("Singer") + start_time = Column(DateTime(timezone=True)) + end_time = Column(DateTime(timezone=True)) + + def __repr__(self): + return ( + f"concerts(" + f"id={self.id!r}," + f"name={self.name!r}," + f"venue={self.venue!r}," + f"singer={self.singer!r}," + f"start_time={self.start_time!r}," + f"end_time={self.end_time!r}," + f"created_at={format_timestamp(self.created_at)!r}," + f"updated_at={format_timestamp(self.updated_at)!r}" + f")" + ) diff --git a/samples/python/sqlalchemy-sample/run_sample.py b/samples/python/sqlalchemy-sample/run_sample.py new file mode 100644 index 000000000..8e0a353e2 --- /dev/null +++ b/samples/python/sqlalchemy-sample/run_sample.py @@ -0,0 +1,27 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import argparse + + +parser = argparse.ArgumentParser(description='Run SQLAlchemy sample.') +parser.add_argument('host', type=str, help='host to connect to') +parser.add_argument('port', type=int, help='port number to connect to') +parser.add_argument('database', type=str, help='database to connect to') +args = parser.parse_args() + + +from sample import run_sample + +run_sample() diff --git a/samples/python/sqlalchemy-sample/sample.py b/samples/python/sqlalchemy-sample/sample.py new file mode 100644 index 000000000..37749a946 --- /dev/null +++ b/samples/python/sqlalchemy-sample/sample.py @@ -0,0 +1,339 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from connect import create_test_engine +from sqlalchemy.orm import Session, joinedload +from sqlalchemy import text, func, or_ +from model import Singer, Album, Track, Venue, Concert +from util_random_names import random_first_name, random_last_name, \ + random_album_title, random_release_date, random_marketing_budget, \ + random_cover_picture +from uuid import uuid4 +from datetime import datetime, date + + +# This is the default engine that is connected to PostgreSQL (PGAdapter). +# This engine will by default use read/write transactions. +engine = create_test_engine() + +# This engine uses read-only transactions instead of read/write transactions. +# It is recommended to use a read-only transaction instead of a read/write +# transaction for all workloads that only read data, as read-only transactions +# do not take any locks. +read_only_engine = engine.execution_options(postgresql_readonly=True) + +# This engine uses auto commit instead of transactions, and will execute all +# read operations with a max staleness of 10 seconds. This will result in more +# efficient read operations, but the data that is returned can have a staleness +# of up to 10 seconds. +stale_read_engine = create_test_engine( + autocommit=True, + options="?options=-c spanner.read_only_staleness='MAX_STALENESS 10s'") + + +def run_sample(): + # Create the data model that is used for this sample. + create_tables_if_not_exists() + # Delete any existing data before running the sample. + delete_all_data() + + # First create some random singers and albums. + create_random_singers_and_albums() + print_singers_and_albums() + + # Create a venue and a concert. + create_venue_and_concert_in_transaction() + print_concerts() + + # Execute a query to get all albums released before 1980. + print_albums_released_before_1980() + # Print a subset of all singers using LIMIT and OFFSET. + print_singers_with_limit_and_offset() + # Execute a query to get all albums that start with the same letter as the + # first name or last name of the singer. + print_albums_first_character_of_title_equal_to_first_or_last_name() + + # Delete an album. This will automatically also delete all related tracks, as + # Tracks is interleaved in Albums with the option ON DELETE CASCADE. + with Session(engine) as session: + album_id = session.query(Album).first().id + delete_album(album_id) + + # Read all albums using a connection that *could* return stale data. + with Session(stale_read_engine) as session: + album = session.get(Album, album_id) + if album is None: + print("No album found using a stale read.") + else: + print("Album was found using a stale read, even though it has already been deleted.") + + print() + print("Finished running sample") + + +def create_tables_if_not_exists(): + # Cloud Spanner does not support DDL in transactions. We therefore turn on + # autocommit for this connection. + with engine.execution_options( + isolation_level="AUTOCOMMIT").connect() as connection: + with open("create_data_model.sql") as file: + print("Reading sample data model from file") + ddl_script = text(file.read()) + print("Executing table creation script") + connection.execute(ddl_script) + print("Finished executing table creation script") + + +# Create five random singers, each with a number of albums. +def create_random_singers_and_albums(): + with Session(engine) as session: + session.add_all([ + create_random_singer(5), + create_random_singer(10), + create_random_singer(7), + create_random_singer(3), + create_random_singer(12), + ]) + session.commit() + print("Created 5 singers") + + +# Print all singers and albums in currently in the database. +def print_singers_and_albums(): + with Session(read_only_engine) as session: + print() + for singer in session.query(Singer).order_by("last_name").all(): + print("{} has {} albums:".format(singer.full_name, len(singer.albums))) + for album in singer.albums: + print(" '{}'".format(album.title)) + + +# Create a Venue and Concert in one read/write transaction. +def create_venue_and_concert_in_transaction(): + with Session(engine) as session: + singer = session.query(Singer).first() + venue = Venue( + id=str(uuid4()), + name="Avenue Park", + description={ + "Capacity": 5000, + "Location": "New York", + "Country": "US" + } + ) + concert = Concert( + id=str(uuid4()), + name="Avenue Park Open", + venue=venue, + singer=singer, + start_time=datetime.fromisoformat("2023-02-01T20:00:00-05:00"), + end_time=datetime.fromisoformat("2023-02-02T02:00:00-05:00"), + ) + session.add_all([venue, concert]) + session.commit() + print() + print("Created Venue and Concert") + + +# Prints the concerts currently in the database. +def print_concerts(): + with Session(read_only_engine) as session: + # Query all concerts and join both Singer and Venue, so we can directly + # access the properties of these as well without having to execute + # additional queries. + concerts = ( + session.query(Concert) + .options(joinedload(Concert.venue)) + .options(joinedload(Concert.singer)) + .order_by("start_time") + .all() + ) + print() + for concert in concerts: + print("Concert '{}' starting at {} with {} will be held at {}" + .format(concert.name, + concert.start_time, + concert.singer.full_name, + concert.venue.name)) + + +# Prints all albums with a release date before 1980-01-01. +def print_albums_released_before_1980(): + with Session(read_only_engine) as session: + print() + print("Searching for albums released before 1980") + albums = ( + session + .query(Album) + .filter(Album.release_date < date.fromisoformat("1980-01-01")) + .all() + ) + for album in albums: + print( + " Album {} was released at {}".format(album.title, album.release_date)) + + +# Uses a limit and offset to select a subset of all singers in the database. +def print_singers_with_limit_and_offset(): + with Session(read_only_engine) as session: + print() + print("Printing at most 5 singers ordered by last name") + singers = ( + session + .query(Singer) + .order_by(Singer.last_name) + .limit(5) + .offset(3) + .all() + ) + num_singers = 0 + for singer in singers: + num_singers = num_singers + 1 + print(" {}. {}".format(num_singers, singer.full_name)) + print("Found {} singers".format(num_singers)) + + +# Searches for all albums that have a title that starts with the same character +# as the first character of either the first name or last name of the singer. +def print_albums_first_character_of_title_equal_to_first_or_last_name(): + print() + print("Searching for albums that have a title that starts with the same " + "character as the first or last name of the singer") + with Session(read_only_engine) as session: + albums = ( + session + .query(Album) + .join(Singer) + .filter(or_(func.lower(func.substring(Album.title, 1, 1)) == + func.lower(func.substring(Singer.first_name, 1, 1)), + func.lower(func.substring(Album.title, 1, 1)) == + func.lower(func.substring(Singer.last_name, 1, 1)))) + .all() + ) + for album in albums: + print(" '{}' by {}".format(album.title, album.singer.full_name)) + + +# Creates a random singer row with `num_albums` random albums. +def create_random_singer(num_albums): + return Singer( + id=str(uuid4()), + first_name=random_first_name(), + last_name=random_last_name(), + active=True, + albums=create_random_albums(num_albums) + ) + + +# Creates `num_albums` random album rows. +def create_random_albums(num_albums): + if num_albums == 0: + return [] + albums = [] + for i in range(num_albums): + albums.append( + Album( + id=str(uuid4()), + title=random_album_title(), + release_date=random_release_date(), + marketing_budget=random_marketing_budget(), + cover_picture=random_cover_picture() + ) + ) + return albums + + +# Loads and prints the singer with the given id. +def load_singer(singer_id): + with Session(engine) as session: + singer = session.get(Singer, singer_id) + print(singer) + print("Albums:") + print(singer.albums) + + +# Adds a new singer row to the database. Shows how flushing the session will +# automatically return the generated `full_name` column of the Singer. +def add_singer(singer): + with Session(engine) as session: + session.add(singer) + # We flush the session here to show that the generated column full_name is + # returned after the insert. Otherwise, we could just execute a commit + # directly. + session.flush() + print( + "Added singer {} with full name {}".format(singer.id, singer.full_name)) + session.commit() + + +# Updates an existing singer in the database. This will also automatically +# update the full_name of the singer. This is returned by the database and is +# visible in the properties of the singer. +def update_singer(singer_id, first_name, last_name): + with Session(engine) as session: + singer = session.get(Singer, singer_id) + singer.first_name = first_name + singer.last_name = last_name + # We flush the session here to show that the generated column full_name is + # returned after the update. Otherwise, we could just execute a commit + # directly. + session.flush() + print("Updated singer {} with full name {}" + .format(singer.id, singer.full_name)) + session.commit() + + +# Loads the given album from the database and prints its properties, including +# the tracks related to this album. The table `tracks` is interleaved in +# `albums`. This is handled as a normal relationship in SQLAlchemy. +def load_album(album_id): + with Session(engine) as session: + album = session.get(Album, album_id) + print(album) + print("Tracks:") + print(album.tracks) + + +# Loads a single track from the database and prints its properties. Track has a +# composite primary key, as it must include both the primary key column(s) of +# the parent table, as well as its own primary key column(s). +def load_track(album_id, track_number): + with Session(engine) as session: + # The "tracks" table has a composite primary key, as it is an interleaved + # child table of "albums". + track = session.get(Track, [album_id, track_number]) + print(track) + + +# Deletes all current sample data. +def delete_all_data(): + with Session(engine) as session: + session.query(Concert).delete() + session.query(Venue).delete() + session.query(Album).delete() + session.query(Singer).delete() + session.commit() + + +# Deletes an album from the database. This will also delete all related tracks. +# The deletion of the tracks is done by the database, and not by SQLAlchemy, as +# passive_deletes=True has been set on the relationship. +def delete_album(album_id): + with Session(engine) as session: + album = session.get(Album, album_id) + session.delete(album) + session.commit() + print() + print("Deleted album with id {}".format(album_id)) diff --git a/samples/python/sqlalchemy-sample/test_add_singer.py b/samples/python/sqlalchemy-sample/test_add_singer.py new file mode 100644 index 000000000..cb09abcae --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_add_singer.py @@ -0,0 +1,28 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from model import Singer +from sample import add_singer +from datetime import datetime + +singer = Singer() +singer.id = "123-456-789" +singer.first_name = "Myfirstname" +singer.last_name = "Mylastname" +singer.active = True +# Manually set a created_at value, as we otherwise do not know which value to +# add to the mock server. +singer.created_at = datetime.fromisoformat("2011-11-04T00:05:23.123456+00:00"), +add_singer(singer) diff --git a/samples/python/sqlalchemy-sample/test_create_model.py b/samples/python/sqlalchemy-sample/test_create_model.py new file mode 100644 index 000000000..33644d89a --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_create_model.py @@ -0,0 +1,23 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from connect import create_test_engine +from model import mapper_registry + + +engine = create_test_engine(options="?options=-c spanner.ddl_transaction_mode" + "=AutocommitExplicitTransaction") +mapper_registry.metadata.create_all(engine) +print("Created data model") diff --git a/samples/python/sqlalchemy-sample/test_create_random_singers_and_albums.py b/samples/python/sqlalchemy-sample/test_create_random_singers_and_albums.py new file mode 100644 index 000000000..46ddcb59d --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_create_random_singers_and_albums.py @@ -0,0 +1,19 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from sample import create_random_singers_and_albums + + +create_random_singers_and_albums() diff --git a/samples/python/sqlalchemy-sample/test_create_venue_and_concert_in_transaction.py b/samples/python/sqlalchemy-sample/test_create_venue_and_concert_in_transaction.py new file mode 100644 index 000000000..ef82c5104 --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_create_venue_and_concert_in_transaction.py @@ -0,0 +1,19 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from sample import create_venue_and_concert_in_transaction + + +create_venue_and_concert_in_transaction() diff --git a/samples/python/sqlalchemy-sample/test_delete_album.py b/samples/python/sqlalchemy-sample/test_delete_album.py new file mode 100644 index 000000000..fc794c1e1 --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_delete_album.py @@ -0,0 +1,18 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from sample import delete_album + +delete_album("123-456-789") diff --git a/samples/python/sqlalchemy-sample/test_get_album.py b/samples/python/sqlalchemy-sample/test_get_album.py new file mode 100644 index 000000000..6abe3752e --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_get_album.py @@ -0,0 +1,19 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from sample import load_album + + +load_album("987-654-321") diff --git a/samples/python/sqlalchemy-sample/test_get_album_with_stale_engine.py b/samples/python/sqlalchemy-sample/test_get_album_with_stale_engine.py new file mode 100644 index 000000000..7952bf07d --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_get_album_with_stale_engine.py @@ -0,0 +1,22 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from model import Album +from sample import stale_read_engine +from sqlalchemy.orm import Session + + +with Session(stale_read_engine) as session: + album = session.get(Album, "987-654-321") diff --git a/samples/python/sqlalchemy-sample/test_get_singer.py b/samples/python/sqlalchemy-sample/test_get_singer.py new file mode 100644 index 000000000..e60e1bc97 --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_get_singer.py @@ -0,0 +1,19 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from sample import load_singer + + +load_singer("123-456-789") diff --git a/samples/python/sqlalchemy-sample/test_get_track.py b/samples/python/sqlalchemy-sample/test_get_track.py new file mode 100644 index 000000000..6c0be66cc --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_get_track.py @@ -0,0 +1,19 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from sample import load_track + + +load_track("987-654-321", 1) diff --git a/samples/python/sqlalchemy-sample/test_metadata_reflect.py b/samples/python/sqlalchemy-sample/test_metadata_reflect.py new file mode 100644 index 000000000..aad9e0bfe --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_metadata_reflect.py @@ -0,0 +1,22 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from connect import create_test_engine +from model import mapper_registry + + +engine = create_test_engine() +mapper_registry.metadata.reflect(engine) +print("Reflected current data model") diff --git a/samples/python/sqlalchemy-sample/test_print_albums_first_character_of_title_equal_to_first_or_last_name.py b/samples/python/sqlalchemy-sample/test_print_albums_first_character_of_title_equal_to_first_or_last_name.py new file mode 100644 index 000000000..ad59f9ec5 --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_print_albums_first_character_of_title_equal_to_first_or_last_name.py @@ -0,0 +1,20 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from sample import \ + print_albums_first_character_of_title_equal_to_first_or_last_name + + +print_albums_first_character_of_title_equal_to_first_or_last_name() diff --git a/samples/python/sqlalchemy-sample/test_print_albums_released_before_1980.py b/samples/python/sqlalchemy-sample/test_print_albums_released_before_1980.py new file mode 100644 index 000000000..f5da8854e --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_print_albums_released_before_1980.py @@ -0,0 +1,19 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from sample import print_albums_released_before_1980 + + +print_albums_released_before_1980() diff --git a/samples/python/sqlalchemy-sample/test_print_concerts.py b/samples/python/sqlalchemy-sample/test_print_concerts.py new file mode 100644 index 000000000..c5c879c96 --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_print_concerts.py @@ -0,0 +1,19 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from sample import print_concerts + + +print_concerts() diff --git a/samples/python/sqlalchemy-sample/test_print_singers_and_albums.py b/samples/python/sqlalchemy-sample/test_print_singers_and_albums.py new file mode 100644 index 000000000..568bc0a5b --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_print_singers_and_albums.py @@ -0,0 +1,19 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from sample import print_singers_and_albums + + +print_singers_and_albums() diff --git a/samples/python/sqlalchemy-sample/test_print_singers_with_limit_and_offset.py b/samples/python/sqlalchemy-sample/test_print_singers_with_limit_and_offset.py new file mode 100644 index 000000000..4783bad11 --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_print_singers_with_limit_and_offset.py @@ -0,0 +1,19 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from sample import print_singers_with_limit_and_offset + + +print_singers_with_limit_and_offset() diff --git a/samples/python/sqlalchemy-sample/test_update_singer.py b/samples/python/sqlalchemy-sample/test_update_singer.py new file mode 100644 index 000000000..e5d844c79 --- /dev/null +++ b/samples/python/sqlalchemy-sample/test_update_singer.py @@ -0,0 +1,18 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from sample import update_singer + +update_singer("123-456-789", "Newfirstname", "Newlastname") diff --git a/samples/python/sqlalchemy-sample/util_random_names.py b/samples/python/sqlalchemy-sample/util_random_names.py new file mode 100644 index 000000000..e5769408f --- /dev/null +++ b/samples/python/sqlalchemy-sample/util_random_names.py @@ -0,0 +1,150 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import secrets +from random import seed, randrange, random +from datetime import date + +seed() + + +""" + Helper functions for generating random names and titles. +""" + +def random_first_name(): + return first_names[randrange(len(first_names))] + + +def random_last_name(): + return last_names[randrange(len(last_names))] + + +def random_album_title(): + return "{} {}".format( + adjectives[randrange(len(adjectives))], nouns[randrange(len(nouns))]) + + +def random_release_date(): + return date.fromisoformat( + "{}-{:02d}-{:02d}".format(randrange(1900, 2023), + randrange(1, 13), + randrange(1, 29))) + + +def random_marketing_budget(): + return random() * 1000000 + + +def random_cover_picture(): + return secrets.token_bytes(randrange(1, 10)) + + +first_names = [ + "Saffron", "Eleanor", "Ann", "Salma", "Kiera", "Mariam", "Georgie", "Eden", "Carmen", "Darcie", + "Antony", "Benjamin", "Donald", "Keaton", "Jared", "Simon", "Tanya", "Julian", "Eugene", "Laurence" +] +last_names = [ + "Terry", "Ford", "Mills", "Connolly", "Newton", "Rodgers", "Austin", "Floyd", "Doherty", "Nguyen", + "Chavez", "Crossley", "Silva", "George", "Baldwin", "Burns", "Russell", "Ramirez", "Hunter", "Fuller" +] +adjectives = [ + "ultra", + "happy", + "emotional", + "filthy", + "charming", + "alleged", + "talented", + "exotic", + "lamentable", + "lewd", + "old-fashioned", + "savory", + "delicate", + "willing", + "habitual", + "upset", + "gainful", + "nonchalant", + "kind", + "unruly" +] +nouns = [ + "improvement", + "control", + "tennis", + "gene", + "department", + "person", + "awareness", + "health", + "development", + "platform", + "garbage", + "suggestion", + "agreement", + "knowledge", + "introduction", + "recommendation", + "driver", + "elevator", + "industry", + "extent" +] +verbs = [ + "instruct", + "rescue", + "disappear", + "import", + "inhibit", + "accommodate", + "dress", + "describe", + "mind", + "strip", + "crawl", + "lower", + "influence", + "alter", + "prove", + "race", + "label", + "exhaust", + "reach", + "remove" +] +adverbs = [ + "cautiously", + "offensively", + "immediately", + "soon", + "judgementally", + "actually", + "honestly", + "slightly", + "limply", + "rigidly", + "fast", + "normally", + "unnecessarily", + "wildly", + "unimpressively", + "helplessly", + "rightfully", + "kiddingly", + "early", + "queasily" +] diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java b/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java index 176a2f486..5f3fc7d57 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java @@ -52,9 +52,11 @@ import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse; import com.google.cloud.spanner.pgadapter.wireoutput.TerminateResponse; import com.google.cloud.spanner.pgadapter.wireprotocol.BootstrapMessage; +import com.google.cloud.spanner.pgadapter.wireprotocol.ParseMessage; import com.google.cloud.spanner.pgadapter.wireprotocol.SSLMessage; import com.google.cloud.spanner.pgadapter.wireprotocol.WireMessage; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import com.google.common.collect.ImmutableList; @@ -69,6 +71,7 @@ import java.time.Duration; import java.util.HashMap; import java.util.HashSet; +import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.Properties; @@ -120,7 +123,7 @@ public class ConnectionHandler extends Thread { private int invalidMessagesCount; private Connection spannerConnection; private DatabaseId databaseId; - private WellKnownClient wellKnownClient; + private WellKnownClient wellKnownClient = WellKnownClient.UNSPECIFIED; private boolean hasDeterminedClientUsingQuery; private ExtendedQueryProtocolHandler extendedQueryProtocolHandler; private CopyStatement activeCopyStatement; @@ -476,12 +479,12 @@ void handleError(PGException exception) throws Exception { DataOutputStream output = getConnectionMetadata().getOutputStream(); if (this.status == ConnectionStatus.TERMINATED || this.status == ConnectionStatus.UNAUTHENTICATED) { - new ErrorResponse(output, exception).send(); + new ErrorResponse(this, exception).send(); new TerminateResponse(output).send(); } else if (this.status == ConnectionStatus.COPY_IN) { - new ErrorResponse(output, exception).send(); + new ErrorResponse(this, exception).send(); } else { - new ErrorResponse(output, exception).send(); + new ErrorResponse(this, exception).send(); new ReadyResponse(output, ReadyResponse.Status.IDLE).send(); } } @@ -704,6 +707,14 @@ public WellKnownClient getWellKnownClient() { public void setWellKnownClient(WellKnownClient wellKnownClient) { this.wellKnownClient = wellKnownClient; + if (this.wellKnownClient != WellKnownClient.UNSPECIFIED) { + logger.log( + Level.INFO, + () -> + String.format( + "Well-known client %s detected for connection %d.", + this.wellKnownClient, getConnectionId())); + } } /** @@ -713,23 +724,58 @@ public void setWellKnownClient(WellKnownClient wellKnownClient) { * executed. */ public void maybeDetermineWellKnownClient(Statement statement) { - if (!this.hasDeterminedClientUsingQuery - && this.wellKnownClient == WellKnownClient.UNSPECIFIED - && getServer().getOptions().shouldAutoDetectClient()) { - this.wellKnownClient = ClientAutoDetector.detectClient(ImmutableList.of(statement)); - if (this.wellKnownClient != WellKnownClient.UNSPECIFIED) { - logger.log( - Level.INFO, - () -> - String.format( - "Well-known client %s detected for connection %d.", - this.wellKnownClient, getConnectionId())); + if (!this.hasDeterminedClientUsingQuery) { + if (this.wellKnownClient == WellKnownClient.UNSPECIFIED + && getServer().getOptions().shouldAutoDetectClient()) { + setWellKnownClient(ClientAutoDetector.detectClient(ImmutableList.of(statement))); } + maybeSetApplicationName(); } // Make sure that we only try to detect the client once. this.hasDeterminedClientUsingQuery = true; } + /** + * This is called by the extended query protocol {@link + * com.google.cloud.spanner.pgadapter.wireprotocol.ParseMessage} to give the connection the + * opportunity to determine the client that is connected based on the data in the (first) parse + * messages. + */ + public void maybeDetermineWellKnownClient(ParseMessage parseMessage) { + if (!this.hasDeterminedClientUsingQuery) { + if (this.wellKnownClient == WellKnownClient.UNSPECIFIED + && getServer().getOptions().shouldAutoDetectClient()) { + setWellKnownClient(ClientAutoDetector.detectClient(parseMessage)); + } + maybeSetApplicationName(); + } + // Make sure that we only try to detect the client once. + this.hasDeterminedClientUsingQuery = true; + } + + private void maybeSetApplicationName() { + try { + if (this.wellKnownClient != WellKnownClient.UNSPECIFIED + && getExtendedQueryProtocolHandler() != null + && Strings.isNullOrEmpty( + getExtendedQueryProtocolHandler() + .getBackendConnection() + .getSessionState() + .get(null, "application_name") + .getSetting())) { + getExtendedQueryProtocolHandler() + .getBackendConnection() + .getSessionState() + .set(null, "application_name", wellKnownClient.name().toLowerCase(Locale.ENGLISH)); + getExtendedQueryProtocolHandler().getBackendConnection().getSessionState().commit(); + } + } catch (Throwable ignore) { + // Safeguard against a theoretical situation that 'application_name' has been removed from + // the list of settings. Just ignore this situation, as the only consequence is that the + // 'application_name' setting has not been set. + } + } + /** Status of a {@link ConnectionHandler} */ public enum ConnectionStatus { UNAUTHENTICATED, diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java b/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java index e620ff36f..f38c4dde4 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java @@ -358,7 +358,7 @@ public String toString() { return String.format("ProxyServer[port: %d]", getLocalPort()); } - ConcurrentLinkedQueue getDebugMessages() { + public ConcurrentLinkedQueue getDebugMessages() { return debugMessages; } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/metadata/OptionsMetadata.java b/src/main/java/com/google/cloud/spanner/pgadapter/metadata/OptionsMetadata.java index 6abac46f3..f3b595cce 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/metadata/OptionsMetadata.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/metadata/OptionsMetadata.java @@ -88,7 +88,7 @@ public enum DdlTransactionMode { } private static final Logger logger = Logger.getLogger(OptionsMetadata.class.getName()); - private static final String DEFAULT_SERVER_VERSION = "14.1"; + public static final String DEFAULT_SERVER_VERSION = "14.1"; private static final String DEFAULT_USER_AGENT = "pg-adapter"; private static final String OPTION_SERVER_PORT = "s"; diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/Parser.java b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/Parser.java index 16141cec1..4bbb1b67d 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/Parser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/Parser.java @@ -97,9 +97,12 @@ public static Parser create( case Oid.JSONB: return new JsonbParser(item, formatCode); case Oid.UNSPECIFIED: - return new UnspecifiedParser(item, formatCode); default: - throw new IllegalArgumentException("Unsupported parameter type: " + oidType); + // Use the UnspecifiedParser for unknown types. This will encode the parameter value as a + // string and send it to Spanner without any type information. This will ensure that clients + // that for example send char instead of varchar as the type code for a parameter would + // still work. + return new UnspecifiedParser(item, formatCode); } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/session/SessionState.java b/src/main/java/com/google/cloud/spanner/pgadapter/session/SessionState.java index 4692e590c..ca37be802 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/session/SessionState.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/session/SessionState.java @@ -374,6 +374,11 @@ public void rollback() { this.transactionSettings = null; } + /** Returns the PostgreSQL version. */ + public String getServerVersion() { + return getStringSetting(null, "server_version", OptionsMetadata.DEFAULT_SERVER_VERSION); + } + /** * Returns whether transaction statements should be ignored and all statements should be executed * in autocommit mode. diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/BackendConnection.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/BackendConnection.java index d3ba9aa36..ef1d09f76 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/BackendConnection.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/BackendConnection.java @@ -30,7 +30,6 @@ import com.google.cloud.spanner.PartitionOptions; import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode; import com.google.cloud.spanner.ResultSet; -import com.google.cloud.spanner.ResultSets; import com.google.cloud.spanner.Spanner; import com.google.cloud.spanner.SpannerBatchUpdateException; import com.google.cloud.spanner.SpannerException; @@ -59,12 +58,14 @@ import com.google.cloud.spanner.pgadapter.statements.SessionStatementParser.SessionStatement; import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName; import com.google.cloud.spanner.pgadapter.statements.local.LocalStatement; +import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient; import com.google.cloud.spanner.pgadapter.utils.CopyDataReceiver; import com.google.cloud.spanner.pgadapter.utils.MutationWriter; import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse; import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse.Status; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap.Builder; @@ -87,6 +88,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -208,9 +210,10 @@ void execute() { // block always ends with a ROLLBACK, PGAdapter should skip the entire execution of that // block. SessionStatement sessionStatement = getSessionManagementStatement(parsedStatement); - if (!localStatements.isEmpty() && localStatements.containsKey(statement.getSql())) { + if (!localStatements.get().isEmpty() + && localStatements.get().containsKey(statement.getSql())) { result.set( - Objects.requireNonNull(localStatements.get(statement.getSql())) + Objects.requireNonNull(localStatements.get().get(statement.getSql())) .execute(BackendConnection.this)); } else if (sessionStatement != null) { result.set(sessionStatement.execute(sessionState)); @@ -253,7 +256,7 @@ void execute() { // Potentially replace pg_catalog table references with common table expressions. updatedStatement = sessionState.isReplacePgCatalogTables() - ? pgCatalog.replacePgCatalogTables(statement) + ? pgCatalog.get().replacePgCatalogTables(statement) : statement; updatedStatement = bindStatement(updatedStatement); result.set(analyzeOrExecute(updatedStatement)); @@ -372,7 +375,7 @@ SessionStatement getSessionManagementStatement(ParsedStatement parsedStatement) } } - private static final int MAX_PARTITIONS = + public static final int MAX_PARTITIONS = Math.max(16, 2 * Runtime.getRuntime().availableProcessors()); private final class CopyOut extends BufferedStatement { @@ -401,9 +404,14 @@ void execute() { // No need for the extra complexity of a partitioned query. result.set(spannerConnection.execute(statement)); } else { + // Get the metadata of the query, so we can include that in the result. + ResultSet metadataResultSet = + spannerConnection.analyzeQuery(statement, QueryAnalyzeMode.PLAN); result.set( new PartitionQueryResult( - batchReadOnlyTransaction.getBatchTransactionId(), partitions)); + batchReadOnlyTransaction.getBatchTransactionId(), + partitions, + metadataResultSet)); } } catch (SpannerException spannerException) { // The query might not be suitable for partitioning. Just try with a normal query. @@ -562,8 +570,8 @@ void execute() { private static final Statement ROLLBACK = Statement.of("ROLLBACK"); private final SessionState sessionState; - private final PgCatalog pgCatalog; - private final ImmutableMap localStatements; + private final Supplier pgCatalog; + private final Supplier> localStatements; private ConnectionState connectionState = ConnectionState.IDLE; private TransactionMode transactionMode = TransactionMode.IMPLICIT; private final String currentSchema = "public"; @@ -576,24 +584,31 @@ void execute() { BackendConnection( DatabaseId databaseId, Connection spannerConnection, + Supplier wellKnownClient, OptionsMetadata optionsMetadata, - ImmutableList localStatements) { + Supplier> localStatements) { this.sessionState = new SessionState(optionsMetadata); - this.pgCatalog = new PgCatalog(this.sessionState); + this.pgCatalog = + Suppliers.memoize( + () -> new PgCatalog(BackendConnection.this.sessionState, wellKnownClient.get())); this.spannerConnection = spannerConnection; this.databaseId = databaseId; this.ddlExecutor = new DdlExecutor(databaseId, this); - if (localStatements.isEmpty()) { - this.localStatements = EMPTY_LOCAL_STATEMENTS; - } else { - Builder builder = ImmutableMap.builder(); - for (LocalStatement localStatement : localStatements) { - for (String sql : localStatement.getSql()) { - builder.put(new SimpleImmutableEntry<>(sql, localStatement)); - } - } - this.localStatements = builder.build(); - } + this.localStatements = + Suppliers.memoize( + () -> { + if (localStatements.get().isEmpty()) { + return EMPTY_LOCAL_STATEMENTS; + } else { + Builder builder = ImmutableMap.builder(); + for (LocalStatement localStatement : localStatements.get()) { + for (String sql : localStatement.getSql()) { + builder.put(new SimpleImmutableEntry<>(sql, localStatement)); + } + } + return builder.build(); + } + }); } /** Returns the current connection state. */ @@ -1262,10 +1277,15 @@ public Long getUpdateCount() { public static final class PartitionQueryResult implements StatementResult { private final BatchTransactionId batchTransactionId; private final List partitions; + private final ResultSet metadataResultSet; - public PartitionQueryResult(BatchTransactionId batchTransactionId, List partitions) { + public PartitionQueryResult( + BatchTransactionId batchTransactionId, + List partitions, + ResultSet metadataResultSet) { this.batchTransactionId = batchTransactionId; this.partitions = partitions; + this.metadataResultSet = metadataResultSet; } public BatchTransactionId getBatchTransactionId() { @@ -1276,6 +1296,10 @@ public List getPartitions() { return partitions; } + public ResultSet getMetadataResultSet() { + return metadataResultSet; + } + @Override public ResultType getResultType() { return ResultType.RESULT_SET; @@ -1288,7 +1312,7 @@ public ClientSideStatementType getClientSideStatementType() { @Override public ResultSet getResultSet() { - return ResultSets.forRows( + return ClientSideResultSet.forRows( Type.struct(StructField.of("partition", Type.bytes())), partitions.stream() .map( diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/ClientSideResultSet.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/ClientSideResultSet.java new file mode 100644 index 000000000..91c92bcb7 --- /dev/null +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/ClientSideResultSet.java @@ -0,0 +1,54 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.statements; + +import com.google.api.core.InternalApi; +import com.google.cloud.spanner.ForwardingResultSet; +import com.google.cloud.spanner.ResultSet; +import com.google.cloud.spanner.ResultSets; +import com.google.cloud.spanner.Struct; +import com.google.cloud.spanner.Type; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; + +/** Wrapper class for query results that are handled directly in PGAdapter. */ +@InternalApi +public class ClientSideResultSet extends ForwardingResultSet { + public static ResultSet forRows(Type type, Iterable rows) { + return new ClientSideResultSet(ResultSets.forRows(type, rows), type); + } + + private final Type type; + + private ClientSideResultSet(ResultSet delegate, Type type) { + super(delegate); + this.type = type; + } + + @Override + public Type getType() { + return type; + } + + @Override + public ResultSetStats getStats() { + return ResultSetStats.getDefaultInstance(); + } + + @Override + public ResultSetMetadata getMetadata() { + return ResultSetMetadata.getDefaultInstance(); + } +} diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java index 7c49c4bd1..370cd620c 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java @@ -24,6 +24,7 @@ import com.google.cloud.spanner.connection.AbstractStatementParser.StatementType; import com.google.cloud.spanner.connection.AutocommitDmlMode; import com.google.cloud.spanner.pgadapter.ConnectionHandler; +import com.google.cloud.spanner.pgadapter.ProxyServer.DataFormat; import com.google.cloud.spanner.pgadapter.error.PGException; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; import com.google.cloud.spanner.pgadapter.error.SQLState; @@ -80,7 +81,17 @@ public static IntermediatePortalStatement create( public enum Format { CSV, TEXT, - BINARY, + BINARY { + @Override + public DataFormat getDataFormat() { + return DataFormat.POSTGRESQL_BINARY; + } + }; + + /** Returns the (default) data format that should be used for this copy format. */ + public DataFormat getDataFormat() { + return DataFormat.POSTGRESQL_TEXT; + } } private static final String COLUMN_NAME = "column_name"; @@ -230,8 +241,8 @@ public MutationWriter getMutationWriter() { } /** @return 0 for text/csv formatting and 1 for binary */ - public int getFormatCode() { - return (parsedCopyStatement.format == Format.BINARY) ? 1 : 0; + public byte getFormatCode() { + return (parsedCopyStatement.format == Format.BINARY) ? (byte) 1 : (byte) 0; } private void verifyCopyColumns() { @@ -552,7 +563,7 @@ static ParsedCopyStatement parse(String sql) { } ParsedCopyStatement.Builder builder = new ParsedCopyStatement.Builder(); if (parser.eatToken("(")) { - builder.query = parser.parseExpressionUntilKeyword(ImmutableList.of(), true, true); + builder.query = parser.parseExpressionUntilKeyword(ImmutableList.of(), true, true, false); if (!parser.eatToken(")")) { throw PGExceptionFactory.newPGException( "missing closing parentheses after query", SQLState.SyntaxError); diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatement.java index 4c4c0f959..de2c852e1 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatement.java @@ -23,7 +23,6 @@ import com.google.cloud.spanner.connection.Connection; import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.pgadapter.ConnectionHandler; -import com.google.cloud.spanner.pgadapter.ProxyServer.DataFormat; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; import com.google.cloud.spanner.pgadapter.error.SQLState; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; @@ -42,6 +41,7 @@ import com.google.common.util.concurrent.Futures; import java.util.List; import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.QuoteMode; @@ -57,7 +57,8 @@ public class CopyToStatement extends IntermediatePortalStatement { new byte[] {'P', 'G', 'C', 'O', 'P', 'Y', '\n', -1, '\r', '\n', '\0'}; private final ParsedCopyStatement parsedCopyStatement; - private final CSVFormat csvFormat; + private CSVFormat csvFormat; + private final AtomicBoolean hasReturnedData = new AtomicBoolean(false); public CopyToStatement( ConnectionHandler connectionHandler, @@ -80,24 +81,28 @@ public CopyToStatement( if (parsedCopyStatement.format == CopyStatement.Format.BINARY) { this.csvFormat = null; } else { + CSVFormat baseFormat = + parsedCopyStatement.format == Format.TEXT + ? CSVFormat.POSTGRESQL_TEXT + : CSVFormat.POSTGRESQL_CSV; CSVFormat.Builder formatBuilder = - CSVFormat.Builder.create(CSVFormat.POSTGRESQL_TEXT) + CSVFormat.Builder.create(baseFormat) .setNullString( parsedCopyStatement.nullString == null - ? CSVFormat.POSTGRESQL_TEXT.getNullString() + ? baseFormat.getNullString() : parsedCopyStatement.nullString) .setRecordSeparator('\n') .setDelimiter( parsedCopyStatement.delimiter == null - ? CSVFormat.POSTGRESQL_TEXT.getDelimiterString().charAt(0) + ? baseFormat.getDelimiterString().charAt(0) : parsedCopyStatement.delimiter) .setQuote( parsedCopyStatement.quote == null - ? CSVFormat.POSTGRESQL_TEXT.getQuoteCharacter() + ? baseFormat.getQuoteCharacter() : parsedCopyStatement.quote) .setEscape( parsedCopyStatement.escape == null - ? CSVFormat.POSTGRESQL_TEXT.getEscapeCharacter() + ? baseFormat.getEscapeCharacter() : parsedCopyStatement.escape); if (parsedCopyStatement.format == Format.TEXT) { formatBuilder.setQuoteMode(QuoteMode.NONE); @@ -220,22 +225,25 @@ public IntermediatePortalStatement createPortal( @Override public WireOutput[] createResultPrefix(ResultSet resultSet) { - return this.parsedCopyStatement.format == CopyStatement.Format.BINARY - ? new WireOutput[] { - new CopyOutResponse( - this.outputStream, - resultSet.getColumnCount(), - DataFormat.POSTGRESQL_BINARY.getCode()), - CopyDataResponse.createBinaryHeader(this.outputStream) - } - : new WireOutput[] { - new CopyOutResponse( - this.outputStream, resultSet.getColumnCount(), DataFormat.POSTGRESQL_TEXT.getCode()) - }; + return new WireOutput[] { + new CopyOutResponse( + this.outputStream, + resultSet.getColumnCount(), + this.parsedCopyStatement.format.getDataFormat().getCode()) + }; } @Override public CopyDataResponse createDataRowResponse(Converter converter) { + // Keep track of whether this COPY statement has returned at least one row. This is necessary to + // know whether we need to include the header in the current row and/or in the trailer. + // PostgreSQL includes the header in either the first data row or in the trailer if there are no + // rows. This is not specifically mentioned in the protocol description, but some clients assume + // this behavior. See + // https://github.com/npgsql/npgsql/blob/7f97dbad28c71b2202dd7bcccd05fc42a7de23c8/src/Npgsql/NpgsqlBinaryExporter.cs#L156 + if (parsedCopyStatement.format == Format.BINARY && !hasReturnedData.getAndSet(true)) { + converter = converter.includeBinaryCopyHeader(); + } return parsedCopyStatement.format == CopyStatement.Format.BINARY ? createBinaryDataResponse(converter) : createDataResponse(converter.getResultSet()); @@ -245,7 +253,7 @@ public CopyDataResponse createDataRowResponse(Converter converter) { public WireOutput[] createResultSuffix() { return this.parsedCopyStatement.format == Format.BINARY ? new WireOutput[] { - CopyDataResponse.createBinaryTrailer(this.outputStream), + CopyDataResponse.createBinaryTrailer(this.outputStream, !hasReturnedData.get()), new CopyDoneResponse(this.outputStream) } : new WireOutput[] {new CopyDoneResponse(this.outputStream)}; @@ -268,6 +276,10 @@ CopyDataResponse createDataResponse(ResultSet resultSet) { } } String row = csvFormat.format((Object[]) data); + // Only include the header with the first row. + if (!csvFormat.getSkipHeaderRecord()) { + csvFormat = csvFormat.builder().setSkipHeaderRecord(true).build(); + } return new CopyDataResponse(this.outputStream, row, csvFormat.getRecordSeparator().charAt(0)); } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExtendedQueryProtocolHandler.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExtendedQueryProtocolHandler.java index ae61125d0..9d3c5bf7c 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExtendedQueryProtocolHandler.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExtendedQueryProtocolHandler.java @@ -40,8 +40,9 @@ public ExtendedQueryProtocolHandler(ConnectionHandler connectionHandler) { new BackendConnection( connectionHandler.getDatabaseId(), connectionHandler.getSpannerConnection(), + connectionHandler::getWellKnownClient, connectionHandler.getServer().getOptions(), - connectionHandler.getWellKnownClient().getLocalStatements(connectionHandler)); + () -> connectionHandler.getWellKnownClient().getLocalStatements(connectionHandler)); } /** Constructor only intended for testing. */ diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/PgCatalog.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/PgCatalog.java index a1d4aac4f..d806f576c 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/PgCatalog.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/PgCatalog.java @@ -20,6 +20,9 @@ import com.google.cloud.spanner.Value; import com.google.cloud.spanner.pgadapter.session.SessionState; import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName; +import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient; +import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -27,11 +30,13 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; +import java.util.function.Supplier; import java.util.regex.Pattern; +import javax.annotation.Nonnull; @InternalApi public class PgCatalog { - private static final ImmutableMap TABLE_REPLACEMENTS = + private static final ImmutableMap DEFAULT_TABLE_REPLACEMENTS = ImmutableMap.builder() .put( new TableOrIndexName("pg_catalog", "pg_namespace"), @@ -57,30 +62,66 @@ public class PgCatalog { .put(new TableOrIndexName(null, "pg_settings"), new TableOrIndexName(null, "pg_settings")) .build(); - private static final ImmutableMap FUNCTION_REPLACEMENTS = + private static final ImmutableMap> DEFAULT_FUNCTION_REPLACEMENTS = ImmutableMap.of( - Pattern.compile("pg_catalog.pg_table_is_visible\\(.+\\)"), "true", - Pattern.compile("pg_table_is_visible\\(.+\\)"), "true", - Pattern.compile("ANY\\(current_schemas\\(true\\)\\)"), "'public'"); + Pattern.compile("pg_catalog.pg_table_is_visible\\(.+\\)"), Suppliers.ofInstance("true"), + Pattern.compile("pg_table_is_visible\\(.+\\)"), Suppliers.ofInstance("true"), + Pattern.compile("ANY\\(current_schemas\\(true\\)\\)"), Suppliers.ofInstance("'public'")); - private final Map pgCatalogTables = + private final ImmutableSet checkPrefixes; + + private final ImmutableMap tableReplacements; + private final ImmutableMap pgCatalogTables; + + private final ImmutableMap> functionReplacements; + + private static final Map DEFAULT_PG_CATALOG_TABLES = ImmutableMap.of( new TableOrIndexName(null, "pg_namespace"), new PgNamespace(), new TableOrIndexName(null, "pg_class"), new PgClass(), new TableOrIndexName(null, "pg_proc"), new PgProc(), new TableOrIndexName(null, "pg_range"), new PgRange(), - new TableOrIndexName(null, "pg_type"), new PgType(), - new TableOrIndexName(null, "pg_settings"), new PgSettings()); + new TableOrIndexName(null, "pg_type"), new PgType()); private final SessionState sessionState; - public PgCatalog(SessionState sessionState) { - this.sessionState = sessionState; + public PgCatalog(@Nonnull SessionState sessionState, @Nonnull WellKnownClient wellKnownClient) { + this.sessionState = Preconditions.checkNotNull(sessionState); + this.checkPrefixes = wellKnownClient.getPgCatalogCheckPrefixes(); + ImmutableMap.Builder builder = + ImmutableMap.builder() + .putAll(DEFAULT_TABLE_REPLACEMENTS); + wellKnownClient + .getTableReplacements() + .forEach((k, v) -> builder.put(TableOrIndexName.parse(k), TableOrIndexName.parse(v))); + this.tableReplacements = builder.build(); + + ImmutableMap.Builder pgCatalogTablesBuilder = + ImmutableMap.builder() + .putAll(DEFAULT_PG_CATALOG_TABLES) + .put(new TableOrIndexName(null, "pg_settings"), new PgSettings()); + wellKnownClient + .getPgCatalogTables() + .forEach((k, v) -> pgCatalogTablesBuilder.put(TableOrIndexName.parse(k), v)); + this.pgCatalogTables = pgCatalogTablesBuilder.build(); + + this.functionReplacements = + ImmutableMap.>builder() + .putAll(DEFAULT_FUNCTION_REPLACEMENTS) + .put( + Pattern.compile("version\\(\\)"), () -> "'" + sessionState.getServerVersion() + "'") + .putAll(wellKnownClient.getFunctionReplacements()) + .build(); } /** Replace supported pg_catalog tables with Common Table Expressions. */ public Statement replacePgCatalogTables(Statement statement) { + // Only replace tables if the statement contains at least one of the known prefixes. + if (checkPrefixes.stream().noneMatch(prefix -> statement.getSql().contains(prefix))) { + return statement; + } + Tuple, Statement> replacedTablesStatement = - new TableParser(statement).detectAndReplaceTables(TABLE_REPLACEMENTS); + new TableParser(statement).detectAndReplaceTables(tableReplacements); if (replacedTablesStatement.x().isEmpty()) { return replacedTablesStatement.y(); } @@ -95,16 +136,19 @@ public Statement replacePgCatalogTables(Statement statement) { return addCommonTableExpressions(replacedTablesStatement.y(), cteBuilder.build()); } - static String replaceKnownUnsupportedFunctions(Statement statement) { + String replaceKnownUnsupportedFunctions(Statement statement) { String sql = statement.getSql(); - for (Entry functionReplacement : FUNCTION_REPLACEMENTS.entrySet()) { - sql = functionReplacement.getKey().matcher(sql).replaceAll(functionReplacement.getValue()); + for (Entry> functionReplacement : functionReplacements.entrySet()) { + sql = + functionReplacement + .getKey() + .matcher(sql) + .replaceAll(functionReplacement.getValue().get()); } return sql; } - static Statement addCommonTableExpressions( - Statement statement, ImmutableList tableExpressions) { + Statement addCommonTableExpressions(Statement statement, ImmutableList tableExpressions) { String sql = replaceKnownUnsupportedFunctions(statement); SimpleParser parser = new SimpleParser(sql); boolean hadCommonTableExpressions = parser.eatKeyword("with"); @@ -153,7 +197,8 @@ PgCatalogTable getPgCatalogTable(TableOrIndexName tableOrIndexName) { return null; } - private interface PgCatalogTable { + @InternalApi + public interface PgCatalogTable { String getTableExpression(); default ImmutableSet getDependencies() { @@ -378,4 +423,37 @@ public String getTableExpression() { return PG_RANGE_CTE; } } + + @InternalApi + public static class EmptyPgAttribute implements PgCatalogTable { + private static final String PG_ATTRIBUTE_CTE = + "pg_attribute as (\n" + + "select * from (" + + "select 0::bigint as attrelid, '' as attname, 0::bigint as atttypid, 0::bigint as attstattarget, " + + "0::bigint as attlen, 0::bigint as attnum, 0::bigint as attndims, -1::bigint as attcacheoff, " + + "0::bigint as atttypmod, true as attbyval, '' as attalign, '' as attstorage, '' as attcompression, " + + "false as attnotnull, true as atthasdef, false as atthasmissing, '' as attidentity, '' as attgenerated, " + + "false as attisdropped, true as attislocal, 0 as attinhcount, 0 as attcollation, '{}'::bigint[] as attacl, " + + "'{}'::text[] as attoptions, '{}'::text[] as attfdwoptions, null as attmissingval\n" + + ") a where false)"; + + @Override + public String getTableExpression() { + return PG_ATTRIBUTE_CTE; + } + } + + @InternalApi + public static class EmptyPgEnum implements PgCatalogTable { + private static final String PG_ENUM_CTE = + "pg_enum as (\n" + + "select * from (" + + "select 0::bigint as oid, 0::bigint as enumtypid, 0.0::float8 as enumsortorder, ''::varchar as enumlabel\n" + + ") e where false)"; + + @Override + public String getTableExpression() { + return PG_ENUM_CTE; + } + } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/SessionStatementParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/SessionStatementParser.java index 5b0608684..2bfc7a0a3 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/SessionStatementParser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/SessionStatementParser.java @@ -17,7 +17,6 @@ import com.google.api.client.util.Strings; import com.google.api.core.InternalApi; import com.google.cloud.spanner.ErrorCode; -import com.google.cloud.spanner.ResultSets; import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.Struct; import com.google.cloud.spanner.Type; @@ -186,7 +185,7 @@ static ShowStatement createShowAll() { public StatementResult execute(SessionState sessionState) { if (name != null) { return new QueryResult( - ResultSets.forRows( + ClientSideResultSet.forRows( Type.struct(StructField.of(getKey(), Type.string())), ImmutableList.of( Struct.newBuilder() @@ -195,7 +194,7 @@ public StatementResult execute(SessionState sessionState) { .build()))); } return new QueryResult( - ResultSets.forRows( + ClientSideResultSet.forRows( Type.struct( StructField.of("name", Type.string()), StructField.of("setting", Type.string()), diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/SimpleParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/SimpleParser.java index ccbd8eb9c..3e5f389d2 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/SimpleParser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/SimpleParser.java @@ -49,6 +49,19 @@ static class TableOrIndexName { /** Name is the actual object name. */ final String name; + /** + * Parses an unquoted, qualified identifier. Use this to quickly parse strings like 'foo' and + * 'foo.bar' as an identifier. + */ + static TableOrIndexName parse(String qualifiedName) { + SimpleParser parser = new SimpleParser(qualifiedName); + TableOrIndexName result = parser.readTableOrIndexName(); + if (result == null || parser.hasMoreTokens()) { + throw new IllegalArgumentException("Invalid identifier: " + qualifiedName); + } + return result; + } + static TableOrIndexName of(String name) { return new TableOrIndexName(name); } @@ -354,6 +367,15 @@ String parseExpressionUntilKeyword( ImmutableList keywords, boolean sameParensLevelAsStart, boolean stopAtEndOfExpression) { + return parseExpressionUntilKeyword( + keywords, sameParensLevelAsStart, stopAtEndOfExpression, true); + } + + String parseExpressionUntilKeyword( + ImmutableList keywords, + boolean sameParensLevelAsStart, + boolean stopAtEndOfExpression, + boolean stopAtComma) { skipWhitespaces(); int start = pos; boolean valid; @@ -366,7 +388,7 @@ String parseExpressionUntilKeyword( if (stopAtEndOfExpression && parens < 0) { break; } - } else if (stopAtEndOfExpression && parens == 0 && sql.charAt(pos) == ',') { + } else if (stopAtEndOfExpression && parens == 0 && stopAtComma && sql.charAt(pos) == ',') { break; } if ((!sameParensLevelAsStart || parens == 0) diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/DjangoGetTableNamesStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/DjangoGetTableNamesStatement.java index f92005013..127cabde2 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/DjangoGetTableNamesStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/DjangoGetTableNamesStatement.java @@ -16,12 +16,12 @@ import com.google.api.core.InternalApi; import com.google.cloud.spanner.ResultSet; -import com.google.cloud.spanner.ResultSets; import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Type.StructField; import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.pgadapter.statements.BackendConnection; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.QueryResult; +import com.google.cloud.spanner.pgadapter.statements.ClientSideResultSet; import com.google.common.collect.ImmutableList; /* @@ -60,7 +60,7 @@ public String[] getSql() { @Override public StatementResult execute(BackendConnection backendConnection) { ResultSet resultSet = - ResultSets.forRows( + ClientSideResultSet.forRows( Type.struct( StructField.of("relname", Type.string()), StructField.of("case", Type.string())), ImmutableList.of()); diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/ListDatabasesStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/ListDatabasesStatement.java index d9f189602..ed4a22147 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/ListDatabasesStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/ListDatabasesStatement.java @@ -20,7 +20,6 @@ import com.google.cloud.spanner.Dialect; import com.google.cloud.spanner.InstanceId; import com.google.cloud.spanner.ResultSet; -import com.google.cloud.spanner.ResultSets; import com.google.cloud.spanner.Spanner; import com.google.cloud.spanner.Struct; import com.google.cloud.spanner.Type; @@ -31,6 +30,7 @@ import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.statements.BackendConnection; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.QueryResult; +import com.google.cloud.spanner.pgadapter.statements.ClientSideResultSet; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import java.util.Comparator; @@ -74,7 +74,7 @@ public StatementResult execute(BackendConnection backendConnection) { InstanceId defaultInstanceId = connectionHandler.getServer().getOptions().getDefaultInstanceId(); ResultSet resultSet = - ResultSets.forRows( + ClientSideResultSet.forRows( Type.struct( StructField.of("Name", Type.string()), StructField.of("Owner", Type.string()), diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentCatalogStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentCatalogStatement.java index aabbfdfe9..e0fcfea9a 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentCatalogStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentCatalogStatement.java @@ -16,13 +16,13 @@ import com.google.api.core.InternalApi; import com.google.cloud.spanner.ResultSet; -import com.google.cloud.spanner.ResultSets; import com.google.cloud.spanner.Struct; import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Type.StructField; import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.pgadapter.statements.BackendConnection; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.QueryResult; +import com.google.cloud.spanner.pgadapter.statements.ClientSideResultSet; import com.google.common.collect.ImmutableList; @InternalApi @@ -52,7 +52,7 @@ public String[] getSql() { @Override public StatementResult execute(BackendConnection backendConnection) { ResultSet resultSet = - ResultSets.forRows( + ClientSideResultSet.forRows( Type.struct(StructField.of("current_catalog", Type.string())), ImmutableList.of( Struct.newBuilder() diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentDatabaseStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentDatabaseStatement.java index 2eefe9c3e..1e173b971 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentDatabaseStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentDatabaseStatement.java @@ -16,13 +16,13 @@ import com.google.api.core.InternalApi; import com.google.cloud.spanner.ResultSet; -import com.google.cloud.spanner.ResultSets; import com.google.cloud.spanner.Struct; import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Type.StructField; import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.pgadapter.statements.BackendConnection; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.QueryResult; +import com.google.cloud.spanner.pgadapter.statements.ClientSideResultSet; import com.google.common.collect.ImmutableList; @InternalApi @@ -53,7 +53,7 @@ public String[] getSql() { @Override public StatementResult execute(BackendConnection backendConnection) { ResultSet resultSet = - ResultSets.forRows( + ClientSideResultSet.forRows( Type.struct(StructField.of("current_database", Type.string())), ImmutableList.of( Struct.newBuilder() diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentSchemaStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentSchemaStatement.java index bba360e03..6afd994ef 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentSchemaStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentSchemaStatement.java @@ -16,13 +16,13 @@ import com.google.api.core.InternalApi; import com.google.cloud.spanner.ResultSet; -import com.google.cloud.spanner.ResultSets; import com.google.cloud.spanner.Struct; import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Type.StructField; import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.pgadapter.statements.BackendConnection; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.QueryResult; +import com.google.cloud.spanner.pgadapter.statements.ClientSideResultSet; import com.google.common.collect.ImmutableList; @InternalApi @@ -50,7 +50,7 @@ public String[] getSql() { @Override public StatementResult execute(BackendConnection backendConnection) { ResultSet resultSet = - ResultSets.forRows( + ClientSideResultSet.forRows( Type.struct(StructField.of("current_schema", Type.string())), ImmutableList.of( Struct.newBuilder() diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectVersionStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectVersionStatement.java index 59980448b..3fbfe1fe3 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectVersionStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectVersionStatement.java @@ -16,13 +16,13 @@ import com.google.api.core.InternalApi; import com.google.cloud.spanner.ResultSet; -import com.google.cloud.spanner.ResultSets; import com.google.cloud.spanner.Struct; import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Type.StructField; import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.pgadapter.statements.BackendConnection; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.QueryResult; +import com.google.cloud.spanner.pgadapter.statements.ClientSideResultSet; import com.google.common.collect.ImmutableList; @InternalApi @@ -64,7 +64,7 @@ public String[] getSql() { @Override public StatementResult execute(BackendConnection backendConnection) { ResultSet resultSet = - ResultSets.forRows( + ClientSideResultSet.forRows( Type.struct(StructField.of("version", Type.string())), ImmutableList.of( Struct.newBuilder() diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParser.java index d39e1bd71..d90bb8955 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParser.java @@ -230,11 +230,24 @@ public int numColumns() { return fields.length; } + @Override + public boolean isEndRecord() { + return false; + } + @Override public boolean hasColumnNames() { return false; } + @Override + public boolean isNull(int columnIndex) { + Preconditions.checkArgument( + columnIndex >= 0 && columnIndex < numColumns(), + "columnIndex must be >= 0 && < numColumns"); + return fields[columnIndex].data == null; + } + @Override public Value getValue(Type type, String columnName) { // The binary copy format does not include any column name headers or any type information. diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetector.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetector.java index 44db4631c..5f9985a39 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetector.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetector.java @@ -17,6 +17,11 @@ import com.google.api.core.InternalApi; import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.pgadapter.ConnectionHandler; +import com.google.cloud.spanner.pgadapter.error.PGException; +import com.google.cloud.spanner.pgadapter.error.SQLState; +import com.google.cloud.spanner.pgadapter.statements.PgCatalog.EmptyPgAttribute; +import com.google.cloud.spanner.pgadapter.statements.PgCatalog.EmptyPgEnum; +import com.google.cloud.spanner.pgadapter.statements.PgCatalog.PgCatalogTable; import com.google.cloud.spanner.pgadapter.statements.local.DjangoGetTableNamesStatement; import com.google.cloud.spanner.pgadapter.statements.local.ListDatabasesStatement; import com.google.cloud.spanner.pgadapter.statements.local.LocalStatement; @@ -24,10 +29,19 @@ import com.google.cloud.spanner.pgadapter.statements.local.SelectCurrentDatabaseStatement; import com.google.cloud.spanner.pgadapter.statements.local.SelectCurrentSchemaStatement; import com.google.cloud.spanner.pgadapter.statements.local.SelectVersionStatement; +import com.google.cloud.spanner.pgadapter.wireoutput.NoticeResponse; +import com.google.cloud.spanner.pgadapter.wireoutput.NoticeResponse.NoticeSeverity; +import com.google.cloud.spanner.pgadapter.wireprotocol.ParseMessage; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import java.time.Duration; import java.util.List; import java.util.Map; +import java.util.function.Supplier; +import java.util.regex.Pattern; import javax.annotation.Nonnull; import org.postgresql.core.Oid; @@ -45,6 +59,10 @@ public class ClientAutoDetector { SelectCurrentCatalogStatement.INSTANCE, SelectVersionStatement.INSTANCE, DjangoGetTableNamesStatement.INSTANCE); + private static final ImmutableSet DEFAULT_CHECK_PG_CATALOG_PREFIXES = + ImmutableSet.of("pg_"); + public static final String PGBENCH_USAGE_HINT = + "See https://github.com/GoogleCloudPlatform/pgadapter/blob/-/docs/pgbench.md for how to use pgbench with PGAdapter"; public enum WellKnownClient { PSQL { @@ -67,6 +85,48 @@ public ImmutableList getLocalStatements(ConnectionHandler connec return ImmutableList.of(new ListDatabasesStatement(connectionHandler)); } }, + PGBENCH { + final ImmutableList errorHints = ImmutableList.of(PGBENCH_USAGE_HINT); + volatile long lastHintTimestampMillis = 0L; + + @Override + public void reset() { + lastHintTimestampMillis = 0L; + } + + @Override + boolean isClient(List orderedParameterKeys, Map parameters) { + // PGBENCH makes it easy for us, as it sends its own name in the application_name parameter. + return parameters.containsKey("application_name") + && parameters.get("application_name").equals("pgbench"); + } + + @Override + public ImmutableList createStartupNoticeResponses( + ConnectionHandler connection) { + synchronized (PGBENCH) { + // Only send the hint at most once every 30 seconds, to prevent benchmark runs that open + // multiple connections from showing the hint every time. + if (Duration.ofMillis(System.currentTimeMillis() - lastHintTimestampMillis).getSeconds() + > 30L) { + lastHintTimestampMillis = System.currentTimeMillis(); + return ImmutableList.of( + new NoticeResponse( + connection.getConnectionMetadata().getOutputStream(), + SQLState.Success, + NoticeSeverity.INFO, + "Detected connection from pgbench", + PGBENCH_USAGE_HINT + "\n")); + } + } + return super.createStartupNoticeResponses(connection); + } + + @Override + public ImmutableList getErrorHints(PGException exception) { + return errorHints; + } + }, JDBC { @Override boolean isClient(List orderedParameterKeys, Map parameters) { @@ -113,8 +173,34 @@ boolean isClient(List orderedParameterKeys, Map paramete // pgx does not send enough unique parameters for it to be auto-detected. return false; } + + @Override + boolean isClient(ParseMessage parseMessage) { + // pgx uses a relatively unique naming scheme for prepared statements (and uses prepared + // statements for everything by default). + return parseMessage.getName() != null && parseMessage.getName().startsWith("lrupsc_"); + } }, NPGSQL { + final ImmutableMap tableReplacements = + ImmutableMap.of( + "pg_catalog.pg_attribute", + "pg_attribute", + "pg_attribute", + "pg_attribute", + "pg_catalog.pg_enum", + "pg_enum", + "pg_enum", + "pg_enum"); + final ImmutableMap pgCatalogTables = + ImmutableMap.of("pg_attribute", new EmptyPgAttribute(), "pg_enum", new EmptyPgEnum()); + + final ImmutableMap> functionReplacements = + ImmutableMap.of( + Pattern.compile("elemproc\\.oid = elemtyp\\.typreceive"), + Suppliers.ofInstance("false"), + Pattern.compile("proc\\.oid = typ\\.typreceive"), Suppliers.ofInstance("false")); + @Override boolean isClient(List orderedParameterKeys, Map parameters) { // npgsql does not send enough unique parameters for it to be auto-detected. @@ -134,6 +220,21 @@ boolean isClient(List statements) { + "\n" + "SELECT ns.nspname, t.oid, t.typname, t.typtype, t.typnotnull, t.elemtypoid\n"); } + + @Override + public ImmutableMap getTableReplacements() { + return tableReplacements; + } + + @Override + public ImmutableMap getPgCatalogTables() { + return pgCatalogTables; + } + + @Override + public ImmutableMap> getFunctionReplacements() { + return functionReplacements; + } }, UNSPECIFIED { @Override @@ -149,14 +250,29 @@ boolean isClient(List statements) { // defaults defined in this enum. return true; } + + @Override + boolean isClient(ParseMessage parseMessage) { + // Use UNSPECIFIED as default to prevent null checks everywhere and to ease the use of any + // defaults defined in this enum. + return true; + } }; abstract boolean isClient(List orderedParameterKeys, Map parameters); + /** Resets any cached or temporary settings for the client. */ + @VisibleForTesting + public void reset() {} + boolean isClient(List statements) { return false; } + boolean isClient(ParseMessage parseMessage) { + return false; + } + public ImmutableList getLocalStatements(ConnectionHandler connectionHandler) { if (connectionHandler.getServer().getOptions().useDefaultLocalStatements()) { return DEFAULT_LOCAL_STATEMENTS; @@ -164,6 +280,33 @@ public ImmutableList getLocalStatements(ConnectionHandler connec return EMPTY_LOCAL_STATEMENTS; } + public ImmutableSet getPgCatalogCheckPrefixes() { + return DEFAULT_CHECK_PG_CATALOG_PREFIXES; + } + + public ImmutableMap getTableReplacements() { + return ImmutableMap.of(); + } + + public ImmutableMap getPgCatalogTables() { + return ImmutableMap.of(); + } + + public ImmutableMap> getFunctionReplacements() { + return ImmutableMap.of(); + } + + /** Creates specific notice messages for a client after startup. */ + public ImmutableList createStartupNoticeResponses( + ConnectionHandler connection) { + return ImmutableList.of(); + } + + /** Returns the client-specific hint(s) that should be included with the given exception. */ + public ImmutableList getErrorHints(PGException exception) { + return ImmutableList.of(); + } + public ImmutableMap getDefaultParameters() { return ImmutableMap.of(); } @@ -198,4 +341,18 @@ public ImmutableMap getDefaultParameters() { // The following line should never be reached. throw new IllegalStateException("UNSPECIFIED.isClient() should have returned true"); } + + /** + * Returns the {@link WellKnownClient} that the detector thinks is connected to PGAdapter based on + * the Parse message that has been received. + */ + public static @Nonnull WellKnownClient detectClient(ParseMessage parseMessage) { + for (WellKnownClient client : WellKnownClient.values()) { + if (client.isClient(parseMessage)) { + return client; + } + } + // The following line should never be reached. + throw new IllegalStateException("UNSPECIFIED.isClient() should have returned true"); + } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/Converter.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/Converter.java index 48211567b..37857bd9d 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/Converter.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/Converter.java @@ -14,6 +14,9 @@ package com.google.cloud.spanner.pgadapter.utils; +import static com.google.cloud.spanner.pgadapter.statements.CopyToStatement.COPY_BINARY_HEADER; + +import com.google.api.core.InternalApi; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.Type; import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode; @@ -38,6 +41,7 @@ import java.io.IOException; /** Utility class for converting between generic PostgreSQL conversions. */ +@InternalApi public class Converter implements AutoCloseable { private final ByteArrayOutputStream buffer = new ByteArrayOutputStream(256); private final DataOutputStream outputStream = new DataOutputStream(buffer); @@ -46,12 +50,15 @@ public class Converter implements AutoCloseable { private final OptionsMetadata options; private final ResultSet resultSet; private final SessionState sessionState; + private boolean includeBinaryCopyHeaderInFirstRow; + private boolean firstRow = true; public Converter( IntermediateStatement statement, QueryMode mode, OptionsMetadata options, - ResultSet resultSet) { + ResultSet resultSet, + boolean includeBinaryCopyHeaderInFirstRow) { this.statement = statement; this.mode = mode; this.options = options; @@ -62,6 +69,16 @@ public Converter( .getExtendedQueryProtocolHandler() .getBackendConnection() .getSessionState(); + this.includeBinaryCopyHeaderInFirstRow = includeBinaryCopyHeaderInFirstRow; + } + + public Converter includeBinaryCopyHeader() { + this.includeBinaryCopyHeaderInFirstRow = true; + return this; + } + + public boolean isIncludeBinaryCopyHeaderInFirstRow() { + return this.includeBinaryCopyHeaderInFirstRow; } @Override @@ -83,6 +100,12 @@ public int convertResultSetRowToDataRowResponse() throws IOException { : DataFormat.POSTGRESQL_TEXT; } buffer.reset(); + if (includeBinaryCopyHeaderInFirstRow && firstRow) { + outputStream.write(COPY_BINARY_HEADER); + outputStream.writeInt(0); // flags + outputStream.writeInt(0); // header extension area length + } + firstRow = false; outputStream.writeShort(resultSet.getColumnCount()); for (int column_index = 0; /* column indices start at 0 */ column_index < resultSet.getColumnCount(); diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyDataReceiver.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyDataReceiver.java index cba2c259c..79d750b28 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyDataReceiver.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyDataReceiver.java @@ -65,7 +65,7 @@ void handleCopy() throws Exception { this.connectionHandler.setActiveCopyStatement(copyStatement); new CopyInResponse( this.connectionHandler.getConnectionMetadata().getOutputStream(), - copyStatement.getTableColumns().size(), + (short) copyStatement.getTableColumns().size(), copyStatement.getFormatCode()) .send(); ConnectionStatus initialConnectionStatus = this.connectionHandler.getStatus(); diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyRecord.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyRecord.java index e6c252aa4..192c46a8b 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyRecord.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyRecord.java @@ -28,12 +28,18 @@ public interface CopyRecord { /** Returns the number of columns in the record. */ int numColumns(); + /** Returns true if this record is the PG end record (\.). */ + boolean isEndRecord(); + /** * Returns true if the copy record has column names. The {@link #getValue(Type, String)} method * can only be used for records that have column names. */ boolean hasColumnNames(); + /** Returns true if the value of the given column is null. */ + boolean isNull(int columnIndex); + /** * Returns the value of the given column as a Cloud Spanner {@link Value} of the given type. This * method is used by a COPY ... FROM ... operation to convert a value to the type of the column diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParser.java index 9265faf6f..f6cde0d01 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParser.java @@ -33,6 +33,7 @@ import java.nio.charset.StandardCharsets; import java.time.format.DateTimeParseException; import java.util.Iterator; +import java.util.Objects; import java.util.logging.Level; import java.util.logging.Logger; import org.apache.commons.codec.binary.Hex; @@ -97,11 +98,25 @@ public int numColumns() { return record.size(); } + @Override + public boolean isEndRecord() { + // End of data can be represented by a single line containing just backslash-period (\.). An + // end-of-data marker is not necessary when reading from a file, since the end of file serves + // perfectly well; it is needed only when copying data to or from client applications using + // pre-3.0 client protocol. + return record.size() == 1 && Objects.equals("\\.", record.get(0)); + } + @Override public boolean hasColumnNames() { return this.hasHeader; } + @Override + public boolean isNull(int columnIndex) { + return record.get(columnIndex) == null; + } + @Override public Value getValue(Type type, String columnName) throws SpannerException { String recordValue = record.get(columnName); diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java index 55eeda095..bdb6b6d1a 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java @@ -66,6 +66,7 @@ import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nonnull; @@ -109,6 +110,8 @@ public enum CopyTransactionMode { private final CSVFormat csvFormat; private final boolean hasHeader; private final CountDownLatch pipeCreatedLatch = new CountDownLatch(1); + private final CountDownLatch dataReceivedLatch = new CountDownLatch(1); + private final AtomicLong bytesReceived = new AtomicLong(); private final PipedOutputStream payload = new PipedOutputStream(); private final AtomicBoolean commit = new AtomicBoolean(false); private final AtomicBoolean rollback = new AtomicBoolean(false); @@ -164,6 +167,8 @@ public void addCopyData(byte[] payload) { } try { pipeCreatedLatch.await(); + bytesReceived.addAndGet(payload.length); + dataReceivedLatch.countDown(); this.payload.write(payload); } catch (InterruptedException | InterruptedIOException interruptedIOException) { // The IO operation was interrupted. This indicates that the user wants to cancel the COPY @@ -204,6 +209,7 @@ public void rollback() { public void close() throws IOException { this.payload.close(); this.closedLatch.countDown(); + this.dataReceivedLatch.countDown(); } @Override @@ -228,14 +234,21 @@ public StatementResult call() throws Exception { // before finishing, to ensure that all data has been written before we signal that we are done. List> allCommitFutures = new ArrayList<>(); try { + // Wait until we know whether we actually will receive any data. It could be that it is an + // empty copy operation, and we should then end early. + dataReceivedLatch.await(); + Iterator iterator = parser.iterator(); List mutations = new ArrayList<>(); long currentBufferByteSize = 0L; // Note: iterator.hasNext() blocks if there is not enough data in the pipeline to construct a // complete record. It returns false if the stream has been closed and all records have been // returned. - while (!rollback.get() && iterator.hasNext()) { + while (bytesReceived.get() > 0L && !rollback.get() && iterator.hasNext()) { CopyRecord record = iterator.next(); + if (record.isEndRecord()) { + break; + } if (record.numColumns() != this.tableColumns.keySet().size()) { throw PGExceptionFactory.newPGException( "Invalid COPY data: Row length mismatched. Expected " diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyDataResponse.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyDataResponse.java index c91250ba1..871721d8e 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyDataResponse.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyDataResponse.java @@ -28,9 +28,9 @@ public class CopyDataResponse extends WireOutput { @InternalApi public enum ResponseType { - HEADER, ROW, TRAILER, + TRAILER_WITH_HEADER, } private final ResponseType responseType; @@ -39,16 +39,14 @@ public enum ResponseType { private final char rowTerminator; private final Converter converter; - /** Creates a {@link CopyDataResponse} message containing the fixed binary COPY header. */ - @InternalApi - public static CopyDataResponse createBinaryHeader(DataOutputStream output) { - return new CopyDataResponse(output, COPY_BINARY_HEADER.length + 8, ResponseType.HEADER); - } - /** Creates a {@link CopyDataResponse} message containing the fixed binary COPY trailer. */ @InternalApi - public static CopyDataResponse createBinaryTrailer(DataOutputStream output) { - return new CopyDataResponse(output, 2, ResponseType.TRAILER); + public static CopyDataResponse createBinaryTrailer( + DataOutputStream output, boolean includeHeader) { + return new CopyDataResponse( + output, + includeHeader ? 21 : 2, + includeHeader ? ResponseType.TRAILER_WITH_HEADER : ResponseType.TRAILER); } private CopyDataResponse(DataOutputStream output, int length, ResponseType responseType) { @@ -78,6 +76,7 @@ public CopyDataResponse(DataOutputStream output, Converter converter) { this.converter = converter; } + @Override public void send(boolean flush) throws Exception { if (converter != null) { this.length = 4 + converter.convertResultSetRowToDataRowResponse(); @@ -91,10 +90,11 @@ protected void sendPayload() throws Exception { this.outputStream.write(this.stringData.getBytes(StandardCharsets.UTF_8)); this.outputStream.write(this.rowTerminator); } else if (this.format == DataFormat.POSTGRESQL_BINARY) { - if (this.responseType == ResponseType.HEADER) { + if (this.responseType == ResponseType.TRAILER_WITH_HEADER) { this.outputStream.write(COPY_BINARY_HEADER); this.outputStream.writeInt(0); // flags this.outputStream.writeInt(0); // header extension area length + this.outputStream.writeShort(-1); } else if (this.responseType == ResponseType.TRAILER) { this.outputStream.writeShort(-1); } else if (this.converter != null) { diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyInResponse.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyInResponse.java index 58eee8e67..13c064eeb 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyInResponse.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyInResponse.java @@ -35,23 +35,25 @@ public class CopyInResponse extends WireOutput { protected static final char IDENTIFIER = 'G'; - private final int numColumns; - private final int formatCode; - private final byte[] columnFormat; + private final short numColumns; + private final byte formatCode; + private final short[] columnFormat; - public CopyInResponse(DataOutputStream output, int numColumns, int formatCode) { + public CopyInResponse(DataOutputStream output, short numColumns, byte formatCode) { super(output, calculateLength(numColumns)); this.numColumns = numColumns; this.formatCode = formatCode; - columnFormat = new byte[COLUMN_NUM_LENGTH * this.numColumns]; - Arrays.fill(columnFormat, (byte) formatCode); + this.columnFormat = new short[this.numColumns]; + Arrays.fill(this.columnFormat, formatCode); } @Override protected void sendPayload() throws IOException { this.outputStream.writeByte(this.formatCode); this.outputStream.writeShort(this.numColumns); - this.outputStream.write(this.columnFormat); + for (int i = 0; i < this.numColumns; i++) { + this.outputStream.writeShort(this.columnFormat[i]); + } } @Override diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/ErrorResponse.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/ErrorResponse.java index e37964fd7..35502d2d0 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/ErrorResponse.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/ErrorResponse.java @@ -15,9 +15,11 @@ package com.google.cloud.spanner.pgadapter.wireoutput; import com.google.api.core.InternalApi; +import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.error.PGException; +import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient; import com.google.common.base.Strings; -import java.io.DataOutputStream; +import com.google.common.collect.ImmutableList; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.text.MessageFormat; @@ -25,6 +27,7 @@ /** Sends error information back to client. */ @InternalApi public class ErrorResponse extends WireOutput { + private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; private static final int HEADER_LENGTH = 4; private static final int FIELD_IDENTIFIER_LENGTH = 1; @@ -36,58 +39,83 @@ public class ErrorResponse extends WireOutput { private static final byte HINT_FLAG = 'H'; private static final byte NULL_TERMINATOR = 0; - private final byte[] severity; - private final byte[] errorMessage; - private final byte[] errorState; - private final byte[] hints; - - public ErrorResponse(DataOutputStream output, PGException pgException) { - super(output, calculateLength(pgException)); - this.errorMessage = pgException.getMessage().getBytes(StandardCharsets.UTF_8); - this.errorState = pgException.getSQLState().getBytes(); - this.severity = pgException.getSeverity().name().getBytes(StandardCharsets.UTF_8); - this.hints = - Strings.isNullOrEmpty(pgException.getHints()) - ? null - : pgException.getHints().getBytes(StandardCharsets.UTF_8); + private final PGException pgException; + private final WellKnownClient client; + + public ErrorResponse(ConnectionHandler connection, PGException pgException) { + super( + connection.getConnectionMetadata().getOutputStream(), + calculateLength(pgException, connection.getWellKnownClient())); + this.pgException = pgException; + this.client = connection.getWellKnownClient(); } - static int calculateLength(PGException pgException) { + static int calculateLength(PGException pgException, WellKnownClient client) { int length = HEADER_LENGTH + FIELD_IDENTIFIER_LENGTH - + pgException.getSeverity().name().getBytes(StandardCharsets.UTF_8).length + + convertSeverityToWireProtocol(pgException).length + NULL_TERMINATOR_LENGTH + FIELD_IDENTIFIER_LENGTH - + pgException.getSQLState().getBytes().length + + convertSQLStateToWireProtocol(pgException).length + NULL_TERMINATOR_LENGTH + FIELD_IDENTIFIER_LENGTH - + pgException.getMessage().getBytes(StandardCharsets.UTF_8).length + + convertMessageToWireProtocol(pgException).length + NULL_TERMINATOR_LENGTH + NULL_TERMINATOR_LENGTH; - if (!Strings.isNullOrEmpty(pgException.getHints())) { - length += - FIELD_IDENTIFIER_LENGTH - + pgException.getHints().getBytes(StandardCharsets.UTF_8).length - + NULL_TERMINATOR_LENGTH; + byte[] hints = convertHintsToWireProtocol(pgException, client); + if (hints.length > 0) { + length += FIELD_IDENTIFIER_LENGTH + hints.length + NULL_TERMINATOR_LENGTH; } return length; } + static byte[] convertSeverityToWireProtocol(PGException pgException) { + return pgException.getSeverity().name().getBytes(StandardCharsets.UTF_8); + } + + static byte[] convertSQLStateToWireProtocol(PGException pgException) { + return pgException.getSQLState().getBytes(); + } + + static byte[] convertMessageToWireProtocol(PGException pgException) { + return pgException.getMessage().getBytes(StandardCharsets.UTF_8); + } + + static byte[] convertHintsToWireProtocol(PGException pgException, WellKnownClient client) { + if (Strings.isNullOrEmpty(pgException.getHints()) + && client.getErrorHints(pgException).isEmpty()) { + return EMPTY_BYTE_ARRAY; + } + String hints = ""; + if (!Strings.isNullOrEmpty(pgException.getHints())) { + hints += pgException.getHints(); + } + ImmutableList clientHints = client.getErrorHints(pgException); + if (!clientHints.isEmpty()) { + if (hints.length() > 0) { + hints += "\n"; + } + hints += String.join("\n", clientHints); + } + return hints.getBytes(StandardCharsets.UTF_8); + } + @Override protected void sendPayload() throws IOException { this.outputStream.writeByte(SEVERITY_FLAG); - this.outputStream.write(severity); + this.outputStream.write(convertSeverityToWireProtocol(pgException)); this.outputStream.writeByte(NULL_TERMINATOR); this.outputStream.writeByte(CODE_FLAG); - this.outputStream.write(this.errorState); + this.outputStream.write(convertSQLStateToWireProtocol(pgException)); this.outputStream.writeByte(NULL_TERMINATOR); this.outputStream.writeByte(MESSAGE_FLAG); - this.outputStream.write(this.errorMessage); + this.outputStream.write(convertMessageToWireProtocol(pgException)); this.outputStream.writeByte(NULL_TERMINATOR); - if (this.hints != null) { + byte[] hints = convertHintsToWireProtocol(pgException, client); + if (hints.length > 0) { this.outputStream.writeByte(HINT_FLAG); - this.outputStream.write(this.hints); + this.outputStream.write(hints); this.outputStream.writeByte(NULL_TERMINATOR); } this.outputStream.writeByte(NULL_TERMINATOR); @@ -109,8 +137,10 @@ protected String getPayloadString() { .format( new Object[] { this.length, - new String(this.errorMessage, UTF8), - this.hints == null ? "" : new String(this.hints, UTF8) + new String(convertMessageToWireProtocol(pgException), UTF8), + convertHintsToWireProtocol(pgException, client).length == 0 + ? "" + : new String(convertHintsToWireProtocol(pgException, client), UTF8) }); } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/NoticeResponse.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/NoticeResponse.java new file mode 100644 index 000000000..bcf3d7485 --- /dev/null +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/NoticeResponse.java @@ -0,0 +1,129 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.wireoutput; + +import com.google.api.core.InternalApi; +import com.google.cloud.spanner.pgadapter.error.SQLState; +import com.google.common.base.Preconditions; +import java.io.DataOutputStream; +import java.nio.charset.StandardCharsets; +import java.text.MessageFormat; + +/** + * Notices can be sent as asynchronous messages and can include warnings, informational messages, + * debug information, etc. + */ +@InternalApi +public class NoticeResponse extends WireOutput { + public enum NoticeSeverity { + WARNING, + NOTICE, + DEBUG, + INFO, + LOG, + TIP, + } + + private static final int HEADER_LENGTH = 4; + private static final int FIELD_IDENTIFIER_LENGTH = 1; + private static final int NULL_TERMINATOR_LENGTH = 1; + + private static final byte MESSAGE_FLAG = 'M'; + private static final byte CODE_FLAG = 'C'; + private static final byte SEVERITY_FLAG = 'S'; + private static final byte HINT_FLAG = 'H'; + private static final byte NULL_TERMINATOR = 0; + + private SQLState sqlState; + private final NoticeSeverity severity; + private final String message; + private final String hint; + + public NoticeResponse( + DataOutputStream output, + SQLState sqlState, + NoticeSeverity severity, + String message, + String hint) { + super( + output, + HEADER_LENGTH + + FIELD_IDENTIFIER_LENGTH + + Preconditions.checkNotNull(severity).name().getBytes(StandardCharsets.UTF_8).length + + NULL_TERMINATOR_LENGTH + + FIELD_IDENTIFIER_LENGTH + + sqlState.getBytes().length + + NULL_TERMINATOR_LENGTH + + FIELD_IDENTIFIER_LENGTH + + Preconditions.checkNotNull(message).getBytes(StandardCharsets.UTF_8).length + + NULL_TERMINATOR_LENGTH + + (hint == null + ? 0 + : (FIELD_IDENTIFIER_LENGTH + + hint.getBytes(StandardCharsets.UTF_8).length + + NULL_TERMINATOR_LENGTH)) + + NULL_TERMINATOR_LENGTH); + this.sqlState = sqlState; + this.severity = severity; + this.message = message; + this.hint = hint; + } + + @Override + protected void sendPayload() throws Exception { + this.outputStream.writeByte(SEVERITY_FLAG); + this.outputStream.write(severity.name().getBytes(StandardCharsets.UTF_8)); + this.outputStream.writeByte(NULL_TERMINATOR); + this.outputStream.writeByte(CODE_FLAG); + this.outputStream.write(sqlState.getBytes()); + this.outputStream.writeByte(NULL_TERMINATOR); + this.outputStream.writeByte(MESSAGE_FLAG); + this.outputStream.write(message.getBytes(StandardCharsets.UTF_8)); + this.outputStream.writeByte(NULL_TERMINATOR); + if (this.hint != null) { + this.outputStream.writeByte(HINT_FLAG); + this.outputStream.write(hint.getBytes(StandardCharsets.UTF_8)); + this.outputStream.writeByte(NULL_TERMINATOR); + } + this.outputStream.writeByte(NULL_TERMINATOR); + } + + @Override + public byte getIdentifier() { + return 'N'; + } + + @Override + protected String getMessageName() { + return "Notice"; + } + + public String getMessage() { + return this.message; + } + + public String getHint() { + return this.hint; + } + + @Override + protected String getPayloadString() { + return new MessageFormat("Length: {0}, Severity: {1}, Notice Message: {2}, Hint: {3}") + .format( + new Object[] { + this.length, this.severity, this.message, this.hint == null ? "" : this.hint + }); + } +} diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BootstrapMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BootstrapMessage.java index 8cc11a6c9..e4ac643e0 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BootstrapMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BootstrapMessage.java @@ -19,6 +19,7 @@ import com.google.cloud.spanner.pgadapter.session.SessionState; import com.google.cloud.spanner.pgadapter.wireoutput.AuthenticationOkResponse; import com.google.cloud.spanner.pgadapter.wireoutput.KeyDataResponse; +import com.google.cloud.spanner.pgadapter.wireoutput.NoticeResponse; import com.google.cloud.spanner.pgadapter.wireoutput.ParameterStatusResponse; import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse; import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse.Status; @@ -107,7 +108,11 @@ protected List parseParameterKeys(String rawParameters) { * @throws Exception */ public static void sendStartupMessage( - DataOutputStream output, int connectionId, int secret, SessionState sessionState) + DataOutputStream output, + int connectionId, + int secret, + SessionState sessionState, + Iterable startupNotices) throws Exception { new AuthenticationOkResponse(output).send(false); new KeyDataResponse(output, connectionId, secret).send(false); @@ -166,6 +171,9 @@ public static void sendStartupMessage( "TimeZone".getBytes(StandardCharsets.UTF_8), ZoneId.systemDefault().getId().getBytes(StandardCharsets.UTF_8)) .send(false); + for (NoticeResponse noticeResponse : startupNotices) { + noticeResponse.send(false); + } new ReadyResponse(output, Status.IDLE).send(); } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java index e233eb411..67ea96f8f 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java @@ -38,6 +38,7 @@ import com.google.cloud.spanner.pgadapter.error.Severity; import com.google.cloud.spanner.pgadapter.metadata.SendResultSetState; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.PartitionQueryResult; +import com.google.cloud.spanner.pgadapter.statements.CopyToStatement; import com.google.cloud.spanner.pgadapter.statements.IntermediateStatement; import com.google.cloud.spanner.pgadapter.utils.Converter; import com.google.cloud.spanner.pgadapter.wireoutput.CommandCompleteResponse; @@ -51,7 +52,6 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; -import com.google.common.util.concurrent.SettableFuture; import io.grpc.Context; import io.grpc.MethodDescriptor; import java.io.DataInputStream; @@ -59,6 +59,7 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.logging.Level; @@ -178,7 +179,7 @@ public static ControlMessage create(ConnectionHandler connection) throws Excepti connection.increaseInvalidMessageCount(); if (connection.getInvalidMessageCount() > MAX_INVALID_MESSAGE_COUNT) { new ErrorResponse( - connection.getConnectionMetadata().getOutputStream(), + connection, PGException.newBuilder( String.format( "Received %d invalid/unexpected messages. Last received message: '%c'", @@ -232,7 +233,7 @@ static PreparedType prepareType(char type) { * @throws Exception if there is some issue in the sending of the error messages. */ protected void handleError(Exception exception) throws Exception { - new ErrorResponse(this.outputStream, PGExceptionFactory.toPGException(exception)).send(false); + new ErrorResponse(this.connection, PGExceptionFactory.toPGException(exception)).send(false); } /** @@ -312,6 +313,7 @@ SendResultSetState sendResultSet( if (statementResult instanceof PartitionQueryResult) { hasData = false; PartitionQueryResult partitionQueryResult = (PartitionQueryResult) statementResult; + sendPrefix(describedResult, ((PartitionQueryResult) statementResult).getMetadataResultSet()); rows = sendPartitionedQuery( describedResult, @@ -321,17 +323,28 @@ SendResultSetState sendResultSet( } else { hasData = describedResult.isHasMoreData(); ResultSet resultSet = describedResult.getStatementResult().getResultSet(); + sendPrefix(describedResult, resultSet); SendResultSetRunnable runnable = - SendResultSetRunnable.forResultSet( - describedResult, resultSet, maxRows, mode, SettableFuture.create(), hasData); + SendResultSetRunnable.forResultSet(describedResult, resultSet, maxRows, mode, hasData); rows = runnable.call(); hasData = runnable.hasData; } + sendSuffix(describedResult); + return new SendResultSetState(describedResult.getCommandTag(), rows, hasData); + } + + private void sendPrefix(IntermediateStatement describedResult, ResultSet resultSet) + throws Exception { + for (WireOutput prefix : describedResult.createResultPrefix(resultSet)) { + prefix.send(false); + } + } + + private void sendSuffix(IntermediateStatement describedResult) throws Exception { for (WireOutput suffix : describedResult.createResultSuffix()) { suffix.send(false); } - return new SendResultSetState(describedResult.getCommandTag(), rows, hasData); } long sendPartitionedQuery( @@ -344,7 +357,6 @@ long sendPartitionedQuery( Executors.newFixedThreadPool( Math.min(8 * Runtime.getRuntime().availableProcessors(), partitions.size()))); List> futures = new ArrayList<>(partitions.size()); - SettableFuture prefixSent = SettableFuture.create(); Connection spannerConnection = connection.getSpannerConnection(); Spanner spanner = ConnectionOptionsHelper.getSpanner(spannerConnection); BatchClient batchClient = spanner.getBatchClient(connection.getDatabaseId()); @@ -361,6 +373,10 @@ public ApiCallContext configure( return GrpcCallContext.createDefault().withTimeout(Duration.ofHours(24L)); } }); + CountDownLatch binaryCopyHeaderSentLatch = + describedResult instanceof CopyToStatement && ((CopyToStatement) describedResult).isBinary() + ? new CountDownLatch(1) + : new CountDownLatch(0); for (int i = 0; i < partitions.size(); i++) { futures.add( executorService.submit( @@ -370,8 +386,7 @@ public ApiCallContext configure( batchReadOnlyTransaction, partitions.get(i), mode, - i == 0, - prefixSent)))); + binaryCopyHeaderSentLatch)))); } executorService.shutdown(); try { @@ -403,8 +418,7 @@ static final class SendResultSetRunnable implements Callable { private final Partition partition; private final long maxRows; private final QueryMode mode; - private final boolean includePrefix; - private final SettableFuture prefixSent; + private final CountDownLatch binaryCopyHeaderSentLatch; private boolean hasData; static SendResultSetRunnable forResultSet( @@ -412,10 +426,8 @@ static SendResultSetRunnable forResultSet( ResultSet resultSet, long maxRows, QueryMode mode, - SettableFuture prefixSent, boolean hasData) { - return new SendResultSetRunnable( - describedResult, resultSet, maxRows, mode, true, prefixSent, hasData); + return new SendResultSetRunnable(describedResult, resultSet, maxRows, mode, true, hasData); } static SendResultSetRunnable forPartition( @@ -423,10 +435,9 @@ static SendResultSetRunnable forPartition( BatchReadOnlyTransaction batchReadOnlyTransaction, Partition partition, QueryMode mode, - boolean includePrefix, - SettableFuture prefixSent) { + CountDownLatch binaryCopyHeaderSentLatch) { return new SendResultSetRunnable( - describedResult, batchReadOnlyTransaction, partition, mode, includePrefix, prefixSent); + describedResult, batchReadOnlyTransaction, partition, mode, binaryCopyHeaderSentLatch); } private SendResultSetRunnable( @@ -435,7 +446,6 @@ private SendResultSetRunnable( long maxRows, QueryMode mode, boolean includePrefix, - SettableFuture prefixSent, boolean hasData) { this.describedResult = describedResult; this.resultSet = resultSet; @@ -444,13 +454,15 @@ private SendResultSetRunnable( describedResult, mode, describedResult.getConnectionHandler().getServer().getOptions(), - resultSet); + resultSet, + includePrefix + && describedResult instanceof CopyToStatement + && ((CopyToStatement) describedResult).isBinary()); this.batchReadOnlyTransaction = null; this.partition = null; this.maxRows = maxRows; this.mode = mode; - this.includePrefix = includePrefix; - this.prefixSent = prefixSent; + this.binaryCopyHeaderSentLatch = new CountDownLatch(0); this.hasData = hasData; } @@ -459,16 +471,14 @@ private SendResultSetRunnable( BatchReadOnlyTransaction batchReadOnlyTransaction, Partition partition, QueryMode mode, - boolean includePrefix, - SettableFuture prefixSent) { + CountDownLatch binaryCopyHeaderSentLatch) { this.describedResult = describedResult; this.resultSet = null; this.batchReadOnlyTransaction = batchReadOnlyTransaction; this.partition = partition; this.maxRows = 0L; this.mode = mode; - this.includePrefix = includePrefix; - this.prefixSent = prefixSent; + this.binaryCopyHeaderSentLatch = binaryCopyHeaderSentLatch; this.hasData = false; } @@ -478,46 +488,30 @@ public Long call() throws Exception { if (resultSet == null && batchReadOnlyTransaction != null && partition != null) { // Note: It is OK not to close this result set, as the underlying transaction and session // will be cleaned up at a later moment. - try { - resultSet = batchReadOnlyTransaction.execute(partition); - converter = - new Converter( - describedResult, - mode, - describedResult.getConnectionHandler().getServer().getOptions(), - resultSet); - hasData = resultSet.next(); - } catch (Throwable t) { - if (includePrefix) { - synchronized (describedResult) { - prefixSent.setException(t); - } - } - throw t; - } - } - if (includePrefix) { - try { - for (WireOutput prefix : describedResult.createResultPrefix(resultSet)) { - prefix.send(false); - } - prefixSent.set(true); - } catch (Throwable t) { - prefixSent.setException(t); - throw t; - } + resultSet = batchReadOnlyTransaction.execute(partition); + hasData = resultSet.next(); + converter = + new Converter( + describedResult, + mode, + describedResult.getConnectionHandler().getServer().getOptions(), + resultSet, + false); } - // Wait until the prefix (if any) has been sent. - prefixSent.get(); long rows = 0L; while (hasData) { - if (Thread.interrupted()) { - throw PGExceptionFactory.newQueryCancelledException(); - } WireOutput wireOutput = describedResult.createDataRowResponse(converter); + if (!converter.isIncludeBinaryCopyHeaderInFirstRow()) { + binaryCopyHeaderSentLatch.await(); + } synchronized (describedResult) { wireOutput.send(false); } + binaryCopyHeaderSentLatch.countDown(); + if (Thread.interrupted()) { + throw PGExceptionFactory.newQueryCancelledException(); + } + rows++; hasData = resultSet.next(); if (rows % 1000 == 0) { diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java index 611c878d2..0eddd36ce 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java @@ -64,6 +64,7 @@ public ParseMessage(ConnectionHandler connection) throws Exception { } this.statement = createStatement(connection, name, parsedStatement, originalStatement, parameterDataTypes); + connection.maybeDetermineWellKnownClient(this); } /** diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/PasswordMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/PasswordMessage.java index a289b44a4..45d29d627 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/PasswordMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/PasswordMessage.java @@ -70,7 +70,7 @@ public PasswordMessage(ConnectionHandler connection, Map paramet protected void sendPayload() throws Exception { if (!useAuthentication()) { new ErrorResponse( - this.outputStream, + this.connection, PGException.newBuilder("Received PasswordMessage while authentication is disabled.") .setSQLState(SQLState.ProtocolViolation) .setSeverity(Severity.ERROR) @@ -83,7 +83,7 @@ protected void sendPayload() throws Exception { Credentials credentials = checkCredentials(this.username, this.password); if (credentials == null) { new ErrorResponse( - this.outputStream, + this.connection, PGException.newBuilder("Invalid credentials received.") .setHints( "PGAdapter expects credentials to be one of the following:\n" diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/StartupMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/StartupMessage.java index c2c7f0016..be2832f4f 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/StartupMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/StartupMessage.java @@ -26,7 +26,6 @@ import java.text.MessageFormat; import java.util.Map; import java.util.Map.Entry; -import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -53,14 +52,6 @@ public StartupMessage(ConnectionHandler connection, int length) throws Exception WellKnownClient wellKnownClient = ClientAutoDetector.detectClient(this.parseParameterKeys(rawParameters), this.parameters); connection.setWellKnownClient(wellKnownClient); - if (wellKnownClient != WellKnownClient.UNSPECIFIED) { - logger.log( - Level.INFO, - () -> - String.format( - "Well-known client %s detected for connection %d.", - wellKnownClient, connection.getConnectionId())); - } } else { connection.setWellKnownClient(WellKnownClient.UNSPECIFIED); } @@ -96,7 +87,8 @@ static void createConnectionAndSendStartupMessage( connection.getConnectionMetadata().getOutputStream(), connection.getConnectionId(), connection.getSecret(), - connection.getExtendedQueryProtocolHandler().getBackendConnection().getSessionState()); + connection.getExtendedQueryProtocolHandler().getBackendConnection().getSessionState(), + connection.getWellKnownClient().createStartupNoticeResponses(connection)); connection.setStatus(ConnectionStatus.AUTHENTICATED); } diff --git a/src/test/csharp/pgadapter_npgsql_tests/npgsql_tests/NpgsqlTest.cs b/src/test/csharp/pgadapter_npgsql_tests/npgsql_tests/NpgsqlTest.cs index 0d1665900..d499edb4c 100644 --- a/src/test/csharp/pgadapter_npgsql_tests/npgsql_tests/NpgsqlTest.cs +++ b/src/test/csharp/pgadapter_npgsql_tests/npgsql_tests/NpgsqlTest.cs @@ -60,6 +60,22 @@ public void TestShowServerVersion() } } + public void TestShowApplicationName() + { + using var connection = new NpgsqlConnection(ConnectionString); + connection.Open(); + + using var cmd = new NpgsqlCommand("show application_name", connection); + using (var reader = cmd.ExecuteReader()) + { + while (reader.Read()) + { + var applicationName = reader.GetString(0); + Console.WriteLine($"{applicationName}"); + } + } + } + public void TestSelect1() { using var connection = new NpgsqlConnection(ConnectionString); diff --git a/src/test/golang/pgadapter_pgx_tests/pgx.go b/src/test/golang/pgadapter_pgx_tests/pgx.go index 4afd39e03..637fc0310 100644 --- a/src/test/golang/pgadapter_pgx_tests/pgx.go +++ b/src/test/golang/pgadapter_pgx_tests/pgx.go @@ -80,6 +80,27 @@ func TestSelect1(connString string) *C.char { return nil } +//export TestShowApplicationName +func TestShowApplicationName(connString string) *C.char { + ctx := context.Background() + conn, err := pgx.Connect(ctx, connString) + if err != nil { + return C.CString(err.Error()) + } + defer conn.Close(ctx) + + var value string + err = conn.QueryRow(ctx, "show application_name").Scan(&value) + if err != nil { + return C.CString(err.Error()) + } + if g, w := value, "pgx"; g != w { + return C.CString(fmt.Sprintf("value mismatch\n Got: %v\nWant: %v", g, w)) + } + + return nil +} + //export TestQueryWithParameter func TestQueryWithParameter(connString string) *C.char { ctx := context.Background() diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java index e941b18e2..ef65e0a0d 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java @@ -366,11 +366,20 @@ public abstract class AbstractMockServerTest { Status.INVALID_ARGUMENT.withDescription("Statement is invalid.").asRuntimeException(); protected static ResultSet createAllTypesResultSet(String columnPrefix) { + return createAllTypesResultSet(columnPrefix, false); + } + + protected static ResultSet createAllTypesResultSet(String columnPrefix, boolean microsTimestamp) { + return createAllTypesResultSet("1", columnPrefix, microsTimestamp); + } + + protected static ResultSet createAllTypesResultSet( + String id, String columnPrefix, boolean microsTimestamp) { return com.google.spanner.v1.ResultSet.newBuilder() .setMetadata(createAllTypesResultSetMetadata(columnPrefix)) .addRows( ListValue.newBuilder() - .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues(Value.newBuilder().setStringValue(id).build()) .addValues(Value.newBuilder().setBoolValue(true).build()) .addValues( Value.newBuilder() @@ -382,7 +391,12 @@ protected static ResultSet createAllTypesResultSet(String columnPrefix) { .addValues(Value.newBuilder().setStringValue("100").build()) .addValues(Value.newBuilder().setStringValue("6.626").build()) .addValues( - Value.newBuilder().setStringValue("2022-02-16T13:18:02.123456789Z").build()) + Value.newBuilder() + .setStringValue( + microsTimestamp + ? "2022-02-16T13:18:02.123456Z" + : "2022-02-16T13:18:02.123456789Z") + .build()) .addValues(Value.newBuilder().setStringValue("2022-03-29").build()) .addValues(Value.newBuilder().setStringValue("test").build()) .addValues(Value.newBuilder().setStringValue("{\"key\": \"value\"}").build()) @@ -391,11 +405,18 @@ protected static ResultSet createAllTypesResultSet(String columnPrefix) { } protected static ResultSet createAllTypesNullResultSet(String columnPrefix) { + return createAllTypesNullResultSet(columnPrefix, null); + } + + protected static ResultSet createAllTypesNullResultSet(String columnPrefix, Long colBigInt) { return com.google.spanner.v1.ResultSet.newBuilder() .setMetadata(createAllTypesResultSetMetadata(columnPrefix)) .addRows( ListValue.newBuilder() - .addValues(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build()) + .addValues( + colBigInt == null + ? Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build() + : Value.newBuilder().setStringValue(String.valueOf(colBigInt)).build()) .addValues(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build()) .addValues(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build()) .addValues(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build()) diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/CopyInMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/CopyInMockServerTest.java index 03f7a7b20..b370979ad 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/CopyInMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/CopyInMockServerTest.java @@ -50,6 +50,7 @@ import com.google.spanner.v1.TypeCode; import io.grpc.Status; import java.io.BufferedReader; +import java.io.ByteArrayInputStream; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStreamReader; @@ -148,6 +149,27 @@ public void testCopyIn() throws SQLException, IOException { assertEquals(3, mutation.getInsert().getColumnsCount()); } + @Test + public void testEndRecord() throws SQLException, IOException { + setupCopyInformationSchemaResults(); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class)); + copyManager.copyIn("COPY users FROM STDIN;", new StringReader("5\t5\t5\n\\.\n7\t7\t7\n")); + } + + List commitRequests = mockSpanner.getRequestsOfType(CommitRequest.class); + assertEquals(1, commitRequests.size()); + CommitRequest commitRequest = commitRequests.get(0); + assertEquals(1, commitRequest.getMutationsCount()); + + Mutation mutation = commitRequest.getMutations(0); + assertEquals(OperationCase.INSERT, mutation.getOperationCase()); + // We should only receive 1 row, as there is an end record in the middle of the stream. + assertEquals(1, mutation.getInsert().getValuesCount()); + assertEquals(3, mutation.getInsert().getColumnsCount()); + } + @Test public void testCopyUpsert() throws SQLException, IOException { setupCopyInformationSchemaResults(); @@ -386,6 +408,19 @@ public void testCopyIn_Small() throws SQLException, IOException { } } + @Test + public void testCopyIn_Empty() throws SQLException, IOException { + setupCopyInformationSchemaResults(); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + PGConnection pgConnection = connection.unwrap(PGConnection.class); + CopyManager copyManager = pgConnection.getCopyAPI(); + long copyCount = + copyManager.copyIn("copy all_types from stdin;", new ByteArrayInputStream(new byte[0])); + assertEquals(0L, copyCount); + } + } + @Test public void testCopyIn_Nulls() throws SQLException, IOException { setupCopyInformationSchemaResults(); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java index fe7da284f..70fb1ce66 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java @@ -35,6 +35,7 @@ import com.google.cloud.spanner.connection.RandomResultSetGenerator; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.session.SessionState; +import com.google.cloud.spanner.pgadapter.statements.BackendConnection; import com.google.cloud.spanner.pgadapter.statements.CopyStatement.Format; import com.google.cloud.spanner.pgadapter.utils.CopyInParser; import com.google.cloud.spanner.pgadapter.utils.CopyRecord; @@ -290,7 +291,13 @@ public void testCopyOutCsv() throws SQLException, IOException { @Test public void testCopyOutCsvWithHeader() throws SQLException, IOException { mockSpanner.putStatementResult( - StatementResult.query(Statement.of("select * from all_types"), ALL_TYPES_RESULTSET)); + StatementResult.query( + Statement.of("select * from all_types"), + ALL_TYPES_RESULTSET + .toBuilder() + .addRows(ALL_TYPES_RESULTSET.getRows(0)) + .addRows(ALL_TYPES_NULLS_RESULTSET.getRows(0)) + .build())); try (Connection connection = DriverManager.getConnection(createUrl())) { connection.createStatement().execute("set time zone 'UTC'"); @@ -301,7 +308,9 @@ public void testCopyOutCsvWithHeader() throws SQLException, IOException { assertEquals( "col_bigint-col_bool-col_bytea-col_float8-col_int-col_numeric-col_timestamptz-col_date-col_varchar-col_jsonb\n" - + "1-t-\\x74657374-3.14-100-6.626-\"2022-02-16 13:18:02.123456+00\"-\"2022-03-29\"-test-\"{~\"key~\": ~\"value~\"}\"\n", + + "1-t-\\x74657374-3.14-100-6.626-\"2022-02-16 13:18:02.123456+00\"-\"2022-03-29\"-test-\"{~\"key~\": ~\"value~\"}\"\n" + + "1-t-\\x74657374-3.14-100-6.626-\"2022-02-16 13:18:02.123456+00\"-\"2022-03-29\"-test-\"{~\"key~\": ~\"value~\"}\"\n" + + "---------\n", writer.toString()); } } @@ -487,6 +496,47 @@ public void testCopyOutBinaryPsql() throws Exception { assertEquals(Value.string("test"), record.getValue(Type.string(), 8)); } + @Test + public void testCopyOutBinaryEmptyPsql() throws Exception { + assumeTrue("This test requires psql", isPsqlAvailable()); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of("select * from all_types"), + ALL_TYPES_RESULTSET.toBuilder().clearRows().build())); + + ProcessBuilder builder = new ProcessBuilder(); + builder.command( + "psql", + "-h", + (useDomainSocket ? "/tmp" : "localhost"), + "-p", + String.valueOf(pgServer.getLocalPort()), + "-c", + "copy all_types to stdout binary"); + Process process = builder.start(); + StringBuilder errorBuilder = new StringBuilder(); + try (Scanner scanner = new Scanner(new InputStreamReader(process.getErrorStream()))) { + while (scanner.hasNextLine()) { + errorBuilder.append(scanner.nextLine()).append('\n'); + } + } + PipedOutputStream pipedOutputStream = new PipedOutputStream(); + PipedInputStream inputStream = new PipedInputStream(pipedOutputStream, 1 << 16); + SessionState sessionState = mock(SessionState.class); + CopyInParser copyParser = + CopyInParser.create(sessionState, Format.BINARY, null, inputStream, false); + int b; + while ((b = process.getInputStream().read()) != -1) { + pipedOutputStream.write(b); + } + int res = process.waitFor(); + assertEquals("", errorBuilder.toString()); + assertEquals(0, res); + + Iterator iterator = copyParser.iterator(); + assertFalse(iterator.hasNext()); + } + @Test public void testCopyOutNullsBinaryPsql() throws Exception { assumeTrue("This test requires psql", isPsqlAvailable()); @@ -536,30 +586,64 @@ public void testCopyOutNullsBinaryPsql() throws Exception { @Test public void testCopyOutPartitioned() throws SQLException, IOException { - final int expectedRowCount = 100; - RandomResultSetGenerator randomResultSetGenerator = - new RandomResultSetGenerator(expectedRowCount, Dialect.POSTGRESQL); - com.google.spanner.v1.ResultSet resultSet = randomResultSetGenerator.generate(); - mockSpanner.putStatementResult( - StatementResult.query(Statement.of("select * from random"), resultSet)); + for (int expectedRowCount : new int[] {0, 1, 2, 3, 5, BackendConnection.MAX_PARTITIONS, 100}) { + RandomResultSetGenerator randomResultSetGenerator = + new RandomResultSetGenerator(expectedRowCount, Dialect.POSTGRESQL); + com.google.spanner.v1.ResultSet resultSet = randomResultSetGenerator.generate(); + mockSpanner.putStatementResult( + StatementResult.query(Statement.of("select * from random"), resultSet)); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class)); + StringWriter writer = new StringWriter(); + long rows = copyManager.copyOut("COPY random TO STDOUT", writer); + + assertEquals(expectedRowCount, rows); + + try (Scanner scanner = new Scanner(writer.toString())) { + int lineCount = 0; + while (scanner.hasNextLine()) { + lineCount++; + String line = scanner.nextLine(); + String[] columns = line.split("\t"); + int index = findIndex(resultSet, columns); + assertNotEquals(String.format("Row %d not found: %s", lineCount, line), -1, index); + } + assertEquals(expectedRowCount, lineCount); + } + } + } + } - try (Connection connection = DriverManager.getConnection(createUrl())) { - CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class)); - StringWriter writer = new StringWriter(); - long rows = copyManager.copyOut("COPY random TO STDOUT", writer); - - assertEquals(expectedRowCount, rows); - - try (Scanner scanner = new Scanner(writer.toString())) { - int lineCount = 0; - while (scanner.hasNextLine()) { - lineCount++; - String line = scanner.nextLine(); - String[] columns = line.split("\t"); - int index = findIndex(resultSet, columns); - assertNotEquals(String.format("Row %d not found: %s", lineCount, line), -1, index); + @Test + public void testCopyOutPartitionedBinary() throws SQLException, IOException { + for (int expectedRowCount : new int[] {0, 1, 2, 3, 5, BackendConnection.MAX_PARTITIONS, 100}) { + RandomResultSetGenerator randomResultSetGenerator = + new RandomResultSetGenerator(expectedRowCount, Dialect.POSTGRESQL); + com.google.spanner.v1.ResultSet resultSet = randomResultSetGenerator.generate(); + mockSpanner.putStatementResult( + StatementResult.query(Statement.of("select * from random"), resultSet)); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class)); + PipedOutputStream pipedOutputStream = new PipedOutputStream(); + PipedInputStream inputStream = new PipedInputStream(pipedOutputStream, 1 << 20); + SessionState sessionState = mock(SessionState.class); + CopyInParser copyParser = + CopyInParser.create(sessionState, Format.BINARY, null, inputStream, false); + long rows = copyManager.copyOut("COPY random TO STDOUT (format binary)", pipedOutputStream); + + assertEquals(expectedRowCount, rows); + + Iterator iterator = copyParser.iterator(); + int recordCount = 0; + while (iterator.hasNext()) { + recordCount++; + CopyRecord record = iterator.next(); + int index = findIndex(resultSet, record); + assertNotEquals(String.format("Row %d not found: %s", recordCount, record), -1, index); } - assertEquals(expectedRowCount, lineCount); + assertEquals(expectedRowCount, recordCount); } } } @@ -627,4 +711,45 @@ static int findIndex(com.google.spanner.v1.ResultSet resultSet, String[] cols) { } return -1; } + + static int findIndex(com.google.spanner.v1.ResultSet resultSet, CopyRecord record) { + for (int index = 0; index < resultSet.getRowsCount(); index++) { + boolean nullValuesEqual = true; + for (int colIndex = 0; colIndex < record.numColumns(); colIndex++) { + if (record.isNull(colIndex) + != resultSet.getRows(index).getValues(colIndex).hasNullValue()) { + nullValuesEqual = false; + break; + } + } + if (!nullValuesEqual) { + continue; + } + + boolean valuesEqual = true; + for (int colIndex = 0; colIndex < record.numColumns(); colIndex++) { + if (!resultSet.getRows(index).getValues(colIndex).hasNullValue()) { + if (resultSet.getMetadata().getRowType().getFields(colIndex).getType().getCode() + == TypeCode.STRING + && !record + .getValue(Type.string(), colIndex) + .getString() + .equals(resultSet.getRows(index).getValues(colIndex).getStringValue())) { + valuesEqual = false; + break; + } + if (resultSet.getRows(index).getValues(colIndex).hasBoolValue() + && record.getValue(Type.bool(), colIndex).getBool() + != resultSet.getRows(index).getValues(colIndex).getBoolValue()) { + valuesEqual = false; + break; + } + } + } + if (valuesEqual) { + return index; + } + } + return -1; + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/EmulatedPgbenchMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/EmulatedPgbenchMockServerTest.java new file mode 100644 index 000000000..9e477b2d8 --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/pgadapter/EmulatedPgbenchMockServerTest.java @@ -0,0 +1,114 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; + +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.pgadapter.error.SQLState; +import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector; +import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient; +import io.grpc.Status; +import java.nio.charset.StandardCharsets; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.SQLWarning; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.postgresql.PGProperty; +import org.postgresql.util.PSQLException; +import org.postgresql.util.PSQLWarning; + +@RunWith(JUnit4.class) +public class EmulatedPgbenchMockServerTest extends AbstractMockServerTest { + + @BeforeClass + public static void loadPgJdbcDriver() throws Exception { + // Make sure the PG JDBC driver is loaded. + Class.forName("org.postgresql.Driver"); + } + + /** + * Creates a JDBC connection string that instructs the PG JDBC driver to use the default simple + * mode. It also adds 'pgbench' as the application name, which will make PGAdapter automatically + * recognize the connection as a pgbench connection. + */ + private String createUrl() { + return String.format( + "jdbc:postgresql://localhost:%d/db?preferQueryMode=simple&%s=pgbench&%s=090000", + pgServer.getLocalPort(), + PGProperty.APPLICATION_NAME.getName(), + PGProperty.ASSUME_MIN_SERVER_VERSION.getName()); + } + + @Test + public void testClientDetectionAndHint() throws Exception { + WellKnownClient.PGBENCH.reset(); + // Verify that we get the notice response that indicates that the client was automatically + // detected. + try (Connection connection = DriverManager.getConnection(createUrl())) { + SQLWarning warning = connection.getWarnings(); + assertNotNull(warning); + PSQLWarning psqlWarning = (PSQLWarning) warning; + assertNotNull(psqlWarning.getServerErrorMessage()); + assertNotNull(psqlWarning.getServerErrorMessage().getSQLState()); + assertArrayEquals( + SQLState.Success.getBytes(), + psqlWarning.getServerErrorMessage().getSQLState().getBytes(StandardCharsets.UTF_8)); + assertEquals( + "Detected connection from pgbench", psqlWarning.getServerErrorMessage().getMessage()); + assertEquals( + ClientAutoDetector.PGBENCH_USAGE_HINT + "\n", + psqlWarning.getServerErrorMessage().getHint()); + + assertNull(warning.getNextWarning()); + } + + // Verify that creating a second connection directly afterwards with pgbench does not repeat the + // hint. + try (Connection connection = DriverManager.getConnection(createUrl())) { + SQLWarning warning = connection.getWarnings(); + assertNull(warning); + } + } + + @Test + public void testErrorHint() throws SQLException { + // Verify that any error message includes the pgbench usage hint. + String sql = "select * from foo where bar=1"; + mockSpanner.putStatementResult( + StatementResult.exception( + Statement.of(sql), + Status.INVALID_ARGUMENT.withDescription("test error").asRuntimeException())); + try (Connection connection = DriverManager.getConnection(createUrl())) { + PSQLException exception = + assertThrows(PSQLException.class, () -> connection.createStatement().execute(sql)); + assertNotNull(exception.getServerErrorMessage()); + assertEquals( + String.format("test error - Statement: '%s'", sql), + exception.getServerErrorMessage().getMessage()); + assertEquals( + ClientAutoDetector.PGBENCH_USAGE_HINT, exception.getServerErrorMessage().getHint()); + } + } +} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/EmulatedPsqlMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/EmulatedPsqlMockServerTest.java index 70660518e..86d53ee77 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/EmulatedPsqlMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/EmulatedPsqlMockServerTest.java @@ -195,6 +195,18 @@ public void testConnectToNonExistingDatabase() { } } + @Test + public void testShowApplicationName() throws SQLException { + try (Connection connection = DriverManager.getConnection(createUrl("my-db"))) { + try (ResultSet resultSet = + connection.createStatement().executeQuery("show application_name")) { + assertTrue(resultSet.next()); + assertEquals("psql", resultSet.getString(1)); + assertFalse(resultSet.next()); + } + } + } + @Test public void testConnectToNonExistingInstance() { try { diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/ITPsqlTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/ITPsqlTest.java index fc4adcb25..5984abb8e 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/ITPsqlTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/ITPsqlTest.java @@ -26,6 +26,7 @@ import com.google.cloud.spanner.Database; import com.google.cloud.spanner.KeySet; import com.google.cloud.spanner.Mutation; +import com.google.cloud.spanner.pgadapter.statements.CopyStatement.Format; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -47,10 +48,12 @@ import java.time.OffsetDateTime; import java.time.ZoneId; import java.time.ZoneOffset; +import java.time.zone.ZoneRulesException; import java.util.Arrays; import java.util.Collections; import java.util.Random; import java.util.Scanner; +import java.util.logging.Logger; import java.util.stream.Collectors; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -68,6 +71,8 @@ @Category(IntegrationTest.class) @RunWith(JUnit4.class) public class ITPsqlTest implements IntegrationTest { + private static final Logger logger = Logger.getLogger(ITPsqlTest.class.getName()); + private static final PgAdapterTestEnv testEnv = new PgAdapterTestEnv(); private static Database database; private static boolean allAssumptionsPassed = false; @@ -395,9 +400,28 @@ public void testPrepareExecuteDeallocate() throws IOException, InterruptedExcept } @Test - public void testCopyFromPostgreSQLToCloudSpanner() throws Exception { + public void testSetOperationWithOrderBy() throws IOException, InterruptedException { + // TODO: Remove + assumeTrue( + testEnv.getSpannerUrl().equals("https://staging-wrenchworks.sandbox.googleapis.com")); + + Tuple result = + runUsingPsql( + "select * from (select 1) one union all select * from (select 2) two order by 1"); + String output = result.x(), errors = result.y(); + assertEquals("", errors); + assertEquals(" ?column? \n----------\n 1\n 2\n(2 rows)\n", output); + } + + /** + * This test copies data back and forth between PostgreSQL and Cloud Spanner and verifies that the + * contents are equal after the COPY operation in both directions. + */ + @Test + public void testCopyBetweenPostgreSQLAndCloudSpanner() throws Exception { int numRows = 100; + logger.info("Copying initial data to PG"); // Generate 99 random rows. copyRandomRowsToPostgreSQL(numRows - 1); // Also add one row with all nulls to ensure that nulls are also copied correctly. @@ -410,6 +434,7 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception { } } + logger.info("Verifying initial data in PG"); // Verify that we have 100 rows in PostgreSQL. try (Connection connection = DriverManager.getConnection(createJdbcUrlForLocalPg(), POSTGRES_USER, POSTGRES_PASSWORD)) { @@ -421,12 +446,14 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception { } } - // Execute the COPY tests in both binary and text mode. - for (boolean binary : new boolean[] {false, true}) { + // Execute the COPY tests in both binary, csv and text mode. + for (Format format : Format.values()) { + logger.info("Testing format: " + format); // Make sure the all_types table on Cloud Spanner is empty. String databaseId = database.getId().getDatabase(); testEnv.write(databaseId, Collections.singleton(Mutation.delete("all_types", KeySet.all()))); + logger.info("Copying rows to Spanner"); // COPY the rows to Cloud Spanner. ProcessBuilder builder = new ProcessBuilder(); builder.command( @@ -441,8 +468,9 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception { + POSTGRES_USER + " -d " + POSTGRES_DATABASE - + " -c \"copy all_types to stdout" - + (binary ? " binary \" " : "\" ") + + " -c \"copy all_types to stdout (format " + + format + + ") \" " + " | psql " + " -h " + (POSTGRES_HOST.startsWith("/") ? "/tmp" : testEnv.getPGAdapterHost()) @@ -450,14 +478,16 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception { + testEnv.getPGAdapterPort() + " -d " + database.getId().getDatabase() - + " -c \"copy all_types from stdin " - + (binary ? "binary" : "") + + " -c \"copy all_types from stdin (format " + + format + + ")" + ";\"\n"); setPgPassword(builder); Process process = builder.start(); int res = process.waitFor(); assertEquals(0, res); + logger.info("Verifying data in Spanner"); // Verify that we now also have 100 rows in Spanner. try (Connection connection = DriverManager.getConnection(createJdbcUrlForPGAdapter())) { try (ResultSet resultSet = @@ -468,9 +498,11 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception { } } + logger.info("Comparing table contents"); // Verify that the rows in both databases are equal. compareTableContents(); + logger.info("Deleting all data in PG"); // Remove all rows in the table in the local PostgreSQL database and then copy everything from // Cloud Spanner to PostgreSQL. try (Connection connection = @@ -479,14 +511,16 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception { assertEquals(numRows, connection.createStatement().executeUpdate("delete from all_types")); } + logger.info("Copying rows to PG"); // COPY the rows from Cloud Spanner to PostgreSQL. ProcessBuilder copyToPostgresBuilder = new ProcessBuilder(); copyToPostgresBuilder.command( "bash", "-c", "psql" - + " -c \"copy all_types to stdout" - + (binary ? " binary \" " : "\" ") + + " -c \"copy all_types to stdout (format " + + format + + ") \" " + " -h " + (POSTGRES_HOST.startsWith("/") ? "/tmp" : testEnv.getPGAdapterHost()) + " -p " @@ -502,8 +536,9 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception { + POSTGRES_USER + " -d " + POSTGRES_DATABASE - + " -c \"copy all_types from stdin " - + (binary ? "binary" : "") + + " -c \"copy all_types from stdin (format " + + format + + ")" + ";\"\n"); setPgPassword(copyToPostgresBuilder); Process copyToPostgresProcess = copyToPostgresBuilder.start(); @@ -518,11 +553,118 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception { assertEquals("", errors.toString()); assertEquals(0, copyToPostgresResult); + logger.info("Compare table contents"); // Compare table contents again. compareTableContents(); } } + /** This test verifies that we can copy an empty table between PostgreSQL and Cloud Spanner. */ + @Test + public void testCopyEmptyTableBetweenCloudSpannerAndPostgreSQL() throws Exception { + logger.info("Deleting all data in PG"); + // Remove all rows in the table in the local PostgreSQL database. + try (Connection connection = + DriverManager.getConnection(createJdbcUrlForLocalPg(), POSTGRES_USER, POSTGRES_PASSWORD)) { + connection.createStatement().executeUpdate("delete from all_types"); + } + + // Execute the COPY tests in both binary, csv and text mode. + for (Format format : Format.values()) { + logger.info("Testing format: " + format); + // Make sure the all_types table on Cloud Spanner is empty. + String databaseId = database.getId().getDatabase(); + testEnv.write(databaseId, Collections.singleton(Mutation.delete("all_types", KeySet.all()))); + + logger.info("Copy empty table to CS"); + // COPY the empty table to Cloud Spanner. + ProcessBuilder builder = new ProcessBuilder(); + builder.command( + "bash", + "-c", + "psql" + + " -h " + + POSTGRES_HOST + + " -p " + + POSTGRES_PORT + + " -U " + + POSTGRES_USER + + " -d " + + POSTGRES_DATABASE + + " -c \"copy all_types to stdout (format " + + format + + ") \" " + + " | psql " + + " -h " + + (POSTGRES_HOST.startsWith("/") ? "/tmp" : testEnv.getPGAdapterHost()) + + " -p " + + testEnv.getPGAdapterPort() + + " -d " + + database.getId().getDatabase() + + " -c \"copy all_types from stdin (format " + + format + + ")" + + ";\"\n"); + setPgPassword(builder); + Process process = builder.start(); + int res = process.waitFor(); + assertEquals(0, res); + + logger.info("Verify that CS is empty"); + // Verify that still have 0 rows in Cloud Spanner. + try (Connection connection = DriverManager.getConnection(createJdbcUrlForPGAdapter())) { + try (ResultSet resultSet = + connection.createStatement().executeQuery("select count(*) from all_types")) { + assertTrue(resultSet.next()); + assertEquals(0, resultSet.getLong(1)); + assertFalse(resultSet.next()); + } + } + + logger.info("Copy empty table to PG"); + // COPY the empty table from Cloud Spanner to PostgreSQL. + ProcessBuilder copyToPostgresBuilder = new ProcessBuilder(); + copyToPostgresBuilder.command( + "bash", + "-c", + "psql" + + " -c \"copy all_types to stdout (format " + + format + + ") \"" + + " -h " + + (POSTGRES_HOST.startsWith("/") ? "/tmp" : testEnv.getPGAdapterHost()) + + " -p " + + testEnv.getPGAdapterPort() + + " -d " + + database.getId().getDatabase() + + " | psql " + + " -h " + + POSTGRES_HOST + + " -p " + + POSTGRES_PORT + + " -U " + + POSTGRES_USER + + " -d " + + POSTGRES_DATABASE + + " -c \"copy all_types from stdin (format " + + format + + ")" + + ";\"\n"); + setPgPassword(copyToPostgresBuilder); + Process copyToPostgresProcess = copyToPostgresBuilder.start(); + InputStream errorStream = copyToPostgresProcess.getErrorStream(); + int copyToPostgresResult = copyToPostgresProcess.waitFor(); + StringBuilder errors = new StringBuilder(); + try (Scanner scanner = new Scanner(new InputStreamReader(errorStream))) { + while (scanner.hasNextLine()) { + errors.append(errors).append(scanner.nextLine()).append("\n"); + } + } + assertEquals("", errors.toString()); + assertEquals(0, copyToPostgresResult); + } + } + @Test public void testTimestamptzParsing() throws Exception { final int numTests = 10; @@ -550,9 +692,13 @@ public void testTimestamptzParsing() throws Exception { "select name from pg_timezone_names where not name like '%%posix%%' and not name like 'Factory' offset %d limit 1", random.nextInt(numTimezones)))) { assertTrue(resultSet.next()); - timezone = resultSet.getString(1); - if (!PROBLEMATIC_ZONE_IDS.contains(ZoneId.of(timezone))) { - break; + try { + timezone = resultSet.getString(1); + if (!PROBLEMATIC_ZONE_IDS.contains(ZoneId.of(timezone))) { + break; + } + } catch (ZoneRulesException ignore) { + // Skip and try a different one if it is not a supported timezone on this system. } } } @@ -634,7 +780,15 @@ public void testTimestamptzParsing() throws Exception { // Mexico abolished DST in 2022, but not all databases contain this information. ZoneId.of("America/Chihuahua"), // Jordan abolished DST in 2022, but not all databases contain this information. - ZoneId.of("Asia/Amman")); + ZoneId.of("Asia/Amman"), + // Iran observed DST in 1978. Not all databases agree on this. + ZoneId.of("Asia/Tehran"), + // Rankin_Inlet did not observer DST in 1970-1979, but not all databases agree. + ZoneId.of("America/Rankin_Inlet"), + // Pangnirtung did not observer DST in 1970-1979, but not all databases agree. + ZoneId.of("America/Pangnirtung"), + // Niue switched from -11:30 to -11 in 1978. Not all JDKs know that. + ZoneId.of("Pacific/Niue")); private LocalDate generateRandomLocalDate() { return LocalDate.ofEpochDay(random.nextInt(365 * 100)); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java index 9fab59cfa..65ff4a7ca 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java @@ -93,6 +93,7 @@ import org.postgresql.PGStatement; import org.postgresql.core.Oid; import org.postgresql.jdbc.PgStatement; +import org.postgresql.util.PGobject; import org.postgresql.util.PSQLException; @RunWith(Parameterized.class) @@ -135,7 +136,7 @@ private String createUrl() { } private String getExpectedInitialApplicationName() { - return pgVersion.equals("1.0") ? null : "PostgreSQL JDBC Driver"; + return pgVersion.equals("1.0") ? "jdbc" : "PostgreSQL JDBC Driver"; } @Test @@ -165,6 +166,23 @@ public void testQuery() throws SQLException { } } + @Test + public void testShowApplicationName() throws SQLException { + try (Connection connection = DriverManager.getConnection(createUrl())) { + try (ResultSet resultSet = + connection.createStatement().executeQuery("show application_name")) { + assertTrue(resultSet.next()); + // If the PG version is 1.0, the JDBC driver thinks that the server does not support the + // application_name property and does not send any value. That means that PGAdapter fills it + // in automatically based on the client that is detected. + // Otherwise, the JDBC driver includes its own name, and that is not overwritten by + // PGAdapter. + assertEquals(getExpectedInitialApplicationName(), resultSet.getString(1)); + assertFalse(resultSet.next()); + } + } + } + @Test public void testPreparedStatementParameterMetadata() throws SQLException { String sql = "SELECT * FROM foo WHERE id=? or value=?"; @@ -637,6 +655,41 @@ public void testQueryWithLegacyDateParameter() throws SQLException { } } + @Test + public void testCharParam() throws SQLException { + String sql = "insert into foo values ($1)"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 1L)); + String jdbcSql = "insert into foo values (?)"; + try (Connection connection = DriverManager.getConnection(createUrl())) { + try (PreparedStatement preparedStatement = connection.prepareStatement(jdbcSql)) { + PGobject pgObject = new PGobject(); + pgObject.setType("char"); + pgObject.setValue("a"); + preparedStatement.setObject(1, pgObject); + assertEquals(1, preparedStatement.executeUpdate()); + } + } + + List parseMessages = + pgServer.getDebugMessages().stream() + .filter(message -> message instanceof ParseMessage) + .map(message -> (ParseMessage) message) + .collect(Collectors.toList()); + assertFalse(parseMessages.isEmpty()); + ParseMessage parseMessage = parseMessages.get(parseMessages.size() - 1); + assertEquals(1, parseMessage.getStatement().getGivenParameterDataTypes().length); + assertEquals(Oid.CHAR, parseMessage.getStatement().getGivenParameterDataTypes()[0]); + + List executeSqlRequests = + mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); + assertEquals(1, executeSqlRequests.size()); + ExecuteSqlRequest request = executeSqlRequests.get(0); + // Oid.CHAR is not a recognized type in PGAdapter. + assertEquals(0, request.getParamTypesCount()); + assertEquals(1, request.getParams().getFieldsCount()); + assertEquals("a", request.getParams().getFieldsMap().get("p1").getStringValue()); + } + @Test public void testAutoDescribedStatementsAreReused() throws SQLException { String jdbcSql = "select col_date from all_types where col_date=?"; @@ -1794,9 +1847,7 @@ public void testPreparedStatementReturning() throws SQLException { .bind("p9") .to("test") .bind("p10") - // TODO: Change to jsonb when https://github.com/googleapis/java-spanner/pull/2182 - // has been merged. - .to(com.google.cloud.spanner.Value.json("{\"key\": \"value\"}")) + .to(com.google.cloud.spanner.Value.pgJsonb("{\"key\": \"value\"}")) .build(), com.google.spanner.v1.ResultSet.newBuilder() .setMetadata(ALL_TYPES_METADATA) @@ -2880,6 +2931,21 @@ public void testSetTimeZone() throws SQLException { } } + @Test + public void testSetTimeZoneEST() throws SQLException { + try (Connection connection = DriverManager.getConnection(createUrl())) { + // Java considers 'EST' to always be '-05:00'. That is; it is never DST. + connection.createStatement().execute("set time zone 'EST'"); + verifySettingValue(connection, "timezone", "-05:00"); + // 'EST5EDT' is the ID for the timezone that will change with DST. + connection.createStatement().execute("set time zone 'EST5EDT'"); + verifySettingValue(connection, "timezone", "EST5EDT"); + // 'America/New_York' is the full name of the geographical timezone. + connection.createStatement().execute("set time zone 'America/New_York'"); + verifySettingValue(connection, "timezone", "America/New_York"); + } + } + @Test public void testSetTimeZoneToServerDefault() throws SQLException { try (Connection connection = DriverManager.getConnection(createUrl())) { diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/PsqlMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/PsqlMockServerTest.java index 8f0e4e935..8d177e938 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/PsqlMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/PsqlMockServerTest.java @@ -14,16 +14,23 @@ package com.google.cloud.spanner.pgadapter; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assume.assumeTrue; import com.google.cloud.spanner.MockSpannerServiceImpl; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.pgadapter.statements.CopyToStatement; import com.google.cloud.spanner.pgadapter.wireprotocol.SSLMessage; import com.google.common.collect.ImmutableList; +import com.google.common.io.ByteStreams; import com.google.spanner.admin.database.v1.UpdateDatabaseDdlRequest; import java.io.BufferedReader; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; import java.io.File; import java.io.InputStreamReader; import java.util.List; @@ -141,4 +148,91 @@ public void testSSLRequire() throws Exception { assertEquals( 1L, pgServer.getDebugMessages().stream().filter(m -> m instanceof SSLMessage).count()); } + + @Test + public void testCopyToBinaryPsql() throws Exception { + assumeTrue("This test requires psql to be installed", isPsqlAvailable()); + + ProcessBuilder builder = new ProcessBuilder(); + String[] psqlCommand = + new String[] { + "psql", + "-h", + "localhost", + "-p", + String.valueOf(pgServer.getLocalPort()), + "-c", + "COPY (SELECT 1) TO STDOUT (FORMAT BINARY)" + }; + builder.command(psqlCommand); + Process process = builder.start(); + String errors; + + try (BufferedReader errorReader = + new BufferedReader(new InputStreamReader(process.getErrorStream()))) { + errors = errorReader.lines().collect(Collectors.joining("\n")); + } + byte[] copyOutput = ByteStreams.toByteArray(process.getInputStream()); + + assertEquals("", errors); + int res = process.waitFor(); + assertEquals(0, res); + + ByteArrayOutputStream byteArrayOutputStream = + new ByteArrayOutputStream(CopyToStatement.COPY_BINARY_HEADER.length); + DataOutputStream expectedOutputStream = new DataOutputStream(byteArrayOutputStream); + expectedOutputStream.write(CopyToStatement.COPY_BINARY_HEADER); + expectedOutputStream.writeInt(0); // flags + expectedOutputStream.writeInt(0); // header extension area length + expectedOutputStream.writeShort(1); // Column count + expectedOutputStream.writeInt(8); // Value length + expectedOutputStream.writeLong(1L); // Column value + expectedOutputStream.writeShort(-1); // Column count == -1 means end of data. + + assertArrayEquals(byteArrayOutputStream.toByteArray(), copyOutput); + } + + @Test + public void testCopyToBinaryPsqlEmptyTable() throws Exception { + assumeTrue("This test requires psql to be installed", isPsqlAvailable()); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of("select * from all_types"), + ALL_TYPES_RESULTSET.toBuilder().clearRows().build())); + + ProcessBuilder builder = new ProcessBuilder(); + String[] psqlCommand = + new String[] { + "psql", + "-h", + "localhost", + "-p", + String.valueOf(pgServer.getLocalPort()), + "-c", + "COPY all_types TO STDOUT (FORMAT BINARY)" + }; + builder.command(psqlCommand); + Process process = builder.start(); + String errors; + + try (BufferedReader errorReader = + new BufferedReader(new InputStreamReader(process.getErrorStream()))) { + errors = errorReader.lines().collect(Collectors.joining("\n")); + } + byte[] copyOutput = ByteStreams.toByteArray(process.getInputStream()); + + assertEquals("", errors); + int res = process.waitFor(); + assertEquals(0, res); + + ByteArrayOutputStream byteArrayOutputStream = + new ByteArrayOutputStream(CopyToStatement.COPY_BINARY_HEADER.length); + DataOutputStream expectedOutputStream = new DataOutputStream(byteArrayOutputStream); + expectedOutputStream.write(CopyToStatement.COPY_BINARY_HEADER); + expectedOutputStream.writeInt(0); // flags + expectedOutputStream.writeInt(0); // header extension area length + expectedOutputStream.writeShort(-1); // Column count == -1 means end of data. + + assertArrayEquals(byteArrayOutputStream.toByteArray(), copyOutput); + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/csharp/AbstractNpgsqlMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/csharp/AbstractNpgsqlMockServerTest.java index c117a0a98..f36e83033 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/csharp/AbstractNpgsqlMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/csharp/AbstractNpgsqlMockServerTest.java @@ -182,12 +182,12 @@ public abstract class AbstractNpgsqlMockServerTest extends AbstractMockServerTes + " END AS elemtypoid\n" + " FROM pg_type AS typ\n" + " LEFT JOIN pg_class AS cls ON (cls.oid = typ.typrelid)\n" - + " LEFT JOIN pg_proc AS proc ON proc.oid = typ.typreceive\n" + + " LEFT JOIN pg_proc AS proc ON false\n" + " LEFT JOIN pg_range ON (pg_range.rngtypid = typ.oid)\n" + " ) AS typ\n" + " LEFT JOIN pg_type AS elemtyp ON elemtyp.oid = elemtypoid\n" + " LEFT JOIN pg_class AS elemcls ON (elemcls.oid = elemtyp.typrelid)\n" - + " LEFT JOIN pg_proc AS elemproc ON elemproc.oid = elemtyp.typreceive\n" + + " LEFT JOIN pg_proc AS elemproc ON false\n" + ") AS t\n" + "JOIN pg_namespace AS ns ON (ns.oid = typnamespace)\n" + "WHERE\n" @@ -481,7 +481,10 @@ public abstract class AbstractNpgsqlMockServerTest extends AbstractMockServerTes + "from information_schema.indexes i\n" + "inner join information_schema.index_columns using (table_catalog, table_schema, table_name)\n" + "group by i.index_name, i.table_schema\n" - + ")\n" + + "),\n" + + "pg_attribute as (\n" + + "select * from (select 0::bigint as attrelid, '' as attname, 0::bigint as atttypid, 0::bigint as attstattarget, 0::bigint as attlen, 0::bigint as attnum, 0::bigint as attndims, -1::bigint as attcacheoff, 0::bigint as atttypmod, true as attbyval, '' as attalign, '' as attstorage, '' as attcompression, false as attnotnull, true as atthasdef, false as atthasmissing, '' as attidentity, '' as attgenerated, false as attisdropped, true as attislocal, 0 as attinhcount, 0 as attcollation, '{}'::bigint[] as attacl, '{}'::text[] as attoptions, '{}'::text[] as attfdwoptions, null as attmissingval\n" + + ") a where false)\n" + "-- Load field definitions for (free-standing) composite types\n" + "SELECT typ.oid, att.attname, att.atttypid\n" + "FROM pg_type AS typ\n" @@ -519,7 +522,10 @@ public abstract class AbstractNpgsqlMockServerTest extends AbstractMockServerTes private static final Statement SELECT_ENUM_LABELS_STATEMENT = Statement.of( - "with pg_namespace as (\n" + "with pg_enum as (\n" + + "select * from (select 0::bigint as oid, 0::bigint as enumtypid, 0.0::float8 as enumsortorder, ''::varchar as enumlabel\n" + + ") e where false),\n" + + "pg_namespace as (\n" + " select case schema_name when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as oid,\n" + " schema_name as nspname, null as nspowner, null as nspacl\n" + " from information_schema.schemata\n" diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/csharp/NpgsqlMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/csharp/NpgsqlMockServerTest.java index 4928f8fe5..d4c51f998 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/csharp/NpgsqlMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/csharp/NpgsqlMockServerTest.java @@ -71,6 +71,12 @@ public void testShowServerVersion() throws IOException, InterruptedException { assertEquals("14.1\n", result); } + @Test + public void testShowApplicationName() throws IOException, InterruptedException { + String result = execute("TestShowApplicationName", createConnectionString()); + assertEquals("npgsql\n", result); + } + @Test public void testSelect1() throws IOException, InterruptedException { String result = execute("TestSelect1", createConnectionString()); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java index a2b62975c..43a33c04f 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java @@ -158,6 +158,16 @@ public void testSelect1() { assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); } + @Test + public void testShowApplicationName() { + String res = pgxTest.TestShowApplicationName(createConnString()); + + assertNull(res); + + // This should all be handled in PGAdapter. + assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + } + @Test public void testQueryWithParameter() { String sql = "SELECT * FROM FOO WHERE BAR=$1"; diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxTest.java index 3521d6cdb..bfdc362d0 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxTest.java @@ -23,6 +23,8 @@ public interface PgxTest extends Library { String TestSelect1(GoString connString); + String TestShowApplicationName(GoString connString); + String TestQueryWithParameter(GoString connString); String TestQueryAllDataTypes(GoString connString, int oid, int format); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java index 0478148ef..0e363f5df 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java @@ -259,9 +259,7 @@ public void testInsertAllTypes() throws IOException, InterruptedException { .bind("p9") .to("some-random-string") .bind("p10") - // TODO: Change to jsonb when https://github.com/googleapis/java-spanner/pull/2182 - // has been merged. - .to(Value.json("{\"my_key\":\"my-value\"}")) + .to(Value.pgJsonb("{\"my_key\":\"my-value\"}")) .build(), 1L); mockSpanner.putStatementResult(updateResult); @@ -386,9 +384,7 @@ public void testInsertAllTypesPreparedStatement() throws IOException, Interrupte .bind("p9") .to("some-random-string") .bind("p10") - // TODO: Change to jsonb when https://github.com/googleapis/java-spanner/pull/2182 - // has been merged. - .to(Value.json("{\"my_key\":\"my-value\"}")) + .to(Value.pgJsonb("{\"my_key\":\"my-value\"}")) .build(), 1L); mockSpanner.putStatementResult(updateResult); @@ -416,9 +412,7 @@ public void testInsertAllTypesPreparedStatement() throws IOException, Interrupte .bind("p9") .to((String) null) .bind("p10") - // TODO: Change to jsonb when https://github.com/googleapis/java-spanner/pull/2182 - // has been merged. - .to(Value.json(null)) + .to(Value.pgJsonb(null)) .build(), 1L)); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/TypeORMMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/TypeORMMockServerTest.java index a0f57da1a..80e7f90df 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/TypeORMMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/TypeORMMockServerTest.java @@ -728,9 +728,7 @@ public void testCreateAllTypes() throws IOException, InterruptedException { .bind("p9") .to("some random string") .bind("p10") - // TODO: Change to jsonb when https://github.com/googleapis/java-spanner/pull/2182 - // has been merged. - .to(com.google.cloud.spanner.Value.json("{\"key\":\"value\"}")) + .to(com.google.cloud.spanner.Value.pgJsonb("{\"key\":\"value\"}")) .build(), 1L)); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/ITSQLAlchemySampleTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/ITSQLAlchemySampleTest.java new file mode 100644 index 000000000..3a2bad963 --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/ITSQLAlchemySampleTest.java @@ -0,0 +1,69 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.python.sqlalchemy; + +import static com.google.cloud.spanner.pgadapter.python.sqlalchemy.SqlAlchemyBasicsTest.execute; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.spanner.Database; +import com.google.cloud.spanner.pgadapter.IntegrationTest; +import com.google.cloud.spanner.pgadapter.PgAdapterTestEnv; +import com.google.common.collect.ImmutableList; +import java.io.IOException; +import java.util.Collections; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@Category(IntegrationTest.class) +@RunWith(JUnit4.class) +public class ITSQLAlchemySampleTest implements IntegrationTest { + private static final PgAdapterTestEnv testEnv = new PgAdapterTestEnv(); + private static final String SAMPLE_DIR = "./samples/python/sqlalchemy-sample"; + + @BeforeClass + public static void setup() throws Exception { + testEnv.setUp(); + Database database = testEnv.createDatabase(ImmutableList.of()); + testEnv.startPGAdapterServerWithDefaultDatabase(database.getId(), Collections.emptyList()); + } + + @AfterClass + public static void teardown() { + testEnv.stopPGAdapterServer(); + testEnv.cleanUp(); + } + + @Test + public void testSQLAlchemySample() throws IOException, InterruptedException { + String output = + execute( + SAMPLE_DIR, + "run_sample.py", + "localhost", + testEnv.getServer().getLocalPort(), + testEnv.getDatabaseId()); + assertNotNull(output); + assertTrue( + output, + output.contains("No album found using a stale read.") + || output.contains( + "Album was found using a stale read, even though it has already been deleted.")); + } +} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/SqlAlchemyBasicsTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/SqlAlchemyBasicsTest.java new file mode 100644 index 000000000..0086ab217 --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/SqlAlchemyBasicsTest.java @@ -0,0 +1,551 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.python.sqlalchemy; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.RandomResultSetGenerator; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.pgadapter.AbstractMockServerTest; +import com.google.cloud.spanner.pgadapter.python.PythonTest; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.spanner.admin.database.v1.UpdateDatabaseDdlRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.RollbackRequest; +import com.google.spanner.v1.TypeCode; +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Scanner; +import java.util.stream.Collectors; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +@Category(PythonTest.class) +public class SqlAlchemyBasicsTest extends AbstractMockServerTest { + + @Parameter public String host; + + @Parameters(name = "host = {0}") + public static List data() { + return ImmutableList.of(new Object[] {"localhost"}, new Object[] {""}); + } + + static String execute(String script, String host, int port) + throws IOException, InterruptedException { + return execute("./src/test/python/sqlalchemy", script, host, port); + } + + static String execute(String directory, String script, String host, int port) + throws IOException, InterruptedException { + return execute(directory, script, host, port, null); + } + + static String execute(String directory, String script, String host, int port, String database) + throws IOException, InterruptedException { + String[] runCommand = + new String[] { + "python3", script, host, Integer.toString(port), database == null ? "d" : database + }; + ProcessBuilder builder = new ProcessBuilder(); + builder.command(runCommand); + builder.directory(new File(directory)); + Process process = builder.start(); + Scanner scanner = new Scanner(process.getInputStream()); + Scanner errorScanner = new Scanner(process.getErrorStream()); + + StringBuilder output = new StringBuilder(); + while (scanner.hasNextLine()) { + output.append(scanner.nextLine()).append("\n"); + } + StringBuilder error = new StringBuilder(); + while (errorScanner.hasNextLine()) { + error.append(errorScanner.nextLine()).append("\n"); + } + int result = process.waitFor(); + assertEquals(error.toString(), 0, result); + + return output.toString(); + } + + @BeforeClass + public static void setupBaseResults() { + String selectHstoreType = + "with pg_namespace as (\n" + + " select case schema_name when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as oid,\n" + + " schema_name as nspname, null as nspowner, null as nspacl\n" + + " from information_schema.schemata\n" + + "),\n" + + "pg_type as (\n" + + " select 16 as oid, 'bool' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, 1 as typlen, true as typbyval, 'b' as typtype, 'B' as typcategory, true as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1000 as typarray, 'boolin' as typinput, 'boolout' as typoutput, 'boolrecv' as typreceive, 'boolsend' as typsend, '-' as typmodin, '-' as typmodout, '-' as typanalyze, 'c' as typalign, 'p' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + + " select 17 as oid, 'bytea' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, -1 as typlen, false as typbyval, 'b' as typtype, 'U' as typcategory, false as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1001 as typarray, 'byteain' as typinput, 'byteaout' as typoutput, 'bytearecv' as typreceive, 'byteasend' as typsend, '-' as typmodin, '-' as typmodout, '-' as typanalyze, 'i' as typalign, 'x' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + + " select 20 as oid, 'int8' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, 8 as typlen, true as typbyval, 'b' as typtype, 'N' as typcategory, false as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1016 as typarray, 'int8in' as typinput, 'int8out' as typoutput, 'int8recv' as typreceive, 'int8send' as typsend, '-' as typmodin, '-' as typmodout, '-' as typanalyze, 'd' as typalign, 'p' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + + " select 21 as oid, 'int2' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, 2 as typlen, true as typbyval, 'b' as typtype, 'N' as typcategory, false as typispreferred, false as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1005 as typarray, 'int2in' as typinput, 'int2out' as typoutput, 'int2recv' as typreceive, 'int2send' as typsend, '-' as typmodin, '-' as typmodout, '-' as typanalyze, 's' as typalign, 'p' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + + " select 23 as oid, 'int4' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, 4 as typlen, true as typbyval, 'b' as typtype, 'N' as typcategory, false as typispreferred, false as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1007 as typarray, 'int4in' as typinput, 'int4out' as typoutput, 'int4recv' as typreceive, 'int4send' as typsend, '-' as typmodin, '-' as typmodout, '-' as typanalyze, 'i' as typalign, 'p' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + + " select 25 as oid, 'text' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, -1 as typlen, false as typbyval, 'b' as typtype, 'S' as typcategory, true as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1009 as typarray, 'textin' as typinput, 'textout' as typoutput, 'textrecv' as typreceive, 'textsend' as typsend, '-' as typmodin, '-' as typmodout, '-' as typanalyze, 'i' as typalign, 'x' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 100 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + + " select 700 as oid, 'float4' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, 4 as typlen, true as typbyval, 'b' as typtype, 'N' as typcategory, false as typispreferred, false as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1021 as typarray, 'float4in' as typinput, 'float4out' as typoutput, 'float4recv' as typreceive, 'float4send' as typsend, '-' as typmodin, '-' as typmodout, '-' as typanalyze, 'i' as typalign, 'p' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + + " select 701 as oid, 'float8' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, 8 as typlen, true as typbyval, 'b' as typtype, 'N' as typcategory, true as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1022 as typarray, 'float8in' as typinput, 'float8out' as typoutput, 'float8recv' as typreceive, 'float8send' as typsend, '-' as typmodin, '-' as typmodout, '-' as typanalyze, 'd' as typalign, 'p' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + + " select 1043 as oid, 'varchar' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, -1 as typlen, false as typbyval, 'b' as typtype, 'S' as typcategory, false as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1015 as typarray, 'varcharin' as typinput, 'varcharout' as typoutput, 'varcharrecv' as typreceive, 'varcharsend' as typsend, 'varchartypmodin' as typmodin, 'varchartypmodout' as typmodout, '-' as typanalyze, 'i' as typalign, 'x' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 100 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + + " select 1082 as oid, 'date' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, 4 as typlen, true as typbyval, 'b' as typtype, 'D' as typcategory, false as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1182 as typarray, 'date_in' as typinput, 'date_out' as typoutput, 'date_recv' as typreceive, 'date_send' as typsend, '-' as typmodin, '-' as typmodout, '-' as typanalyze, 'i' as typalign, 'p' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + + " select 1114 as oid, 'timestamp' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, 8 as typlen, true as typbyval, 'b' as typtype, 'D' as typcategory, false as typispreferred, false as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1115 as typarray, 'timestamp_in' as typinput, 'timestamp_out' as typoutput, 'timestamp_recv' as typreceive, 'timestamp_send' as typsend, 'timestamptypmodin' as typmodin, 'timestamptypmodout' as typmodout, '-' as typanalyze, 'd' as typalign, 'p' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + + " select 1184 as oid, 'timestamptz' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, 8 as typlen, true as typbyval, 'b' as typtype, 'D' as typcategory, true as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1185 as typarray, 'timestamptz_in' as typinput, 'timestamptz_out' as typoutput, 'timestamptz_recv' as typreceive, 'timestamptz_send' as typsend, 'timestamptztypmodin' as typmodin, 'timestamptztypmodout' as typmodout, '-' as typanalyze, 'd' as typalign, 'p' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + + " select 1700 as oid, 'numeric' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, -1 as typlen, false as typbyval, 'b' as typtype, 'N' as typcategory, false as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1231 as typarray, 'numeric_in' as typinput, 'numeric_out' as typoutput, 'numeric_recv' as typreceive, 'numeric_send' as typsend, 'numerictypmodin' as typmodin, 'numerictypmodout' as typmodout, '-' as typanalyze, 'i' as typalign, 'm' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + + " select 3802 as oid, 'jsonb' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, -1 as typlen, false as typbyval, 'b' as typtype, 'U' as typcategory, false as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 3807 as typarray, 'jsonb_in' as typinput, 'jsonb_out' as typoutput, 'jsonb_recv' as typreceive, 'jsonb_send' as typsend, '-' as typmodin, '-' as typmodout, '-' as typanalyze, 'i' as typalign, 'x' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl\n" + + ")\n" + + "SELECT t.oid, typarray\n" + + "FROM pg_type t JOIN pg_namespace ns\n" + + " ON typnamespace = ns.oid\n" + + "WHERE typname = 'hstore'"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(selectHstoreType), + ResultSet.newBuilder() + .setMetadata(createMetadata(ImmutableList.of(TypeCode.INT64, TypeCode.INT64))) + .build())); + } + + @Test + public void testHelloWorld() throws IOException, InterruptedException { + String sql = "select 'hello world'"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata(createMetadata(ImmutableList.of(TypeCode.STRING))) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("hello world").build()) + .build()) + .build())); + + String actualOutput = execute("hello_world.py", host, pgServer.getLocalPort()); + String expectedOutput = "[('hello world',)]\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql, request.getSql()); + assertTrue(request.getTransaction().hasSingleUse()); + assertTrue(request.getTransaction().getSingleUse().hasReadOnly()); + } + + @Test + public void testSimpleInsert() throws IOException, InterruptedException { + String sql1 = "INSERT INTO test VALUES (1, 'One')"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql1), 1L)); + String sql2 = "INSERT INTO test VALUES (2, 'Two')"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql2), 1L)); + + String actualOutput = execute("simple_insert.py", host, pgServer.getLocalPort()); + String expectedOutput = "Row count: 2\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(3, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request1 = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql1, request1.getSql()); + assertTrue(request1.getTransaction().hasBegin()); + assertTrue(request1.getTransaction().getBegin().hasReadWrite()); + ExecuteSqlRequest request2 = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(2); + assertEquals(sql2, request2.getSql()); + assertTrue(request2.getTransaction().hasId()); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void testEngineBegin() throws IOException, InterruptedException { + String sql1 = "INSERT INTO test VALUES (3, 'Three')"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql1), 1L)); + String sql2 = "INSERT INTO test VALUES (4, 'Four')"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql2), 1L)); + + String actualOutput = execute("engine_begin.py", host, pgServer.getLocalPort()); + String expectedOutput = "Row count: 2\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(3, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request1 = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql1, request1.getSql()); + assertTrue(request1.getTransaction().hasBegin()); + assertTrue(request1.getTransaction().getBegin().hasReadWrite()); + ExecuteSqlRequest request2 = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(2); + assertEquals(sql2, request2.getSql()); + assertTrue(request2.getTransaction().hasId()); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void testSessionExecute() throws IOException, InterruptedException { + String sql1 = "UPDATE test SET value='one' WHERE id=1"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql1), 1L)); + String sql2 = "UPDATE test SET value='two' WHERE id=2"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql2), 1L)); + + String actualOutput = execute("session_execute.py", host, pgServer.getLocalPort()); + String expectedOutput = "Row count: 2\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(3, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request1 = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql1, request1.getSql()); + assertTrue(request1.getTransaction().hasBegin()); + assertTrue(request1.getTransaction().getBegin().hasReadWrite()); + ExecuteSqlRequest request2 = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(2); + assertEquals(sql2, request2.getSql()); + assertTrue(request2.getTransaction().hasId()); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void testSimpleMetadata() throws Exception { + String checkTableExistsSql = + "with pg_class as (\n" + + " select\n" + + " -1 as oid,\n" + + " table_name as relname,\n" + + " case table_schema when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as relnamespace,\n" + + " 0 as reltype,\n" + + " 0 as reloftype,\n" + + " 0 as relowner,\n" + + " 1 as relam,\n" + + " 0 as relfilenode,\n" + + " 0 as reltablespace,\n" + + " 0 as relpages,\n" + + " 0.0::float8 as reltuples,\n" + + " 0 as relallvisible,\n" + + " 0 as reltoastrelid,\n" + + " false as relhasindex,\n" + + " false as relisshared,\n" + + " 'p' as relpersistence,\n" + + " 'r' as relkind,\n" + + " count(*) as relnatts,\n" + + " 0 as relchecks,\n" + + " false as relhasrules,\n" + + " false as relhastriggers,\n" + + " false as relhassubclass,\n" + + " false as relrowsecurity,\n" + + " false as relforcerowsecurity,\n" + + " true as relispopulated,\n" + + " 'n' as relreplident,\n" + + " false as relispartition,\n" + + " 0 as relrewrite,\n" + + " 0 as relfrozenxid,\n" + + " 0 as relminmxid,\n" + + " '{}'::bigint[] as relacl,\n" + + " '{}'::text[] as reloptions,\n" + + " 0 as relpartbound\n" + + "from information_schema.tables t\n" + + "inner join information_schema.columns using (table_catalog, table_schema, table_name)\n" + + "group by t.table_name, t.table_schema\n" + + "union all\n" + + "select\n" + + " -1 as oid,\n" + + " i.index_name as relname,\n" + + " case table_schema when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as relnamespace,\n" + + " 0 as reltype,\n" + + " 0 as reloftype,\n" + + " 0 as relowner,\n" + + " 1 as relam,\n" + + " 0 as relfilenode,\n" + + " 0 as reltablespace,\n" + + " 0 as relpages,\n" + + " 0.0::float8 as reltuples,\n" + + " 0 as relallvisible,\n" + + " 0 as reltoastrelid,\n" + + " false as relhasindex,\n" + + " false as relisshared,\n" + + " 'p' as relpersistence,\n" + + " 'r' as relkind,\n" + + " count(*) as relnatts,\n" + + " 0 as relchecks,\n" + + " false as relhasrules,\n" + + " false as relhastriggers,\n" + + " false as relhassubclass,\n" + + " false as relrowsecurity,\n" + + " false as relforcerowsecurity,\n" + + " true as relispopulated,\n" + + " 'n' as relreplident,\n" + + " false as relispartition,\n" + + " 0 as relrewrite,\n" + + " 0 as relfrozenxid,\n" + + " 0 as relminmxid,\n" + + " '{}'::bigint[] as relacl,\n" + + " '{}'::text[] as reloptions,\n" + + " 0 as relpartbound\n" + + "from information_schema.indexes i\n" + + "inner join information_schema.index_columns using (table_catalog, table_schema, table_name)\n" + + "group by i.index_name, i.table_schema\n" + + "),\n" + + "pg_namespace as (\n" + + " select case schema_name when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as oid,\n" + + " schema_name as nspname, null as nspowner, null as nspacl\n" + + " from information_schema.schemata\n" + + ")\n" + + "select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where true and relname='%s'"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(String.format(checkTableExistsSql, "user_account")), + ResultSet.newBuilder().setMetadata(SELECT1_RESULTSET.getMetadata()).build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(String.format(checkTableExistsSql, "address")), + ResultSet.newBuilder().setMetadata(SELECT1_RESULTSET.getMetadata()).build())); + addDdlResponseToSpannerAdmin(); + + String actualOutput = execute("simple_metadata.py", host, pgServer.getLocalPort()); + String expectedOutput = + "user_account.name\n" + + "['id', 'name', 'fullname']\n" + + "PrimaryKeyConstraint(Column('id', Integer(), table=, primary_key=True, nullable=False))\n" + + "user_account\n" + + "address\n"; + assertEquals(expectedOutput, actualOutput); + + List requests = + mockDatabaseAdmin.getRequests().stream() + .filter(req -> req instanceof UpdateDatabaseDdlRequest) + .map(req -> (UpdateDatabaseDdlRequest) req) + .collect(Collectors.toList()); + assertEquals(1, requests.size()); + assertEquals(2, requests.get(0).getStatementsCount()); + assertEquals( + "CREATE TABLE user_account (\n" + + "\tid SERIAL NOT NULL, \n" + + "\tname VARCHAR(30), \n" + + "\tfullname VARCHAR, \n" + + "\tPRIMARY KEY (id)\n" + + ")", + requests.get(0).getStatements(0)); + assertEquals( + "CREATE TABLE address (\n" + + "\tid SERIAL NOT NULL, \n" + + "\temail_address VARCHAR NOT NULL, \n" + + "\tuser_id INTEGER, \n" + + "\tPRIMARY KEY (id), \n" + + "\tFOREIGN KEY(user_id) REFERENCES user_account (id)\n" + + ")", + requests.get(0).getStatements(1)); + } + + @Test + public void testCoreInsert() throws IOException, InterruptedException { + String sql = + "INSERT INTO user_account (name, fullname) VALUES " + + "('spongebob', 'Spongebob Squarepants') RETURNING user_account.id"; + mockSpanner.putStatementResult(StatementResult.query(Statement.of(sql), SELECT1_RESULTSET)); + String sqlMultiple = + "INSERT INTO user_account (name, fullname) VALUES " + + "('sandy', 'Sandy Cheeks'),('patrick', 'Patrick Star')"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(sqlMultiple), 2L)); + + String actualOutput = execute("core_insert.py", host, pgServer.getLocalPort()); + String expectedOutput = + "INSERT INTO user_account (name, fullname) VALUES (:name, :fullname)\n" + + "{'name': 'spongebob', 'fullname': 'Spongebob Squarepants'}\n" + + "Result: []\n" + + "Row count: 1\n" + + "Inserted primary key: (1,)\n" + + "Row count: 2\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(3, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql, request.getSql()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadWrite()); + ExecuteSqlRequest multipleRequest = + mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(2); + assertTrue(multipleRequest.getTransaction().hasId()); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void testCoreInsertFromSelect() throws IOException, InterruptedException { + String sql = + "INSERT INTO address (user_id, email_address) " + + "SELECT user_account.id, user_account.name || '@aol.com' AS anon_1 \n" + + "FROM user_account RETURNING address.id, address.email_address"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata(createMetadata(ImmutableList.of(TypeCode.INT64, TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().setRowCountExact(2L).build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues(Value.newBuilder().setStringValue("test1@aol.com")) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("2").build()) + .addValues(Value.newBuilder().setStringValue("test2@aol.com")) + .build()) + .build())); + + String actualOutput = execute("core_insert_from_select.py", host, pgServer.getLocalPort()); + String expectedOutput = + "INSERT INTO address (user_id, email_address) " + + "SELECT user_account.id, user_account.name || :name_1 AS anon_1 \n" + + "FROM user_account\n" + + "Inserted rows: 2\n" + + "Returned rows: [(1, 'test1@aol.com'), (2, 'test2@aol.com')]\n"; + assertEquals(expectedOutput, actualOutput); + } + + @Test + public void testCoreSelect() throws IOException, InterruptedException { + String sql = + "SELECT user_account.id, user_account.name, user_account.fullname \n" + + "FROM user_account \n" + + "WHERE user_account.name = 'spongebob'"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata( + createMetadata( + ImmutableList.of(TypeCode.INT64, TypeCode.STRING, TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().setRowCountExact(2L).build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues(Value.newBuilder().setStringValue("test1@aol.com")) + .addValues(Value.newBuilder().setStringValue("Bob Test1")) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("2").build()) + .addValues(Value.newBuilder().setStringValue("test2@aol.com")) + .addValues(Value.newBuilder().setStringValue("Bob Test2")) + .build()) + .build())); + + String actualOutput = execute("core_select.py", host, pgServer.getLocalPort()); + String expectedOutput = + "SELECT user_account.id, user_account.name, user_account.fullname \n" + + "FROM user_account \n" + + "WHERE user_account.name = :name_1\n" + + "(1, 'test1@aol.com', 'Bob Test1')\n" + + "(2, 'test2@aol.com', 'Bob Test2')\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertEquals(2, mockSpanner.countRequestsOfType(RollbackRequest.class)); + } + + @Test + public void testAutoCommit() throws IOException, InterruptedException { + String sql = + "SELECT user_account.id, user_account.name, user_account.fullname \n" + + "FROM user_account \n" + + "WHERE user_account.name = 'spongebob'"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata( + createMetadata( + ImmutableList.of(TypeCode.INT64, TypeCode.STRING, TypeCode.STRING))) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues(Value.newBuilder().setStringValue("test1@aol.com")) + .addValues(Value.newBuilder().setStringValue("Bob Test1")) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("2").build()) + .addValues(Value.newBuilder().setStringValue("test2@aol.com")) + .addValues(Value.newBuilder().setStringValue("Bob Test2")) + .build()) + .build())); + String insertSql1 = + "INSERT INTO user_account (name, fullname) VALUES " + + "('sandy', 'Sandy Cheeks') RETURNING user_account.id"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(insertSql1), + ResultSet.newBuilder() + .setMetadata(SELECT1_RESULTSET.getMetadata()) + .setStats(ResultSetStats.newBuilder().setRowCountExact(1L).build()) + .addRows(SELECT1_RESULTSET.getRows(0)) + .build())); + String insertSql2 = + "INSERT INTO user_account (name, fullname) VALUES " + + "('patrick', 'Patrick Star') RETURNING user_account.id"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(insertSql2), + ResultSet.newBuilder() + .setMetadata(SELECT2_RESULTSET.getMetadata()) + .setStats(ResultSetStats.newBuilder().setRowCountExact(1L).build()) + .addRows(SELECT2_RESULTSET.getRows(0)) + .build())); + + String actualOutput = execute("autocommit.py", host, pgServer.getLocalPort()); + String expectedOutput = + "SERIALIZABLE\n" + + "(1, 'test1@aol.com', 'Bob Test1')\n" + + "(2, 'test2@aol.com', 'Bob Test2')\n" + + "Row count: 1\n" + + "Row count: 1\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(4, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest selectRequest = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertTrue(selectRequest.getTransaction().hasSingleUse()); + assertTrue(selectRequest.getTransaction().getSingleUse().hasReadOnly()); + + ExecuteSqlRequest insertRequest1 = + mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(2); + assertTrue(insertRequest1.getTransaction().hasBegin()); + assertTrue(insertRequest1.getTransaction().getBegin().hasReadWrite()); + ExecuteSqlRequest insertRequest2 = + mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(2); + assertTrue(insertRequest2.getTransaction().hasBegin()); + assertTrue(insertRequest2.getTransaction().getBegin().hasReadWrite()); + + assertEquals(2, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Ignore("requires DECLARE support, https://github.com/GoogleCloudPlatform/pgadapter/issues/510") + @Test + public void testServerSideCursors() throws IOException, InterruptedException { + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of("select * from random"), new RandomResultSetGenerator(100).generate())); + + String actualOutput = execute("server_side_cursor.py", host, pgServer.getLocalPort()); + String expectedOutput = ""; + assertEquals(expectedOutput, actualOutput); + } +} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/SqlAlchemyOrmTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/SqlAlchemyOrmTest.java new file mode 100644 index 000000000..4391a9dba --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/SqlAlchemyOrmTest.java @@ -0,0 +1,645 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.python.sqlalchemy; + +import static com.google.cloud.spanner.pgadapter.python.sqlalchemy.SqlAlchemyBasicsTest.execute; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.pgadapter.AbstractMockServerTest; +import com.google.cloud.spanner.pgadapter.python.PythonTest; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Duration; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.ExecuteSqlRequest.QueryMode; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.RollbackRequest; +import com.google.spanner.v1.TypeCode; +import io.grpc.Status; +import java.io.IOException; +import java.util.List; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +@Category(PythonTest.class) +public class SqlAlchemyOrmTest extends AbstractMockServerTest { + + @Parameter public String host; + + @Parameters(name = "host = {0}") + public static List data() { + return ImmutableList.of(new Object[] {"localhost"}, new Object[] {""}); + } + + @BeforeClass + public static void setupBaseResults() { + SqlAlchemyBasicsTest.setupBaseResults(); + } + + @Test + public void testInsertAllTypes() throws IOException, InterruptedException { + String sql = + "INSERT INTO all_types (col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + + "VALUES (1, true, '\\x74657374206279746573'::bytea, 3.14, 100, 6.626, '2011-11-04T00:05:23.123456+00:00'::timestamptz, '2011-11-04'::date, 'test string', '{\"key1\": \"value1\", \"key2\": \"value2\"}')"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 1L)); + + String actualOutput = execute("orm_insert.py", host, pgServer.getLocalPort()); + String expectedOutput = "Inserted 1 row(s)\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql, request.getSql()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void testInsertAllTypes_NullValues() throws IOException, InterruptedException { + // Note that the JSONB column is 'null' instead of NULL. That means that SQLAlchemy is inserting + // a JSON null values instead of a SQL NULL value into the column. This can be changed by + // creating the columns as Column(JSONB(none_as_null=True)) in the SQLAlchemy model. + String sql = + "INSERT INTO all_types (col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + + "VALUES (1, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, 'null')"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 1L)); + + String actualOutput = execute("orm_insert_null_values.py", host, pgServer.getLocalPort()); + String expectedOutput = "Inserted 1 row(s)\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql, request.getSql()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Ignore("requires support for literals like '2022-12-16 10:11:12+01:00'::timestamptz") + @Test + public void testInsertAllTypesWithPreparedStatement() throws IOException, InterruptedException { + String sql = + "INSERT INTO all_types (col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata( + createParameterTypesMetadata( + ImmutableList.of( + TypeCode.INT64, + TypeCode.BOOL, + TypeCode.BYTES, + TypeCode.FLOAT64, + TypeCode.INT64, + TypeCode.NUMERIC, + TypeCode.TIMESTAMP, + TypeCode.DATE, + TypeCode.STRING, + TypeCode.JSON))) + .setStats(ResultSetStats.newBuilder().build()) + .build())); + + String actualOutput = + execute("orm_insert_with_prepared_statement.py", host, pgServer.getLocalPort()); + String expectedOutput = "Inserted 1 row(s)\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql, request.getSql()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void testSelectAllTypes() throws IOException, InterruptedException { + String sql = + "SELECT all_types.col_bigint, all_types.col_bool, all_types.col_bytea, all_types.col_float8, all_types.col_int, all_types.col_numeric, all_types.col_timestamptz, all_types.col_date, all_types.col_varchar, all_types.col_jsonb \n" + + "FROM all_types"; + mockSpanner.putStatementResult( + StatementResult.query(Statement.of(sql), createAllTypesResultSet("", true))); + + String actualOutput = execute("orm_select_first.py", host, pgServer.getLocalPort()); + String expectedOutput = + "AllTypes(col_bigint= 1,col_bool= True,col_bytea= b'test'col_float8= 3.14col_int= 100col_numeric= Decimal('6.626')col_timestamptz=datetime.datetime(2022, 2, 16, 13, 18, 2, 123456, tzinfo=datetime.timezone.utc)col_date= datetime.date(2022, 3, 29)col_varchar= 'test'col_jsonb= {'key': 'value'})\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql, request.getSql()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(2, mockSpanner.countRequestsOfType(RollbackRequest.class)); + } + + @Test + public void testGetAllTypes() throws IOException, InterruptedException { + String sql = + "SELECT all_types.col_bigint AS all_types_col_bigint, all_types.col_bool AS all_types_col_bool, all_types.col_bytea AS all_types_col_bytea, all_types.col_float8 AS all_types_col_float8, all_types.col_int AS all_types_col_int, all_types.col_numeric AS all_types_col_numeric, all_types.col_timestamptz AS all_types_col_timestamptz, all_types.col_date AS all_types_col_date, all_types.col_varchar AS all_types_col_varchar, all_types.col_jsonb AS all_types_col_jsonb \n" + + "FROM all_types \n" + + "WHERE all_types.col_bigint = 1"; + mockSpanner.putStatementResult( + StatementResult.query(Statement.of(sql), createAllTypesResultSet("", true))); + + String actualOutput = execute("orm_get.py", host, pgServer.getLocalPort()); + String expectedOutput = + "AllTypes(col_bigint= 1,col_bool= True,col_bytea= b'test'col_float8= 3.14col_int= 100col_numeric= Decimal('6.626')col_timestamptz=datetime.datetime(2022, 2, 16, 13, 18, 2, 123456, tzinfo=datetime.timezone.utc)col_date= datetime.date(2022, 3, 29)col_varchar= 'test'col_jsonb= {'key': 'value'})\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql, request.getSql()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(2, mockSpanner.countRequestsOfType(RollbackRequest.class)); + } + + @Test + public void testGetAllTypes_NullValues() throws IOException, InterruptedException { + String sql = + "SELECT all_types.col_bigint AS all_types_col_bigint, all_types.col_bool AS all_types_col_bool, all_types.col_bytea AS all_types_col_bytea, all_types.col_float8 AS all_types_col_float8, all_types.col_int AS all_types_col_int, all_types.col_numeric AS all_types_col_numeric, all_types.col_timestamptz AS all_types_col_timestamptz, all_types.col_date AS all_types_col_date, all_types.col_varchar AS all_types_col_varchar, all_types.col_jsonb AS all_types_col_jsonb \n" + + "FROM all_types \n" + + "WHERE all_types.col_bigint = 1"; + mockSpanner.putStatementResult( + StatementResult.query(Statement.of(sql), createAllTypesNullResultSet("", 1L))); + + String actualOutput = execute("orm_get.py", host, pgServer.getLocalPort()); + String expectedOutput = + "AllTypes(col_bigint= 1,col_bool= None,col_bytea= Nonecol_float8= Nonecol_int= Nonecol_numeric= Nonecol_timestamptz=Nonecol_date= Nonecol_varchar= Nonecol_jsonb= None)\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql, request.getSql()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(2, mockSpanner.countRequestsOfType(RollbackRequest.class)); + } + + @Test + public void testGetAllTypesWithPreparedStatement() throws IOException, InterruptedException { + String sql = "select * from all_types where col_bigint=$1"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata( + createAllTypesResultSetMetadata("") + .toBuilder() + .setUndeclaredParameters( + createParameterTypesMetadata(ImmutableList.of(TypeCode.INT64)) + .getUndeclaredParameters())) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(sql).bind("p1").to(1L).build(), + createAllTypesResultSet("", true))); + + String actualOutput = + execute("orm_get_with_prepared_statement.py", host, pgServer.getLocalPort()); + String expectedOutput = + "AllTypes(col_bigint= 1,col_bool= True,col_bytea= b'test'col_float8= 3.14col_int= 100col_numeric= Decimal('6.626')col_timestamptz=datetime.datetime(2022, 2, 16, 13, 18, 2, 123456, tzinfo=datetime.timezone.utc)col_date= datetime.date(2022, 3, 29)col_varchar= 'test'col_jsonb= {'key': 'value'})\n"; + assertEquals(expectedOutput, actualOutput); + + // We receive 3 ExecuteSqlRequests: + // 1. Internal metadata query from SQLAlchemy (ignored in this test). + // 2. The SQL statement in PLAN mode to prepare the statement. + // 3. The SQL statement in NORMAL mode with 1 as the parameter value. + assertEquals(3, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest planRequest = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(QueryMode.PLAN, planRequest.getQueryMode()); + assertEquals(sql, planRequest.getSql()); + assertTrue(planRequest.getTransaction().hasBegin()); + assertTrue(planRequest.getTransaction().getBegin().hasReadWrite()); + + ExecuteSqlRequest executeRequest = + mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(2); + assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); + assertEquals(sql, executeRequest.getSql()); + assertTrue(executeRequest.getTransaction().hasId()); + + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(2, mockSpanner.countRequestsOfType(RollbackRequest.class)); + } + + @Test + public void testOrmReadOnlyTransaction() throws IOException, InterruptedException { + String sql = + "SELECT all_types.col_bigint AS all_types_col_bigint, all_types.col_bool AS all_types_col_bool, all_types.col_bytea AS all_types_col_bytea, all_types.col_float8 AS all_types_col_float8, all_types.col_int AS all_types_col_int, all_types.col_numeric AS all_types_col_numeric, all_types.col_timestamptz AS all_types_col_timestamptz, all_types.col_date AS all_types_col_date, all_types.col_varchar AS all_types_col_varchar, all_types.col_jsonb AS all_types_col_jsonb \n" + + "FROM all_types \n" + + "WHERE all_types.col_bigint = 1"; + mockSpanner.putStatementResult( + StatementResult.query(Statement.of(sql), createAllTypesResultSet("", true))); + + String actualOutput = execute("orm_read_only_transaction.py", host, pgServer.getLocalPort()); + String expectedOutput = + "AllTypes(col_bigint= 1,col_bool= True,col_bytea= b'test'col_float8= 3.14col_int= 100col_numeric= Decimal('6.626')col_timestamptz=datetime.datetime(2022, 2, 16, 13, 18, 2, 123456, tzinfo=datetime.timezone.utc)col_date= datetime.date(2022, 3, 29)col_varchar= 'test'col_jsonb= {'key': 'value'})\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql, request.getSql()); + assertFalse(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().hasId()); + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals( + 1, + mockSpanner.getRequestsOfType(BeginTransactionRequest.class).stream() + .filter(req -> req.getOptions().hasReadOnly()) + .count()); + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + // This rollback request comes from a system query before the actual data query. + // The read-only transaction is not committed or rolled back. + assertEquals(1, mockSpanner.countRequestsOfType(RollbackRequest.class)); + } + + @Test + public void testStaleRead() throws IOException, InterruptedException { + String sql = + "SELECT all_types.col_bigint AS all_types_col_bigint, all_types.col_bool AS all_types_col_bool, all_types.col_bytea AS all_types_col_bytea, all_types.col_float8 AS all_types_col_float8, all_types.col_int AS all_types_col_int, all_types.col_numeric AS all_types_col_numeric, all_types.col_timestamptz AS all_types_col_timestamptz, all_types.col_date AS all_types_col_date, all_types.col_varchar AS all_types_col_varchar, all_types.col_jsonb AS all_types_col_jsonb \n" + + "FROM all_types \n" + + "WHERE all_types.col_bigint = %d"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(String.format(sql, 1)), createAllTypesResultSet("1", "", true))); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(String.format(sql, 2)), createAllTypesResultSet("2", "", true))); + + String actualOutput = execute("orm_stale_read.py", host, pgServer.getLocalPort()); + assertTrue(actualOutput, actualOutput.contains("AllTypes(col_bigint= 1")); + assertTrue(actualOutput, actualOutput.contains("AllTypes(col_bigint= 2")); + + assertEquals(3, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + for (int index = 1; index < 3; index++) { + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(index); + assertTrue(request.getTransaction().hasSingleUse()); + assertTrue(request.getTransaction().getSingleUse().hasReadOnly()); + assertTrue(request.getTransaction().getSingleUse().getReadOnly().hasMaxStaleness()); + assertEquals( + Duration.newBuilder().setSeconds(10L).build(), + request.getTransaction().getSingleUse().getReadOnly().getMaxStaleness()); + } + // This rollback request comes from a system query before the actual data query. + // The read-only transaction is not committed or rolled back. + assertEquals(1, mockSpanner.countRequestsOfType(RollbackRequest.class)); + } + + @Test + public void testUpdateAllTypes() throws IOException, InterruptedException { + String sql = + "SELECT all_types.col_bigint AS all_types_col_bigint, all_types.col_bool AS all_types_col_bool, all_types.col_bytea AS all_types_col_bytea, all_types.col_float8 AS all_types_col_float8, all_types.col_int AS all_types_col_int, all_types.col_numeric AS all_types_col_numeric, all_types.col_timestamptz AS all_types_col_timestamptz, all_types.col_date AS all_types_col_date, all_types.col_varchar AS all_types_col_varchar, all_types.col_jsonb AS all_types_col_jsonb \n" + + "FROM all_types \n" + + "WHERE all_types.col_bigint = 1"; + mockSpanner.putStatementResult( + StatementResult.query(Statement.of(sql), createAllTypesResultSet("", true))); + String updateSql = + "UPDATE all_types SET col_varchar='updated string' WHERE all_types.col_bigint = 1"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(updateSql), 1L)); + + String actualOutput = execute("orm_update.py", host, pgServer.getLocalPort()); + String expectedOutput = + "AllTypes(col_bigint= 1,col_bool= True,col_bytea= b'test'col_float8= 3.14col_int= 100col_numeric= Decimal('6.626')col_timestamptz=datetime.datetime(2022, 2, 16, 13, 18, 2, 123456, tzinfo=datetime.timezone.utc)col_date= datetime.date(2022, 3, 29)col_varchar= 'updated string'col_jsonb= {'key': 'value'})\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(3, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest selectRequest = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql, selectRequest.getSql()); + assertTrue(selectRequest.getTransaction().hasBegin()); + assertTrue(selectRequest.getTransaction().getBegin().hasReadWrite()); + + ExecuteSqlRequest updateRequest = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(2); + assertEquals(updateSql, updateRequest.getSql()); + assertTrue(updateRequest.getTransaction().hasId()); + + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(RollbackRequest.class)); + } + + @Test + public void testDeleteAllTypes() throws IOException, InterruptedException { + String sql = + "SELECT all_types.col_bigint AS all_types_col_bigint, all_types.col_bool AS all_types_col_bool, all_types.col_bytea AS all_types_col_bytea, all_types.col_float8 AS all_types_col_float8, all_types.col_int AS all_types_col_int, all_types.col_numeric AS all_types_col_numeric, all_types.col_timestamptz AS all_types_col_timestamptz, all_types.col_date AS all_types_col_date, all_types.col_varchar AS all_types_col_varchar, all_types.col_jsonb AS all_types_col_jsonb \n" + + "FROM all_types \n" + + "WHERE all_types.col_bigint = 1"; + mockSpanner.putStatementResult( + StatementResult.query(Statement.of(sql), createAllTypesResultSet("", true))); + String deleteSql = "DELETE FROM all_types WHERE all_types.col_bigint = 1"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(deleteSql), 1L)); + + String actualOutput = execute("orm_delete.py", host, pgServer.getLocalPort()); + String expectedOutput = "deleted row\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(3, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest selectRequest = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql, selectRequest.getSql()); + assertTrue(selectRequest.getTransaction().hasBegin()); + assertTrue(selectRequest.getTransaction().getBegin().hasReadWrite()); + + ExecuteSqlRequest deleteRequest = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(2); + assertEquals(deleteSql, deleteRequest.getSql()); + assertTrue(deleteRequest.getTransaction().hasId()); + + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(RollbackRequest.class)); + } + + @Test + public void testRollback() throws IOException, InterruptedException { + String sql = + "SELECT all_types.col_bigint AS all_types_col_bigint, all_types.col_bool AS all_types_col_bool, all_types.col_bytea AS all_types_col_bytea, all_types.col_float8 AS all_types_col_float8, all_types.col_int AS all_types_col_int, all_types.col_numeric AS all_types_col_numeric, all_types.col_timestamptz AS all_types_col_timestamptz, all_types.col_date AS all_types_col_date, all_types.col_varchar AS all_types_col_varchar, all_types.col_jsonb AS all_types_col_jsonb \n" + + "FROM all_types \n" + + "WHERE all_types.col_bigint = 1"; + mockSpanner.putStatementResult( + StatementResult.query(Statement.of(sql), createAllTypesResultSet("", true))); + String updateSql = + "UPDATE all_types SET col_varchar='updated string' WHERE all_types.col_bigint = 1"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(updateSql), 1L)); + + String actualOutput = execute("orm_rollback.py", host, pgServer.getLocalPort()); + String expectedOutput = + "Before rollback: AllTypes(col_bigint= 1,col_bool= True,col_bytea= b'test'col_float8= 3.14col_int= 100col_numeric= Decimal('6.626')col_timestamptz=datetime.datetime(2022, 2, 16, 13, 18, 2, 123456, tzinfo=datetime.timezone.utc)col_date= datetime.date(2022, 3, 29)col_varchar= 'updated string'col_jsonb= {'key': 'value'})\n" + + "After rollback: AllTypes(col_bigint= 1,col_bool= True,col_bytea= b'test'col_float8= 3.14col_int= 100col_numeric= Decimal('6.626')col_timestamptz=datetime.datetime(2022, 2, 16, 13, 18, 2, 123456, tzinfo=datetime.timezone.utc)col_date= datetime.date(2022, 3, 29)col_varchar= 'test'col_jsonb= {'key': 'value'})\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(4, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest selectRequest = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql, selectRequest.getSql()); + assertTrue(selectRequest.getTransaction().hasBegin()); + assertTrue(selectRequest.getTransaction().getBegin().hasReadWrite()); + + ExecuteSqlRequest updateRequest = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(2); + assertEquals(updateSql, updateRequest.getSql()); + assertTrue(updateRequest.getTransaction().hasId()); + + ExecuteSqlRequest refreshRequest = + mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(3); + assertEquals(sql, refreshRequest.getSql()); + assertTrue(refreshRequest.getTransaction().hasBegin()); + assertTrue(refreshRequest.getTransaction().getBegin().hasReadWrite()); + + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(3, mockSpanner.countRequestsOfType(RollbackRequest.class)); + } + + @Test + public void testCreateRelationships() throws IOException, InterruptedException { + String insertUserSql = + "INSERT INTO user_account (name, fullname) " + + "VALUES ('pkrabs', 'Pearl Krabs') " + + "RETURNING user_account.id"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(insertUserSql), + SELECT1_RESULTSET + .toBuilder() + .setStats(ResultSetStats.newBuilder().setRowCountExact(1L).build()) + .build())); + String insertAddressesSql = + "INSERT INTO address (email_address, user_id) " + + "VALUES ('pearl.krabs@gmail.com', 1),('pearl@aol.com', 1) " + + "RETURNING address.id"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(insertAddressesSql), + ResultSet.newBuilder() + .setMetadata(SELECT1_RESULTSET.getMetadata()) + .setStats(ResultSetStats.newBuilder().setRowCountExact(2L).build()) + .addRows(SELECT1_RESULTSET.getRows(0)) + .addRows(SELECT2_RESULTSET.getRows(0)) + .build())); + + String actualOutput = execute("orm_create_relationships.py", host, pgServer.getLocalPort()); + String expectedOutput = + "[]\n" + + "[Address(id=None, email_address='pearl.krabs@gmail.com')]\n" + + "User(id=None, name='pkrabs', fullname='Pearl Krabs')\n" + + "[Address(id=None, email_address='pearl.krabs@gmail.com'), Address(id=None, email_address='pearl@aol.com')]\n" + + "True\n" + + "True\n" + + "True\n" + + "None\n" + + "None\n" + + "None\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(3, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest insertUserRequest = + mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(insertUserSql, insertUserRequest.getSql()); + assertTrue(insertUserRequest.getTransaction().hasBegin()); + assertTrue(insertUserRequest.getTransaction().getBegin().hasReadWrite()); + ExecuteSqlRequest insertAddressesRequest = + mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(2); + assertEquals(insertAddressesSql, insertAddressesRequest.getSql()); + assertTrue(insertAddressesRequest.getTransaction().hasId()); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void testLoadRelationships() throws IOException, InterruptedException { + String selectUsersSql = + "SELECT user_account.id, user_account.name, user_account.fullname \n" + + "FROM user_account ORDER BY user_account.id"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(selectUsersSql), + ResultSet.newBuilder() + .setMetadata( + createMetadata( + ImmutableList.of(TypeCode.INT64, TypeCode.STRING, TypeCode.STRING))) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues(Value.newBuilder().setStringValue("spongebob")) + .addValues(Value.newBuilder().setStringValue("spongebob squarepants")) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("2").build()) + .addValues(Value.newBuilder().setStringValue("sandy")) + .addValues(Value.newBuilder().setStringValue("sandy oyster")) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("3").build()) + .addValues(Value.newBuilder().setStringValue("patrick")) + .addValues(Value.newBuilder().setStringValue("patrick sea")) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("4").build()) + .addValues(Value.newBuilder().setStringValue("squidward")) + .addValues(Value.newBuilder().setStringValue("squidward manyarms")) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("5").build()) + .addValues(Value.newBuilder().setStringValue("ehkrabs")) + .addValues(Value.newBuilder().setStringValue("ehkrabs hibernate")) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("6").build()) + .addValues(Value.newBuilder().setStringValue("pkrabs")) + .addValues(Value.newBuilder().setStringValue("pkrabs primary")) + .build()) + .build())); + String selectAddressesSql = + "SELECT address.user_id AS address_user_id, address.id AS address_id, address.email_address AS address_email_address \n" + + "FROM address \n" + + "WHERE address.user_id IN (1, 2, 3, 4, 5, 6)"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(selectAddressesSql), + ResultSet.newBuilder() + .setMetadata( + createMetadata( + ImmutableList.of(TypeCode.INT64, TypeCode.INT64, TypeCode.STRING))) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues(Value.newBuilder().setStringValue("1")) + .addValues(Value.newBuilder().setStringValue("spongebob@sqlalchemy.org")) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("2").build()) + .addValues(Value.newBuilder().setStringValue("2")) + .addValues(Value.newBuilder().setStringValue("sandy@sqlalchemy.org")) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("2").build()) + .addValues(Value.newBuilder().setStringValue("3")) + .addValues(Value.newBuilder().setStringValue("sandy@squirrelpower.org")) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("6").build()) + .addValues(Value.newBuilder().setStringValue("4")) + .addValues(Value.newBuilder().setStringValue("pearl.krabs@gmail.com")) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("6").build()) + .addValues(Value.newBuilder().setStringValue("5")) + .addValues(Value.newBuilder().setStringValue("pearl@aol.com")) + .build()) + .build())); + + String actualOutput = execute("orm_load_relationships.py", host, pgServer.getLocalPort()); + String expectedOutput = + "spongebob (spongebob@sqlalchemy.org) \n" + + "sandy (sandy@sqlalchemy.org, sandy@squirrelpower.org) \n" + + "patrick () \n" + + "squidward () \n" + + "ehkrabs () \n" + + "pkrabs (pearl.krabs@gmail.com, pearl@aol.com) \n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(3, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest selectUsersRequest = + mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(selectUsersSql, selectUsersRequest.getSql()); + assertTrue(selectUsersRequest.getTransaction().hasBegin()); + assertTrue(selectUsersRequest.getTransaction().getBegin().hasReadWrite()); + ExecuteSqlRequest selectAddressesRequest = + mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(2); + assertEquals(selectAddressesSql, selectAddressesRequest.getSql()); + assertTrue(selectAddressesRequest.getTransaction().hasId()); + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(2, mockSpanner.countRequestsOfType(RollbackRequest.class)); + } + + @Test + public void testErrorInReadWriteTransaction() throws IOException, InterruptedException { + String sql = + "INSERT INTO all_types (col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + + "VALUES (1, true, '\\x74657374206279746573'::bytea, 3.14, 100, 6.626, '2011-11-04T00:05:23.123456+00:00'::timestamptz, '2011-11-04'::date, 'test string', '{\"key1\": \"value1\", \"key2\": \"value2\"}')"; + mockSpanner.putStatementResult( + StatementResult.exception( + Statement.of(sql), + Status.ALREADY_EXISTS + .withDescription("Row with id 1 already exists") + .asRuntimeException())); + + String actualOutput = + execute("orm_error_in_read_write_transaction.py", host, pgServer.getLocalPort()); + String expectedOutput = + "Insert failed: (psycopg2.errors.RaiseException) com.google.api.gax.rpc.AlreadyExistsException: io.grpc.StatusRuntimeException: ALREADY_EXISTS: Row with id 1 already exists"; + assertTrue(actualOutput, actualOutput.startsWith(expectedOutput)); + assertFalse(actualOutput, actualOutput.contains("Getting the row after an error succeeded")); + + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql, request.getSql()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void testErrorInReadWriteTransactionContinue() throws IOException, InterruptedException { + String sql = + "INSERT INTO all_types (col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + + "VALUES (1, true, '\\x74657374206279746573'::bytea, 3.14, 100, 6.626, '2011-11-04T00:05:23.123456+00:00'::timestamptz, '2011-11-04'::date, 'test string', '{\"key1\": \"value1\", \"key2\": \"value2\"}')"; + mockSpanner.putStatementResult( + StatementResult.exception( + Statement.of(sql), + Status.ALREADY_EXISTS + .withDescription("Row with id 1 already exists") + .asRuntimeException())); + + String actualOutput = + execute("orm_error_in_read_write_transaction_continue.py", host, pgServer.getLocalPort()); + String expectedOutput = + "Getting the row failed: This Session's transaction has been rolled back due to a previous exception during flush. " + + "To begin a new transaction with this Session, first issue Session.rollback(). " + + "Original exception was: (psycopg2.errors.RaiseException) com.google.api.gax.rpc.AlreadyExistsException: io.grpc.StatusRuntimeException: ALREADY_EXISTS: Row with id 1 already exists"; + assertTrue(actualOutput, actualOutput.startsWith(expectedOutput)); + + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1); + assertEquals(sql, request.getSql()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + } +} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/SqlAlchemySampleTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/SqlAlchemySampleTest.java new file mode 100644 index 000000000..cd8c409b4 --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/SqlAlchemySampleTest.java @@ -0,0 +1,1368 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.python.sqlalchemy; + +import static com.google.cloud.spanner.pgadapter.python.sqlalchemy.SqlAlchemyBasicsTest.execute; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.ByteArray; +import com.google.cloud.Date; +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.pgadapter.AbstractMockServerTest; +import com.google.cloud.spanner.pgadapter.python.PythonTest; +import com.google.protobuf.Duration; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.spanner.admin.database.v1.UpdateDatabaseDdlRequest; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeAnnotationCode; +import com.google.spanner.v1.TypeCode; +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +@Category(PythonTest.class) +public class SqlAlchemySampleTest extends AbstractMockServerTest { + private static final String SAMPLE_DIR = "./samples/python/sqlalchemy-sample"; + + @BeforeClass + public static void setupBaseResults() { + SqlAlchemyBasicsTest.setupBaseResults(); + } + + @Test + public void testDeleteAlbum() throws IOException, InterruptedException { + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT albums.id AS albums_id, albums.version_id AS albums_version_id, albums.created_at AS albums_created_at, albums.updated_at AS albums_updated_at, albums.title AS albums_title, albums.marketing_budget AS albums_marketing_budget, albums.release_date AS albums_release_date, albums.cover_picture AS albums_cover_picture, albums.singer_id AS albums_singer_id \n" + + "FROM albums \n" + + "WHERE albums.id = '123-456-789'"), + ResultSet.newBuilder() + .setMetadata(createAlbumsMetadata("albums_")) + .addRows( + createAlbumRow( + "123-456-789", + "My album", + "5000", + Date.parseDate("2000-01-01"), + ByteArray.copyFrom("My album cover picture"), + "321", + Timestamp.parseTimestamp("2022-12-02T17:30:00Z"), + Timestamp.parseTimestamp("2022-12-02T17:30:00Z"))) + .build())); + mockSpanner.putStatementResult( + StatementResult.update( + Statement.of( + "DELETE FROM albums WHERE albums.id = '123-456-789' AND albums.version_id = 1"), + 1L)); + + String output = + execute(SAMPLE_DIR, "test_delete_album.py", "localhost", pgServer.getLocalPort()); + assertEquals("\n" + "Deleted album with id 123-456-789\n", output); + } + + @Test + public void testAlbumsWithTitleFirstCharEqualToSingerName() + throws IOException, InterruptedException { + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT albums.id AS albums_id, albums.version_id AS albums_version_id, albums.created_at AS albums_created_at, albums.updated_at AS albums_updated_at, albums.title AS albums_title, albums.marketing_budget AS albums_marketing_budget, albums.release_date AS albums_release_date, albums.cover_picture AS albums_cover_picture, albums.singer_id AS albums_singer_id \n" + + "FROM albums JOIN singers ON singers.id = albums.singer_id \n" + + "WHERE lower(SUBSTRING(albums.title FROM 1 FOR 1)) = lower(SUBSTRING(singers.first_name FROM 1 FOR 1)) OR " + + "lower(SUBSTRING(albums.title FROM 1 FOR 1)) = lower(SUBSTRING(singers.last_name FROM 1 FOR 1))"), + ResultSet.newBuilder().setMetadata(createAlbumsMetadata("albums_")).build())); + + String output = + execute( + SAMPLE_DIR, + "test_print_albums_first_character_of_title_equal_to_first_or_last_name.py", + "localhost", + pgServer.getLocalPort()); + assertEquals( + "\n" + + "Searching for albums that have a title that starts with the same character as the first or last name of the singer\n", + output); + } + + @Test + public void testSingersWithLimitAndOffset() throws IOException, InterruptedException { + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT singers.id AS singers_id, singers.version_id AS singers_version_id, singers.created_at AS singers_created_at, singers.updated_at AS singers_updated_at, singers.first_name AS singers_first_name, singers.last_name AS singers_last_name, singers.full_name AS singers_full_name, singers.active AS singers_active \n" + + "FROM singers ORDER BY singers.last_name \n" + + " LIMIT 5 OFFSET 3"), + ResultSet.newBuilder() + .setMetadata(createSingersMetadata("singers_")) + .addRows( + createSingerRow( + "123", + "Pete", + "Allison", + true, + Timestamp.parseTimestamp("2022-12-02T17:30:00Z"), + Timestamp.parseTimestamp("2022-12-02T17:30:00Z"))) + .addRows( + createSingerRow( + "321", + "Alice", + "Henderson", + true, + Timestamp.parseTimestamp("2022-12-02T17:30:00Z"), + Timestamp.parseTimestamp("2022-12-02T17:30:00Z"))) + .build())); + String output = + execute( + SAMPLE_DIR, + "test_print_singers_with_limit_and_offset.py", + "localhost", + pgServer.getLocalPort()); + assertEquals( + "\n" + + "Printing at most 5 singers ordered by last name\n" + + " 1. Pete Allison\n" + + " 2. Alice Henderson\n" + + "Found 2 singers\n", + output); + } + + @Test + public void testPrintAlbumsReleasedBefore1980() throws IOException, InterruptedException { + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT albums.id AS albums_id, albums.version_id AS albums_version_id, albums.created_at AS albums_created_at, albums.updated_at AS albums_updated_at, albums.title AS albums_title, albums.marketing_budget AS albums_marketing_budget, albums.release_date AS albums_release_date, albums.cover_picture AS albums_cover_picture, albums.singer_id AS albums_singer_id \n" + + "FROM albums \n" + + "WHERE albums.release_date < '1980-01-01'::date"), + ResultSet.newBuilder() + .setMetadata(createAlbumsMetadata("albums_")) + .addRows( + createAlbumRow( + "a1", + "Album 1", + "123.456", + Date.parseDate("1979-10-16"), + ByteArray.copyFrom("some cover picture"), + "123", + Timestamp.parseTimestamp("2022-12-02T17:30:00Z"), + Timestamp.parseTimestamp("2022-12-02T17:30:00Z"))) + .build())); + String output = + execute( + SAMPLE_DIR, + "test_print_albums_released_before_1980.py", + "localhost", + pgServer.getLocalPort()); + assertEquals( + "\n" + + "Searching for albums released before 1980\n" + + " Album Album 1 was released at 1979-10-16\n", + output); + } + + @Test + public void testPrintConcerts() throws IOException, InterruptedException { + List concertsFields = createConcertsMetadata("concerts_").getRowType().getFieldsList(); + List venuesFields = createVenuesMetadata("venues_1_").getRowType().getFieldsList(); + List singersFields = createSingersMetadata("singers_1_").getRowType().getFieldsList(); + List concertValues = + createConcertRow( + "c1", + "Avenue Park Open", + "v1", + "123", + Timestamp.parseTimestamp("2023-02-01T20:00:00-05:00"), + Timestamp.parseTimestamp("2023-02-02T02:00:00-05:00"), + Timestamp.parseTimestamp("2022-12-02T17:30:00Z"), + Timestamp.parseTimestamp("2022-12-02T17:30:00Z")) + .getValuesList(); + List venueValues = + createVenueRow( + "v1", + "Avenue Park", + "{\n" + + " \"Capacity\": 5000,\n" + + " \"Location\": \"New York\",\n" + + " \"Country\": \"US\"\n" + + "}", + Timestamp.parseTimestamp("2022-12-02T17:30:00Z"), + Timestamp.parseTimestamp("2022-12-02T17:30:00Z")) + .getValuesList(); + List singerValues = + createSingerRow( + "123", + "Pete", + "Allison", + true, + Timestamp.parseTimestamp("2022-12-02T17:30:00Z"), + Timestamp.parseTimestamp("2022-12-02T17:30:00Z")) + .getValuesList(); + + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT concerts.id AS concerts_id, concerts.version_id AS concerts_version_id, concerts.created_at AS concerts_created_at, concerts.updated_at AS concerts_updated_at, concerts.name AS concerts_name, concerts.venue_id AS concerts_venue_id, concerts.singer_id AS concerts_singer_id, concerts.start_time AS concerts_start_time, concerts.end_time AS concerts_end_time, " + + "venues_1.id AS venues_1_id, venues_1.version_id AS venues_1_version_id, venues_1.created_at AS venues_1_created_at, venues_1.updated_at AS venues_1_updated_at, venues_1.name AS venues_1_name, venues_1.description AS venues_1_description, " + + "singers_1.id AS singers_1_id, singers_1.version_id AS singers_1_version_id, singers_1.created_at AS singers_1_created_at, singers_1.updated_at AS singers_1_updated_at, singers_1.first_name AS singers_1_first_name, singers_1.last_name AS singers_1_last_name, singers_1.full_name AS singers_1_full_name, singers_1.active AS singers_1_active \n" + + "FROM concerts LEFT OUTER JOIN venues AS venues_1 ON venues_1.id = concerts.venue_id LEFT OUTER JOIN singers AS singers_1 ON singers_1.id = concerts.singer_id ORDER BY concerts.start_time"), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addAllFields(concertsFields) + .addAllFields(venuesFields) + .addAllFields(singersFields)) + .build()) + .addRows( + ListValue.newBuilder() + .addAllValues(concertValues) + .addAllValues(venueValues) + .addAllValues(singerValues) + .build()) + .build())); + String output = + execute(SAMPLE_DIR, "test_print_concerts.py", "localhost", pgServer.getLocalPort()); + assertEquals( + "\nConcert 'Avenue Park Open' starting at 2023-02-02 02:00:00+01:00 with Pete Allison will be held at Avenue Park\n", + output); + } + + @Test + public void testCreateVenueAndConcertInTransaction() throws IOException, InterruptedException { + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT singers.id AS singers_id, singers.version_id AS singers_version_id, singers.created_at AS singers_created_at, singers.updated_at AS singers_updated_at, singers.first_name AS singers_first_name, singers.last_name AS singers_last_name, singers.full_name AS singers_full_name, singers.active AS singers_active \n" + + "FROM singers \n" + + " LIMIT 1"), + ResultSet.newBuilder() + .setMetadata(createSingersMetadata("singers_")) + .addRows( + createSingerRow( + "123", + "Pete", + "Allison", + true, + Timestamp.parseTimestamp("2001-02-28T00:00:00Z"), + Timestamp.parseTimestamp("2001-02-28T00:00:00Z"))) + .build())); + mockSpanner.putPartialStatementResult( + StatementResult.update( + Statement.of( + "INSERT INTO venues (id, version_id, created_at, updated_at, name, description) VALUES "), + 1L)); + mockSpanner.putPartialStatementResult( + StatementResult.update( + Statement.of( + "INSERT INTO concerts (id, version_id, created_at, updated_at, name, venue_id, singer_id, start_time, end_time) VALUES "), + 1L)); + + String output = + execute( + SAMPLE_DIR, + "test_create_venue_and_concert_in_transaction.py", + "localhost", + pgServer.getLocalPort()); + assertEquals("\nCreated Venue and Concert\n", output); + } + + @Test + public void testCreateRandomSingersAndAlbums() throws IOException, InterruptedException { + mockSpanner.putPartialStatementResult( + StatementResult.query( + Statement.of( + "INSERT INTO singers (id, version_id, created_at, updated_at, first_name, last_name, active) " + + "VALUES "), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("full_name") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .build())) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("(unknown)").build()) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("(unknown)").build()) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("(unknown)").build()) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("(unknown)").build()) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("(unknown)").build()) + .build()) + .setStats(ResultSetStats.newBuilder().setRowCountExact(5L).build()) + .build())); + mockSpanner.putPartialStatementResult( + StatementResult.update( + Statement.of( + "INSERT INTO albums (id, version_id, created_at, updated_at, title, marketing_budget, release_date, cover_picture, singer_id) VALUES "), + 37L)); + + String output = + execute( + SAMPLE_DIR, + "test_create_random_singers_and_albums.py", + "localhost", + pgServer.getLocalPort()); + assertEquals("Created 5 singers\n", output); + } + + @Test + public void testPrintSingersAndAlbums() throws IOException, InterruptedException { + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT singers.id AS singers_id, singers.version_id AS singers_version_id, singers.created_at AS singers_created_at, singers.updated_at AS singers_updated_at, singers.first_name AS singers_first_name, singers.last_name AS singers_last_name, singers.full_name AS singers_full_name, singers.active AS singers_active \n" + + "FROM singers ORDER BY singers.last_name"), + ResultSet.newBuilder() + .setMetadata(createSingersMetadata("singers_")) + .addRows( + createSingerRow( + "b2", + "Pete", + "Allison", + true, + Timestamp.parseTimestamp("2022-12-01T15:12:00Z"), + Timestamp.parseTimestamp("2022-12-01T15:12:00Z"))) + .addRows( + createSingerRow( + "a1", + "Alice", + "Henderson", + true, + Timestamp.parseTimestamp("2022-12-01T15:12:00Z"), + Timestamp.parseTimestamp("2022-12-01T15:12:00Z"))) + .addRows( + createSingerRow( + "c3", + "Renate", + "Unna", + true, + Timestamp.parseTimestamp("2022-12-01T15:12:00Z"), + Timestamp.parseTimestamp("2022-12-01T15:12:00Z"))) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT albums.id AS albums_id, albums.version_id AS albums_version_id, albums.created_at AS albums_created_at, albums.updated_at AS albums_updated_at, albums.title AS albums_title, albums.marketing_budget AS albums_marketing_budget, albums.release_date AS albums_release_date, albums.cover_picture AS albums_cover_picture, albums.singer_id AS albums_singer_id \n" + + "FROM albums \n" + + "WHERE 'b2' = albums.singer_id"), + ResultSet.newBuilder() + .setMetadata(createAlbumsMetadata("albums_")) + .addRows( + createAlbumRow( + "a1", + "Title 1", + "100.90", + Date.parseDate("2000-01-01"), + ByteArray.copyFrom("cover pic 1"), + "b2", + Timestamp.parseTimestamp("2022-12-01T15:12:00Z"), + Timestamp.parseTimestamp("2022-12-01T15:12:00Z"))) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT albums.id AS albums_id, albums.version_id AS albums_version_id, albums.created_at AS albums_created_at, albums.updated_at AS albums_updated_at, albums.title AS albums_title, albums.marketing_budget AS albums_marketing_budget, albums.release_date AS albums_release_date, albums.cover_picture AS albums_cover_picture, albums.singer_id AS albums_singer_id \n" + + "FROM albums \n" + + "WHERE 'a1' = albums.singer_id"), + ResultSet.newBuilder() + .setMetadata(createAlbumsMetadata("albums_")) + .addRows( + createAlbumRow( + "a2", + "Title 2", + "100.90", + Date.parseDate("2000-01-01"), + ByteArray.copyFrom("cover pic 1"), + "a1", + Timestamp.parseTimestamp("2022-12-01T15:12:00Z"), + Timestamp.parseTimestamp("2022-12-01T15:12:00Z"))) + .addRows( + createAlbumRow( + "a3", + "Title 3", + "100.90", + Date.parseDate("2000-01-01"), + ByteArray.copyFrom("cover pic 2"), + "a1", + Timestamp.parseTimestamp("2022-12-01T15:12:00Z"), + Timestamp.parseTimestamp("2022-12-01T15:12:00Z"))) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT albums.id AS albums_id, albums.version_id AS albums_version_id, albums.created_at AS albums_created_at, albums.updated_at AS albums_updated_at, albums.title AS albums_title, albums.marketing_budget AS albums_marketing_budget, albums.release_date AS albums_release_date, albums.cover_picture AS albums_cover_picture, albums.singer_id AS albums_singer_id \n" + + "FROM albums \n" + + "WHERE 'c3' = albums.singer_id"), + ResultSet.newBuilder().setMetadata(createAlbumsMetadata("albums_")).build())); + + String output = + execute( + SAMPLE_DIR, "test_print_singers_and_albums.py", "localhost", pgServer.getLocalPort()); + assertEquals( + "\n" + + "Pete Allison has 1 albums:\n" + + " 'Title 1'\n" + + "Alice Henderson has 2 albums:\n" + + " 'Title 2'\n" + + " 'Title 3'\n" + + "Renate Unna has 0 albums:\n", + output); + List beginRequests = + mockSpanner.getRequestsOfType(BeginTransactionRequest.class); + assertEquals(1, beginRequests.size()); + assertTrue(beginRequests.get(0).getOptions().hasReadOnly()); + List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); + assertEquals(5, requests.size()); + } + + @Test + public void testGetSinger() throws IOException, InterruptedException { + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT singers.id AS singers_id, singers.version_id AS singers_version_id, singers.created_at AS singers_created_at, singers.updated_at AS singers_updated_at, singers.first_name AS singers_first_name, singers.last_name AS singers_last_name, singers.full_name AS singers_full_name, singers.active AS singers_active \n" + + "FROM singers \n" + + "WHERE singers.id = '123-456-789'"), + ResultSet.newBuilder() + .setMetadata(createSingersMetadata("singers_")) + .addRows( + createSingerRow( + "123-456-789", + "Myfirstname", + "Mylastname", + true, + Timestamp.parseTimestamp("2022-02-21T10:19:18Z"), + Timestamp.parseTimestamp("2022-02-21T10:19:18Z"))) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT albums.id AS albums_id, albums.version_id AS albums_version_id, albums.created_at AS albums_created_at, albums.updated_at AS albums_updated_at, albums.title AS albums_title, albums.marketing_budget AS albums_marketing_budget, albums.release_date AS albums_release_date, albums.cover_picture AS albums_cover_picture, albums.singer_id AS albums_singer_id \n" + + "FROM albums \n" + + "WHERE '123-456-789' = albums.singer_id"), + ResultSet.newBuilder() + .setMetadata(createAlbumsMetadata("albums_")) + .addRows( + createAlbumRow( + "987-654-321", + "My title", + "9423.13", + Date.parseDate("2002-10-17"), + ByteArray.copyFrom("cover picture"), + "123-456-789", + Timestamp.parseTimestamp("2022-02-21T10:19:18Z"), + Timestamp.parseTimestamp("2022-02-21T10:19:18Z"))) + .build())); + + String output = execute(SAMPLE_DIR, "test_get_singer.py", "localhost", pgServer.getLocalPort()); + assertEquals( + "singers(id='123-456-789',first_name='Myfirstname',last_name='Mylastname',active=True,created_at='2022-02-21T10:19:18+00:00',updated_at='2022-02-21T10:19:18+00:00')\n" + + "Albums:\n" + + "[albums(id='987-654-321',title='My title',marketing_budget=Decimal('9423.13'),release_date=datetime.date(2002, 10, 17),cover_picture=b'cover picture',singer='123-456-789',created_at='2022-02-21T10:19:18+00:00',updated_at='2022-02-21T10:19:18+00:00')]\n", + output); + } + + @Test + public void testAddSinger() throws IOException, InterruptedException { + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "INSERT INTO singers (id, version_id, created_at, updated_at, first_name, last_name, active) " + + "VALUES ('123-456-789', 1, ('2011-11-04T00:05:23.123456+00:00'::timestamptz), NULL, 'Myfirstname', 'Mylastname', true) " + + "RETURNING singers.full_name"), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("full_name") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .build())) + .addRows( + ListValue.newBuilder() + .addValues( + Value.newBuilder().setStringValue("Myfirstname Mylastname").build()) + .build()) + .build())); + + String output = execute(SAMPLE_DIR, "test_add_singer.py", "localhost", pgServer.getLocalPort()); + assertEquals("Added singer 123-456-789 with full name Myfirstname Mylastname\n", output); + } + + @Test + public void testUpdateSinger() throws IOException, InterruptedException { + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT singers.id AS singers_id, singers.version_id AS singers_version_id, singers.created_at AS singers_created_at, singers.updated_at AS singers_updated_at, singers.first_name AS singers_first_name, singers.last_name AS singers_last_name, singers.full_name AS singers_full_name, singers.active AS singers_active \n" + + "FROM singers \n" + + "WHERE singers.id = '123-456-789'"), + ResultSet.newBuilder() + .setMetadata(createSingersMetadata("singers_")) + .addRows( + createSingerRow( + "123-456-789", + "Myfirstname", + "Mylastname", + true, + Timestamp.parseTimestamp("2022-12-01T10:00:00Z"), + Timestamp.parseTimestamp("2022-12-01T10:00:00Z"))) + .build())); + // We have to use a partial SQL string here, as we don't know exactly what updated_at timestamp + // will be used by SQLAlchemy. + mockSpanner.putPartialStatementResult( + StatementResult.query( + Statement.of("UPDATE singers SET version_id=2, updated_at='"), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("full_name") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .build())) + .addRows( + ListValue.newBuilder() + .addValues( + Value.newBuilder().setStringValue("Newfirstname Newlastname").build()) + .build()) + .build())); + + String output = + execute(SAMPLE_DIR, "test_update_singer.py", "localhost", pgServer.getLocalPort()); + assertEquals("Updated singer 123-456-789 with full name Newfirstname Newlastname\n", output); + } + + @Test + public void testGetAlbum() throws IOException, InterruptedException { + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT albums.id AS albums_id, albums.version_id AS albums_version_id, albums.created_at AS albums_created_at, albums.updated_at AS albums_updated_at, albums.title AS albums_title, albums.marketing_budget AS albums_marketing_budget, albums.release_date AS albums_release_date, albums.cover_picture AS albums_cover_picture, albums.singer_id AS albums_singer_id \n" + + "FROM albums \n" + + "WHERE albums.id = '987-654-321'"), + ResultSet.newBuilder() + .setMetadata(createAlbumsMetadata("albums_")) + .addRows( + createAlbumRow( + "123-456-789", + "My title", + "9423.13", + Date.parseDate("2002-10-17"), + ByteArray.copyFrom("cover picture"), + "123-456-789", + Timestamp.parseTimestamp("2022-02-21T10:19:18Z"), + Timestamp.parseTimestamp("2022-02-21T10:19:18Z"))) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT tracks.version_id AS tracks_version_id, tracks.created_at AS tracks_created_at, tracks.updated_at AS tracks_updated_at, tracks.id AS tracks_id, tracks.track_number AS tracks_track_number, tracks.title AS tracks_title, tracks.sample_rate AS tracks_sample_rate \n" + + "FROM tracks \n" + + "WHERE '123-456-789' = tracks.id"), + ResultSet.newBuilder() + .setMetadata(createTracksMetadata("tracks_")) + .addRows( + createTrackRow( + "123-456-789", + 1L, + "Track 1", + 6.34324, + Timestamp.parseTimestamp("2018-02-28T17:00:00Z"), + Timestamp.parseTimestamp("2018-02-01T09:00:00Z"))) + .addRows( + createTrackRow( + "123-456-789", + 2L, + "Track 2", + 6.34324, + Timestamp.parseTimestamp("2018-02-28T17:00:00Z"), + Timestamp.parseTimestamp("2018-02-01T09:00:00Z"))) + .build())); + + String output = execute(SAMPLE_DIR, "test_get_album.py", "localhost", pgServer.getLocalPort()); + assertEquals( + "albums(id='123-456-789',title='My title',marketing_budget=Decimal('9423.13'),release_date=datetime.date(2002, 10, 17),cover_picture=b'cover picture',singer='123-456-789',created_at='2022-02-21T10:19:18+00:00',updated_at='2022-02-21T10:19:18+00:00')\n" + + "Tracks:\n" + + "[tracks(id='123-456-789',track_number=1,title='Track 1',sample_rate=6.34324,created_at='2018-02-28T17:00:00+00:00',updated_at='2018-02-01T09:00:00+00:00'), " + + "tracks(id='123-456-789',track_number=2,title='Track 2',sample_rate=6.34324,created_at='2018-02-28T17:00:00+00:00',updated_at='2018-02-01T09:00:00+00:00')]\n", + output); + } + + @Test + public void testGetAlbumWithStaleEngine() throws IOException, InterruptedException { + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT albums.id AS albums_id, albums.version_id AS albums_version_id, albums.created_at AS albums_created_at, albums.updated_at AS albums_updated_at, albums.title AS albums_title, albums.marketing_budget AS albums_marketing_budget, albums.release_date AS albums_release_date, albums.cover_picture AS albums_cover_picture, albums.singer_id AS albums_singer_id \n" + + "FROM albums \n" + + "WHERE albums.id = '987-654-321'"), + ResultSet.newBuilder() + .setMetadata(createAlbumsMetadata("albums_")) + .addRows( + createAlbumRow( + "123-456-789", + "My title", + "9423.13", + Date.parseDate("2002-10-17"), + ByteArray.copyFrom("cover picture"), + "123-456-789", + Timestamp.parseTimestamp("2022-02-21T10:19:18Z"), + Timestamp.parseTimestamp("2022-02-21T10:19:18Z"))) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT tracks.version_id AS tracks_version_id, tracks.created_at AS tracks_created_at, tracks.updated_at AS tracks_updated_at, tracks.id AS tracks_id, tracks.track_number AS tracks_track_number, tracks.title AS tracks_title, tracks.sample_rate AS tracks_sample_rate \n" + + "FROM tracks \n" + + "WHERE '123-456-789' = tracks.id"), + ResultSet.newBuilder() + .setMetadata(createTracksMetadata("tracks_")) + .addRows( + createTrackRow( + "123-456-789", + 1L, + "Track 1", + 6.34324, + Timestamp.parseTimestamp("2018-02-28T17:00:00Z"), + Timestamp.parseTimestamp("2018-02-01T09:00:00Z"))) + .addRows( + createTrackRow( + "123-456-789", + 2L, + "Track 2", + 6.34324, + Timestamp.parseTimestamp("2018-02-28T17:00:00Z"), + Timestamp.parseTimestamp("2018-02-01T09:00:00Z"))) + .build())); + + String output = + execute( + SAMPLE_DIR, + "test_get_album_with_stale_engine.py", + "localhost", + pgServer.getLocalPort()); + assertEquals("", output); + List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); + assertEquals(2, requests.size()); + ExecuteSqlRequest request = requests.get(1); + assertTrue(request.getTransaction().hasSingleUse()); + assertTrue(request.getTransaction().getSingleUse().hasReadOnly()); + assertTrue(request.getTransaction().getSingleUse().getReadOnly().hasMaxStaleness()); + assertEquals( + Duration.newBuilder().setSeconds(10L).setNanos(0).build(), + request.getTransaction().getSingleUse().getReadOnly().getMaxStaleness()); + } + + @Test + public void testGetTrack() throws IOException, InterruptedException { + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of( + "SELECT tracks.version_id AS tracks_version_id, tracks.created_at AS tracks_created_at, tracks.updated_at AS tracks_updated_at, tracks.id AS tracks_id, tracks.track_number AS tracks_track_number, tracks.title AS tracks_title, tracks.sample_rate AS tracks_sample_rate \n" + + "FROM tracks \n" + + "WHERE tracks.id = '987-654-321' AND tracks.track_number = 1"), + ResultSet.newBuilder() + .setMetadata(createTracksMetadata("tracks_")) + .addRows( + createTrackRow( + "987-654-321", + 1L, + "Track 1", + 6.34324, + Timestamp.parseTimestamp("2018-02-28T17:00:00Z"), + Timestamp.parseTimestamp("2018-02-01T09:00:00Z"))) + .build())); + + String output = execute(SAMPLE_DIR, "test_get_track.py", "localhost", pgServer.getLocalPort()); + assertEquals( + "tracks(id='987-654-321',track_number=1,title='Track 1',sample_rate=6.34324,created_at='2018-02-28T17:00:00+00:00',updated_at='2018-02-01T09:00:00+00:00')\n", + output); + } + + @Test + public void testCreateDataModel() throws Exception { + String checkTableExistsSql = + "with pg_class as (\n" + + " select\n" + + " -1 as oid,\n" + + " table_name as relname,\n" + + " case table_schema when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as relnamespace,\n" + + " 0 as reltype,\n" + + " 0 as reloftype,\n" + + " 0 as relowner,\n" + + " 1 as relam,\n" + + " 0 as relfilenode,\n" + + " 0 as reltablespace,\n" + + " 0 as relpages,\n" + + " 0.0::float8 as reltuples,\n" + + " 0 as relallvisible,\n" + + " 0 as reltoastrelid,\n" + + " false as relhasindex,\n" + + " false as relisshared,\n" + + " 'p' as relpersistence,\n" + + " 'r' as relkind,\n" + + " count(*) as relnatts,\n" + + " 0 as relchecks,\n" + + " false as relhasrules,\n" + + " false as relhastriggers,\n" + + " false as relhassubclass,\n" + + " false as relrowsecurity,\n" + + " false as relforcerowsecurity,\n" + + " true as relispopulated,\n" + + " 'n' as relreplident,\n" + + " false as relispartition,\n" + + " 0 as relrewrite,\n" + + " 0 as relfrozenxid,\n" + + " 0 as relminmxid,\n" + + " '{}'::bigint[] as relacl,\n" + + " '{}'::text[] as reloptions,\n" + + " 0 as relpartbound\n" + + "from information_schema.tables t\n" + + "inner join information_schema.columns using (table_catalog, table_schema, table_name)\n" + + "group by t.table_name, t.table_schema\n" + + "union all\n" + + "select\n" + + " -1 as oid,\n" + + " i.index_name as relname,\n" + + " case table_schema when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as relnamespace,\n" + + " 0 as reltype,\n" + + " 0 as reloftype,\n" + + " 0 as relowner,\n" + + " 1 as relam,\n" + + " 0 as relfilenode,\n" + + " 0 as reltablespace,\n" + + " 0 as relpages,\n" + + " 0.0::float8 as reltuples,\n" + + " 0 as relallvisible,\n" + + " 0 as reltoastrelid,\n" + + " false as relhasindex,\n" + + " false as relisshared,\n" + + " 'p' as relpersistence,\n" + + " 'r' as relkind,\n" + + " count(*) as relnatts,\n" + + " 0 as relchecks,\n" + + " false as relhasrules,\n" + + " false as relhastriggers,\n" + + " false as relhassubclass,\n" + + " false as relrowsecurity,\n" + + " false as relforcerowsecurity,\n" + + " true as relispopulated,\n" + + " 'n' as relreplident,\n" + + " false as relispartition,\n" + + " 0 as relrewrite,\n" + + " 0 as relfrozenxid,\n" + + " 0 as relminmxid,\n" + + " '{}'::bigint[] as relacl,\n" + + " '{}'::text[] as reloptions,\n" + + " 0 as relpartbound\n" + + "from information_schema.indexes i\n" + + "inner join information_schema.index_columns using (table_catalog, table_schema, table_name)\n" + + "group by i.index_name, i.table_schema\n" + + "),\n" + + "pg_namespace as (\n" + + " select case schema_name when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as oid,\n" + + " schema_name as nspname, null as nspowner, null as nspacl\n" + + " from information_schema.schemata\n" + + ")\n" + + "select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where true and relname='%s'"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(String.format(checkTableExistsSql, "singers")), + ResultSet.newBuilder().setMetadata(SELECT1_RESULTSET.getMetadata()).build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(String.format(checkTableExistsSql, "albums")), + ResultSet.newBuilder().setMetadata(SELECT1_RESULTSET.getMetadata()).build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(String.format(checkTableExistsSql, "tracks")), + ResultSet.newBuilder().setMetadata(SELECT1_RESULTSET.getMetadata()).build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(String.format(checkTableExistsSql, "venues")), + ResultSet.newBuilder().setMetadata(SELECT1_RESULTSET.getMetadata()).build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(String.format(checkTableExistsSql, "concerts")), + ResultSet.newBuilder().setMetadata(SELECT1_RESULTSET.getMetadata()).build())); + addDdlResponseToSpannerAdmin(); + + String actualOutput = + execute(SAMPLE_DIR, "test_create_model.py", "localhost", pgServer.getLocalPort()); + String expectedOutput = "Created data model\n"; + assertEquals(expectedOutput, actualOutput); + + List requests = + mockDatabaseAdmin.getRequests().stream() + .filter(req -> req instanceof UpdateDatabaseDdlRequest) + .map(req -> (UpdateDatabaseDdlRequest) req) + .collect(Collectors.toList()); + assertEquals(1, requests.size()); + assertEquals(5, requests.get(0).getStatementsCount()); + assertEquals( + "CREATE TABLE singers (\n" + + "\tid VARCHAR NOT NULL, \n" + + "\tversion_id INTEGER NOT NULL, \n" + + "\tcreated_at TIMESTAMP WITH TIME ZONE, \n" + + "\tupdated_at TIMESTAMP WITH TIME ZONE, \n" + + "\tfirst_name VARCHAR(100), \n" + + "\tlast_name VARCHAR(200), \n" + + "\tfull_name VARCHAR, \n" + + "\tactive BOOLEAN, \n" + + "\tPRIMARY KEY (id)\n" + + ")", + requests.get(0).getStatements(0)); + assertEquals( + "CREATE TABLE venues (\n" + + "\tid VARCHAR NOT NULL, \n" + + "\tversion_id INTEGER NOT NULL, \n" + + "\tcreated_at TIMESTAMP WITH TIME ZONE, \n" + + "\tupdated_at TIMESTAMP WITH TIME ZONE, \n" + + "\tname VARCHAR(200), \n" + + "\tdescription JSONB, \n" + + "\tPRIMARY KEY (id)\n" + + ")", + requests.get(0).getStatements(1)); + assertEquals( + "CREATE TABLE albums (\n" + + "\tid VARCHAR NOT NULL, \n" + + "\tversion_id INTEGER NOT NULL, \n" + + "\tcreated_at TIMESTAMP WITH TIME ZONE, \n" + + "\tupdated_at TIMESTAMP WITH TIME ZONE, \n" + + "\ttitle VARCHAR(200), \n" + + "\tmarketing_budget NUMERIC, \n" + + "\trelease_date DATE, \n" + + "\tcover_picture BYTEA, \n" + + "\tsinger_id VARCHAR, \n" + + "\tPRIMARY KEY (id), \n" + + "\tFOREIGN KEY(singer_id) REFERENCES singers (id)\n" + + ")", + requests.get(0).getStatements(2)); + assertEquals( + "CREATE TABLE concerts (\n" + + "\tid VARCHAR NOT NULL, \n" + + "\tversion_id INTEGER NOT NULL, \n" + + "\tcreated_at TIMESTAMP WITH TIME ZONE, \n" + + "\tupdated_at TIMESTAMP WITH TIME ZONE, \n" + + "\tname VARCHAR(200), \n" + + "\tvenue_id VARCHAR, \n" + + "\tsinger_id VARCHAR, \n" + + "\tstart_time TIMESTAMP WITH TIME ZONE, \n" + + "\tend_time TIMESTAMP WITH TIME ZONE, \n" + + "\tPRIMARY KEY (id), \n" + + "\tFOREIGN KEY(venue_id) REFERENCES venues (id), \n" + + "\tFOREIGN KEY(singer_id) REFERENCES singers (id)\n" + + ")", + requests.get(0).getStatements(3)); + // The 'tracks' table is not generated 100% according to what we would want, but that is because + // the PostgreSQL SQLAlchemy provider does not understand interleaved tables. + assertEquals( + "CREATE TABLE tracks (\n" + + "\tversion_id INTEGER NOT NULL, \n" + + "\tcreated_at TIMESTAMP WITH TIME ZONE, \n" + + "\tupdated_at TIMESTAMP WITH TIME ZONE, \n" + + "\tid VARCHAR NOT NULL, \n" + + "\ttrack_number INTEGER NOT NULL, \n" + + "\ttitle VARCHAR, \n" + + "\tsample_rate FLOAT, \n" + + "\tPRIMARY KEY (id, track_number), \n" + + "\tFOREIGN KEY(id) REFERENCES albums (id)\n" + + ")", + requests.get(0).getStatements(4)); + } + + @Test + public void testMetadataReflect() throws IOException, InterruptedException { + String sql = + "with pg_class as (\n" + + " select\n" + + " -1 as oid,\n" + + " table_name as relname,\n" + + " case table_schema when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as relnamespace,\n" + + " 0 as reltype,\n" + + " 0 as reloftype,\n" + + " 0 as relowner,\n" + + " 1 as relam,\n" + + " 0 as relfilenode,\n" + + " 0 as reltablespace,\n" + + " 0 as relpages,\n" + + " 0.0::float8 as reltuples,\n" + + " 0 as relallvisible,\n" + + " 0 as reltoastrelid,\n" + + " false as relhasindex,\n" + + " false as relisshared,\n" + + " 'p' as relpersistence,\n" + + " 'r' as relkind,\n" + + " count(*) as relnatts,\n" + + " 0 as relchecks,\n" + + " false as relhasrules,\n" + + " false as relhastriggers,\n" + + " false as relhassubclass,\n" + + " false as relrowsecurity,\n" + + " false as relforcerowsecurity,\n" + + " true as relispopulated,\n" + + " 'n' as relreplident,\n" + + " false as relispartition,\n" + + " 0 as relrewrite,\n" + + " 0 as relfrozenxid,\n" + + " 0 as relminmxid,\n" + + " '{}'::bigint[] as relacl,\n" + + " '{}'::text[] as reloptions,\n" + + " 0 as relpartbound\n" + + "from information_schema.tables t\n" + + "inner join information_schema.columns using (table_catalog, table_schema, table_name)\n" + + "group by t.table_name, t.table_schema\n" + + "union all\n" + + "select\n" + + " -1 as oid,\n" + + " i.index_name as relname,\n" + + " case table_schema when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as relnamespace,\n" + + " 0 as reltype,\n" + + " 0 as reloftype,\n" + + " 0 as relowner,\n" + + " 1 as relam,\n" + + " 0 as relfilenode,\n" + + " 0 as reltablespace,\n" + + " 0 as relpages,\n" + + " 0.0::float8 as reltuples,\n" + + " 0 as relallvisible,\n" + + " 0 as reltoastrelid,\n" + + " false as relhasindex,\n" + + " false as relisshared,\n" + + " 'p' as relpersistence,\n" + + " 'r' as relkind,\n" + + " count(*) as relnatts,\n" + + " 0 as relchecks,\n" + + " false as relhasrules,\n" + + " false as relhastriggers,\n" + + " false as relhassubclass,\n" + + " false as relrowsecurity,\n" + + " false as relforcerowsecurity,\n" + + " true as relispopulated,\n" + + " 'n' as relreplident,\n" + + " false as relispartition,\n" + + " 0 as relrewrite,\n" + + " 0 as relfrozenxid,\n" + + " 0 as relminmxid,\n" + + " '{}'::bigint[] as relacl,\n" + + " '{}'::text[] as reloptions,\n" + + " 0 as relpartbound\n" + + "from information_schema.indexes i\n" + + "inner join information_schema.index_columns using (table_catalog, table_schema, table_name)\n" + + "group by i.index_name, i.table_schema\n" + + "),\n" + + "pg_namespace as (\n" + + " select case schema_name when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as oid,\n" + + " schema_name as nspname, null as nspowner, null as nspacl\n" + + " from information_schema.schemata\n" + + ")\n" + + "SELECT c.relname FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE n.nspname = 'public' AND c.relkind in ('r', 'p')"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("relname") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .build()) + .build()) + .build())); + + String actualOutput = + execute(SAMPLE_DIR, "test_metadata_reflect.py", "localhost", pgServer.getLocalPort()); + String expectedOutput = "Reflected current data model\n"; + assertEquals(expectedOutput, actualOutput); + } + + static ResultSetMetadata createSingersMetadata(String prefix) { + return ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "id") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .setName(prefix + "version_id") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .setName(prefix + "created_at") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .setName(prefix + "updated_at") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "first_name") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "last_name") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "full_name") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.BOOL).build()) + .setName(prefix + "active") + .build()) + .build()) + .build(); + } + + static ListValue createSingerRow( + String id, + String firstName, + String lastName, + boolean active, + Timestamp createdAt, + Timestamp updatedAt) { + return ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue(id).build()) + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues(Value.newBuilder().setStringValue(createdAt.toString()).build()) + .addValues(Value.newBuilder().setStringValue(updatedAt.toString()).build()) + .addValues(Value.newBuilder().setStringValue(firstName).build()) + .addValues(Value.newBuilder().setStringValue(lastName).build()) + .addValues(Value.newBuilder().setStringValue(firstName + " " + lastName).build()) + .addValues(Value.newBuilder().setBoolValue(active).build()) + .build(); + } + + static ResultSetMetadata createAlbumsMetadata(String prefix) { + return ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "id") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .setName(prefix + "version_id") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .setName(prefix + "created_at") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .setName(prefix + "updated_at") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "title") + .build()) + .addFields( + Field.newBuilder() + .setType( + Type.newBuilder() + .setCode(TypeCode.NUMERIC) + .setTypeAnnotation(TypeAnnotationCode.PG_NUMERIC) + .build()) + .setName(prefix + "marketing_budget") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.DATE).build()) + .setName(prefix + "release_date") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.BYTES).build()) + .setName(prefix + "cover_picture") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "singer_id") + .build()) + .build()) + .build(); + } + + static ListValue createAlbumRow( + String id, + String title, + String marketingBudget, + Date releaseDate, + ByteArray coverPicture, + String singerId, + Timestamp createdAt, + Timestamp updatedAt) { + return ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue(id).build()) + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues(Value.newBuilder().setStringValue(createdAt.toString()).build()) + .addValues(Value.newBuilder().setStringValue(updatedAt.toString()).build()) + .addValues(Value.newBuilder().setStringValue(title).build()) + .addValues(Value.newBuilder().setStringValue(marketingBudget).build()) + .addValues(Value.newBuilder().setStringValue(releaseDate.toString()).build()) + .addValues(Value.newBuilder().setStringValue(coverPicture.toBase64()).build()) + .addValues(Value.newBuilder().setStringValue(singerId).build()) + .build(); + } + + static ResultSetMetadata createTracksMetadata(String prefix) { + return ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .setName(prefix + "version_id") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .setName(prefix + "created_at") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .setName(prefix + "updated_at") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "id") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .setName(prefix + "track_number") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "title") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.FLOAT64).build()) + .setName(prefix + "sample_rate") + .build()) + .build()) + .build(); + } + + static ListValue createTrackRow( + String id, + long trackNumber, + String title, + double sampleRate, + Timestamp createdAt, + Timestamp updatedAt) { + return ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues(Value.newBuilder().setStringValue(createdAt.toString()).build()) + .addValues(Value.newBuilder().setStringValue(updatedAt.toString()).build()) + .addValues(Value.newBuilder().setStringValue(id).build()) + .addValues(Value.newBuilder().setStringValue(String.valueOf(trackNumber)).build()) + .addValues(Value.newBuilder().setStringValue(title).build()) + .addValues(Value.newBuilder().setNumberValue(sampleRate).build()) + .build(); + } + + static ResultSetMetadata createConcertsMetadata(String prefix) { + return ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "id") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .setName(prefix + "version_id") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .setName(prefix + "created_at") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .setName(prefix + "updated_at") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "name") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "venue_id") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "singer_id") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .setName(prefix + "start_time") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .setName(prefix + "end_time") + .build()) + .build()) + .build(); + } + + static ListValue createConcertRow( + String id, + String name, + String venueId, + String singerId, + Timestamp startTime, + Timestamp endTime, + Timestamp createdAt, + Timestamp updatedAt) { + return ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue(id).build()) + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues(Value.newBuilder().setStringValue(createdAt.toString()).build()) + .addValues(Value.newBuilder().setStringValue(updatedAt.toString()).build()) + .addValues(Value.newBuilder().setStringValue(name).build()) + .addValues(Value.newBuilder().setStringValue(venueId).build()) + .addValues(Value.newBuilder().setStringValue(singerId).build()) + .addValues(Value.newBuilder().setStringValue(startTime.toString()).build()) + .addValues(Value.newBuilder().setStringValue(endTime.toString()).build()) + .build(); + } + + static ResultSetMetadata createVenuesMetadata(String prefix) { + return ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "id") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .setName(prefix + "version_id") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .setName(prefix + "created_at") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .setName(prefix + "updated_at") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "name") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName(prefix + "description") + .build()) + .build()) + .build(); + } + + static ListValue createVenueRow( + String id, String name, String description, Timestamp createdAt, Timestamp updatedAt) { + return ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue(id).build()) + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues(Value.newBuilder().setStringValue(createdAt.toString()).build()) + .addValues(Value.newBuilder().setStringValue(updatedAt.toString()).build()) + .addValues(Value.newBuilder().setStringValue(name).build()) + .addValues(Value.newBuilder().setStringValue(description).build()) + .build(); + } +} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/session/SessionStateTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/session/SessionStateTest.java index dac991ace..d67eb144a 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/session/SessionStateTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/session/SessionStateTest.java @@ -32,6 +32,7 @@ import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata.DdlTransactionMode; import com.google.cloud.spanner.pgadapter.statements.PgCatalog; +import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient; import com.google.common.collect.ImmutableMap; import java.util.List; import java.util.Map; @@ -463,7 +464,7 @@ public void testGeneratePGSettingsCte() { @Test public void testAddSessionState() { SessionState state = new SessionState(mock(OptionsMetadata.class)); - PgCatalog pgCatalog = new PgCatalog(state); + PgCatalog pgCatalog = new PgCatalog(state, WellKnownClient.UNSPECIFIED); Statement statement = Statement.of("select * from pg_settings"); Statement withSessionState = pgCatalog.replacePgCatalogTables(statement); @@ -476,7 +477,7 @@ public void testAddSessionState() { @Test public void testAddSessionStateWithParameters() { SessionState state = new SessionState(mock(OptionsMetadata.class)); - PgCatalog pgCatalog = new PgCatalog(state); + PgCatalog pgCatalog = new PgCatalog(state, WellKnownClient.UNSPECIFIED); Statement statement = Statement.newBuilder("select * from pg_settings where name=$1") .bind("p1") @@ -497,7 +498,7 @@ public void testAddSessionStateWithParameters() { @Test public void testAddSessionStateWithoutPgSettings() { SessionState state = new SessionState(mock(OptionsMetadata.class)); - PgCatalog pgCatalog = new PgCatalog(state); + PgCatalog pgCatalog = new PgCatalog(state, WellKnownClient.UNSPECIFIED); Statement statement = Statement.of("select * from some_table"); Statement withSessionState = pgCatalog.replacePgCatalogTables(statement); @@ -508,7 +509,7 @@ public void testAddSessionStateWithoutPgSettings() { @Test public void testAddSessionStateWithComments() { SessionState state = new SessionState(mock(OptionsMetadata.class)); - PgCatalog pgCatalog = new PgCatalog(state); + PgCatalog pgCatalog = new PgCatalog(state, WellKnownClient.UNSPECIFIED); Statement statement = Statement.of("/* This comment is preserved */ select * from pg_settings"); Statement withSessionState = pgCatalog.replacePgCatalogTables(statement); @@ -521,7 +522,7 @@ public void testAddSessionStateWithComments() { @Test public void testAddSessionStateWithExistingCTE() { SessionState state = new SessionState(mock(OptionsMetadata.class)); - PgCatalog pgCatalog = new PgCatalog(state); + PgCatalog pgCatalog = new PgCatalog(state, WellKnownClient.UNSPECIFIED); Statement statement = Statement.of( "with my_cte as (select col1, col2 from foo) select * from pg_settings inner join my_cte on my_cte.col1=pg_settings.name"); @@ -538,7 +539,7 @@ public void testAddSessionStateWithExistingCTE() { @Test public void testAddSessionStateWithCommentsAndExistingCTE() { SessionState state = new SessionState(mock(OptionsMetadata.class)); - PgCatalog pgCatalog = new PgCatalog(state); + PgCatalog pgCatalog = new PgCatalog(state, WellKnownClient.UNSPECIFIED); Statement statement = Statement.of( "/* This comment is preserved */ with foo as (select * from bar)\nselect * from pg_settings"); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/statements/BackendConnectionTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/statements/BackendConnectionTest.java index f149bcf6a..55ca7b761 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/statements/BackendConnectionTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/statements/BackendConnectionTest.java @@ -52,6 +52,7 @@ import com.google.cloud.spanner.pgadapter.statements.DdlExecutor.NotExecuted; import com.google.cloud.spanner.pgadapter.statements.local.ListDatabasesStatement; import com.google.cloud.spanner.pgadapter.statements.local.LocalStatement; +import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient; import com.google.cloud.spanner.pgadapter.utils.CopyDataReceiver; import com.google.cloud.spanner.pgadapter.utils.MutationWriter; import com.google.common.collect.ImmutableList; @@ -134,8 +135,9 @@ public void testExecuteStatementsInBatch() { new BackendConnection( DatabaseId.of("p", "i", "d"), spannerConnection, + () -> WellKnownClient.UNSPECIFIED, mock(OptionsMetadata.class), - ImmutableList.of()); + ImmutableList::of); backendConnection.execute( PARSER.parse(Statement.of("CREATE TABLE \"Foo\" (id bigint primary key)")), @@ -168,8 +170,9 @@ public void testCopyPropagatesNonSpannerException() { new BackendConnection( DatabaseId.of("p", "i", "d"), spannerConnection, + () -> WellKnownClient.UNSPECIFIED, mock(OptionsMetadata.class), - ImmutableList.of()); + ImmutableList::of); Future result = backendConnection.executeCopy(parsedStatement, statement, receiver, writer, executor); backendConnection.flush(); @@ -205,8 +208,9 @@ public void testHasDmlOrCopyStatementsAfter() { new BackendConnection( DatabaseId.of("p", "i", "d"), spannerConnection, + () -> WellKnownClient.UNSPECIFIED, mock(OptionsMetadata.class), - ImmutableList.of()); + ImmutableList::of); onlyDmlStatements.execute(parsedUpdateStatement, updateStatement, Function.identity()); onlyDmlStatements.execute(parsedUpdateStatement, updateStatement, Function.identity()); assertTrue(onlyDmlStatements.hasDmlOrCopyStatementsAfter(0)); @@ -216,8 +220,9 @@ public void testHasDmlOrCopyStatementsAfter() { new BackendConnection( DatabaseId.of("p", "i", "d"), spannerConnection, + () -> WellKnownClient.UNSPECIFIED, mock(OptionsMetadata.class), - ImmutableList.of()); + ImmutableList::of); onlyCopyStatements.executeCopy(parsedCopyStatement, copyStatement, receiver, writer, executor); onlyCopyStatements.executeCopy(parsedCopyStatement, copyStatement, receiver, writer, executor); assertTrue(onlyCopyStatements.hasDmlOrCopyStatementsAfter(0)); @@ -227,8 +232,9 @@ public void testHasDmlOrCopyStatementsAfter() { new BackendConnection( DatabaseId.of("p", "i", "d"), spannerConnection, + () -> WellKnownClient.UNSPECIFIED, mock(OptionsMetadata.class), - ImmutableList.of()); + ImmutableList::of); dmlAndCopyStatements.execute(parsedUpdateStatement, updateStatement, Function.identity()); dmlAndCopyStatements.executeCopy( parsedCopyStatement, copyStatement, receiver, writer, executor); @@ -239,8 +245,9 @@ public void testHasDmlOrCopyStatementsAfter() { new BackendConnection( DatabaseId.of("p", "i", "d"), spannerConnection, + () -> WellKnownClient.UNSPECIFIED, mock(OptionsMetadata.class), - ImmutableList.of()); + ImmutableList::of); onlySelectStatements.execute(parsedSelectStatement, selectStatement, Function.identity()); onlySelectStatements.execute(parsedSelectStatement, selectStatement, Function.identity()); assertFalse(onlySelectStatements.hasDmlOrCopyStatementsAfter(0)); @@ -250,8 +257,9 @@ public void testHasDmlOrCopyStatementsAfter() { new BackendConnection( DatabaseId.of("p", "i", "d"), spannerConnection, + () -> WellKnownClient.UNSPECIFIED, mock(OptionsMetadata.class), - ImmutableList.of()); + ImmutableList::of); onlyClientSideStatements.execute( parsedClientSideStatement, clientSideStatement, Function.identity()); onlyClientSideStatements.execute( @@ -263,8 +271,9 @@ public void testHasDmlOrCopyStatementsAfter() { new BackendConnection( DatabaseId.of("p", "i", "d"), spannerConnection, + () -> WellKnownClient.UNSPECIFIED, mock(OptionsMetadata.class), - ImmutableList.of()); + ImmutableList::of); onlyUnknownStatements.execute(parsedUnknownStatement, unknownStatement, Function.identity()); onlyUnknownStatements.execute(parsedUnknownStatement, unknownStatement, Function.identity()); assertFalse(onlyUnknownStatements.hasDmlOrCopyStatementsAfter(0)); @@ -274,8 +283,9 @@ public void testHasDmlOrCopyStatementsAfter() { new BackendConnection( DatabaseId.of("p", "i", "d"), spannerConnection, + () -> WellKnownClient.UNSPECIFIED, mock(OptionsMetadata.class), - ImmutableList.of()); + ImmutableList::of); dmlAndSelectStatements.execute(parsedUpdateStatement, updateStatement, Function.identity()); dmlAndSelectStatements.execute(parsedSelectStatement, selectStatement, Function.identity()); assertTrue(dmlAndSelectStatements.hasDmlOrCopyStatementsAfter(0)); @@ -285,8 +295,9 @@ public void testHasDmlOrCopyStatementsAfter() { new BackendConnection( DatabaseId.of("p", "i", "d"), spannerConnection, + () -> WellKnownClient.UNSPECIFIED, mock(OptionsMetadata.class), - ImmutableList.of()); + ImmutableList::of); copyAndSelectStatements.executeCopy( parsedCopyStatement, copyStatement, receiver, writer, executor); copyAndSelectStatements.execute(parsedSelectStatement, selectStatement, Function.identity()); @@ -297,8 +308,9 @@ public void testHasDmlOrCopyStatementsAfter() { new BackendConnection( DatabaseId.of("p", "i", "d"), spannerConnection, + () -> WellKnownClient.UNSPECIFIED, mock(OptionsMetadata.class), - ImmutableList.of()); + ImmutableList::of); copyAndUnknownStatements.executeCopy( parsedCopyStatement, copyStatement, receiver, writer, executor); copyAndUnknownStatements.execute(parsedUnknownStatement, unknownStatement, Function.identity()); @@ -323,7 +335,11 @@ public void testExecuteLocalStatement() throws ExecutionException, InterruptedEx BackendConnection backendConnection = new BackendConnection( - DatabaseId.of("p", "i", "d"), connection, mock(OptionsMetadata.class), localStatements); + DatabaseId.of("p", "i", "d"), + connection, + () -> WellKnownClient.UNSPECIFIED, + mock(OptionsMetadata.class), + () -> localStatements); Future resultFuture = backendConnection.execute( parsedListDatabasesStatement, @@ -355,7 +371,11 @@ public void testExecuteOtherStatementWithLocalStatements() BackendConnection backendConnection = new BackendConnection( - DatabaseId.of("p", "i", "d"), connection, mock(OptionsMetadata.class), localStatements); + DatabaseId.of("p", "i", "d"), + connection, + () -> WellKnownClient.UNSPECIFIED, + mock(OptionsMetadata.class), + () -> localStatements); Future resultFuture = backendConnection.execute(parsedStatement, statement, Function.identity()); backendConnection.flush(); @@ -388,8 +408,9 @@ public void testGeneralException() { new BackendConnection( DatabaseId.of("p", "i", "d"), connection, + () -> WellKnownClient.UNSPECIFIED, mock(OptionsMetadata.class), - EMPTY_LOCAL_STATEMENTS); + () -> EMPTY_LOCAL_STATEMENTS); Future resultFuture = backendConnection.execute(parsedStatement, statement, Function.identity()); backendConnection.flush(); @@ -412,8 +433,9 @@ public void testCancelledException() { new BackendConnection( DatabaseId.of("p", "i", "d"), connection, + () -> WellKnownClient.UNSPECIFIED, mock(OptionsMetadata.class), - EMPTY_LOCAL_STATEMENTS); + () -> EMPTY_LOCAL_STATEMENTS); Future resultFuture = backendConnection.execute(parsedStatement, statement, Function.identity()); backendConnection.flush(); @@ -449,8 +471,9 @@ public void testDdlExceptionInBatch() { new BackendConnection( DatabaseId.of("p", "i", "d"), connection, + () -> WellKnownClient.UNSPECIFIED, mock(OptionsMetadata.class), - EMPTY_LOCAL_STATEMENTS); + () -> EMPTY_LOCAL_STATEMENTS); Future resultFuture1 = backendConnection.execute(parsedStatement1, statement1, Function.identity()); backendConnection.execute(parsedStatement2, statement2, Function.identity()); @@ -474,7 +497,11 @@ public void testReplacePgCatalogTables() { BackendConnection backendConnection = new BackendConnection( - DatabaseId.of("p", "i", "d"), connection, options, EMPTY_LOCAL_STATEMENTS); + DatabaseId.of("p", "i", "d"), + connection, + () -> WellKnownClient.UNSPECIFIED, + options, + () -> EMPTY_LOCAL_STATEMENTS); backendConnection.execute(parsedStatement, statement, Function.identity()); backendConnection.flush(); @@ -517,7 +544,11 @@ public void testDisableReplacePgCatalogTables() { BackendConnection backendConnection = new BackendConnection( - DatabaseId.of("p", "i", "d"), connection, options, EMPTY_LOCAL_STATEMENTS); + DatabaseId.of("p", "i", "d"), + connection, + () -> WellKnownClient.UNSPECIFIED, + options, + () -> EMPTY_LOCAL_STATEMENTS); backendConnection.execute(parsedStatement, statement, Function.identity()); backendConnection.flush(); @@ -537,8 +568,9 @@ public void testDoNotStartTransactionInBatch() { new BackendConnection( DatabaseId.of("p", "i", "d"), connection, + () -> WellKnownClient.UNSPECIFIED, mock(OptionsMetadata.class), - EMPTY_LOCAL_STATEMENTS); + () -> EMPTY_LOCAL_STATEMENTS); backendConnection.execute(parsedStatement, statement, Function.identity()); backendConnection.flush(); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatementTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatementTest.java index 3c32d8592..f118d104a 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatementTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatementTest.java @@ -15,6 +15,7 @@ package com.google.cloud.spanner.pgadapter.statements; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.mockito.Mockito.when; import com.google.cloud.spanner.pgadapter.ConnectionHandler; @@ -91,4 +92,14 @@ public void testRecordSeparator() { new CopyToStatement(connectionHandler, options, "", parsedCopyStatement); assertEquals("\n", statement.getCsvFormat().getRecordSeparator()); } + + @Test + public void testCopyQuery() { + String query = + "select executed_at, workload, threads, batch_size, operation_count, round(read_avg/1000) as read_avg, round(read_p95/1000) as read_p95 from run where true and workload='a' and threads=20 and deployment='java_uds' order by executed_at desc limit 100"; + ParsedCopyStatement parsedCopyStatement = + CopyStatement.parse(String.format("copy (%s) to stdout\n", query)); + assertNotNull(parsedCopyStatement); + assertEquals(parsedCopyStatement.query, query); + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/statements/SimpleParserTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/statements/SimpleParserTest.java index cd04d756e..838082470 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/statements/SimpleParserTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/statements/SimpleParserTest.java @@ -461,4 +461,15 @@ public void testReadDoubleQuotedString() { public void testUnescapeQuotedStringValue() { assertEquals("'", unescapeQuotedStringValue("'\\''", '\'')); } + + @Test + public void testParserTableOrIndexName() { + assertEquals(TableOrIndexName.of("foo"), TableOrIndexName.parse("foo")); + assertEquals( + TableOrIndexName.of(/* schema= */ "foo", /* name= */ "bar"), + TableOrIndexName.parse("foo.bar")); + + assertThrows(IllegalArgumentException.class, () -> TableOrIndexName.parse("")); + assertThrows(IllegalArgumentException.class, () -> TableOrIndexName.parse("foo.bar baz")); + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/statements/StatementTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/statements/StatementTest.java index 19b87e5bd..9e0a67fd6 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/statements/StatementTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/statements/StatementTest.java @@ -39,6 +39,7 @@ import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.Value; import com.google.cloud.spanner.connection.AbstractStatementParser; import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.AbstractStatementParser.StatementType; @@ -52,6 +53,7 @@ import com.google.cloud.spanner.pgadapter.metadata.ConnectionMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.session.SessionState; +import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient; import com.google.cloud.spanner.pgadapter.utils.MutationWriter; import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage; import com.google.cloud.spanner.pgadapter.wireprotocol.QueryMessage; @@ -194,7 +196,11 @@ public void testBasicZeroUpdateCountResultStatement() throws Exception { ImmutableList.of()); BackendConnection backendConnection = new BackendConnection( - connectionHandler.getDatabaseId(), connection, options, EMPTY_LOCAL_STATEMENTS); + connectionHandler.getDatabaseId(), + connection, + () -> WellKnownClient.UNSPECIFIED, + options, + () -> EMPTY_LOCAL_STATEMENTS); assertFalse(intermediateStatement.isExecuted()); assertEquals("UPDATE", intermediateStatement.getCommand()); @@ -280,7 +286,11 @@ public void testBasicStatementExceptionGetsSetOnExceptedExecution() { ImmutableList.of()); BackendConnection backendConnection = new BackendConnection( - connectionHandler.getDatabaseId(), connection, options, EMPTY_LOCAL_STATEMENTS); + connectionHandler.getDatabaseId(), + connection, + () -> WellKnownClient.UNSPECIFIED, + options, + () -> EMPTY_LOCAL_STATEMENTS); intermediateStatement.executeAsync(backendConnection); backendConnection.flush(); @@ -315,7 +325,11 @@ public void testPreparedStatement() { .build(); BackendConnection backendConnection = new BackendConnection( - connectionHandler.getDatabaseId(), connection, options, EMPTY_LOCAL_STATEMENTS); + connectionHandler.getDatabaseId(), + connection, + () -> WellKnownClient.UNSPECIFIED, + options, + () -> EMPTY_LOCAL_STATEMENTS); IntermediatePreparedStatement intermediateStatement = new IntermediatePreparedStatement( @@ -371,8 +385,10 @@ public void testPreparedStatementIllegalTypeThrowsException() { IntermediatePortalStatement portalStatement = intermediateStatement.createPortal("", parameters, new ArrayList<>(), new ArrayList<>()); - assertThrows( - IllegalArgumentException.class, () -> portalStatement.bind(Statement.of(sqlStatement))); + Statement boundStatement = portalStatement.bind(Statement.of(sqlStatement)); + assertEquals( + Value.untyped(com.google.protobuf.Value.newBuilder().setStringValue("{}").build()), + boundStatement.getParameters().get("p1")); } @Test @@ -416,7 +432,11 @@ public void testPortalStatementDescribePropagatesFailure() { ImmutableList.of()); BackendConnection backendConnection = new BackendConnection( - connectionHandler.getDatabaseId(), connection, options, EMPTY_LOCAL_STATEMENTS); + connectionHandler.getDatabaseId(), + connection, + () -> WellKnownClient.UNSPECIFIED, + options, + () -> EMPTY_LOCAL_STATEMENTS); when(connection.execute(Statement.of(sqlStatement))) .thenThrow( @@ -478,7 +498,11 @@ public void testCopyInvalidBuildMutation() throws Exception { BackendConnection backendConnection = new BackendConnection( - DatabaseId.of("p", "i", "d"), connection, options, ImmutableList.of()); + DatabaseId.of("p", "i", "d"), + connection, + () -> WellKnownClient.UNSPECIFIED, + options, + ImmutableList::of); statement.executeAsync(backendConnection); ExecutorService executor = Executors.newSingleThreadExecutor(); @@ -546,7 +570,11 @@ public void testGetStatementResultBeforeFlushFails() { ImmutableList.of()); BackendConnection backendConnection = new BackendConnection( - connectionHandler.getDatabaseId(), connection, options, EMPTY_LOCAL_STATEMENTS); + connectionHandler.getDatabaseId(), + connection, + () -> WellKnownClient.UNSPECIFIED, + options, + () -> EMPTY_LOCAL_STATEMENTS); intermediateStatement.executeAsync(backendConnection); @@ -567,7 +595,11 @@ public void testCopyBatchSizeLimit() throws Exception { setupQueryInformationSchemaResults(); BackendConnection backendConnection = new BackendConnection( - DatabaseId.of("p", "i", "d"), connection, options, ImmutableList.of()); + DatabaseId.of("p", "i", "d"), + connection, + () -> WellKnownClient.UNSPECIFIED, + options, + ImmutableList::of); byte[] payload = Files.readAllBytes(Paths.get("./src/test/resources/batch-size-test.txt")); @@ -616,7 +648,11 @@ public void testCopyDataRowLengthMismatchLimit() throws Exception { setupQueryInformationSchemaResults(); BackendConnection backendConnection = new BackendConnection( - DatabaseId.of("p", "i", "d"), connection, options, ImmutableList.of()); + DatabaseId.of("p", "i", "d"), + connection, + () -> WellKnownClient.UNSPECIFIED, + options, + ImmutableList::of); byte[] payload = "1\t'one'\n2".getBytes(); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParserTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParserTest.java index 3bb7ff7d5..be8ba3849 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParserTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParserTest.java @@ -361,4 +361,67 @@ public void testBinaryRecord() { assertThrows(IllegalArgumentException.class, () -> record.getValue(Type.string(), 10)); assertThrows(SpannerException.class, () -> record.getValue(Type.numeric(), 0)); } + + @Test + public void testIsNull() throws IOException { + PipedOutputStream pipedOutputStream = new PipedOutputStream(); + BinaryCopyParser parser = new BinaryCopyParser(new PipedInputStream(pipedOutputStream, 256)); + + DataOutputStream data = new DataOutputStream(pipedOutputStream); + data.write(COPY_BINARY_HEADER); + data.writeInt(0); + data.writeInt(0); + + // Write a tuple with two non-null fields. + data.writeShort(2); + data.writeInt(8); + data.writeLong(1L); + data.writeInt(8); + data.writeLong(2L); + + // Write a tuple with one non-null and one null field. + data.writeShort(2); + data.writeInt(8); + data.writeLong(1L); + data.writeInt(-1); + + // Write a tuple with one null and one non-null field. + data.writeShort(2); + data.writeInt(-1); + data.writeInt(8); + data.writeLong(2L); + + // Write a tuple with two non-null fields. + data.writeShort(2); + data.writeInt(-1); + data.writeInt(-1); + + // Trailer. + data.writeShort(-1); + + Iterator iterator = parser.iterator(); + assertTrue(iterator.hasNext()); + CopyRecord record1 = iterator.next(); + assertFalse(record1.isNull(0)); + assertFalse(record1.isNull(1)); + assertThrows(IllegalArgumentException.class, () -> record1.isNull(-1)); + assertThrows(IllegalArgumentException.class, () -> record1.isNull(2)); + + assertTrue(iterator.hasNext()); + CopyRecord record2 = iterator.next(); + assertFalse(record2.isNull(0)); + assertTrue(record2.isNull(1)); + + assertTrue(iterator.hasNext()); + CopyRecord record3 = iterator.next(); + assertTrue(record3.isNull(0)); + assertFalse(record3.isNull(1)); + + assertTrue(iterator.hasNext()); + CopyRecord record4 = iterator.next(); + assertTrue(record4.isNull(0)); + assertTrue(record4.isNull(1)); + + assertFalse(iterator.hasNext()); + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetectorTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetectorTest.java index 8b6fd843a..790258ed7 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetectorTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetectorTest.java @@ -24,9 +24,12 @@ import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.ProxyServer; +import com.google.cloud.spanner.pgadapter.error.PGException; +import com.google.cloud.spanner.pgadapter.metadata.ConnectionMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.statements.local.ListDatabasesStatement; import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient; +import com.google.cloud.spanner.pgadapter.wireoutput.NoticeResponse; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.ArrayList; @@ -37,6 +40,39 @@ @RunWith(JUnit4.class) public class ClientAutoDetectorTest { + + @Test + public void testUnspecified() { + WellKnownClient.UNSPECIFIED.reset(); + + assertEquals( + WellKnownClient.UNSPECIFIED, + ClientAutoDetector.detectClient(ImmutableList.of(), ImmutableMap.of())); + assertEquals( + WellKnownClient.UNSPECIFIED, + ClientAutoDetector.detectClient( + ImmutableList.of("some-param"), ImmutableMap.of("some-param", "some-value"))); + + ConnectionHandler connection = mock(ConnectionHandler.class); + ProxyServer server = mock(ProxyServer.class); + OptionsMetadata options = mock(OptionsMetadata.class); + when(server.getOptions()).thenReturn(options); + when(connection.getServer()).thenReturn(server); + + when(options.useDefaultLocalStatements()).thenReturn(true); + assertEquals( + DEFAULT_LOCAL_STATEMENTS, WellKnownClient.UNSPECIFIED.getLocalStatements(connection)); + + when(options.useDefaultLocalStatements()).thenReturn(false); + assertEquals(ImmutableList.of(), WellKnownClient.UNSPECIFIED.getLocalStatements(connection)); + + assertEquals( + ImmutableList.of(), WellKnownClient.UNSPECIFIED.createStartupNoticeResponses(connection)); + assertEquals( + ImmutableList.of(), WellKnownClient.UNSPECIFIED.getErrorHints(mock(PGException.class))); + assertEquals(ImmutableMap.of(), WellKnownClient.UNSPECIFIED.getDefaultParameters()); + } + @Test public void testJdbc() { // The JDBC driver will **always** include these startup parameters in exactly this order. @@ -291,4 +327,52 @@ public void testNpgsql() { + "SELECT ns.nspname, t.oid, t.typname, t.typtype, t.typnotnull, t.elemtypoid\n"), Statement.of("SELECT version();")))); } + + @Test + public void testPgbench() { + // pgbench always includes itself as the application_name. + assertEquals( + WellKnownClient.PGBENCH, + ClientAutoDetector.detectClient( + ImmutableList.of("application_name"), ImmutableMap.of("application_name", "pgbench"))); + assertEquals( + WellKnownClient.PGBENCH, + ClientAutoDetector.detectClient( + ImmutableList.of("application_name", "some-param"), + ImmutableMap.of( + "application_name", "pgbench", + "some-param", "some-value"))); + assertEquals( + WellKnownClient.PGBENCH, + ClientAutoDetector.detectClient( + ImmutableList.of("some-param1", "application_name", "some-param2"), + ImmutableMap.of( + "some-param1", "some-value1", + "application_name", "pgbench", + "some-param2", "some-value2"))); + assertNotEquals( + WellKnownClient.PGBENCH, + ClientAutoDetector.detectClient(ImmutableList.of(), ImmutableMap.of())); + assertNotEquals( + WellKnownClient.PGBENCH, + ClientAutoDetector.detectClient( + ImmutableList.of("some-param1"), ImmutableMap.of("some-param1", "some-value1"))); + assertNotEquals( + WellKnownClient.PGBENCH, + ClientAutoDetector.detectClient( + ImmutableList.of("application_name"), ImmutableMap.of("application_name", "JDBC"))); + assertNotEquals( + WellKnownClient.PGBENCH, + ClientAutoDetector.detectClient( + ImmutableList.of("application_name"), ImmutableMap.of("application_name", "PGBENCH"))); + + WellKnownClient.PGBENCH.reset(); + ConnectionHandler connectionHandler = mock(ConnectionHandler.class); + when(connectionHandler.getConnectionMetadata()).thenReturn(mock(ConnectionMetadata.class)); + ImmutableList startupNotices = + WellKnownClient.PGBENCH.createStartupNoticeResponses(connectionHandler); + assertEquals(1, startupNotices.size()); + assertEquals("Detected connection from pgbench", startupNotices.get(0).getMessage()); + assertEquals(ClientAutoDetector.PGBENCH_USAGE_HINT + "\n", startupNotices.get(0).getHint()); + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParserTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParserTest.java index 15614b9f4..3138ced56 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParserTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParserTest.java @@ -15,8 +15,10 @@ package com.google.cloud.spanner.pgadapter.utils; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; import com.google.cloud.ByteArray; @@ -33,6 +35,7 @@ import java.io.PipedInputStream; import java.io.PipedOutputStream; import java.nio.charset.StandardCharsets; +import java.util.Iterator; import org.apache.commons.csv.CSVFormat; import org.junit.Test; import org.junit.runner.RunWith; @@ -43,7 +46,6 @@ public class CsvCopyParserTest { @Test public void testCanCreateIterator() throws IOException { - CsvCopyParser parser = new CsvCopyParser( mock(SessionState.class), @@ -293,4 +295,60 @@ public void testGetSpannerValue_UnsupportedType() { CsvCopyRecord.getSpannerValue( sessionState, Type.struct(StructField.of("f1", Type.string())), "value")); } + + @Test + public void testIsNull() throws IOException { + PipedOutputStream outputStream = new PipedOutputStream(); + DataOutputStream data = new DataOutputStream(outputStream); + new Thread( + () -> { + while (true) { + try { + data.write("\"value1\"\t\"value2\"\n".getBytes(StandardCharsets.UTF_8)); + data.write("\"value1\"\t\\N\n".getBytes(StandardCharsets.UTF_8)); + data.write("\\N\t\"value2\"\n".getBytes(StandardCharsets.UTF_8)); + data.write("\\N\t\\N\n".getBytes(StandardCharsets.UTF_8)); + data.close(); + break; + } catch (IOException e) { + if (e.getMessage().contains("Pipe not connected")) { + Thread.yield(); + } else { + throw new RuntimeException(e); + } + } + } + }) + .start(); + CsvCopyParser parser = + new CsvCopyParser( + mock(SessionState.class), + CSVFormat.POSTGRESQL_TEXT, + new PipedInputStream(outputStream, 256), + false); + Iterator iterator = parser.iterator(); + + assertTrue(iterator.hasNext()); + CopyRecord record = iterator.next(); + assertFalse(record.isNull(0)); + assertFalse(record.isNull(1)); + + assertTrue(iterator.hasNext()); + record = iterator.next(); + assertFalse(record.isNull(0)); + assertTrue(record.isNull(1)); + + assertTrue(iterator.hasNext()); + record = iterator.next(); + assertTrue(record.isNull(0)); + assertFalse(record.isNull(1)); + + assertTrue(iterator.hasNext()); + record = iterator.next(); + assertTrue(record.isNull(0)); + assertTrue(record.isNull(1)); + + assertFalse(iterator.hasNext()); + parser.close(); + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/utils/MutationWriterTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/utils/MutationWriterTest.java index 3f8aca1d9..815c8d80e 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/utils/MutationWriterTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/utils/MutationWriterTest.java @@ -91,7 +91,7 @@ public void testParsePartialPayload() throws IOException { CSVParser parser = CSVParser.parse(reader, format); // Pass in 2 complete and one incomplete record. It should be possible to parse the two first // records without problems. - String records = "1\t\"One\"\n2\t\"Two\"\n3\t"; + String records = "1\tOne\n2\tTwo\n3\t"; writer.write(records); writer.flush(); @@ -110,7 +110,7 @@ record = iterator.next(); // Calling iterator.hasNext() or iterator.next() would now block, as there is not enough data // to build another record. // Add the missing pieces for the last record and parse that as well. - writer.write("\"Three\"\n"); + writer.write("Three\n"); writer.close(); assertTrue(iterator.hasNext()); @@ -150,8 +150,7 @@ public void testWriteMutations() throws Exception { executor.submit( () -> { try { - mutationWriter.addCopyData( - "1\t\"One\"\n2\t\"Two\"\n".getBytes(StandardCharsets.UTF_8)); + mutationWriter.addCopyData("1\tOne\n2\tTwo\n".getBytes(StandardCharsets.UTF_8)); mutationWriter.commit(); mutationWriter.close(); } catch (IOException ignore) { @@ -309,8 +308,7 @@ public void testWriteMutations_FailsForLargeCommit() throws Exception { executor.submit( () -> { try { - mutationWriter.addCopyData( - "1\t\"One\"\n2\t\"Two\"\n".getBytes(StandardCharsets.UTF_8)); + mutationWriter.addCopyData("1\tOne\n2\tTwo\n".getBytes(StandardCharsets.UTF_8)); mutationWriter.close(); } catch (IOException ignore) { } @@ -357,8 +355,7 @@ public void testWriteMutations_NonAtomic_SucceedsForLargeCommit() throws Excepti () -> { try { mutationWriter.addCopyData( - "1\t\"One\"\n2\t\"Two\"\n3\t\"Three\"\n4\t\"Four\"\n5\t\"Five\"\n" - .getBytes(StandardCharsets.UTF_8)); + "1\tOne\n2\tTwo\n3\tThree\n4\tFour\n5\tFive\n".getBytes(StandardCharsets.UTF_8)); mutationWriter.commit(); mutationWriter.close(); } catch (IOException ignore) { @@ -409,13 +406,12 @@ public void testWritePartials() throws Exception { ExecutorService executor = Executors.newFixedThreadPool(2); executor.submit( () -> { - mutationWriter.addCopyData("1\t\"One\"\n".getBytes(StandardCharsets.UTF_8)); - mutationWriter.addCopyData("2\t\"Two".getBytes(StandardCharsets.UTF_8)); - mutationWriter.addCopyData("\"".getBytes(StandardCharsets.UTF_8)); + mutationWriter.addCopyData("1\tOne\n".getBytes(StandardCharsets.UTF_8)); + mutationWriter.addCopyData("2\tTwo".getBytes(StandardCharsets.UTF_8)); + mutationWriter.addCopyData("".getBytes(StandardCharsets.UTF_8)); mutationWriter.addCopyData("\n3\t".getBytes(StandardCharsets.UTF_8)); - mutationWriter.addCopyData( - "\"Three\"\n4\t\"Four\"\n5\t".getBytes(StandardCharsets.UTF_8)); - mutationWriter.addCopyData("\"Five\"\n".getBytes(StandardCharsets.UTF_8)); + mutationWriter.addCopyData("Three\n4\tFour\n5\t".getBytes(StandardCharsets.UTF_8)); + mutationWriter.addCopyData("Five\n".getBytes(StandardCharsets.UTF_8)); mutationWriter.commit(); mutationWriter.close(); return null; diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyInResponseTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyInResponseTest.java new file mode 100644 index 000000000..1f095daf6 --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyInResponseTest.java @@ -0,0 +1,70 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.wireoutput; + +import static org.junit.Assert.assertArrayEquals; + +import com.google.cloud.spanner.pgadapter.ProxyServer.DataFormat; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CopyInResponseTest { + + @Test + public void testSendPayloadText() throws Exception { + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + DataOutputStream output = new DataOutputStream(bytes); + short numColumns = 3; + + CopyInResponse response = + new CopyInResponse(output, numColumns, (byte) DataFormat.POSTGRESQL_TEXT.getCode()); + response.sendPayload(); + + assertArrayEquals( + new byte[] { + 0, // format + 0, 3, // numColumns + 0, 0, // format column 1 + 0, 0, // format column 2 + 0, 0, // format column 3 + }, + bytes.toByteArray()); + } + + @Test + public void testSendPayloadBinary() throws Exception { + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + DataOutputStream output = new DataOutputStream(bytes); + short numColumns = 3; + + CopyInResponse response = + new CopyInResponse(output, numColumns, (byte) DataFormat.POSTGRESQL_BINARY.getCode()); + response.sendPayload(); + + assertArrayEquals( + new byte[] { + 1, // format + 0, 3, // numColumns + 0, 1, // format column 1 + 0, 1, // format column 2 + 0, 1, // format column 3 + }, + bytes.toByteArray()); + } +} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/wireoutput/ErrorResponseTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/wireoutput/ErrorResponseTest.java index 443c1ee65..9dc392444 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/wireoutput/ErrorResponseTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/wireoutput/ErrorResponseTest.java @@ -16,10 +16,16 @@ import static com.google.cloud.spanner.pgadapter.wireoutput.ErrorResponse.calculateLength; import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.error.PGException; import com.google.cloud.spanner.pgadapter.error.SQLState; import com.google.cloud.spanner.pgadapter.error.Severity; +import com.google.cloud.spanner.pgadapter.metadata.ConnectionMetadata; +import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector; +import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import org.junit.Test; @@ -44,7 +50,8 @@ public void testCalculateLength() { PGException.newBuilder("test message") .setSeverity(Severity.ERROR) .setSQLState(SQLState.RaiseException) - .build())); + .build(), + WellKnownClient.UNSPECIFIED)); assertEquals( 4 + "test message".length() @@ -61,16 +68,59 @@ public void testCalculateLength() { .setSeverity(Severity.ERROR) .setSQLState(SQLState.RaiseException) .setHints("test hint") - .build())); + .build(), + WellKnownClient.UNSPECIFIED)); + + assertEquals( + 4 + + "test message".length() + + /* Field header + null terminator */ 2 + + "ERROR".length() + + 2 + + "P0001".length() + + 2 + + ClientAutoDetector.PGBENCH_USAGE_HINT.length() + + 2 + + 1, + calculateLength( + PGException.newBuilder("test message") + .setSeverity(Severity.ERROR) + .setSQLState(SQLState.RaiseException) + .build(), + WellKnownClient.PGBENCH)); + assertEquals( + 4 + + "test message".length() + + /* Field header + null terminator */ 2 + + "ERROR".length() + + 2 + + "P0001".length() + + 2 + + "test hint\n".length() + + ClientAutoDetector.PGBENCH_USAGE_HINT.length() + + 2 + + 1, + calculateLength( + PGException.newBuilder("test message") + .setSeverity(Severity.ERROR) + .setSQLState(SQLState.RaiseException) + .setHints("test hint") + .build(), + WellKnownClient.PGBENCH)); } @Test public void testSendPayload() throws Exception { ByteArrayOutputStream out = new ByteArrayOutputStream(); DataOutputStream output = new DataOutputStream(out); + ConnectionMetadata metadata = mock(ConnectionMetadata.class); + when(metadata.getOutputStream()).thenReturn(output); + ConnectionHandler connectionHandler = mock(ConnectionHandler.class); + when(connectionHandler.getWellKnownClient()).thenReturn(WellKnownClient.UNSPECIFIED); + when(connectionHandler.getConnectionMetadata()).thenReturn(metadata); ErrorResponse response = new ErrorResponse( - output, + connectionHandler, PGException.newBuilder("test message") .setSeverity(Severity.ERROR) .setSQLState(SQLState.RaiseException) @@ -86,9 +136,14 @@ public void testSendPayload() throws Exception { public void testSendPayloadWithHint() throws Exception { ByteArrayOutputStream out = new ByteArrayOutputStream(); DataOutputStream output = new DataOutputStream(out); + ConnectionMetadata metadata = mock(ConnectionMetadata.class); + when(metadata.getOutputStream()).thenReturn(output); + ConnectionHandler connectionHandler = mock(ConnectionHandler.class); + when(connectionHandler.getWellKnownClient()).thenReturn(WellKnownClient.UNSPECIFIED); + when(connectionHandler.getConnectionMetadata()).thenReturn(metadata); ErrorResponse response = new ErrorResponse( - output, + connectionHandler, PGException.newBuilder("test message") .setSeverity(Severity.ERROR) .setSQLState(SQLState.RaiseException) @@ -102,4 +157,63 @@ public void testSendPayloadWithHint() throws Exception { "Length: 51, Error Message: test message, Hints: test hint\n" + "line 2", response.getPayloadString()); } + + @Test + public void testSendPayloadWithClientHint() throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + DataOutputStream output = new DataOutputStream(out); + ConnectionMetadata metadata = mock(ConnectionMetadata.class); + when(metadata.getOutputStream()).thenReturn(output); + ConnectionHandler connectionHandler = mock(ConnectionHandler.class); + when(connectionHandler.getWellKnownClient()).thenReturn(WellKnownClient.PGBENCH); + when(connectionHandler.getConnectionMetadata()).thenReturn(metadata); + ErrorResponse response = + new ErrorResponse( + connectionHandler, + PGException.newBuilder("test message") + .setSeverity(Severity.ERROR) + .setSQLState(SQLState.RaiseException) + .build()); + + response.sendPayload(); + + assertEquals( + String.format( + "SERROR\0CP0001\0Mtest message\0H%s\0\0", ClientAutoDetector.PGBENCH_USAGE_HINT), + out.toString()); + assertEquals( + "Length: 148, Error Message: test message, Hints: " + ClientAutoDetector.PGBENCH_USAGE_HINT, + response.getPayloadString()); + } + + @Test + public void testSendPayloadWithErrorAndClientHint() throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + DataOutputStream output = new DataOutputStream(out); + ConnectionMetadata metadata = mock(ConnectionMetadata.class); + when(metadata.getOutputStream()).thenReturn(output); + ConnectionHandler connectionHandler = mock(ConnectionHandler.class); + when(connectionHandler.getWellKnownClient()).thenReturn(WellKnownClient.PGBENCH); + when(connectionHandler.getConnectionMetadata()).thenReturn(metadata); + ErrorResponse response = + new ErrorResponse( + connectionHandler, + PGException.newBuilder("test message") + .setSeverity(Severity.ERROR) + .setSQLState(SQLState.RaiseException) + .setHints("test hint\nline 2") + .build()); + + response.sendPayload(); + + assertEquals( + String.format( + "SERROR\0CP0001\0Mtest message\0Htest hint\nline 2\n%s\0\0", + ClientAutoDetector.PGBENCH_USAGE_HINT), + out.toString()); + assertEquals( + "Length: 165, Error Message: test message, Hints: test hint\nline 2\n" + + ClientAutoDetector.PGBENCH_USAGE_HINT, + response.getPayloadString()); + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/wireoutput/NoticeResponseTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/wireoutput/NoticeResponseTest.java new file mode 100644 index 000000000..3a4b628be --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/pgadapter/wireoutput/NoticeResponseTest.java @@ -0,0 +1,69 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.wireoutput; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.spanner.pgadapter.error.SQLState; +import com.google.cloud.spanner.pgadapter.wireoutput.NoticeResponse.NoticeSeverity; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class NoticeResponseTest { + + @Test + public void testBasics() { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + DataOutputStream output = new DataOutputStream(out); + NoticeResponse response = + new NoticeResponse(output, SQLState.Success, NoticeSeverity.NOTICE, "test notice", null); + + assertEquals('N', response.getIdentifier()); + assertEquals("Notice", response.getMessageName()); + } + + @Test + public void testSendPayload() throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + DataOutputStream output = new DataOutputStream(out); + NoticeResponse response = + new NoticeResponse(output, SQLState.Success, NoticeSeverity.NOTICE, "test notice", null); + response.sendPayload(); + + assertEquals("SNOTICE\0C00000\0Mtest notice\0\0", out.toString()); + assertEquals( + "Length: 33, Severity: NOTICE, Notice Message: test notice, Hint: ", + response.getPayloadString()); + } + + @Test + public void testSendPayloadWithHint() throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + DataOutputStream output = new DataOutputStream(out); + NoticeResponse response = + new NoticeResponse( + output, SQLState.Success, NoticeSeverity.NOTICE, "test notice", "some hint"); + response.sendPayload(); + + assertEquals("SNOTICE\0C00000\0Mtest notice\0Hsome hint\0\0", out.toString()); + assertEquals( + "Length: 44, Severity: NOTICE, Notice Message: test notice, Hint: some hint", + response.getPayloadString()); + } +} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ProtocolTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ProtocolTest.java index 43cc6e714..8e6f4df6f 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ProtocolTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ProtocolTest.java @@ -986,6 +986,7 @@ public void testExecuteMessageWithException() throws Exception { PGExceptionFactory.newPGException("test error", SQLState.SyntaxError); when(intermediatePortalStatement.hasException()).thenReturn(true); when(intermediatePortalStatement.getException()).thenReturn(testException); + when(connectionHandler.getWellKnownClient()).thenReturn(WellKnownClient.UNSPECIFIED); when(connectionHandler.getPortal(anyString())).thenReturn(intermediatePortalStatement); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); @@ -1890,6 +1891,7 @@ public void testRepeatedCopyDataInNormalMode_TerminatesConnectionAndReturnsError when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(connectionHandler.getWellKnownClient()).thenReturn(WellKnownClient.UNSPECIFIED); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.AUTHENTICATED); doCallRealMethod().when(connectionHandler).increaseInvalidMessageCount(); diff --git a/src/test/python/sqlalchemy/all_types.py b/src/test/python/sqlalchemy/all_types.py new file mode 100644 index 000000000..132df9945 --- /dev/null +++ b/src/test/python/sqlalchemy/all_types.py @@ -0,0 +1,36 @@ +from sqlalchemy import Column, Integer, String, Boolean, LargeBinary, Float,\ + Numeric, DateTime, Date +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import registry + +mapper_registry = registry() +Base = mapper_registry.generate_base() + + +class AllTypes(Base): + __tablename__ = "all_types" + + col_bigint = Column(Integer, primary_key=True) + col_bool = Column(Boolean) + col_bytea = Column(LargeBinary) + col_float8 = Column(Float) + col_int = Column(Integer) + col_numeric = Column(Numeric) + col_timestamptz = Column(DateTime(timezone=True)) + col_date = Column(Date) + col_varchar = Column(String) + col_jsonb = Column(JSONB) + + def __repr__(self): + return f"AllTypes(" \ + f"col_bigint= {self.col_bigint!r}," \ + f"col_bool= {self.col_bool!r}," \ + f"col_bytea= {self.col_bytea!r}" \ + f"col_float8= {self.col_float8!r}" \ + f"col_int= {self.col_int!r}" \ + f"col_numeric= {self.col_numeric!r}" \ + f"col_timestamptz={self.col_timestamptz!r}" \ + f"col_date= {self.col_date!r}" \ + f"col_varchar= {self.col_varchar!r}" \ + f"col_jsonb= {self.col_jsonb!r}" \ + f")" diff --git a/src/test/python/sqlalchemy/autocommit.py b/src/test/python/sqlalchemy/autocommit.py new file mode 100644 index 000000000..1b411ebce --- /dev/null +++ b/src/test/python/sqlalchemy/autocommit.py @@ -0,0 +1,26 @@ +from connect import create_test_engine +from user_metadata import user_table +from sqlalchemy import select, insert + +stmt = select(user_table).where(user_table.c.name == "spongebob") +engine = create_test_engine(autocommit=True) +with engine.connect() as conn: + print(conn.get_isolation_level()) + for row in conn.execute(stmt): + print(row) + + result = conn.execute( + insert(user_table), + [ + {"name": "sandy", "fullname": "Sandy Cheeks"}, + ], + ) + print("Row count: {}".format(result.rowcount)) + + result = conn.execute( + insert(user_table), + [ + {"name": "patrick", "fullname": "Patrick Star"}, + ], + ) + print("Row count: {}".format(result.rowcount)) diff --git a/src/test/python/sqlalchemy/connect.py b/src/test/python/sqlalchemy/connect.py new file mode 100644 index 000000000..a7e1af32f --- /dev/null +++ b/src/test/python/sqlalchemy/connect.py @@ -0,0 +1,72 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import argparse + +from psycopg2 import sql +from psycopg2.extensions import register_adapter, AsIs +from psycopg2.extras import Json +from sqlalchemy import create_engine, text, bindparam +import sys + + +def create_test_engine(autocommit=False, options=""): + parser = argparse.ArgumentParser(description='Run SQLAlchemy tests.') + parser.add_argument('host', type=str, help='host to connect to') + parser.add_argument('port', type=int, help='port number to connect to') + parser.add_argument('database', type=str, help='database to connect to') + args = parser.parse_args() + + conn_string = "postgresql+psycopg2://user:password@{host}:{port}/d{options}".format( + host=args.host, port=args.port, options=options) + if args.host == "": + if options == "": + conn_string = conn_string + "?host=/tmp" + else: + conn_string = conn_string + "&host=/tmp" + conn = create_engine(conn_string, future=True) + if autocommit: + return conn.execution_options(isolation_level="AUTOCOMMIT") + return conn + +# Adapts identifiers as-is. This is used to insert PostgreSQL style parameters +# into psycopg2-formatted strings. +def adapt_identifier(identifier): + return AsIs(identifier.string) + +def adapt_dict(value): + return Json(value) + + +# Generates a `prepare as ` statement. +def generate_prepare_statement(name, stmt): + prepare_stmt = text("prepare {} as {}".format(name, stmt)) + params = {} + for index, param in enumerate(stmt.params): + params[param] = sql.Identifier("${}".format(index+1)) + return prepare_stmt, params + +def generate_execute_statement(name, stmt): + param_names = [""] * len(stmt.params) + params = {} + for index, param in enumerate(stmt.params): + param_names[index] = ":{}".format(str(param)) + params[param] = bindparam(str(param)) + execute_stmt = text("execute {} ({})".format(name, ",".join(param_names))) + # execute_stmt.bindparams(params) + return execute_stmt + + +register_adapter(sql.Identifier, adapt_identifier) +register_adapter(dict, adapt_dict) diff --git a/src/test/python/sqlalchemy/core_insert.py b/src/test/python/sqlalchemy/core_insert.py new file mode 100644 index 000000000..2eedc59e3 --- /dev/null +++ b/src/test/python/sqlalchemy/core_insert.py @@ -0,0 +1,25 @@ +from sqlalchemy import insert +from connect import create_test_engine +from user_metadata import user_table + +stmt = insert(user_table).values(name="spongebob", fullname="Spongebob " + "Squarepants") +print(stmt) +compiled = stmt.compile() +print(compiled.params) + +engine = create_test_engine() +with engine.connect() as conn: + result = conn.execute(stmt) + print("Result: {}".format(result.all())) + print("Row count: {}".format(result.rowcount)) + print("Inserted primary key: {}".format(result.inserted_primary_key)) + result = conn.execute( + insert(user_table), + [ + {"name": "sandy", "fullname": "Sandy Cheeks"}, + {"name": "patrick", "fullname": "Patrick Star"}, + ], + ) + conn.commit() + print("Row count: {}".format(result.rowcount)) diff --git a/src/test/python/sqlalchemy/core_insert_from_select.py b/src/test/python/sqlalchemy/core_insert_from_select.py new file mode 100644 index 000000000..5973d65d0 --- /dev/null +++ b/src/test/python/sqlalchemy/core_insert_from_select.py @@ -0,0 +1,21 @@ +from sqlalchemy import insert, select +from user_metadata import user_table, address_table +from connect import create_test_engine + +select_stmt = select(user_table.c.id, user_table.c.name + "@aol.com") +insert_stmt = insert(address_table).from_select( + ["user_id", "email_address"], select_stmt +) +print(insert_stmt) + +select_stmt = select(user_table.c.id, user_table.c.name + "@aol.com") +insert_stmt = insert(address_table).from_select( + ["user_id", "email_address"], select_stmt +) + +engine = create_test_engine() +with engine.connect() as conn: + res = conn.execute(insert_stmt.returning(address_table.c.id, + address_table.c.email_address)) + print("Inserted rows: {}".format(res.rowcount)) + print("Returned rows: {}".format(res.all())) diff --git a/src/test/python/sqlalchemy/core_select.py b/src/test/python/sqlalchemy/core_select.py new file mode 100644 index 000000000..856f4ada9 --- /dev/null +++ b/src/test/python/sqlalchemy/core_select.py @@ -0,0 +1,12 @@ +from connect import create_test_engine +from user_metadata import user_table +from sqlalchemy import select + + +stmt = select(user_table).where(user_table.c.name == "spongebob") +print(stmt) + +engine = create_test_engine() +with engine.connect() as conn: + for row in conn.execute(stmt): + print(row) diff --git a/src/test/python/sqlalchemy/engine_begin.py b/src/test/python/sqlalchemy/engine_begin.py new file mode 100644 index 000000000..a9144fbba --- /dev/null +++ b/src/test/python/sqlalchemy/engine_begin.py @@ -0,0 +1,24 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from connect import create_test_engine +from sqlalchemy import text + +engine = create_test_engine() +with engine.begin() as tx: + result = tx.execute(text("INSERT INTO test VALUES (:x, :y)"), + [{"x": 3, "y": 'Three'}, {"x": 4, "y": 'Four'}]) + # This prints out the total row count, so in this case 2. + print("Row count: {}".format(result.rowcount)) diff --git a/src/test/python/sqlalchemy/hello_world.py b/src/test/python/sqlalchemy/hello_world.py new file mode 100644 index 000000000..f04110d48 --- /dev/null +++ b/src/test/python/sqlalchemy/hello_world.py @@ -0,0 +1,23 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from connect import create_test_engine +from sqlalchemy import text + + +engine = create_test_engine(autocommit=True) +with engine.connect() as conn: + result = conn.execute(text("select 'hello world'")) + print(result.all()) diff --git a/src/test/python/sqlalchemy/orm_create_relationships.py b/src/test/python/sqlalchemy/orm_create_relationships.py new file mode 100644 index 000000000..6f4ce2e25 --- /dev/null +++ b/src/test/python/sqlalchemy/orm_create_relationships.py @@ -0,0 +1,27 @@ +from sqlalchemy.orm import Session +from connect import create_test_engine +from user_model import User, Address + +engine = create_test_engine() +session = Session(engine) +u1 = User(name="pkrabs", fullname="Pearl Krabs") +print(u1.addresses) + +a1 = Address(email_address="pearl.krabs@gmail.com") +u1.addresses.append(a1) +print(u1.addresses) +print(a1.user) + +a2 = Address(email_address="pearl@aol.com", user=u1) +print(u1.addresses) + +session.add(u1) +print(u1 in session) +print(a1 in session) +print(a2 in session) + +print(u1.id) +print(a1.user_id) +print(a2.user_id) + +session.commit() diff --git a/src/test/python/sqlalchemy/orm_delete.py b/src/test/python/sqlalchemy/orm_delete.py new file mode 100644 index 000000000..ed291d04c --- /dev/null +++ b/src/test/python/sqlalchemy/orm_delete.py @@ -0,0 +1,10 @@ +from sqlalchemy.orm import Session +from connect import create_test_engine +from all_types import AllTypes + +engine = create_test_engine(options="?options=-c timezone=UTC") +session = Session(engine) +row = session.get(AllTypes, 1) +session.delete(row) +session.commit() +print("deleted row") diff --git a/src/test/python/sqlalchemy/orm_error_in_read_write_transaction.py b/src/test/python/sqlalchemy/orm_error_in_read_write_transaction.py new file mode 100644 index 000000000..b51e6494e --- /dev/null +++ b/src/test/python/sqlalchemy/orm_error_in_read_write_transaction.py @@ -0,0 +1,32 @@ +from sqlalchemy.orm import Session +from sqlalchemy import exc +from connect import create_test_engine +from all_types import AllTypes +from datetime import datetime, date + + +engine = create_test_engine() +with Session(engine) as session: + row = AllTypes( + col_bigint=1, + col_bool=True, + col_bytea=bytes("test bytes", "utf-8"), + col_float8=3.14, + col_int=100, + col_numeric=6.626, + col_timestamptz=datetime.fromisoformat("2011-11-04T00:05:23.123456+00:00"), + col_date=date.fromisoformat("2011-11-04"), + col_varchar="test string", + col_jsonb={"key1": "value1", "key2": "value2"} + ) + session.add(row) + try: + session.flush() + print("Inserted 1 row(s)") + session.commit() + except exc.SQLAlchemyError as e: + print("Insert failed: {}".format(e)) + # Rolling back a transaction that has failed is allowed. Any other + # database operation would fail. + session.rollback() + pass diff --git a/src/test/python/sqlalchemy/orm_error_in_read_write_transaction_continue.py b/src/test/python/sqlalchemy/orm_error_in_read_write_transaction_continue.py new file mode 100644 index 000000000..e8bf09945 --- /dev/null +++ b/src/test/python/sqlalchemy/orm_error_in_read_write_transaction_continue.py @@ -0,0 +1,35 @@ +from sqlalchemy.orm import Session +from sqlalchemy import exc +from connect import create_test_engine +from all_types import AllTypes +from datetime import datetime, date + + +engine = create_test_engine() +with Session(engine) as session: + row = AllTypes( + col_bigint=1, + col_bool=True, + col_bytea=bytes("test bytes", "utf-8"), + col_float8=3.14, + col_int=100, + col_numeric=6.626, + col_timestamptz=datetime.fromisoformat("2011-11-04T00:05:23.123456+00:00"), + col_date=date.fromisoformat("2011-11-04"), + col_varchar="test string", + col_jsonb={"key1": "value1", "key2": "value2"} + ) + session.add(row) + try: + session.commit() + except exc.SQLAlchemyError as e: + try: + # Trying to use the same transaction after an error is not possible. + session.get(AllTypes, 1) + # The following line should never be reached. + print("Getting the row after an error succeeded") + pass + except exc.SQLAlchemyError as e2: + print("Getting the row failed: {}".format(e2)) + session.rollback() + pass diff --git a/src/test/python/sqlalchemy/orm_get.py b/src/test/python/sqlalchemy/orm_get.py new file mode 100644 index 000000000..2a4fc6103 --- /dev/null +++ b/src/test/python/sqlalchemy/orm_get.py @@ -0,0 +1,9 @@ +from sqlalchemy.orm import Session +from connect import create_test_engine +from all_types import AllTypes + + +engine = create_test_engine(options="?options=-c timezone=UTC") +session = Session(engine) +row = session.get(AllTypes, 1) +print(row) diff --git a/src/test/python/sqlalchemy/orm_get_with_prepared_statement.py b/src/test/python/sqlalchemy/orm_get_with_prepared_statement.py new file mode 100644 index 000000000..7b68c5676 --- /dev/null +++ b/src/test/python/sqlalchemy/orm_get_with_prepared_statement.py @@ -0,0 +1,28 @@ +from sqlalchemy import event +from sqlalchemy.orm import Session +from sqlalchemy.sql import text +from connect import create_test_engine +from all_types import AllTypes + + +engine = create_test_engine(options="?options=-c timezone=UTC") + +# Register an event listener for this engine that creates a prepared statement +# for each connection that is created. +@event.listens_for(engine, "connect") +def connect(dbapi_connection, connection_record): + cursor_obj = dbapi_connection.cursor() + cursor_obj.execute("prepare get_all_types as select * from all_types where col_bigint=$1") + cursor_obj.close() + + +def get_all_types(col_bigint): + return session.query(AllTypes) \ + .from_statement(text("execute get_all_types (:col_bigint)")) \ + .params(col_bigint=col_bigint) \ + .first() + + +with Session(engine) as session: + row = get_all_types(1) + print(row) diff --git a/src/test/python/sqlalchemy/orm_insert.py b/src/test/python/sqlalchemy/orm_insert.py new file mode 100644 index 000000000..91451784c --- /dev/null +++ b/src/test/python/sqlalchemy/orm_insert.py @@ -0,0 +1,23 @@ +from sqlalchemy.orm import Session +from connect import create_test_engine +from all_types import AllTypes +from datetime import datetime, date + + +engine = create_test_engine() +session = Session(engine) +row = AllTypes( + col_bigint=1, + col_bool=True, + col_bytea=bytes("test bytes", "utf-8"), + col_float8=3.14, + col_int=100, + col_numeric=6.626, + col_timestamptz=datetime.fromisoformat("2011-11-04T00:05:23.123456+00:00"), + col_date=date.fromisoformat("2011-11-04"), + col_varchar="test string", + col_jsonb={"key1": "value1", "key2": "value2"} +) +session.add(row) +session.commit() +print("Inserted 1 row(s)") diff --git a/src/test/python/sqlalchemy/orm_insert_null_values.py b/src/test/python/sqlalchemy/orm_insert_null_values.py new file mode 100644 index 000000000..8bc8abcb4 --- /dev/null +++ b/src/test/python/sqlalchemy/orm_insert_null_values.py @@ -0,0 +1,22 @@ +from sqlalchemy.orm import Session +from connect import create_test_engine +from all_types import AllTypes + + +engine = create_test_engine() +session = Session(engine) +row = AllTypes( + col_bigint=1, + col_bool=None, + col_bytea=None, + col_float8=None, + col_int=None, + col_numeric=None, + col_timestamptz=None, + col_date=None, + col_varchar=None, + col_jsonb=None, +) +session.add(row) +session.commit() +print("Inserted 1 row(s)") diff --git a/src/test/python/sqlalchemy/orm_insert_with_prepared_statement.py b/src/test/python/sqlalchemy/orm_insert_with_prepared_statement.py new file mode 100644 index 000000000..d597e8252 --- /dev/null +++ b/src/test/python/sqlalchemy/orm_insert_with_prepared_statement.py @@ -0,0 +1,36 @@ +from sqlalchemy import insert +from sqlalchemy.orm import Session +from datetime import datetime, date +from connect import create_test_engine, generate_prepare_statement,\ + generate_execute_statement +from all_types import AllTypes + +engine = create_test_engine() +with engine.connect() as conn: + session = Session(engine) + insert_stmt = insert(AllTypes).compile() + prepare_stmt, params = generate_prepare_statement("insert_all_types", + insert_stmt) + execute_stmt = generate_execute_statement("insert_all_types", insert_stmt) + + # Execute the `prepare as ` statement to create a prepared + # statement. + session.execute(prepare_stmt, params) + + row = AllTypes( + col_bigint=1, + col_bool=True, + col_bytea=bytes("test bytes", "utf-8"), + col_float8=3.14, + col_int=100, + col_numeric=6.626, + col_timestamptz=datetime.fromisoformat("2011-11-04T00:05:23.123456+00:00"), + col_date=date.fromisoformat("2011-11-04"), + col_varchar="test string", + col_jsonb={"key1": "value1", "key2": "value2"} + ) + print(vars(row)) + # conn.execute(execute_stmt, vars(row)) + session.execute(execute_stmt, vars(row)) + # execute_stmt.bindparams(row) + # conn.execute(execute_stmt, {"col_bigint": 1}) diff --git a/src/test/python/sqlalchemy/orm_load_relationships.py b/src/test/python/sqlalchemy/orm_load_relationships.py new file mode 100644 index 000000000..2ca797666 --- /dev/null +++ b/src/test/python/sqlalchemy/orm_load_relationships.py @@ -0,0 +1,14 @@ +from sqlalchemy import select +from sqlalchemy.orm import selectinload +from sqlalchemy.orm import Session +from connect import create_test_engine +from user_model import User + + +engine = create_test_engine() +session = Session(engine) + +stmt = select(User).options(selectinload(User.addresses)).order_by(User.id) +for row in session.execute(stmt): + print(f"{row.User.name} " + f"({', '.join(a.email_address for a in row.User.addresses)}) ") diff --git a/src/test/python/sqlalchemy/orm_read_only_transaction.py b/src/test/python/sqlalchemy/orm_read_only_transaction.py new file mode 100644 index 000000000..9614f258e --- /dev/null +++ b/src/test/python/sqlalchemy/orm_read_only_transaction.py @@ -0,0 +1,12 @@ +from sqlalchemy.orm import Session +from connect import create_test_engine +from all_types import AllTypes + + +engine = create_test_engine(options="?options=-c timezone=UTC") +readonly_engine = engine.execution_options(postgresql_readonly=True) + +session = Session(readonly_engine) +session.begin() +row = session.get(AllTypes, 1) +print(row) diff --git a/src/test/python/sqlalchemy/orm_rollback.py b/src/test/python/sqlalchemy/orm_rollback.py new file mode 100644 index 000000000..7a0d23a94 --- /dev/null +++ b/src/test/python/sqlalchemy/orm_rollback.py @@ -0,0 +1,12 @@ +from sqlalchemy.orm import Session +from connect import create_test_engine +from all_types import AllTypes + +engine = create_test_engine(options="?options=-c timezone=UTC") +session = Session(engine) +row = session.get(AllTypes, 1) +row.col_varchar = "updated string" +session.flush() +print("Before rollback: {}".format(row)) +session.rollback() +print("After rollback: {}".format(row)) diff --git a/src/test/python/sqlalchemy/orm_select_first.py b/src/test/python/sqlalchemy/orm_select_first.py new file mode 100644 index 000000000..0fb2ee4af --- /dev/null +++ b/src/test/python/sqlalchemy/orm_select_first.py @@ -0,0 +1,10 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session +from connect import create_test_engine +from all_types import AllTypes + + +engine = create_test_engine(options="?options=-c timezone=UTC") +session = Session(engine) +row = session.scalars(select(AllTypes)).first() +print(row) diff --git a/src/test/python/sqlalchemy/orm_stale_read.py b/src/test/python/sqlalchemy/orm_stale_read.py new file mode 100644 index 000000000..19226a68e --- /dev/null +++ b/src/test/python/sqlalchemy/orm_stale_read.py @@ -0,0 +1,16 @@ +from sqlalchemy.orm import Session +from connect import create_test_engine +from all_types import AllTypes + + +stale_read_engine = create_test_engine( + autocommit=True, + options="?options=-c spanner.read_only_staleness='MAX_STALENESS 10s'") + +session = Session(stale_read_engine) + +row1 = session.get(AllTypes, 1) +print(row1) + +row2 = session.get(AllTypes, 2) +print(row2) diff --git a/src/test/python/sqlalchemy/orm_update.py b/src/test/python/sqlalchemy/orm_update.py new file mode 100644 index 000000000..7c0d2d4a3 --- /dev/null +++ b/src/test/python/sqlalchemy/orm_update.py @@ -0,0 +1,10 @@ +from sqlalchemy.orm import Session +from connect import create_test_engine +from all_types import AllTypes + +engine = create_test_engine(options="?options=-c timezone=UTC") +session = Session(engine) +row = session.get(AllTypes, 1) +row.col_varchar = "updated string" +print(row) +session.commit() diff --git a/src/test/python/sqlalchemy/requirements.txt b/src/test/python/sqlalchemy/requirements.txt new file mode 100644 index 000000000..2bfae662d --- /dev/null +++ b/src/test/python/sqlalchemy/requirements.txt @@ -0,0 +1,3 @@ +psycopg2~=2.9.3 +pytz~=2022.1 +sqlalchemy==1.4.45 diff --git a/src/test/python/sqlalchemy/server_side_cursor.py b/src/test/python/sqlalchemy/server_side_cursor.py new file mode 100644 index 000000000..8d21897f7 --- /dev/null +++ b/src/test/python/sqlalchemy/server_side_cursor.py @@ -0,0 +1,13 @@ +from connect import create_test_engine +from sqlalchemy import text + + +engine = create_test_engine() +with engine.connect() as conn: + with conn.execution_options(yield_per=10).execute( + text("select * from random") + ) as result: + for partition in result.partitions(): + # partition is an iterable that will be at most 10 items + for row in partition: + print(f"{row}") diff --git a/src/test/python/sqlalchemy/session_execute.py b/src/test/python/sqlalchemy/session_execute.py new file mode 100644 index 000000000..9b6ad41d0 --- /dev/null +++ b/src/test/python/sqlalchemy/session_execute.py @@ -0,0 +1,26 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from connect import create_test_engine +from sqlalchemy import text +from sqlalchemy.orm import Session + +engine = create_test_engine() +with Session(engine) as session: + result = session.execute( + text("UPDATE test SET value=:y WHERE id=:x"), + [{"x": 1, "y": 'one'}, {"x": 2, "y": 'two'}]) + print("Row count: {}".format(result.rowcount)) + session.commit() diff --git a/src/test/python/sqlalchemy/simple_insert.py b/src/test/python/sqlalchemy/simple_insert.py new file mode 100644 index 000000000..e93b5cca9 --- /dev/null +++ b/src/test/python/sqlalchemy/simple_insert.py @@ -0,0 +1,25 @@ +""" Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from connect import create_test_engine +from sqlalchemy import text + +engine = create_test_engine() +with engine.connect() as conn: + result = conn.execute(text("INSERT INTO test VALUES (:x, :y)"), + [{"x": 1, "y": 'One'}, {"x": 2, "y": 'Two'}]) + # This prints out the total row count, so in this case 2. + print("Row count: {}".format(result.rowcount)) + conn.commit() diff --git a/src/test/python/sqlalchemy/simple_metadata.py b/src/test/python/sqlalchemy/simple_metadata.py new file mode 100644 index 000000000..822704108 --- /dev/null +++ b/src/test/python/sqlalchemy/simple_metadata.py @@ -0,0 +1,14 @@ +from connect import create_test_engine +from user_metadata import user_table +from user_model import User, Address, mapper_registry + +print(user_table.c.name) +print(user_table.c.keys()) +print(user_table.primary_key) + +print(User.__table__) +print(Address.__table__) + +engine = create_test_engine(options="?options=-c spanner.ddl_transaction_mode" + "=AutocommitExplicitTransaction") +mapper_registry.metadata.create_all(engine) \ No newline at end of file diff --git a/src/test/python/sqlalchemy/user_metadata.py b/src/test/python/sqlalchemy/user_metadata.py new file mode 100644 index 000000000..a9d2169ec --- /dev/null +++ b/src/test/python/sqlalchemy/user_metadata.py @@ -0,0 +1,19 @@ +from sqlalchemy import MetaData, Table, Column, Integer, String, ForeignKey + + +metadata_obj = MetaData() + +user_table = Table( + "user_account", + metadata_obj, + Column("id", Integer, primary_key=True), + Column("name", String(30)), + Column("fullname", String), +) +address_table = Table( + "address", + metadata_obj, + Column("id", Integer, primary_key=True), + Column("user_id", ForeignKey("user_account.id"), nullable=False), + Column("email_address", String, nullable=False), +) diff --git a/src/test/python/sqlalchemy/user_model.py b/src/test/python/sqlalchemy/user_model.py new file mode 100644 index 000000000..4b008a36f --- /dev/null +++ b/src/test/python/sqlalchemy/user_model.py @@ -0,0 +1,33 @@ +from sqlalchemy import Column, Integer, String, ForeignKey +from sqlalchemy.orm import registry +from sqlalchemy.orm import relationship + + +mapper_registry = registry() +Base = mapper_registry.generate_base() + + +class User(Base): + __tablename__ = "user_account" + + id = Column(Integer, primary_key=True) + name = Column(String(30)) + fullname = Column(String) + + addresses = relationship("Address", back_populates="user") + + def __repr__(self): + return f"User(id={self.id!r}, name={self.name!r}, fullname={self.fullname!r})" + + +class Address(Base): + __tablename__ = "address" + + id = Column(Integer, primary_key=True) + email_address = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey("user_account.id")) + + user = relationship("User", back_populates="addresses") + + def __repr__(self): + return f"Address(id={self.id!r}, email_address={self.email_address!r})" diff --git a/versions.txt b/versions.txt index 5dc05cf56..a8293857f 100644 --- a/versions.txt +++ b/versions.txt @@ -1,4 +1,4 @@ # Format: # module:released-version:current-version -google-cloud-spanner-pgadapter:0.15.0:0.15.0 +google-cloud-spanner-pgadapter:0.16.0:0.16.0