diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml
index f53d35aa6..53fe596fb 100644
--- a/.github/workflows/integration.yaml
+++ b/.github/workflows/integration.yaml
@@ -91,6 +91,7 @@ jobs:
with:
python-version: '3.9'
- run: pip install -r ./src/test/python/requirements.txt
+ - run: pip install -r ./src/test/python/sqlalchemy/requirements.txt
- run: pip install -r ./src/test/python/pg8000/requirements.txt
- uses: actions/setup-node@v3
with:
diff --git a/.github/workflows/units.yaml b/.github/workflows/units.yaml
index 7f598a0fe..81d8a7d11 100644
--- a/.github/workflows/units.yaml
+++ b/.github/workflows/units.yaml
@@ -26,6 +26,7 @@ jobs:
python-version: '3.9'
- run: python --version
- run: pip install -r ./src/test/python/requirements.txt
+ - run: pip install -r ./src/test/python/sqlalchemy/requirements.txt
- run: pip install -r ./src/test/python/pg8000/requirements.txt
- uses: actions/setup-node@v3
with:
@@ -51,6 +52,7 @@ jobs:
python-version: '3.9'
- run: python --version
- run: pip install -r ./src/test/python/requirements.txt
+ - run: pip install -r ./src/test/python/sqlalchemy/requirements.txt
- run: pip install -r ./src/test/python/pg8000/requirements.txt
- run: mvn -B test
macos:
@@ -77,6 +79,7 @@ jobs:
python-version: '3.9'
- run: python --version
- run: pip install -r ./src/test/python/requirements.txt
+ - run: pip install -r ./src/test/python/sqlalchemy/requirements.txt
- run: pip install -r ./src/test/python/pg8000/requirements.txt
- uses: actions/setup-dotnet@v3
with:
diff --git a/CHANGELOG.md b/CHANGELOG.md
index f90c874a2..e43b7f759 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,36 @@
# Changelog
+## [0.16.0](https://github.com/GoogleCloudPlatform/pgadapter/compare/v0.15.0...v0.16.0) (2023-02-05)
+
+
+### Features
+
+* allow unsupported OIDs as param types ([#604](https://github.com/GoogleCloudPlatform/pgadapter/issues/604)) ([5e9f95a](https://github.com/GoogleCloudPlatform/pgadapter/commit/5e9f95a720f1648236b39167b227cc70bd40e323))
+* make table and function replacements client-aware ([#605](https://github.com/GoogleCloudPlatform/pgadapter/issues/605)) ([ad49e99](https://github.com/GoogleCloudPlatform/pgadapter/commit/ad49e990298d0e91736d4f5afe581d2f1411b5ca))
+
+
+### Bug Fixes
+
+* binary copy header should be included in first data message ([#609](https://github.com/GoogleCloudPlatform/pgadapter/issues/609)) ([2fbf89e](https://github.com/GoogleCloudPlatform/pgadapter/commit/2fbf89e6a6b3ba0b66f126abf019e386e9276d4c))
+* copy to for a query would fail with a column list ([#616](https://github.com/GoogleCloudPlatform/pgadapter/issues/616)) ([16f030e](https://github.com/GoogleCloudPlatform/pgadapter/commit/16f030e3f6b93ae0a243b6c495b0c906403c5e16))
+* CopyResponse did not return correct column format ([#633](https://github.com/GoogleCloudPlatform/pgadapter/issues/633)) ([dc0d482](https://github.com/GoogleCloudPlatform/pgadapter/commit/dc0d482ffb61d1857a3f49fc424a07d72886b460))
+* csv copy header was repeated for each row ([#619](https://github.com/GoogleCloudPlatform/pgadapter/issues/619)) ([622c49a](https://github.com/GoogleCloudPlatform/pgadapter/commit/622c49a02cf2a865874764f44a77b96539382be0))
+* empty copy from stdin statements could be unresponsive ([#617](https://github.com/GoogleCloudPlatform/pgadapter/issues/617)) ([c576124](https://github.com/GoogleCloudPlatform/pgadapter/commit/c576124e40ad7f07ee0d1e2f3090886896c70dc3))
+* empty partitions could skip binary copy header ([#615](https://github.com/GoogleCloudPlatform/pgadapter/issues/615)) ([e7dd650](https://github.com/GoogleCloudPlatform/pgadapter/commit/e7dd6508015ed45147af59c25f95e18628461d85))
+* show statements failed in pgx ([#629](https://github.com/GoogleCloudPlatform/pgadapter/issues/629)) ([734f521](https://github.com/GoogleCloudPlatform/pgadapter/commit/734f52176f75e4ccb0b8bddc96eae49ace9ab19e))
+* support end-of-data record in COPY ([#602](https://github.com/GoogleCloudPlatform/pgadapter/issues/602)) ([8b705e8](https://github.com/GoogleCloudPlatform/pgadapter/commit/8b705e8f917035cbabe9e6751008e93692355158))
+
+
+### Dependencies
+
+* update Spanner client to 6.35.1 ([#607](https://github.com/GoogleCloudPlatform/pgadapter/issues/607)) ([0c607c7](https://github.com/GoogleCloudPlatform/pgadapter/commit/0c607c7c1bce48139f28688a5d7f1e202d839860))
+
+
+### Documentation
+
+* document pgbench usage ([#603](https://github.com/GoogleCloudPlatform/pgadapter/issues/603)) ([5a62bf6](https://github.com/GoogleCloudPlatform/pgadapter/commit/5a62bf64c56a976625e2c707b6d049e593cddc96))
+* document unix domain sockets with Docker ([#622](https://github.com/GoogleCloudPlatform/pgadapter/issues/622)) ([e4e41f7](https://github.com/GoogleCloudPlatform/pgadapter/commit/e4e41f70e5ad23d8e7d6f2a1bc1851458466bbb6))
+
## [0.15.0](https://github.com/GoogleCloudPlatform/pgadapter/compare/v0.14.1...v0.15.0) (2023-01-18)
diff --git a/README.md b/README.md
index dca5be2f2..6756f34c6 100644
--- a/README.md
+++ b/README.md
@@ -11,12 +11,17 @@ PGAdapter can be used with the following drivers and clients:
4. `psycopg2`: Version 2.9.3 and higher (but not `psycopg3`) are supported. See [psycopg2](docs/psycopg2.md) for more details.
5. `node-postgres`: Version 8.8.0 and higher are supported. See [node-postgres support](docs/node-postgres.md) for more details.
-## Frameworks
-PGAdapter can be used with the following frameworks:
+## Frameworks and Tools
+PGAdapter can be used with the following frameworks and tools:
1. `Liquibase`: Version 4.12.0 and higher are supported. See [Liquibase support](docs/liquibase.md)
for more details. See also [this directory](samples/java/liquibase) for a sample application using `Liquibase`.
2. `gorm`: Version 1.23.8 and higher are supported. See [gorm support](docs/gorm.md) for more details.
See also [this directory](samples/golang/gorm) for a sample application using `gorm`.
+3. `SQLAlchemy`: Version 1.4.45 has _experimental support_. See [SQLAlchemy support](docs/sqlalchemy.md)
+ for more details. See also [this directory](samples/python/sqlalchemy-sample) for a sample
+ application using `SQLAlchemy`.
+4. `pgbench` can be used with PGAdapter, but with some limitations. See [pgbench.md](docs/pgbench.md)
+ for more details.
## FAQ
See [Frequently Asked Questions](docs/faq.md) for answers to frequently asked questions.
@@ -63,9 +68,9 @@ Use the `-s` option to specify a different local port than the default 5432 if y
PostgreSQL running on your local system.
-You can also download a specific version of the jar. Example (replace `v0.15.0` with the version you want to download):
+You can also download a specific version of the jar. Example (replace `v0.16.0` with the version you want to download):
```shell
-VERSION=v0.15.0
+VERSION=v0.16.0
wget https://storage.googleapis.com/pgadapter-jar-releases/pgadapter-${VERSION}.tar.gz \
&& tar -xzvf pgadapter-${VERSION}.tar.gz
java -jar pgadapter.jar -p my-project -i my-instance -d my-database
@@ -100,7 +105,7 @@ This option is only available for Java/JVM-based applications.
com.google.cloudgoogle-cloud-spanner-pgadapter
- 0.15.0
+ 0.16.0
```
diff --git a/benchmarks/ycsb/run.sh b/benchmarks/ycsb/run.sh
index dfeb48e09..947687499 100644
--- a/benchmarks/ycsb/run.sh
+++ b/benchmarks/ycsb/run.sh
@@ -31,16 +31,19 @@ psql -h localhost -c "CREATE TABLE IF NOT EXISTS usertable (
read_min float,
read_max float,
read_avg float,
+ read_p50 float,
read_p95 float,
read_p99 float,
update_min float,
update_max float,
update_avg float,
+ update_p50 float,
update_p95 float,
update_p99 float,
insert_min float,
insert_max float,
insert_avg float,
+ insert_p50 float,
insert_p95 float,
insert_p99 float,
primary key (executed_at, deployment, workload, threads, batch_size, operation_count)
@@ -98,6 +101,7 @@ do
fi
./bin/ycsb run jdbc -P workloads/workload$WORKLOAD \
-threads $THREADS \
+ -p hdrhistogram.percentiles=50,95,99 \
-p operationcount=$OPERATION_COUNT \
-p recordcount=100000 \
-p db.batchsize=$BATCH_SIZE \
@@ -112,30 +116,33 @@ do
READ_AVG=$(grep '\[READ\], AverageLatency(us), ' ycsb.log | sed 's/^.*, //' | sed "s/NaN/'NaN'/" || echo null)
READ_MIN=$(grep '\[READ\], MinLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null)
READ_MAX=$(grep '\[READ\], MaxLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null)
+ READ_P50=$(grep '\[READ\], 50thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null)
READ_P95=$(grep '\[READ\], 95thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null)
READ_P99=$(grep '\[READ\], 99thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null)
UPDATE_AVG=$(grep '\[UPDATE\], AverageLatency(us), ' ycsb.log | sed 's/^.*, //' | sed "s/NaN/'NaN'/" || echo null)
UPDATE_MIN=$(grep '\[UPDATE\], MinLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null)
UPDATE_MAX=$(grep '\[UPDATE\], MaxLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null)
+ UPDATE_P50=$(grep '\[UPDATE\], 50thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null)
UPDATE_P95=$(grep '\[UPDATE\], 95thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null)
UPDATE_P99=$(grep '\[UPDATE\], 99thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null)
INSERT_AVG=$(grep '\[INSERT\], AverageLatency(us), ' ycsb.log | sed 's/^.*, //' | sed "s/NaN/'NaN'/" || echo null)
INSERT_MIN=$(grep '\[INSERT\], MinLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null)
INSERT_MAX=$(grep '\[INSERT\], MaxLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null)
+ INSERT_P50=$(grep '\[INSERT\], 50thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null)
INSERT_P95=$(grep '\[INSERT\], 95thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null)
INSERT_P99=$(grep '\[INSERT\], 99thPercentileLatency(us), ' ycsb.log | sed 's/^.*, //' || echo null)
psql -h localhost \
-c "insert into run (executed_at, deployment, workload, threads, batch_size, operation_count, run_time, throughput,
- read_min, read_max, read_avg, read_p95, read_p99,
- update_min, update_max, update_avg, update_p95, update_p99,
- insert_min, insert_max, insert_avg, insert_p95, insert_p99) values
+ read_min, read_max, read_avg, read_p50, read_p95, read_p99,
+ update_min, update_max, update_avg, update_p50, update_p95, update_p99,
+ insert_min, insert_max, insert_avg, insert_p50, insert_p95, insert_p99) values
('$EXECUTED_AT', '$DEPLOYMENT', '$WORKLOAD', $THREADS, $BATCH_SIZE, $OPERATION_COUNT, $OVERALL_RUNTIME, $OVERALL_THROUGHPUT,
- $READ_MIN, $READ_MAX, $READ_AVG, $READ_P95, $READ_P99,
- $UPDATE_MIN, $UPDATE_MAX, $UPDATE_AVG, $UPDATE_P95, $UPDATE_P99,
- $INSERT_MIN, $INSERT_MAX, $INSERT_AVG, $INSERT_P95, $INSERT_P99)"
+ $READ_MIN, $READ_MAX, $READ_AVG, $READ_P50, $READ_P95, $READ_P99,
+ $UPDATE_MIN, $UPDATE_MAX, $UPDATE_AVG, $UPDATE_P50, $UPDATE_P95, $UPDATE_P99,
+ $INSERT_MIN, $INSERT_MAX, $INSERT_AVG, $INSERT_P50, $INSERT_P95, $INSERT_P99)"
done
done
done
diff --git a/docs/docker.md b/docs/docker.md
index ccc605ac5..9ab4b046b 100644
--- a/docs/docker.md
+++ b/docs/docker.md
@@ -12,6 +12,7 @@ docker run \
-d -p 5432:5432 \
-v ${GOOGLE_APPLICATION_CREDENTIALS}:${GOOGLE_APPLICATION_CREDENTIALS}:ro \
-e GOOGLE_APPLICATION_CREDENTIALS \
+ -v /tmp:/tmp:rw \
gcr.io/cloud-spanner-pg-adapter/pgadapter \
-p my-project -i my-instance -d my-database \
-x
@@ -26,6 +27,8 @@ The Docker options in the `docker run` command that are used in the above exampl
The local file should contain the credentials that should be used by PGAdapter.
* `-e`: Copy the value of the environment variable `GOOGLE_APPLICATION_CREDENTIALS` to the container.
This will make the virtual file `/path/to/credentials.json` the default credentials in the container.
+* `-v /tmp:/tmp:rw`: Map the `/tmp` host directory to the `/tmp` directory in the container. PGAdapter by
+ default uses `/tmp` as the directory where it creates a Unix Domain Socket.
The PGAdapter options in the `docker run` command that are used in the above example are:
* `-p`: The Google Cloud project name where the Cloud Spanner database is located.
@@ -54,3 +57,24 @@ psql -h localhost -p 5433
The above example starts PGAdapter in a Docker container on the default port 5432. That port is
mapped to port 5433 on the host machine. When connecting to PGAdapter in the Docker container, you
must specify port 5433 for the connection.
+
+## Using a different directory for Unix Domain Sockets
+
+PGAdapter by default uses `/tmp` as the directory for Unix Domain Sockets. You can change this with
+the `-dir` command line argument for PGAdapter. You must then also change the volume mapping to your
+host to ensure that you can connect to the Unix Domain Socket.
+
+```shell
+docker run \
+ -d -p 5432:5432 \
+ -v ${GOOGLE_APPLICATION_CREDENTIALS}:${GOOGLE_APPLICATION_CREDENTIALS}:ro \
+ -e GOOGLE_APPLICATION_CREDENTIALS \
+ -v /var/pgadapter:/var/pgadapter:rw \
+ gcr.io/cloud-spanner-pg-adapter/pgadapter \
+ -p my-project -i my-instance -d my-database \
+ -dir /var/pgadapter \
+ -x
+psql -h /var/pgadapter
+```
+
+The above example uses `/var/pgadapter` as the directory for Unix Domain Sockets.
diff --git a/docs/pgbench.md b/docs/pgbench.md
new file mode 100644
index 000000000..eb393d906
--- /dev/null
+++ b/docs/pgbench.md
@@ -0,0 +1,100 @@
+# PGAdapter - pgbench
+
+[pgbench](https://www.postgresql.org/docs/current/pgbench.html) can be used with PGAdapter, but with
+some limitations.
+
+Follow these steps to initialize and run benchmarks with `pgbench` with PGAdapter:
+
+## Create Data Model
+The default data model that is generated by `pgbench` does not include primary keys for the tables.
+Cloud Spanner requires all tables to have primary keys. Execute the following command to manually
+create the data model for `pgbench`:
+
+```shell
+psql -h /tmp -p 5432 -d my-database \
+ -c "START BATCH DDL;
+ CREATE TABLE pgbench_accounts (
+ aid integer primary key NOT NULL,
+ bid integer NULL,
+ abalance integer NULL,
+ filler varchar(84) NULL
+ );
+ CREATE TABLE pgbench_branches (
+ bid integer primary key NOT NULL,
+ bbalance integer NULL,
+ filler varchar(88) NULL
+ );
+ CREATE TABLE pgbench_history (
+ tid integer NOT NULL DEFAULT -1,
+ bid integer NOT NULL DEFAULT -1,
+ aid integer NOT NULL DEFAULT -1,
+ delta integer NULL,
+ mtime timestamptz NULL,
+ filler varchar(22) NULL,
+ primary key (tid, bid, aid)
+ );
+ CREATE TABLE pgbench_tellers (
+ tid integer primary key NOT NULL,
+ bid integer NULL,
+ tbalance integer NULL,
+ filler varchar(84) NULL
+ );
+ RUN BATCH;"
+```
+
+## Initialize Data
+`pgbench` deletes and inserts data into PostgreSQL using a combination of `truncate`, `insert` and
+`copy` statements. These statements all run in a single transaction. The amount of data that is
+modified during this transaction will exceed the transaction mutation limits of Cloud Spanner. This
+can be worked around by adding the following options to the `pgbench` initialization command:
+
+```shell
+pgbench "host=/tmp port=5432 dbname=my-database \
+ options='-c spanner.force_autocommit=on -c spanner.autocommit_dml_mode=\'partitioned_non_atomic\''" \
+ -i -Ig \
+ --scale=100
+```
+
+These additional options do the following:
+1. `spanner.force_autocommit=true`: This instructs PGAdapter to ignore any transaction statements and
+ execute all statements in autocommit mode. This prevents the initialization from being executed as
+ a single, large transaction.
+2. `spanner.autocommit_dml_mode='partitioned_non_atomic'`: This instructs PGAdapter to use Partitioned
+ DML for (large) update statements. This ensures that a single statement succeeds even if it would
+ exceed the transaction limits of Cloud Spanner, including large `copy` operations.
+3. `-i` activates initialization mode of `pgbench`.
+4. `-Ig` instructs `pgbench` to generate test data client side. PGAdapter does not support generating
+ test data server side.
+
+## Running Benchmarks
+You can run different benchmarks after finishing the steps above.
+
+### Default Benchmark
+Run a default benchmark to verify that everything works as expected.
+
+```shell
+pgbench "host=/tmp port=5432 dbname=my-database"
+```
+
+### Number of Clients
+Increase the number of clients and threads to increase the number of parallel transactions.
+
+```shell
+pgbench "host=/tmp port=5432 dbname=my-database" \
+ --client=100 --jobs=100 \
+ --progress=10
+```
+
+## Dropping Tables
+Execute the following command to remove the `pgbench` tables from your database if you no longer
+need them.
+
+```shell
+psql -h /tmp -p 5432 -d my-database \
+ -c "START BATCH DDL;
+ DROP TABLE pgbench_history;
+ DROP TABLE pgbench_tellers;
+ DROP TABLE pgbench_branches;
+ DROP TABLE pgbench_accounts;
+ RUN BATCH;"
+```
diff --git a/docs/sqlalchemy.md b/docs/sqlalchemy.md
new file mode 100644
index 000000000..f4e577fea
--- /dev/null
+++ b/docs/sqlalchemy.md
@@ -0,0 +1,47 @@
+# PGAdapter - SQLAlchemy Connection Options
+
+## Limitations
+PGAdapter has experimental support for SQLAlchemy 1.4 with Cloud Spanner PostgreSQL databases. It
+has been tested with SQLAlchemy 1.4.45 and psycopg2 2.9.3. Developing new applications using
+SQLAlchemy is possible as long as the listed limitations are taken into account.
+Porting an existing application from PostgreSQL to Cloud Spanner is likely to require code changes.
+
+See [Limitations](../samples/python/sqlalchemy-sample/README.md#limitations) in the `sqlalchemy-sample`
+directory for a full list of limitations.
+
+## Usage
+
+First start PGAdapter:
+
+```shell
+export GOOGLE_APPLICATION_CREDENTIALS=/path/to/credentials.json
+docker pull gcr.io/cloud-spanner-pg-adapter/pgadapter
+docker run \
+ -d -p 5432:5432 \
+ -v ${GOOGLE_APPLICATION_CREDENTIALS}:${GOOGLE_APPLICATION_CREDENTIALS}:ro \
+ -e GOOGLE_APPLICATION_CREDENTIALS \
+ gcr.io/cloud-spanner-pg-adapter/pgadapter \
+ -p my-project -i my-instance \
+ -x
+```
+
+Then connect to PGAdapter and use `SQLAlchemy` like this:
+
+```python
+conn_string = "postgresql+psycopg2://user:password@localhost:5432/my-database"
+engine = create_engine(conn_string)
+with Session(engine) as session:
+ user1 = User(
+ name="user1",
+ fullname="Full name of User 1",
+ addresses=[Address(email_address="user1@sqlalchemy.org")]
+ )
+ session.add(user1)
+ session.commit()
+```
+
+## Full Sample and Limitations
+[This directory](../samples/python/sqlalchemy-sample) contains a full sample of how to work with
+`SQLAlchemy` with Cloud Spanner and PGAdapter. The sample readme file also lists the
+[current limitations](../samples/python/sqlalchemy-sample/README.md#limitations) when working with
+`SQLAlchemy`.
diff --git a/pom.xml b/pom.xml
index 76c891e9c..c4cb90a62 100644
--- a/pom.xml
+++ b/pom.xml
@@ -34,13 +34,13 @@
com.google.cloud.spanner.pgadapter.nodejs.NodeJSTest
- 6.33.0
+ 6.35.12.6.14.0.0google-cloud-spanner-pgadapter
- 0.15.0
+ 0.16.0Google Cloud Spanner PostgreSQL Adapterjar
@@ -111,7 +111,7 @@
org.apache.commonscommons-csv
- 1.9.0
+ 1.10.0org.apache.commons
@@ -147,7 +147,7 @@
net.java.dev.jnajna
- 5.12.1
+ 5.13.0test
diff --git a/samples/python/sqlalchemy-sample/README.md b/samples/python/sqlalchemy-sample/README.md
new file mode 100644
index 000000000..e7abad7ce
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/README.md
@@ -0,0 +1,174 @@
+# PGAdapter and SQLAlchemy
+
+PGAdapter has experimental support for [SQLAlchemy 1.4](https://docs.sqlalchemy.org/en/14/index.html)
+with the `psycopg2` driver. This document shows how to use this sample application, and lists the
+limitations when working with `SQLAlchemy` with PGAdapter.
+
+The [sample.py](sample.py) file contains a sample application using `SQLAlchemy` with PGAdapter. Use
+this as a reference for features of `SQLAlchemy` that are supported with PGAdapter. This sample
+assumes that the reader is familiar with `SQLAlchemy`, and it is not intended as a tutorial for how
+to use `SQLAlchemy` in general.
+
+See [Limitations](#limitations) for a full list of known limitations when working with `SQLAlchemy`.
+
+## Start PGAdapter
+You must start PGAdapter before you can run the sample. The following command shows how to start PGAdapter using the
+pre-built Docker image. See [Running PGAdapter](../../../README.md#usage) for more information on other options for how
+to run PGAdapter.
+
+```shell
+export GOOGLE_APPLICATION_CREDENTIALS=/path/to/credentials.json
+docker pull gcr.io/cloud-spanner-pg-adapter/pgadapter
+docker run \
+ -d -p 5432:5432 \
+ -v ${GOOGLE_APPLICATION_CREDENTIALS}:${GOOGLE_APPLICATION_CREDENTIALS}:ro \
+ -e GOOGLE_APPLICATION_CREDENTIALS \
+ -v /tmp:/tmp \
+ gcr.io/cloud-spanner-pg-adapter/pgadapter \
+ -p my-project -i my-instance \
+ -x
+```
+
+## Creating the Sample Data Model
+The sample data model contains example tables that cover all supported data types the Cloud Spanner
+PostgreSQL dialect. It also includes an example for how [interleaved tables](https://cloud.google.com/spanner/docs/reference/postgresql/data-definition-language#extensions_to)
+can be used with SQLAlchemy. Interleaved tables is a Cloud Spanner extension of the standard
+PostgreSQL dialect.
+
+The corresponding SQLAlchemy model is defined in [model.py](model.py).
+
+Run the following command in this directory. Replace the host, port and database name with the actual
+host, port and database name for your PGAdapter and database setup.
+
+```shell
+psql -h localhost -p 5432 -d my-database -f create_data_model.sql
+```
+
+You can also drop an existing data model using the `drop_data_model.sql` script:
+
+```shell
+psql -h localhost -p 5432 -d my-database -f drop_data_model.sql
+```
+
+## Data Types
+Cloud Spanner supports the following data types in combination with `SQLAlchemy`.
+
+| PostgreSQL Type | SQLAlchemy type |
+|----------------------------------------|-------------------------|
+| boolean | Boolean |
+| bigint / int8 | Integer, BigInteger |
+| varchar | String |
+| text | String |
+| float8 / double precision | Float |
+| numeric | Numeric |
+| timestamptz / timestamp with time zone | DateTime(timezone=True) |
+| date | Date |
+| bytea | LargeBinary |
+| jsonb | JSONB |
+
+
+## Limitations
+The following limitations are currently known:
+
+| Limitation | Workaround |
+|------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| Creating and Dropping Tables | Cloud Spanner does not support the full PostgreSQL DDL dialect. Automated creation of tables using `SQLAlchemy` is therefore not supported. |
+| metadata.reflect() | Cloud Spanner does not support all PostgreSQL `pg_catalog` tables. Using `metadata.reflect()` to get the current objects in the database is therefore not supported. |
+| DDL Transactions | Cloud Spanner does not support DDL statements in a transaction. Add `?options=-c spanner.ddl_transaction_mode=AutocommitExplicitTransaction` to your connection string to automatically convert DDL transactions to [non-atomic DDL batches](../../../docs/ddl.md). |
+| Generated primary keys | Manually assign a value to the primary key column in your code. The recommended primary key type is a random UUID. Sequences / SERIAL / IDENTITY columns are currently not supported. |
+| INSERT ... ON CONFLICT | `INSERT ... ON CONFLICT` is not supported. |
+| SAVEPOINT | Nested transactions and savepoints are not supported. |
+| SELECT ... FOR UPDATE | `SELECT ... FOR UPDATE` is not supported. |
+| Server side cursors | Server side cursors are currently not supported. |
+| Transaction isolation level | Only SERIALIZABLE and AUTOCOMMIT are supported. `postgresql_readonly=True` is also supported. It is recommended to use either autocommit or read-only for workloads that only read data and/or that do not need to be atomic to get the best possible performance. |
+| Stored procedures | Cloud Spanner does not support Stored Procedures. |
+| User defined functions | Cloud Spanner does not support User Defined Functions. |
+| Other drivers than psycopg2 | PGAdapter does not support using SQLAlchemy with any other drivers than `psycopg2`. |
+
+### Generated Primary Keys
+Generated primary keys are not supported and should be replaced with primary key definitions that
+are manually assigned. See https://cloud.google.com/spanner/docs/schema-design#primary-key-prevent-hotspots
+for more information on choosing a good primary key. This sample uses random UUIDs that are generated
+by the client and stored as strings for primary keys.
+
+```python
+from uuid import uuid4
+
+class Singer(Base):
+ id = Column(String, primary_key=True)
+ name = Column(String(100))
+
+singer = Singer(
+ id="{}".format(uuid4()),
+ name="Alice")
+```
+
+### ON CONFLICT Clauses
+`INSERT ... ON CONFLICT ...` are not supported by Cloud Spanner and should not be used. Trying to
+use https://docs.sqlalchemy.org/en/14/dialects/postgresql.html#sqlalchemy.dialects.postgresql.Insert.on_conflict_do_update
+or https://docs.sqlalchemy.org/en/14/dialects/postgresql.html#sqlalchemy.dialects.postgresql.Insert.on_conflict_do_nothing
+will fail.
+
+### SAVEPOINT - Nested transactions
+`SAVEPOINT`s are not supported by Cloud Spanner. Nested transactions in SQLAlchemy are translated to
+savepoints and are therefore not supported. Trying to use `Session.begin_nested()`
+(https://docs.sqlalchemy.org/en/14/orm/session_api.html#sqlalchemy.orm.Session.begin_nested) will fail.
+
+### Locking - SELECT ... FOR UPDATE
+Locking clauses, like `SELECT ... FOR UPDATE`, are not supported (see also https://docs.sqlalchemy.org/en/20/orm/queryguide/query.html#sqlalchemy.orm.Query.with_for_update).
+These are normally also not required, as Cloud Spanner uses isolation level `serializable` for
+read/write transactions.
+
+## Performance Considerations
+
+### Parameterized Queries
+`psycopg2` does not use [parameterized queries](https://cloud.google.com/spanner/docs/sql-best-practices#postgresql_1).
+Instead, `psycopg2` will replace query parameters with literals in the SQL string before sending
+these to PGAdapter. This means that each query must be parsed and planned separately. This will add
+latency to queries that are executed multiple times compared to if the queries were executed using
+actual query parameters.
+
+### Read-only Transactions
+SQLAlchemy will by default use read/write transactions for all database operations, including for
+workloads that only read data. This will cause Cloud Spanner to take locks for all data that is read
+during the transaction. It is recommended to use either autocommit or [read-only transactions](https://cloud.google.com/spanner/docs/transactions#read-only_transactions)
+for workloads that are known to only execute read operations. Read-only transactions do not take any
+locks. You can create a separate database engine that can be used for read-only transactions from
+your default database engine by adding the `postgresql_readonly=True` execution option.
+
+```python
+read_only_engine = engine.execution_options(postgresql_readonly=True)
+```
+
+### Autocommit
+Using isolation level `AUTOCOMMIT` will suppress the use of (read/write) transactions for each
+database operation in SQLAlchemy. Using autocommit is more efficient than read/write transactions
+for workloads that only read and/or that do not need the atomicity that is offered by transactions.
+
+You can create a separate database engine that can be used for workloads that do not need
+transactions by adding the `isolation_level="AUTOCOMMIT"` execution option to your default database
+engine.
+
+```python
+autocommit_engine = engine.execution_options(isolation_level="AUTOCOMMIT")
+```
+
+### Stale reads
+Read-only transactions and database engines using `AUTOCOMMIT` will by default use strong reads for
+queries. Cloud Spanner also supports stale reads.
+
+* A strong read is a read at a current timestamp and is guaranteed to see all data that has been
+ committed up until the start of this read. Spanner defaults to using strong reads to serve read requests.
+* A stale read is read at a timestamp in the past. If your application is latency sensitive but
+ tolerant of stale data, then stale reads can provide performance benefits.
+
+See also https://cloud.google.com/spanner/docs/reads#read_types
+
+You can create a database engine that will use stale reads in autocommit mode by adding the following
+to the connection string and execution options of the engine:
+
+```python
+conn_string = "postgresql+psycopg2://user:password@localhost:5432/my-database" \
+ "?options=-c spanner.read_only_staleness='MAX_STALENESS 10s'"
+engine = create_engine(conn_string).execution_options(isolation_level="AUTOCOMMIT")
+```
diff --git a/samples/python/sqlalchemy-sample/connect.py b/samples/python/sqlalchemy-sample/connect.py
new file mode 100644
index 000000000..0d6980c99
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/connect.py
@@ -0,0 +1,60 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+import argparse
+
+from sqlalchemy import create_engine, event
+from model import *
+import sys
+
+
+# Creates a database engine that can be used with the sample.py script in this
+# directory. This function assumes that the host, port and database name was
+# given as command line arguments.
+def create_test_engine(autocommit: bool=False, options: str=""):
+ parser = argparse.ArgumentParser(description='Run SQLAlchemy sample.')
+ parser.add_argument('host', type=str, help='host to connect to')
+ parser.add_argument('port', type=int, help='port number to connect to')
+ parser.add_argument('database', type=str, help='database to connect to', default='d')
+ args = parser.parse_args()
+
+ conn_string = "postgresql+psycopg2://user:password@{host}:{port}/" \
+ "{database}{options}".format(host=args.host,
+ port=args.port,
+ database=args.database,
+ options=options)
+ if args.host == "":
+ if options == "":
+ conn_string = conn_string + "?host=/tmp"
+ else:
+ conn_string = conn_string + "&host=/tmp"
+ engine = create_engine(conn_string, future=True)
+ if autocommit:
+ engine = engine.execution_options(isolation_level="AUTOCOMMIT")
+ return engine
+
+
+def register_event_listener_for_prepared_statements(engine):
+ # Register an event listener for this engine that creates prepared statements
+ # for each connection that is created. The prepared statement definitions can
+ # be added to the model classes.
+ @event.listens_for(engine, "connect")
+ def connect(dbapi_connection, connection_record):
+ cursor_obj = dbapi_connection.cursor()
+ for model in BaseMixin.__subclasses__():
+ if model.__prepare_statements__ is not None:
+ for prepare_statement in model.__prepare_statements__:
+ cursor_obj.execute(prepare_statement)
+
+ cursor_obj.close()
diff --git a/samples/python/sqlalchemy-sample/create_data_model.sql b/samples/python/sqlalchemy-sample/create_data_model.sql
new file mode 100644
index 000000000..c8e173b32
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/create_data_model.sql
@@ -0,0 +1,64 @@
+
+-- Executing the schema creation in a batch will improve execution speed.
+start batch ddl;
+
+create table if not exists singers (
+ id varchar not null primary key,
+ version_id int not null,
+ first_name varchar,
+ last_name varchar not null,
+ full_name varchar generated always as (coalesce(concat(first_name, ' '::varchar, last_name), last_name)) stored,
+ active boolean,
+ created_at timestamptz,
+ updated_at timestamptz
+);
+
+create table if not exists albums (
+ id varchar not null primary key,
+ version_id int not null,
+ title varchar not null,
+ marketing_budget numeric,
+ release_date date,
+ cover_picture bytea,
+ singer_id varchar not null,
+ created_at timestamptz,
+ updated_at timestamptz,
+ constraint fk_albums_singers foreign key (singer_id) references singers (id)
+);
+
+create table if not exists tracks (
+ id varchar not null,
+ track_number bigint not null,
+ version_id int not null,
+ title varchar not null,
+ sample_rate float8 not null,
+ created_at timestamptz,
+ updated_at timestamptz,
+ primary key (id, track_number)
+) interleave in parent albums on delete cascade;
+
+create table if not exists venues (
+ id varchar not null primary key,
+ version_id int not null,
+ name varchar not null,
+ description jsonb not null,
+ created_at timestamptz,
+ updated_at timestamptz
+);
+
+create table if not exists concerts (
+ id varchar not null primary key,
+ version_id int not null,
+ venue_id varchar not null,
+ singer_id varchar not null,
+ name varchar not null,
+ start_time timestamptz not null,
+ end_time timestamptz not null,
+ created_at timestamptz,
+ updated_at timestamptz,
+ constraint fk_concerts_venues foreign key (venue_id) references venues (id),
+ constraint fk_concerts_singers foreign key (singer_id) references singers (id),
+ constraint chk_end_time_after_start_time check (end_time > start_time)
+);
+
+run batch;
diff --git a/samples/python/sqlalchemy-sample/drop_data_model.sql b/samples/python/sqlalchemy-sample/drop_data_model.sql
new file mode 100644
index 000000000..b57ea2ad8
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/drop_data_model.sql
@@ -0,0 +1,10 @@
+-- Executing the schema drop in a batch will improve execution speed.
+start batch ddl;
+
+drop table if exists concerts;
+drop table if exists venues;
+drop table if exists tracks;
+drop table if exists albums;
+drop table if exists singers;
+
+run batch;
diff --git a/samples/python/sqlalchemy-sample/model.py b/samples/python/sqlalchemy-sample/model.py
new file mode 100644
index 000000000..b920cac03
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/model.py
@@ -0,0 +1,178 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from sqlalchemy import Column, Integer, String, Boolean, LargeBinary, Float, \
+ Numeric, DateTime, Date, FetchedValue, ForeignKey, ColumnDefault
+from sqlalchemy.orm import registry, relationship
+from sqlalchemy.dialects.postgresql import JSONB
+from datetime import timezone, datetime
+
+mapper_registry = registry()
+Base = mapper_registry.generate_base()
+
+
+"""
+ BaseMixin contains properties that are common to all models in this sample. All
+ models use a string column that contains a client-side generated UUID as the
+ primary key.
+ The created_at and updated_at properties are automatically filled with the
+ current client system time when a model is created or updated.
+"""
+
+
+def format_timestamp(timestamp: datetime) -> str:
+ return timestamp.astimezone(timezone.utc).isoformat() if timestamp else None
+
+
+class BaseMixin(object):
+ __prepare_statements__ = None
+
+ id = Column(String, primary_key=True)
+ version_id = Column(Integer, nullable=False)
+ created_at = Column(DateTime(timezone=True),
+ # We need to explicitly format the timestamps with a
+ # timezone to ensure that SQLAlchemy uses a
+ # timestamptz instead of just timestamp.
+ ColumnDefault(datetime.utcnow().astimezone(timezone.utc)))
+ updated_at = Column(DateTime(timezone=True),
+ ColumnDefault(
+ datetime.utcnow().astimezone(timezone.utc),
+ for_update=True))
+ __mapper_args__ = {"version_id_col": version_id}
+
+
+class Singer(BaseMixin, Base):
+ __tablename__ = "singers"
+
+ first_name = Column(String(100))
+ last_name = Column(String(200))
+ full_name = Column(String,
+ server_default=FetchedValue(),
+ server_onupdate=FetchedValue())
+ active = Column(Boolean)
+ albums = relationship("Album", back_populates="singer")
+
+ __mapper_args__ = {
+ "eager_defaults": True,
+ "version_id_col": BaseMixin.version_id
+ }
+
+ def __repr__(self):
+ return (
+ f"singers("
+ f"id={self.id!r},"
+ f"first_name={self.first_name!r},"
+ f"last_name={self.last_name!r},"
+ f"active={self.active!r},"
+ f"created_at={format_timestamp(self.created_at)!r},"
+ f"updated_at={format_timestamp(self.updated_at)!r}"
+ f")"
+ )
+
+
+class Album(BaseMixin, Base):
+ __tablename__ = "albums"
+
+ title = Column(String(200))
+ marketing_budget = Column(Numeric)
+ release_date = Column(Date)
+ cover_picture = Column(LargeBinary)
+ singer_id = Column(String, ForeignKey("singers.id"))
+ singer = relationship("Singer", back_populates="albums")
+ # The `tracks` relationship uses passive_deletes=True, because `tracks` is
+ # interleaved in `albums` with `ON DELETE CASCADE`. This prevents SQLAlchemy
+ # from deleting the related tracks when an album is deleted, and lets the
+ # database handle it.
+ tracks = relationship("Track", back_populates="album", passive_deletes=True)
+
+ def __repr__(self):
+ return (
+ f"albums("
+ f"id={self.id!r},"
+ f"title={self.title!r},"
+ f"marketing_budget={self.marketing_budget!r},"
+ f"release_date={self.release_date!r},"
+ f"cover_picture={self.cover_picture!r},"
+ f"singer={self.singer_id!r},"
+ f"created_at={format_timestamp(self.created_at)!r},"
+ f"updated_at={format_timestamp(self.updated_at)!r}"
+ f")"
+ )
+
+
+class Track(BaseMixin, Base):
+ __tablename__ = "tracks"
+
+ id = Column(String, ForeignKey("albums.id"), primary_key=True)
+ track_number = Column(Integer, primary_key=True)
+ title = Column(String)
+ sample_rate = Column(Float)
+ album = relationship("Album", back_populates="tracks")
+
+ def __repr__(self):
+ return (
+ f"tracks("
+ f"id={self.id!r},"
+ f"track_number={self.track_number!r},"
+ f"title={self.title!r},"
+ f"sample_rate={self.sample_rate!r},"
+ f"created_at={format_timestamp(self.created_at)!r},"
+ f"updated_at={format_timestamp(self.updated_at)!r}"
+ f")"
+ )
+
+
+class Venue(BaseMixin, Base):
+ __tablename__ = "venues"
+
+ name = Column(String(200))
+ description = Column(JSONB)
+
+ def __repr__(self):
+ return (
+ f"venues("
+ f"id={self.id!r},"
+ f"name={self.name!r},"
+ f"description={self.description!r},"
+ f"created_at={format_timestamp(self.created_at)!r},"
+ f"updated_at={format_timestamp(self.updated_at)!r}"
+ f")"
+ )
+
+
+class Concert(BaseMixin, Base):
+ __tablename__ = "concerts"
+
+ name = Column(String(200))
+ venue_id = Column(String, ForeignKey("venues.id"))
+ venue = relationship("Venue")
+ singer_id = Column(String, ForeignKey("singers.id"))
+ singer = relationship("Singer")
+ start_time = Column(DateTime(timezone=True))
+ end_time = Column(DateTime(timezone=True))
+
+ def __repr__(self):
+ return (
+ f"concerts("
+ f"id={self.id!r},"
+ f"name={self.name!r},"
+ f"venue={self.venue!r},"
+ f"singer={self.singer!r},"
+ f"start_time={self.start_time!r},"
+ f"end_time={self.end_time!r},"
+ f"created_at={format_timestamp(self.created_at)!r},"
+ f"updated_at={format_timestamp(self.updated_at)!r}"
+ f")"
+ )
diff --git a/samples/python/sqlalchemy-sample/run_sample.py b/samples/python/sqlalchemy-sample/run_sample.py
new file mode 100644
index 000000000..8e0a353e2
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/run_sample.py
@@ -0,0 +1,27 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+import argparse
+
+
+parser = argparse.ArgumentParser(description='Run SQLAlchemy sample.')
+parser.add_argument('host', type=str, help='host to connect to')
+parser.add_argument('port', type=int, help='port number to connect to')
+parser.add_argument('database', type=str, help='database to connect to')
+args = parser.parse_args()
+
+
+from sample import run_sample
+
+run_sample()
diff --git a/samples/python/sqlalchemy-sample/sample.py b/samples/python/sqlalchemy-sample/sample.py
new file mode 100644
index 000000000..37749a946
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/sample.py
@@ -0,0 +1,339 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from connect import create_test_engine
+from sqlalchemy.orm import Session, joinedload
+from sqlalchemy import text, func, or_
+from model import Singer, Album, Track, Venue, Concert
+from util_random_names import random_first_name, random_last_name, \
+ random_album_title, random_release_date, random_marketing_budget, \
+ random_cover_picture
+from uuid import uuid4
+from datetime import datetime, date
+
+
+# This is the default engine that is connected to PostgreSQL (PGAdapter).
+# This engine will by default use read/write transactions.
+engine = create_test_engine()
+
+# This engine uses read-only transactions instead of read/write transactions.
+# It is recommended to use a read-only transaction instead of a read/write
+# transaction for all workloads that only read data, as read-only transactions
+# do not take any locks.
+read_only_engine = engine.execution_options(postgresql_readonly=True)
+
+# This engine uses auto commit instead of transactions, and will execute all
+# read operations with a max staleness of 10 seconds. This will result in more
+# efficient read operations, but the data that is returned can have a staleness
+# of up to 10 seconds.
+stale_read_engine = create_test_engine(
+ autocommit=True,
+ options="?options=-c spanner.read_only_staleness='MAX_STALENESS 10s'")
+
+
+def run_sample():
+ # Create the data model that is used for this sample.
+ create_tables_if_not_exists()
+ # Delete any existing data before running the sample.
+ delete_all_data()
+
+ # First create some random singers and albums.
+ create_random_singers_and_albums()
+ print_singers_and_albums()
+
+ # Create a venue and a concert.
+ create_venue_and_concert_in_transaction()
+ print_concerts()
+
+ # Execute a query to get all albums released before 1980.
+ print_albums_released_before_1980()
+ # Print a subset of all singers using LIMIT and OFFSET.
+ print_singers_with_limit_and_offset()
+ # Execute a query to get all albums that start with the same letter as the
+ # first name or last name of the singer.
+ print_albums_first_character_of_title_equal_to_first_or_last_name()
+
+ # Delete an album. This will automatically also delete all related tracks, as
+ # Tracks is interleaved in Albums with the option ON DELETE CASCADE.
+ with Session(engine) as session:
+ album_id = session.query(Album).first().id
+ delete_album(album_id)
+
+ # Read all albums using a connection that *could* return stale data.
+ with Session(stale_read_engine) as session:
+ album = session.get(Album, album_id)
+ if album is None:
+ print("No album found using a stale read.")
+ else:
+ print("Album was found using a stale read, even though it has already been deleted.")
+
+ print()
+ print("Finished running sample")
+
+
+def create_tables_if_not_exists():
+ # Cloud Spanner does not support DDL in transactions. We therefore turn on
+ # autocommit for this connection.
+ with engine.execution_options(
+ isolation_level="AUTOCOMMIT").connect() as connection:
+ with open("create_data_model.sql") as file:
+ print("Reading sample data model from file")
+ ddl_script = text(file.read())
+ print("Executing table creation script")
+ connection.execute(ddl_script)
+ print("Finished executing table creation script")
+
+
+# Create five random singers, each with a number of albums.
+def create_random_singers_and_albums():
+ with Session(engine) as session:
+ session.add_all([
+ create_random_singer(5),
+ create_random_singer(10),
+ create_random_singer(7),
+ create_random_singer(3),
+ create_random_singer(12),
+ ])
+ session.commit()
+ print("Created 5 singers")
+
+
+# Print all singers and albums in currently in the database.
+def print_singers_and_albums():
+ with Session(read_only_engine) as session:
+ print()
+ for singer in session.query(Singer).order_by("last_name").all():
+ print("{} has {} albums:".format(singer.full_name, len(singer.albums)))
+ for album in singer.albums:
+ print(" '{}'".format(album.title))
+
+
+# Create a Venue and Concert in one read/write transaction.
+def create_venue_and_concert_in_transaction():
+ with Session(engine) as session:
+ singer = session.query(Singer).first()
+ venue = Venue(
+ id=str(uuid4()),
+ name="Avenue Park",
+ description={
+ "Capacity": 5000,
+ "Location": "New York",
+ "Country": "US"
+ }
+ )
+ concert = Concert(
+ id=str(uuid4()),
+ name="Avenue Park Open",
+ venue=venue,
+ singer=singer,
+ start_time=datetime.fromisoformat("2023-02-01T20:00:00-05:00"),
+ end_time=datetime.fromisoformat("2023-02-02T02:00:00-05:00"),
+ )
+ session.add_all([venue, concert])
+ session.commit()
+ print()
+ print("Created Venue and Concert")
+
+
+# Prints the concerts currently in the database.
+def print_concerts():
+ with Session(read_only_engine) as session:
+ # Query all concerts and join both Singer and Venue, so we can directly
+ # access the properties of these as well without having to execute
+ # additional queries.
+ concerts = (
+ session.query(Concert)
+ .options(joinedload(Concert.venue))
+ .options(joinedload(Concert.singer))
+ .order_by("start_time")
+ .all()
+ )
+ print()
+ for concert in concerts:
+ print("Concert '{}' starting at {} with {} will be held at {}"
+ .format(concert.name,
+ concert.start_time,
+ concert.singer.full_name,
+ concert.venue.name))
+
+
+# Prints all albums with a release date before 1980-01-01.
+def print_albums_released_before_1980():
+ with Session(read_only_engine) as session:
+ print()
+ print("Searching for albums released before 1980")
+ albums = (
+ session
+ .query(Album)
+ .filter(Album.release_date < date.fromisoformat("1980-01-01"))
+ .all()
+ )
+ for album in albums:
+ print(
+ " Album {} was released at {}".format(album.title, album.release_date))
+
+
+# Uses a limit and offset to select a subset of all singers in the database.
+def print_singers_with_limit_and_offset():
+ with Session(read_only_engine) as session:
+ print()
+ print("Printing at most 5 singers ordered by last name")
+ singers = (
+ session
+ .query(Singer)
+ .order_by(Singer.last_name)
+ .limit(5)
+ .offset(3)
+ .all()
+ )
+ num_singers = 0
+ for singer in singers:
+ num_singers = num_singers + 1
+ print(" {}. {}".format(num_singers, singer.full_name))
+ print("Found {} singers".format(num_singers))
+
+
+# Searches for all albums that have a title that starts with the same character
+# as the first character of either the first name or last name of the singer.
+def print_albums_first_character_of_title_equal_to_first_or_last_name():
+ print()
+ print("Searching for albums that have a title that starts with the same "
+ "character as the first or last name of the singer")
+ with Session(read_only_engine) as session:
+ albums = (
+ session
+ .query(Album)
+ .join(Singer)
+ .filter(or_(func.lower(func.substring(Album.title, 1, 1)) ==
+ func.lower(func.substring(Singer.first_name, 1, 1)),
+ func.lower(func.substring(Album.title, 1, 1)) ==
+ func.lower(func.substring(Singer.last_name, 1, 1))))
+ .all()
+ )
+ for album in albums:
+ print(" '{}' by {}".format(album.title, album.singer.full_name))
+
+
+# Creates a random singer row with `num_albums` random albums.
+def create_random_singer(num_albums):
+ return Singer(
+ id=str(uuid4()),
+ first_name=random_first_name(),
+ last_name=random_last_name(),
+ active=True,
+ albums=create_random_albums(num_albums)
+ )
+
+
+# Creates `num_albums` random album rows.
+def create_random_albums(num_albums):
+ if num_albums == 0:
+ return []
+ albums = []
+ for i in range(num_albums):
+ albums.append(
+ Album(
+ id=str(uuid4()),
+ title=random_album_title(),
+ release_date=random_release_date(),
+ marketing_budget=random_marketing_budget(),
+ cover_picture=random_cover_picture()
+ )
+ )
+ return albums
+
+
+# Loads and prints the singer with the given id.
+def load_singer(singer_id):
+ with Session(engine) as session:
+ singer = session.get(Singer, singer_id)
+ print(singer)
+ print("Albums:")
+ print(singer.albums)
+
+
+# Adds a new singer row to the database. Shows how flushing the session will
+# automatically return the generated `full_name` column of the Singer.
+def add_singer(singer):
+ with Session(engine) as session:
+ session.add(singer)
+ # We flush the session here to show that the generated column full_name is
+ # returned after the insert. Otherwise, we could just execute a commit
+ # directly.
+ session.flush()
+ print(
+ "Added singer {} with full name {}".format(singer.id, singer.full_name))
+ session.commit()
+
+
+# Updates an existing singer in the database. This will also automatically
+# update the full_name of the singer. This is returned by the database and is
+# visible in the properties of the singer.
+def update_singer(singer_id, first_name, last_name):
+ with Session(engine) as session:
+ singer = session.get(Singer, singer_id)
+ singer.first_name = first_name
+ singer.last_name = last_name
+ # We flush the session here to show that the generated column full_name is
+ # returned after the update. Otherwise, we could just execute a commit
+ # directly.
+ session.flush()
+ print("Updated singer {} with full name {}"
+ .format(singer.id, singer.full_name))
+ session.commit()
+
+
+# Loads the given album from the database and prints its properties, including
+# the tracks related to this album. The table `tracks` is interleaved in
+# `albums`. This is handled as a normal relationship in SQLAlchemy.
+def load_album(album_id):
+ with Session(engine) as session:
+ album = session.get(Album, album_id)
+ print(album)
+ print("Tracks:")
+ print(album.tracks)
+
+
+# Loads a single track from the database and prints its properties. Track has a
+# composite primary key, as it must include both the primary key column(s) of
+# the parent table, as well as its own primary key column(s).
+def load_track(album_id, track_number):
+ with Session(engine) as session:
+ # The "tracks" table has a composite primary key, as it is an interleaved
+ # child table of "albums".
+ track = session.get(Track, [album_id, track_number])
+ print(track)
+
+
+# Deletes all current sample data.
+def delete_all_data():
+ with Session(engine) as session:
+ session.query(Concert).delete()
+ session.query(Venue).delete()
+ session.query(Album).delete()
+ session.query(Singer).delete()
+ session.commit()
+
+
+# Deletes an album from the database. This will also delete all related tracks.
+# The deletion of the tracks is done by the database, and not by SQLAlchemy, as
+# passive_deletes=True has been set on the relationship.
+def delete_album(album_id):
+ with Session(engine) as session:
+ album = session.get(Album, album_id)
+ session.delete(album)
+ session.commit()
+ print()
+ print("Deleted album with id {}".format(album_id))
diff --git a/samples/python/sqlalchemy-sample/test_add_singer.py b/samples/python/sqlalchemy-sample/test_add_singer.py
new file mode 100644
index 000000000..cb09abcae
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_add_singer.py
@@ -0,0 +1,28 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from model import Singer
+from sample import add_singer
+from datetime import datetime
+
+singer = Singer()
+singer.id = "123-456-789"
+singer.first_name = "Myfirstname"
+singer.last_name = "Mylastname"
+singer.active = True
+# Manually set a created_at value, as we otherwise do not know which value to
+# add to the mock server.
+singer.created_at = datetime.fromisoformat("2011-11-04T00:05:23.123456+00:00"),
+add_singer(singer)
diff --git a/samples/python/sqlalchemy-sample/test_create_model.py b/samples/python/sqlalchemy-sample/test_create_model.py
new file mode 100644
index 000000000..33644d89a
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_create_model.py
@@ -0,0 +1,23 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from connect import create_test_engine
+from model import mapper_registry
+
+
+engine = create_test_engine(options="?options=-c spanner.ddl_transaction_mode"
+ "=AutocommitExplicitTransaction")
+mapper_registry.metadata.create_all(engine)
+print("Created data model")
diff --git a/samples/python/sqlalchemy-sample/test_create_random_singers_and_albums.py b/samples/python/sqlalchemy-sample/test_create_random_singers_and_albums.py
new file mode 100644
index 000000000..46ddcb59d
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_create_random_singers_and_albums.py
@@ -0,0 +1,19 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from sample import create_random_singers_and_albums
+
+
+create_random_singers_and_albums()
diff --git a/samples/python/sqlalchemy-sample/test_create_venue_and_concert_in_transaction.py b/samples/python/sqlalchemy-sample/test_create_venue_and_concert_in_transaction.py
new file mode 100644
index 000000000..ef82c5104
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_create_venue_and_concert_in_transaction.py
@@ -0,0 +1,19 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from sample import create_venue_and_concert_in_transaction
+
+
+create_venue_and_concert_in_transaction()
diff --git a/samples/python/sqlalchemy-sample/test_delete_album.py b/samples/python/sqlalchemy-sample/test_delete_album.py
new file mode 100644
index 000000000..fc794c1e1
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_delete_album.py
@@ -0,0 +1,18 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from sample import delete_album
+
+delete_album("123-456-789")
diff --git a/samples/python/sqlalchemy-sample/test_get_album.py b/samples/python/sqlalchemy-sample/test_get_album.py
new file mode 100644
index 000000000..6abe3752e
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_get_album.py
@@ -0,0 +1,19 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from sample import load_album
+
+
+load_album("987-654-321")
diff --git a/samples/python/sqlalchemy-sample/test_get_album_with_stale_engine.py b/samples/python/sqlalchemy-sample/test_get_album_with_stale_engine.py
new file mode 100644
index 000000000..7952bf07d
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_get_album_with_stale_engine.py
@@ -0,0 +1,22 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from model import Album
+from sample import stale_read_engine
+from sqlalchemy.orm import Session
+
+
+with Session(stale_read_engine) as session:
+ album = session.get(Album, "987-654-321")
diff --git a/samples/python/sqlalchemy-sample/test_get_singer.py b/samples/python/sqlalchemy-sample/test_get_singer.py
new file mode 100644
index 000000000..e60e1bc97
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_get_singer.py
@@ -0,0 +1,19 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from sample import load_singer
+
+
+load_singer("123-456-789")
diff --git a/samples/python/sqlalchemy-sample/test_get_track.py b/samples/python/sqlalchemy-sample/test_get_track.py
new file mode 100644
index 000000000..6c0be66cc
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_get_track.py
@@ -0,0 +1,19 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from sample import load_track
+
+
+load_track("987-654-321", 1)
diff --git a/samples/python/sqlalchemy-sample/test_metadata_reflect.py b/samples/python/sqlalchemy-sample/test_metadata_reflect.py
new file mode 100644
index 000000000..aad9e0bfe
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_metadata_reflect.py
@@ -0,0 +1,22 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from connect import create_test_engine
+from model import mapper_registry
+
+
+engine = create_test_engine()
+mapper_registry.metadata.reflect(engine)
+print("Reflected current data model")
diff --git a/samples/python/sqlalchemy-sample/test_print_albums_first_character_of_title_equal_to_first_or_last_name.py b/samples/python/sqlalchemy-sample/test_print_albums_first_character_of_title_equal_to_first_or_last_name.py
new file mode 100644
index 000000000..ad59f9ec5
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_print_albums_first_character_of_title_equal_to_first_or_last_name.py
@@ -0,0 +1,20 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from sample import \
+ print_albums_first_character_of_title_equal_to_first_or_last_name
+
+
+print_albums_first_character_of_title_equal_to_first_or_last_name()
diff --git a/samples/python/sqlalchemy-sample/test_print_albums_released_before_1980.py b/samples/python/sqlalchemy-sample/test_print_albums_released_before_1980.py
new file mode 100644
index 000000000..f5da8854e
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_print_albums_released_before_1980.py
@@ -0,0 +1,19 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from sample import print_albums_released_before_1980
+
+
+print_albums_released_before_1980()
diff --git a/samples/python/sqlalchemy-sample/test_print_concerts.py b/samples/python/sqlalchemy-sample/test_print_concerts.py
new file mode 100644
index 000000000..c5c879c96
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_print_concerts.py
@@ -0,0 +1,19 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from sample import print_concerts
+
+
+print_concerts()
diff --git a/samples/python/sqlalchemy-sample/test_print_singers_and_albums.py b/samples/python/sqlalchemy-sample/test_print_singers_and_albums.py
new file mode 100644
index 000000000..568bc0a5b
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_print_singers_and_albums.py
@@ -0,0 +1,19 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from sample import print_singers_and_albums
+
+
+print_singers_and_albums()
diff --git a/samples/python/sqlalchemy-sample/test_print_singers_with_limit_and_offset.py b/samples/python/sqlalchemy-sample/test_print_singers_with_limit_and_offset.py
new file mode 100644
index 000000000..4783bad11
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_print_singers_with_limit_and_offset.py
@@ -0,0 +1,19 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from sample import print_singers_with_limit_and_offset
+
+
+print_singers_with_limit_and_offset()
diff --git a/samples/python/sqlalchemy-sample/test_update_singer.py b/samples/python/sqlalchemy-sample/test_update_singer.py
new file mode 100644
index 000000000..e5d844c79
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/test_update_singer.py
@@ -0,0 +1,18 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from sample import update_singer
+
+update_singer("123-456-789", "Newfirstname", "Newlastname")
diff --git a/samples/python/sqlalchemy-sample/util_random_names.py b/samples/python/sqlalchemy-sample/util_random_names.py
new file mode 100644
index 000000000..e5769408f
--- /dev/null
+++ b/samples/python/sqlalchemy-sample/util_random_names.py
@@ -0,0 +1,150 @@
+""" Copyright 2022 Google LLC
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import secrets
+from random import seed, randrange, random
+from datetime import date
+
+seed()
+
+
+"""
+ Helper functions for generating random names and titles.
+"""
+
+def random_first_name():
+ return first_names[randrange(len(first_names))]
+
+
+def random_last_name():
+ return last_names[randrange(len(last_names))]
+
+
+def random_album_title():
+ return "{} {}".format(
+ adjectives[randrange(len(adjectives))], nouns[randrange(len(nouns))])
+
+
+def random_release_date():
+ return date.fromisoformat(
+ "{}-{:02d}-{:02d}".format(randrange(1900, 2023),
+ randrange(1, 13),
+ randrange(1, 29)))
+
+
+def random_marketing_budget():
+ return random() * 1000000
+
+
+def random_cover_picture():
+ return secrets.token_bytes(randrange(1, 10))
+
+
+first_names = [
+ "Saffron", "Eleanor", "Ann", "Salma", "Kiera", "Mariam", "Georgie", "Eden", "Carmen", "Darcie",
+ "Antony", "Benjamin", "Donald", "Keaton", "Jared", "Simon", "Tanya", "Julian", "Eugene", "Laurence"
+]
+last_names = [
+ "Terry", "Ford", "Mills", "Connolly", "Newton", "Rodgers", "Austin", "Floyd", "Doherty", "Nguyen",
+ "Chavez", "Crossley", "Silva", "George", "Baldwin", "Burns", "Russell", "Ramirez", "Hunter", "Fuller"
+]
+adjectives = [
+ "ultra",
+ "happy",
+ "emotional",
+ "filthy",
+ "charming",
+ "alleged",
+ "talented",
+ "exotic",
+ "lamentable",
+ "lewd",
+ "old-fashioned",
+ "savory",
+ "delicate",
+ "willing",
+ "habitual",
+ "upset",
+ "gainful",
+ "nonchalant",
+ "kind",
+ "unruly"
+]
+nouns = [
+ "improvement",
+ "control",
+ "tennis",
+ "gene",
+ "department",
+ "person",
+ "awareness",
+ "health",
+ "development",
+ "platform",
+ "garbage",
+ "suggestion",
+ "agreement",
+ "knowledge",
+ "introduction",
+ "recommendation",
+ "driver",
+ "elevator",
+ "industry",
+ "extent"
+]
+verbs = [
+ "instruct",
+ "rescue",
+ "disappear",
+ "import",
+ "inhibit",
+ "accommodate",
+ "dress",
+ "describe",
+ "mind",
+ "strip",
+ "crawl",
+ "lower",
+ "influence",
+ "alter",
+ "prove",
+ "race",
+ "label",
+ "exhaust",
+ "reach",
+ "remove"
+]
+adverbs = [
+ "cautiously",
+ "offensively",
+ "immediately",
+ "soon",
+ "judgementally",
+ "actually",
+ "honestly",
+ "slightly",
+ "limply",
+ "rigidly",
+ "fast",
+ "normally",
+ "unnecessarily",
+ "wildly",
+ "unimpressively",
+ "helplessly",
+ "rightfully",
+ "kiddingly",
+ "early",
+ "queasily"
+]
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java b/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java
index 176a2f486..5f3fc7d57 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java
@@ -52,9 +52,11 @@
import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse;
import com.google.cloud.spanner.pgadapter.wireoutput.TerminateResponse;
import com.google.cloud.spanner.pgadapter.wireprotocol.BootstrapMessage;
+import com.google.cloud.spanner.pgadapter.wireprotocol.ParseMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.SSLMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.WireMessage;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Strings;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableList;
@@ -69,6 +71,7 @@
import java.time.Duration;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Properties;
@@ -120,7 +123,7 @@ public class ConnectionHandler extends Thread {
private int invalidMessagesCount;
private Connection spannerConnection;
private DatabaseId databaseId;
- private WellKnownClient wellKnownClient;
+ private WellKnownClient wellKnownClient = WellKnownClient.UNSPECIFIED;
private boolean hasDeterminedClientUsingQuery;
private ExtendedQueryProtocolHandler extendedQueryProtocolHandler;
private CopyStatement activeCopyStatement;
@@ -476,12 +479,12 @@ void handleError(PGException exception) throws Exception {
DataOutputStream output = getConnectionMetadata().getOutputStream();
if (this.status == ConnectionStatus.TERMINATED
|| this.status == ConnectionStatus.UNAUTHENTICATED) {
- new ErrorResponse(output, exception).send();
+ new ErrorResponse(this, exception).send();
new TerminateResponse(output).send();
} else if (this.status == ConnectionStatus.COPY_IN) {
- new ErrorResponse(output, exception).send();
+ new ErrorResponse(this, exception).send();
} else {
- new ErrorResponse(output, exception).send();
+ new ErrorResponse(this, exception).send();
new ReadyResponse(output, ReadyResponse.Status.IDLE).send();
}
}
@@ -704,6 +707,14 @@ public WellKnownClient getWellKnownClient() {
public void setWellKnownClient(WellKnownClient wellKnownClient) {
this.wellKnownClient = wellKnownClient;
+ if (this.wellKnownClient != WellKnownClient.UNSPECIFIED) {
+ logger.log(
+ Level.INFO,
+ () ->
+ String.format(
+ "Well-known client %s detected for connection %d.",
+ this.wellKnownClient, getConnectionId()));
+ }
}
/**
@@ -713,23 +724,58 @@ public void setWellKnownClient(WellKnownClient wellKnownClient) {
* executed.
*/
public void maybeDetermineWellKnownClient(Statement statement) {
- if (!this.hasDeterminedClientUsingQuery
- && this.wellKnownClient == WellKnownClient.UNSPECIFIED
- && getServer().getOptions().shouldAutoDetectClient()) {
- this.wellKnownClient = ClientAutoDetector.detectClient(ImmutableList.of(statement));
- if (this.wellKnownClient != WellKnownClient.UNSPECIFIED) {
- logger.log(
- Level.INFO,
- () ->
- String.format(
- "Well-known client %s detected for connection %d.",
- this.wellKnownClient, getConnectionId()));
+ if (!this.hasDeterminedClientUsingQuery) {
+ if (this.wellKnownClient == WellKnownClient.UNSPECIFIED
+ && getServer().getOptions().shouldAutoDetectClient()) {
+ setWellKnownClient(ClientAutoDetector.detectClient(ImmutableList.of(statement)));
}
+ maybeSetApplicationName();
}
// Make sure that we only try to detect the client once.
this.hasDeterminedClientUsingQuery = true;
}
+ /**
+ * This is called by the extended query protocol {@link
+ * com.google.cloud.spanner.pgadapter.wireprotocol.ParseMessage} to give the connection the
+ * opportunity to determine the client that is connected based on the data in the (first) parse
+ * messages.
+ */
+ public void maybeDetermineWellKnownClient(ParseMessage parseMessage) {
+ if (!this.hasDeterminedClientUsingQuery) {
+ if (this.wellKnownClient == WellKnownClient.UNSPECIFIED
+ && getServer().getOptions().shouldAutoDetectClient()) {
+ setWellKnownClient(ClientAutoDetector.detectClient(parseMessage));
+ }
+ maybeSetApplicationName();
+ }
+ // Make sure that we only try to detect the client once.
+ this.hasDeterminedClientUsingQuery = true;
+ }
+
+ private void maybeSetApplicationName() {
+ try {
+ if (this.wellKnownClient != WellKnownClient.UNSPECIFIED
+ && getExtendedQueryProtocolHandler() != null
+ && Strings.isNullOrEmpty(
+ getExtendedQueryProtocolHandler()
+ .getBackendConnection()
+ .getSessionState()
+ .get(null, "application_name")
+ .getSetting())) {
+ getExtendedQueryProtocolHandler()
+ .getBackendConnection()
+ .getSessionState()
+ .set(null, "application_name", wellKnownClient.name().toLowerCase(Locale.ENGLISH));
+ getExtendedQueryProtocolHandler().getBackendConnection().getSessionState().commit();
+ }
+ } catch (Throwable ignore) {
+ // Safeguard against a theoretical situation that 'application_name' has been removed from
+ // the list of settings. Just ignore this situation, as the only consequence is that the
+ // 'application_name' setting has not been set.
+ }
+ }
+
/** Status of a {@link ConnectionHandler} */
public enum ConnectionStatus {
UNAUTHENTICATED,
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java b/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java
index e620ff36f..f38c4dde4 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java
@@ -358,7 +358,7 @@ public String toString() {
return String.format("ProxyServer[port: %d]", getLocalPort());
}
- ConcurrentLinkedQueue getDebugMessages() {
+ public ConcurrentLinkedQueue getDebugMessages() {
return debugMessages;
}
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/metadata/OptionsMetadata.java b/src/main/java/com/google/cloud/spanner/pgadapter/metadata/OptionsMetadata.java
index 6abac46f3..f3b595cce 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/metadata/OptionsMetadata.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/metadata/OptionsMetadata.java
@@ -88,7 +88,7 @@ public enum DdlTransactionMode {
}
private static final Logger logger = Logger.getLogger(OptionsMetadata.class.getName());
- private static final String DEFAULT_SERVER_VERSION = "14.1";
+ public static final String DEFAULT_SERVER_VERSION = "14.1";
private static final String DEFAULT_USER_AGENT = "pg-adapter";
private static final String OPTION_SERVER_PORT = "s";
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/Parser.java b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/Parser.java
index 16141cec1..4bbb1b67d 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/Parser.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/Parser.java
@@ -97,9 +97,12 @@ public static Parser> create(
case Oid.JSONB:
return new JsonbParser(item, formatCode);
case Oid.UNSPECIFIED:
- return new UnspecifiedParser(item, formatCode);
default:
- throw new IllegalArgumentException("Unsupported parameter type: " + oidType);
+ // Use the UnspecifiedParser for unknown types. This will encode the parameter value as a
+ // string and send it to Spanner without any type information. This will ensure that clients
+ // that for example send char instead of varchar as the type code for a parameter would
+ // still work.
+ return new UnspecifiedParser(item, formatCode);
}
}
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/session/SessionState.java b/src/main/java/com/google/cloud/spanner/pgadapter/session/SessionState.java
index 4692e590c..ca37be802 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/session/SessionState.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/session/SessionState.java
@@ -374,6 +374,11 @@ public void rollback() {
this.transactionSettings = null;
}
+ /** Returns the PostgreSQL version. */
+ public String getServerVersion() {
+ return getStringSetting(null, "server_version", OptionsMetadata.DEFAULT_SERVER_VERSION);
+ }
+
/**
* Returns whether transaction statements should be ignored and all statements should be executed
* in autocommit mode.
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/BackendConnection.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/BackendConnection.java
index d3ba9aa36..ef1d09f76 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/BackendConnection.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/BackendConnection.java
@@ -30,7 +30,6 @@
import com.google.cloud.spanner.PartitionOptions;
import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode;
import com.google.cloud.spanner.ResultSet;
-import com.google.cloud.spanner.ResultSets;
import com.google.cloud.spanner.Spanner;
import com.google.cloud.spanner.SpannerBatchUpdateException;
import com.google.cloud.spanner.SpannerException;
@@ -59,12 +58,14 @@
import com.google.cloud.spanner.pgadapter.statements.SessionStatementParser.SessionStatement;
import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName;
import com.google.cloud.spanner.pgadapter.statements.local.LocalStatement;
+import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient;
import com.google.cloud.spanner.pgadapter.utils.CopyDataReceiver;
import com.google.cloud.spanner.pgadapter.utils.MutationWriter;
import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse;
import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse.Status;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
+import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMap.Builder;
@@ -87,6 +88,7 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.Function;
+import java.util.function.Supplier;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
@@ -208,9 +210,10 @@ void execute() {
// block always ends with a ROLLBACK, PGAdapter should skip the entire execution of that
// block.
SessionStatement sessionStatement = getSessionManagementStatement(parsedStatement);
- if (!localStatements.isEmpty() && localStatements.containsKey(statement.getSql())) {
+ if (!localStatements.get().isEmpty()
+ && localStatements.get().containsKey(statement.getSql())) {
result.set(
- Objects.requireNonNull(localStatements.get(statement.getSql()))
+ Objects.requireNonNull(localStatements.get().get(statement.getSql()))
.execute(BackendConnection.this));
} else if (sessionStatement != null) {
result.set(sessionStatement.execute(sessionState));
@@ -253,7 +256,7 @@ void execute() {
// Potentially replace pg_catalog table references with common table expressions.
updatedStatement =
sessionState.isReplacePgCatalogTables()
- ? pgCatalog.replacePgCatalogTables(statement)
+ ? pgCatalog.get().replacePgCatalogTables(statement)
: statement;
updatedStatement = bindStatement(updatedStatement);
result.set(analyzeOrExecute(updatedStatement));
@@ -372,7 +375,7 @@ SessionStatement getSessionManagementStatement(ParsedStatement parsedStatement)
}
}
- private static final int MAX_PARTITIONS =
+ public static final int MAX_PARTITIONS =
Math.max(16, 2 * Runtime.getRuntime().availableProcessors());
private final class CopyOut extends BufferedStatement {
@@ -401,9 +404,14 @@ void execute() {
// No need for the extra complexity of a partitioned query.
result.set(spannerConnection.execute(statement));
} else {
+ // Get the metadata of the query, so we can include that in the result.
+ ResultSet metadataResultSet =
+ spannerConnection.analyzeQuery(statement, QueryAnalyzeMode.PLAN);
result.set(
new PartitionQueryResult(
- batchReadOnlyTransaction.getBatchTransactionId(), partitions));
+ batchReadOnlyTransaction.getBatchTransactionId(),
+ partitions,
+ metadataResultSet));
}
} catch (SpannerException spannerException) {
// The query might not be suitable for partitioning. Just try with a normal query.
@@ -562,8 +570,8 @@ void execute() {
private static final Statement ROLLBACK = Statement.of("ROLLBACK");
private final SessionState sessionState;
- private final PgCatalog pgCatalog;
- private final ImmutableMap localStatements;
+ private final Supplier pgCatalog;
+ private final Supplier> localStatements;
private ConnectionState connectionState = ConnectionState.IDLE;
private TransactionMode transactionMode = TransactionMode.IMPLICIT;
private final String currentSchema = "public";
@@ -576,24 +584,31 @@ void execute() {
BackendConnection(
DatabaseId databaseId,
Connection spannerConnection,
+ Supplier wellKnownClient,
OptionsMetadata optionsMetadata,
- ImmutableList localStatements) {
+ Supplier> localStatements) {
this.sessionState = new SessionState(optionsMetadata);
- this.pgCatalog = new PgCatalog(this.sessionState);
+ this.pgCatalog =
+ Suppliers.memoize(
+ () -> new PgCatalog(BackendConnection.this.sessionState, wellKnownClient.get()));
this.spannerConnection = spannerConnection;
this.databaseId = databaseId;
this.ddlExecutor = new DdlExecutor(databaseId, this);
- if (localStatements.isEmpty()) {
- this.localStatements = EMPTY_LOCAL_STATEMENTS;
- } else {
- Builder builder = ImmutableMap.builder();
- for (LocalStatement localStatement : localStatements) {
- for (String sql : localStatement.getSql()) {
- builder.put(new SimpleImmutableEntry<>(sql, localStatement));
- }
- }
- this.localStatements = builder.build();
- }
+ this.localStatements =
+ Suppliers.memoize(
+ () -> {
+ if (localStatements.get().isEmpty()) {
+ return EMPTY_LOCAL_STATEMENTS;
+ } else {
+ Builder builder = ImmutableMap.builder();
+ for (LocalStatement localStatement : localStatements.get()) {
+ for (String sql : localStatement.getSql()) {
+ builder.put(new SimpleImmutableEntry<>(sql, localStatement));
+ }
+ }
+ return builder.build();
+ }
+ });
}
/** Returns the current connection state. */
@@ -1262,10 +1277,15 @@ public Long getUpdateCount() {
public static final class PartitionQueryResult implements StatementResult {
private final BatchTransactionId batchTransactionId;
private final List partitions;
+ private final ResultSet metadataResultSet;
- public PartitionQueryResult(BatchTransactionId batchTransactionId, List partitions) {
+ public PartitionQueryResult(
+ BatchTransactionId batchTransactionId,
+ List partitions,
+ ResultSet metadataResultSet) {
this.batchTransactionId = batchTransactionId;
this.partitions = partitions;
+ this.metadataResultSet = metadataResultSet;
}
public BatchTransactionId getBatchTransactionId() {
@@ -1276,6 +1296,10 @@ public List getPartitions() {
return partitions;
}
+ public ResultSet getMetadataResultSet() {
+ return metadataResultSet;
+ }
+
@Override
public ResultType getResultType() {
return ResultType.RESULT_SET;
@@ -1288,7 +1312,7 @@ public ClientSideStatementType getClientSideStatementType() {
@Override
public ResultSet getResultSet() {
- return ResultSets.forRows(
+ return ClientSideResultSet.forRows(
Type.struct(StructField.of("partition", Type.bytes())),
partitions.stream()
.map(
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/ClientSideResultSet.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/ClientSideResultSet.java
new file mode 100644
index 000000000..91c92bcb7
--- /dev/null
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/ClientSideResultSet.java
@@ -0,0 +1,54 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.cloud.spanner.pgadapter.statements;
+
+import com.google.api.core.InternalApi;
+import com.google.cloud.spanner.ForwardingResultSet;
+import com.google.cloud.spanner.ResultSet;
+import com.google.cloud.spanner.ResultSets;
+import com.google.cloud.spanner.Struct;
+import com.google.cloud.spanner.Type;
+import com.google.spanner.v1.ResultSetMetadata;
+import com.google.spanner.v1.ResultSetStats;
+
+/** Wrapper class for query results that are handled directly in PGAdapter. */
+@InternalApi
+public class ClientSideResultSet extends ForwardingResultSet {
+ public static ResultSet forRows(Type type, Iterable rows) {
+ return new ClientSideResultSet(ResultSets.forRows(type, rows), type);
+ }
+
+ private final Type type;
+
+ private ClientSideResultSet(ResultSet delegate, Type type) {
+ super(delegate);
+ this.type = type;
+ }
+
+ @Override
+ public Type getType() {
+ return type;
+ }
+
+ @Override
+ public ResultSetStats getStats() {
+ return ResultSetStats.getDefaultInstance();
+ }
+
+ @Override
+ public ResultSetMetadata getMetadata() {
+ return ResultSetMetadata.getDefaultInstance();
+ }
+}
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java
index 7c49c4bd1..370cd620c 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java
@@ -24,6 +24,7 @@
import com.google.cloud.spanner.connection.AbstractStatementParser.StatementType;
import com.google.cloud.spanner.connection.AutocommitDmlMode;
import com.google.cloud.spanner.pgadapter.ConnectionHandler;
+import com.google.cloud.spanner.pgadapter.ProxyServer.DataFormat;
import com.google.cloud.spanner.pgadapter.error.PGException;
import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory;
import com.google.cloud.spanner.pgadapter.error.SQLState;
@@ -80,7 +81,17 @@ public static IntermediatePortalStatement create(
public enum Format {
CSV,
TEXT,
- BINARY,
+ BINARY {
+ @Override
+ public DataFormat getDataFormat() {
+ return DataFormat.POSTGRESQL_BINARY;
+ }
+ };
+
+ /** Returns the (default) data format that should be used for this copy format. */
+ public DataFormat getDataFormat() {
+ return DataFormat.POSTGRESQL_TEXT;
+ }
}
private static final String COLUMN_NAME = "column_name";
@@ -230,8 +241,8 @@ public MutationWriter getMutationWriter() {
}
/** @return 0 for text/csv formatting and 1 for binary */
- public int getFormatCode() {
- return (parsedCopyStatement.format == Format.BINARY) ? 1 : 0;
+ public byte getFormatCode() {
+ return (parsedCopyStatement.format == Format.BINARY) ? (byte) 1 : (byte) 0;
}
private void verifyCopyColumns() {
@@ -552,7 +563,7 @@ static ParsedCopyStatement parse(String sql) {
}
ParsedCopyStatement.Builder builder = new ParsedCopyStatement.Builder();
if (parser.eatToken("(")) {
- builder.query = parser.parseExpressionUntilKeyword(ImmutableList.of(), true, true);
+ builder.query = parser.parseExpressionUntilKeyword(ImmutableList.of(), true, true, false);
if (!parser.eatToken(")")) {
throw PGExceptionFactory.newPGException(
"missing closing parentheses after query", SQLState.SyntaxError);
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatement.java
index 4c4c0f959..de2c852e1 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatement.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatement.java
@@ -23,7 +23,6 @@
import com.google.cloud.spanner.connection.Connection;
import com.google.cloud.spanner.connection.StatementResult;
import com.google.cloud.spanner.pgadapter.ConnectionHandler;
-import com.google.cloud.spanner.pgadapter.ProxyServer.DataFormat;
import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory;
import com.google.cloud.spanner.pgadapter.error.SQLState;
import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata;
@@ -42,6 +41,7 @@
import com.google.common.util.concurrent.Futures;
import java.util.List;
import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.QuoteMode;
@@ -57,7 +57,8 @@ public class CopyToStatement extends IntermediatePortalStatement {
new byte[] {'P', 'G', 'C', 'O', 'P', 'Y', '\n', -1, '\r', '\n', '\0'};
private final ParsedCopyStatement parsedCopyStatement;
- private final CSVFormat csvFormat;
+ private CSVFormat csvFormat;
+ private final AtomicBoolean hasReturnedData = new AtomicBoolean(false);
public CopyToStatement(
ConnectionHandler connectionHandler,
@@ -80,24 +81,28 @@ public CopyToStatement(
if (parsedCopyStatement.format == CopyStatement.Format.BINARY) {
this.csvFormat = null;
} else {
+ CSVFormat baseFormat =
+ parsedCopyStatement.format == Format.TEXT
+ ? CSVFormat.POSTGRESQL_TEXT
+ : CSVFormat.POSTGRESQL_CSV;
CSVFormat.Builder formatBuilder =
- CSVFormat.Builder.create(CSVFormat.POSTGRESQL_TEXT)
+ CSVFormat.Builder.create(baseFormat)
.setNullString(
parsedCopyStatement.nullString == null
- ? CSVFormat.POSTGRESQL_TEXT.getNullString()
+ ? baseFormat.getNullString()
: parsedCopyStatement.nullString)
.setRecordSeparator('\n')
.setDelimiter(
parsedCopyStatement.delimiter == null
- ? CSVFormat.POSTGRESQL_TEXT.getDelimiterString().charAt(0)
+ ? baseFormat.getDelimiterString().charAt(0)
: parsedCopyStatement.delimiter)
.setQuote(
parsedCopyStatement.quote == null
- ? CSVFormat.POSTGRESQL_TEXT.getQuoteCharacter()
+ ? baseFormat.getQuoteCharacter()
: parsedCopyStatement.quote)
.setEscape(
parsedCopyStatement.escape == null
- ? CSVFormat.POSTGRESQL_TEXT.getEscapeCharacter()
+ ? baseFormat.getEscapeCharacter()
: parsedCopyStatement.escape);
if (parsedCopyStatement.format == Format.TEXT) {
formatBuilder.setQuoteMode(QuoteMode.NONE);
@@ -220,22 +225,25 @@ public IntermediatePortalStatement createPortal(
@Override
public WireOutput[] createResultPrefix(ResultSet resultSet) {
- return this.parsedCopyStatement.format == CopyStatement.Format.BINARY
- ? new WireOutput[] {
- new CopyOutResponse(
- this.outputStream,
- resultSet.getColumnCount(),
- DataFormat.POSTGRESQL_BINARY.getCode()),
- CopyDataResponse.createBinaryHeader(this.outputStream)
- }
- : new WireOutput[] {
- new CopyOutResponse(
- this.outputStream, resultSet.getColumnCount(), DataFormat.POSTGRESQL_TEXT.getCode())
- };
+ return new WireOutput[] {
+ new CopyOutResponse(
+ this.outputStream,
+ resultSet.getColumnCount(),
+ this.parsedCopyStatement.format.getDataFormat().getCode())
+ };
}
@Override
public CopyDataResponse createDataRowResponse(Converter converter) {
+ // Keep track of whether this COPY statement has returned at least one row. This is necessary to
+ // know whether we need to include the header in the current row and/or in the trailer.
+ // PostgreSQL includes the header in either the first data row or in the trailer if there are no
+ // rows. This is not specifically mentioned in the protocol description, but some clients assume
+ // this behavior. See
+ // https://github.com/npgsql/npgsql/blob/7f97dbad28c71b2202dd7bcccd05fc42a7de23c8/src/Npgsql/NpgsqlBinaryExporter.cs#L156
+ if (parsedCopyStatement.format == Format.BINARY && !hasReturnedData.getAndSet(true)) {
+ converter = converter.includeBinaryCopyHeader();
+ }
return parsedCopyStatement.format == CopyStatement.Format.BINARY
? createBinaryDataResponse(converter)
: createDataResponse(converter.getResultSet());
@@ -245,7 +253,7 @@ public CopyDataResponse createDataRowResponse(Converter converter) {
public WireOutput[] createResultSuffix() {
return this.parsedCopyStatement.format == Format.BINARY
? new WireOutput[] {
- CopyDataResponse.createBinaryTrailer(this.outputStream),
+ CopyDataResponse.createBinaryTrailer(this.outputStream, !hasReturnedData.get()),
new CopyDoneResponse(this.outputStream)
}
: new WireOutput[] {new CopyDoneResponse(this.outputStream)};
@@ -268,6 +276,10 @@ CopyDataResponse createDataResponse(ResultSet resultSet) {
}
}
String row = csvFormat.format((Object[]) data);
+ // Only include the header with the first row.
+ if (!csvFormat.getSkipHeaderRecord()) {
+ csvFormat = csvFormat.builder().setSkipHeaderRecord(true).build();
+ }
return new CopyDataResponse(this.outputStream, row, csvFormat.getRecordSeparator().charAt(0));
}
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExtendedQueryProtocolHandler.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExtendedQueryProtocolHandler.java
index ae61125d0..9d3c5bf7c 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExtendedQueryProtocolHandler.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExtendedQueryProtocolHandler.java
@@ -40,8 +40,9 @@ public ExtendedQueryProtocolHandler(ConnectionHandler connectionHandler) {
new BackendConnection(
connectionHandler.getDatabaseId(),
connectionHandler.getSpannerConnection(),
+ connectionHandler::getWellKnownClient,
connectionHandler.getServer().getOptions(),
- connectionHandler.getWellKnownClient().getLocalStatements(connectionHandler));
+ () -> connectionHandler.getWellKnownClient().getLocalStatements(connectionHandler));
}
/** Constructor only intended for testing. */
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/PgCatalog.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/PgCatalog.java
index a1d4aac4f..d806f576c 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/PgCatalog.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/PgCatalog.java
@@ -20,6 +20,9 @@
import com.google.cloud.spanner.Value;
import com.google.cloud.spanner.pgadapter.session.SessionState;
import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName;
+import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
@@ -27,11 +30,13 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
+import java.util.function.Supplier;
import java.util.regex.Pattern;
+import javax.annotation.Nonnull;
@InternalApi
public class PgCatalog {
- private static final ImmutableMap TABLE_REPLACEMENTS =
+ private static final ImmutableMap DEFAULT_TABLE_REPLACEMENTS =
ImmutableMap.builder()
.put(
new TableOrIndexName("pg_catalog", "pg_namespace"),
@@ -57,30 +62,66 @@ public class PgCatalog {
.put(new TableOrIndexName(null, "pg_settings"), new TableOrIndexName(null, "pg_settings"))
.build();
- private static final ImmutableMap FUNCTION_REPLACEMENTS =
+ private static final ImmutableMap> DEFAULT_FUNCTION_REPLACEMENTS =
ImmutableMap.of(
- Pattern.compile("pg_catalog.pg_table_is_visible\\(.+\\)"), "true",
- Pattern.compile("pg_table_is_visible\\(.+\\)"), "true",
- Pattern.compile("ANY\\(current_schemas\\(true\\)\\)"), "'public'");
+ Pattern.compile("pg_catalog.pg_table_is_visible\\(.+\\)"), Suppliers.ofInstance("true"),
+ Pattern.compile("pg_table_is_visible\\(.+\\)"), Suppliers.ofInstance("true"),
+ Pattern.compile("ANY\\(current_schemas\\(true\\)\\)"), Suppliers.ofInstance("'public'"));
- private final Map pgCatalogTables =
+ private final ImmutableSet checkPrefixes;
+
+ private final ImmutableMap tableReplacements;
+ private final ImmutableMap pgCatalogTables;
+
+ private final ImmutableMap> functionReplacements;
+
+ private static final Map DEFAULT_PG_CATALOG_TABLES =
ImmutableMap.of(
new TableOrIndexName(null, "pg_namespace"), new PgNamespace(),
new TableOrIndexName(null, "pg_class"), new PgClass(),
new TableOrIndexName(null, "pg_proc"), new PgProc(),
new TableOrIndexName(null, "pg_range"), new PgRange(),
- new TableOrIndexName(null, "pg_type"), new PgType(),
- new TableOrIndexName(null, "pg_settings"), new PgSettings());
+ new TableOrIndexName(null, "pg_type"), new PgType());
private final SessionState sessionState;
- public PgCatalog(SessionState sessionState) {
- this.sessionState = sessionState;
+ public PgCatalog(@Nonnull SessionState sessionState, @Nonnull WellKnownClient wellKnownClient) {
+ this.sessionState = Preconditions.checkNotNull(sessionState);
+ this.checkPrefixes = wellKnownClient.getPgCatalogCheckPrefixes();
+ ImmutableMap.Builder builder =
+ ImmutableMap.builder()
+ .putAll(DEFAULT_TABLE_REPLACEMENTS);
+ wellKnownClient
+ .getTableReplacements()
+ .forEach((k, v) -> builder.put(TableOrIndexName.parse(k), TableOrIndexName.parse(v)));
+ this.tableReplacements = builder.build();
+
+ ImmutableMap.Builder pgCatalogTablesBuilder =
+ ImmutableMap.builder()
+ .putAll(DEFAULT_PG_CATALOG_TABLES)
+ .put(new TableOrIndexName(null, "pg_settings"), new PgSettings());
+ wellKnownClient
+ .getPgCatalogTables()
+ .forEach((k, v) -> pgCatalogTablesBuilder.put(TableOrIndexName.parse(k), v));
+ this.pgCatalogTables = pgCatalogTablesBuilder.build();
+
+ this.functionReplacements =
+ ImmutableMap.>builder()
+ .putAll(DEFAULT_FUNCTION_REPLACEMENTS)
+ .put(
+ Pattern.compile("version\\(\\)"), () -> "'" + sessionState.getServerVersion() + "'")
+ .putAll(wellKnownClient.getFunctionReplacements())
+ .build();
}
/** Replace supported pg_catalog tables with Common Table Expressions. */
public Statement replacePgCatalogTables(Statement statement) {
+ // Only replace tables if the statement contains at least one of the known prefixes.
+ if (checkPrefixes.stream().noneMatch(prefix -> statement.getSql().contains(prefix))) {
+ return statement;
+ }
+
Tuple, Statement> replacedTablesStatement =
- new TableParser(statement).detectAndReplaceTables(TABLE_REPLACEMENTS);
+ new TableParser(statement).detectAndReplaceTables(tableReplacements);
if (replacedTablesStatement.x().isEmpty()) {
return replacedTablesStatement.y();
}
@@ -95,16 +136,19 @@ public Statement replacePgCatalogTables(Statement statement) {
return addCommonTableExpressions(replacedTablesStatement.y(), cteBuilder.build());
}
- static String replaceKnownUnsupportedFunctions(Statement statement) {
+ String replaceKnownUnsupportedFunctions(Statement statement) {
String sql = statement.getSql();
- for (Entry functionReplacement : FUNCTION_REPLACEMENTS.entrySet()) {
- sql = functionReplacement.getKey().matcher(sql).replaceAll(functionReplacement.getValue());
+ for (Entry> functionReplacement : functionReplacements.entrySet()) {
+ sql =
+ functionReplacement
+ .getKey()
+ .matcher(sql)
+ .replaceAll(functionReplacement.getValue().get());
}
return sql;
}
- static Statement addCommonTableExpressions(
- Statement statement, ImmutableList tableExpressions) {
+ Statement addCommonTableExpressions(Statement statement, ImmutableList tableExpressions) {
String sql = replaceKnownUnsupportedFunctions(statement);
SimpleParser parser = new SimpleParser(sql);
boolean hadCommonTableExpressions = parser.eatKeyword("with");
@@ -153,7 +197,8 @@ PgCatalogTable getPgCatalogTable(TableOrIndexName tableOrIndexName) {
return null;
}
- private interface PgCatalogTable {
+ @InternalApi
+ public interface PgCatalogTable {
String getTableExpression();
default ImmutableSet getDependencies() {
@@ -378,4 +423,37 @@ public String getTableExpression() {
return PG_RANGE_CTE;
}
}
+
+ @InternalApi
+ public static class EmptyPgAttribute implements PgCatalogTable {
+ private static final String PG_ATTRIBUTE_CTE =
+ "pg_attribute as (\n"
+ + "select * from ("
+ + "select 0::bigint as attrelid, '' as attname, 0::bigint as atttypid, 0::bigint as attstattarget, "
+ + "0::bigint as attlen, 0::bigint as attnum, 0::bigint as attndims, -1::bigint as attcacheoff, "
+ + "0::bigint as atttypmod, true as attbyval, '' as attalign, '' as attstorage, '' as attcompression, "
+ + "false as attnotnull, true as atthasdef, false as atthasmissing, '' as attidentity, '' as attgenerated, "
+ + "false as attisdropped, true as attislocal, 0 as attinhcount, 0 as attcollation, '{}'::bigint[] as attacl, "
+ + "'{}'::text[] as attoptions, '{}'::text[] as attfdwoptions, null as attmissingval\n"
+ + ") a where false)";
+
+ @Override
+ public String getTableExpression() {
+ return PG_ATTRIBUTE_CTE;
+ }
+ }
+
+ @InternalApi
+ public static class EmptyPgEnum implements PgCatalogTable {
+ private static final String PG_ENUM_CTE =
+ "pg_enum as (\n"
+ + "select * from ("
+ + "select 0::bigint as oid, 0::bigint as enumtypid, 0.0::float8 as enumsortorder, ''::varchar as enumlabel\n"
+ + ") e where false)";
+
+ @Override
+ public String getTableExpression() {
+ return PG_ENUM_CTE;
+ }
+ }
}
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/SessionStatementParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/SessionStatementParser.java
index 5b0608684..2bfc7a0a3 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/SessionStatementParser.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/SessionStatementParser.java
@@ -17,7 +17,6 @@
import com.google.api.client.util.Strings;
import com.google.api.core.InternalApi;
import com.google.cloud.spanner.ErrorCode;
-import com.google.cloud.spanner.ResultSets;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.Struct;
import com.google.cloud.spanner.Type;
@@ -186,7 +185,7 @@ static ShowStatement createShowAll() {
public StatementResult execute(SessionState sessionState) {
if (name != null) {
return new QueryResult(
- ResultSets.forRows(
+ ClientSideResultSet.forRows(
Type.struct(StructField.of(getKey(), Type.string())),
ImmutableList.of(
Struct.newBuilder()
@@ -195,7 +194,7 @@ public StatementResult execute(SessionState sessionState) {
.build())));
}
return new QueryResult(
- ResultSets.forRows(
+ ClientSideResultSet.forRows(
Type.struct(
StructField.of("name", Type.string()),
StructField.of("setting", Type.string()),
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/SimpleParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/SimpleParser.java
index ccbd8eb9c..3e5f389d2 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/SimpleParser.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/SimpleParser.java
@@ -49,6 +49,19 @@ static class TableOrIndexName {
/** Name is the actual object name. */
final String name;
+ /**
+ * Parses an unquoted, qualified identifier. Use this to quickly parse strings like 'foo' and
+ * 'foo.bar' as an identifier.
+ */
+ static TableOrIndexName parse(String qualifiedName) {
+ SimpleParser parser = new SimpleParser(qualifiedName);
+ TableOrIndexName result = parser.readTableOrIndexName();
+ if (result == null || parser.hasMoreTokens()) {
+ throw new IllegalArgumentException("Invalid identifier: " + qualifiedName);
+ }
+ return result;
+ }
+
static TableOrIndexName of(String name) {
return new TableOrIndexName(name);
}
@@ -354,6 +367,15 @@ String parseExpressionUntilKeyword(
ImmutableList keywords,
boolean sameParensLevelAsStart,
boolean stopAtEndOfExpression) {
+ return parseExpressionUntilKeyword(
+ keywords, sameParensLevelAsStart, stopAtEndOfExpression, true);
+ }
+
+ String parseExpressionUntilKeyword(
+ ImmutableList keywords,
+ boolean sameParensLevelAsStart,
+ boolean stopAtEndOfExpression,
+ boolean stopAtComma) {
skipWhitespaces();
int start = pos;
boolean valid;
@@ -366,7 +388,7 @@ String parseExpressionUntilKeyword(
if (stopAtEndOfExpression && parens < 0) {
break;
}
- } else if (stopAtEndOfExpression && parens == 0 && sql.charAt(pos) == ',') {
+ } else if (stopAtEndOfExpression && parens == 0 && stopAtComma && sql.charAt(pos) == ',') {
break;
}
if ((!sameParensLevelAsStart || parens == 0)
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/DjangoGetTableNamesStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/DjangoGetTableNamesStatement.java
index f92005013..127cabde2 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/DjangoGetTableNamesStatement.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/DjangoGetTableNamesStatement.java
@@ -16,12 +16,12 @@
import com.google.api.core.InternalApi;
import com.google.cloud.spanner.ResultSet;
-import com.google.cloud.spanner.ResultSets;
import com.google.cloud.spanner.Type;
import com.google.cloud.spanner.Type.StructField;
import com.google.cloud.spanner.connection.StatementResult;
import com.google.cloud.spanner.pgadapter.statements.BackendConnection;
import com.google.cloud.spanner.pgadapter.statements.BackendConnection.QueryResult;
+import com.google.cloud.spanner.pgadapter.statements.ClientSideResultSet;
import com.google.common.collect.ImmutableList;
/*
@@ -60,7 +60,7 @@ public String[] getSql() {
@Override
public StatementResult execute(BackendConnection backendConnection) {
ResultSet resultSet =
- ResultSets.forRows(
+ ClientSideResultSet.forRows(
Type.struct(
StructField.of("relname", Type.string()), StructField.of("case", Type.string())),
ImmutableList.of());
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/ListDatabasesStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/ListDatabasesStatement.java
index d9f189602..ed4a22147 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/ListDatabasesStatement.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/ListDatabasesStatement.java
@@ -20,7 +20,6 @@
import com.google.cloud.spanner.Dialect;
import com.google.cloud.spanner.InstanceId;
import com.google.cloud.spanner.ResultSet;
-import com.google.cloud.spanner.ResultSets;
import com.google.cloud.spanner.Spanner;
import com.google.cloud.spanner.Struct;
import com.google.cloud.spanner.Type;
@@ -31,6 +30,7 @@
import com.google.cloud.spanner.pgadapter.ConnectionHandler;
import com.google.cloud.spanner.pgadapter.statements.BackendConnection;
import com.google.cloud.spanner.pgadapter.statements.BackendConnection.QueryResult;
+import com.google.cloud.spanner.pgadapter.statements.ClientSideResultSet;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.Comparator;
@@ -74,7 +74,7 @@ public StatementResult execute(BackendConnection backendConnection) {
InstanceId defaultInstanceId =
connectionHandler.getServer().getOptions().getDefaultInstanceId();
ResultSet resultSet =
- ResultSets.forRows(
+ ClientSideResultSet.forRows(
Type.struct(
StructField.of("Name", Type.string()),
StructField.of("Owner", Type.string()),
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentCatalogStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentCatalogStatement.java
index aabbfdfe9..e0fcfea9a 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentCatalogStatement.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentCatalogStatement.java
@@ -16,13 +16,13 @@
import com.google.api.core.InternalApi;
import com.google.cloud.spanner.ResultSet;
-import com.google.cloud.spanner.ResultSets;
import com.google.cloud.spanner.Struct;
import com.google.cloud.spanner.Type;
import com.google.cloud.spanner.Type.StructField;
import com.google.cloud.spanner.connection.StatementResult;
import com.google.cloud.spanner.pgadapter.statements.BackendConnection;
import com.google.cloud.spanner.pgadapter.statements.BackendConnection.QueryResult;
+import com.google.cloud.spanner.pgadapter.statements.ClientSideResultSet;
import com.google.common.collect.ImmutableList;
@InternalApi
@@ -52,7 +52,7 @@ public String[] getSql() {
@Override
public StatementResult execute(BackendConnection backendConnection) {
ResultSet resultSet =
- ResultSets.forRows(
+ ClientSideResultSet.forRows(
Type.struct(StructField.of("current_catalog", Type.string())),
ImmutableList.of(
Struct.newBuilder()
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentDatabaseStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentDatabaseStatement.java
index 2eefe9c3e..1e173b971 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentDatabaseStatement.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentDatabaseStatement.java
@@ -16,13 +16,13 @@
import com.google.api.core.InternalApi;
import com.google.cloud.spanner.ResultSet;
-import com.google.cloud.spanner.ResultSets;
import com.google.cloud.spanner.Struct;
import com.google.cloud.spanner.Type;
import com.google.cloud.spanner.Type.StructField;
import com.google.cloud.spanner.connection.StatementResult;
import com.google.cloud.spanner.pgadapter.statements.BackendConnection;
import com.google.cloud.spanner.pgadapter.statements.BackendConnection.QueryResult;
+import com.google.cloud.spanner.pgadapter.statements.ClientSideResultSet;
import com.google.common.collect.ImmutableList;
@InternalApi
@@ -53,7 +53,7 @@ public String[] getSql() {
@Override
public StatementResult execute(BackendConnection backendConnection) {
ResultSet resultSet =
- ResultSets.forRows(
+ ClientSideResultSet.forRows(
Type.struct(StructField.of("current_database", Type.string())),
ImmutableList.of(
Struct.newBuilder()
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentSchemaStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentSchemaStatement.java
index bba360e03..6afd994ef 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentSchemaStatement.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectCurrentSchemaStatement.java
@@ -16,13 +16,13 @@
import com.google.api.core.InternalApi;
import com.google.cloud.spanner.ResultSet;
-import com.google.cloud.spanner.ResultSets;
import com.google.cloud.spanner.Struct;
import com.google.cloud.spanner.Type;
import com.google.cloud.spanner.Type.StructField;
import com.google.cloud.spanner.connection.StatementResult;
import com.google.cloud.spanner.pgadapter.statements.BackendConnection;
import com.google.cloud.spanner.pgadapter.statements.BackendConnection.QueryResult;
+import com.google.cloud.spanner.pgadapter.statements.ClientSideResultSet;
import com.google.common.collect.ImmutableList;
@InternalApi
@@ -50,7 +50,7 @@ public String[] getSql() {
@Override
public StatementResult execute(BackendConnection backendConnection) {
ResultSet resultSet =
- ResultSets.forRows(
+ ClientSideResultSet.forRows(
Type.struct(StructField.of("current_schema", Type.string())),
ImmutableList.of(
Struct.newBuilder()
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectVersionStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectVersionStatement.java
index 59980448b..3fbfe1fe3 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectVersionStatement.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectVersionStatement.java
@@ -16,13 +16,13 @@
import com.google.api.core.InternalApi;
import com.google.cloud.spanner.ResultSet;
-import com.google.cloud.spanner.ResultSets;
import com.google.cloud.spanner.Struct;
import com.google.cloud.spanner.Type;
import com.google.cloud.spanner.Type.StructField;
import com.google.cloud.spanner.connection.StatementResult;
import com.google.cloud.spanner.pgadapter.statements.BackendConnection;
import com.google.cloud.spanner.pgadapter.statements.BackendConnection.QueryResult;
+import com.google.cloud.spanner.pgadapter.statements.ClientSideResultSet;
import com.google.common.collect.ImmutableList;
@InternalApi
@@ -64,7 +64,7 @@ public String[] getSql() {
@Override
public StatementResult execute(BackendConnection backendConnection) {
ResultSet resultSet =
- ResultSets.forRows(
+ ClientSideResultSet.forRows(
Type.struct(StructField.of("version", Type.string())),
ImmutableList.of(
Struct.newBuilder()
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParser.java
index d39e1bd71..d90bb8955 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParser.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParser.java
@@ -230,11 +230,24 @@ public int numColumns() {
return fields.length;
}
+ @Override
+ public boolean isEndRecord() {
+ return false;
+ }
+
@Override
public boolean hasColumnNames() {
return false;
}
+ @Override
+ public boolean isNull(int columnIndex) {
+ Preconditions.checkArgument(
+ columnIndex >= 0 && columnIndex < numColumns(),
+ "columnIndex must be >= 0 && < numColumns");
+ return fields[columnIndex].data == null;
+ }
+
@Override
public Value getValue(Type type, String columnName) {
// The binary copy format does not include any column name headers or any type information.
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetector.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetector.java
index 44db4631c..5f9985a39 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetector.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetector.java
@@ -17,6 +17,11 @@
import com.google.api.core.InternalApi;
import com.google.cloud.spanner.Statement;
import com.google.cloud.spanner.pgadapter.ConnectionHandler;
+import com.google.cloud.spanner.pgadapter.error.PGException;
+import com.google.cloud.spanner.pgadapter.error.SQLState;
+import com.google.cloud.spanner.pgadapter.statements.PgCatalog.EmptyPgAttribute;
+import com.google.cloud.spanner.pgadapter.statements.PgCatalog.EmptyPgEnum;
+import com.google.cloud.spanner.pgadapter.statements.PgCatalog.PgCatalogTable;
import com.google.cloud.spanner.pgadapter.statements.local.DjangoGetTableNamesStatement;
import com.google.cloud.spanner.pgadapter.statements.local.ListDatabasesStatement;
import com.google.cloud.spanner.pgadapter.statements.local.LocalStatement;
@@ -24,10 +29,19 @@
import com.google.cloud.spanner.pgadapter.statements.local.SelectCurrentDatabaseStatement;
import com.google.cloud.spanner.pgadapter.statements.local.SelectCurrentSchemaStatement;
import com.google.cloud.spanner.pgadapter.statements.local.SelectVersionStatement;
+import com.google.cloud.spanner.pgadapter.wireoutput.NoticeResponse;
+import com.google.cloud.spanner.pgadapter.wireoutput.NoticeResponse.NoticeSeverity;
+import com.google.cloud.spanner.pgadapter.wireprotocol.ParseMessage;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import java.time.Duration;
import java.util.List;
import java.util.Map;
+import java.util.function.Supplier;
+import java.util.regex.Pattern;
import javax.annotation.Nonnull;
import org.postgresql.core.Oid;
@@ -45,6 +59,10 @@ public class ClientAutoDetector {
SelectCurrentCatalogStatement.INSTANCE,
SelectVersionStatement.INSTANCE,
DjangoGetTableNamesStatement.INSTANCE);
+ private static final ImmutableSet DEFAULT_CHECK_PG_CATALOG_PREFIXES =
+ ImmutableSet.of("pg_");
+ public static final String PGBENCH_USAGE_HINT =
+ "See https://github.com/GoogleCloudPlatform/pgadapter/blob/-/docs/pgbench.md for how to use pgbench with PGAdapter";
public enum WellKnownClient {
PSQL {
@@ -67,6 +85,48 @@ public ImmutableList getLocalStatements(ConnectionHandler connec
return ImmutableList.of(new ListDatabasesStatement(connectionHandler));
}
},
+ PGBENCH {
+ final ImmutableList errorHints = ImmutableList.of(PGBENCH_USAGE_HINT);
+ volatile long lastHintTimestampMillis = 0L;
+
+ @Override
+ public void reset() {
+ lastHintTimestampMillis = 0L;
+ }
+
+ @Override
+ boolean isClient(List orderedParameterKeys, Map parameters) {
+ // PGBENCH makes it easy for us, as it sends its own name in the application_name parameter.
+ return parameters.containsKey("application_name")
+ && parameters.get("application_name").equals("pgbench");
+ }
+
+ @Override
+ public ImmutableList createStartupNoticeResponses(
+ ConnectionHandler connection) {
+ synchronized (PGBENCH) {
+ // Only send the hint at most once every 30 seconds, to prevent benchmark runs that open
+ // multiple connections from showing the hint every time.
+ if (Duration.ofMillis(System.currentTimeMillis() - lastHintTimestampMillis).getSeconds()
+ > 30L) {
+ lastHintTimestampMillis = System.currentTimeMillis();
+ return ImmutableList.of(
+ new NoticeResponse(
+ connection.getConnectionMetadata().getOutputStream(),
+ SQLState.Success,
+ NoticeSeverity.INFO,
+ "Detected connection from pgbench",
+ PGBENCH_USAGE_HINT + "\n"));
+ }
+ }
+ return super.createStartupNoticeResponses(connection);
+ }
+
+ @Override
+ public ImmutableList getErrorHints(PGException exception) {
+ return errorHints;
+ }
+ },
JDBC {
@Override
boolean isClient(List orderedParameterKeys, Map parameters) {
@@ -113,8 +173,34 @@ boolean isClient(List orderedParameterKeys, Map paramete
// pgx does not send enough unique parameters for it to be auto-detected.
return false;
}
+
+ @Override
+ boolean isClient(ParseMessage parseMessage) {
+ // pgx uses a relatively unique naming scheme for prepared statements (and uses prepared
+ // statements for everything by default).
+ return parseMessage.getName() != null && parseMessage.getName().startsWith("lrupsc_");
+ }
},
NPGSQL {
+ final ImmutableMap tableReplacements =
+ ImmutableMap.of(
+ "pg_catalog.pg_attribute",
+ "pg_attribute",
+ "pg_attribute",
+ "pg_attribute",
+ "pg_catalog.pg_enum",
+ "pg_enum",
+ "pg_enum",
+ "pg_enum");
+ final ImmutableMap pgCatalogTables =
+ ImmutableMap.of("pg_attribute", new EmptyPgAttribute(), "pg_enum", new EmptyPgEnum());
+
+ final ImmutableMap> functionReplacements =
+ ImmutableMap.of(
+ Pattern.compile("elemproc\\.oid = elemtyp\\.typreceive"),
+ Suppliers.ofInstance("false"),
+ Pattern.compile("proc\\.oid = typ\\.typreceive"), Suppliers.ofInstance("false"));
+
@Override
boolean isClient(List orderedParameterKeys, Map parameters) {
// npgsql does not send enough unique parameters for it to be auto-detected.
@@ -134,6 +220,21 @@ boolean isClient(List statements) {
+ "\n"
+ "SELECT ns.nspname, t.oid, t.typname, t.typtype, t.typnotnull, t.elemtypoid\n");
}
+
+ @Override
+ public ImmutableMap getTableReplacements() {
+ return tableReplacements;
+ }
+
+ @Override
+ public ImmutableMap getPgCatalogTables() {
+ return pgCatalogTables;
+ }
+
+ @Override
+ public ImmutableMap> getFunctionReplacements() {
+ return functionReplacements;
+ }
},
UNSPECIFIED {
@Override
@@ -149,14 +250,29 @@ boolean isClient(List statements) {
// defaults defined in this enum.
return true;
}
+
+ @Override
+ boolean isClient(ParseMessage parseMessage) {
+ // Use UNSPECIFIED as default to prevent null checks everywhere and to ease the use of any
+ // defaults defined in this enum.
+ return true;
+ }
};
abstract boolean isClient(List orderedParameterKeys, Map parameters);
+ /** Resets any cached or temporary settings for the client. */
+ @VisibleForTesting
+ public void reset() {}
+
boolean isClient(List statements) {
return false;
}
+ boolean isClient(ParseMessage parseMessage) {
+ return false;
+ }
+
public ImmutableList getLocalStatements(ConnectionHandler connectionHandler) {
if (connectionHandler.getServer().getOptions().useDefaultLocalStatements()) {
return DEFAULT_LOCAL_STATEMENTS;
@@ -164,6 +280,33 @@ public ImmutableList getLocalStatements(ConnectionHandler connec
return EMPTY_LOCAL_STATEMENTS;
}
+ public ImmutableSet getPgCatalogCheckPrefixes() {
+ return DEFAULT_CHECK_PG_CATALOG_PREFIXES;
+ }
+
+ public ImmutableMap getTableReplacements() {
+ return ImmutableMap.of();
+ }
+
+ public ImmutableMap getPgCatalogTables() {
+ return ImmutableMap.of();
+ }
+
+ public ImmutableMap> getFunctionReplacements() {
+ return ImmutableMap.of();
+ }
+
+ /** Creates specific notice messages for a client after startup. */
+ public ImmutableList createStartupNoticeResponses(
+ ConnectionHandler connection) {
+ return ImmutableList.of();
+ }
+
+ /** Returns the client-specific hint(s) that should be included with the given exception. */
+ public ImmutableList getErrorHints(PGException exception) {
+ return ImmutableList.of();
+ }
+
public ImmutableMap getDefaultParameters() {
return ImmutableMap.of();
}
@@ -198,4 +341,18 @@ public ImmutableMap getDefaultParameters() {
// The following line should never be reached.
throw new IllegalStateException("UNSPECIFIED.isClient() should have returned true");
}
+
+ /**
+ * Returns the {@link WellKnownClient} that the detector thinks is connected to PGAdapter based on
+ * the Parse message that has been received.
+ */
+ public static @Nonnull WellKnownClient detectClient(ParseMessage parseMessage) {
+ for (WellKnownClient client : WellKnownClient.values()) {
+ if (client.isClient(parseMessage)) {
+ return client;
+ }
+ }
+ // The following line should never be reached.
+ throw new IllegalStateException("UNSPECIFIED.isClient() should have returned true");
+ }
}
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/Converter.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/Converter.java
index 48211567b..37857bd9d 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/Converter.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/Converter.java
@@ -14,6 +14,9 @@
package com.google.cloud.spanner.pgadapter.utils;
+import static com.google.cloud.spanner.pgadapter.statements.CopyToStatement.COPY_BINARY_HEADER;
+
+import com.google.api.core.InternalApi;
import com.google.cloud.spanner.ResultSet;
import com.google.cloud.spanner.Type;
import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode;
@@ -38,6 +41,7 @@
import java.io.IOException;
/** Utility class for converting between generic PostgreSQL conversions. */
+@InternalApi
public class Converter implements AutoCloseable {
private final ByteArrayOutputStream buffer = new ByteArrayOutputStream(256);
private final DataOutputStream outputStream = new DataOutputStream(buffer);
@@ -46,12 +50,15 @@ public class Converter implements AutoCloseable {
private final OptionsMetadata options;
private final ResultSet resultSet;
private final SessionState sessionState;
+ private boolean includeBinaryCopyHeaderInFirstRow;
+ private boolean firstRow = true;
public Converter(
IntermediateStatement statement,
QueryMode mode,
OptionsMetadata options,
- ResultSet resultSet) {
+ ResultSet resultSet,
+ boolean includeBinaryCopyHeaderInFirstRow) {
this.statement = statement;
this.mode = mode;
this.options = options;
@@ -62,6 +69,16 @@ public Converter(
.getExtendedQueryProtocolHandler()
.getBackendConnection()
.getSessionState();
+ this.includeBinaryCopyHeaderInFirstRow = includeBinaryCopyHeaderInFirstRow;
+ }
+
+ public Converter includeBinaryCopyHeader() {
+ this.includeBinaryCopyHeaderInFirstRow = true;
+ return this;
+ }
+
+ public boolean isIncludeBinaryCopyHeaderInFirstRow() {
+ return this.includeBinaryCopyHeaderInFirstRow;
}
@Override
@@ -83,6 +100,12 @@ public int convertResultSetRowToDataRowResponse() throws IOException {
: DataFormat.POSTGRESQL_TEXT;
}
buffer.reset();
+ if (includeBinaryCopyHeaderInFirstRow && firstRow) {
+ outputStream.write(COPY_BINARY_HEADER);
+ outputStream.writeInt(0); // flags
+ outputStream.writeInt(0); // header extension area length
+ }
+ firstRow = false;
outputStream.writeShort(resultSet.getColumnCount());
for (int column_index = 0; /* column indices start at 0 */
column_index < resultSet.getColumnCount();
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyDataReceiver.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyDataReceiver.java
index cba2c259c..79d750b28 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyDataReceiver.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyDataReceiver.java
@@ -65,7 +65,7 @@ void handleCopy() throws Exception {
this.connectionHandler.setActiveCopyStatement(copyStatement);
new CopyInResponse(
this.connectionHandler.getConnectionMetadata().getOutputStream(),
- copyStatement.getTableColumns().size(),
+ (short) copyStatement.getTableColumns().size(),
copyStatement.getFormatCode())
.send();
ConnectionStatus initialConnectionStatus = this.connectionHandler.getStatus();
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyRecord.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyRecord.java
index e6c252aa4..192c46a8b 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyRecord.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyRecord.java
@@ -28,12 +28,18 @@ public interface CopyRecord {
/** Returns the number of columns in the record. */
int numColumns();
+ /** Returns true if this record is the PG end record (\.). */
+ boolean isEndRecord();
+
/**
* Returns true if the copy record has column names. The {@link #getValue(Type, String)} method
* can only be used for records that have column names.
*/
boolean hasColumnNames();
+ /** Returns true if the value of the given column is null. */
+ boolean isNull(int columnIndex);
+
/**
* Returns the value of the given column as a Cloud Spanner {@link Value} of the given type. This
* method is used by a COPY ... FROM ... operation to convert a value to the type of the column
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParser.java
index 9265faf6f..f6cde0d01 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParser.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParser.java
@@ -33,6 +33,7 @@
import java.nio.charset.StandardCharsets;
import java.time.format.DateTimeParseException;
import java.util.Iterator;
+import java.util.Objects;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.codec.binary.Hex;
@@ -97,11 +98,25 @@ public int numColumns() {
return record.size();
}
+ @Override
+ public boolean isEndRecord() {
+ // End of data can be represented by a single line containing just backslash-period (\.). An
+ // end-of-data marker is not necessary when reading from a file, since the end of file serves
+ // perfectly well; it is needed only when copying data to or from client applications using
+ // pre-3.0 client protocol.
+ return record.size() == 1 && Objects.equals("\\.", record.get(0));
+ }
+
@Override
public boolean hasColumnNames() {
return this.hasHeader;
}
+ @Override
+ public boolean isNull(int columnIndex) {
+ return record.get(columnIndex) == null;
+ }
+
@Override
public Value getValue(Type type, String columnName) throws SpannerException {
String recordValue = record.get(columnName);
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java
index 55eeda095..bdb6b6d1a 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java
@@ -66,6 +66,7 @@
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nonnull;
@@ -109,6 +110,8 @@ public enum CopyTransactionMode {
private final CSVFormat csvFormat;
private final boolean hasHeader;
private final CountDownLatch pipeCreatedLatch = new CountDownLatch(1);
+ private final CountDownLatch dataReceivedLatch = new CountDownLatch(1);
+ private final AtomicLong bytesReceived = new AtomicLong();
private final PipedOutputStream payload = new PipedOutputStream();
private final AtomicBoolean commit = new AtomicBoolean(false);
private final AtomicBoolean rollback = new AtomicBoolean(false);
@@ -164,6 +167,8 @@ public void addCopyData(byte[] payload) {
}
try {
pipeCreatedLatch.await();
+ bytesReceived.addAndGet(payload.length);
+ dataReceivedLatch.countDown();
this.payload.write(payload);
} catch (InterruptedException | InterruptedIOException interruptedIOException) {
// The IO operation was interrupted. This indicates that the user wants to cancel the COPY
@@ -204,6 +209,7 @@ public void rollback() {
public void close() throws IOException {
this.payload.close();
this.closedLatch.countDown();
+ this.dataReceivedLatch.countDown();
}
@Override
@@ -228,14 +234,21 @@ public StatementResult call() throws Exception {
// before finishing, to ensure that all data has been written before we signal that we are done.
List> allCommitFutures = new ArrayList<>();
try {
+ // Wait until we know whether we actually will receive any data. It could be that it is an
+ // empty copy operation, and we should then end early.
+ dataReceivedLatch.await();
+
Iterator iterator = parser.iterator();
List mutations = new ArrayList<>();
long currentBufferByteSize = 0L;
// Note: iterator.hasNext() blocks if there is not enough data in the pipeline to construct a
// complete record. It returns false if the stream has been closed and all records have been
// returned.
- while (!rollback.get() && iterator.hasNext()) {
+ while (bytesReceived.get() > 0L && !rollback.get() && iterator.hasNext()) {
CopyRecord record = iterator.next();
+ if (record.isEndRecord()) {
+ break;
+ }
if (record.numColumns() != this.tableColumns.keySet().size()) {
throw PGExceptionFactory.newPGException(
"Invalid COPY data: Row length mismatched. Expected "
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyDataResponse.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyDataResponse.java
index c91250ba1..871721d8e 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyDataResponse.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyDataResponse.java
@@ -28,9 +28,9 @@
public class CopyDataResponse extends WireOutput {
@InternalApi
public enum ResponseType {
- HEADER,
ROW,
TRAILER,
+ TRAILER_WITH_HEADER,
}
private final ResponseType responseType;
@@ -39,16 +39,14 @@ public enum ResponseType {
private final char rowTerminator;
private final Converter converter;
- /** Creates a {@link CopyDataResponse} message containing the fixed binary COPY header. */
- @InternalApi
- public static CopyDataResponse createBinaryHeader(DataOutputStream output) {
- return new CopyDataResponse(output, COPY_BINARY_HEADER.length + 8, ResponseType.HEADER);
- }
-
/** Creates a {@link CopyDataResponse} message containing the fixed binary COPY trailer. */
@InternalApi
- public static CopyDataResponse createBinaryTrailer(DataOutputStream output) {
- return new CopyDataResponse(output, 2, ResponseType.TRAILER);
+ public static CopyDataResponse createBinaryTrailer(
+ DataOutputStream output, boolean includeHeader) {
+ return new CopyDataResponse(
+ output,
+ includeHeader ? 21 : 2,
+ includeHeader ? ResponseType.TRAILER_WITH_HEADER : ResponseType.TRAILER);
}
private CopyDataResponse(DataOutputStream output, int length, ResponseType responseType) {
@@ -78,6 +76,7 @@ public CopyDataResponse(DataOutputStream output, Converter converter) {
this.converter = converter;
}
+ @Override
public void send(boolean flush) throws Exception {
if (converter != null) {
this.length = 4 + converter.convertResultSetRowToDataRowResponse();
@@ -91,10 +90,11 @@ protected void sendPayload() throws Exception {
this.outputStream.write(this.stringData.getBytes(StandardCharsets.UTF_8));
this.outputStream.write(this.rowTerminator);
} else if (this.format == DataFormat.POSTGRESQL_BINARY) {
- if (this.responseType == ResponseType.HEADER) {
+ if (this.responseType == ResponseType.TRAILER_WITH_HEADER) {
this.outputStream.write(COPY_BINARY_HEADER);
this.outputStream.writeInt(0); // flags
this.outputStream.writeInt(0); // header extension area length
+ this.outputStream.writeShort(-1);
} else if (this.responseType == ResponseType.TRAILER) {
this.outputStream.writeShort(-1);
} else if (this.converter != null) {
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyInResponse.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyInResponse.java
index 58eee8e67..13c064eeb 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyInResponse.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/CopyInResponse.java
@@ -35,23 +35,25 @@ public class CopyInResponse extends WireOutput {
protected static final char IDENTIFIER = 'G';
- private final int numColumns;
- private final int formatCode;
- private final byte[] columnFormat;
+ private final short numColumns;
+ private final byte formatCode;
+ private final short[] columnFormat;
- public CopyInResponse(DataOutputStream output, int numColumns, int formatCode) {
+ public CopyInResponse(DataOutputStream output, short numColumns, byte formatCode) {
super(output, calculateLength(numColumns));
this.numColumns = numColumns;
this.formatCode = formatCode;
- columnFormat = new byte[COLUMN_NUM_LENGTH * this.numColumns];
- Arrays.fill(columnFormat, (byte) formatCode);
+ this.columnFormat = new short[this.numColumns];
+ Arrays.fill(this.columnFormat, formatCode);
}
@Override
protected void sendPayload() throws IOException {
this.outputStream.writeByte(this.formatCode);
this.outputStream.writeShort(this.numColumns);
- this.outputStream.write(this.columnFormat);
+ for (int i = 0; i < this.numColumns; i++) {
+ this.outputStream.writeShort(this.columnFormat[i]);
+ }
}
@Override
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/ErrorResponse.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/ErrorResponse.java
index e37964fd7..35502d2d0 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/ErrorResponse.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/ErrorResponse.java
@@ -15,9 +15,11 @@
package com.google.cloud.spanner.pgadapter.wireoutput;
import com.google.api.core.InternalApi;
+import com.google.cloud.spanner.pgadapter.ConnectionHandler;
import com.google.cloud.spanner.pgadapter.error.PGException;
+import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient;
import com.google.common.base.Strings;
-import java.io.DataOutputStream;
+import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.text.MessageFormat;
@@ -25,6 +27,7 @@
/** Sends error information back to client. */
@InternalApi
public class ErrorResponse extends WireOutput {
+ private static final byte[] EMPTY_BYTE_ARRAY = new byte[0];
private static final int HEADER_LENGTH = 4;
private static final int FIELD_IDENTIFIER_LENGTH = 1;
@@ -36,58 +39,83 @@ public class ErrorResponse extends WireOutput {
private static final byte HINT_FLAG = 'H';
private static final byte NULL_TERMINATOR = 0;
- private final byte[] severity;
- private final byte[] errorMessage;
- private final byte[] errorState;
- private final byte[] hints;
-
- public ErrorResponse(DataOutputStream output, PGException pgException) {
- super(output, calculateLength(pgException));
- this.errorMessage = pgException.getMessage().getBytes(StandardCharsets.UTF_8);
- this.errorState = pgException.getSQLState().getBytes();
- this.severity = pgException.getSeverity().name().getBytes(StandardCharsets.UTF_8);
- this.hints =
- Strings.isNullOrEmpty(pgException.getHints())
- ? null
- : pgException.getHints().getBytes(StandardCharsets.UTF_8);
+ private final PGException pgException;
+ private final WellKnownClient client;
+
+ public ErrorResponse(ConnectionHandler connection, PGException pgException) {
+ super(
+ connection.getConnectionMetadata().getOutputStream(),
+ calculateLength(pgException, connection.getWellKnownClient()));
+ this.pgException = pgException;
+ this.client = connection.getWellKnownClient();
}
- static int calculateLength(PGException pgException) {
+ static int calculateLength(PGException pgException, WellKnownClient client) {
int length =
HEADER_LENGTH
+ FIELD_IDENTIFIER_LENGTH
- + pgException.getSeverity().name().getBytes(StandardCharsets.UTF_8).length
+ + convertSeverityToWireProtocol(pgException).length
+ NULL_TERMINATOR_LENGTH
+ FIELD_IDENTIFIER_LENGTH
- + pgException.getSQLState().getBytes().length
+ + convertSQLStateToWireProtocol(pgException).length
+ NULL_TERMINATOR_LENGTH
+ FIELD_IDENTIFIER_LENGTH
- + pgException.getMessage().getBytes(StandardCharsets.UTF_8).length
+ + convertMessageToWireProtocol(pgException).length
+ NULL_TERMINATOR_LENGTH
+ NULL_TERMINATOR_LENGTH;
- if (!Strings.isNullOrEmpty(pgException.getHints())) {
- length +=
- FIELD_IDENTIFIER_LENGTH
- + pgException.getHints().getBytes(StandardCharsets.UTF_8).length
- + NULL_TERMINATOR_LENGTH;
+ byte[] hints = convertHintsToWireProtocol(pgException, client);
+ if (hints.length > 0) {
+ length += FIELD_IDENTIFIER_LENGTH + hints.length + NULL_TERMINATOR_LENGTH;
}
return length;
}
+ static byte[] convertSeverityToWireProtocol(PGException pgException) {
+ return pgException.getSeverity().name().getBytes(StandardCharsets.UTF_8);
+ }
+
+ static byte[] convertSQLStateToWireProtocol(PGException pgException) {
+ return pgException.getSQLState().getBytes();
+ }
+
+ static byte[] convertMessageToWireProtocol(PGException pgException) {
+ return pgException.getMessage().getBytes(StandardCharsets.UTF_8);
+ }
+
+ static byte[] convertHintsToWireProtocol(PGException pgException, WellKnownClient client) {
+ if (Strings.isNullOrEmpty(pgException.getHints())
+ && client.getErrorHints(pgException).isEmpty()) {
+ return EMPTY_BYTE_ARRAY;
+ }
+ String hints = "";
+ if (!Strings.isNullOrEmpty(pgException.getHints())) {
+ hints += pgException.getHints();
+ }
+ ImmutableList clientHints = client.getErrorHints(pgException);
+ if (!clientHints.isEmpty()) {
+ if (hints.length() > 0) {
+ hints += "\n";
+ }
+ hints += String.join("\n", clientHints);
+ }
+ return hints.getBytes(StandardCharsets.UTF_8);
+ }
+
@Override
protected void sendPayload() throws IOException {
this.outputStream.writeByte(SEVERITY_FLAG);
- this.outputStream.write(severity);
+ this.outputStream.write(convertSeverityToWireProtocol(pgException));
this.outputStream.writeByte(NULL_TERMINATOR);
this.outputStream.writeByte(CODE_FLAG);
- this.outputStream.write(this.errorState);
+ this.outputStream.write(convertSQLStateToWireProtocol(pgException));
this.outputStream.writeByte(NULL_TERMINATOR);
this.outputStream.writeByte(MESSAGE_FLAG);
- this.outputStream.write(this.errorMessage);
+ this.outputStream.write(convertMessageToWireProtocol(pgException));
this.outputStream.writeByte(NULL_TERMINATOR);
- if (this.hints != null) {
+ byte[] hints = convertHintsToWireProtocol(pgException, client);
+ if (hints.length > 0) {
this.outputStream.writeByte(HINT_FLAG);
- this.outputStream.write(this.hints);
+ this.outputStream.write(hints);
this.outputStream.writeByte(NULL_TERMINATOR);
}
this.outputStream.writeByte(NULL_TERMINATOR);
@@ -109,8 +137,10 @@ protected String getPayloadString() {
.format(
new Object[] {
this.length,
- new String(this.errorMessage, UTF8),
- this.hints == null ? "" : new String(this.hints, UTF8)
+ new String(convertMessageToWireProtocol(pgException), UTF8),
+ convertHintsToWireProtocol(pgException, client).length == 0
+ ? ""
+ : new String(convertHintsToWireProtocol(pgException, client), UTF8)
});
}
}
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/NoticeResponse.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/NoticeResponse.java
new file mode 100644
index 000000000..bcf3d7485
--- /dev/null
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireoutput/NoticeResponse.java
@@ -0,0 +1,129 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.cloud.spanner.pgadapter.wireoutput;
+
+import com.google.api.core.InternalApi;
+import com.google.cloud.spanner.pgadapter.error.SQLState;
+import com.google.common.base.Preconditions;
+import java.io.DataOutputStream;
+import java.nio.charset.StandardCharsets;
+import java.text.MessageFormat;
+
+/**
+ * Notices can be sent as asynchronous messages and can include warnings, informational messages,
+ * debug information, etc.
+ */
+@InternalApi
+public class NoticeResponse extends WireOutput {
+ public enum NoticeSeverity {
+ WARNING,
+ NOTICE,
+ DEBUG,
+ INFO,
+ LOG,
+ TIP,
+ }
+
+ private static final int HEADER_LENGTH = 4;
+ private static final int FIELD_IDENTIFIER_LENGTH = 1;
+ private static final int NULL_TERMINATOR_LENGTH = 1;
+
+ private static final byte MESSAGE_FLAG = 'M';
+ private static final byte CODE_FLAG = 'C';
+ private static final byte SEVERITY_FLAG = 'S';
+ private static final byte HINT_FLAG = 'H';
+ private static final byte NULL_TERMINATOR = 0;
+
+ private SQLState sqlState;
+ private final NoticeSeverity severity;
+ private final String message;
+ private final String hint;
+
+ public NoticeResponse(
+ DataOutputStream output,
+ SQLState sqlState,
+ NoticeSeverity severity,
+ String message,
+ String hint) {
+ super(
+ output,
+ HEADER_LENGTH
+ + FIELD_IDENTIFIER_LENGTH
+ + Preconditions.checkNotNull(severity).name().getBytes(StandardCharsets.UTF_8).length
+ + NULL_TERMINATOR_LENGTH
+ + FIELD_IDENTIFIER_LENGTH
+ + sqlState.getBytes().length
+ + NULL_TERMINATOR_LENGTH
+ + FIELD_IDENTIFIER_LENGTH
+ + Preconditions.checkNotNull(message).getBytes(StandardCharsets.UTF_8).length
+ + NULL_TERMINATOR_LENGTH
+ + (hint == null
+ ? 0
+ : (FIELD_IDENTIFIER_LENGTH
+ + hint.getBytes(StandardCharsets.UTF_8).length
+ + NULL_TERMINATOR_LENGTH))
+ + NULL_TERMINATOR_LENGTH);
+ this.sqlState = sqlState;
+ this.severity = severity;
+ this.message = message;
+ this.hint = hint;
+ }
+
+ @Override
+ protected void sendPayload() throws Exception {
+ this.outputStream.writeByte(SEVERITY_FLAG);
+ this.outputStream.write(severity.name().getBytes(StandardCharsets.UTF_8));
+ this.outputStream.writeByte(NULL_TERMINATOR);
+ this.outputStream.writeByte(CODE_FLAG);
+ this.outputStream.write(sqlState.getBytes());
+ this.outputStream.writeByte(NULL_TERMINATOR);
+ this.outputStream.writeByte(MESSAGE_FLAG);
+ this.outputStream.write(message.getBytes(StandardCharsets.UTF_8));
+ this.outputStream.writeByte(NULL_TERMINATOR);
+ if (this.hint != null) {
+ this.outputStream.writeByte(HINT_FLAG);
+ this.outputStream.write(hint.getBytes(StandardCharsets.UTF_8));
+ this.outputStream.writeByte(NULL_TERMINATOR);
+ }
+ this.outputStream.writeByte(NULL_TERMINATOR);
+ }
+
+ @Override
+ public byte getIdentifier() {
+ return 'N';
+ }
+
+ @Override
+ protected String getMessageName() {
+ return "Notice";
+ }
+
+ public String getMessage() {
+ return this.message;
+ }
+
+ public String getHint() {
+ return this.hint;
+ }
+
+ @Override
+ protected String getPayloadString() {
+ return new MessageFormat("Length: {0}, Severity: {1}, Notice Message: {2}, Hint: {3}")
+ .format(
+ new Object[] {
+ this.length, this.severity, this.message, this.hint == null ? "" : this.hint
+ });
+ }
+}
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BootstrapMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BootstrapMessage.java
index 8cc11a6c9..e4ac643e0 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BootstrapMessage.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BootstrapMessage.java
@@ -19,6 +19,7 @@
import com.google.cloud.spanner.pgadapter.session.SessionState;
import com.google.cloud.spanner.pgadapter.wireoutput.AuthenticationOkResponse;
import com.google.cloud.spanner.pgadapter.wireoutput.KeyDataResponse;
+import com.google.cloud.spanner.pgadapter.wireoutput.NoticeResponse;
import com.google.cloud.spanner.pgadapter.wireoutput.ParameterStatusResponse;
import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse;
import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse.Status;
@@ -107,7 +108,11 @@ protected List parseParameterKeys(String rawParameters) {
* @throws Exception
*/
public static void sendStartupMessage(
- DataOutputStream output, int connectionId, int secret, SessionState sessionState)
+ DataOutputStream output,
+ int connectionId,
+ int secret,
+ SessionState sessionState,
+ Iterable startupNotices)
throws Exception {
new AuthenticationOkResponse(output).send(false);
new KeyDataResponse(output, connectionId, secret).send(false);
@@ -166,6 +171,9 @@ public static void sendStartupMessage(
"TimeZone".getBytes(StandardCharsets.UTF_8),
ZoneId.systemDefault().getId().getBytes(StandardCharsets.UTF_8))
.send(false);
+ for (NoticeResponse noticeResponse : startupNotices) {
+ noticeResponse.send(false);
+ }
new ReadyResponse(output, Status.IDLE).send();
}
}
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java
index e233eb411..67ea96f8f 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java
@@ -38,6 +38,7 @@
import com.google.cloud.spanner.pgadapter.error.Severity;
import com.google.cloud.spanner.pgadapter.metadata.SendResultSetState;
import com.google.cloud.spanner.pgadapter.statements.BackendConnection.PartitionQueryResult;
+import com.google.cloud.spanner.pgadapter.statements.CopyToStatement;
import com.google.cloud.spanner.pgadapter.statements.IntermediateStatement;
import com.google.cloud.spanner.pgadapter.utils.Converter;
import com.google.cloud.spanner.pgadapter.wireoutput.CommandCompleteResponse;
@@ -51,7 +52,6 @@
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
-import com.google.common.util.concurrent.SettableFuture;
import io.grpc.Context;
import io.grpc.MethodDescriptor;
import java.io.DataInputStream;
@@ -59,6 +59,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
+import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.logging.Level;
@@ -178,7 +179,7 @@ public static ControlMessage create(ConnectionHandler connection) throws Excepti
connection.increaseInvalidMessageCount();
if (connection.getInvalidMessageCount() > MAX_INVALID_MESSAGE_COUNT) {
new ErrorResponse(
- connection.getConnectionMetadata().getOutputStream(),
+ connection,
PGException.newBuilder(
String.format(
"Received %d invalid/unexpected messages. Last received message: '%c'",
@@ -232,7 +233,7 @@ static PreparedType prepareType(char type) {
* @throws Exception if there is some issue in the sending of the error messages.
*/
protected void handleError(Exception exception) throws Exception {
- new ErrorResponse(this.outputStream, PGExceptionFactory.toPGException(exception)).send(false);
+ new ErrorResponse(this.connection, PGExceptionFactory.toPGException(exception)).send(false);
}
/**
@@ -312,6 +313,7 @@ SendResultSetState sendResultSet(
if (statementResult instanceof PartitionQueryResult) {
hasData = false;
PartitionQueryResult partitionQueryResult = (PartitionQueryResult) statementResult;
+ sendPrefix(describedResult, ((PartitionQueryResult) statementResult).getMetadataResultSet());
rows =
sendPartitionedQuery(
describedResult,
@@ -321,17 +323,28 @@ SendResultSetState sendResultSet(
} else {
hasData = describedResult.isHasMoreData();
ResultSet resultSet = describedResult.getStatementResult().getResultSet();
+ sendPrefix(describedResult, resultSet);
SendResultSetRunnable runnable =
- SendResultSetRunnable.forResultSet(
- describedResult, resultSet, maxRows, mode, SettableFuture.create(), hasData);
+ SendResultSetRunnable.forResultSet(describedResult, resultSet, maxRows, mode, hasData);
rows = runnable.call();
hasData = runnable.hasData;
}
+ sendSuffix(describedResult);
+ return new SendResultSetState(describedResult.getCommandTag(), rows, hasData);
+ }
+
+ private void sendPrefix(IntermediateStatement describedResult, ResultSet resultSet)
+ throws Exception {
+ for (WireOutput prefix : describedResult.createResultPrefix(resultSet)) {
+ prefix.send(false);
+ }
+ }
+
+ private void sendSuffix(IntermediateStatement describedResult) throws Exception {
for (WireOutput suffix : describedResult.createResultSuffix()) {
suffix.send(false);
}
- return new SendResultSetState(describedResult.getCommandTag(), rows, hasData);
}
long sendPartitionedQuery(
@@ -344,7 +357,6 @@ long sendPartitionedQuery(
Executors.newFixedThreadPool(
Math.min(8 * Runtime.getRuntime().availableProcessors(), partitions.size())));
List> futures = new ArrayList<>(partitions.size());
- SettableFuture prefixSent = SettableFuture.create();
Connection spannerConnection = connection.getSpannerConnection();
Spanner spanner = ConnectionOptionsHelper.getSpanner(spannerConnection);
BatchClient batchClient = spanner.getBatchClient(connection.getDatabaseId());
@@ -361,6 +373,10 @@ public ApiCallContext configure(
return GrpcCallContext.createDefault().withTimeout(Duration.ofHours(24L));
}
});
+ CountDownLatch binaryCopyHeaderSentLatch =
+ describedResult instanceof CopyToStatement && ((CopyToStatement) describedResult).isBinary()
+ ? new CountDownLatch(1)
+ : new CountDownLatch(0);
for (int i = 0; i < partitions.size(); i++) {
futures.add(
executorService.submit(
@@ -370,8 +386,7 @@ public ApiCallContext configure(
batchReadOnlyTransaction,
partitions.get(i),
mode,
- i == 0,
- prefixSent))));
+ binaryCopyHeaderSentLatch))));
}
executorService.shutdown();
try {
@@ -403,8 +418,7 @@ static final class SendResultSetRunnable implements Callable {
private final Partition partition;
private final long maxRows;
private final QueryMode mode;
- private final boolean includePrefix;
- private final SettableFuture prefixSent;
+ private final CountDownLatch binaryCopyHeaderSentLatch;
private boolean hasData;
static SendResultSetRunnable forResultSet(
@@ -412,10 +426,8 @@ static SendResultSetRunnable forResultSet(
ResultSet resultSet,
long maxRows,
QueryMode mode,
- SettableFuture prefixSent,
boolean hasData) {
- return new SendResultSetRunnable(
- describedResult, resultSet, maxRows, mode, true, prefixSent, hasData);
+ return new SendResultSetRunnable(describedResult, resultSet, maxRows, mode, true, hasData);
}
static SendResultSetRunnable forPartition(
@@ -423,10 +435,9 @@ static SendResultSetRunnable forPartition(
BatchReadOnlyTransaction batchReadOnlyTransaction,
Partition partition,
QueryMode mode,
- boolean includePrefix,
- SettableFuture prefixSent) {
+ CountDownLatch binaryCopyHeaderSentLatch) {
return new SendResultSetRunnable(
- describedResult, batchReadOnlyTransaction, partition, mode, includePrefix, prefixSent);
+ describedResult, batchReadOnlyTransaction, partition, mode, binaryCopyHeaderSentLatch);
}
private SendResultSetRunnable(
@@ -435,7 +446,6 @@ private SendResultSetRunnable(
long maxRows,
QueryMode mode,
boolean includePrefix,
- SettableFuture prefixSent,
boolean hasData) {
this.describedResult = describedResult;
this.resultSet = resultSet;
@@ -444,13 +454,15 @@ private SendResultSetRunnable(
describedResult,
mode,
describedResult.getConnectionHandler().getServer().getOptions(),
- resultSet);
+ resultSet,
+ includePrefix
+ && describedResult instanceof CopyToStatement
+ && ((CopyToStatement) describedResult).isBinary());
this.batchReadOnlyTransaction = null;
this.partition = null;
this.maxRows = maxRows;
this.mode = mode;
- this.includePrefix = includePrefix;
- this.prefixSent = prefixSent;
+ this.binaryCopyHeaderSentLatch = new CountDownLatch(0);
this.hasData = hasData;
}
@@ -459,16 +471,14 @@ private SendResultSetRunnable(
BatchReadOnlyTransaction batchReadOnlyTransaction,
Partition partition,
QueryMode mode,
- boolean includePrefix,
- SettableFuture prefixSent) {
+ CountDownLatch binaryCopyHeaderSentLatch) {
this.describedResult = describedResult;
this.resultSet = null;
this.batchReadOnlyTransaction = batchReadOnlyTransaction;
this.partition = partition;
this.maxRows = 0L;
this.mode = mode;
- this.includePrefix = includePrefix;
- this.prefixSent = prefixSent;
+ this.binaryCopyHeaderSentLatch = binaryCopyHeaderSentLatch;
this.hasData = false;
}
@@ -478,46 +488,30 @@ public Long call() throws Exception {
if (resultSet == null && batchReadOnlyTransaction != null && partition != null) {
// Note: It is OK not to close this result set, as the underlying transaction and session
// will be cleaned up at a later moment.
- try {
- resultSet = batchReadOnlyTransaction.execute(partition);
- converter =
- new Converter(
- describedResult,
- mode,
- describedResult.getConnectionHandler().getServer().getOptions(),
- resultSet);
- hasData = resultSet.next();
- } catch (Throwable t) {
- if (includePrefix) {
- synchronized (describedResult) {
- prefixSent.setException(t);
- }
- }
- throw t;
- }
- }
- if (includePrefix) {
- try {
- for (WireOutput prefix : describedResult.createResultPrefix(resultSet)) {
- prefix.send(false);
- }
- prefixSent.set(true);
- } catch (Throwable t) {
- prefixSent.setException(t);
- throw t;
- }
+ resultSet = batchReadOnlyTransaction.execute(partition);
+ hasData = resultSet.next();
+ converter =
+ new Converter(
+ describedResult,
+ mode,
+ describedResult.getConnectionHandler().getServer().getOptions(),
+ resultSet,
+ false);
}
- // Wait until the prefix (if any) has been sent.
- prefixSent.get();
long rows = 0L;
while (hasData) {
- if (Thread.interrupted()) {
- throw PGExceptionFactory.newQueryCancelledException();
- }
WireOutput wireOutput = describedResult.createDataRowResponse(converter);
+ if (!converter.isIncludeBinaryCopyHeaderInFirstRow()) {
+ binaryCopyHeaderSentLatch.await();
+ }
synchronized (describedResult) {
wireOutput.send(false);
}
+ binaryCopyHeaderSentLatch.countDown();
+ if (Thread.interrupted()) {
+ throw PGExceptionFactory.newQueryCancelledException();
+ }
+
rows++;
hasData = resultSet.next();
if (rows % 1000 == 0) {
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java
index 611c878d2..0eddd36ce 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java
@@ -64,6 +64,7 @@ public ParseMessage(ConnectionHandler connection) throws Exception {
}
this.statement =
createStatement(connection, name, parsedStatement, originalStatement, parameterDataTypes);
+ connection.maybeDetermineWellKnownClient(this);
}
/**
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/PasswordMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/PasswordMessage.java
index a289b44a4..45d29d627 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/PasswordMessage.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/PasswordMessage.java
@@ -70,7 +70,7 @@ public PasswordMessage(ConnectionHandler connection, Map paramet
protected void sendPayload() throws Exception {
if (!useAuthentication()) {
new ErrorResponse(
- this.outputStream,
+ this.connection,
PGException.newBuilder("Received PasswordMessage while authentication is disabled.")
.setSQLState(SQLState.ProtocolViolation)
.setSeverity(Severity.ERROR)
@@ -83,7 +83,7 @@ protected void sendPayload() throws Exception {
Credentials credentials = checkCredentials(this.username, this.password);
if (credentials == null) {
new ErrorResponse(
- this.outputStream,
+ this.connection,
PGException.newBuilder("Invalid credentials received.")
.setHints(
"PGAdapter expects credentials to be one of the following:\n"
diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/StartupMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/StartupMessage.java
index c2c7f0016..be2832f4f 100644
--- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/StartupMessage.java
+++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/StartupMessage.java
@@ -26,7 +26,6 @@
import java.text.MessageFormat;
import java.util.Map;
import java.util.Map.Entry;
-import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
@@ -53,14 +52,6 @@ public StartupMessage(ConnectionHandler connection, int length) throws Exception
WellKnownClient wellKnownClient =
ClientAutoDetector.detectClient(this.parseParameterKeys(rawParameters), this.parameters);
connection.setWellKnownClient(wellKnownClient);
- if (wellKnownClient != WellKnownClient.UNSPECIFIED) {
- logger.log(
- Level.INFO,
- () ->
- String.format(
- "Well-known client %s detected for connection %d.",
- wellKnownClient, connection.getConnectionId()));
- }
} else {
connection.setWellKnownClient(WellKnownClient.UNSPECIFIED);
}
@@ -96,7 +87,8 @@ static void createConnectionAndSendStartupMessage(
connection.getConnectionMetadata().getOutputStream(),
connection.getConnectionId(),
connection.getSecret(),
- connection.getExtendedQueryProtocolHandler().getBackendConnection().getSessionState());
+ connection.getExtendedQueryProtocolHandler().getBackendConnection().getSessionState(),
+ connection.getWellKnownClient().createStartupNoticeResponses(connection));
connection.setStatus(ConnectionStatus.AUTHENTICATED);
}
diff --git a/src/test/csharp/pgadapter_npgsql_tests/npgsql_tests/NpgsqlTest.cs b/src/test/csharp/pgadapter_npgsql_tests/npgsql_tests/NpgsqlTest.cs
index 0d1665900..d499edb4c 100644
--- a/src/test/csharp/pgadapter_npgsql_tests/npgsql_tests/NpgsqlTest.cs
+++ b/src/test/csharp/pgadapter_npgsql_tests/npgsql_tests/NpgsqlTest.cs
@@ -60,6 +60,22 @@ public void TestShowServerVersion()
}
}
+ public void TestShowApplicationName()
+ {
+ using var connection = new NpgsqlConnection(ConnectionString);
+ connection.Open();
+
+ using var cmd = new NpgsqlCommand("show application_name", connection);
+ using (var reader = cmd.ExecuteReader())
+ {
+ while (reader.Read())
+ {
+ var applicationName = reader.GetString(0);
+ Console.WriteLine($"{applicationName}");
+ }
+ }
+ }
+
public void TestSelect1()
{
using var connection = new NpgsqlConnection(ConnectionString);
diff --git a/src/test/golang/pgadapter_pgx_tests/pgx.go b/src/test/golang/pgadapter_pgx_tests/pgx.go
index 4afd39e03..637fc0310 100644
--- a/src/test/golang/pgadapter_pgx_tests/pgx.go
+++ b/src/test/golang/pgadapter_pgx_tests/pgx.go
@@ -80,6 +80,27 @@ func TestSelect1(connString string) *C.char {
return nil
}
+//export TestShowApplicationName
+func TestShowApplicationName(connString string) *C.char {
+ ctx := context.Background()
+ conn, err := pgx.Connect(ctx, connString)
+ if err != nil {
+ return C.CString(err.Error())
+ }
+ defer conn.Close(ctx)
+
+ var value string
+ err = conn.QueryRow(ctx, "show application_name").Scan(&value)
+ if err != nil {
+ return C.CString(err.Error())
+ }
+ if g, w := value, "pgx"; g != w {
+ return C.CString(fmt.Sprintf("value mismatch\n Got: %v\nWant: %v", g, w))
+ }
+
+ return nil
+}
+
//export TestQueryWithParameter
func TestQueryWithParameter(connString string) *C.char {
ctx := context.Background()
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java
index e941b18e2..ef65e0a0d 100644
--- a/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java
@@ -366,11 +366,20 @@ public abstract class AbstractMockServerTest {
Status.INVALID_ARGUMENT.withDescription("Statement is invalid.").asRuntimeException();
protected static ResultSet createAllTypesResultSet(String columnPrefix) {
+ return createAllTypesResultSet(columnPrefix, false);
+ }
+
+ protected static ResultSet createAllTypesResultSet(String columnPrefix, boolean microsTimestamp) {
+ return createAllTypesResultSet("1", columnPrefix, microsTimestamp);
+ }
+
+ protected static ResultSet createAllTypesResultSet(
+ String id, String columnPrefix, boolean microsTimestamp) {
return com.google.spanner.v1.ResultSet.newBuilder()
.setMetadata(createAllTypesResultSetMetadata(columnPrefix))
.addRows(
ListValue.newBuilder()
- .addValues(Value.newBuilder().setStringValue("1").build())
+ .addValues(Value.newBuilder().setStringValue(id).build())
.addValues(Value.newBuilder().setBoolValue(true).build())
.addValues(
Value.newBuilder()
@@ -382,7 +391,12 @@ protected static ResultSet createAllTypesResultSet(String columnPrefix) {
.addValues(Value.newBuilder().setStringValue("100").build())
.addValues(Value.newBuilder().setStringValue("6.626").build())
.addValues(
- Value.newBuilder().setStringValue("2022-02-16T13:18:02.123456789Z").build())
+ Value.newBuilder()
+ .setStringValue(
+ microsTimestamp
+ ? "2022-02-16T13:18:02.123456Z"
+ : "2022-02-16T13:18:02.123456789Z")
+ .build())
.addValues(Value.newBuilder().setStringValue("2022-03-29").build())
.addValues(Value.newBuilder().setStringValue("test").build())
.addValues(Value.newBuilder().setStringValue("{\"key\": \"value\"}").build())
@@ -391,11 +405,18 @@ protected static ResultSet createAllTypesResultSet(String columnPrefix) {
}
protected static ResultSet createAllTypesNullResultSet(String columnPrefix) {
+ return createAllTypesNullResultSet(columnPrefix, null);
+ }
+
+ protected static ResultSet createAllTypesNullResultSet(String columnPrefix, Long colBigInt) {
return com.google.spanner.v1.ResultSet.newBuilder()
.setMetadata(createAllTypesResultSetMetadata(columnPrefix))
.addRows(
ListValue.newBuilder()
- .addValues(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build())
+ .addValues(
+ colBigInt == null
+ ? Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build()
+ : Value.newBuilder().setStringValue(String.valueOf(colBigInt)).build())
.addValues(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build())
.addValues(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build())
.addValues(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build())
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/CopyInMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/CopyInMockServerTest.java
index 03f7a7b20..b370979ad 100644
--- a/src/test/java/com/google/cloud/spanner/pgadapter/CopyInMockServerTest.java
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/CopyInMockServerTest.java
@@ -50,6 +50,7 @@
import com.google.spanner.v1.TypeCode;
import io.grpc.Status;
import java.io.BufferedReader;
+import java.io.ByteArrayInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
@@ -148,6 +149,27 @@ public void testCopyIn() throws SQLException, IOException {
assertEquals(3, mutation.getInsert().getColumnsCount());
}
+ @Test
+ public void testEndRecord() throws SQLException, IOException {
+ setupCopyInformationSchemaResults();
+
+ try (Connection connection = DriverManager.getConnection(createUrl())) {
+ CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class));
+ copyManager.copyIn("COPY users FROM STDIN;", new StringReader("5\t5\t5\n\\.\n7\t7\t7\n"));
+ }
+
+ List commitRequests = mockSpanner.getRequestsOfType(CommitRequest.class);
+ assertEquals(1, commitRequests.size());
+ CommitRequest commitRequest = commitRequests.get(0);
+ assertEquals(1, commitRequest.getMutationsCount());
+
+ Mutation mutation = commitRequest.getMutations(0);
+ assertEquals(OperationCase.INSERT, mutation.getOperationCase());
+ // We should only receive 1 row, as there is an end record in the middle of the stream.
+ assertEquals(1, mutation.getInsert().getValuesCount());
+ assertEquals(3, mutation.getInsert().getColumnsCount());
+ }
+
@Test
public void testCopyUpsert() throws SQLException, IOException {
setupCopyInformationSchemaResults();
@@ -386,6 +408,19 @@ public void testCopyIn_Small() throws SQLException, IOException {
}
}
+ @Test
+ public void testCopyIn_Empty() throws SQLException, IOException {
+ setupCopyInformationSchemaResults();
+
+ try (Connection connection = DriverManager.getConnection(createUrl())) {
+ PGConnection pgConnection = connection.unwrap(PGConnection.class);
+ CopyManager copyManager = pgConnection.getCopyAPI();
+ long copyCount =
+ copyManager.copyIn("copy all_types from stdin;", new ByteArrayInputStream(new byte[0]));
+ assertEquals(0L, copyCount);
+ }
+ }
+
@Test
public void testCopyIn_Nulls() throws SQLException, IOException {
setupCopyInformationSchemaResults();
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java
index fe7da284f..70fb1ce66 100644
--- a/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java
@@ -35,6 +35,7 @@
import com.google.cloud.spanner.connection.RandomResultSetGenerator;
import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata;
import com.google.cloud.spanner.pgadapter.session.SessionState;
+import com.google.cloud.spanner.pgadapter.statements.BackendConnection;
import com.google.cloud.spanner.pgadapter.statements.CopyStatement.Format;
import com.google.cloud.spanner.pgadapter.utils.CopyInParser;
import com.google.cloud.spanner.pgadapter.utils.CopyRecord;
@@ -290,7 +291,13 @@ public void testCopyOutCsv() throws SQLException, IOException {
@Test
public void testCopyOutCsvWithHeader() throws SQLException, IOException {
mockSpanner.putStatementResult(
- StatementResult.query(Statement.of("select * from all_types"), ALL_TYPES_RESULTSET));
+ StatementResult.query(
+ Statement.of("select * from all_types"),
+ ALL_TYPES_RESULTSET
+ .toBuilder()
+ .addRows(ALL_TYPES_RESULTSET.getRows(0))
+ .addRows(ALL_TYPES_NULLS_RESULTSET.getRows(0))
+ .build()));
try (Connection connection = DriverManager.getConnection(createUrl())) {
connection.createStatement().execute("set time zone 'UTC'");
@@ -301,7 +308,9 @@ public void testCopyOutCsvWithHeader() throws SQLException, IOException {
assertEquals(
"col_bigint-col_bool-col_bytea-col_float8-col_int-col_numeric-col_timestamptz-col_date-col_varchar-col_jsonb\n"
- + "1-t-\\x74657374-3.14-100-6.626-\"2022-02-16 13:18:02.123456+00\"-\"2022-03-29\"-test-\"{~\"key~\": ~\"value~\"}\"\n",
+ + "1-t-\\x74657374-3.14-100-6.626-\"2022-02-16 13:18:02.123456+00\"-\"2022-03-29\"-test-\"{~\"key~\": ~\"value~\"}\"\n"
+ + "1-t-\\x74657374-3.14-100-6.626-\"2022-02-16 13:18:02.123456+00\"-\"2022-03-29\"-test-\"{~\"key~\": ~\"value~\"}\"\n"
+ + "---------\n",
writer.toString());
}
}
@@ -487,6 +496,47 @@ public void testCopyOutBinaryPsql() throws Exception {
assertEquals(Value.string("test"), record.getValue(Type.string(), 8));
}
+ @Test
+ public void testCopyOutBinaryEmptyPsql() throws Exception {
+ assumeTrue("This test requires psql", isPsqlAvailable());
+ mockSpanner.putStatementResult(
+ StatementResult.query(
+ Statement.of("select * from all_types"),
+ ALL_TYPES_RESULTSET.toBuilder().clearRows().build()));
+
+ ProcessBuilder builder = new ProcessBuilder();
+ builder.command(
+ "psql",
+ "-h",
+ (useDomainSocket ? "/tmp" : "localhost"),
+ "-p",
+ String.valueOf(pgServer.getLocalPort()),
+ "-c",
+ "copy all_types to stdout binary");
+ Process process = builder.start();
+ StringBuilder errorBuilder = new StringBuilder();
+ try (Scanner scanner = new Scanner(new InputStreamReader(process.getErrorStream()))) {
+ while (scanner.hasNextLine()) {
+ errorBuilder.append(scanner.nextLine()).append('\n');
+ }
+ }
+ PipedOutputStream pipedOutputStream = new PipedOutputStream();
+ PipedInputStream inputStream = new PipedInputStream(pipedOutputStream, 1 << 16);
+ SessionState sessionState = mock(SessionState.class);
+ CopyInParser copyParser =
+ CopyInParser.create(sessionState, Format.BINARY, null, inputStream, false);
+ int b;
+ while ((b = process.getInputStream().read()) != -1) {
+ pipedOutputStream.write(b);
+ }
+ int res = process.waitFor();
+ assertEquals("", errorBuilder.toString());
+ assertEquals(0, res);
+
+ Iterator iterator = copyParser.iterator();
+ assertFalse(iterator.hasNext());
+ }
+
@Test
public void testCopyOutNullsBinaryPsql() throws Exception {
assumeTrue("This test requires psql", isPsqlAvailable());
@@ -536,30 +586,64 @@ public void testCopyOutNullsBinaryPsql() throws Exception {
@Test
public void testCopyOutPartitioned() throws SQLException, IOException {
- final int expectedRowCount = 100;
- RandomResultSetGenerator randomResultSetGenerator =
- new RandomResultSetGenerator(expectedRowCount, Dialect.POSTGRESQL);
- com.google.spanner.v1.ResultSet resultSet = randomResultSetGenerator.generate();
- mockSpanner.putStatementResult(
- StatementResult.query(Statement.of("select * from random"), resultSet));
+ for (int expectedRowCount : new int[] {0, 1, 2, 3, 5, BackendConnection.MAX_PARTITIONS, 100}) {
+ RandomResultSetGenerator randomResultSetGenerator =
+ new RandomResultSetGenerator(expectedRowCount, Dialect.POSTGRESQL);
+ com.google.spanner.v1.ResultSet resultSet = randomResultSetGenerator.generate();
+ mockSpanner.putStatementResult(
+ StatementResult.query(Statement.of("select * from random"), resultSet));
+
+ try (Connection connection = DriverManager.getConnection(createUrl())) {
+ CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class));
+ StringWriter writer = new StringWriter();
+ long rows = copyManager.copyOut("COPY random TO STDOUT", writer);
+
+ assertEquals(expectedRowCount, rows);
+
+ try (Scanner scanner = new Scanner(writer.toString())) {
+ int lineCount = 0;
+ while (scanner.hasNextLine()) {
+ lineCount++;
+ String line = scanner.nextLine();
+ String[] columns = line.split("\t");
+ int index = findIndex(resultSet, columns);
+ assertNotEquals(String.format("Row %d not found: %s", lineCount, line), -1, index);
+ }
+ assertEquals(expectedRowCount, lineCount);
+ }
+ }
+ }
+ }
- try (Connection connection = DriverManager.getConnection(createUrl())) {
- CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class));
- StringWriter writer = new StringWriter();
- long rows = copyManager.copyOut("COPY random TO STDOUT", writer);
-
- assertEquals(expectedRowCount, rows);
-
- try (Scanner scanner = new Scanner(writer.toString())) {
- int lineCount = 0;
- while (scanner.hasNextLine()) {
- lineCount++;
- String line = scanner.nextLine();
- String[] columns = line.split("\t");
- int index = findIndex(resultSet, columns);
- assertNotEquals(String.format("Row %d not found: %s", lineCount, line), -1, index);
+ @Test
+ public void testCopyOutPartitionedBinary() throws SQLException, IOException {
+ for (int expectedRowCount : new int[] {0, 1, 2, 3, 5, BackendConnection.MAX_PARTITIONS, 100}) {
+ RandomResultSetGenerator randomResultSetGenerator =
+ new RandomResultSetGenerator(expectedRowCount, Dialect.POSTGRESQL);
+ com.google.spanner.v1.ResultSet resultSet = randomResultSetGenerator.generate();
+ mockSpanner.putStatementResult(
+ StatementResult.query(Statement.of("select * from random"), resultSet));
+
+ try (Connection connection = DriverManager.getConnection(createUrl())) {
+ CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class));
+ PipedOutputStream pipedOutputStream = new PipedOutputStream();
+ PipedInputStream inputStream = new PipedInputStream(pipedOutputStream, 1 << 20);
+ SessionState sessionState = mock(SessionState.class);
+ CopyInParser copyParser =
+ CopyInParser.create(sessionState, Format.BINARY, null, inputStream, false);
+ long rows = copyManager.copyOut("COPY random TO STDOUT (format binary)", pipedOutputStream);
+
+ assertEquals(expectedRowCount, rows);
+
+ Iterator iterator = copyParser.iterator();
+ int recordCount = 0;
+ while (iterator.hasNext()) {
+ recordCount++;
+ CopyRecord record = iterator.next();
+ int index = findIndex(resultSet, record);
+ assertNotEquals(String.format("Row %d not found: %s", recordCount, record), -1, index);
}
- assertEquals(expectedRowCount, lineCount);
+ assertEquals(expectedRowCount, recordCount);
}
}
}
@@ -627,4 +711,45 @@ static int findIndex(com.google.spanner.v1.ResultSet resultSet, String[] cols) {
}
return -1;
}
+
+ static int findIndex(com.google.spanner.v1.ResultSet resultSet, CopyRecord record) {
+ for (int index = 0; index < resultSet.getRowsCount(); index++) {
+ boolean nullValuesEqual = true;
+ for (int colIndex = 0; colIndex < record.numColumns(); colIndex++) {
+ if (record.isNull(colIndex)
+ != resultSet.getRows(index).getValues(colIndex).hasNullValue()) {
+ nullValuesEqual = false;
+ break;
+ }
+ }
+ if (!nullValuesEqual) {
+ continue;
+ }
+
+ boolean valuesEqual = true;
+ for (int colIndex = 0; colIndex < record.numColumns(); colIndex++) {
+ if (!resultSet.getRows(index).getValues(colIndex).hasNullValue()) {
+ if (resultSet.getMetadata().getRowType().getFields(colIndex).getType().getCode()
+ == TypeCode.STRING
+ && !record
+ .getValue(Type.string(), colIndex)
+ .getString()
+ .equals(resultSet.getRows(index).getValues(colIndex).getStringValue())) {
+ valuesEqual = false;
+ break;
+ }
+ if (resultSet.getRows(index).getValues(colIndex).hasBoolValue()
+ && record.getValue(Type.bool(), colIndex).getBool()
+ != resultSet.getRows(index).getValues(colIndex).getBoolValue()) {
+ valuesEqual = false;
+ break;
+ }
+ }
+ }
+ if (valuesEqual) {
+ return index;
+ }
+ }
+ return -1;
+ }
}
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/EmulatedPgbenchMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/EmulatedPgbenchMockServerTest.java
new file mode 100644
index 000000000..9e477b2d8
--- /dev/null
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/EmulatedPgbenchMockServerTest.java
@@ -0,0 +1,114 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.cloud.spanner.pgadapter;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
+
+import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
+import com.google.cloud.spanner.Statement;
+import com.google.cloud.spanner.pgadapter.error.SQLState;
+import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector;
+import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient;
+import io.grpc.Status;
+import java.nio.charset.StandardCharsets;
+import java.sql.Connection;
+import java.sql.DriverManager;
+import java.sql.SQLException;
+import java.sql.SQLWarning;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.postgresql.PGProperty;
+import org.postgresql.util.PSQLException;
+import org.postgresql.util.PSQLWarning;
+
+@RunWith(JUnit4.class)
+public class EmulatedPgbenchMockServerTest extends AbstractMockServerTest {
+
+ @BeforeClass
+ public static void loadPgJdbcDriver() throws Exception {
+ // Make sure the PG JDBC driver is loaded.
+ Class.forName("org.postgresql.Driver");
+ }
+
+ /**
+ * Creates a JDBC connection string that instructs the PG JDBC driver to use the default simple
+ * mode. It also adds 'pgbench' as the application name, which will make PGAdapter automatically
+ * recognize the connection as a pgbench connection.
+ */
+ private String createUrl() {
+ return String.format(
+ "jdbc:postgresql://localhost:%d/db?preferQueryMode=simple&%s=pgbench&%s=090000",
+ pgServer.getLocalPort(),
+ PGProperty.APPLICATION_NAME.getName(),
+ PGProperty.ASSUME_MIN_SERVER_VERSION.getName());
+ }
+
+ @Test
+ public void testClientDetectionAndHint() throws Exception {
+ WellKnownClient.PGBENCH.reset();
+ // Verify that we get the notice response that indicates that the client was automatically
+ // detected.
+ try (Connection connection = DriverManager.getConnection(createUrl())) {
+ SQLWarning warning = connection.getWarnings();
+ assertNotNull(warning);
+ PSQLWarning psqlWarning = (PSQLWarning) warning;
+ assertNotNull(psqlWarning.getServerErrorMessage());
+ assertNotNull(psqlWarning.getServerErrorMessage().getSQLState());
+ assertArrayEquals(
+ SQLState.Success.getBytes(),
+ psqlWarning.getServerErrorMessage().getSQLState().getBytes(StandardCharsets.UTF_8));
+ assertEquals(
+ "Detected connection from pgbench", psqlWarning.getServerErrorMessage().getMessage());
+ assertEquals(
+ ClientAutoDetector.PGBENCH_USAGE_HINT + "\n",
+ psqlWarning.getServerErrorMessage().getHint());
+
+ assertNull(warning.getNextWarning());
+ }
+
+ // Verify that creating a second connection directly afterwards with pgbench does not repeat the
+ // hint.
+ try (Connection connection = DriverManager.getConnection(createUrl())) {
+ SQLWarning warning = connection.getWarnings();
+ assertNull(warning);
+ }
+ }
+
+ @Test
+ public void testErrorHint() throws SQLException {
+ // Verify that any error message includes the pgbench usage hint.
+ String sql = "select * from foo where bar=1";
+ mockSpanner.putStatementResult(
+ StatementResult.exception(
+ Statement.of(sql),
+ Status.INVALID_ARGUMENT.withDescription("test error").asRuntimeException()));
+ try (Connection connection = DriverManager.getConnection(createUrl())) {
+ PSQLException exception =
+ assertThrows(PSQLException.class, () -> connection.createStatement().execute(sql));
+ assertNotNull(exception.getServerErrorMessage());
+ assertEquals(
+ String.format("test error - Statement: '%s'", sql),
+ exception.getServerErrorMessage().getMessage());
+ assertEquals(
+ ClientAutoDetector.PGBENCH_USAGE_HINT, exception.getServerErrorMessage().getHint());
+ }
+ }
+}
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/EmulatedPsqlMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/EmulatedPsqlMockServerTest.java
index 70660518e..86d53ee77 100644
--- a/src/test/java/com/google/cloud/spanner/pgadapter/EmulatedPsqlMockServerTest.java
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/EmulatedPsqlMockServerTest.java
@@ -195,6 +195,18 @@ public void testConnectToNonExistingDatabase() {
}
}
+ @Test
+ public void testShowApplicationName() throws SQLException {
+ try (Connection connection = DriverManager.getConnection(createUrl("my-db"))) {
+ try (ResultSet resultSet =
+ connection.createStatement().executeQuery("show application_name")) {
+ assertTrue(resultSet.next());
+ assertEquals("psql", resultSet.getString(1));
+ assertFalse(resultSet.next());
+ }
+ }
+ }
+
@Test
public void testConnectToNonExistingInstance() {
try {
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/ITPsqlTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/ITPsqlTest.java
index fc4adcb25..5984abb8e 100644
--- a/src/test/java/com/google/cloud/spanner/pgadapter/ITPsqlTest.java
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/ITPsqlTest.java
@@ -26,6 +26,7 @@
import com.google.cloud.spanner.Database;
import com.google.cloud.spanner.KeySet;
import com.google.cloud.spanner.Mutation;
+import com.google.cloud.spanner.pgadapter.statements.CopyStatement.Format;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
@@ -47,10 +48,12 @@
import java.time.OffsetDateTime;
import java.time.ZoneId;
import java.time.ZoneOffset;
+import java.time.zone.ZoneRulesException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Random;
import java.util.Scanner;
+import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.junit.AfterClass;
import org.junit.BeforeClass;
@@ -68,6 +71,8 @@
@Category(IntegrationTest.class)
@RunWith(JUnit4.class)
public class ITPsqlTest implements IntegrationTest {
+ private static final Logger logger = Logger.getLogger(ITPsqlTest.class.getName());
+
private static final PgAdapterTestEnv testEnv = new PgAdapterTestEnv();
private static Database database;
private static boolean allAssumptionsPassed = false;
@@ -395,9 +400,28 @@ public void testPrepareExecuteDeallocate() throws IOException, InterruptedExcept
}
@Test
- public void testCopyFromPostgreSQLToCloudSpanner() throws Exception {
+ public void testSetOperationWithOrderBy() throws IOException, InterruptedException {
+ // TODO: Remove
+ assumeTrue(
+ testEnv.getSpannerUrl().equals("https://staging-wrenchworks.sandbox.googleapis.com"));
+
+ Tuple result =
+ runUsingPsql(
+ "select * from (select 1) one union all select * from (select 2) two order by 1");
+ String output = result.x(), errors = result.y();
+ assertEquals("", errors);
+ assertEquals(" ?column? \n----------\n 1\n 2\n(2 rows)\n", output);
+ }
+
+ /**
+ * This test copies data back and forth between PostgreSQL and Cloud Spanner and verifies that the
+ * contents are equal after the COPY operation in both directions.
+ */
+ @Test
+ public void testCopyBetweenPostgreSQLAndCloudSpanner() throws Exception {
int numRows = 100;
+ logger.info("Copying initial data to PG");
// Generate 99 random rows.
copyRandomRowsToPostgreSQL(numRows - 1);
// Also add one row with all nulls to ensure that nulls are also copied correctly.
@@ -410,6 +434,7 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception {
}
}
+ logger.info("Verifying initial data in PG");
// Verify that we have 100 rows in PostgreSQL.
try (Connection connection =
DriverManager.getConnection(createJdbcUrlForLocalPg(), POSTGRES_USER, POSTGRES_PASSWORD)) {
@@ -421,12 +446,14 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception {
}
}
- // Execute the COPY tests in both binary and text mode.
- for (boolean binary : new boolean[] {false, true}) {
+ // Execute the COPY tests in both binary, csv and text mode.
+ for (Format format : Format.values()) {
+ logger.info("Testing format: " + format);
// Make sure the all_types table on Cloud Spanner is empty.
String databaseId = database.getId().getDatabase();
testEnv.write(databaseId, Collections.singleton(Mutation.delete("all_types", KeySet.all())));
+ logger.info("Copying rows to Spanner");
// COPY the rows to Cloud Spanner.
ProcessBuilder builder = new ProcessBuilder();
builder.command(
@@ -441,8 +468,9 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception {
+ POSTGRES_USER
+ " -d "
+ POSTGRES_DATABASE
- + " -c \"copy all_types to stdout"
- + (binary ? " binary \" " : "\" ")
+ + " -c \"copy all_types to stdout (format "
+ + format
+ + ") \" "
+ " | psql "
+ " -h "
+ (POSTGRES_HOST.startsWith("/") ? "/tmp" : testEnv.getPGAdapterHost())
@@ -450,14 +478,16 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception {
+ testEnv.getPGAdapterPort()
+ " -d "
+ database.getId().getDatabase()
- + " -c \"copy all_types from stdin "
- + (binary ? "binary" : "")
+ + " -c \"copy all_types from stdin (format "
+ + format
+ + ")"
+ ";\"\n");
setPgPassword(builder);
Process process = builder.start();
int res = process.waitFor();
assertEquals(0, res);
+ logger.info("Verifying data in Spanner");
// Verify that we now also have 100 rows in Spanner.
try (Connection connection = DriverManager.getConnection(createJdbcUrlForPGAdapter())) {
try (ResultSet resultSet =
@@ -468,9 +498,11 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception {
}
}
+ logger.info("Comparing table contents");
// Verify that the rows in both databases are equal.
compareTableContents();
+ logger.info("Deleting all data in PG");
// Remove all rows in the table in the local PostgreSQL database and then copy everything from
// Cloud Spanner to PostgreSQL.
try (Connection connection =
@@ -479,14 +511,16 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception {
assertEquals(numRows, connection.createStatement().executeUpdate("delete from all_types"));
}
+ logger.info("Copying rows to PG");
// COPY the rows from Cloud Spanner to PostgreSQL.
ProcessBuilder copyToPostgresBuilder = new ProcessBuilder();
copyToPostgresBuilder.command(
"bash",
"-c",
"psql"
- + " -c \"copy all_types to stdout"
- + (binary ? " binary \" " : "\" ")
+ + " -c \"copy all_types to stdout (format "
+ + format
+ + ") \" "
+ " -h "
+ (POSTGRES_HOST.startsWith("/") ? "/tmp" : testEnv.getPGAdapterHost())
+ " -p "
@@ -502,8 +536,9 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception {
+ POSTGRES_USER
+ " -d "
+ POSTGRES_DATABASE
- + " -c \"copy all_types from stdin "
- + (binary ? "binary" : "")
+ + " -c \"copy all_types from stdin (format "
+ + format
+ + ")"
+ ";\"\n");
setPgPassword(copyToPostgresBuilder);
Process copyToPostgresProcess = copyToPostgresBuilder.start();
@@ -518,11 +553,118 @@ public void testCopyFromPostgreSQLToCloudSpanner() throws Exception {
assertEquals("", errors.toString());
assertEquals(0, copyToPostgresResult);
+ logger.info("Compare table contents");
// Compare table contents again.
compareTableContents();
}
}
+ /** This test verifies that we can copy an empty table between PostgreSQL and Cloud Spanner. */
+ @Test
+ public void testCopyEmptyTableBetweenCloudSpannerAndPostgreSQL() throws Exception {
+ logger.info("Deleting all data in PG");
+ // Remove all rows in the table in the local PostgreSQL database.
+ try (Connection connection =
+ DriverManager.getConnection(createJdbcUrlForLocalPg(), POSTGRES_USER, POSTGRES_PASSWORD)) {
+ connection.createStatement().executeUpdate("delete from all_types");
+ }
+
+ // Execute the COPY tests in both binary, csv and text mode.
+ for (Format format : Format.values()) {
+ logger.info("Testing format: " + format);
+ // Make sure the all_types table on Cloud Spanner is empty.
+ String databaseId = database.getId().getDatabase();
+ testEnv.write(databaseId, Collections.singleton(Mutation.delete("all_types", KeySet.all())));
+
+ logger.info("Copy empty table to CS");
+ // COPY the empty table to Cloud Spanner.
+ ProcessBuilder builder = new ProcessBuilder();
+ builder.command(
+ "bash",
+ "-c",
+ "psql"
+ + " -h "
+ + POSTGRES_HOST
+ + " -p "
+ + POSTGRES_PORT
+ + " -U "
+ + POSTGRES_USER
+ + " -d "
+ + POSTGRES_DATABASE
+ + " -c \"copy all_types to stdout (format "
+ + format
+ + ") \" "
+ + " | psql "
+ + " -h "
+ + (POSTGRES_HOST.startsWith("/") ? "/tmp" : testEnv.getPGAdapterHost())
+ + " -p "
+ + testEnv.getPGAdapterPort()
+ + " -d "
+ + database.getId().getDatabase()
+ + " -c \"copy all_types from stdin (format "
+ + format
+ + ")"
+ + ";\"\n");
+ setPgPassword(builder);
+ Process process = builder.start();
+ int res = process.waitFor();
+ assertEquals(0, res);
+
+ logger.info("Verify that CS is empty");
+ // Verify that still have 0 rows in Cloud Spanner.
+ try (Connection connection = DriverManager.getConnection(createJdbcUrlForPGAdapter())) {
+ try (ResultSet resultSet =
+ connection.createStatement().executeQuery("select count(*) from all_types")) {
+ assertTrue(resultSet.next());
+ assertEquals(0, resultSet.getLong(1));
+ assertFalse(resultSet.next());
+ }
+ }
+
+ logger.info("Copy empty table to PG");
+ // COPY the empty table from Cloud Spanner to PostgreSQL.
+ ProcessBuilder copyToPostgresBuilder = new ProcessBuilder();
+ copyToPostgresBuilder.command(
+ "bash",
+ "-c",
+ "psql"
+ + " -c \"copy all_types to stdout (format "
+ + format
+ + ") \""
+ + " -h "
+ + (POSTGRES_HOST.startsWith("/") ? "/tmp" : testEnv.getPGAdapterHost())
+ + " -p "
+ + testEnv.getPGAdapterPort()
+ + " -d "
+ + database.getId().getDatabase()
+ + " | psql "
+ + " -h "
+ + POSTGRES_HOST
+ + " -p "
+ + POSTGRES_PORT
+ + " -U "
+ + POSTGRES_USER
+ + " -d "
+ + POSTGRES_DATABASE
+ + " -c \"copy all_types from stdin (format "
+ + format
+ + ")"
+ + ";\"\n");
+ setPgPassword(copyToPostgresBuilder);
+ Process copyToPostgresProcess = copyToPostgresBuilder.start();
+ InputStream errorStream = copyToPostgresProcess.getErrorStream();
+ int copyToPostgresResult = copyToPostgresProcess.waitFor();
+ StringBuilder errors = new StringBuilder();
+ try (Scanner scanner = new Scanner(new InputStreamReader(errorStream))) {
+ while (scanner.hasNextLine()) {
+ errors.append(errors).append(scanner.nextLine()).append("\n");
+ }
+ }
+ assertEquals("", errors.toString());
+ assertEquals(0, copyToPostgresResult);
+ }
+ }
+
@Test
public void testTimestamptzParsing() throws Exception {
final int numTests = 10;
@@ -550,9 +692,13 @@ public void testTimestamptzParsing() throws Exception {
"select name from pg_timezone_names where not name like '%%posix%%' and not name like 'Factory' offset %d limit 1",
random.nextInt(numTimezones)))) {
assertTrue(resultSet.next());
- timezone = resultSet.getString(1);
- if (!PROBLEMATIC_ZONE_IDS.contains(ZoneId.of(timezone))) {
- break;
+ try {
+ timezone = resultSet.getString(1);
+ if (!PROBLEMATIC_ZONE_IDS.contains(ZoneId.of(timezone))) {
+ break;
+ }
+ } catch (ZoneRulesException ignore) {
+ // Skip and try a different one if it is not a supported timezone on this system.
}
}
}
@@ -634,7 +780,15 @@ public void testTimestamptzParsing() throws Exception {
// Mexico abolished DST in 2022, but not all databases contain this information.
ZoneId.of("America/Chihuahua"),
// Jordan abolished DST in 2022, but not all databases contain this information.
- ZoneId.of("Asia/Amman"));
+ ZoneId.of("Asia/Amman"),
+ // Iran observed DST in 1978. Not all databases agree on this.
+ ZoneId.of("Asia/Tehran"),
+ // Rankin_Inlet did not observer DST in 1970-1979, but not all databases agree.
+ ZoneId.of("America/Rankin_Inlet"),
+ // Pangnirtung did not observer DST in 1970-1979, but not all databases agree.
+ ZoneId.of("America/Pangnirtung"),
+ // Niue switched from -11:30 to -11 in 1978. Not all JDKs know that.
+ ZoneId.of("Pacific/Niue"));
private LocalDate generateRandomLocalDate() {
return LocalDate.ofEpochDay(random.nextInt(365 * 100));
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java
index 9fab59cfa..65ff4a7ca 100644
--- a/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java
@@ -93,6 +93,7 @@
import org.postgresql.PGStatement;
import org.postgresql.core.Oid;
import org.postgresql.jdbc.PgStatement;
+import org.postgresql.util.PGobject;
import org.postgresql.util.PSQLException;
@RunWith(Parameterized.class)
@@ -135,7 +136,7 @@ private String createUrl() {
}
private String getExpectedInitialApplicationName() {
- return pgVersion.equals("1.0") ? null : "PostgreSQL JDBC Driver";
+ return pgVersion.equals("1.0") ? "jdbc" : "PostgreSQL JDBC Driver";
}
@Test
@@ -165,6 +166,23 @@ public void testQuery() throws SQLException {
}
}
+ @Test
+ public void testShowApplicationName() throws SQLException {
+ try (Connection connection = DriverManager.getConnection(createUrl())) {
+ try (ResultSet resultSet =
+ connection.createStatement().executeQuery("show application_name")) {
+ assertTrue(resultSet.next());
+ // If the PG version is 1.0, the JDBC driver thinks that the server does not support the
+ // application_name property and does not send any value. That means that PGAdapter fills it
+ // in automatically based on the client that is detected.
+ // Otherwise, the JDBC driver includes its own name, and that is not overwritten by
+ // PGAdapter.
+ assertEquals(getExpectedInitialApplicationName(), resultSet.getString(1));
+ assertFalse(resultSet.next());
+ }
+ }
+ }
+
@Test
public void testPreparedStatementParameterMetadata() throws SQLException {
String sql = "SELECT * FROM foo WHERE id=? or value=?";
@@ -637,6 +655,41 @@ public void testQueryWithLegacyDateParameter() throws SQLException {
}
}
+ @Test
+ public void testCharParam() throws SQLException {
+ String sql = "insert into foo values ($1)";
+ mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 1L));
+ String jdbcSql = "insert into foo values (?)";
+ try (Connection connection = DriverManager.getConnection(createUrl())) {
+ try (PreparedStatement preparedStatement = connection.prepareStatement(jdbcSql)) {
+ PGobject pgObject = new PGobject();
+ pgObject.setType("char");
+ pgObject.setValue("a");
+ preparedStatement.setObject(1, pgObject);
+ assertEquals(1, preparedStatement.executeUpdate());
+ }
+ }
+
+ List parseMessages =
+ pgServer.getDebugMessages().stream()
+ .filter(message -> message instanceof ParseMessage)
+ .map(message -> (ParseMessage) message)
+ .collect(Collectors.toList());
+ assertFalse(parseMessages.isEmpty());
+ ParseMessage parseMessage = parseMessages.get(parseMessages.size() - 1);
+ assertEquals(1, parseMessage.getStatement().getGivenParameterDataTypes().length);
+ assertEquals(Oid.CHAR, parseMessage.getStatement().getGivenParameterDataTypes()[0]);
+
+ List executeSqlRequests =
+ mockSpanner.getRequestsOfType(ExecuteSqlRequest.class);
+ assertEquals(1, executeSqlRequests.size());
+ ExecuteSqlRequest request = executeSqlRequests.get(0);
+ // Oid.CHAR is not a recognized type in PGAdapter.
+ assertEquals(0, request.getParamTypesCount());
+ assertEquals(1, request.getParams().getFieldsCount());
+ assertEquals("a", request.getParams().getFieldsMap().get("p1").getStringValue());
+ }
+
@Test
public void testAutoDescribedStatementsAreReused() throws SQLException {
String jdbcSql = "select col_date from all_types where col_date=?";
@@ -1794,9 +1847,7 @@ public void testPreparedStatementReturning() throws SQLException {
.bind("p9")
.to("test")
.bind("p10")
- // TODO: Change to jsonb when https://github.com/googleapis/java-spanner/pull/2182
- // has been merged.
- .to(com.google.cloud.spanner.Value.json("{\"key\": \"value\"}"))
+ .to(com.google.cloud.spanner.Value.pgJsonb("{\"key\": \"value\"}"))
.build(),
com.google.spanner.v1.ResultSet.newBuilder()
.setMetadata(ALL_TYPES_METADATA)
@@ -2880,6 +2931,21 @@ public void testSetTimeZone() throws SQLException {
}
}
+ @Test
+ public void testSetTimeZoneEST() throws SQLException {
+ try (Connection connection = DriverManager.getConnection(createUrl())) {
+ // Java considers 'EST' to always be '-05:00'. That is; it is never DST.
+ connection.createStatement().execute("set time zone 'EST'");
+ verifySettingValue(connection, "timezone", "-05:00");
+ // 'EST5EDT' is the ID for the timezone that will change with DST.
+ connection.createStatement().execute("set time zone 'EST5EDT'");
+ verifySettingValue(connection, "timezone", "EST5EDT");
+ // 'America/New_York' is the full name of the geographical timezone.
+ connection.createStatement().execute("set time zone 'America/New_York'");
+ verifySettingValue(connection, "timezone", "America/New_York");
+ }
+ }
+
@Test
public void testSetTimeZoneToServerDefault() throws SQLException {
try (Connection connection = DriverManager.getConnection(createUrl())) {
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/PsqlMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/PsqlMockServerTest.java
index 8f0e4e935..8d177e938 100644
--- a/src/test/java/com/google/cloud/spanner/pgadapter/PsqlMockServerTest.java
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/PsqlMockServerTest.java
@@ -14,16 +14,23 @@
package com.google.cloud.spanner.pgadapter;
+import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeTrue;
import com.google.cloud.spanner.MockSpannerServiceImpl;
+import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
+import com.google.cloud.spanner.Statement;
+import com.google.cloud.spanner.pgadapter.statements.CopyToStatement;
import com.google.cloud.spanner.pgadapter.wireprotocol.SSLMessage;
import com.google.common.collect.ImmutableList;
+import com.google.common.io.ByteStreams;
import com.google.spanner.admin.database.v1.UpdateDatabaseDdlRequest;
import java.io.BufferedReader;
+import java.io.ByteArrayOutputStream;
+import java.io.DataOutputStream;
import java.io.File;
import java.io.InputStreamReader;
import java.util.List;
@@ -141,4 +148,91 @@ public void testSSLRequire() throws Exception {
assertEquals(
1L, pgServer.getDebugMessages().stream().filter(m -> m instanceof SSLMessage).count());
}
+
+ @Test
+ public void testCopyToBinaryPsql() throws Exception {
+ assumeTrue("This test requires psql to be installed", isPsqlAvailable());
+
+ ProcessBuilder builder = new ProcessBuilder();
+ String[] psqlCommand =
+ new String[] {
+ "psql",
+ "-h",
+ "localhost",
+ "-p",
+ String.valueOf(pgServer.getLocalPort()),
+ "-c",
+ "COPY (SELECT 1) TO STDOUT (FORMAT BINARY)"
+ };
+ builder.command(psqlCommand);
+ Process process = builder.start();
+ String errors;
+
+ try (BufferedReader errorReader =
+ new BufferedReader(new InputStreamReader(process.getErrorStream()))) {
+ errors = errorReader.lines().collect(Collectors.joining("\n"));
+ }
+ byte[] copyOutput = ByteStreams.toByteArray(process.getInputStream());
+
+ assertEquals("", errors);
+ int res = process.waitFor();
+ assertEquals(0, res);
+
+ ByteArrayOutputStream byteArrayOutputStream =
+ new ByteArrayOutputStream(CopyToStatement.COPY_BINARY_HEADER.length);
+ DataOutputStream expectedOutputStream = new DataOutputStream(byteArrayOutputStream);
+ expectedOutputStream.write(CopyToStatement.COPY_BINARY_HEADER);
+ expectedOutputStream.writeInt(0); // flags
+ expectedOutputStream.writeInt(0); // header extension area length
+ expectedOutputStream.writeShort(1); // Column count
+ expectedOutputStream.writeInt(8); // Value length
+ expectedOutputStream.writeLong(1L); // Column value
+ expectedOutputStream.writeShort(-1); // Column count == -1 means end of data.
+
+ assertArrayEquals(byteArrayOutputStream.toByteArray(), copyOutput);
+ }
+
+ @Test
+ public void testCopyToBinaryPsqlEmptyTable() throws Exception {
+ assumeTrue("This test requires psql to be installed", isPsqlAvailable());
+ mockSpanner.putStatementResult(
+ StatementResult.query(
+ Statement.of("select * from all_types"),
+ ALL_TYPES_RESULTSET.toBuilder().clearRows().build()));
+
+ ProcessBuilder builder = new ProcessBuilder();
+ String[] psqlCommand =
+ new String[] {
+ "psql",
+ "-h",
+ "localhost",
+ "-p",
+ String.valueOf(pgServer.getLocalPort()),
+ "-c",
+ "COPY all_types TO STDOUT (FORMAT BINARY)"
+ };
+ builder.command(psqlCommand);
+ Process process = builder.start();
+ String errors;
+
+ try (BufferedReader errorReader =
+ new BufferedReader(new InputStreamReader(process.getErrorStream()))) {
+ errors = errorReader.lines().collect(Collectors.joining("\n"));
+ }
+ byte[] copyOutput = ByteStreams.toByteArray(process.getInputStream());
+
+ assertEquals("", errors);
+ int res = process.waitFor();
+ assertEquals(0, res);
+
+ ByteArrayOutputStream byteArrayOutputStream =
+ new ByteArrayOutputStream(CopyToStatement.COPY_BINARY_HEADER.length);
+ DataOutputStream expectedOutputStream = new DataOutputStream(byteArrayOutputStream);
+ expectedOutputStream.write(CopyToStatement.COPY_BINARY_HEADER);
+ expectedOutputStream.writeInt(0); // flags
+ expectedOutputStream.writeInt(0); // header extension area length
+ expectedOutputStream.writeShort(-1); // Column count == -1 means end of data.
+
+ assertArrayEquals(byteArrayOutputStream.toByteArray(), copyOutput);
+ }
}
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/csharp/AbstractNpgsqlMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/csharp/AbstractNpgsqlMockServerTest.java
index c117a0a98..f36e83033 100644
--- a/src/test/java/com/google/cloud/spanner/pgadapter/csharp/AbstractNpgsqlMockServerTest.java
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/csharp/AbstractNpgsqlMockServerTest.java
@@ -182,12 +182,12 @@ public abstract class AbstractNpgsqlMockServerTest extends AbstractMockServerTes
+ " END AS elemtypoid\n"
+ " FROM pg_type AS typ\n"
+ " LEFT JOIN pg_class AS cls ON (cls.oid = typ.typrelid)\n"
- + " LEFT JOIN pg_proc AS proc ON proc.oid = typ.typreceive\n"
+ + " LEFT JOIN pg_proc AS proc ON false\n"
+ " LEFT JOIN pg_range ON (pg_range.rngtypid = typ.oid)\n"
+ " ) AS typ\n"
+ " LEFT JOIN pg_type AS elemtyp ON elemtyp.oid = elemtypoid\n"
+ " LEFT JOIN pg_class AS elemcls ON (elemcls.oid = elemtyp.typrelid)\n"
- + " LEFT JOIN pg_proc AS elemproc ON elemproc.oid = elemtyp.typreceive\n"
+ + " LEFT JOIN pg_proc AS elemproc ON false\n"
+ ") AS t\n"
+ "JOIN pg_namespace AS ns ON (ns.oid = typnamespace)\n"
+ "WHERE\n"
@@ -481,7 +481,10 @@ public abstract class AbstractNpgsqlMockServerTest extends AbstractMockServerTes
+ "from information_schema.indexes i\n"
+ "inner join information_schema.index_columns using (table_catalog, table_schema, table_name)\n"
+ "group by i.index_name, i.table_schema\n"
- + ")\n"
+ + "),\n"
+ + "pg_attribute as (\n"
+ + "select * from (select 0::bigint as attrelid, '' as attname, 0::bigint as atttypid, 0::bigint as attstattarget, 0::bigint as attlen, 0::bigint as attnum, 0::bigint as attndims, -1::bigint as attcacheoff, 0::bigint as atttypmod, true as attbyval, '' as attalign, '' as attstorage, '' as attcompression, false as attnotnull, true as atthasdef, false as atthasmissing, '' as attidentity, '' as attgenerated, false as attisdropped, true as attislocal, 0 as attinhcount, 0 as attcollation, '{}'::bigint[] as attacl, '{}'::text[] as attoptions, '{}'::text[] as attfdwoptions, null as attmissingval\n"
+ + ") a where false)\n"
+ "-- Load field definitions for (free-standing) composite types\n"
+ "SELECT typ.oid, att.attname, att.atttypid\n"
+ "FROM pg_type AS typ\n"
@@ -519,7 +522,10 @@ public abstract class AbstractNpgsqlMockServerTest extends AbstractMockServerTes
private static final Statement SELECT_ENUM_LABELS_STATEMENT =
Statement.of(
- "with pg_namespace as (\n"
+ "with pg_enum as (\n"
+ + "select * from (select 0::bigint as oid, 0::bigint as enumtypid, 0.0::float8 as enumsortorder, ''::varchar as enumlabel\n"
+ + ") e where false),\n"
+ + "pg_namespace as (\n"
+ " select case schema_name when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as oid,\n"
+ " schema_name as nspname, null as nspowner, null as nspacl\n"
+ " from information_schema.schemata\n"
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/csharp/NpgsqlMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/csharp/NpgsqlMockServerTest.java
index 4928f8fe5..d4c51f998 100644
--- a/src/test/java/com/google/cloud/spanner/pgadapter/csharp/NpgsqlMockServerTest.java
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/csharp/NpgsqlMockServerTest.java
@@ -71,6 +71,12 @@ public void testShowServerVersion() throws IOException, InterruptedException {
assertEquals("14.1\n", result);
}
+ @Test
+ public void testShowApplicationName() throws IOException, InterruptedException {
+ String result = execute("TestShowApplicationName", createConnectionString());
+ assertEquals("npgsql\n", result);
+ }
+
@Test
public void testSelect1() throws IOException, InterruptedException {
String result = execute("TestSelect1", createConnectionString());
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java
index a2b62975c..43a33c04f 100644
--- a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java
@@ -158,6 +158,16 @@ public void testSelect1() {
assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode());
}
+ @Test
+ public void testShowApplicationName() {
+ String res = pgxTest.TestShowApplicationName(createConnString());
+
+ assertNull(res);
+
+ // This should all be handled in PGAdapter.
+ assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
+ }
+
@Test
public void testQueryWithParameter() {
String sql = "SELECT * FROM FOO WHERE BAR=$1";
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxTest.java
index 3521d6cdb..bfdc362d0 100644
--- a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxTest.java
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxTest.java
@@ -23,6 +23,8 @@ public interface PgxTest extends Library {
String TestSelect1(GoString connString);
+ String TestShowApplicationName(GoString connString);
+
String TestQueryWithParameter(GoString connString);
String TestQueryAllDataTypes(GoString connString, int oid, int format);
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java
index 0478148ef..0e363f5df 100644
--- a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java
@@ -259,9 +259,7 @@ public void testInsertAllTypes() throws IOException, InterruptedException {
.bind("p9")
.to("some-random-string")
.bind("p10")
- // TODO: Change to jsonb when https://github.com/googleapis/java-spanner/pull/2182
- // has been merged.
- .to(Value.json("{\"my_key\":\"my-value\"}"))
+ .to(Value.pgJsonb("{\"my_key\":\"my-value\"}"))
.build(),
1L);
mockSpanner.putStatementResult(updateResult);
@@ -386,9 +384,7 @@ public void testInsertAllTypesPreparedStatement() throws IOException, Interrupte
.bind("p9")
.to("some-random-string")
.bind("p10")
- // TODO: Change to jsonb when https://github.com/googleapis/java-spanner/pull/2182
- // has been merged.
- .to(Value.json("{\"my_key\":\"my-value\"}"))
+ .to(Value.pgJsonb("{\"my_key\":\"my-value\"}"))
.build(),
1L);
mockSpanner.putStatementResult(updateResult);
@@ -416,9 +412,7 @@ public void testInsertAllTypesPreparedStatement() throws IOException, Interrupte
.bind("p9")
.to((String) null)
.bind("p10")
- // TODO: Change to jsonb when https://github.com/googleapis/java-spanner/pull/2182
- // has been merged.
- .to(Value.json(null))
+ .to(Value.pgJsonb(null))
.build(),
1L));
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/TypeORMMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/TypeORMMockServerTest.java
index a0f57da1a..80e7f90df 100644
--- a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/TypeORMMockServerTest.java
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/TypeORMMockServerTest.java
@@ -728,9 +728,7 @@ public void testCreateAllTypes() throws IOException, InterruptedException {
.bind("p9")
.to("some random string")
.bind("p10")
- // TODO: Change to jsonb when https://github.com/googleapis/java-spanner/pull/2182
- // has been merged.
- .to(com.google.cloud.spanner.Value.json("{\"key\":\"value\"}"))
+ .to(com.google.cloud.spanner.Value.pgJsonb("{\"key\":\"value\"}"))
.build(),
1L));
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/ITSQLAlchemySampleTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/ITSQLAlchemySampleTest.java
new file mode 100644
index 000000000..3a2bad963
--- /dev/null
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/ITSQLAlchemySampleTest.java
@@ -0,0 +1,69 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.cloud.spanner.pgadapter.python.sqlalchemy;
+
+import static com.google.cloud.spanner.pgadapter.python.sqlalchemy.SqlAlchemyBasicsTest.execute;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+import com.google.cloud.spanner.Database;
+import com.google.cloud.spanner.pgadapter.IntegrationTest;
+import com.google.cloud.spanner.pgadapter.PgAdapterTestEnv;
+import com.google.common.collect.ImmutableList;
+import java.io.IOException;
+import java.util.Collections;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@Category(IntegrationTest.class)
+@RunWith(JUnit4.class)
+public class ITSQLAlchemySampleTest implements IntegrationTest {
+ private static final PgAdapterTestEnv testEnv = new PgAdapterTestEnv();
+ private static final String SAMPLE_DIR = "./samples/python/sqlalchemy-sample";
+
+ @BeforeClass
+ public static void setup() throws Exception {
+ testEnv.setUp();
+ Database database = testEnv.createDatabase(ImmutableList.of());
+ testEnv.startPGAdapterServerWithDefaultDatabase(database.getId(), Collections.emptyList());
+ }
+
+ @AfterClass
+ public static void teardown() {
+ testEnv.stopPGAdapterServer();
+ testEnv.cleanUp();
+ }
+
+ @Test
+ public void testSQLAlchemySample() throws IOException, InterruptedException {
+ String output =
+ execute(
+ SAMPLE_DIR,
+ "run_sample.py",
+ "localhost",
+ testEnv.getServer().getLocalPort(),
+ testEnv.getDatabaseId());
+ assertNotNull(output);
+ assertTrue(
+ output,
+ output.contains("No album found using a stale read.")
+ || output.contains(
+ "Album was found using a stale read, even though it has already been deleted."));
+ }
+}
diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/SqlAlchemyBasicsTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/SqlAlchemyBasicsTest.java
new file mode 100644
index 000000000..0086ab217
--- /dev/null
+++ b/src/test/java/com/google/cloud/spanner/pgadapter/python/sqlalchemy/SqlAlchemyBasicsTest.java
@@ -0,0 +1,551 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.cloud.spanner.pgadapter.python.sqlalchemy;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
+import com.google.cloud.spanner.RandomResultSetGenerator;
+import com.google.cloud.spanner.Statement;
+import com.google.cloud.spanner.pgadapter.AbstractMockServerTest;
+import com.google.cloud.spanner.pgadapter.python.PythonTest;
+import com.google.common.collect.ImmutableList;
+import com.google.protobuf.ListValue;
+import com.google.protobuf.Value;
+import com.google.spanner.admin.database.v1.UpdateDatabaseDdlRequest;
+import com.google.spanner.v1.CommitRequest;
+import com.google.spanner.v1.ExecuteSqlRequest;
+import com.google.spanner.v1.ResultSet;
+import com.google.spanner.v1.ResultSetStats;
+import com.google.spanner.v1.RollbackRequest;
+import com.google.spanner.v1.TypeCode;
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+import java.util.Scanner;
+import java.util.stream.Collectors;
+import org.junit.BeforeClass;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+@Category(PythonTest.class)
+public class SqlAlchemyBasicsTest extends AbstractMockServerTest {
+
+ @Parameter public String host;
+
+ @Parameters(name = "host = {0}")
+ public static List