diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ab1ccad2e..1383a2a56 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -21,31 +21,6 @@ jobs: GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} if: "${{ env.GCP_PROJECT_ID != '' }}" run: echo "::set-output name=defined::true" - units: - runs-on: ubuntu-latest - strategy: - matrix: - java: [8, 11] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-java@v3 - with: - java-version: ${{matrix.java}} - distribution: 'zulu' - - run: java -version - - uses: actions/setup-go@v3 - with: - go-version: '^1.17.7' - - run: go version - - uses: actions/setup-python@v4 - with: - python-version: '3.9' - - run: python --version - - run: pip install -r ./src/test/python/requirements.txt - - uses: actions/setup-node@v3 - with: - node-version: 16 - - run: .ci/run-with-credentials.sh units lint: runs-on: ubuntu-latest @@ -57,16 +32,6 @@ jobs: distribution: 'zulu' - run: java -version - run: .ci/run-with-credentials.sh lint - clirr: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-java@v3 - with: - java-version: 8 - distribution: 'zulu' - - run: java -version - - run: .ci/run-with-credentials.sh clirr e2e-psql-v11-v1: needs: [check-env] if: needs.check-env.outputs.has-key == 'true' diff --git a/.github/workflows/cloud-storage-build-and-push-assembly.yaml b/.github/workflows/cloud-storage-build-and-push-assembly.yaml index c6e11256a..a6dc75a17 100644 --- a/.github/workflows/cloud-storage-build-and-push-assembly.yaml +++ b/.github/workflows/cloud-storage-build-and-push-assembly.yaml @@ -22,7 +22,7 @@ jobs: uses: actions/checkout@v3 - id: 'auth' - uses: 'google-github-actions/auth@v0' + uses: 'google-github-actions/auth@v1' with: credentials_json: '${{ secrets.CLOUD_SPANNER_PG_ADAPTER_SERVICE_ACCOUNT }}' @@ -43,13 +43,13 @@ jobs: # Upload the assembly to Google Cloud Storage - id: 'upload-versioned-file' - uses: 'google-github-actions/upload-cloud-storage@v0' + uses: 'google-github-actions/upload-cloud-storage@v1' with: path: ${{ env.assembly }} destination: 'pgadapter-jar-releases' parent: false - id: 'overwrite-current-file' - uses: 'google-github-actions/upload-cloud-storage@v0' + uses: 'google-github-actions/upload-cloud-storage@v1' with: path: ${{ env.assembly_current }} destination: 'pgadapter-jar-releases' diff --git a/.github/workflows/cloud-storage-build-and-push.yaml b/.github/workflows/cloud-storage-build-and-push.yaml index cbb355b15..cd62d8fd7 100644 --- a/.github/workflows/cloud-storage-build-and-push.yaml +++ b/.github/workflows/cloud-storage-build-and-push.yaml @@ -22,7 +22,7 @@ jobs: uses: actions/checkout@v3 - id: 'auth' - uses: 'google-github-actions/auth@v0' + uses: 'google-github-actions/auth@v1' with: credentials_json: '${{ secrets.CLOUD_SPANNER_PG_ADAPTER_SERVICE_ACCOUNT }}' @@ -40,13 +40,13 @@ jobs: # Upload the jar to Google Cloud Storage - id: 'upload-versioned-file' - uses: 'google-github-actions/upload-cloud-storage@v0' + uses: 'google-github-actions/upload-cloud-storage@v1' with: path: ${{ env.uber_jar }} destination: 'pgadapter-jar-releases' parent: false - id: 'overwrite-current-file' - uses: 'google-github-actions/upload-cloud-storage@v0' + uses: 'google-github-actions/upload-cloud-storage@v1' with: path: ${{ env.uber_jar_current }} destination: 'pgadapter-jar-releases' diff --git a/.github/workflows/docker-build-and-push.yaml b/.github/workflows/docker-build-and-push.yaml index 6709b59bb..a91d06805 100644 --- a/.github/workflows/docker-build-and-push.yaml +++ b/.github/workflows/docker-build-and-push.yaml @@ -24,7 +24,7 @@ jobs: uses: actions/checkout@v3 - id: 'auth' - uses: 'google-github-actions/auth@v0' + uses: 'google-github-actions/auth@v1' with: credentials_json: '${{ secrets.CLOUD_SPANNER_PG_ADAPTER_SERVICE_ACCOUNT }}' diff --git a/.github/workflows/integration-tests-against-docker.yaml b/.github/workflows/integration-tests-against-docker.yaml index 444ab0079..b014328f2 100644 --- a/.github/workflows/integration-tests-against-docker.yaml +++ b/.github/workflows/integration-tests-against-docker.yaml @@ -19,7 +19,7 @@ jobs: java-version: 8 distribution: 'zulu' - id: 'auth' - uses: 'google-github-actions/auth@v0' + uses: 'google-github-actions/auth@v1' with: credentials_json: '${{ secrets.JSON_SERVICE_ACCOUNT_CREDENTIALS }}' export_environment_variables: true diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 5d25b1a13..f53d35aa6 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -75,7 +75,7 @@ jobs: uses: actions/setup-java@v3 with: distribution: zulu - java-version: 8 + java-version: 11 - run: java -version - name: Setup Go uses: actions/setup-go@v3 @@ -91,6 +91,7 @@ jobs: with: python-version: '3.9' - run: pip install -r ./src/test/python/requirements.txt + - run: pip install -r ./src/test/python/pg8000/requirements.txt - uses: actions/setup-node@v3 with: node-version: 16 diff --git a/.github/workflows/units.yaml b/.github/workflows/units.yaml index f43d59086..7f598a0fe 100644 --- a/.github/workflows/units.yaml +++ b/.github/workflows/units.yaml @@ -26,6 +26,7 @@ jobs: python-version: '3.9' - run: python --version - run: pip install -r ./src/test/python/requirements.txt + - run: pip install -r ./src/test/python/pg8000/requirements.txt - uses: actions/setup-node@v3 with: node-version: 16 @@ -50,6 +51,7 @@ jobs: python-version: '3.9' - run: python --version - run: pip install -r ./src/test/python/requirements.txt + - run: pip install -r ./src/test/python/pg8000/requirements.txt - run: mvn -B test macos: runs-on: macos-latest @@ -75,6 +77,7 @@ jobs: python-version: '3.9' - run: python --version - run: pip install -r ./src/test/python/requirements.txt + - run: pip install -r ./src/test/python/pg8000/requirements.txt - uses: actions/setup-dotnet@v3 with: dotnet-version: '6.0.x' diff --git a/.gitignore b/.gitignore index 2cc0e3c45..f15099496 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,14 @@ target/ /target +.DS_STORE .idea/ *.iml *.class *.lst output.txt __pycache__ +venv +.DS_Store src/test/golang/**/*.h src/test/golang/**/*.so diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b8fbcef7..d7e1aa6f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,31 @@ # Changelog +## [0.13.0](https://github.com/GoogleCloudPlatform/pgadapter/compare/v0.12.0...v0.13.0) (2022-12-07) + + +### Features + +* accept UUID as a parameter value ([#518](https://github.com/GoogleCloudPlatform/pgadapter/issues/518)) ([46941ab](https://github.com/GoogleCloudPlatform/pgadapter/commit/46941ab318e4269061336e2ecb95a4402cf2a5e5)) +* support 'select version()' and similar ([#495](https://github.com/GoogleCloudPlatform/pgadapter/issues/495)) ([fbd16ec](https://github.com/GoogleCloudPlatform/pgadapter/commit/fbd16ecd44d12ffb65b85555d2ddef0cc533b4be)) +* Support Describe message for DDL statements and other no-result statements ([#501](https://github.com/GoogleCloudPlatform/pgadapter/issues/501)) ([cb616d8](https://github.com/GoogleCloudPlatform/pgadapter/commit/cb616d8f64c6aabe0422020d7ce2bc90734ff837)) +* support DML RETURNING clause ([#498](https://github.com/GoogleCloudPlatform/pgadapter/issues/498)) ([c1d7e4e](https://github.com/GoogleCloudPlatform/pgadapter/commit/c1d7e4eff240449245f223bc17793f393cafea2f)) +* support more than 50 query parameters ([#527](https://github.com/GoogleCloudPlatform/pgadapter/issues/527)) ([9fca9ba](https://github.com/GoogleCloudPlatform/pgadapter/commit/9fca9ba487515d63b586bb4ed6329f2d84d98996)) +* use session timezone to format timestamps ([#470](https://github.com/GoogleCloudPlatform/pgadapter/issues/470)) ([d84564d](https://github.com/GoogleCloudPlatform/pgadapter/commit/d84564dc45a4259c3b8246d05c66a2645cb92f2d)) + + +### Bug Fixes + +* client side results were not returned ([#493](https://github.com/GoogleCloudPlatform/pgadapter/issues/493)) ([5e9e85e](https://github.com/GoogleCloudPlatform/pgadapter/commit/5e9e85e72b7d51bb6426ad963521fb3e24fa36bb)) +* pg_catalog tables were not replaced for information_schema queries ([#494](https://github.com/GoogleCloudPlatform/pgadapter/issues/494)) ([e1f02fe](https://github.com/GoogleCloudPlatform/pgadapter/commit/e1f02fed232c09c96adb426b9f8ce91d61c6659d)) + + +### Documentation + +* [WIP] Hibernate sample ([#373](https://github.com/GoogleCloudPlatform/pgadapter/issues/373)) ([7125c91](https://github.com/GoogleCloudPlatform/pgadapter/commit/7125c9110eab429ea311676445c71308c1018aac)) +* document Liquibase Pilot Support ([#485](https://github.com/GoogleCloudPlatform/pgadapter/issues/485)) ([745089f](https://github.com/GoogleCloudPlatform/pgadapter/commit/745089f8d7f6df2401eb0fb15cca80c85dc18437)) +* document Support for gorm ([#469](https://github.com/GoogleCloudPlatform/pgadapter/issues/469)) ([0b962af](https://github.com/GoogleCloudPlatform/pgadapter/commit/0b962af9f0037b7fb86225ed0b3f89c072bf7bcf)) +* remove limitation for RETURNING and generated columns for gorm ([#526](https://github.com/GoogleCloudPlatform/pgadapter/issues/526)) ([0420e99](https://github.com/GoogleCloudPlatform/pgadapter/commit/0420e997fb1c334bd08ee2507ca73ad11426e370)) + ## [0.12.0](https://github.com/GoogleCloudPlatform/pgadapter/compare/v0.11.0...v0.12.0) (2022-11-02) diff --git a/README.md b/README.md index 0d433c91c..3d1f86e98 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,8 @@ PGAdapter is a proxy which translates the PostgreSQL wire-protocol into the equivalent for Spanner databases [that use the PostgreSQL interface](https://cloud.google.com/spanner/docs/postgresql-interface). -PGAdapter can be used with the following clients: +## Drivers and Clients +PGAdapter can be used with the following drivers and clients: 1. `psql`: Versions 11, 12, 13 and 14 are supported. See [psql support](docs/psql.md) for more details. 2. `JDBC`: Versions 42.x and higher are supported. See [JDBC support](docs/jdbc.md) for more details. 3. `pgx`: Version 4.15 and higher are supported. See [pgx support](docs/pgx.md) for more details. @@ -11,6 +12,13 @@ PGAdapter can be used with the following clients: 5. `node-postgres`: Version 8.8.0 and higher have __experimental support__. See [node-postgres support](docs/node-postgres.md) for more details. +## Frameworks +PGAdapter can be used with the following frameworks: +1. `Liquibase`: Version 4.12.0 and higher are supported. See [Liquibase support](docs/liquibase.md) + for more details. See also [this directory](samples/java/liquibase) for a sample application using `Liquibase`. +2. `gorm`: Version 1.23.8 and higher are supported. See [gorm support](docs/gorm.md) for more details. + See also [this directory](samples/golang/gorm) for a sample application using `gorm`. + ## FAQ See [Frequently Asked Questions](docs/faq.md) for answers to frequently asked questions. @@ -56,9 +64,9 @@ Use the `-s` option to specify a different local port than the default 5432 if y PostgreSQL running on your local system. -You can also download a specific version of the jar. Example (replace `v0.12.0` with the version you want to download): +You can also download a specific version of the jar. Example (replace `v0.13.0` with the version you want to download): ```shell -VERSION=v0.12.0 +VERSION=v0.13.0 wget https://storage.googleapis.com/pgadapter-jar-releases/pgadapter-${VERSION}.tar.gz \ && tar -xzvf pgadapter-${VERSION}.tar.gz java -jar pgadapter.jar -p my-project -i my-instance -d my-database @@ -93,7 +101,7 @@ This option is only available for Java/JVM-based applications. com.google.cloud google-cloud-spanner-pgadapter - 0.12.0 + 0.13.0 ``` @@ -264,8 +272,6 @@ Other `psql` meta-commands are __not__ supported. ## Limitations PGAdapter has the following known limitations at this moment: -- Server side [prepared statements](https://www.postgresql.org/docs/current/sql-prepare.html) are limited to at most 50 parameters. -- SSL connections are not supported. - Only [password authentication](https://www.postgresql.org/docs/current/auth-password.html) using the `password` method is supported. All other authentication methods are not supported. - The COPY protocol only supports COPY TO|FROM STDOUT|STDIN [BINARY]. COPY TO|FROM is not supported. diff --git a/benchmarks/nodejs/index.js b/benchmarks/nodejs/index.js new file mode 100644 index 000000000..36e623835 --- /dev/null +++ b/benchmarks/nodejs/index.js @@ -0,0 +1,338 @@ +const {Spanner} = require('@google-cloud/spanner'); +const { Pool } = require('pg'); + +const projectId = 'spanner-pg-preview-internal'; +const instanceId = 'europe-north1'; +const databaseId = 'knut-test-db'; + +function getRandomInt(max) { + return Math.floor(Math.random() * max); +} + +async function test() { + + // Creates a Spanner client + const spanner = new Spanner({projectId}); + // Creates a PG client pool. + const pool = new Pool({ + user: 'user', + host: '/tmp', + database: 'knut-test-db', + password: 'password', + port: 5432, + max: 400, + }); + + // Gets a reference to a Cloud Spanner instance and database + const instance = spanner.instance(instanceId); + const database = instance.database(databaseId); + + // Make sure the session pools have been initialized. + await database.run('select 1'); + await pool.query('select 1'); + + await spannerSelectRowsSequentially(database, 100); + await spannerSelectMultipleRows(database, 20, 500); + await spannerSelectAndUpdateRows(database, 20, 5); + + await spannerSelectRowsInParallel(database, 1000); + await spannerSelectMultipleRowsInParallel(database, 200, 500); + await spannerSelectAndUpdateRowsInParallel(database, 200, 5); + + await pgSelectRowsSequentially(pool, 100); + await pgSelectMultipleRows(pool, 20, 500); + await pgSelectAndUpdateRows(pool, 20, 5); + + await pgSelectRowsInParallel(pool, 1000); + await pgSelectMultipleRowsInParallel(pool, 200, 500); + await pgSelectAndUpdateRowsInParallel(pool, 200, 5); + + await database.close(); + await pool.end(); +} + +async function spannerSelectRowsSequentially(database, numQueries) { + console.log(`Selecting ${numQueries} rows sequentially`); + const start = new Date(); + for (let i = 0; i < numQueries; i++) { + const query = { + sql: 'SELECT * FROM all_types WHERE col_bigint=$1', + params: { + p1: getRandomInt(5000000), + }, + }; + + await database.run(query); + process.stdout.write('.'); + } + process.stdout.write('\n'); + const end = new Date(); + const elapsed = end - start; + console.log(`Execution time for selecting ${numQueries} rows sequentially: ${elapsed}ms`); + console.log(`Avg execution time: ${elapsed/numQueries}ms`); +} + +async function spannerSelectRowsInParallel(database, numQueries) { + console.log(`Selecting ${numQueries} rows in parallel`); + const start = new Date(); + const promises = []; + for (let i = 0; i < numQueries; i++) { + const query = { + sql: 'SELECT * FROM all_types WHERE col_bigint=$1', + params: { + p1: getRandomInt(5000000), + }, + }; + + promises.push(database.run(query)); + process.stdout.write('.'); + } + process.stdout.write('\n'); + console.log("Waiting for queries to finish"); + const allRows = await Promise.all(promises); + allRows.forEach(rows => { + if (rows[0].length < 0 || rows[0].length > 1) { + console.log(`Unexpected row count: ${rows[0].length}`); + } + }); + const end = new Date(); + const elapsed = end - start; + console.log(`Execution time for selecting ${numQueries} rows in parallel: ${elapsed}ms`); + console.log(`Avg execution time: ${elapsed/numQueries}ms`); +} + +async function spannerSelectMultipleRows(database, numQueries, numRows) { + console.log(`Selecting ${numQueries} with each ${numRows} rows sequentially`); + const start = new Date(); + for (let i = 0; i < numQueries; i++) { + const query = { + sql: `SELECT * FROM all_types WHERE col_bigint>$1 LIMIT ${numRows}`, + params: { + p1: getRandomInt(5000000), + }, + json: true, + }; + await database.run(query); + process.stdout.write('.'); + } + process.stdout.write('\n'); + const end = new Date(); + const elapsed = end - start; + console.log(`Execution time for executing ${numQueries} with ${numRows} rows each: ${elapsed}ms`); + console.log(`Avg execution time: ${elapsed/numQueries}ms`); +} + +async function spannerSelectMultipleRowsInParallel(database, numQueries, numRows) { + console.log(`Selecting ${numQueries} with each ${numRows} rows in parallel`); + const start = new Date(); + const promises = []; + for (let i = 0; i < numQueries; i++) { + const query = { + sql: `SELECT * FROM all_types WHERE col_bigint>$1 LIMIT ${numRows}`, + params: { + p1: getRandomInt(5000000), + }, + json: true, + }; + promises.push(database.run(query)); + process.stdout.write('.'); + } + process.stdout.write('\n'); + console.log("Waiting for queries to finish"); + const allRows = await Promise.all(promises); + allRows.forEach(rows => { + if (rows[0].length < 0 || rows[0].length > numRows) { + console.log(`Unexpected row count: ${rows[0].length}`); + } + }); + const end = new Date(); + const elapsed = end - start; + console.log(`Execution time for executing ${numQueries} with ${numRows} rows each in parallel: ${elapsed}ms`); + console.log(`Avg execution time: ${elapsed/numQueries}ms`); +} + +async function spannerSelectAndUpdateRows(database, numTransactions, numRowsPerTx) { + console.log(`Executing ${numTransactions} with each ${numRowsPerTx} rows per transaction`); + const start = new Date(); + for (let i = 0; i < numTransactions; i++) { + await database.runTransactionAsync(selectAndUpdate); + process.stdout.write('.'); + } + process.stdout.write('\n'); + const end = new Date(); + const elapsed = end - start; + console.log(`Execution time executing ${numTransactions} with ${numRowsPerTx} rows each: ${elapsed}ms`); + console.log(`Avg execution time: ${elapsed/numTransactions}ms`); +} + +async function spannerSelectAndUpdateRowsInParallel(database, numTransactions, numRowsPerTx) { + console.log(`Executing ${numTransactions} with each ${numRowsPerTx} rows per transaction in parallel`); + const start = new Date(); + const promises = []; + for (let i = 0; i < numTransactions; i++) { + promises.push(database.runTransactionAsync(selectAndUpdate)); + process.stdout.write('.'); + } + process.stdout.write('\n'); + console.log("Waiting for transactions to finish"); + await Promise.all(promises); + const end = new Date(); + const elapsed = end - start; + console.log(`Execution time executing ${numTransactions} with ${numRowsPerTx} rows each: ${elapsed}ms`); + console.log(`Avg execution time: ${elapsed/numTransactions}ms`); +} + +async function selectAndUpdate(tx) { + const query = { + sql: 'SELECT * FROM all_types WHERE col_bigint=$1', + params: { + p1: getRandomInt(5000000), + }, + json: true, + }; + const [rows] = await tx.run(query); + if (rows.length === 1) { + rows[0].col_float8 = Math.random(); + const update = { + sql: 'UPDATE all_types SET col_float8=$1 WHERE col_bigint=$2', + params: { + p1: rows[0].col_float8, + p2: rows[0].col_bigint, + }, + } + const [rowCount] = await tx.runUpdate(update); + if (rowCount !== 1) { + console.error(`Unexpected update count: ${rowCount}`); + } + } + await tx.commit(); +} + +async function pgSelectRowsSequentially(pool, numQueries) { + console.log(`PG: Selecting ${numQueries} rows sequentially`); + const start = new Date(); + for (let i = 0; i < numQueries; i++) { + await pool.query('SELECT * FROM all_types WHERE col_bigint=$1', [getRandomInt(5000000)]); + process.stdout.write('.'); + } + process.stdout.write('\n'); + const end = new Date(); + const elapsed = end - start; + console.log(`PG: Execution time for selecting ${numQueries} rows sequentially: ${elapsed}ms`); + console.log(`PG: Avg execution time: ${elapsed/numQueries}ms`); +} + +async function pgSelectRowsInParallel(pool, numQueries) { + console.log(`PG: Selecting ${numQueries} rows in parallel`); + const start = new Date(); + const promises = []; + for (let i = 0; i < numQueries; i++) { + promises.push(pool.query('SELECT * FROM all_types WHERE col_bigint=$1', [getRandomInt(5000000)])); + process.stdout.write('.'); + } + process.stdout.write('\n'); + console.log("Waiting for queries to finish"); + const allRows = await Promise.all(promises); + allRows.forEach(result => { + if (result.rows.length < 0 || result.rows.length > 1) { + console.log(`Unexpected row count: ${result.rows.length}`); + } + }); + const end = new Date(); + const elapsed = end - start; + console.log(`PG: Execution time for selecting ${numQueries} rows in parallel: ${elapsed}ms`); + console.log(`PG: Avg execution time: ${elapsed/numQueries}ms`); +} + +async function pgSelectMultipleRows(pool, numQueries, numRows) { + console.log(`PG: Selecting ${numQueries} with each ${numRows} rows sequentially`); + const start = new Date(); + for (let i = 0; i < numQueries; i++) { + const sql = `SELECT * FROM all_types WHERE col_bigint>$1 LIMIT ${numRows}`; + await pool.query(sql, [getRandomInt(5000000)]); + process.stdout.write('.'); + } + process.stdout.write('\n'); + const end = new Date(); + const elapsed = end - start; + console.log(`PG: Execution time for executing ${numQueries} with ${numRows} rows each: ${elapsed}ms`); + console.log(`PG: Avg execution time: ${elapsed/numQueries}ms`); +} + +async function pgSelectMultipleRowsInParallel(pool, numQueries, numRows) { + console.log(`PG: Selecting ${numQueries} with each ${numRows} rows in parallel`); + const start = new Date(); + const promises = []; + for (let i = 0; i < numQueries; i++) { + const sql = `SELECT * FROM all_types WHERE col_bigint>$1 LIMIT ${numRows}`; + promises.push(pool.query(sql, [getRandomInt(5000000)])); + process.stdout.write('.'); + } + console.log("Waiting for queries to finish"); + const allResults = await Promise.all(promises); + allResults.forEach(result => { + if (result.rows.length < 0 || result.rows.length > numRows) { + console.log(`Unexpected row count: ${result.rows.length}`); + } + }); + const end = new Date(); + const elapsed = end - start; + console.log(`PG: Execution time for executing ${numQueries} with ${numRows} rows each: ${elapsed}ms`); + console.log(`PG: Avg execution time: ${elapsed/numQueries}ms`); +} + +async function pgSelectAndUpdateRows(pool, numTransactions, numRowsPerTx) { + console.log(`PG: Executing ${numTransactions} with each ${numRowsPerTx} rows per transaction`); + const start = new Date(); + for (let i = 0; i < numTransactions; i++) { + await pgRunTransaction(pool); + process.stdout.write('.'); + } + process.stdout.write('\n'); + const end = new Date(); + const elapsed = end - start; + console.log(`PG: Execution time executing ${numTransactions} with ${numRowsPerTx} rows each: ${elapsed}ms`); + console.log(`PG: Avg execution time: ${elapsed/numTransactions}ms`); +} + +async function pgSelectAndUpdateRowsInParallel(pool, numTransactions, numRowsPerTx) { + console.log(`PG: Executing ${numTransactions} with each ${numRowsPerTx} rows per transaction`); + const start = new Date(); + const promises = []; + for (let i = 0; i < numTransactions; i++) { + promises.push(pgRunTransaction(pool)); + process.stdout.write('.'); + } + process.stdout.write('\n'); + console.log("Waiting for transactions to finish"); + await Promise.all(promises); + const end = new Date(); + const elapsed = end - start; + console.log(`PG: Execution time executing ${numTransactions} with ${numRowsPerTx} rows each: ${elapsed}ms`); + console.log(`PG: Avg execution time: ${elapsed/numTransactions}ms`); +} + +async function pgRunTransaction(pool) { + const client = await pool.connect() + try { + await client.query('BEGIN'); + const selectResult = await client.query('SELECT * FROM all_types WHERE col_bigint=$1', [getRandomInt(5000000)]); + if (selectResult.rows.length === 1) { + const updateResult = await client.query('UPDATE all_types SET col_float8=$1 WHERE col_bigint=$2', [ + Math.random(), + selectResult.rows[0].col_bigint, + ]); + if (updateResult.rowCount !== 1) { + console.error(`Unexpected update count: ${updateResult.rowCount}`); + } + } + await client.query('COMMIT') + } catch (e) { + await client.query('ROLLBACK') + throw e + } finally { + client.release() + } +} + +test().then(() => console.log('Finished')); diff --git a/benchmarks/nodejs/package.json b/benchmarks/nodejs/package.json new file mode 100644 index 000000000..cab9ca702 --- /dev/null +++ b/benchmarks/nodejs/package.json @@ -0,0 +1,16 @@ +{ + "name": "spanner-benchmark", + "version": "1.0.0", + "description": "", + "main": "index.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "keywords": [], + "author": "", + "license": "ISC", + "dependencies": { + "@google-cloud/spanner": "^6.4.0", + "pg": "^8.8.0" + } +} diff --git a/clirr-ignored-differences.xml b/clirr-ignored-differences.xml index 14173412f..04467af96 100644 --- a/clirr-ignored-differences.xml +++ b/clirr-ignored-differences.xml @@ -90,6 +90,20 @@ void connectToSpanner(java.lang.String, com.google.auth.oauth2.GoogleCredentials) void connectToSpanner(java.lang.String, com.google.auth.Credentials) + + + 7005 + com/google/cloud/spanner/pgadapter/ConnectionHandler + void registerAutoDescribedStatement(java.lang.String, int[]) + void registerAutoDescribedStatement(java.lang.String, java.util.concurrent.Future) + + + 7006 + com/google/cloud/spanner/pgadapter/ConnectionHandler + int[] getAutoDescribedStatement(java.lang.String) + int[] + java.util.concurrent.Future + diff --git a/docs/gorm.md b/docs/gorm.md new file mode 100644 index 000000000..180fe9db5 --- /dev/null +++ b/docs/gorm.md @@ -0,0 +1,52 @@ +# PGAdapter - gorm Connection Options + +PGAdapter has Pilot Support for [gorm](https://gorm.io) version v1.23.8 and higher. + +## Limitations +Pilot Support means that it is possible to use `gorm` with Cloud Spanner PostgreSQL databases, but +with limitations. This means that porting an existing application from PostgreSQL to Cloud Spanner +will probably require code changes. See [Limitations](../samples/golang/gorm/README.md#limitations) +in the `gorm` sample directory for a full list of limitations. + +## Usage + +First start PGAdapter: + +```shell +export GOOGLE_APPLICATION_CREDENTIALS=/path/to/credentials.json +docker pull gcr.io/cloud-spanner-pg-adapter/pgadapter +docker run \ + -d -p 5432:5432 \ + -v ${GOOGLE_APPLICATION_CREDENTIALS}:${GOOGLE_APPLICATION_CREDENTIALS}:ro \ + -e GOOGLE_APPLICATION_CREDENTIALS \ + gcr.io/cloud-spanner-pg-adapter/pgadapter \ + -p my-project -i my-instance \ + -x +``` + +Then connect to PGAdapter and use `gorm` like this: + +```go +db, err := gorm.Open(postgres.Open("host=localhost port=5432 database=gorm-sample"), &gorm.Config{ + // DisableNestedTransaction will turn off the use of Savepoints if gorm + // detects a nested transaction. Cloud Spanner does not support Savepoints, + // so it is recommended to set this configuration option to true. + DisableNestedTransaction: true, + Logger: logger.Default.LogMode(logger.Error), +}) +if err != nil { + fmt.Printf("Failed to open gorm connection: %v\n", err) +} +tx := db.Begin() +user := User{ + ID: 1, + Name: "User Name", + Age: 20, +} +res := tx.Create(&user) +``` + +## Full Sample and Limitations +[This directory](../samples/golang/gorm) contains a full sample of how to work with `gorm` with +Cloud Spanner and PGAdapter. The sample readme file also lists the [current limitations](../samples/golang/gorm) +when working with `gorm`. diff --git a/docs/jdbc.md b/docs/jdbc.md index 7db1f37d0..09212d849 100644 --- a/docs/jdbc.md +++ b/docs/jdbc.md @@ -106,7 +106,3 @@ try (java.sql.Statement statement = connection.createStatement()) { ``` ## Limitations -- Server side [prepared statements](https://www.postgresql.org/docs/current/sql-prepare.html) are limited to at most 50 parameters. - Note: This is not the same as `java.sql.PreparedStatement`. A `java.sql.PreparedStatement` will only use - a server side prepared statement if it has been executed more at least `prepareThreshold` times. - See https://jdbc.postgresql.org/documentation/server-prepare/#activation for more information. diff --git a/docs/liquibase.md b/docs/liquibase.md new file mode 100644 index 000000000..5501114c8 --- /dev/null +++ b/docs/liquibase.md @@ -0,0 +1,56 @@ +# PGAdapter - Liquibase Connection Options + +PGAdapter has Pilot Support for [Liquibase](https://www.liquibase.org/) version v4.12.0 and higher. + +## Limitations +Pilot Support means that it is possible to use `Liquibase` with Cloud Spanner PostgreSQL databases, +but with limitations. This means that porting an existing application from PostgreSQL to Cloud Spanner +will probably require code changes. See [Limitations](../samples/java/liquibase/README.md#limitations) +in the `Liquibase` sample directory for a full list of limitations. + +## Usage + + +### Start PGAdapter +First start PGAdapter: + +```shell +export GOOGLE_APPLICATION_CREDENTIALS=/path/to/credentials.json +docker pull gcr.io/cloud-spanner-pg-adapter/pgadapter +docker run \ + -d -p 5432:5432 \ + -v ${GOOGLE_APPLICATION_CREDENTIALS}:${GOOGLE_APPLICATION_CREDENTIALS}:ro \ + -e GOOGLE_APPLICATION_CREDENTIALS \ + gcr.io/cloud-spanner-pg-adapter/pgadapter \ + -p my-project -i my-instance \ + -x +``` + +### Create databasechangelog and databasechangeloglock Manually +The `databasechangeloglock` and `databasechangelog` tables **must** be created manually, as the +DDL script that is automatically generated by Liquibase will try to use the data type +`timestamp without time zone`, which is not supported by Cloud Spanner. The DDL script to create +these tables manually can be found in [create_database_change_log.sql](../samples/java/liquibase/create_database_change_log.sql). + +### Connect Liquibase to PGAdapter +Liquibase will by default use DDL transactions when connecting to PostgreSQL databases. This is not +supported by Cloud Spanner. Instead, PGAdapter can automatically convert DDL transactions into DDL +batches. This requires the following option to be set in the JDBC connection URL: + +```properties +url: jdbc:postgresql://localhost:5432/liquibase-test?options=-c%20spanner.ddl_transaction_mode=AutocommitExplicitTransaction +``` + +See [liquibase.properties](../samples/java/liquibase/liquibase.properties) for an example connection URL. + +### Run Liquibase +Run Liquibase using for example the Maven plugin: + +```shell +mvn liquibase:validate +``` + +## Full Sample and Limitations +[This directory](../samples/java/liquibase) contains a full sample of how to work with `Liquibase` with +Cloud Spanner and PGAdapter. The sample readme file also lists the [current limitations](../samples/java/liquibase) +when working with `Liquibase`. diff --git a/docs/node-postgres.md b/docs/node-postgres.md index 505e46a28..c2fbcf5d6 100644 --- a/docs/node-postgres.md +++ b/docs/node-postgres.md @@ -116,4 +116,3 @@ console.log(res); ``` ## Limitations -- [Prepared statements](https://www.postgresql.org/docs/current/sql-prepare.html) are limited to at most 50 parameters. diff --git a/docs/pgx.md b/docs/pgx.md index 5858ed6cb..63be7b501 100644 --- a/docs/pgx.md +++ b/docs/pgx.md @@ -104,6 +104,3 @@ res := conn.SendBatch(context.Background(), batch) ``` ## Limitations -- Server side [prepared statements](https://www.postgresql.org/docs/current/sql-prepare.html) are limited to at most 50 parameters. - `pgx` uses server side prepared statements for all parameterized statements in extended query mode. - You can use the [simple query protocol](https://pkg.go.dev/github.com/jackc/pgx/v4#QuerySimpleProtocol) to work around this limitation. diff --git a/pom.xml b/pom.xml index 86c58a1a7..4842496a4 100644 --- a/pom.xml +++ b/pom.xml @@ -34,13 +34,13 @@ com.google.cloud.spanner.pgadapter.nodejs.NodeJSTest - 6.31.2 + 6.33.0 2.6.1 4.0.0 google-cloud-spanner-pgadapter - 0.12.0 + 0.13.0 Google Cloud Spanner PostgreSQL Adapter jar @@ -72,6 +72,11 @@ + + com.google.auto.value + auto-value-annotations + 1.10.1 + com.google.cloud google-cloud-spanner @@ -80,7 +85,7 @@ org.postgresql postgresql - 42.5.0 + 42.5.1 com.kohlschutter.junixsocket @@ -154,7 +159,7 @@ org.mockito mockito-core - 4.8.1 + 4.9.0 test diff --git a/samples/golang/gorm/README.md b/samples/golang/gorm/README.md index d71e59161..5ade1a462 100644 --- a/samples/golang/gorm/README.md +++ b/samples/golang/gorm/README.md @@ -1,12 +1,21 @@ # PGAdapter and gorm -PGAdapter can be used with [gorm](https://gorm.io/) and the `pgx` driver. This document shows how to use this sample -application, and lists the limitations when working with `gorm` with PGAdapter. +PGAdapter has Pilot Support for [gorm](https://gorm.io/) with the `pgx` driver. This document shows +how to use this sample application, and lists the limitations when working with `gorm` with PGAdapter. The [sample.go](sample.go) file contains a sample application using `gorm` with PGAdapter. Use this as a reference for features of `gorm` that are supported with PGAdapter. This sample assumes that the reader is familiar with `gorm`, and it is not intended as a tutorial for how to use `gorm` in general. +## Pilot Support +Pilot Support means that `gorm` can be used with Cloud Spanner PostgreSQL databases, but with limitations. +Applications that have been developed with `gorm` for PostgreSQL will probably require modifications +before they can be used with Cloud Spanner PostgreSQL databases. It is possible to develop new +applications using `gorm` with Cloud Spanner PostgreSQL databases. These applications will also work +with PostgreSQL without modifications. + +See [Limitations](#limitations) for a full list of limitations when working with `gorm`. + ## Start PGAdapter You must start PGAdapter before you can run the sample. The following command shows how to start PGAdapter using the pre-built Docker image. See [Running PGAdapter](../../../README.md#usage) for more information on other options for how @@ -43,7 +52,7 @@ psql -h localhost -p 5432 -d my-database -f drop_data_model.sql Cloud Spanner supports the following data types in combination with `gorm`. | PostgreSQL Type | gorm / go type | -|------------------------------------------------------------------------| +|-----------------------------------------|------------------------------| | boolean | bool, sql.NullBool | | bigint / int8 | int64, sql.NullInt64 | | varchar | string, sql.NullString | @@ -62,7 +71,6 @@ The following limitations are currently known: |------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | Migrations | Cloud Spanner does not support the full PostgreSQL DDL dialect. Automated migrations using `gorm` are therefore not supported. | | Generated primary keys | Disable auto increment for primary key columns by adding the annotation `gorm:"primaryKey;autoIncrement:false"` to the primary key property. | -| Generated columns | Generated columns require support for `RETURNING` clauses. That is currently not supported by Cloud Spanner. | | OnConflict | OnConflict clauses are not supported | | Nested transactions | Nested transactions and savepoints are not supported. It is therefore recommended to set the configuration option `DisableNestedTransaction: true,` | | Locking | Lock clauses (e.g. `clause.Locking{Strength: "UPDATE"}`) are not supported. These are generally speaking also not required, as the default isolation level that is used by Cloud Spanner is serializable. | @@ -162,14 +170,3 @@ db, err := gorm.Open(postgres.Open(connectionString), &gorm.Config{ ### Locking Locking clauses, like `clause.Locking{Strength: "UPDATE"}`, are not supported. These are generally speaking also not required, as Cloud Spanner uses isolation level `serializable` for read/write transactions. - -### Large CreateInBatches -The `CreateInBatches` function will generate an insert statement in the following form: - -```sql -INSERT INTO my_table (col1, col2, col3) -VALUES ($1, $2, $3), ($4, $5, $6), ($7, $8, $9), ..., ($x, $y, $z) -``` - -PGAdapter currently does not support prepared statements with more than 50 parameters. Either reduce the number of rows -that are inserted in one batch or use the `SimpleQueryMode` to work around this limitation. diff --git a/samples/golang/gorm/sample.go b/samples/golang/gorm/sample.go index 48b10b825..0052e0262 100644 --- a/samples/golang/gorm/sample.go +++ b/samples/golang/gorm/sample.go @@ -58,7 +58,7 @@ type Singer struct { // FullName is generated by the database. The '->' marks this a read-only field. Preferably this field should also // include a `default:(-)` annotation, as that would make gorm read the value back using a RETURNING clause. That is // however currently not supported. - FullName string `gorm:"->;type:GENERATED ALWAYS AS (coalesce(concat(first_name,' '::varchar,last_name))) STORED;"` + FullName string `gorm:"->;type:GENERATED ALWAYS AS (coalesce(concat(first_name,' '::varchar,last_name))) STORED;default:(-);"` Active bool Albums []Album } @@ -115,7 +115,7 @@ func RunSample(connString string) error { // detects a nested transaction. Cloud Spanner does not support Savepoints, // so it is recommended to set this configuration option to true. DisableNestedTransaction: true, - Logger: logger.Default.LogMode(logger.Error), + Logger: logger.Default.LogMode(logger.Error), }) if err != nil { fmt.Printf("Failed to open gorm connection: %v\n", err) @@ -537,6 +537,11 @@ func CreateSinger(db *gorm.DB, firstName, lastName string) (string, error) { LastName: lastName, } res := db.Create(&singer) + // FullName is automatically generated by the database and should be returned to the client by + // the insert statement. + if singer.FullName != firstName+" "+lastName { + return "", fmt.Errorf("unexpected full name for singer: %v", singer.FullName) + } return singer.ID, res.Error } diff --git a/samples/java/hibernate/README.md b/samples/java/hibernate/README.md new file mode 100644 index 000000000..1807768b2 --- /dev/null +++ b/samples/java/hibernate/README.md @@ -0,0 +1,82 @@ +# PGAdapter and Hibernate + +PGAdapter can be used in combination with Hibernate, but with a number of limitations. This sample +shows the command line arguments and configuration that is needed in order to use Hibernate with +PGAdapter. + +## Start PGAdapter +You must start PGAdapter before you can run the sample. The following command shows how to start PGAdapter using the +pre-built Docker image. See [Running PGAdapter](../../../README.md#usage) for more information on other options for how +to run PGAdapter. + +```shell +export GOOGLE_APPLICATION_CREDENTIALS=/path/to/credentials.json +docker pull gcr.io/cloud-spanner-pg-adapter/pgadapter +docker run \ + -d -p 5432:5432 \ + -v ${GOOGLE_APPLICATION_CREDENTIALS}:${GOOGLE_APPLICATION_CREDENTIALS}:ro \ + -e GOOGLE_APPLICATION_CREDENTIALS \ + -v /tmp:/tmp \ + gcr.io/cloud-spanner-pg-adapter/pgadapter \ + -p my-project -i my-instance \ + -x +``` + +## Creating the Sample Data Model +Run the following command in this directory. Replace the host, port and database name with the actual host, port and +database name for your PGAdapter and database setup. + +```shell +psql -h localhost -p 5432 -d my-database -f sample-schema.sql +``` + +You can also drop an existing data model using the `drop_data_model.sql` script: + +```shell +psql -h localhost -p 5432 -d my-database -f drop-data-model.sql +``` + +## Data Types +Cloud Spanner supports the following data types in combination with `Hibernate`. + +| PostgreSQL Type | Hibernate or Java | +|----------------------------------------|-------------------| +| boolean | boolean | +| bigint / int8 | long | +| varchar | String | +| text | String | +| float8 / double precision | double | +| numeric | BigDecimal | +| timestamptz / timestamp with time zone | LocalDateTime | +| bytea | byte[] | +| date | LocalDate | + +## Limitations +The following limitations are currently known: + +| Limitation | Workaround | +|--------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------| +| Schema updates | Cloud Spanner does not support the full PostgreSQL DDL dialect. Automated schema updates using `hibernate` are therefore not supported. It is recommended to set the option `hibernate.hbm2ddl.auto=none` (or `spring.jpa.hibernate.ddl-auto=none` if you are using Spring). | +| Generated primary keys | Cloud Spanner does not support `sequences`. Auto-increment primary key is not supported. Remove auto increment annotation for primary key columns. The recommended type of primary key is a client side generated `UUID` stored as a string. | + + +### Schema Updates +Schema updates are not supported as Cloud Spanner does not support the full PostgreSQL DDL dialect. It is recommended to +create the schema manually. Note that PGAdapter does support `create table if not exists` / `drop table if exists`. +See [sample-schema.sql](src/main/resources/sample-schema-sql) for the data model for this example. + +### Generated Primary Keys +`Sequences` are not supported. Hence, auto increment primary key is not supported and should be replaced with primary key definitions that +are manually assigned. See https://cloud.google.com/spanner/docs/schema-design#primary-key-prevent-hotspots +for more information on choosing a good primary key. This sample uses UUIDs that are generated by the client for primary +keys. + +```java +public class User { + // This uses auto generated UUID. + @Id + @Column(columnDefinition = "varchar", nullable = false) + @GeneratedValue + private UUID id; +} +``` \ No newline at end of file diff --git a/samples/java/hibernate/pom.xml b/samples/java/hibernate/pom.xml new file mode 100644 index 000000000..31e8b50bd --- /dev/null +++ b/samples/java/hibernate/pom.xml @@ -0,0 +1,67 @@ + + + 4.0.0 + + org.example + hibernate + 1.0-SNAPSHOT + + + 1.8 + 1.8 + 1.8 + 5.3.20.Final + false + ${skipTests} + true + 8 + true + UTF-8 + UTF-8 + 6.31.1 + 2.5.1 + + + + + + org.hibernate + hibernate-entitymanager + 5.3.20.Final + + + + + org.hibernate + hibernate-core + 5.3.20.Final + + + + + org.postgresql + postgresql + 42.4.3 + + + + + javax.xml.bind + jaxb-api + 2.3.1 + + + com.sun.xml.bind + jaxb-impl + 2.3.4 + + + org.javassist + javassist + 3.25.0-GA + + + + \ No newline at end of file diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/metadata/DescribeMetadata.java b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/CurrentLocalDateTimeGenerator.java similarity index 60% rename from src/main/java/com/google/cloud/spanner/pgadapter/metadata/DescribeMetadata.java rename to samples/java/hibernate/src/main/java/com/google/cloud/postgres/CurrentLocalDateTimeGenerator.java index 4f68b421c..24d8a7c47 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/metadata/DescribeMetadata.java +++ b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/CurrentLocalDateTimeGenerator.java @@ -1,4 +1,4 @@ -// Copyright 2020 Google LLC +// Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,17 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -package com.google.cloud.spanner.pgadapter.metadata; +package com.google.cloud.postgres; -import com.google.api.core.InternalApi; +import java.time.LocalDateTime; +import org.hibernate.Session; +import org.hibernate.tuple.ValueGenerator; -/** Simple POJO superclass to hold results from a describe statement. */ -@InternalApi -public abstract class DescribeMetadata { +public class CurrentLocalDateTimeGenerator implements ValueGenerator { - protected T metadata; - - public T getMetadata() { - return this.metadata; + @Override + public LocalDateTime generateValue(Session session, Object entity) { + return LocalDateTime.now(); } } diff --git a/samples/java/hibernate/src/main/java/com/google/cloud/postgres/HibernateSampleTest.java b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/HibernateSampleTest.java new file mode 100644 index 000000000..92d0a5fa1 --- /dev/null +++ b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/HibernateSampleTest.java @@ -0,0 +1,314 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.postgres; + +import com.google.cloud.postgres.models.Albums; +import com.google.cloud.postgres.models.Concerts; +import com.google.cloud.postgres.models.HibernateConfiguration; +import com.google.cloud.postgres.models.Singers; +import com.google.cloud.postgres.models.Tracks; +import com.google.cloud.postgres.models.Venues; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.persistence.criteria.CriteriaBuilder; +import javax.persistence.criteria.CriteriaDelete; +import javax.persistence.criteria.CriteriaQuery; +import javax.persistence.criteria.CriteriaUpdate; +import javax.persistence.criteria.Root; +import org.hibernate.Session; +import org.hibernate.Transaction; +import org.hibernate.query.Query; + +public class HibernateSampleTest { + + private static final Logger logger = Logger.getLogger(HibernateSampleTest.class.getName()); + + private HibernateConfiguration hibernateConfiguration; + + private List singersId = new ArrayList<>(); + private List albumsId = new ArrayList<>(); + private List tracksId = new ArrayList<>(); + private List venuesId = new ArrayList<>(); + private List concertsId = new ArrayList<>(); + + public HibernateSampleTest(HibernateConfiguration hibernateConfiguration) { + this.hibernateConfiguration = hibernateConfiguration; + } + + public void testJPACriteriaDelete() { + try (Session s = hibernateConfiguration.openSession()) { + final Singers singers = Utils.createSingers(); + final Albums albums = Utils.createAlbums(singers); + s.getTransaction().begin(); + s.saveOrUpdate(singers); + s.saveOrUpdate(albums); + final Tracks tracks1 = Utils.createTracks(albums.getId()); + s.saveOrUpdate(tracks1); + final Tracks tracks2 = Utils.createTracks(albums.getId()); + s.saveOrUpdate(tracks2); + s.getTransaction().commit(); + s.clear(); + + CriteriaBuilder cb = s.getCriteriaBuilder(); + CriteriaDelete albumsCriteriaDelete = cb.createCriteriaDelete(Albums.class); + Root albumsRoot = albumsCriteriaDelete.from(Albums.class); + albumsCriteriaDelete.where(cb.equal(albumsRoot.get("id"), albums.getId())); + Transaction transaction = s.beginTransaction(); + s.createQuery(albumsCriteriaDelete).executeUpdate(); + transaction.commit(); + } + } + + public void testJPACriteria() { + try (Session s = hibernateConfiguration.openSession()) { + CriteriaBuilder cb = s.getCriteriaBuilder(); + CriteriaQuery singersCriteriaQuery = cb.createQuery(Singers.class); + Root singersRoot = singersCriteriaQuery.from(Singers.class); + singersCriteriaQuery + .select(singersRoot) + .where( + cb.and( + cb.equal(singersRoot.get("firstName"), "David"), + cb.equal(singersRoot.get("lastName"), "Lee"))); + + Query singersQuery = s.createQuery(singersCriteriaQuery); + List singers = singersQuery.getResultList(); + + System.out.println("Listed singer: " + singers.size()); + + CriteriaUpdate albumsCriteriaUpdate = cb.createCriteriaUpdate(Albums.class); + Root albumsRoot = albumsCriteriaUpdate.from(Albums.class); + albumsCriteriaUpdate.set("marketingBudget", new BigDecimal("5.0")); + albumsCriteriaUpdate.where(cb.equal(albumsRoot.get("id"), UUID.fromString(albumsId.get(0)))); + Transaction transaction = s.beginTransaction(); + s.createQuery(albumsCriteriaUpdate).executeUpdate(); + transaction.commit(); + } + } + + public void testHqlUpdate() { + try (Session s = hibernateConfiguration.openSession()) { + Singers singers = Utils.createSingers(); + singers.setLastName("Cord"); + s.getTransaction().begin(); + s.saveOrUpdate(singers); + s.getTransaction().commit(); + + s.getTransaction().begin(); + Query query = + s.createQuery( + "update Singers set active=:active " + + "where lastName=:lastName and firstName=:firstName"); + query.setParameter("active", false); + query.setParameter("lastName", "Cord"); + query.setParameter("firstName", "David"); + query.executeUpdate(); + s.getTransaction().commit(); + + System.out.println("Updated singer: " + s.get(Singers.class, singers.getId())); + } + } + + public void testHqlList() { + try (Session s = hibernateConfiguration.openSession()) { + Query query = s.createQuery("from Singers"); + List list = query.list(); + System.out.println("Singers list size: " + list.size()); + + query = s.createQuery("from Singers order by fullName"); + query.setFirstResult(2); + list = query.list(); + System.out.println("Singers list size with first result: " + list.size()); + + /* Current Limit is not supported. */ + // query = s.createQuery("from Singers"); + // query.setMaxResults(2); + // list = query.list(); + // System.out.println("Singers list size with first result: " + list.size()); + + query = s.createQuery("select sum(sampleRate) from Tracks"); + list = query.list(); + System.out.println("Sample rate sum: " + list); + } + } + + public void testOneToManyData() { + try (Session s = hibernateConfiguration.openSession()) { + Venues venues = s.get(Venues.class, UUID.fromString(venuesId.get(0))); + if (venues == null) { + logger.log(Level.SEVERE, "Previously Added Venues Not Found."); + } + if (venues.getConcerts().size() <= 1) { + logger.log(Level.SEVERE, "Previously Added Concerts Not Found."); + } + + System.out.println("Venues fetched: " + venues); + } + } + + public void testDeletingData() { + try (Session s = hibernateConfiguration.openSession()) { + Singers singers = Utils.createSingers(); + s.getTransaction().begin(); + s.saveOrUpdate(singers); + s.getTransaction().commit(); + + singers = s.get(Singers.class, singers.getId()); + if (singers == null) { + logger.log(Level.SEVERE, "Added singers not found."); + } + + s.getTransaction().begin(); + s.delete(singers); + s.getTransaction().commit(); + + singers = s.get(Singers.class, singers.getId()); + if (singers != null) { + logger.log(Level.SEVERE, "Deleted singers found."); + } + } + } + + public void testAddingData() { + try (Session s = hibernateConfiguration.openSession()) { + final Singers singers = Utils.createSingers(); + final Albums albums = Utils.createAlbums(singers); + final Venues venues = Utils.createVenue(); + final Concerts concerts1 = Utils.createConcerts(singers, venues); + final Concerts concerts2 = Utils.createConcerts(singers, venues); + final Concerts concerts3 = Utils.createConcerts(singers, venues); + s.getTransaction().begin(); + s.saveOrUpdate(singers); + s.saveOrUpdate(albums); + s.saveOrUpdate(venues); + s.persist(concerts1); + s.persist(concerts2); + final Tracks tracks1 = Utils.createTracks(albums.getId()); + s.saveOrUpdate(tracks1); + final Tracks tracks2 = Utils.createTracks(albums.getId()); + s.saveOrUpdate(tracks2); + s.persist(concerts3); + s.getTransaction().commit(); + + singersId.add(singers.getId().toString()); + albumsId.add(albums.getId().toString()); + venuesId.add(venues.getId().toString()); + concertsId.add(concerts1.getId().toString()); + concertsId.add(concerts2.getId().toString()); + concertsId.add(concerts3.getId().toString()); + tracksId.add(tracks1.getId().getTrackNumber()); + tracksId.add(tracks2.getId().getTrackNumber()); + + System.out.println("Created Singer: " + singers.getId()); + System.out.println("Created Albums: " + albums.getId()); + System.out.println("Created Venues: " + venues.getId()); + System.out.println("Created Concerts: " + concerts1.getId()); + System.out.println("Created Concerts: " + concerts2.getId()); + System.out.println("Created Concerts: " + concerts3.getId()); + System.out.println("Created Tracks: " + tracks1.getId()); + System.out.println("Created Tracks: " + tracks2.getId()); + } + } + + public void testSessionRollback() { + try (Session s = hibernateConfiguration.openSession()) { + final Singers singers = Utils.createSingers(); + s.getTransaction().begin(); + s.saveOrUpdate(singers); + s.getTransaction().rollback(); + + System.out.println("Singers that was saved: " + singers.getId()); + Singers singersFromDb = s.get(Singers.class, singers.getId()); + if (singersFromDb == null) { + System.out.println("Singers not found as expected."); + } else { + logger.log(Level.SEVERE, "Singers found. Lookout for the error."); + } + } + } + + public void testForeignKey() { + try (Session s = hibernateConfiguration.openSession()) { + final Singers singers = Utils.createSingers(); + final Albums albums = Utils.createAlbums(singers); + + s.getTransaction().begin(); + s.saveOrUpdate(singers); + s.persist(albums); + s.getTransaction().commit(); + + singersId.add(singers.getId().toString()); + albumsId.add(albums.getId().toString()); + System.out.println("Created Singer: " + singers.getId()); + System.out.println("Created Albums: " + albums.getId()); + } + } + + public void executeTest() { + try { + System.out.println("Testing Foreign Key"); + testForeignKey(); + System.out.println("Foreign Key Test Completed"); + + System.out.println("Testing Session Rollback"); + testSessionRollback(); + System.out.println("Session Rollback Test Completed"); + + System.out.println("Testing Data Insert"); + testAddingData(); + System.out.println("Data Insert Test Completed"); + + System.out.println("Testing Data Delete"); + testDeletingData(); + System.out.println("Data Delete Test Completed"); + + System.out.println("Testing One to Many Fetch"); + testOneToManyData(); + System.out.println("One To Many Fetch Test Completed"); + + System.out.println("Testing HQL List"); + testHqlList(); + System.out.println("HQL List Test Completed"); + + System.out.println("Testing HQL Update"); + testHqlUpdate(); + System.out.println("HQL Update Test Completed"); + + System.out.println("Testing JPA List and Update"); + testJPACriteria(); + System.out.println("JPA List and Update Test Completed"); + + System.out.println("Testing JPA Delete"); + testJPACriteriaDelete(); + System.out.println("JPA Delete Test Completed"); + } finally { + // Make sure we always close the session factory when the test is done. Otherwise, the sample + // application might keep non-daemon threads alive and not stop. + hibernateConfiguration.closeSessionFactory(); + } + } + + public static void main(String[] args) { + System.out.println("Starting Hibernate Test"); + HibernateSampleTest hibernateSampleTest = + new HibernateSampleTest(HibernateConfiguration.createHibernateConfiguration()); + hibernateSampleTest.executeTest(); + System.out.println("Hibernate Test Ended Successfully"); + } +} diff --git a/samples/java/hibernate/src/main/java/com/google/cloud/postgres/Utils.java b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/Utils.java new file mode 100644 index 000000000..f243a3cee --- /dev/null +++ b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/Utils.java @@ -0,0 +1,83 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.postgres; + +import com.google.cloud.postgres.models.Albums; +import com.google.cloud.postgres.models.Concerts; +import com.google.cloud.postgres.models.Singers; +import com.google.cloud.postgres.models.Tracks; +import com.google.cloud.postgres.models.TracksId; +import com.google.cloud.postgres.models.Venues; +import java.math.BigDecimal; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.Random; +import java.util.UUID; + +public class Utils { + + private static Random random = new Random(); + + public static Singers createSingers() { + final Singers singers = new Singers(); + singers.setActive(true); + singers.setFirstName("David"); + singers.setLastName("Lee"); + singers.setCreatedAt(LocalDateTime.now()); + return singers; + } + + public static Albums createAlbums(Singers singers) { + final Albums albums = new Albums(); + albums.setTitle("Perfect"); + albums.setMarketingBudget(new BigDecimal("1.00")); + albums.setReleaseDate(LocalDate.now()); + albums.setCreatedAt(LocalDateTime.now()); + albums.setSingers(singers); + return albums; + } + + public static Concerts createConcerts(Singers singers, Venues venues) { + final Concerts concerts = new Concerts(); + concerts.setCreatedAt(LocalDateTime.now()); + concerts.setEndTime(LocalDateTime.now().plusHours(1)); + concerts.setStartTime(LocalDateTime.now()); + concerts.setName("Sunburn"); + concerts.setSingers(singers); + concerts.setVenues(venues); + return concerts; + } + + public static Tracks createTracks(UUID albumId) { + final Tracks tracks = new Tracks(); + tracks.setCreatedAt(LocalDateTime.now()); + tracks.setTitle("Perfect"); + tracks.setSampleRate(random.nextInt()); + TracksId tracksId = new TracksId(); + tracksId.setTrackNumber(random.nextInt()); + tracksId.setId(albumId); + tracks.setId(tracksId); + return tracks; + } + + public static Venues createVenue() { + final Venues venues = new Venues(); + venues.setCreatedAt(LocalDateTime.now()); + venues.setName("Hall"); + venues.setDescription("Theater"); + + return venues; + } +} diff --git a/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/Albums.java b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/Albums.java new file mode 100644 index 000000000..33ca7eb01 --- /dev/null +++ b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/Albums.java @@ -0,0 +1,139 @@ +package com.google.cloud.postgres.models; + +import com.google.cloud.postgres.CurrentLocalDateTimeGenerator; +import java.math.BigDecimal; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.UUID; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.GeneratedValue; +import javax.persistence.Id; +import javax.persistence.JoinColumn; +import javax.persistence.Lob; +import javax.persistence.ManyToOne; +import org.hibernate.annotations.GenerationTime; +import org.hibernate.annotations.GeneratorType; +import org.hibernate.annotations.Type; + +@Entity +public class Albums { + + @Id + @Column(columnDefinition = "varchar", nullable = false) + @GeneratedValue + private UUID id; + + private String title; + + @Column(name = "marketing_budget") + private BigDecimal marketingBudget; + + @Column(name = "release_date", columnDefinition = "date") + private LocalDate releaseDate; + + @Lob + @Type(type = "org.hibernate.type.BinaryType") + @Column(name = "cover_picture") + private byte[] coverPicture; + + @ManyToOne + @JoinColumn(name = "singer_id") + private Singers singers; + + @Column(name = "created_at", columnDefinition = "timestamptz") + private LocalDateTime createdAt; + + @GeneratorType(type = CurrentLocalDateTimeGenerator.class, when = GenerationTime.ALWAYS) + @Column(name = "updated_at", columnDefinition = "timestamptz") + private LocalDateTime updatedAt; + + public UUID getId() { + return id; + } + + public void setId(UUID id) { + this.id = id; + } + + public String getTitle() { + return title; + } + + public void setTitle(String title) { + this.title = title; + } + + public BigDecimal getMarketingBudget() { + return marketingBudget; + } + + public void setMarketingBudget(BigDecimal marketingBudget) { + this.marketingBudget = marketingBudget; + } + + public byte[] getCoverPicture() { + return coverPicture; + } + + public void setCoverPicture(byte[] coverPicture) { + this.coverPicture = coverPicture; + } + + public Singers getSingers() { + return singers; + } + + public void setSingers(Singers singers) { + this.singers = singers; + } + + public LocalDate getReleaseDate() { + return releaseDate; + } + + public void setReleaseDate(LocalDate releaseDate) { + this.releaseDate = releaseDate; + } + + public LocalDateTime getCreatedAt() { + return createdAt; + } + + public void setCreatedAt(LocalDateTime createdAt) { + this.createdAt = createdAt; + } + + public LocalDateTime getUpdatedAt() { + return updatedAt; + } + + public void setUpdatedAt(LocalDateTime updatedAt) { + this.updatedAt = updatedAt; + } + + @Override + public String toString() { + return "Albums{" + + "id='" + + id + + '\'' + + ", title='" + + title + + '\'' + + ", marketingBudget=" + + marketingBudget + + ", releaseDate=" + + releaseDate + + ", coverPicture=" + + Arrays.toString(coverPicture) + + ", singers=" + + singers + + ", createdAt=" + + createdAt + + ", updatedAt=" + + updatedAt + + '}'; + } +} diff --git a/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/Concerts.java b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/Concerts.java new file mode 100644 index 000000000..10bef124a --- /dev/null +++ b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/Concerts.java @@ -0,0 +1,145 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.postgres.models; + +import com.google.cloud.postgres.CurrentLocalDateTimeGenerator; +import java.time.LocalDateTime; +import java.util.UUID; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.GeneratedValue; +import javax.persistence.Id; +import javax.persistence.JoinColumn; +import javax.persistence.ManyToOne; +import org.hibernate.annotations.GenerationTime; +import org.hibernate.annotations.GeneratorType; + +@Entity +public class Concerts { + + @Id + @Column(columnDefinition = "varchar", nullable = false) + @GeneratedValue + private UUID id; + + @Column(name = "name", nullable = false) + private String name; + + @Column(name = "start_time", columnDefinition = "timestamptz") + private LocalDateTime startTime; + + @Column(name = "end_time", columnDefinition = "timestamptz") + private LocalDateTime endTime; + + @ManyToOne + @JoinColumn(name = "singer_id") + private Singers singers; + + @ManyToOne + @JoinColumn(name = "venue_id") + private Venues venues; + + @Column(name = "created_at", columnDefinition = "timestamptz") + private LocalDateTime createdAt; + + @GeneratorType(type = CurrentLocalDateTimeGenerator.class, when = GenerationTime.ALWAYS) + @Column(name = "updated_at", columnDefinition = "timestamptz") + private LocalDateTime updatedAt; + + public UUID getId() { + return id; + } + + public void setId(UUID id) { + this.id = id; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public LocalDateTime getStartTime() { + return startTime; + } + + public void setStartTime(LocalDateTime startTime) { + this.startTime = startTime; + } + + public LocalDateTime getEndTime() { + return endTime; + } + + public void setEndTime(LocalDateTime endTime) { + this.endTime = endTime; + } + + public Singers getSingers() { + return singers; + } + + public void setSingers(Singers singers) { + this.singers = singers; + } + + public Venues getVenues() { + return venues; + } + + public void setVenues(Venues venues) { + this.venues = venues; + } + + public LocalDateTime getCreatedAt() { + return createdAt; + } + + public void setCreatedAt(LocalDateTime createdAt) { + this.createdAt = createdAt; + } + + public LocalDateTime getUpdatedAt() { + return updatedAt; + } + + public void setUpdatedAt(LocalDateTime updatedAt) { + this.updatedAt = updatedAt; + } + + @Override + public String toString() { + return "Concerts{" + + "id=" + + id + + ", name='" + + name + + '\'' + + ", startTime=" + + startTime + + ", endTime=" + + endTime + + ", singers=" + + singers + + ", createdAt=" + + createdAt + + ", updatedAt=" + + updatedAt + + '}'; + } +} diff --git a/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/HibernateConfiguration.java b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/HibernateConfiguration.java new file mode 100644 index 000000000..324998295 --- /dev/null +++ b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/HibernateConfiguration.java @@ -0,0 +1,54 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.postgres.models; + +import org.hibernate.Session; +import org.hibernate.SessionFactory; +import org.hibernate.boot.registry.StandardServiceRegistryBuilder; +import org.hibernate.cfg.Configuration; + +public class HibernateConfiguration { + + private Configuration configuration; + private SessionFactory sessionFactory; + + private HibernateConfiguration(Configuration configuration, SessionFactory sessionFactory) { + this.configuration = configuration; + this.sessionFactory = sessionFactory; + } + + public Session openSession() { + return sessionFactory.openSession(); + } + + public void closeSessionFactory() { + sessionFactory.close(); + } + + public static HibernateConfiguration createHibernateConfiguration() { + final Configuration configuration = new Configuration(); + configuration.addAnnotatedClass(Albums.class); + configuration.addAnnotatedClass(Concerts.class); + configuration.addAnnotatedClass(Singers.class); + configuration.addAnnotatedClass(Venues.class); + configuration.addAnnotatedClass(Tracks.class); + configuration.addAnnotatedClass(TracksId.class); + + final SessionFactory sessionFactory = + configuration.buildSessionFactory(new StandardServiceRegistryBuilder().build()); + + return new HibernateConfiguration(configuration, sessionFactory); + } +} diff --git a/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/Singers.java b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/Singers.java new file mode 100644 index 000000000..7158c9bec --- /dev/null +++ b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/Singers.java @@ -0,0 +1,130 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.postgres.models; + +import com.google.cloud.postgres.CurrentLocalDateTimeGenerator; +import java.time.LocalDateTime; +import java.util.UUID; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.GeneratedValue; +import javax.persistence.Id; +import org.hibernate.annotations.Generated; +import org.hibernate.annotations.GenerationTime; +import org.hibernate.annotations.GeneratorType; + +@Entity +public class Singers { + + @Id + @Column(columnDefinition = "varchar", nullable = false) + @GeneratedValue + private UUID id; + + @Column(name = "first_name") + private String firstName; + + @Column(name = "last_name", nullable = false) + private String lastName; + + @Column(name = "full_name", insertable = false) + @Generated(GenerationTime.ALWAYS) + private String fullName; + + private boolean active; + + @Column(name = "created_at", columnDefinition = "timestamptz") + private LocalDateTime createdAt; + + @GeneratorType(type = CurrentLocalDateTimeGenerator.class, when = GenerationTime.ALWAYS) + @Column(name = "updated_at", columnDefinition = "timestamptz") + private LocalDateTime updatedAt; + + public UUID getId() { + return id; + } + + public String getFirstName() { + return firstName; + } + + public String getLastName() { + return lastName; + } + + public String getFullName() { + return fullName; + } + + public boolean isActive() { + return active; + } + + public void setId(UUID id) { + this.id = id; + } + + public void setFirstName(String firstName) { + this.firstName = firstName; + } + + public void setLastName(String lastName) { + this.lastName = lastName; + } + + public void setActive(boolean active) { + this.active = active; + } + + public LocalDateTime getCreatedAt() { + return createdAt; + } + + public void setCreatedAt(LocalDateTime createdAt) { + this.createdAt = createdAt; + } + + public LocalDateTime getUpdatedAt() { + return updatedAt; + } + + public void setUpdatedAt(LocalDateTime updatedAt) { + this.updatedAt = updatedAt; + } + + @Override + public String toString() { + return "Singers{" + + "id='" + + id + + '\'' + + ", firstName='" + + firstName + + '\'' + + ", lastName='" + + lastName + + '\'' + + ", fullName='" + + fullName + + '\'' + + ", active=" + + active + + ", createdAt=" + + createdAt + + ", updatedAt=" + + updatedAt + + '}'; + } +} diff --git a/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/Tracks.java b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/Tracks.java new file mode 100644 index 000000000..f11886ea6 --- /dev/null +++ b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/Tracks.java @@ -0,0 +1,100 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.postgres.models; + +import com.google.cloud.postgres.CurrentLocalDateTimeGenerator; +import java.time.LocalDateTime; +import javax.persistence.Column; +import javax.persistence.EmbeddedId; +import javax.persistence.Entity; +import org.hibernate.annotations.GenerationTime; +import org.hibernate.annotations.GeneratorType; + +@Entity +public class Tracks { + + // For composite primary keys, @EmbeddedId will have to be used. + @EmbeddedId private TracksId id; + + @Column(name = "title", nullable = false) + private String title; + + @Column(name = "sample_rate") + private double sampleRate; + + @Column(name = "created_at", columnDefinition = "timestamptz") + private LocalDateTime createdAt; + + @GeneratorType(type = CurrentLocalDateTimeGenerator.class, when = GenerationTime.ALWAYS) + @Column(name = "updated_at", columnDefinition = "timestamptz") + private LocalDateTime updatedAt; + + public TracksId getId() { + return id; + } + + public void setId(TracksId id) { + this.id = id; + } + + public String getTitle() { + return title; + } + + public void setTitle(String title) { + this.title = title; + } + + public double getSampleRate() { + return sampleRate; + } + + public void setSampleRate(double sampleRate) { + this.sampleRate = sampleRate; + } + + public LocalDateTime getCreatedAt() { + return createdAt; + } + + public void setCreatedAt(LocalDateTime createdAt) { + this.createdAt = createdAt; + } + + public LocalDateTime getUpdatedAt() { + return updatedAt; + } + + public void setUpdatedAt(LocalDateTime updatedAt) { + this.updatedAt = updatedAt; + } + + @Override + public String toString() { + return "Tracks{" + + "id=" + + id + + ", title='" + + title + + '\'' + + ", sampleRate=" + + sampleRate + + ", createdAt=" + + createdAt + + ", updatedAt=" + + updatedAt + + '}'; + } +} diff --git a/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/TracksId.java b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/TracksId.java new file mode 100644 index 000000000..87c079d58 --- /dev/null +++ b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/TracksId.java @@ -0,0 +1,61 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.postgres.models; + +import java.io.Serializable; +import java.util.UUID; +import javax.persistence.Column; +import javax.persistence.Embeddable; + +/** + * @Embeddable is to be used for composite primary key. + */ +@Embeddable +public class TracksId implements Serializable { + + @Column(columnDefinition = "varchar", nullable = false) + private UUID id; + + @Column(name = "track_number", nullable = false) + private long trackNumber; + + public TracksId() {} + + public TracksId(UUID id, long trackNumber) { + this.id = id; + this.trackNumber = trackNumber; + } + + public UUID getId() { + return id; + } + + public void setId(UUID id) { + this.id = id; + } + + public long getTrackNumber() { + return trackNumber; + } + + public void setTrackNumber(long trackNumber) { + this.trackNumber = trackNumber; + } + + @Override + public String toString() { + return "TracksId{" + "id=" + id + ", trackNumber=" + trackNumber + '}'; + } +} diff --git a/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/Venues.java b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/Venues.java new file mode 100644 index 000000000..22b16a61e --- /dev/null +++ b/samples/java/hibernate/src/main/java/com/google/cloud/postgres/models/Venues.java @@ -0,0 +1,123 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.postgres.models; + +import com.google.cloud.postgres.CurrentLocalDateTimeGenerator; +import java.time.LocalDateTime; +import java.util.List; +import java.util.UUID; +import javax.persistence.CascadeType; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.GeneratedValue; +import javax.persistence.Id; +import javax.persistence.JoinColumn; +import javax.persistence.OneToMany; +import org.hibernate.annotations.GenerationTime; +import org.hibernate.annotations.GeneratorType; + +@Entity +public class Venues { + + @Id + @Column(columnDefinition = "varchar", nullable = false) + @GeneratedValue + private UUID id; + + @Column(name = "name", nullable = false) + private String name; + + @Column(name = "description", nullable = false) + private String description; + + @Column(name = "created_at", columnDefinition = "timestamptz") + private LocalDateTime createdAt; + + @GeneratorType(type = CurrentLocalDateTimeGenerator.class, when = GenerationTime.ALWAYS) + @Column(name = "updated_at", columnDefinition = "timestamptz") + private LocalDateTime updatedAt; + + @OneToMany(cascade = CascadeType.ALL) + @JoinColumn(name = "venue_id") + private List concerts; + + public UUID getId() { + return id; + } + + public void setId(UUID id) { + this.id = id; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public LocalDateTime getCreatedAt() { + return createdAt; + } + + public void setCreatedAt(LocalDateTime createdAt) { + this.createdAt = createdAt; + } + + public LocalDateTime getUpdatedAt() { + return updatedAt; + } + + public void setUpdatedAt(LocalDateTime updatedAt) { + this.updatedAt = updatedAt; + } + + public List getConcerts() { + return concerts; + } + + public void setConcerts(List concerts) { + this.concerts = concerts; + } + + @Override + public String toString() { + return "Venues{" + + "id=" + + id + + ", name='" + + name + + '\'' + + ", description='" + + description + + '\'' + + ", createdAt=" + + createdAt + + ", updatedAt=" + + updatedAt + + ", concerts=" + + concerts + + '}'; + } +} diff --git a/samples/java/hibernate/src/main/resources/drop-data-model.sql b/samples/java/hibernate/src/main/resources/drop-data-model.sql new file mode 100644 index 000000000..977cebd9b --- /dev/null +++ b/samples/java/hibernate/src/main/resources/drop-data-model.sql @@ -0,0 +1,10 @@ +-- Executing the schema drop in a batch will improve execution speed. +start batch ddl; + +drop table if exists concerts; +drop table if exists venues; +drop table if exists tracks; +drop table if exists albums; +drop table if exists singers; + +run batch; \ No newline at end of file diff --git a/samples/java/hibernate/src/main/resources/hibernate.properties b/samples/java/hibernate/src/main/resources/hibernate.properties new file mode 100644 index 000000000..be825c6ac --- /dev/null +++ b/samples/java/hibernate/src/main/resources/hibernate.properties @@ -0,0 +1,14 @@ +# [START spanner_hibernate_config] +hibernate.dialect org.hibernate.dialect.PostgreSQLDialect +hibernate.connection.driver_class org.postgresql.Driver +# [END spanner_hibernate_config] + +hibernate.connection.url jdbc:postgresql://localhost:5432/test-database +hibernate.connection.username pratick + +hibernate.connection.pool_size 5 + +hibernate.show_sql true +hibernate.format_sql true + +hibernate.hbm2ddl.auto validate diff --git a/samples/java/hibernate/src/main/resources/log4f.properties b/samples/java/hibernate/src/main/resources/log4f.properties new file mode 100644 index 000000000..a4e96acc5 --- /dev/null +++ b/samples/java/hibernate/src/main/resources/log4f.properties @@ -0,0 +1,14 @@ +# Direct log messages to stdout +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.Target=System.out +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%d{ABSOLUTE} %5p %c{1}:%L - %m%n + +# Root logger option +log4j.rootLogger=INFO, stdout + +# Log everything. Good for troubleshooting +log4j.logger.org.hibernate=INFO + +# Log all JDBC parameters +log4j.logger.org.hibernate.type=ALL diff --git a/samples/java/hibernate/src/main/resources/sample-schema-sql b/samples/java/hibernate/src/main/resources/sample-schema-sql new file mode 100644 index 000000000..1ae32e684 --- /dev/null +++ b/samples/java/hibernate/src/main/resources/sample-schema-sql @@ -0,0 +1,58 @@ +-- Executing the schema creation in a batch will improve execution speed. +start batch ddl; + +create table if not exists singers ( + id varchar not null primary key, + first_name varchar, + last_name varchar not null, + full_name varchar generated always as (coalesce(concat(first_name, ' '::varchar, last_name), last_name)) stored, + active boolean, + created_at timestamptz, + updated_at timestamptz +); + +create table if not exists albums ( + id varchar not null primary key, + title varchar not null, + marketing_budget numeric, + release_date date, + cover_picture bytea, + singer_id varchar not null, + created_at timestamptz, + updated_at timestamptz, + constraint fk_albums_singers foreign key (singer_id) references singers (id) +); + +create table if not exists tracks ( + id varchar not null, + track_number bigint not null, + title varchar not null, + sample_rate float8 not null, + created_at timestamptz, + updated_at timestamptz, + primary key (id, track_number) +) interleave in parent albums on delete cascade; + +create table if not exists venues ( + id varchar not null primary key, + name varchar not null, + description varchar not null, + created_at timestamptz, + updated_at timestamptz +); + +create table if not exists concerts ( + id varchar not null primary key, + venue_id varchar not null, + singer_id varchar not null, + name varchar not null, + start_time timestamptz not null, + end_time timestamptz not null, + created_at timestamptz, + updated_at timestamptz, + constraint fk_concerts_venues foreign key (venue_id) references venues (id), + constraint fk_concerts_singers foreign key (singer_id) references singers (id), + constraint chk_end_time_after_start_time check (end_time > start_time) +); + +run batch; \ No newline at end of file diff --git a/samples/java/liquibase/README.md b/samples/java/liquibase/README.md index 0c23d0331..a727e60d4 100644 --- a/samples/java/liquibase/README.md +++ b/samples/java/liquibase/README.md @@ -1,6 +1,6 @@ # PGAdapter and Liquibase -PGAdapter can be used in combination with Liquibase, but with a number of limitations. This sample +PGAdapter has Pilot Support for [Liquibase](https://www.liquibase.org/). This sample shows the command line arguments and configuration that is needed in order to use Liquibase with PGAdapter. diff --git a/samples/java/liquibase/liquibase.properties b/samples/java/liquibase/liquibase.properties index 000b194af..bb0aae572 100644 --- a/samples/java/liquibase/liquibase.properties +++ b/samples/java/liquibase/liquibase.properties @@ -7,4 +7,4 @@ changeLogFile: dbchangelog.xml # DDL transactions into DDL batches. # See https://github.com/GoogleCloudPlatform/pgadapter/blob/postgresql-dialect/docs/ddl.md for more # information. -url: jdbc:postgresql://localhost:5432/liquibase-test?options=-c%20spanner.ddl_transaction_mode=AutocommitExplicitTransaction%20-c%20server_version=14.1 +url: jdbc:postgresql://localhost:5432/liquibase-test?options=-c%20spanner.ddl_transaction_mode=AutocommitExplicitTransaction diff --git a/samples/java/liquibase/pom.xml b/samples/java/liquibase/pom.xml index 8a8ed516a..0da1e29a9 100644 --- a/samples/java/liquibase/pom.xml +++ b/samples/java/liquibase/pom.xml @@ -71,7 +71,7 @@ org.postgresql postgresql - 42.4.1 + 42.4.3 diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java b/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java index ef42c8d72..5492e44cf 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java @@ -37,6 +37,7 @@ import com.google.cloud.spanner.pgadapter.error.SQLState; import com.google.cloud.spanner.pgadapter.error.Severity; import com.google.cloud.spanner.pgadapter.metadata.ConnectionMetadata; +import com.google.cloud.spanner.pgadapter.metadata.DescribeResult; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata.SslMode; import com.google.cloud.spanner.pgadapter.statements.CopyStatement; @@ -70,6 +71,7 @@ import java.util.Properties; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; @@ -95,7 +97,7 @@ public class ConnectionHandler extends Thread { private final ProxyServer server; private Socket socket; private final Map statementsMap = new HashMap<>(); - private final Cache autoDescribedStatementsCache = + private final Cache> autoDescribedStatementsCache = CacheBuilder.newBuilder() .expireAfterWrite(Duration.ofMinutes(30L)) .maximumSize(5000L) @@ -587,13 +589,13 @@ public void registerStatement(String statementName, IntermediatePreparedStatemen * Returns the parameter types of a cached auto-described statement, or null if none is available * in the cache. */ - public int[] getAutoDescribedStatement(String sql) { + public Future getAutoDescribedStatement(String sql) { return this.autoDescribedStatementsCache.getIfPresent(sql); } /** Stores the parameter types of an auto-described statement in the cache. */ - public void registerAutoDescribedStatement(String sql, int[] parameterTypes) { - this.autoDescribedStatementsCache.put(sql, parameterTypes); + public void registerAutoDescribedStatement(String sql, Future describeResult) { + this.autoDescribedStatementsCache.put(sql, describeResult); } public void closeStatement(String statementName) { diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/error/PGException.java b/src/main/java/com/google/cloud/spanner/pgadapter/error/PGException.java index 5b6650f95..ce6f99b94 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/error/PGException.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/error/PGException.java @@ -26,9 +26,10 @@ public class PGException extends RuntimeException { public static class Builder { private final String message; - private Severity severity; + private Severity severity = Severity.ERROR; private SQLState sqlState; private String hints; + private Throwable cause; private Builder(String message) { this.message = message; @@ -49,15 +50,20 @@ public Builder setHints(String hints) { return this; } + public Builder setCause(Throwable cause) { + this.cause = cause; + return this; + } + public PGException build() { - return new PGException(severity, sqlState, message, hints); + return new PGException(cause, severity, sqlState, message, hints); } } public static Builder newBuilder(Exception cause) { Preconditions.checkNotNull(cause); - return new Builder( - cause.getMessage() == null ? cause.getClass().getName() : cause.getMessage()); + return new Builder(cause.getMessage() == null ? cause.getClass().getName() : cause.getMessage()) + .setCause(cause); } public static Builder newBuilder(String message) { @@ -68,8 +74,9 @@ public static Builder newBuilder(String message) { private final SQLState sqlState; private final String hints; - private PGException(Severity severity, SQLState sqlState, String message, String hints) { - super(message == null ? "" : message); + private PGException( + Throwable cause, Severity severity, SQLState sqlState, String message, String hints) { + super(message == null ? "" : message, cause); this.severity = severity; this.sqlState = sqlState; this.hints = hints; diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/error/PGExceptionFactory.java b/src/main/java/com/google/cloud/spanner/pgadapter/error/PGExceptionFactory.java index 207cea96d..72753eaaf 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/error/PGExceptionFactory.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/error/PGExceptionFactory.java @@ -14,6 +14,8 @@ package com.google.cloud.spanner.pgadapter.error; +import static com.google.cloud.spanner.pgadapter.statements.BackendConnection.TRANSACTION_ABORTED_ERROR; + import com.google.api.core.InternalApi; import com.google.cloud.spanner.SpannerException; import io.grpc.StatusRuntimeException; @@ -58,17 +60,28 @@ public static PGException newQueryCancelledException() { return newPGException("Query cancelled", SQLState.QueryCanceled); } + /** + * Creates a new exception that indicates that the current transaction is in the aborted state. + */ + public static PGException newTransactionAbortedException() { + return newPGException(TRANSACTION_ABORTED_ERROR, SQLState.InFailedSqlTransaction); + } + /** Converts the given {@link SpannerException} to a {@link PGException}. */ public static PGException toPGException(SpannerException spannerException) { return newPGException(extractMessage(spannerException)); } /** Converts the given {@link Exception} to a {@link PGException}. */ - public static PGException toPGException(Exception exception) { - if (exception instanceof SpannerException) { - return toPGException((SpannerException) exception); + public static PGException toPGException(Throwable throwable) { + if (throwable instanceof SpannerException) { + return toPGException((SpannerException) throwable); + } + if (throwable instanceof PGException) { + return (PGException) throwable; } - return newPGException(exception.getMessage()); + return newPGException( + throwable.getMessage() == null ? throwable.getClass().getName() : throwable.getMessage()); } private static final String NOT_FOUND_PREFIX = diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/error/SQLState.java b/src/main/java/com/google/cloud/spanner/pgadapter/error/SQLState.java index 1d47d0b6a..89be8e33d 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/error/SQLState.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/error/SQLState.java @@ -135,6 +135,20 @@ public enum SQLState { SqlJsonScalarRequired("2203F"), SqlJsonItemCannotBeCastToTargetType("2203G"), + // Class 25 - Invalid Transaction State + InvalidTransactionState("25000"), + ActiveSqlTransaction("25001"), + BranchTransactionAlreadyActive("25002"), + HeldCursorRequiresSameIsolationLevel("25008"), + InappropriateAccessModeForBranchTransaction("25003"), + InappropriateIsolationLevelForBranchTransaction("25004"), + NoActiveSqlTransactionForBranchTransaction("25005"), + ReadOnlySqlTransaction("25006"), + SchemaAndDataStatementMixingNotSupported("25007"), + NoActiveSqlTransaction("25P01"), + InFailedSqlTransaction("25P02"), + IdleInTransactionSessionTimeout("25P03"), + // Class 42 — Syntax Error or Access Rule Violation SyntaxErrorOrAccessRuleViolation("42000"), SyntaxError("42601"), diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/metadata/DescribePortalMetadata.java b/src/main/java/com/google/cloud/spanner/pgadapter/metadata/DescribePortalMetadata.java deleted file mode 100644 index b8611a985..000000000 --- a/src/main/java/com/google/cloud/spanner/pgadapter/metadata/DescribePortalMetadata.java +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.cloud.spanner.pgadapter.metadata; - -import com.google.api.core.InternalApi; -import com.google.cloud.spanner.ResultSet; - -/** Simple POJO to hold describe metadata, specific to portal describes. */ -@InternalApi -public class DescribePortalMetadata extends DescribeMetadata { - - public DescribePortalMetadata(ResultSet metadata) { - this.metadata = metadata; - } -} diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/metadata/DescribeResult.java b/src/main/java/com/google/cloud/spanner/pgadapter/metadata/DescribeResult.java new file mode 100644 index 000000000..8e43e0507 --- /dev/null +++ b/src/main/java/com/google/cloud/spanner/pgadapter/metadata/DescribeResult.java @@ -0,0 +1,86 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.metadata; + +import com.google.api.core.InternalApi; +import com.google.cloud.spanner.ResultSet; +import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; +import com.google.cloud.spanner.pgadapter.error.SQLState; +import com.google.cloud.spanner.pgadapter.parsers.Parser; +import com.google.spanner.v1.StructType; +import java.util.Arrays; +import javax.annotation.Nullable; + +@InternalApi +public class DescribeResult { + @Nullable private final ResultSet resultSet; + private final int[] parameters; + + public DescribeResult(int[] givenParameterTypes, @Nullable ResultSet resultMetadata) { + this.resultSet = resultMetadata; + this.parameters = extractParameters(givenParameterTypes, resultMetadata); + } + + static int[] extractParameters(int[] givenParameterTypes, @Nullable ResultSet resultSet) { + if (resultSet == null || !resultSet.getMetadata().hasUndeclaredParameters()) { + return givenParameterTypes; + } + return extractParameterTypes( + givenParameterTypes, resultSet.getMetadata().getUndeclaredParameters()); + } + + static int[] extractParameterTypes(int[] givenParameterTypes, StructType parameters) { + int[] result; + int maxParamIndex = maxParamNumber(parameters); + if (maxParamIndex == givenParameterTypes.length) { + result = givenParameterTypes; + } else { + result = + Arrays.copyOf(givenParameterTypes, Math.max(givenParameterTypes.length, maxParamIndex)); + } + for (int i = 0; i < parameters.getFieldsCount(); i++) { + // Only override parameter types that were not specified by the frontend. + int paramIndex = Integer.parseInt(parameters.getFields(i).getName().substring(1)) - 1; + if (paramIndex >= givenParameterTypes.length || givenParameterTypes[paramIndex] == 0) { + result[paramIndex] = Parser.toOid(parameters.getFields(i).getType()); + } + } + return result; + } + + static int maxParamNumber(StructType parameters) { + int max = 0; + for (int i = 0; i < parameters.getFieldsCount(); i++) { + try { + int paramIndex = Integer.parseInt(parameters.getFields(i).getName().substring(1)); + if (paramIndex > max) { + max = paramIndex; + } + } catch (NumberFormatException numberFormatException) { + throw PGExceptionFactory.newPGException( + "Invalid parameter name: " + parameters.getFields(i).getName(), SQLState.InternalError); + } + } + return max; + } + + public int[] getParameters() { + return parameters; + } + + public @Nullable ResultSet getResultSet() { + return resultSet; + } +} diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/metadata/DescribeStatementMetadata.java b/src/main/java/com/google/cloud/spanner/pgadapter/metadata/DescribeStatementMetadata.java deleted file mode 100644 index badb818f3..000000000 --- a/src/main/java/com/google/cloud/spanner/pgadapter/metadata/DescribeStatementMetadata.java +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.cloud.spanner.pgadapter.metadata; - -import com.google.api.core.InternalApi; -import com.google.cloud.Tuple; -import com.google.cloud.spanner.ResultSet; - -/** Simple POJO to hold describe metadata specific to prepared statements. */ -@InternalApi -public class DescribeStatementMetadata extends DescribeMetadata> - implements AutoCloseable { - - public DescribeStatementMetadata(int[] parameters, ResultSet resultMetaData) { - this.metadata = Tuple.of(parameters, resultMetaData); - } - - public int[] getParameters() { - return metadata.x(); - } - - public ResultSet getResultSet() { - return metadata.y(); - } - - @Override - public void close() { - if (getResultSet() != null) { - getResultSet().close(); - } - } -} diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/metadata/OptionsMetadata.java b/src/main/java/com/google/cloud/spanner/pgadapter/metadata/OptionsMetadata.java index 549945327..6abac46f3 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/metadata/OptionsMetadata.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/metadata/OptionsMetadata.java @@ -42,6 +42,11 @@ /** Metadata extractor for CLI. */ public class OptionsMetadata { + /** Returns true if the current JVM is Java 8. */ + public static boolean isJava8() { + return System.getProperty("java.version").startsWith("1.8"); + } + public enum SslMode { /** Disables SSL connections. This is the default. */ Disable { @@ -83,7 +88,7 @@ public enum DdlTransactionMode { } private static final Logger logger = Logger.getLogger(OptionsMetadata.class.getName()); - private static final String DEFAULT_SERVER_VERSION = "1.0"; + private static final String DEFAULT_SERVER_VERSION = "14.1"; private static final String DEFAULT_USER_AGENT = "pg-adapter"; private static final String OPTION_SERVER_PORT = "s"; diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/metadata/SendResultSetState.java b/src/main/java/com/google/cloud/spanner/pgadapter/metadata/SendResultSetState.java index b5ef45708..ad796c518 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/metadata/SendResultSetState.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/metadata/SendResultSetState.java @@ -31,7 +31,9 @@ public SendResultSetState(String commandTag, long numRowsSent, boolean hasMoreRo } public String getCommandAndNumRows() { - return getCommandTag() + " " + getNumberOfRowsSent(); + String command = getCommandTag(); + command += ("INSERT".equals(command) ? " 0 " : " ") + getNumberOfRowsSent(); + return command; } public String getCommandTag() { diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/ArrayParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/ArrayParser.java index 4b5003ef4..50c417dea 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/ArrayParser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/ArrayParser.java @@ -21,6 +21,7 @@ import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Type.Code; import com.google.cloud.spanner.Value; +import com.google.cloud.spanner.pgadapter.session.SessionState; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.LinkedList; @@ -39,8 +40,10 @@ public class ArrayParser extends Parser> { private final Type arrayElementType; private final boolean isStringEquivalent; + private final SessionState sessionState; - public ArrayParser(ResultSet item, int position) { + public ArrayParser(ResultSet item, int position, SessionState sessionState) { + this.sessionState = sessionState; if (item != null) { this.arrayElementType = item.getColumnType(position).getArrayElementType(); if (this.arrayElementType.getCode() == Code.ARRAY) { @@ -133,7 +136,9 @@ public String stringParse() { List results = new LinkedList<>(); for (Object currentItem : this.item) { results.add( - stringify(Parser.create(currentItem, this.arrayElementType.getCode()).stringParse())); + stringify( + Parser.create(currentItem, this.arrayElementType.getCode(), sessionState) + .stringParse())); } return results.stream() .collect(Collectors.joining(ARRAY_DELIMITER, PG_ARRAY_OPEN, PG_ARRAY_CLOSE)); @@ -147,7 +152,9 @@ protected String spannerParse() { List results = new LinkedList<>(); for (Object currentItem : this.item) { results.add( - stringify(Parser.create(currentItem, this.arrayElementType.getCode()).spannerParse())); + stringify( + Parser.create(currentItem, this.arrayElementType.getCode(), sessionState) + .spannerParse())); } return results.stream() .collect(Collectors.joining(ARRAY_DELIMITER, SPANNER_ARRAY_OPEN, SPANNER_ARRAY_CLOSE)); @@ -169,7 +176,9 @@ protected byte[] binaryParse() { if (currentItem == null) { arrayStream.write(IntegerParser.binaryParse(-1)); } else { - byte[] data = Parser.create(currentItem, this.arrayElementType.getCode()).binaryParse(); + byte[] data = + Parser.create(currentItem, this.arrayElementType.getCode(), sessionState) + .binaryParse(); arrayStream.write(IntegerParser.binaryParse(data.length)); arrayStream.write(data); } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/Parser.java b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/Parser.java index 0b7964f17..16141cec1 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/Parser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/Parser.java @@ -22,9 +22,11 @@ import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Type.Code; import com.google.cloud.spanner.pgadapter.ProxyServer.DataFormat; +import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; +import com.google.cloud.spanner.pgadapter.error.SQLState; +import com.google.cloud.spanner.pgadapter.session.SessionState; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; -import java.util.Set; import org.postgresql.core.Oid; /** @@ -52,38 +54,17 @@ public static FormatCode of(short code) { protected static final Charset UTF8 = StandardCharsets.UTF_8; protected T item; - /** - * Guess the type of a parameter with unspecified type. - * - * @param item The value to guess the type for - * @param formatCode The encoding that is used for the value - * @return The {@link Oid} type code that is guessed for the value or {@link Oid#UNSPECIFIED} if - * no type could be guessed. - */ - private static int guessType(Set guessTypes, byte[] item, FormatCode formatCode) { - if (formatCode == FormatCode.TEXT && item != null) { - String value = new String(item, StandardCharsets.UTF_8); - if (guessTypes.contains(Oid.TIMESTAMPTZ) && TimestampParser.isTimestamp(value)) { - return Oid.TIMESTAMPTZ; - } - if (guessTypes.contains(Oid.DATE) && DateParser.isDate(value)) { - return Oid.DATE; - } - } - return Oid.UNSPECIFIED; - } - /** * Factory method to create a Parser subtype with a designated type from a byte array. * * @param item The data to be parsed * @param oidType The type of the designated data * @param formatCode The format of the data to be parsed - * @param guessTypes The OIDs of the types that may be 'guessed' based on the input value + * @param sessionState The session state to use when parsing and converting * @return The parser object for the designated data type. */ public static Parser create( - Set guessTypes, byte[] item, int oidType, FormatCode formatCode) { + SessionState sessionState, byte[] item, int oidType, FormatCode formatCode) { switch (oidType) { case Oid.BOOL: case Oid.BIT: @@ -108,19 +89,15 @@ public static Parser create( case Oid.TEXT: case Oid.VARCHAR: return new StringParser(item, formatCode); + case Oid.UUID: + return new UuidParser(item, formatCode); case Oid.TIMESTAMP: case Oid.TIMESTAMPTZ: - return new TimestampParser(item, formatCode); + return new TimestampParser(item, formatCode, sessionState); case Oid.JSONB: return new JsonbParser(item, formatCode); case Oid.UNSPECIFIED: - // Try to guess the type based on the value. Use an unspecified parser if no type could be - // determined. - int type = guessType(guessTypes, item, formatCode); - if (type == Oid.UNSPECIFIED) { - return new UnspecifiedParser(item, formatCode); - } - return create(guessTypes, item, type, formatCode); + return new UnspecifiedParser(item, formatCode); default: throw new IllegalArgumentException("Unsupported parameter type: " + oidType); } @@ -134,7 +111,8 @@ public static Parser create( * @param columnarPosition Column from the result to be parsed. * @return The parser object for the designated data type. */ - public static Parser create(ResultSet result, Type type, int columnarPosition) { + public static Parser create( + ResultSet result, Type type, int columnarPosition, SessionState sessionState) { switch (type.getCode()) { case BOOL: return new BooleanParser(result, columnarPosition); @@ -151,11 +129,11 @@ public static Parser create(ResultSet result, Type type, int columnarPosition case STRING: return new StringParser(result, columnarPosition); case TIMESTAMP: - return new TimestampParser(result, columnarPosition); + return new TimestampParser(result, columnarPosition, sessionState); case PG_JSONB: return new JsonbParser(result, columnarPosition); case ARRAY: - return new ArrayParser(result, columnarPosition); + return new ArrayParser(result, columnarPosition, sessionState); case NUMERIC: case JSON: case STRUCT: @@ -171,7 +149,7 @@ public static Parser create(ResultSet result, Type type, int columnarPosition * @param typeCode The type of the object to be parsed. * @return The parser object for the designated data type. */ - protected static Parser create(Object result, Code typeCode) { + protected static Parser create(Object result, Code typeCode, SessionState sessionState) { switch (typeCode) { case BOOL: return new BooleanParser(result); @@ -188,7 +166,7 @@ protected static Parser create(Object result, Code typeCode) { case STRING: return new StringParser(result); case TIMESTAMP: - return new TimestampParser(result); + return new TimestampParser(result, sessionState); case PG_JSONB: return new JsonbParser(result); case NUMERIC: @@ -261,6 +239,64 @@ public static int toOid(Type type) { ErrorCode.INVALID_ARGUMENT, "Unsupported or unknown type: " + type); } } + /** + * Translates the given Cloud Spanner {@link Type} to a PostgreSQL OID constant. + * + * @param type the type to translate + * @return The OID constant value for the type + */ + public static int toOid(com.google.spanner.v1.Type type) { + switch (type.getCode()) { + case BOOL: + return Oid.BOOL; + case INT64: + return Oid.INT8; + case NUMERIC: + return Oid.NUMERIC; + case FLOAT64: + return Oid.FLOAT8; + case STRING: + return Oid.VARCHAR; + case JSON: + return Oid.JSONB; + case BYTES: + return Oid.BYTEA; + case TIMESTAMP: + return Oid.TIMESTAMPTZ; + case DATE: + return Oid.DATE; + case ARRAY: + switch (type.getArrayElementType().getCode()) { + case BOOL: + return Oid.BOOL_ARRAY; + case INT64: + return Oid.INT8_ARRAY; + case NUMERIC: + return Oid.NUMERIC_ARRAY; + case FLOAT64: + return Oid.FLOAT8_ARRAY; + case STRING: + return Oid.VARCHAR_ARRAY; + case JSON: + return Oid.JSONB_ARRAY; + case BYTES: + return Oid.BYTEA_ARRAY; + case TIMESTAMP: + return Oid.TIMESTAMPTZ_ARRAY; + case DATE: + return Oid.DATE_ARRAY; + case ARRAY: + case STRUCT: + default: + throw PGExceptionFactory.newPGException( + "Unsupported or unknown array type: " + type, SQLState.InternalError); + } + case STRUCT: + default: + throw PGExceptionFactory.newPGException( + "Unsupported or unknown type: " + type, SQLState.InternalError); + } + } /** Returns the item helder by this parser. */ public T getItem() { diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParser.java index 802122fef..7b5961c5a 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParser.java @@ -22,8 +22,13 @@ import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.pgadapter.ProxyServer.DataFormat; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; +import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; +import com.google.cloud.spanner.pgadapter.session.SessionState; import com.google.common.base.Preconditions; import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.OffsetDateTime; +import java.time.ZoneId; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; import java.time.temporal.ChronoField; @@ -54,24 +59,40 @@ public class TimestampParser extends Parser { private static final Pattern TIMESTAMP_PATTERN = Pattern.compile(TIMESTAMP_REGEX); - private static final DateTimeFormatter TIMESTAMP_FORMATTER = + private static final DateTimeFormatter TIMESTAMP_OUTPUT_FORMATTER = new DateTimeFormatterBuilder() .parseLenient() .parseCaseInsensitive() .appendPattern("yyyy-MM-dd HH:mm:ss") .appendFraction(ChronoField.NANO_OF_SECOND, 0, 9, true) - .appendOffset("+HH:mm", "Z") + // Java 8 does not support seconds in timezone offset. + .appendOffset(OptionsMetadata.isJava8() ? "+HH:mm" : "+HH:mm:ss", "+00") .toFormatter(); - TimestampParser(ResultSet item, int position) { + private static final DateTimeFormatter TIMESTAMP_INPUT_FORMATTER = + new DateTimeFormatterBuilder() + .parseLenient() + .parseCaseInsensitive() + .appendPattern("yyyy-MM-dd HH:mm:ss") + .appendFraction(ChronoField.NANO_OF_SECOND, 0, 9, true) + // Java 8 does not support seconds in timezone offset. + .appendOffset(OptionsMetadata.isJava8() ? "+HH:mm" : "+HH:mm:ss", "+00:00:00") + .toFormatter(); + + private final SessionState sessionState; + + TimestampParser(ResultSet item, int position, SessionState sessionState) { this.item = item.getTimestamp(position); + this.sessionState = sessionState; } - TimestampParser(Object item) { + TimestampParser(Object item, SessionState sessionState) { this.item = (Timestamp) item; + this.sessionState = sessionState; } - TimestampParser(byte[] item, FormatCode formatCode) { + TimestampParser(byte[] item, FormatCode formatCode, SessionState sessionState) { + this.sessionState = sessionState; if (item != null) { switch (formatCode) { case TEXT: @@ -103,7 +124,7 @@ public static Timestamp toTimestamp(@Nonnull byte[] data) { public static Timestamp toTimestamp(String value) { try { String stringValue = toPGString(value); - TemporalAccessor temporalAccessor = TIMESTAMP_FORMATTER.parse(stringValue); + TemporalAccessor temporalAccessor = TIMESTAMP_INPUT_FORMATTER.parse(stringValue); return Timestamp.ofTimeSecondsAndNanos( temporalAccessor.getLong(ChronoField.INSTANT_SECONDS), temporalAccessor.get(ChronoField.NANO_OF_SECOND)); @@ -125,7 +146,7 @@ static boolean isTimestamp(String value) { @Override public String stringParse() { - return this.item == null ? null : toPGString(this.item.toString()); + return this.item == null ? null : toPGString(this.item, sessionState.getTimezone()); } @Override @@ -150,12 +171,13 @@ static byte[] convertToPG(Timestamp value) { return result; } - public static byte[] convertToPG(ResultSet resultSet, int position, DataFormat format) { + public static byte[] convertToPG( + ResultSet resultSet, int position, DataFormat format, ZoneId zoneId) { switch (format) { case SPANNER: return resultSet.getTimestamp(position).toString().getBytes(StandardCharsets.UTF_8); case POSTGRESQL_TEXT: - return toPGString(resultSet.getTimestamp(position).toString()) + return toPGString(resultSet.getTimestamp(position), zoneId) .getBytes(StandardCharsets.UTF_8); case POSTGRESQL_BINARY: return convertToPG(resultSet.getTimestamp(position)); @@ -174,6 +196,13 @@ private static String toPGString(String value) { return value.replace(TIMESTAMP_SEPARATOR, EMPTY_SPACE).replace(ZERO_TIMEZONE, PG_ZERO_TIMEZONE); } + private static String toPGString(Timestamp value, ZoneId zoneId) { + OffsetDateTime offsetDateTime = + OffsetDateTime.ofInstant( + Instant.ofEpochSecond(value.getSeconds(), value.getNanos()), zoneId); + return TIMESTAMP_OUTPUT_FORMATTER.format(offsetDateTime); + } + @Override public void bind(Statement.Builder statementBuilder, String name) { statementBuilder.bind(name).to(this.item); diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/UuidParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/UuidParser.java new file mode 100644 index 000000000..01bebe591 --- /dev/null +++ b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/UuidParser.java @@ -0,0 +1,115 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.parsers; + +import com.google.api.core.InternalApi; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.pgadapter.error.PGException; +import com.google.cloud.spanner.pgadapter.error.SQLState; +import com.google.cloud.spanner.pgadapter.error.Severity; +import java.nio.charset.StandardCharsets; +import java.util.UUID; +import javax.annotation.Nonnull; +import org.postgresql.util.ByteConverter; + +/** + * Translate from wire protocol to UUID. This is currently a one-way conversion, as we only accept + * UUID as a parameter type. UUIDs are converted to strings. + */ +@InternalApi +public class UuidParser extends Parser { + + UuidParser(byte[] item, FormatCode formatCode) { + switch (formatCode) { + case TEXT: + this.item = + item == null ? null : verifyStringValue(new String(item, StandardCharsets.UTF_8)); + break; + case BINARY: + this.item = verifyBinaryValue(item); + break; + default: + handleInvalidFormat(formatCode); + } + } + + static void handleInvalidFormat(FormatCode formatCode) { + throw PGException.newBuilder("Unsupported format: " + formatCode.name()) + .setSQLState(SQLState.InternalError) + .setSeverity(Severity.ERROR) + .build(); + } + + @Override + public String stringParse() { + return this.item; + } + + @Override + protected byte[] binaryParse() { + if (this.item == null) { + return null; + } + return binaryEncode(this.item); + } + + static String verifyStringValue(@Nonnull String value) { + try { + //noinspection ResultOfMethodCallIgnored + UUID.fromString(value); + return value; + } catch (Exception exception) { + throw createInvalidUuidValueException(value, exception); + } + } + + static String verifyBinaryValue(byte[] value) { + if (value == null) { + return null; + } + if (value.length != 16) { + throw PGException.newBuilder("Invalid UUID binary length: " + value.length) + .setSeverity(Severity.ERROR) + .setSQLState(SQLState.InvalidParameterValue) + .build(); + } + return new UUID(ByteConverter.int8(value, 0), ByteConverter.int8(value, 8)).toString(); + } + + static byte[] binaryEncode(String value) { + try { + UUID uuid = UUID.fromString(value); + byte[] val = new byte[16]; + ByteConverter.int8(val, 0, uuid.getMostSignificantBits()); + ByteConverter.int8(val, 8, uuid.getLeastSignificantBits()); + return val; + } catch (Exception exception) { + throw createInvalidUuidValueException(value, exception); + } + } + + static PGException createInvalidUuidValueException(String value, Exception cause) { + return PGException.newBuilder("Invalid UUID: " + value) + .setSeverity(Severity.ERROR) + .setSQLState(SQLState.InvalidParameterValue) + .setCause(cause) + .build(); + } + + @Override + public void bind(Statement.Builder statementBuilder, String name) { + statementBuilder.bind(name).to(this.item); + } +} diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/session/PGSetting.java b/src/main/java/com/google/cloud/spanner/pgadapter/session/PGSetting.java index 63300cbfe..cc39550d7 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/session/PGSetting.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/session/PGSetting.java @@ -25,10 +25,8 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; -import java.util.Arrays; -import java.util.Locale; -import java.util.Objects; -import java.util.Scanner; +import java.time.ZoneId; +import java.util.*; import java.util.stream.Collectors; import javax.annotation.Nonnull; @@ -380,7 +378,7 @@ void setSetting(Context context, String value) { } if (this.vartype != null) { // Check validity of the value. - checkValidValue(value); + value = checkValidValue(value); } this.setting = value; } @@ -420,7 +418,7 @@ static SpannerException invalidContextError(String key, Context context) { } } - private void checkValidValue(String value) { + private String checkValidValue(String value) { if ("bool".equals(this.vartype)) { // Just verify that it is a valid boolean. This will throw an IllegalArgumentException if // setting is not a valid boolean value. @@ -445,7 +443,32 @@ private void checkValidValue(String value) { && !upperCaseEnumVals.contains( MoreObjects.firstNonNull(value, "").toUpperCase(Locale.ENGLISH))) { throw invalidEnumError(getCasePreservingKey(), value, enumVals); + } else if ("TimeZone".equals(this.name)) { + try { + value = convertToValidZoneId(value); + ZoneId.of(value); + } catch (Exception ignore) { + throw invalidValueError(this.name, value); + } + } + return value; + } + + static String convertToValidZoneId(String value) { + if ("utc".equalsIgnoreCase(value)) { + return "UTC"; } + for (String zoneId : ZoneId.getAvailableZoneIds()) { + if (zoneId.equalsIgnoreCase(value)) { + return zoneId; + } + } + for (Map.Entry shortId : ZoneId.SHORT_IDS.entrySet()) { + if (shortId.getKey().equalsIgnoreCase(value)) { + return shortId.getValue(); + } + } + return value; } static SpannerException invalidBoolError(String key) { diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/session/SessionState.java b/src/main/java/com/google/cloud/spanner/pgadapter/session/SessionState.java index 060f564c2..1c05bb33d 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/session/SessionState.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/session/SessionState.java @@ -16,7 +16,6 @@ import static com.google.cloud.spanner.pgadapter.session.CopySettings.initCopySettings; -import com.google.api.client.util.Strings; import com.google.api.core.InternalApi; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.SpannerException; @@ -30,8 +29,8 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.ImmutableSet.Builder; import com.google.common.collect.Sets; +import java.time.ZoneId; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -44,7 +43,6 @@ import java.util.Set; import java.util.concurrent.Callable; import java.util.stream.Collectors; -import javax.annotation.Nonnull; /** {@link SessionState} contains all session variables for a connection. */ @InternalApi @@ -166,6 +164,10 @@ private static String generatePgSettingsColumnExpressions() { } private static String toKey(String extension, String name) { + if (extension == null && "timezone".equalsIgnoreCase(name)) { + // TimeZone is the only special setting that uses CamelCase. + return "TimeZone"; + } return extension == null ? name.toLowerCase(Locale.ROOT) : extension.toLowerCase(Locale.ROOT) + "." + name.toLowerCase(Locale.ROOT); @@ -393,41 +395,28 @@ public DdlTransactionMode getDdlTransactionMode() { () -> DdlTransactionMode.valueOf(setting.getBootVal())); } - /** - * Returns a set of OIDs that PGAdapter should try to guess if it receives an untyped parameter - * value. This is needed because some clients (JDBC) deliberately send parameters without a type - * code to force the server to infer the type. This specifically applies to date/timestamp - * parameters. - */ - public Set getGuessTypes() { - PGSetting setting = internalGet(toKey("spanner", "guess_types"), false); - if (setting == null || Strings.isNullOrEmpty(setting.getSetting())) { - return ImmutableSet.of(); + /** Returns the {@link ZoneId} of the current timezone for this session. */ + public ZoneId getTimezone() { + PGSetting setting = internalGet(toKey(null, "timezone"), false); + if (setting == null) { + return ZoneId.systemDefault(); } - return convertOidListToSet(setting.getSetting()); + String id = + tryGetFirstNonNull( + ZoneId.systemDefault().getId(), + setting::getSetting, + setting::getResetVal, + setting::getBootVal); + + return zoneIdFromString(id); } - /** Keep a cache of 1 element ready as the setting is not likely to change often. */ - private final Map> cachedGuessTypes = new HashMap<>(1); - - Set convertOidListToSet(@Nonnull String value) { - if (cachedGuessTypes.containsKey(value)) { - return cachedGuessTypes.get(value); - } - - Builder builder = ImmutableSet.builder(); - String[] oids = value.split(","); - for (String oid : oids) { - try { - builder.add(Integer.valueOf(oid)); - } catch (Exception ignore) { - // ignore invalid oids. - } + private ZoneId zoneIdFromString(String value) { + try { + return ZoneId.of(value); + } catch (Throwable ignore) { + return ZoneId.systemDefault(); } - cachedGuessTypes.clear(); - cachedGuessTypes.put(value, builder.build()); - - return cachedGuessTypes.get(value); } @SafeVarargs diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/BackendConnection.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/BackendConnection.java index 69bdc6098..b94a77757 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/BackendConnection.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/BackendConnection.java @@ -26,6 +26,7 @@ import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.Partition; import com.google.cloud.spanner.PartitionOptions; +import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.ResultSets; import com.google.cloud.spanner.Spanner; @@ -44,7 +45,9 @@ import com.google.cloud.spanner.connection.ResultSetHelper; import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.connection.StatementResult.ClientSideStatementType; +import com.google.cloud.spanner.pgadapter.error.PGException; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; +import com.google.cloud.spanner.pgadapter.error.SQLState; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata.DdlTransactionMode; import com.google.cloud.spanner.pgadapter.session.SessionState; @@ -78,6 +81,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import java.util.function.Function; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -126,6 +130,12 @@ enum TransactionMode { DDL_BATCH, } + static PGException setAndReturn(SettableFuture future, Throwable throwable) { + PGException pgException = PGExceptionFactory.toPGException(throwable); + future.setException(pgException); + return pgException; + } + /** * Buffered statements are kept in memory until a flush or sync message is received. This makes it * possible to batch multiple statements together when sending them to Cloud Spanner. @@ -141,25 +151,50 @@ abstract class BufferedStatement { this.result = SettableFuture.create(); } + boolean isBatchingPossible() { + return false; + } + abstract void execute(); void checkConnectionState() { // Only COMMIT or ROLLBACK is allowed if we are in an ABORTED transaction. if (connectionState == ConnectionState.ABORTED && !(isCommit(parsedStatement) || isRollback(parsedStatement))) { - throw SpannerExceptionFactory.newSpannerException( - ErrorCode.INVALID_ARGUMENT, TRANSACTION_ABORTED_ERROR); + throw PGExceptionFactory.newTransactionAbortedException(); } } } private final class Execute extends BufferedStatement { - Execute(ParsedStatement parsedStatement, Statement statement) { + private final Function statementBinder; + private final boolean analyze; + + Execute( + ParsedStatement parsedStatement, + Statement statement, + Function statementBinder) { + this(parsedStatement, statement, statementBinder, false); + } + + Execute( + ParsedStatement parsedStatement, + Statement statement, + Function statementBinder, + boolean analyze) { super(parsedStatement, statement); + this.statementBinder = statementBinder; + this.analyze = analyze; + } + + @Override + boolean isBatchingPossible() { + return !analyze; } @Override void execute() { + Statement updatedStatement = statement; try { checkConnectionState(); // TODO(b/235719478): If the statement is a BEGIN statement and there is a COMMIT statement @@ -202,14 +237,18 @@ void execute() { } else if (statement.getSql().isEmpty()) { result.set(NO_RESULT); } else if (parsedStatement.isDdl()) { - result.set(ddlExecutor.execute(parsedStatement, statement)); + if (analyze) { + result.set(NO_RESULT); + } else { + result.set(ddlExecutor.execute(parsedStatement, statement)); + } } else { // Potentially replace pg_catalog table references with common table expressions. - Statement updatedStatement = + updatedStatement = sessionState.isReplacePgCatalogTables() ? pgCatalog.replacePgCatalogTables(statement) : statement; - result.set(spannerConnection.execute(updatedStatement)); + result.set(bindAndExecute(updatedStatement)); } } catch (SpannerException spannerException) { // Executing queries against the information schema in a transaction is unsupported. @@ -222,25 +261,61 @@ void execute() { spannerConnection .getDatabaseClient() .singleUse() - .executeQuery(statement)))); + .executeQuery(updatedStatement)))); return; } catch (Exception exception) { - result.setException(exception); - throw exception; + throw setAndReturn(result, exception); } } if (spannerException.getErrorCode() == ErrorCode.CANCELLED || Thread.interrupted()) { - result.setException(PGExceptionFactory.newQueryCancelledException()); + throw setAndReturn(result, PGExceptionFactory.newQueryCancelledException()); } else { - result.setException(spannerException); + throw setAndReturn(result, spannerException); } - throw spannerException; } catch (Throwable exception) { - result.setException(exception); - throw exception; + throw setAndReturn(result, exception); } } + StatementResult bindAndExecute(Statement statement) { + statement = statementBinder.apply(statement); + if (analyze) { + ResultSet resultSet; + if (parsedStatement.isUpdate() && !parsedStatement.hasReturningClause()) { + // TODO(#477): Single analyzeUpdate statements that are executed in an implicit + // transaction could use a single-use read/write transaction. Replays are not + // dangerous for those. + + // We handle one very specific use case here to prevent unnecessary problems: If the user + // has started a DML batch and is then analyzing an update statement (probably a prepared + // statement), then we use a separate transaction for that. + if (spannerConnection.isDmlBatchActive()) { + final Statement statementToAnalyze = statement; + resultSet = + spannerConnection + .getDatabaseClient() + .readWriteTransaction() + .run( + transaction -> { + ResultSet updateStatementMetadata = + transaction.analyzeUpdateStatement( + statementToAnalyze, QueryAnalyzeMode.PLAN); + updateStatementMetadata.next(); + return updateStatementMetadata; + }); + } else { + resultSet = spannerConnection.analyzeUpdateStatement(statement, QueryAnalyzeMode.PLAN); + } + } else if (parsedStatement.isQuery() || parsedStatement.hasReturningClause()) { + resultSet = spannerConnection.analyzeQuery(statement, QueryAnalyzeMode.PLAN); + } else { + return NO_RESULT; + } + return new QueryResult(resultSet); + } + return spannerConnection.execute(statement); + } + /** * Returns true if the given exception is the error that is returned by Cloud Spanner when an * INFORMATION_SCHEMA query is not executed in a single-use read-only transaction. @@ -307,8 +382,7 @@ void execute() { } } } catch (Exception exception) { - result.setException(exception); - throw exception; + throw setAndReturn(result, exception); } } } @@ -425,8 +499,17 @@ public ConnectionState getConnectionState() { * message is received. The returned future will contain the result of the statement when * execution has finished. */ - public Future execute(ParsedStatement parsedStatement, Statement statement) { - Execute execute = new Execute(parsedStatement, statement); + public Future execute( + ParsedStatement parsedStatement, + Statement statement, + Function statementBinder) { + Execute execute = new Execute(parsedStatement, statement, statementBinder); + bufferedStatements.add(execute); + return execute.result; + } + + public Future analyze(ParsedStatement parsedStatement, Statement statement) { + Execute execute = new Execute(parsedStatement, statement, Function.identity(), true); bufferedStatements.add(execute); return execute.result; } @@ -548,7 +631,7 @@ private void flush(boolean isSync) { spannerConnection.beginTransaction(); } boolean canUseBatch = false; - if (index < (getStatementCount() - 1)) { + if (bufferedStatement.isBatchingPossible() && index < (getStatementCount() - 1)) { StatementType statementType = getStatementType(index); StatementType nextStatementType = getStatementType(index + 1); canUseBatch = canBeBatchedTogether(statementType, nextStatementType); @@ -674,9 +757,9 @@ private void prepareExecuteDdl(BufferedStatement bufferedStatement) { // Single DDL statements outside explicit transactions are always allowed. For a single // statement, there can also not be an implicit transaction that needs to be committed. if (transactionMode == TransactionMode.EXPLICIT) { - throw SpannerExceptionFactory.newSpannerException( - ErrorCode.FAILED_PRECONDITION, - "DDL statements are only allowed outside explicit transactions."); + throw PGExceptionFactory.newPGException( + "DDL statements are only allowed outside explicit transactions.", + SQLState.InvalidTransactionState); } // Fall-through to commit the transaction if necessary. case AutocommitExplicitTransaction: @@ -700,23 +783,23 @@ private void prepareExecuteDdl(BufferedStatement bufferedStatement) { // We are in a batch of statements. switch (ddlTransactionMode) { case Single: - throw SpannerExceptionFactory.newSpannerException( - ErrorCode.FAILED_PRECONDITION, - "DDL statements are only allowed outside batches and transactions."); + throw PGExceptionFactory.newPGException( + "DDL statements are only allowed outside batches and transactions.", + SQLState.InvalidTransactionState); case Batch: if (spannerConnection.isInTransaction() || bufferedStatements.stream() .anyMatch(statement -> !statement.parsedStatement.isDdl())) { - throw SpannerExceptionFactory.newSpannerException( - ErrorCode.FAILED_PRECONDITION, - "DDL statements are not allowed in mixed batches or transactions."); + throw PGExceptionFactory.newPGException( + "DDL statements are not allowed in mixed batches or transactions.", + SQLState.InvalidTransactionState); } break; case AutocommitImplicitTransaction: if (spannerConnection.isInTransaction() && transactionMode != TransactionMode.IMPLICIT) { - throw SpannerExceptionFactory.newSpannerException( - ErrorCode.FAILED_PRECONDITION, - "DDL statements are only allowed outside explicit transactions."); + throw PGExceptionFactory.newPGException( + "DDL statements are only allowed outside explicit transactions.", + SQLState.InvalidTransactionState); } // Fallthrough to commit the transaction if necessary. case AutocommitExplicitTransaction: @@ -734,9 +817,8 @@ private void prepareExecuteDdl(BufferedStatement bufferedStatement) { } } } - } catch (SpannerException exception) { - bufferedStatement.result.setException(exception); - throw exception; + } catch (Throwable throwable) { + throw setAndReturn(bufferedStatement.result, throwable); } } @@ -819,16 +901,10 @@ int executeStatementsInBatch(int fromIndex) { Preconditions.checkArgument( canBeBatchedTogether(getStatementType(fromIndex), getStatementType(fromIndex + 1))); StatementType batchType = getStatementType(fromIndex); - switch (batchType) { - case UPDATE: - spannerConnection.startBatchDml(); - break; - case DDL: - spannerConnection.startBatchDdl(); - break; - default: - throw SpannerExceptionFactory.newSpannerException( - ErrorCode.INVALID_ARGUMENT, "Statement type is not supported for batching"); + if (batchType == StatementType.UPDATE) { + spannerConnection.startBatchDml(); + } else if (batchType == StatementType.DDL) { + spannerConnection.startBatchDdl(); } List statementResults = new ArrayList<>(getStatementCount()); int index = fromIndex; @@ -845,7 +921,8 @@ int executeStatementsInBatch(int fromIndex) { bufferedStatements.get(index).parsedStatement, bufferedStatements.get(index).statement)); } else { - spannerConnection.execute(bufferedStatements.get(index).statement); + Execute execute = (Execute) bufferedStatements.get(index); + execute.bindAndExecute(execute.statement); } index++; } else { diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java index f411f0caf..7c49c4bd1 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java @@ -17,7 +17,6 @@ import com.google.api.core.InternalApi; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.ResultSet; -import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.Type; @@ -109,7 +108,18 @@ public CopyStatement( ParsedStatement parsedStatement, Statement originalStatement, ParsedCopyStatement parsedCopyStatement) { - super(connectionHandler, options, name, parsedStatement, originalStatement); + super( + name, + new IntermediatePreparedStatement( + connectionHandler, + options, + name, + NO_PARAMETER_TYPES, + parsedStatement, + originalStatement), + NO_PARAMS, + ImmutableList.of(), + ImmutableList.of()); this.parsedCopyStatement = parsedCopyStatement; } @@ -118,12 +128,6 @@ public boolean hasException() { return this.exception != null; } - @Override - public SpannerException getException() { - // Do not clear exceptions on a CopyStatement. - return this.exception; - } - @Override public long getUpdateCount() { // COPY statements continue to execute while the server continues to receive a stream of @@ -356,7 +360,7 @@ private int queryIndexedColumnsCount( } @Override - public IntermediatePortalStatement bind( + public IntermediatePortalStatement createPortal( String name, byte[][] parameters, List parameterFormatCodes, @@ -393,8 +397,7 @@ public void executeAsync(BackendConnection backendConnection) { mutationWriter, executor)); } catch (Exception e) { - SpannerException spannerException = SpannerExceptionFactory.asSpannerException(e); - handleExecutionException(spannerException); + handleExecutionException(PGExceptionFactory.toPGException(e)); } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatement.java index 06361ece6..4c4c0f959 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyToStatement.java @@ -21,13 +21,14 @@ import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.AbstractStatementParser.StatementType; import com.google.cloud.spanner.connection.Connection; +import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.ProxyServer.DataFormat; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; import com.google.cloud.spanner.pgadapter.error.SQLState; -import com.google.cloud.spanner.pgadapter.metadata.DescribePortalMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.parsers.Parser; +import com.google.cloud.spanner.pgadapter.session.SessionState; import com.google.cloud.spanner.pgadapter.statements.CopyStatement.Format; import com.google.cloud.spanner.pgadapter.statements.CopyStatement.ParsedCopyStatement; import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName; @@ -37,6 +38,7 @@ import com.google.cloud.spanner.pgadapter.wireoutput.CopyOutResponse; import com.google.cloud.spanner.pgadapter.wireoutput.WireOutput; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.Futures; import java.util.List; import java.util.concurrent.Future; @@ -63,11 +65,17 @@ public CopyToStatement( String name, ParsedCopyStatement parsedCopyStatement) { super( - connectionHandler, - options, name, - createParsedStatement(parsedCopyStatement), - createSelectStatement(parsedCopyStatement)); + new IntermediatePreparedStatement( + connectionHandler, + options, + name, + NO_PARAMETER_TYPES, + createParsedStatement(parsedCopyStatement), + createSelectStatement(parsedCopyStatement)), + NO_PARAMS, + ImmutableList.of(), + ImmutableList.of()); this.parsedCopyStatement = parsedCopyStatement; if (parsedCopyStatement.format == CopyStatement.Format.BINARY) { this.csvFormat = null; @@ -194,14 +202,14 @@ public void executeAsync(BackendConnection backendConnection) { } @Override - public Future describeAsync(BackendConnection backendConnection) { + public Future describeAsync(BackendConnection backendConnection) { // Return null to indicate that this COPY TO STDOUT statement does not return any // RowDescriptionResponse. return Futures.immediateFuture(null); } @Override - public IntermediatePortalStatement bind( + public IntermediatePortalStatement createPortal( String name, byte[][] parameters, List parameterFormatCodes, @@ -245,11 +253,17 @@ public WireOutput[] createResultSuffix() { CopyDataResponse createDataResponse(ResultSet resultSet) { String[] data = new String[resultSet.getColumnCount()]; + SessionState sessionState = + getConnectionHandler() + .getExtendedQueryProtocolHandler() + .getBackendConnection() + .getSessionState(); for (int col = 0; col < resultSet.getColumnCount(); col++) { if (resultSet.isNull(col)) { data[col] = null; } else { - Parser parser = Parser.create(resultSet, resultSet.getColumnType(col), col); + Parser parser = + Parser.create(resultSet, resultSet.getColumnType(col), col, sessionState); data[col] = parser.stringParse(); } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/DeallocateStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/DeallocateStatement.java index b0f9774c9..43041d182 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/DeallocateStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/DeallocateStatement.java @@ -19,13 +19,14 @@ import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.AbstractStatementParser.StatementType; +import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; -import com.google.cloud.spanner.pgadapter.metadata.DescribePortalMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.NoResult; import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.Futures; import java.util.List; import java.util.concurrent.Future; @@ -47,7 +48,18 @@ public DeallocateStatement( String name, ParsedStatement parsedStatement, Statement originalStatement) { - super(connectionHandler, options, name, parsedStatement, originalStatement); + super( + name, + new IntermediatePreparedStatement( + connectionHandler, + options, + name, + NO_PARAMETER_TYPES, + parsedStatement, + originalStatement), + NO_PARAMS, + ImmutableList.of(), + ImmutableList.of()); this.deallocateStatement = parse(originalStatement.getSql()); } @@ -78,14 +90,14 @@ public void executeAsync(BackendConnection backendConnection) { } @Override - public Future describeAsync(BackendConnection backendConnection) { + public Future describeAsync(BackendConnection backendConnection) { // Return null to indicate that this EXECUTE statement does not return any // RowDescriptionResponse. return Futures.immediateFuture(null); } @Override - public IntermediatePortalStatement bind( + public IntermediatePortalStatement createPortal( String name, byte[][] parameters, List parameterFormatCodes, diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExecuteStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExecuteStatement.java index dd3c9c506..f8fa71d90 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExecuteStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExecuteStatement.java @@ -20,9 +20,9 @@ import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.AbstractStatementParser.StatementType; +import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; -import com.google.cloud.spanner.pgadapter.metadata.DescribePortalMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName; import com.google.cloud.spanner.pgadapter.wireprotocol.BindMessage; @@ -31,6 +31,7 @@ import com.google.cloud.spanner.pgadapter.wireprotocol.DescribeMessage; import com.google.cloud.spanner.pgadapter.wireprotocol.ExecuteMessage; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.Futures; import java.nio.charset.StandardCharsets; import java.util.Collections; @@ -58,7 +59,18 @@ public ExecuteStatement( String name, ParsedStatement parsedStatement, Statement originalStatement) { - super(connectionHandler, options, name, parsedStatement, originalStatement); + super( + name, + new IntermediatePreparedStatement( + connectionHandler, + options, + name, + NO_PARAMETER_TYPES, + parsedStatement, + originalStatement), + NO_PARAMS, + ImmutableList.of(), + ImmutableList.of()); this.executeStatement = parse(originalStatement.getSql()); } @@ -98,14 +110,14 @@ public void executeAsync(BackendConnection backendConnection) { } @Override - public Future describeAsync(BackendConnection backendConnection) { + public Future describeAsync(BackendConnection backendConnection) { // Return null to indicate that this EXECUTE statement does not return any // RowDescriptionResponse. return Futures.immediateFuture(null); } @Override - public IntermediatePortalStatement bind( + public IntermediatePortalStatement createPortal( String name, byte[][] parameters, List parameterFormatCodes, diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePortalStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePortalStatement.java index 0080cebe8..80e3edd65 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePortalStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePortalStatement.java @@ -16,13 +16,9 @@ import com.google.api.core.InternalApi; import com.google.cloud.spanner.Statement; -import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.StatementResult; -import com.google.cloud.spanner.pgadapter.ConnectionHandler; -import com.google.cloud.spanner.pgadapter.metadata.DescribePortalMetadata; -import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; -import com.google.common.util.concurrent.Futures; -import java.util.ArrayList; +import com.google.cloud.spanner.pgadapter.parsers.Parser; +import com.google.cloud.spanner.pgadapter.parsers.Parser.FormatCode; import java.util.List; import java.util.concurrent.Future; @@ -32,23 +28,34 @@ */ @InternalApi public class IntermediatePortalStatement extends IntermediatePreparedStatement { - protected List parameterFormatCodes; - protected List resultFormatCodes; + static final byte[][] NO_PARAMS = new byte[0][]; + + private final IntermediatePreparedStatement preparedStatement; + private final byte[][] parameters; + protected final List parameterFormatCodes; + protected final List resultFormatCodes; public IntermediatePortalStatement( - ConnectionHandler connectionHandler, - OptionsMetadata options, String name, - ParsedStatement parsedStatement, - Statement originalStatement) { - super(connectionHandler, options, name, parsedStatement, originalStatement); - this.statement = originalStatement; - this.parameterFormatCodes = new ArrayList<>(); - this.resultFormatCodes = new ArrayList<>(); + IntermediatePreparedStatement preparedStatement, + byte[][] parameters, + List parameterFormatCodes, + List resultFormatCodes) { + super( + preparedStatement.connectionHandler, + preparedStatement.options, + name, + preparedStatement.givenParameterDataTypes, + preparedStatement.parsedStatement, + preparedStatement.originalStatement); + this.preparedStatement = preparedStatement; + this.parameters = parameters; + this.parameterFormatCodes = parameterFormatCodes; + this.resultFormatCodes = resultFormatCodes; } - void setBoundStatement(Statement statement) { - this.statement = statement; + public IntermediatePreparedStatement getPreparedStatement() { + return this.preparedStatement; } public short getParameterFormatCode(int index) { @@ -72,26 +79,50 @@ public short getResultFormatCode(int index) { } } - public void setParameterFormatCodes(List parameterFormatCodes) { - this.parameterFormatCodes = parameterFormatCodes; + @Override + public void executeAsync(BackendConnection backendConnection) { + // If the portal has already been described, the statement has already been executed, and we + // don't need to do that once more. + if (futureStatementResult == null && getStatementResult() == null) { + this.executed = true; + setFutureStatementResult(backendConnection.execute(parsedStatement, statement, this::bind)); + } } - public void setResultFormatCodes(List resultFormatCodes) { - this.resultFormatCodes = resultFormatCodes; + /** Binds this portal to a set of parameter values. */ + public Statement bind(Statement statement) { + // Make sure the results from any Describe message are propagated to the prepared statement + // before using it to bind the parameter values. + preparedStatement.describe(); + Statement.Builder builder = statement.toBuilder(); + for (int index = 0; index < parameters.length; index++) { + short formatCode = getParameterFormatCode(index); + int type = preparedStatement.getParameterDataType(index); + Parser parser = + Parser.create( + connectionHandler + .getExtendedQueryProtocolHandler() + .getBackendConnection() + .getSessionState(), + parameters[index], + type, + FormatCode.of(formatCode)); + parser.bind(builder, "p" + (index + 1)); + } + return builder.build(); } @Override - public Future describeAsync(BackendConnection backendConnection) { + public Future describeAsync(BackendConnection backendConnection) { // Pre-emptively execute the statement, even though it is only asked to be described. This is // a lot more efficient than taking two round trips to the server, and getting a // DescribePortal message without a following Execute message is extremely rare, as that would // only happen if the client is ill-behaved, or if the client crashes between the // DescribePortal and Execute. Future statementResultFuture = - backendConnection.execute(this.parsedStatement, this.statement); + backendConnection.execute(this.parsedStatement, this.statement, this::bind); setFutureStatementResult(statementResultFuture); - return Futures.lazyTransform( - statementResultFuture, - statementResult -> new DescribePortalMetadata(statementResult.getResultSet())); + this.executed = true; + return statementResultFuture; } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePreparedStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePreparedStatement.java index 36fdff297..4d8314f65 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePreparedStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediatePreparedStatement.java @@ -15,31 +15,19 @@ package com.google.cloud.spanner.pgadapter.statements; import com.google.api.core.InternalApi; -import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode; -import com.google.cloud.spanner.ResultSet; -import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.Statement; -import com.google.cloud.spanner.Type.StructField; import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; -import com.google.cloud.spanner.connection.Connection; +import com.google.cloud.spanner.connection.StatementResult; +import com.google.cloud.spanner.connection.StatementResult.ResultType; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; -import com.google.cloud.spanner.pgadapter.metadata.DescribeMetadata; -import com.google.cloud.spanner.pgadapter.metadata.DescribeStatementMetadata; +import com.google.cloud.spanner.pgadapter.metadata.DescribeResult; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; -import com.google.cloud.spanner.pgadapter.parsers.Parser; -import com.google.cloud.spanner.pgadapter.parsers.Parser.FormatCode; -import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Strings; -import com.google.common.collect.ImmutableSortedSet; -import java.util.ArrayList; +import com.google.common.util.concurrent.Futures; import java.util.Arrays; -import java.util.Comparator; import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; -import javax.annotation.Nullable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; import org.postgresql.core.Oid; /** @@ -47,518 +35,146 @@ */ @InternalApi public class IntermediatePreparedStatement extends IntermediateStatement { + static final int[] NO_PARAMETER_TYPES = new int[0]; + private final String name; - protected int[] parameterDataTypes; + protected final int[] givenParameterDataTypes; protected Statement statement; - private boolean described; + private Future describeResult; public IntermediatePreparedStatement( ConnectionHandler connectionHandler, OptionsMetadata options, String name, + int[] givenParameterDataTypes, ParsedStatement parsedStatement, Statement originalStatement) { super(connectionHandler, options, parsedStatement, originalStatement); this.name = name; - this.parameterDataTypes = null; + this.givenParameterDataTypes = givenParameterDataTypes; + this.statement = originalStatement; + } + + public int[] getGivenParameterDataTypes() { + return this.givenParameterDataTypes; } /** - * Given a set of parameters in byte format, return the designated type if stored by the user, - * otherwise guess that type. + * Returns the parameter type of the given parameter. This will only use the given parameter types + * if the statement has not been described, and the described parameter types if the statement has + * been described. * - * @param parameters Array of all parameters in byte format. - * @param index Index of the desired item. - * @return The type of the item specified. + * @param index Index of the desired parameter. + * @return The type of the parameter specified. */ - private int parseType(byte[][] parameters, int index) throws IllegalArgumentException { - if (this.parameterDataTypes.length > index) { - return this.parameterDataTypes[index]; + int getParameterDataType(int index) throws IllegalArgumentException { + int[] parameterDataTypes = this.givenParameterDataTypes; + if (this.described) { + parameterDataTypes = describe().getParameters(); + } + if (parameterDataTypes.length > index) { + return parameterDataTypes[index]; } else { return Oid.UNSPECIFIED; } } - public boolean isDescribed() { - return this.described; - } - - public void setDescribed() { - this.described = true; - } - - public int[] getParameterDataTypes() { - return this.parameterDataTypes; - } - - public void setParameterDataTypes(int[] parameterDataTypes) { - this.parameterDataTypes = parameterDataTypes; - } - - @Override - public void executeAsync(BackendConnection backendConnection) { - // If the portal has already been described, the statement has already been executed, and we - // don't need to do that once more. - if (futureStatementResult == null && getStatementResult() == null) { - this.executed = true; - setFutureStatementResult(backendConnection.execute(parsedStatement, statement)); - } - } - /** - * Bind this statement (that is to say, transform it into a portal by giving it the data items to - * complete the statement. + * Creates a portal from this statement. * - * @param parameters The array of parameters to be bound in byte format. * @param parameterFormatCodes A list of the format of each parameter. * @param resultFormatCodes A list of the desired format of each result. * @return An Intermediate Portal Statement (or rather a bound version of this statement) */ - public IntermediatePortalStatement bind( + public IntermediatePortalStatement createPortal( String name, byte[][] parameters, List parameterFormatCodes, List resultFormatCodes) { - IntermediatePortalStatement portal = - new IntermediatePortalStatement( - this.connectionHandler, - this.options, - name, - this.parsedStatement, - this.originalStatement); - portal.setParameterFormatCodes(parameterFormatCodes); - portal.setResultFormatCodes(resultFormatCodes); - Statement.Builder builder = this.originalStatement.toBuilder(); - for (int index = 0; index < parameters.length; index++) { - short formatCode = portal.getParameterFormatCode(index); - int type = this.parseType(parameters, index); - Parser parser = - Parser.create( - connectionHandler - .getExtendedQueryProtocolHandler() - .getBackendConnection() - .getSessionState() - .getGuessTypes(), - parameters[index], - type, - FormatCode.of(formatCode)); - parser.bind(builder, "p" + (index + 1)); - } - this.statement = builder.build(); - portal.setBoundStatement(statement); + return new IntermediatePortalStatement( + name, this, parameters, parameterFormatCodes, resultFormatCodes); + } - return portal; + @Override + public Future describeAsync(BackendConnection backendConnection) { + Future statementResultFuture = + backendConnection.analyze(this.parsedStatement, this.statement); + setFutureStatementResult(statementResultFuture); + this.describeResult = + Futures.lazyTransform( + statementResultFuture, + result -> + new DescribeResult( + this.givenParameterDataTypes, + result.getResultType() == ResultType.RESULT_SET + ? result.getResultSet() + : null)); + this.described = true; + return statementResultFuture; } @Override - public DescribeMetadata describe() { - ResultSet columnsResultSet = null; - if (this.parsedStatement.isQuery()) { - Statement statement = Statement.of(this.parsedStatement.getSqlWithoutComments()); - columnsResultSet = connection.analyzeQuery(statement, QueryAnalyzeMode.PLAN); - } - boolean describeSucceeded = describeParameters(null, false); - if (columnsResultSet != null) { - return new DescribeStatementMetadata(this.parameterDataTypes, columnsResultSet); + public DescribeResult describe() { + if (this.describeResult == null) { + return null; } - - if (this.parsedStatement.isUpdate() - && (!describeSucceeded || !Strings.isNullOrEmpty(this.name))) { - // Let the backend analyze the statement if it is a named prepared statement or if the query - // that was used to determine the parameter types failed, so we can return a reasonable - // error message if the statement is invalid. If it is the unnamed statement or getting the - // param types succeeded, we will let the following EXECUTE message handle that, instead of - // sending the statement twice to the backend. - connection.analyzeUpdate( - Statement.of(this.parsedStatement.getSqlWithoutComments()), QueryAnalyzeMode.PLAN); + try { + return this.describeResult.get(); + } catch (ExecutionException exception) { + throw PGExceptionFactory.toPGException(exception.getCause()); + } catch (InterruptedException interruptedException) { + throw PGExceptionFactory.newQueryCancelledException(); } - return new DescribeStatementMetadata(this.parameterDataTypes, null); } /** Describe the parameters of this statement. */ - public boolean describeParameters(byte[][] parameterValues, boolean isAutoDescribe) { - Set parameters = extractParameters(this.parsedStatement.getSqlWithoutComments()); - boolean describeSucceeded = true; - if (parameters.isEmpty()) { - ensureParameterLength(0); - } else if (parameters.size() != this.parameterDataTypes.length - || Arrays.stream(this.parameterDataTypes).anyMatch(p -> p == 0)) { + public void autoDescribeParameters( + byte[][] parameterValues, BackendConnection backendConnection) { + // Don't bother to auto-describe statements without any parameter values or with only + // null-valued parameters. + if (parameterValues == null + || parameterValues.length == 0 + || hasOnlyNullValues(parameterValues)) { + return; + } + if (parameterValues.length != this.givenParameterDataTypes.length + || Arrays.stream(this.givenParameterDataTypes).anyMatch(p -> p == 0)) { // Note: We are only asking the backend to parse the types if there is at least one // parameter with unspecified type. Otherwise, we will rely on the types given in PARSE. - - // If this describe-request is an auto-describe request, we can safely try to look it up in a - // cache. Also, we do not need to describe the parameter types if they are all null values. - if (isAutoDescribe) { - int[] cachedParameterTypes = - getConnectionHandler().getAutoDescribedStatement(this.originalStatement.getSql()); - if (cachedParameterTypes != null) { - this.parameterDataTypes = cachedParameterTypes; - return true; - } - if (hasOnlyNullValues(parameters.size(), parameterValues)) { - // Don't bother to describe null-valued parameter types. - return true; + // There is also no need to auto-describe if only null values are untyped. + boolean mustAutoDescribe = false; + for (int index = 0; index < parameterValues.length; index++) { + if (parameterValues[index] != null && getParameterDataType(index) == Oid.UNSPECIFIED) { + mustAutoDescribe = true; + break; } } - - // We cannot describe statements with more than 50 parameters, as Cloud Spanner does not allow - // queries that select from a sub-query to contain more than 50 columns in the select list. - if (parameters.size() > 50) { - throw PGExceptionFactory.newPGException( - "Cannot describe statements with more than 50 parameters"); + if (!mustAutoDescribe) { + return; } - // Transform the statement into a select statement that selects the parameters, and then - // extract the types from the result set metadata. - Statement selectParamsStatement = transformToSelectParams(parameters); - if (selectParamsStatement == null) { - // The transformation failed. Just rely on the types given in the PARSE message. If the - // transformation failed because the statement was malformed, the backend will catch that - // at a later stage. - describeSucceeded = false; - ensureParameterLength(parameters.size()); - } else { - try (ResultSet paramsResultSet = - isAutoDescribe - ? connection - .getDatabaseClient() - .singleUse() - .analyzeQuery(selectParamsStatement, QueryAnalyzeMode.PLAN) - : connection.analyzeQuery(selectParamsStatement, QueryAnalyzeMode.PLAN)) { - extractParameterTypes(paramsResultSet); - if (isAutoDescribe) { - getConnectionHandler() - .registerAutoDescribedStatement( - this.originalStatement.getSql(), this.parameterDataTypes); - } - } catch (SpannerException exception) { - // Ignore here and rely on the types given in PARSE. - describeSucceeded = false; - ensureParameterLength(parameters.size()); - } + // As this describe-request is an auto-describe request, we can safely try to look it up in a + // cache. + Future cachedDescribeResult = + getConnectionHandler().getAutoDescribedStatement(this.originalStatement.getSql()); + if (cachedDescribeResult != null) { + this.described = true; + this.describeResult = cachedDescribeResult; + return; } + // No cached result found. Add a describe-statement message to the queue. + describeAsync(backendConnection); + getConnectionHandler() + .registerAutoDescribedStatement(this.originalStatement.getSql(), this.describeResult); } - return describeSucceeded; } - boolean hasOnlyNullValues(int numParameters, byte[][] parameterValues) { - for (int paramIndex = 0; paramIndex < numParameters; paramIndex++) { - if (parseType(null, paramIndex) == Oid.UNSPECIFIED - && parameterValues != null - && parameterValues[paramIndex] != null) { + boolean hasOnlyNullValues(byte[][] parameterValues) { + for (byte[] parameterValue : parameterValues) { + if (parameterValue != null) { return false; } } return true; } - - /** - * Extracts the statement parameters from the given sql string and returns these as a sorted set. - * The parameters are ordered by their index and not by the textual value (i.e. "$9" comes before - * "$10"). - */ - @VisibleForTesting - static ImmutableSortedSet extractParameters(String sql) { - return ImmutableSortedSet.orderedBy( - Comparator.comparing(o -> Integer.valueOf(o.substring(1)))) - .addAll(PARSER.getQueryParameters(sql)) - .build(); - } - - /** - * Transforms a query or DML statement into a SELECT statement that selects the parameters in the - * statements. Examples: - * - *
    - *
  • select * from foo where id=$1 is transformed to - * select $1 from (select * from foo where id=$1) p - *
  • insert into foo (id, value) values ($1, $2) is transformed to - * select $1, $2 from (select id=$1, value=$2 from foo) p - *
- */ - @VisibleForTesting - Statement transformToSelectParams(Set parameters) { - switch (this.parsedStatement.getType()) { - case QUERY: - return transformSelectToSelectParams( - this.parsedStatement.getSqlWithoutComments(), parameters); - case UPDATE: - return transformDmlToSelectParams(parameters); - case CLIENT_SIDE: - case DDL: - case UNKNOWN: - default: - return Statement.of(this.parsedStatement.getSqlWithoutComments()); - } - } - - /** - * Transforms a query into one that selects the parameters in the query. - * - *

Example: select id, value from foo where value like $1 is transformed to - * select $1, $2 from (select id, value from foo where value like $1) p - */ - private static Statement transformSelectToSelectParams(String sql, Set parameters) { - return Statement.of(String.format("select %s from (%s) p", String.join(", ", parameters), sql)); - } - - /** - * Transforms a DML statement into a SELECT statement that selects the parameters in the DML - * statement. - */ - private Statement transformDmlToSelectParams(Set parameters) { - switch (getCommand()) { - case "INSERT": - return transformInsertToSelectParams( - this.connection, this.parsedStatement.getSqlWithoutComments(), parameters); - case "UPDATE": - return transformUpdateToSelectParams( - this.parsedStatement.getSqlWithoutComments(), parameters); - case "DELETE": - return transformDeleteToSelectParams( - this.parsedStatement.getSqlWithoutComments(), parameters); - default: - return null; - } - } - - /** - * Transforms an INSERT statement into a SELECT statement that selects the parameters in the - * insert statement. The way this is done depends on whether the INSERT statement uses a VALUES - * clause or a SELECT statement. If the INSERT statement uses a SELECT clause, the same strategy - * is used as for normal SELECT statements. For INSERT statements with a VALUES clause, a SELECT - * statement is created that selects a comparison between the column where a value is inserted and - * the expression that is used to insert a value in the column. - * - *

Examples: - * - *

    - *
  • insert into foo (id, value) values ($1, $2) is transformed to - * select $1, $2 from (select id=$1, value=$2 from foo) p - *
  • - * insert into bar (id, value, created_at) values (1, $1 + sqrt($2), current_timestamp()) - * is transformed to - * select $1, $2 from (select value=$1 + sqrt($2) from bar) p - *
  • insert into foo values ($1, $2) is transformed to - * select $1, $2 from (select id=$1, value=$2 from foo) p - *
- */ - @VisibleForTesting - static @Nullable Statement transformInsertToSelectParams( - Connection connection, String sql, Set parameters) { - SimpleParser parser = new SimpleParser(sql); - if (!parser.eatKeyword("insert")) { - return null; - } - parser.eatKeyword("into"); - TableOrIndexName table = parser.readTableOrIndexName(); - if (table == null) { - return null; - } - parser.skipWhitespaces(); - - List columnsList = null; - int posBeforeToken = parser.getPos(); - if (parser.eatToken("(")) { - if (parser.peekKeyword("select") || parser.peekToken("(")) { - // Revert and assume that the insert uses a select statement. - parser.setPos(posBeforeToken); - } else { - columnsList = parser.parseExpressionList(); - if (!parser.eatToken(")")) { - return null; - } - } - } - - parser.skipWhitespaces(); - int potentialSelectStart = parser.getPos(); - if (!parser.eatKeyword("values")) { - while (parser.eatToken("(")) { - // ignore - } - if (parser.eatKeyword("select")) { - // This is an `insert into [(...)] select ...` statement. Then we can just use the - // select statement as the result. - return transformSelectToSelectParams( - parser.getSql().substring(potentialSelectStart), parameters); - } - return null; - } - - if (columnsList == null || columnsList.isEmpty()) { - columnsList = getAllColumns(connection, table); - } - List> rows = new ArrayList<>(); - while (parser.eatToken("(")) { - List row = parser.parseExpressionList(); - if (row == null - || row.isEmpty() - || !parser.eatToken(")") - || row.size() != columnsList.size()) { - return null; - } - rows.add(row); - if (!parser.eatToken(",")) { - break; - } - } - if (rows.isEmpty()) { - return null; - } - StringBuilder select = new StringBuilder("select "); - select.append(String.join(", ", parameters)).append(" from (select "); - - int columnIndex = 0; - int colCount = rows.size() * columnsList.size(); - for (List row : rows) { - for (int index = 0; index < row.size(); index++) { - select.append(columnsList.get(index)).append("=").append(row.get(index)); - columnIndex++; - if (columnIndex < colCount) { - select.append(", "); - } - } - } - select.append(" from ").append(table).append(") p"); - - return Statement.of(select.toString()); - } - - /** - * Returns a list of all columns in the given table. This is used to transform insert statements - * without a column list. The query that is used does not use the INFORMATION_SCHEMA, but queries - * the table directly, so it can use the same transaction as the actual insert statement. - */ - static List getAllColumns(Connection connection, TableOrIndexName table) { - try (ResultSet resultSet = - connection.analyzeQuery( - Statement.of("SELECT * FROM " + table + " LIMIT 1"), QueryAnalyzeMode.PLAN)) { - return resultSet.getType().getStructFields().stream() - .map(StructField::getName) - .collect(Collectors.toList()); - } - } - - /** - * Transforms an UPDATE statement into a SELECT statement that selects the parameters in the - * update statement. This is done by creating a SELECT statement that selects the assignment - * expressions in the UPDATE statement, followed by the WHERE clause of the UPDATE statement. - * - *

Examples: - * - *

    - *
  • update foo set value=$1 where id=$2 is transformed to - * select $1, $2 from (select value=$1 from foo where id=$2) p - *
  • update bar set value=$1+sqrt($2), updated_at=current_timestamp() is - * transformed to select $1, $2 from (select value=$1+sqrt($2) from foo) p - *
- */ - @VisibleForTesting - static Statement transformUpdateToSelectParams(String sql, Set parameters) { - SimpleParser parser = new SimpleParser(sql); - if (!parser.eatKeyword("update")) { - return null; - } - parser.eatKeyword("only"); - TableOrIndexName table = parser.readTableOrIndexName(); - if (table == null) { - return null; - } - if (!parser.eatKeyword("set")) { - return null; - } - List assignmentsList = parser.parseExpressionListUntilKeyword("where", true); - if (assignmentsList == null || assignmentsList.isEmpty()) { - return null; - } - int whereStart = parser.getPos(); - if (!parser.eatKeyword("where")) { - whereStart = -1; - } - - StringBuilder select = new StringBuilder("select "); - select - .append(String.join(", ", parameters)) - .append(" from (select ") - .append(String.join(", ", assignmentsList)) - .append(" from ") - .append(table); - if (whereStart > -1) { - select.append(" ").append(sql.substring(whereStart)); - } - select.append(") p"); - - return Statement.of(select.toString()); - } - - /** - * Transforms a DELETE statement into a SELECT statement that selects the parameters of the DELETE - * statement. This is done by creating a SELECT 1 FROM table_name WHERE ... statement from the - * DELETE statement. - * - *

Example: - * - *

    - *
  • DELETE FROM foo WHERE id=$1 is transformed to - * SELECT $1 FROM (SELECT 1 FROM foo WHERE id=$1) p - *
- */ - @VisibleForTesting - static Statement transformDeleteToSelectParams(String sql, Set parameters) { - SimpleParser parser = new SimpleParser(sql); - if (!parser.eatKeyword("delete")) { - return null; - } - parser.eatKeyword("from"); - TableOrIndexName table = parser.readTableOrIndexName(); - if (table == null) { - return null; - } - parser.skipWhitespaces(); - int whereStart = parser.getPos(); - if (!parser.eatKeyword("where")) { - // Deletes must have a where clause, otherwise there cannot be any parameters. - return null; - } - - StringBuilder select = - new StringBuilder("select ") - .append(String.join(", ", parameters)) - .append(" from (select 1 from ") - .append(table) - .append(" ") - .append(sql.substring(whereStart)) - .append(") p"); - - return Statement.of(select.toString()); - } - - /** - * Returns the parameter types in the SQL string of this statement. The current implementation - * always returns any parameters that may have been specified in the PARSE message, and - * OID.Unspecified for all other parameters. - */ - private void extractParameterTypes(ResultSet paramsResultSet) { - paramsResultSet.next(); - ensureParameterLength(paramsResultSet.getColumnCount()); - for (int i = 0; i < paramsResultSet.getColumnCount(); i++) { - // Only override parameter types that were not specified by the frontend. - if (this.parameterDataTypes[i] == 0) { - this.parameterDataTypes[i] = Parser.toOid(paramsResultSet.getColumnType(i)); - } - } - } - - /** - * Enlarges the size of the parameter types of this statement to match the given count. Existing - * parameter types are preserved. New parameters are set to OID.Unspecified. - */ - private void ensureParameterLength(int parameterCount) { - if (this.parameterDataTypes == null) { - this.parameterDataTypes = new int[parameterCount]; - } else if (this.parameterDataTypes.length != parameterCount) { - this.parameterDataTypes = Arrays.copyOf(this.parameterDataTypes, parameterCount); - } - } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatement.java index 604e4ab41..c870414fc 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatement.java @@ -19,8 +19,6 @@ import com.google.api.core.InternalApi; import com.google.cloud.spanner.Dialect; import com.google.cloud.spanner.ResultSet; -import com.google.cloud.spanner.SpannerException; -import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.connection.AbstractStatementParser; import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; @@ -28,15 +26,18 @@ import com.google.cloud.spanner.connection.Connection; import com.google.cloud.spanner.connection.PostgreSQLStatementParser; import com.google.cloud.spanner.connection.StatementResult; +import com.google.cloud.spanner.connection.StatementResult.ClientSideStatementType; import com.google.cloud.spanner.connection.StatementResult.ResultType; import com.google.cloud.spanner.pgadapter.ConnectionHandler; +import com.google.cloud.spanner.pgadapter.error.PGException; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; -import com.google.cloud.spanner.pgadapter.metadata.DescribeMetadata; +import com.google.cloud.spanner.pgadapter.metadata.DescribeResult; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.NoResult; import com.google.cloud.spanner.pgadapter.utils.Converter; import com.google.cloud.spanner.pgadapter.wireoutput.DataRowResponse; import com.google.cloud.spanner.pgadapter.wireoutput.WireOutput; +import com.google.common.annotations.VisibleForTesting; import java.io.DataOutputStream; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; @@ -70,11 +71,12 @@ public enum ResultNotReadyBehavior { protected StatementResult statementResult; protected boolean hasMoreData; protected Future futureStatementResult; - protected SpannerException exception; + protected PGException exception; protected final ParsedStatement parsedStatement; protected final Statement originalStatement; protected final String command; protected String commandTag; + protected boolean described; protected boolean executed; protected final Connection connection; protected final ConnectionHandler connectionHandler; @@ -136,7 +138,11 @@ public void close() throws Exception { /** @return True if this is a select statement, false otherwise. */ public boolean containsResultSet() { - return this.parsedStatement.isQuery(); + return this.parsedStatement.isQuery() + || (this.parsedStatement.getType() == StatementType.CLIENT_SIDE + && this.parsedStatement.getClientSideStatementType() + == ClientSideStatementType.RUN_BATCH) + || (this.parsedStatement.isUpdate() && this.parsedStatement.hasReturningClause()); } /** @return True if this statement was executed, False otherwise. */ @@ -167,7 +173,8 @@ public long getUpdateCount(ResultNotReadyBehavior resultNotReadyBehavior) { case QUERY: return -1L; case UPDATE: - return this.statementResult.getUpdateCount(); + long res = this.statementResult.getUpdateCount(); + return Math.max(res, 0L); case CLIENT_SIDE: case DDL: case UNKNOWN: @@ -214,7 +221,8 @@ public String getStatement() { return this.parsedStatement.getSqlWithoutComments(); } - private void initFutureResult(ResultNotReadyBehavior resultNotReadyBehavior) { + @VisibleForTesting + void initFutureResult(ResultNotReadyBehavior resultNotReadyBehavior) { if (this.futureStatementResult != null) { if (resultNotReadyBehavior == ResultNotReadyBehavior.FAIL && !this.futureStatementResult.isDone()) { @@ -223,12 +231,9 @@ private void initFutureResult(ResultNotReadyBehavior resultNotReadyBehavior) { try { setStatementResult(this.futureStatementResult.get()); } catch (ExecutionException executionException) { - setException(SpannerExceptionFactory.asSpannerException(executionException.getCause())); + setException(PGExceptionFactory.toPGException(executionException.getCause())); } catch (InterruptedException interruptedException) { - // TODO(b/246193644): Switch to PGException - setException( - SpannerExceptionFactory.asSpannerException( - PGExceptionFactory.newQueryCancelledException())); + setException(PGExceptionFactory.newQueryCancelledException()); } finally { this.futureStatementResult = null; } @@ -269,11 +274,11 @@ public String getSql() { } /** Returns any execution exception registered for this statement. */ - public SpannerException getException() { + public PGException getException() { return this.exception; } - void setException(SpannerException exception) { + void setException(PGException exception) { // Do not override any exception that has already been registered. COPY statements can receive // multiple errors as they execute asynchronously while receiving a stream of data from the // client. We always return the first exception that we encounter. @@ -287,11 +292,15 @@ void setException(SpannerException exception) { * * @param exception The exception to store. */ - public void handleExecutionException(SpannerException exception) { + public void handleExecutionException(PGException exception) { setException(exception); this.hasMoreData = false; } + public boolean isDescribed() { + return this.described; + } + public void executeAsync(BackendConnection backendConnection) { throw new UnsupportedOperationException(); } @@ -300,12 +309,12 @@ public void executeAsync(BackendConnection backendConnection) { * Moreso meant for inherited classes, allows one to call describe on a statement. Since raw * statements cannot be described, throw an error. */ - public DescribeMetadata describe() { + public DescribeResult describe() { throw new IllegalStateException( "Cannot describe a simple statement " + "(only prepared statements and portals)"); } - public Future> describeAsync(BackendConnection backendConnection) { + public Future describeAsync(BackendConnection backendConnection) { throw new UnsupportedOperationException(); } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/InvalidStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/InvalidStatement.java index afa387be9..c3b12977a 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/InvalidStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/InvalidStatement.java @@ -14,11 +14,12 @@ package com.google.cloud.spanner.pgadapter.statements; -import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.pgadapter.ConnectionHandler; +import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; +import com.google.common.collect.ImmutableList; public class InvalidStatement extends IntermediatePortalStatement { @@ -29,7 +30,18 @@ public InvalidStatement( ParsedStatement parsedStatement, Statement originalStatement, Exception exception) { - super(connectionHandler, options, name, parsedStatement, originalStatement); - setException(SpannerExceptionFactory.asSpannerException(exception)); + super( + name, + new IntermediatePreparedStatement( + connectionHandler, + options, + name, + NO_PARAMETER_TYPES, + parsedStatement, + originalStatement), + NO_PARAMS, + ImmutableList.of(), + ImmutableList.of()); + setException(PGExceptionFactory.toPGException(exception)); } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/PgCatalog.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/PgCatalog.java index 2a96ef030..7aa28a3d2 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/PgCatalog.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/PgCatalog.java @@ -27,6 +27,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; +import java.util.regex.Pattern; @InternalApi public class PgCatalog { @@ -35,15 +36,22 @@ public class PgCatalog { new TableOrIndexName("pg_catalog", "pg_namespace"), new TableOrIndexName(null, "pg_namespace"), new TableOrIndexName(null, "pg_namespace"), new TableOrIndexName(null, "pg_namespace"), + new TableOrIndexName("pg_catalog", "pg_class"), new TableOrIndexName(null, "pg_class"), + new TableOrIndexName(null, "pg_class"), new TableOrIndexName(null, "pg_class"), new TableOrIndexName("pg_catalog", "pg_type"), new TableOrIndexName(null, "pg_type"), new TableOrIndexName(null, "pg_type"), new TableOrIndexName(null, "pg_type"), new TableOrIndexName("pg_catalog", "pg_settings"), new TableOrIndexName(null, "pg_settings"), new TableOrIndexName(null, "pg_settings"), new TableOrIndexName(null, "pg_settings")); + private static final ImmutableMap FUNCTION_REPLACEMENTS = + ImmutableMap.of( + Pattern.compile("pg_catalog.pg_table_is_visible\\(.+\\)"), "true", + Pattern.compile("pg_table_is_visible\\(.+\\)"), "true"); private final Map pgCatalogTables = ImmutableMap.of( new TableOrIndexName(null, "pg_namespace"), new PgNamespace(), + new TableOrIndexName(null, "pg_class"), new PgClass(), new TableOrIndexName(null, "pg_type"), new PgType(), new TableOrIndexName(null, "pg_settings"), new PgSettings()); private final SessionState sessionState; @@ -70,9 +78,18 @@ public Statement replacePgCatalogTables(Statement statement) { return addCommonTableExpressions(replacedTablesStatement.y(), cteBuilder.build()); } + static String replaceKnownUnsupportedFunctions(Statement statement) { + String sql = statement.getSql(); + for (Entry functionReplacement : FUNCTION_REPLACEMENTS.entrySet()) { + sql = functionReplacement.getKey().matcher(sql).replaceAll(functionReplacement.getValue()); + } + return sql; + } + static Statement addCommonTableExpressions( Statement statement, ImmutableList tableExpressions) { - SimpleParser parser = new SimpleParser(statement.getSql()); + String sql = replaceKnownUnsupportedFunctions(statement); + SimpleParser parser = new SimpleParser(sql); boolean hadCommonTableExpressions = parser.eatKeyword("with"); String tableExpressionsSql = String.join(",\n", tableExpressions); Statement.Builder builder = @@ -224,4 +241,90 @@ public String getTableExpression() { return sessionState.generatePGSettingsCte(); } } + + private static class PgClass implements PgCatalogTable { + private static final String PG_CLASS_CTE = + "pg_class as (\n" + + " select\n" + + " -1 as oid,\n" + + " table_name as relname,\n" + + " case table_schema when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as relnamespace,\n" + + " 0 as reltype,\n" + + " 0 as reloftype,\n" + + " 0 as relowner,\n" + + " 1 as relam,\n" + + " 0 as relfilenode,\n" + + " 0 as reltablespace,\n" + + " 0 as relpages,\n" + + " 0.0::float8 as reltuples,\n" + + " 0 as relallvisible,\n" + + " 0 as reltoastrelid,\n" + + " false as relhasindex,\n" + + " false as relisshared,\n" + + " 'p' as relpersistence,\n" + + " 'r' as relkind,\n" + + " count(*) as relnatts,\n" + + " 0 as relchecks,\n" + + " false as relhasrules,\n" + + " false as relhastriggers,\n" + + " false as relhassubclass,\n" + + " false as relrowsecurity,\n" + + " false as relforcerowsecurity,\n" + + " true as relispopulated,\n" + + " 'n' as relreplident,\n" + + " false as relispartition,\n" + + " 0 as relrewrite,\n" + + " 0 as relfrozenxid,\n" + + " 0 as relminmxid,\n" + + " '{}'::bigint[] as relacl,\n" + + " '{}'::text[] as reloptions,\n" + + " 0 as relpartbound\n" + + "from information_schema.tables t\n" + + "inner join information_schema.columns using (table_catalog, table_schema, table_name)\n" + + "group by t.table_name, t.table_schema\n" + + "union all\n" + + "select\n" + + " -1 as oid,\n" + + " i.index_name as relname,\n" + + " case table_schema when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as relnamespace,\n" + + " 0 as reltype,\n" + + " 0 as reloftype,\n" + + " 0 as relowner,\n" + + " 1 as relam,\n" + + " 0 as relfilenode,\n" + + " 0 as reltablespace,\n" + + " 0 as relpages,\n" + + " 0.0::float8 as reltuples,\n" + + " 0 as relallvisible,\n" + + " 0 as reltoastrelid,\n" + + " false as relhasindex,\n" + + " false as relisshared,\n" + + " 'p' as relpersistence,\n" + + " 'r' as relkind,\n" + + " count(*) as relnatts,\n" + + " 0 as relchecks,\n" + + " false as relhasrules,\n" + + " false as relhastriggers,\n" + + " false as relhassubclass,\n" + + " false as relrowsecurity,\n" + + " false as relforcerowsecurity,\n" + + " true as relispopulated,\n" + + " 'n' as relreplident,\n" + + " false as relispartition,\n" + + " 0 as relrewrite,\n" + + " 0 as relfrozenxid,\n" + + " 0 as relminmxid,\n" + + " '{}'::bigint[] as relacl,\n" + + " '{}'::text[] as reloptions,\n" + + " 0 as relpartbound\n" + + "from information_schema.indexes i\n" + + "inner join information_schema.index_columns using (table_catalog, table_schema, table_name)\n" + + "group by i.index_name, i.table_schema\n" + + ")"; + + @Override + public String getTableExpression() { + return PG_CLASS_CTE; + } + } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/PrepareStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/PrepareStatement.java index d8f6d20bb..bd69f4617 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/PrepareStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/PrepareStatement.java @@ -20,9 +20,9 @@ import com.google.cloud.spanner.connection.AbstractStatementParser; import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.AbstractStatementParser.StatementType; +import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; -import com.google.cloud.spanner.pgadapter.metadata.DescribePortalMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.NoResult; import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName; @@ -108,7 +108,18 @@ public PrepareStatement( String name, ParsedStatement parsedStatement, Statement originalStatement) { - super(connectionHandler, options, name, parsedStatement, originalStatement); + super( + name, + new IntermediatePreparedStatement( + connectionHandler, + options, + name, + NO_PARAMETER_TYPES, + parsedStatement, + originalStatement), + NO_PARAMS, + ImmutableList.of(), + ImmutableList.of()); this.preparedStatement = parse(originalStatement.getSql()); } @@ -149,14 +160,14 @@ public void executeAsync(BackendConnection backendConnection) { } @Override - public Future describeAsync(BackendConnection backendConnection) { + public Future describeAsync(BackendConnection backendConnection) { // Return null to indicate that this PREPARE statement does not return any // RowDescriptionResponse. return Futures.immediateFuture(null); } @Override - public IntermediatePortalStatement bind( + public IntermediatePortalStatement createPortal( String name, byte[][] parameters, List parameterFormatCodes, diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectVersionStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectVersionStatement.java new file mode 100644 index 000000000..59980448b --- /dev/null +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/local/SelectVersionStatement.java @@ -0,0 +1,81 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.statements.local; + +import com.google.api.core.InternalApi; +import com.google.cloud.spanner.ResultSet; +import com.google.cloud.spanner.ResultSets; +import com.google.cloud.spanner.Struct; +import com.google.cloud.spanner.Type; +import com.google.cloud.spanner.Type.StructField; +import com.google.cloud.spanner.connection.StatementResult; +import com.google.cloud.spanner.pgadapter.statements.BackendConnection; +import com.google.cloud.spanner.pgadapter.statements.BackendConnection.QueryResult; +import com.google.common.collect.ImmutableList; + +@InternalApi +public class SelectVersionStatement implements LocalStatement { + public static final SelectVersionStatement INSTANCE = new SelectVersionStatement(); + + private SelectVersionStatement() {} + + @Override + public String[] getSql() { + return new String[] { + "select version()", + "SELECT version()", + "Select version()", + "select VERSION()", + "SELECT VERSION()", + "Select VERSION()", + "select * from version()", + "SELECT * FROM version()", + "Select * from version()", + "select * from VERSION()", + "SELECT * FROM VERSION()", + "Select * from VERSION()", + "select pg_catalog.version()", + "SELECT pg_catalog.version()", + "Select pg_catalog.version()", + "select PG_CATALOG.VERSION()", + "SELECT PG_CATALOG.VERSION()", + "Select PG_CATALOG.VERSION()", + "select * from pg_catalog.version()", + "SELECT * FROM pg_catalog.version()", + "Select * from pg_catalog.version()", + "select * from PG_CATALOG.VERSION()", + "SELECT * FROM PG_CATALOG.VERSION()", + "Select * from PG_CATALOG.VERSION()" + }; + } + + @Override + public StatementResult execute(BackendConnection backendConnection) { + ResultSet resultSet = + ResultSets.forRows( + Type.struct(StructField.of("version", Type.string())), + ImmutableList.of( + Struct.newBuilder() + .set("version") + .to( + "PostgreSQL " + + backendConnection + .getSessionState() + .get(null, "server_version") + .getSetting()) + .build())); + return new QueryResult(resultSet); + } +} diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetector.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetector.java index e53fa52d5..4d5ecec22 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetector.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/ClientAutoDetector.java @@ -22,6 +22,7 @@ import com.google.cloud.spanner.pgadapter.statements.local.SelectCurrentCatalogStatement; import com.google.cloud.spanner.pgadapter.statements.local.SelectCurrentDatabaseStatement; import com.google.cloud.spanner.pgadapter.statements.local.SelectCurrentSchemaStatement; +import com.google.cloud.spanner.pgadapter.statements.local.SelectVersionStatement; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.List; @@ -41,6 +42,7 @@ public class ClientAutoDetector { SelectCurrentSchemaStatement.INSTANCE, SelectCurrentDatabaseStatement.INSTANCE, SelectCurrentCatalogStatement.INSTANCE, + SelectVersionStatement.INSTANCE, DjangoGetTableNamesStatement.INSTANCE); public enum WellKnownClient { diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/Converter.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/Converter.java index 7e7aa8daa..48211567b 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/Converter.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/Converter.java @@ -29,6 +29,7 @@ import com.google.cloud.spanner.pgadapter.parsers.NumericParser; import com.google.cloud.spanner.pgadapter.parsers.StringParser; import com.google.cloud.spanner.pgadapter.parsers.TimestampParser; +import com.google.cloud.spanner.pgadapter.session.SessionState; import com.google.cloud.spanner.pgadapter.statements.CopyToStatement; import com.google.cloud.spanner.pgadapter.statements.IntermediateStatement; import com.google.common.base.Preconditions; @@ -44,6 +45,7 @@ public class Converter implements AutoCloseable { private final QueryMode mode; private final OptionsMetadata options; private final ResultSet resultSet; + private final SessionState sessionState; public Converter( IntermediateStatement statement, @@ -54,6 +56,12 @@ public Converter( this.mode = mode; this.options = options; this.resultSet = resultSet; + this.sessionState = + statement + .getConnectionHandler() + .getExtendedQueryProtocolHandler() + .getBackendConnection() + .getSessionState(); } @Override @@ -86,7 +94,7 @@ public int convertResultSetRowToDataRowResponse() throws IOException { fixedFormat == null ? DataFormat.getDataFormat(column_index, statement, mode, options) : fixedFormat; - byte[] column = Converter.convertToPG(this.resultSet, column_index, format); + byte[] column = Converter.convertToPG(this.resultSet, column_index, format, sessionState); outputStream.writeInt(column.length); outputStream.write(column); } @@ -107,7 +115,8 @@ public void writeBuffer(DataOutputStream outputStream) throws IOException { * @param format The {@link DataFormat} format to use to encode the data. * @return a byte array containing the data in the specified format. */ - public static byte[] convertToPG(ResultSet result, int position, DataFormat format) { + public static byte[] convertToPG( + ResultSet result, int position, DataFormat format, SessionState sessionState) { Preconditions.checkArgument(!result.isNull(position), "Column may not contain a null value"); Type type = result.getColumnType(position); switch (type.getCode()) { @@ -126,11 +135,11 @@ public static byte[] convertToPG(ResultSet result, int position, DataFormat form case STRING: return StringParser.convertToPG(result, position); case TIMESTAMP: - return TimestampParser.convertToPG(result, position, format); + return TimestampParser.convertToPG(result, position, format, sessionState.getTimezone()); case PG_JSONB: return JsonbParser.convertToPG(result, position, format); case ARRAY: - ArrayParser arrayParser = new ArrayParser(result, position); + ArrayParser arrayParser = new ArrayParser(result, position, sessionState); return arrayParser.parse(format); case NUMERIC: case JSON: diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java index 50a895db4..2dc290fe3 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java @@ -34,7 +34,9 @@ import com.google.cloud.spanner.Value; import com.google.cloud.spanner.connection.Connection; import com.google.cloud.spanner.connection.StatementResult; +import com.google.cloud.spanner.pgadapter.error.PGException; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; +import com.google.cloud.spanner.pgadapter.error.SQLState; import com.google.cloud.spanner.pgadapter.session.CopySettings; import com.google.cloud.spanner.pgadapter.session.SessionState; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.UpdateCount; @@ -106,7 +108,7 @@ public enum CopyTransactionMode { private final Format copyFormat; private final CSVFormat csvFormat; private final boolean hasHeader; - private final CountDownLatch parserCreatedLatch = new CountDownLatch(1); + private final CountDownLatch pipeCreatedLatch = new CountDownLatch(1); private final PipedOutputStream payload = new PipedOutputStream(); private final AtomicBoolean commit = new AtomicBoolean(false); private final AtomicBoolean rollback = new AtomicBoolean(false); @@ -117,7 +119,7 @@ public enum CopyTransactionMode { private final Object lock = new Object(); @GuardedBy("lock") - private SpannerException exception; + private PGException exception; public MutationWriter( SessionState sessionState, @@ -161,7 +163,7 @@ public void addCopyData(byte[] payload) { } } try { - parserCreatedLatch.await(); + pipeCreatedLatch.await(); this.payload.write(payload); } catch (InterruptedException | InterruptedIOException interruptedIOException) { // The IO operation was interrupted. This indicates that the user wants to cancel the COPY @@ -173,11 +175,13 @@ public void addCopyData(byte[] payload) { // Ignore the exception if the executor has already been shutdown. That means that an error // occurred that ended the COPY operation while we were writing data to the buffer. if (!executorService.isShutdown()) { - SpannerException spannerException = - SpannerExceptionFactory.newSpannerException( - ErrorCode.INTERNAL, "Could not write copy data to buffer", e); - logger.log(Level.SEVERE, spannerException.getMessage(), spannerException); - throw spannerException; + PGException pgException = + PGException.newBuilder("Could not write copy data to buffer") + .setSQLState(SQLState.InternalError) + .setCause(e) + .build(); + logger.log(Level.SEVERE, pgException.getMessage(), pgException); + throw pgException; } } } @@ -205,7 +209,7 @@ public void close() throws IOException { @Override public StatementResult call() throws Exception { PipedInputStream inputStream = new PipedInputStream(payload, copySettings.getPipeBufferSize()); - parserCreatedLatch.countDown(); + pipeCreatedLatch.countDown(); final CopyInParser parser = CopyInParser.create(copyFormat, csvFormat, inputStream, hasHeader); // This LinkedBlockingDeque holds a reference to all transactions that are currently active. The // max capacity of this deque is what ensures that we never have more than maxParallelism @@ -231,12 +235,12 @@ public StatementResult call() throws Exception { while (!rollback.get() && iterator.hasNext()) { CopyRecord record = iterator.next(); if (record.numColumns() != this.tableColumns.keySet().size()) { - throw SpannerExceptionFactory.newSpannerException( - ErrorCode.INVALID_ARGUMENT, + throw PGExceptionFactory.newPGException( "Invalid COPY data: Row length mismatched. Expected " + this.tableColumns.keySet().size() + " columns, but only found " - + record.numColumns()); + + record.numColumns(), + SQLState.DataException); } Mutation mutation = buildMutation(record); @@ -308,17 +312,17 @@ public StatementResult call() throws Exception { ApiFutures.allAsList(allCommitFutures).get(); } catch (SpannerException e) { synchronized (lock) { - this.exception = e; + this.exception = PGExceptionFactory.toPGException(e); throw this.exception; } } catch (ExecutionException e) { synchronized (lock) { - this.exception = SpannerExceptionFactory.asSpannerException(e.getCause()); + this.exception = PGExceptionFactory.toPGException(e.getCause()); throw this.exception; } } catch (Exception e) { synchronized (lock) { - this.exception = SpannerExceptionFactory.asSpannerException(e); + this.exception = PGExceptionFactory.toPGException(e); throw this.exception; } } finally { diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BindMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BindMessage.java index 84af15a68..9f8526145 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BindMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BindMessage.java @@ -17,8 +17,10 @@ import com.google.api.core.InternalApi; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.statements.BackendConnection; +import com.google.cloud.spanner.pgadapter.statements.IntermediatePortalStatement; import com.google.cloud.spanner.pgadapter.statements.IntermediatePreparedStatement; import com.google.cloud.spanner.pgadapter.wireoutput.BindCompleteResponse; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import java.text.MessageFormat; import java.util.Arrays; @@ -38,7 +40,7 @@ public class BindMessage extends AbstractQueryProtocolMessage { private final List formatCodes; private final List resultFormatCodes; private final byte[][] parameters; - private final IntermediatePreparedStatement statement; + private final IntermediatePortalStatement statement; /** Constructor for Bind messages that are received from the front-end. */ public BindMessage(ConnectionHandler connection) throws Exception { @@ -48,7 +50,11 @@ public BindMessage(ConnectionHandler connection) throws Exception { this.formatCodes = getFormatCodes(this.inputStream); this.parameters = getParameters(this.inputStream); this.resultFormatCodes = getFormatCodes(this.inputStream); - this.statement = connection.getStatement(this.statementName); + IntermediatePreparedStatement statement = connection.getStatement(statementName); + this.statement = + statement.createPortal( + this.portalName, this.parameters, this.formatCodes, this.resultFormatCodes); + this.connection.registerPortal(this.portalName, this.statement); } /** Constructor for Bind messages that are constructed to execute a Query message. */ @@ -67,27 +73,33 @@ public BindMessage( this.statementName = statementName; this.formatCodes = ImmutableList.of(); this.resultFormatCodes = ImmutableList.of(); - this.parameters = parameters; - this.statement = connection.getStatement(statementName); + this.parameters = Preconditions.checkNotNull(parameters); + IntermediatePreparedStatement statement = connection.getStatement(statementName); + this.statement = + statement.createPortal( + this.portalName, this.parameters, this.formatCodes, this.resultFormatCodes); + this.connection.registerPortal(this.portalName, this.statement); + } + + boolean hasParameterValues() { + return this.parameters.length > 0; } @Override void buffer(BackendConnection backendConnection) { - if (isExtendedProtocol() && !this.statement.isDescribed()) { + if (isExtendedProtocol() && !this.statement.getPreparedStatement().isDescribed()) { try { // Make sure all parameters have been described, so we always send typed parameters to Cloud // Spanner. - this.statement.describeParameters(this.parameters, true); + this.statement + .getPreparedStatement() + .autoDescribeParameters(this.parameters, backendConnection); } catch (Throwable ignore) { // Ignore any error messages while describing the parameters, and let the following // DescribePortal or execute message handle any errors that are caused by invalid // statements. } } - this.connection.registerPortal( - this.portalName, - this.statement.bind( - this.portalName, this.parameters, this.formatCodes, this.resultFormatCodes)); } @Override diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BootstrapMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BootstrapMessage.java index 8a5b20b7f..8cc11a6c9 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BootstrapMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BootstrapMessage.java @@ -23,11 +23,12 @@ import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse; import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse.Status; import java.io.DataOutputStream; +import java.nio.charset.StandardCharsets; +import java.time.ZoneId; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.TimeZone; /** * This represents all messages which occur before {@link ControlMessage} type messages. Those @@ -108,30 +109,63 @@ protected List parseParameterKeys(String rawParameters) { public static void sendStartupMessage( DataOutputStream output, int connectionId, int secret, SessionState sessionState) throws Exception { - new AuthenticationOkResponse(output).send(); - new KeyDataResponse(output, connectionId, secret).send(); + new AuthenticationOkResponse(output).send(false); + new KeyDataResponse(output, connectionId, secret).send(false); new ParameterStatusResponse( output, - "server_version".getBytes(), - sessionState.get(null, "server_version").getSetting().getBytes()) - .send(); - new ParameterStatusResponse(output, "application_name".getBytes(), "PGAdapter".getBytes()) - .send(); - new ParameterStatusResponse(output, "is_superuser".getBytes(), "false".getBytes()).send(); - new ParameterStatusResponse(output, "session_authorization".getBytes(), "PGAdapter".getBytes()) - .send(); - new ParameterStatusResponse(output, "integer_datetimes".getBytes(), "on".getBytes()).send(); - new ParameterStatusResponse(output, "server_encoding".getBytes(), "UTF8".getBytes()).send(); - new ParameterStatusResponse(output, "client_encoding".getBytes(), "UTF8".getBytes()).send(); - new ParameterStatusResponse(output, "DateStyle".getBytes(), "ISO,YMD".getBytes()).send(); - new ParameterStatusResponse(output, "IntervalStyle".getBytes(), "iso_8601".getBytes()).send(); - new ParameterStatusResponse(output, "standard_conforming_strings".getBytes(), "on".getBytes()) - .send(); + "server_version".getBytes(StandardCharsets.UTF_8), + sessionState.get(null, "server_version").getSetting().getBytes(StandardCharsets.UTF_8)) + .send(false); new ParameterStatusResponse( output, - "TimeZone".getBytes(), - TimeZone.getDefault().getDisplayName(false, TimeZone.SHORT).getBytes()) - .send(); + "application_name".getBytes(StandardCharsets.UTF_8), + "PGAdapter".getBytes(StandardCharsets.UTF_8)) + .send(false); + new ParameterStatusResponse( + output, + "is_superuser".getBytes(StandardCharsets.UTF_8), + "false".getBytes(StandardCharsets.UTF_8)) + .send(false); + new ParameterStatusResponse( + output, + "session_authorization".getBytes(StandardCharsets.UTF_8), + "PGAdapter".getBytes(StandardCharsets.UTF_8)) + .send(false); + new ParameterStatusResponse( + output, + "integer_datetimes".getBytes(StandardCharsets.UTF_8), + "on".getBytes(StandardCharsets.UTF_8)) + .send(false); + new ParameterStatusResponse( + output, + "server_encoding".getBytes(StandardCharsets.UTF_8), + "UTF8".getBytes(StandardCharsets.UTF_8)) + .send(false); + new ParameterStatusResponse( + output, + "client_encoding".getBytes(StandardCharsets.UTF_8), + "UTF8".getBytes(StandardCharsets.UTF_8)) + .send(false); + new ParameterStatusResponse( + output, + "DateStyle".getBytes(StandardCharsets.UTF_8), + "ISO,YMD".getBytes(StandardCharsets.UTF_8)) + .send(false); + new ParameterStatusResponse( + output, + "IntervalStyle".getBytes(StandardCharsets.UTF_8), + "iso_8601".getBytes(StandardCharsets.UTF_8)) + .send(false); + new ParameterStatusResponse( + output, + "standard_conforming_strings".getBytes(StandardCharsets.UTF_8), + "on".getBytes(StandardCharsets.UTF_8)) + .send(false); + new ParameterStatusResponse( + output, + "TimeZone".getBytes(StandardCharsets.UTF_8), + ZoneId.systemDefault().getId().getBytes(StandardCharsets.UTF_8)) + .send(false); new ReadyResponse(output, Status.IDLE).send(); } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java index 0a3b07a08..e233eb411 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java @@ -28,6 +28,7 @@ import com.google.cloud.spanner.connection.Connection; import com.google.cloud.spanner.connection.ConnectionOptionsHelper; import com.google.cloud.spanner.connection.StatementResult; +import com.google.cloud.spanner.connection.StatementResult.ResultType; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.ConnectionHandler.ConnectionStatus; import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode; @@ -243,7 +244,7 @@ protected void handleError(Exception exception) throws Exception { * *

NOTE: This method does not flush the output stream. */ - public void sendSpannerResult(IntermediateStatement statement, QueryMode mode, long maxRows) + void sendSpannerResult(IntermediateStatement statement, QueryMode mode, long maxRows) throws Exception { String command = statement.getCommandTag(); if (Strings.isNullOrEmpty(command)) { @@ -253,31 +254,38 @@ public void sendSpannerResult(IntermediateStatement statement, QueryMode mode, l if (statement.getStatementResult() == null) { return; } - switch (statement.getStatementType()) { case DDL: - case CLIENT_SIDE: case UNKNOWN: new CommandCompleteResponse(this.outputStream, command).send(false); break; + case CLIENT_SIDE: + if (statement.getStatementResult().getResultType() != ResultType.RESULT_SET) { + new CommandCompleteResponse(this.outputStream, command).send(false); + break; + } + // fallthrough to QUERY case QUERY: - SendResultSetState state = sendResultSet(statement, mode, maxRows); - statement.setHasMoreData(state.hasMoreRows()); - if (state.hasMoreRows()) { - new PortalSuspendedResponse(this.outputStream).send(false); + case UPDATE: + if (statement.getStatementResult().getResultType() == ResultType.RESULT_SET) { + SendResultSetState state = sendResultSet(statement, mode, maxRows); + statement.setHasMoreData(state.hasMoreRows()); + if (state.hasMoreRows()) { + new PortalSuspendedResponse(this.outputStream).send(false); + } else { + statement.close(); + new CommandCompleteResponse(this.outputStream, state.getCommandAndNumRows()) + .send(false); + } } else { - statement.close(); - new CommandCompleteResponse(this.outputStream, state.getCommandAndNumRows()).send(false); + // For an INSERT command, the tag is INSERT oid rows, where rows is the number of rows + // inserted. oid used to be the object ID of the inserted row if rows was 1 and the target + // table had OIDs, but OIDs system columns are not supported anymore; therefore oid is + // always 0. + command += ("INSERT".equals(command) ? " 0 " : " ") + statement.getUpdateCount(); + new CommandCompleteResponse(this.outputStream, command).send(false); } break; - case UPDATE: - // For an INSERT command, the tag is INSERT oid rows, where rows is the number of rows - // inserted. oid used to be the object ID of the inserted row if rows was 1 and the target - // table had OIDs, but OIDs system columns are not supported anymore; therefore oid is - // always 0. - command += ("INSERT".equals(command) ? " 0 " : " ") + statement.getUpdateCount(); - new CommandCompleteResponse(this.outputStream, command).send(false); - break; default: throw new IllegalStateException("Unknown statement type: " + statement.getStatement()); } @@ -293,12 +301,13 @@ public void sendSpannerResult(IntermediateStatement statement, QueryMode mode, l * @return An adapted representation with specific metadata which PG wire requires. * @throws com.google.cloud.spanner.SpannerException if traversing the {@link ResultSet} fails. */ - public SendResultSetState sendResultSet( + SendResultSetState sendResultSet( IntermediateStatement describedResult, QueryMode mode, long maxRows) throws Exception { + StatementResult statementResult = describedResult.getStatementResult(); Preconditions.checkArgument( - describedResult.containsResultSet(), "The statement result must be a result set"); + statementResult.getResultType() == ResultType.RESULT_SET, + "The statement result must be a result set"); long rows; - StatementResult statementResult = describedResult.getStatementResult(); boolean hasData; if (statementResult instanceof PartitionQueryResult) { hasData = false; diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDataMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDataMessage.java index d10b73e96..51bfb0b8d 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDataMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDataMessage.java @@ -15,8 +15,8 @@ package com.google.cloud.spanner.pgadapter.wireprotocol; import com.google.api.core.InternalApi; -import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.pgadapter.ConnectionHandler; +import com.google.cloud.spanner.pgadapter.error.PGException; import com.google.cloud.spanner.pgadapter.statements.CopyStatement; import com.google.cloud.spanner.pgadapter.utils.MutationWriter; import java.text.MessageFormat; @@ -57,7 +57,7 @@ protected void sendPayload() throws Exception { if (!statement.hasException()) { try { mutationWriter.addCopyData(this.payload); - } catch (SpannerException exception) { + } catch (PGException exception) { statement.handleExecutionException(exception); throw exception; } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyFailMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyFailMessage.java index b2dbfe865..84b190c05 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyFailMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyFailMessage.java @@ -15,10 +15,10 @@ package com.google.cloud.spanner.pgadapter.wireprotocol; import com.google.api.core.InternalApi; -import com.google.cloud.spanner.ErrorCode; -import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.ConnectionHandler.ConnectionStatus; +import com.google.cloud.spanner.pgadapter.error.PGException; +import com.google.cloud.spanner.pgadapter.error.SQLState; import com.google.cloud.spanner.pgadapter.statements.CopyStatement; import com.google.cloud.spanner.pgadapter.utils.MutationWriter; import java.text.MessageFormat; @@ -49,7 +49,7 @@ protected void sendPayload() throws Exception { mutationWriter.rollback(); statement.close(); this.statement.handleExecutionException( - SpannerExceptionFactory.newSpannerException(ErrorCode.CANCELLED, this.errorMessage)); + PGException.newBuilder(this.errorMessage).setSQLState(SQLState.QueryCanceled).build()); } // Clear the COPY_IN status to indicate that we finished unsuccessfully. This will cause the // inline handling of incoming (copy) messages to stop and the server to return an error message diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/DescribeMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/DescribeMessage.java index f11b95442..4edc9ba5b 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/DescribeMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/DescribeMessage.java @@ -15,14 +15,13 @@ package com.google.cloud.spanner.pgadapter.wireprotocol; import com.google.api.core.InternalApi; -import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; +import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; -import com.google.cloud.spanner.pgadapter.metadata.DescribePortalMetadata; -import com.google.cloud.spanner.pgadapter.metadata.DescribeStatementMetadata; +import com.google.cloud.spanner.pgadapter.metadata.DescribeResult; import com.google.cloud.spanner.pgadapter.statements.BackendConnection; -import com.google.cloud.spanner.pgadapter.statements.IntermediatePreparedStatement; +import com.google.cloud.spanner.pgadapter.statements.IntermediateStatement; import com.google.cloud.spanner.pgadapter.wireoutput.NoDataResponse; import com.google.cloud.spanner.pgadapter.wireoutput.ParameterDescriptionResponse; import com.google.cloud.spanner.pgadapter.wireoutput.RowDescriptionResponse; @@ -39,8 +38,8 @@ public class DescribeMessage extends AbstractQueryProtocolMessage { private final PreparedType type; private final String name; - private final IntermediatePreparedStatement statement; - private Future describePortalMetadata; + private final IntermediateStatement statement; + private Future describePortalMetadata; public DescribeMessage(ConnectionHandler connection) throws Exception { super(connection); @@ -78,10 +77,9 @@ public DescribeMessage( @Override void buffer(BackendConnection backendConnection) { if (this.type == PreparedType.Portal && this.statement.containsResultSet()) { - describePortalMetadata = - (Future) this.statement.describeAsync(backendConnection); + describePortalMetadata = this.statement.describeAsync(backendConnection); } else if (this.type == PreparedType.Statement) { - this.statement.setDescribed(); + describePortalMetadata = this.statement.describeAsync(backendConnection); } } @@ -93,7 +91,7 @@ public void flush() throws Exception { } else { this.handleDescribeStatement(); } - } catch (SpannerException e) { + } catch (Exception e) { handleError(e); } } @@ -149,11 +147,11 @@ void handleDescribePortal() throws Exception { new RowDescriptionResponse( this.outputStream, this.statement, - getPortalMetadata().getMetadata(), + getPortalMetadata().getResultSet(), this.connection.getServer().getOptions(), this.queryMode) .send(false); - } catch (SpannerException exception) { + } catch (Exception exception) { this.handleError(exception); } } @@ -168,7 +166,7 @@ void handleDescribePortal() throws Exception { } @VisibleForTesting - DescribePortalMetadata getPortalMetadata() { + StatementResult getPortalMetadata() { if (!this.describePortalMetadata.isDone()) { throw new IllegalStateException("Trying to get Portal Metadata before it has been described"); } @@ -187,11 +185,13 @@ DescribePortalMetadata getPortalMetadata() { * @throws Exception if sending the message back to the client causes an error. */ public void handleDescribeStatement() throws Exception { - try (DescribeStatementMetadata metadata = - (DescribeStatementMetadata) this.statement.describe()) { + if (this.statement.hasException()) { + throw this.statement.getException(); + } else { if (isExtendedProtocol()) { + DescribeResult metadata = this.statement.describe(); new ParameterDescriptionResponse(this.outputStream, metadata.getParameters()).send(false); - if (metadata.getResultSet() != null) { + if (metadata.getResultSet() != null && metadata.getResultSet().getColumnCount() > 0) { new RowDescriptionResponse( this.outputStream, this.statement, @@ -203,8 +203,6 @@ public void handleDescribeStatement() throws Exception { new NoDataResponse(this.outputStream).send(false); } } - } catch (SpannerException exception) { - this.handleError(exception); } } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java index f986ab41d..48957a87b 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ParseMessage.java @@ -123,15 +123,13 @@ static IntermediatePreparedStatement createStatement( parsedStatement, originalStatement); } else { - IntermediatePreparedStatement statement = - new IntermediatePreparedStatement( - connectionHandler, - connectionHandler.getServer().getOptions(), - name, - parsedStatement, - originalStatement); - statement.setParameterDataTypes(parameterDataTypes); - return statement; + return new IntermediatePreparedStatement( + connectionHandler, + connectionHandler.getServer().getOptions(), + name, + parameterDataTypes, + parsedStatement, + originalStatement); } } catch (Exception exception) { return new InvalidStatement( diff --git a/src/test/golang/pgadapter_gorm_tests/go.mod b/src/test/golang/pgadapter_gorm_tests/go.mod index ddbf27cc5..d553f4658 100644 --- a/src/test/golang/pgadapter_gorm_tests/go.mod +++ b/src/test/golang/pgadapter_gorm_tests/go.mod @@ -3,25 +3,25 @@ module cloud.google.com/pgadapter_gorm_tests go 1.17 require ( - github.com/shopspring/decimal v1.2.0 - gorm.io/datatypes v1.0.6 - gorm.io/driver/postgres v1.3.7 - gorm.io/gorm v1.23.6 + github.com/shopspring/decimal v1.3.1 + gorm.io/datatypes v1.0.7 + gorm.io/driver/postgres v1.4.5 + gorm.io/gorm v1.24.2 ) require ( github.com/go-sql-driver/mysql v1.6.0 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect - github.com/jackc/pgconn v1.12.1 // indirect + github.com/jackc/pgconn v1.13.0 // indirect github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgproto3/v2 v2.3.0 // indirect + github.com/jackc/pgproto3/v2 v2.3.1 // indirect github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect - github.com/jackc/pgtype v1.11.0 // indirect - github.com/jackc/pgx/v4 v4.16.1 // indirect + github.com/jackc/pgtype v1.12.0 // indirect + github.com/jackc/pgx/v4 v4.17.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect - github.com/jinzhu/now v1.1.4 // indirect - golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect + github.com/jinzhu/now v1.1.5 // indirect + golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect golang.org/x/text v0.3.7 // indirect gorm.io/driver/mysql v1.3.2 // indirect ) diff --git a/src/test/golang/pgadapter_gorm_tests/go.sum b/src/test/golang/pgadapter_gorm_tests/go.sum index f81df1c1d..e9f369446 100644 --- a/src/test/golang/pgadapter_gorm_tests/go.sum +++ b/src/test/golang/pgadapter_gorm_tests/go.sum @@ -43,6 +43,8 @@ github.com/jackc/pgconn v1.10.1/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoy github.com/jackc/pgconn v1.11.0/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= github.com/jackc/pgconn v1.12.1 h1:rsDFzIpRk7xT4B8FufgpCCeyjdNpKyghZeSefViE5W8= github.com/jackc/pgconn v1.12.1/go.mod h1:ZkhRC59Llhrq3oSfrikvwQ5NaxYExr6twkdkMLaKono= +github.com/jackc/pgconn v1.13.0 h1:3L1XMNV2Zvca/8BYhzcRFS70Lr0WlDg16Di6SFGAbys= +github.com/jackc/pgconn v1.13.0/go.mod h1:AnowpAqO4CMIIJNZl2VJp+KrkAZciAkhEl0W0JIobpI= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= @@ -62,6 +64,8 @@ github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwX github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.3.0 h1:brH0pCGBDkBW07HWlN/oSBXrmo3WB0UvZd1pIuDcL8Y= github.com/jackc/pgproto3/v2 v2.3.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.3.1 h1:nwj7qwf0S+Q7ISFfBndqeLwSwxs+4DPsbRFjECT1Y4Y= +github.com/jackc/pgproto3/v2 v2.3.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= @@ -72,6 +76,8 @@ github.com/jackc/pgtype v1.9.1/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76I github.com/jackc/pgtype v1.10.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= github.com/jackc/pgtype v1.11.0 h1:u4uiGPz/1hryuXzyaBhSk6dnIyyG2683olG2OV+UUgs= github.com/jackc/pgtype v1.11.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgtype v1.12.0 h1:Dlq8Qvcch7kiehm8wPGIW0W3KsCCHJnRacKW0UM8n5w= +github.com/jackc/pgtype v1.12.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= @@ -80,15 +86,20 @@ github.com/jackc/pgx/v4 v4.14.1/go.mod h1:RgDuE4Z34o7XE92RpLsvFiOEfrAUT0Xt2KxvX7 github.com/jackc/pgx/v4 v4.15.0/go.mod h1:D/zyOyXiaM1TmVWnOM18p0xdDtdakRBa0RsVGI3U3bw= github.com/jackc/pgx/v4 v4.16.1 h1:JzTglcal01DrghUqt+PmzWsZx/Yh7SC/CTQmSBMTd0Y= github.com/jackc/pgx/v4 v4.16.1/go.mod h1:SIhx0D5hoADaiXZVyv+3gSm3LCIIINTVO0PficsvWGQ= +github.com/jackc/pgx/v4 v4.17.2 h1:0Ut0rpeKwvIVbMQ1KbMBU4h6wxehBI535LK6Flheh8E= +github.com/jackc/pgx/v4 v4.17.2/go.mod h1:lcxIZN44yMIrWI78a5CpucdD14hX0SBDbNRvjDBItsw= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -123,17 +134,23 @@ github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdh github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= +github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= @@ -159,6 +176,8 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE= golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= @@ -215,13 +234,20 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/datatypes v1.0.6 h1:3cqbakp1DIgC+P7wyODb5k+lSjW8g3mjkg/BIsmhjlE= gorm.io/datatypes v1.0.6/go.mod h1:Gh/Xd/iUWWybMEk8CzYCK/swqlni2r+ROeM1HGIM0ck= +gorm.io/datatypes v1.0.7 h1:8NhJN4+annFjwV1WufDhFiPjdUvV1lSGUdg1UCjQIWY= +gorm.io/datatypes v1.0.7/go.mod h1:l9qkCuy0CdzDEop9HKUdcnC9gHC2sRlaFtHkTzsZRqg= gorm.io/driver/mysql v1.3.2 h1:QJryWiqQ91EvZ0jZL48NOpdlPdMjdip1hQ8bTgo4H7I= gorm.io/driver/mysql v1.3.2/go.mod h1:ChK6AHbHgDCFZyJp0F+BmVGb06PSIoh9uVYKAlRbb2U= gorm.io/driver/postgres v1.3.1/go.mod h1:WwvWOuR9unCLpGWCL6Y3JOeBWvbKi6JLhayiVclSZZU= +gorm.io/driver/postgres v1.3.4/go.mod h1:y0vEuInFKJtijuSGu9e5bs5hzzSzPK+LancpKpvbRBw= gorm.io/driver/postgres v1.3.7 h1:FKF6sIMDHDEvvMF/XJvbnCl0nu6KSKUaPXevJ4r+VYQ= gorm.io/driver/postgres v1.3.7/go.mod h1:f02ympjIcgtHEGFMZvdgTxODZ9snAHDb4hXfigBVuNI= +gorm.io/driver/postgres v1.4.5 h1:mTeXTTtHAgnS9PgmhN2YeUbazYpLhUI1doLnw42XUZc= +gorm.io/driver/postgres v1.4.5/go.mod h1:GKNQYSJ14qvWkvPwXljMGehpKrhlDNsqYRr5HnYGncg= gorm.io/driver/sqlite v1.3.1 h1:bwfE+zTEWklBYoEodIOIBwuWHpnx52Z9zJFW5F33WLk= gorm.io/driver/sqlite v1.3.1/go.mod h1:wJx0hJspfycZ6myN38x1O/AqLtNS6c5o9TndewFbELg= gorm.io/driver/sqlserver v1.3.1 h1:F5t6ScMzOgy1zukRTIZgLZwKahgt3q1woAILVolKpOI= @@ -231,4 +257,8 @@ gorm.io/gorm v1.23.2/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= gorm.io/gorm v1.23.4/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= gorm.io/gorm v1.23.6 h1:KFLdNgri4ExFFGTRGGFWON2P1ZN28+9SJRN8voOoYe0= gorm.io/gorm v1.23.6/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= +gorm.io/gorm v1.24.1-0.20221019064659-5dd2bb482755 h1:7AdrbfcvKnzejfqP5g37fdSZOXH/JvaPIzBIHTOqXKk= +gorm.io/gorm v1.24.1-0.20221019064659-5dd2bb482755/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= +gorm.io/gorm v1.24.2 h1:9wR6CFD+G8nOusLdvkZelOEhpJVwwHzpQOUM+REd6U0= +gorm.io/gorm v1.24.2/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/src/test/golang/pgadapter_gorm_tests/gorm.go b/src/test/golang/pgadapter_gorm_tests/gorm.go index 1ea0fb57c..eceebbcb1 100644 --- a/src/test/golang/pgadapter_gorm_tests/gorm.go +++ b/src/test/golang/pgadapter_gorm_tests/gorm.go @@ -40,15 +40,16 @@ func main() { type User struct { // Prevent gorm from using an auto-generated key. - ID int64 `gorm:"primaryKey;autoIncrement:false"` - Name string - Email *string - Age int64 - Birthday *time.Time - MemberNumber sql.NullString - ActivatedAt sql.NullTime - CreatedAt time.Time - UpdatedAt time.Time + ID int64 `gorm:"primaryKey;autoIncrement:false"` + Name string + Email *string + Age int64 + Birthday *time.Time + MemberNumber sql.NullString + NameAndNumber string `gorm:"->;type:GENERATED ALWAYS AS (coalesce(concat(name,' '::varchar,member_number))) STORED;default:(-);"` + ActivatedAt sql.NullTime + CreatedAt time.Time + UpdatedAt time.Time } type Blog struct { @@ -97,6 +98,9 @@ func TestCreateBlogAndUser(connString string) *C.char { if res.Error != nil { return C.CString(fmt.Sprintf("failed to create User: %v", res.Error)) } + if g, w := user.NameAndNumber, "User Name null"; g != w { + return C.CString(fmt.Sprintf("Name and number mismatch for User\nGot: %v\nWant: %v", g, w)) + } if g, w := res.RowsAffected, int64(1); g != w { return C.CString(fmt.Sprintf("affected row count mismatch for User\nGot: %v\nWant: %v", g, w)) } diff --git a/src/test/golang/pgadapter_pgx_tests/pgx.go b/src/test/golang/pgadapter_pgx_tests/pgx.go index 48e56c609..cee4fcd73 100644 --- a/src/test/golang/pgadapter_pgx_tests/pgx.go +++ b/src/test/golang/pgadapter_pgx_tests/pgx.go @@ -254,6 +254,95 @@ func TestInsertNullsAllDataTypes(connString string) *C.char { return nil } +//export TestInsertAllDataTypesReturning +func TestInsertAllDataTypesReturning(connString string) *C.char { + ctx := context.Background() + conn, err := pgx.Connect(ctx, connString) + if err != nil { + return C.CString(err.Error()) + } + defer conn.Close(ctx) + + sql := "INSERT INTO all_types (col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) returning *" + numeric := pgtype.Numeric{} + _ = numeric.Set("6.626") + timestamptz, _ := time.Parse(time.RFC3339Nano, "2022-03-24T07:39:10.123456789+01:00") + date := pgtype.Date{} + _ = date.Set("2022-04-02") + var row pgx.Row + if strings.Contains(connString, "prefer_simple_protocol=true") { + // Simple mode will format the date as '2022-04-02 00:00:00Z', which is not supported by the + // backend yet. + row = conn.QueryRow(ctx, sql, 100, true, []byte("test_bytes"), 3.14, 1, numeric, timestamptz, "2022-04-02", "test_string", "{\"key\": \"value\"}") + } else { + row = conn.QueryRow(ctx, sql, 100, true, []byte("test_bytes"), 3.14, 1, numeric, timestamptz, date, "test_string", "{\"key\": \"value\"}") + } + var bigintValue int64 + var boolValue bool + var byteaValue []byte + var float8Value float64 + var intValue int + var numericValue pgtype.Numeric // pgx by default maps numeric to string + var timestamptzValue time.Time + var dateValue time.Time + var varcharValue string + var jsonbValue string + + err = row.Scan( + &bigintValue, + &boolValue, + &byteaValue, + &float8Value, + &intValue, + &numericValue, + ×tamptzValue, + &dateValue, + &varcharValue, + &jsonbValue, + ) + if err != nil { + return C.CString(fmt.Sprintf("Failed to execute insert: %v", err.Error())) + } + if g, w := bigintValue, int64(1); g != w { + return C.CString(fmt.Sprintf("value mismatch\n Got: %v\nWant: %v", g, w)) + } + if g, w := boolValue, true; g != w { + return C.CString(fmt.Sprintf("value mismatch\n Got: %v\nWant: %v", g, w)) + } + if g, w := byteaValue, []byte("test"); !reflect.DeepEqual(g, w) { + return C.CString(fmt.Sprintf("value mismatch\n Got: %v\nWant: %v", g, w)) + } + if g, w := float8Value, 3.14; g != w { + return C.CString(fmt.Sprintf("value mismatch\n Got: %v\nWant: %v", g, w)) + } + if g, w := intValue, 100; g != w { + return C.CString(fmt.Sprintf("value mismatch\n Got: %v\nWant: %v", g, w)) + } + var wantNumericValue pgtype.Numeric + _ = wantNumericValue.Scan("6.626") + if g, w := numericValue, wantNumericValue; !reflect.DeepEqual(g, w) { + return C.CString(fmt.Sprintf("value mismatch\n Got: %v\nWant: %v", g, w)) + } + wantDateValue, _ := time.Parse("2006-01-02", "2022-03-29") + if g, w := dateValue, wantDateValue; !reflect.DeepEqual(g, w) { + return C.CString(fmt.Sprintf("value mismatch\n Got: %v\nWant: %v", g, w)) + } + // Encoding the timestamp values as a parameter will truncate it to microsecond precision. + wantTimestamptzValue, _ := time.Parse(time.RFC3339Nano, "2022-02-16T13:18:02.123456+00:00") + if g, w := timestamptzValue.UTC().String(), wantTimestamptzValue.UTC().String(); g != w { + return C.CString(fmt.Sprintf("value mismatch\n Got: %v\nWant: %v", g, w)) + } + if g, w := varcharValue, "test"; g != w { + return C.CString(fmt.Sprintf("value mismatch\n Got: %v\nWant: %v", g, w)) + } + if g, w := jsonbValue, "{\"key\": \"value\"}"; g != w { + return C.CString(fmt.Sprintf("value mismatch\n Got: %v\nWant: %v", g, w)) + } + + return nil +} + //export TestUpdateAllDataTypes func TestUpdateAllDataTypes(connString string) *C.char { ctx := context.Background() diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/AbortedMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/AbortedMockServerTest.java new file mode 100644 index 000000000..8d18de140 --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/pgadapter/AbortedMockServerTest.java @@ -0,0 +1,1500 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter; + +import static com.google.cloud.spanner.pgadapter.statements.BackendConnection.TRANSACTION_ABORTED_ERROR; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.cloud.ByteArray; +import com.google.cloud.Date; +import com.google.cloud.NoCredentials; +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.DatabaseId; +import com.google.cloud.spanner.Dialect; +import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.Spanner; +import com.google.cloud.spanner.SpannerOptions; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.RandomResultSetGenerator; +import com.google.common.collect.ImmutableList; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.TypeCode; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Status; +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ParameterMetaData; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Types; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.time.temporal.ChronoUnit; +import java.util.Calendar; +import java.util.stream.Collectors; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.postgresql.PGConnection; +import org.postgresql.PGStatement; +import org.postgresql.core.Oid; +import org.postgresql.jdbc.PgStatement; + +@RunWith(JUnit4.class) +public class AbortedMockServerTest extends AbstractMockServerTest { + private static final int RANDOM_RESULTS_ROW_COUNT = 100; + private static final Statement SELECT_RANDOM = Statement.of("select * from random_table"); + private static final ImmutableList JDBC_STARTUP_STATEMENTS = + ImmutableList.of( + "SET extra_float_digits = 3", "SET application_name = 'PostgreSQL JDBC Driver'"); + + @BeforeClass + public static void loadPgJdbcDriver() throws Exception { + // Make sure the PG JDBC driver is loaded. + Class.forName("org.postgresql.Driver"); + + addRandomResultResults(); + mockSpanner.setAbortProbability(0.2); + } + + private static void addRandomResultResults() { + RandomResultSetGenerator generator = + new RandomResultSetGenerator(RANDOM_RESULTS_ROW_COUNT, Dialect.POSTGRESQL); + mockSpanner.putStatementResult(StatementResult.query(SELECT_RANDOM, generator.generate())); + } + + private String getExpectedInitialApplicationName() { + return "PostgreSQL JDBC Driver"; + } + + /** + * Creates a JDBC connection string that instructs the PG JDBC driver to use the default extended + * mode for queries and DML statements. + */ + private String createUrl(String extraOptions) { + return String.format("jdbc:postgresql://localhost:%d/" + extraOptions, pgServer.getLocalPort()); + } + + private Connection createConnection() throws SQLException { + return createConnection(""); + } + + private Connection createConnection(String extraOptions) throws SQLException { + Connection connection = DriverManager.getConnection(createUrl(extraOptions)); + connection.setAutoCommit(false); + return connection; + } + + @Test + public void testQuery() throws SQLException { + String sql = "SELECT 1"; + + try (Connection connection = createConnection()) { + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + assertEquals(1L, resultSet.getLong(1)); + assertFalse(resultSet.next()); + } + } + assertFalse(mockSpanner.getTransactionsStarted().isEmpty()); + } + + @Test + public void testSelectCurrentSchema() throws SQLException { + String sql = "SELECT current_schema"; + + try (Connection connection = createConnection()) { + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + assertEquals("public", resultSet.getString("current_schema")); + assertFalse(resultSet.next()); + } + } + + // The statement is handled locally and not sent to Cloud Spanner. + assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + } + + @Test + public void testSelectCurrentDatabase() throws SQLException { + for (String sql : + new String[] { + "SELECT current_database()", + "select current_database()", + "select * from CURRENT_DATABASE()" + }) { + + try (Connection connection = createConnection()) { + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + assertEquals("d", resultSet.getString("current_database")); + assertFalse(resultSet.next()); + } + } + } + + // The statement is handled locally and not sent to Cloud Spanner. + assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + } + + @Test + public void testSelectCurrentCatalog() throws SQLException { + for (String sql : + new String[] { + "SELECT current_catalog", "select current_catalog", "select * from CURRENT_CATALOG" + }) { + + try (Connection connection = createConnection()) { + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + assertEquals("d", resultSet.getString("current_catalog")); + assertFalse(resultSet.next()); + } + } + } + + // The statement is handled locally and not sent to Cloud Spanner. + assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + } + + @Test + public void testShowSearchPath() throws SQLException { + String sql = "show search_path"; + + try (Connection connection = createConnection()) { + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + assertEquals("public", resultSet.getString("search_path")); + assertFalse(resultSet.next()); + } + } + + // The statement is handled locally and not sent to Cloud Spanner. + assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + } + + @Test + public void testSetSearchPath() throws SQLException { + String sql = "set search_path to public"; + + try (Connection connection = createConnection()) { + try (java.sql.Statement statement = connection.createStatement()) { + assertFalse(statement.execute(sql)); + assertEquals(0, statement.getUpdateCount()); + assertFalse(statement.getMoreResults()); + } + } + + // The statement is handled locally and not sent to Cloud Spanner. + assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + } + + @Test + public void testQueryHint() throws SQLException { + String sql = "/* @OPTIMIZER_VERSION=1 */ SELECT 1"; + mockSpanner.putStatementResult(StatementResult.query(Statement.of(sql), SELECT1_RESULTSET)); + + try (Connection connection = createConnection()) { + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + assertEquals(1L, resultSet.getLong(1)); + assertFalse(resultSet.next()); + } + } + } + + @Test + public void testQueryWithParameters() throws SQLException { + String jdbcSql = + "select col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb " + + "from all_types " + + "where col_bigint=? " + + "and col_bool=? " + + "and col_bytea=? " + + "and col_int=? " + + "and col_float8=? " + + "and col_numeric=? " + + "and col_timestamptz=? " + + "and col_date=? " + + "and col_varchar=? " + + "and col_jsonb=?"; + String pgSql = + "select col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb " + + "from all_types " + + "where col_bigint=$1 " + + "and col_bool=$2 " + + "and col_bytea=$3 " + + "and col_int=$4 " + + "and col_float8=$5 " + + "and col_numeric=$6 " + + "and col_timestamptz=$7 " + + "and col_date=$8 " + + "and col_varchar=$9 " + + "and col_jsonb=$10"; + mockSpanner.putStatementResult(StatementResult.query(Statement.of(pgSql), ALL_TYPES_RESULTSET)); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(pgSql) + .bind("p1") + .to(1L) + .bind("p2") + .to(true) + .bind("p3") + .to(ByteArray.copyFrom("test")) + .bind("p4") + .to(100) + .bind("p5") + .to(3.14d) + .bind("p6") + .to(com.google.cloud.spanner.Value.pgNumeric("6.626")) + .bind("p7") + .to(Timestamp.parseTimestamp("2022-02-16T13:18:02.123457000Z")) + .bind("p8") + .to(Date.parseDate("2022-03-29")) + .bind("p9") + .to("test") + .bind("p10") + .to("{\"key\": \"value\"}") + .build(), + ALL_TYPES_RESULTSET)); + + OffsetDateTime offsetDateTime = + LocalDateTime.of(2022, 2, 16, 13, 18, 2, 123456789).atOffset(ZoneOffset.UTC); + OffsetDateTime truncatedOffsetDateTime = offsetDateTime.truncatedTo(ChronoUnit.MICROS); + + // Threshold 5 is the default. Use a named prepared statement if it is executed 5 times or more. + // Threshold 1 means always use a named prepared statement. + // Threshold 0 means never use a named prepared statement. + // Threshold -1 means use binary transfer of values and use DESCRIBE statement. + // (10 points to you if you guessed the last one up front!). + for (int preparedThreshold : new int[] {5, 1, 0, -1}) { + try (Connection connection = createConnection()) { + try (PreparedStatement preparedStatement = connection.prepareStatement(jdbcSql)) { + preparedStatement.unwrap(PgStatement.class).setPrepareThreshold(preparedThreshold); + int index = 0; + preparedStatement.setLong(++index, 1L); + preparedStatement.setBoolean(++index, true); + preparedStatement.setBytes(++index, "test".getBytes(StandardCharsets.UTF_8)); + preparedStatement.setInt(++index, 100); + preparedStatement.setDouble(++index, 3.14d); + preparedStatement.setBigDecimal(++index, new BigDecimal("6.626")); + preparedStatement.setObject(++index, offsetDateTime); + preparedStatement.setObject(++index, LocalDate.of(2022, 3, 29)); + preparedStatement.setString(++index, "test"); + preparedStatement.setString(++index, "{\"key\": \"value\"}"); + try (ResultSet resultSet = preparedStatement.executeQuery()) { + assertTrue(resultSet.next()); + index = 0; + assertEquals(1L, resultSet.getLong(++index)); + assertTrue(resultSet.getBoolean(++index)); + assertArrayEquals("test".getBytes(StandardCharsets.UTF_8), resultSet.getBytes(++index)); + assertEquals(3.14d, resultSet.getDouble(++index), 0.0d); + assertEquals(100, resultSet.getInt(++index)); + assertEquals(new BigDecimal("6.626"), resultSet.getBigDecimal(++index)); + if (preparedThreshold < 0) { + // The binary format will truncate the timestamp value to microseconds. + assertEquals( + truncatedOffsetDateTime, resultSet.getObject(++index, OffsetDateTime.class)); + } else { + assertEquals(offsetDateTime, resultSet.getObject(++index, OffsetDateTime.class)); + } + assertEquals(LocalDate.of(2022, 3, 29), resultSet.getObject(++index, LocalDate.class)); + assertEquals("test", resultSet.getString(++index)); + assertEquals("{\"key\": \"value\"}", resultSet.getString(++index)); + assertFalse(resultSet.next()); + } + } + } + } + } + + @Test + public void testQueryWithLegacyDateParameter() throws SQLException { + String jdbcSql = "select col_date from all_types where col_date=?"; + String pgSql = "select col_date from all_types where col_date=$1"; + mockSpanner.putStatementResult(StatementResult.query(Statement.of(pgSql), ALL_TYPES_RESULTSET)); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(pgSql).bind("p1").to(Date.parseDate("2022-03-29")).build(), + ALL_TYPES_RESULTSET)); + + // Threshold 5 is the default. Use a named prepared statement if it is executed 5 times or more. + // Threshold 1 means always use a named prepared statement. + // Threshold 0 means never use a named prepared statement. + // Threshold -1 means use binary transfer of values and use DESCRIBE statement. + // (10 points to you if you guessed the last one up front!). + for (int preparedThreshold : new int[] {5, 1, 0, -1}) { + try (Connection connection = createConnection()) { + try (PreparedStatement preparedStatement = connection.prepareStatement(jdbcSql)) { + preparedStatement.unwrap(PgStatement.class).setPrepareThreshold(preparedThreshold); + int index = 0; + preparedStatement.setDate(++index, new java.sql.Date(2022 - 1900, Calendar.MARCH, 29)); + try (ResultSet resultSet = preparedStatement.executeQuery()) { + assertTrue(resultSet.next()); + assertEquals( + new java.sql.Date(2022 - 1900, Calendar.MARCH, 29), resultSet.getDate("col_date")); + assertFalse(resultSet.next()); + } + connection.commit(); + } + } + } + } + + @Test + public void testMultipleQueriesInTransaction() throws SQLException { + String sql = "SELECT 1"; + + try (Connection connection = createConnection()) { + // Force the use of prepared statements. + connection.unwrap(PGConnection.class).setPrepareThreshold(-1); + for (int i = 0; i < 2; i++) { + // https://github.com/GoogleCloudPlatform/pgadapter/issues/278 + // This would return `ERROR: FAILED_PRECONDITION: This ResultSet is closed` + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + assertEquals(1L, resultSet.getLong(1)); + assertFalse(resultSet.next()); + } + } + } + } + + @Test + public void testQueryWithNonExistingTable() throws SQLException { + String sql = "select * from non_existing_table where id=?"; + String pgSql = "select * from non_existing_table where id=$1"; + mockSpanner.putStatementResult( + StatementResult.exception( + Statement.newBuilder(pgSql).bind("p1").to(1L).build(), + Status.NOT_FOUND + .withDescription("Table non_existing_table not found") + .asRuntimeException())); + try (Connection connection = createConnection()) { + try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) { + preparedStatement.setLong(1, 1L); + SQLException exception = assertThrows(SQLException.class, preparedStatement::executeQuery); + assertEquals( + "ERROR: Table non_existing_table not found - Statement: 'select * from non_existing_table where id=$1'", + exception.getMessage()); + } + } + } + + @Test + public void testDmlWithNonExistingTable() throws SQLException { + String sql = "update non_existing_table set value=? where id=?"; + String pgSql = "update non_existing_table set value=$1 where id=$2"; + mockSpanner.putStatementResult( + StatementResult.exception( + Statement.newBuilder(pgSql).bind("p1").to("foo").bind("p2").to(1L).build(), + Status.NOT_FOUND + .withDescription("Table non_existing_table not found") + .asRuntimeException())); + try (Connection connection = createConnection()) { + try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) { + preparedStatement.setString(1, "foo"); + preparedStatement.setLong(2, 1L); + SQLException exception = assertThrows(SQLException.class, preparedStatement::executeUpdate); + assertEquals("ERROR: Table non_existing_table not found", exception.getMessage()); + } + } + } + + @Test + public void testNullValues() throws SQLException { + mockSpanner.putStatementResult( + StatementResult.update( + Statement.newBuilder( + "insert into all_types " + + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar) " + + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9)") + .bind("p1") + .to(2L) + .bind("p2") + .to((Boolean) null) + .bind("p3") + .to((ByteArray) null) + .bind("p4") + .to((Double) null) + .bind("p5") + .to((Long) null) + .bind("p6") + .to(com.google.cloud.spanner.Value.pgNumeric(null)) + .bind("p7") + .to((com.google.cloud.spanner.Value) null) + .bind("p8") + .to((Date) null) + .bind("p9") + .to((String) null) + .build(), + 1L)); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of("select * from all_types where col_bigint is null"), + ALL_TYPES_NULLS_RESULTSET)); + + try (Connection connection = createConnection()) { + try (PreparedStatement statement = + connection.prepareStatement( + "insert into all_types " + + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar) " + + "values (?, ?, ?, ?, ?, ?, ?, ?, ?)")) { + int index = 0; + statement.setLong(++index, 2); + statement.setNull(++index, Types.BOOLEAN); + statement.setNull(++index, Types.BINARY); + statement.setNull(++index, Types.DOUBLE); + statement.setNull(++index, Types.INTEGER); + statement.setNull(++index, Types.NUMERIC); + statement.setNull(++index, Types.TIMESTAMP_WITH_TIMEZONE); + statement.setNull(++index, Types.DATE); + statement.setNull(++index, Types.VARCHAR); + + assertEquals(1, statement.executeUpdate()); + } + + try (ResultSet resultSet = + connection + .createStatement() + .executeQuery("select * from all_types where col_bigint is null")) { + assertTrue(resultSet.next()); + + int index = 0; + // Note: JDBC returns the zero-value for primitive types if the value is NULL, and you have + // to call wasNull() to determine whether the value was NULL or zero. + assertEquals(0L, resultSet.getLong(++index)); + assertTrue(resultSet.wasNull()); + assertFalse(resultSet.getBoolean(++index)); + assertTrue(resultSet.wasNull()); + assertNull(resultSet.getBytes(++index)); + assertTrue(resultSet.wasNull()); + assertEquals(0d, resultSet.getDouble(++index), 0.0d); + assertTrue(resultSet.wasNull()); + assertNull(resultSet.getBigDecimal(++index)); + assertTrue(resultSet.wasNull()); + assertNull(resultSet.getTimestamp(++index)); + assertTrue(resultSet.wasNull()); + assertNull(resultSet.getDate(++index)); + assertTrue(resultSet.wasNull()); + assertNull(resultSet.getString(++index)); + assertTrue(resultSet.wasNull()); + + assertFalse(resultSet.next()); + } + } + } + + @Test + public void testDescribeQueryWithNonExistingTable() throws SQLException { + String sql = "select * from non_existing_table where id=$1"; + mockSpanner.putStatementResult( + StatementResult.exception( + Statement.of(sql), + Status.NOT_FOUND + .withDescription("Table non_existing_table not found") + .asRuntimeException())); + try (Connection connection = createConnection()) { + try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) { + SQLException exception = + assertThrows(SQLException.class, preparedStatement::getParameterMetaData); + assertEquals( + "ERROR: Table non_existing_table not found - Statement: 'select * from non_existing_table where id=$1'", + exception.getMessage()); + } + } + } + + @Test + public void testDescribeDmlWithNonExistingTable() throws SQLException { + String sql = "update non_existing_table set value=$2 where id=$1"; + mockSpanner.putStatementResult( + StatementResult.exception( + Statement.of(sql), + Status.NOT_FOUND + .withDescription("Table non_existing_table not found") + .asRuntimeException())); + String describeSql = + "select $1, $2 from (select value=$2 from non_existing_table where id=$1) p"; + mockSpanner.putStatementResult( + StatementResult.exception( + Statement.of(describeSql), + Status.NOT_FOUND + .withDescription("Table non_existing_table not found") + .asRuntimeException())); + try (Connection connection = createConnection()) { + try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) { + SQLException exception = + assertThrows(SQLException.class, preparedStatement::getParameterMetaData); + assertEquals("ERROR: Table non_existing_table not found", exception.getMessage()); + } + } + } + + @Test + public void testDescribeDmlWithSchemaPrefix() throws SQLException { + String sql = "update public.my_table set value=? where id=?"; + String describeSql = "update public.my_table set value=$1 where id=$2"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(describeSql), + com.google.spanner.v1.ResultSet.newBuilder() + .setMetadata( + createParameterTypesMetadata(ImmutableList.of(TypeCode.STRING, TypeCode.INT64))) + .setStats(ResultSetStats.newBuilder().build()) + .build())); + try (Connection connection = createConnection()) { + try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) { + ParameterMetaData metadata = preparedStatement.getParameterMetaData(); + assertEquals(Types.VARCHAR, metadata.getParameterType(1)); + assertEquals(Types.BIGINT, metadata.getParameterType(2)); + } + } + } + + @Test + public void testDescribeDmlWithQuotedSchemaPrefix() throws SQLException { + String sql = "update \"public\".\"my_table\" set value=? where id=?"; + String describeSql = "update \"public\".\"my_table\" set value=$1 where id=$2"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(describeSql), + com.google.spanner.v1.ResultSet.newBuilder() + .setMetadata( + createParameterTypesMetadata(ImmutableList.of(TypeCode.STRING, TypeCode.INT64))) + .setStats(ResultSetStats.newBuilder().build()) + .build())); + try (Connection connection = createConnection()) { + try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) { + ParameterMetaData metadata = preparedStatement.getParameterMetaData(); + assertEquals(Types.VARCHAR, metadata.getParameterType(1)); + assertEquals(Types.BIGINT, metadata.getParameterType(2)); + } + } + } + + @Test + public void testTwoDmlStatements() throws SQLException { + try (Connection connection = createConnection()) { + try (java.sql.Statement statement = connection.createStatement()) { + // The PG JDBC driver will internally split the following SQL string into two statements and + // execute these sequentially. We still get the results back as if they were executed as one + // batch on the same statement. + assertFalse(statement.execute(String.format("%s;%s;", INSERT_STATEMENT, UPDATE_STATEMENT))); + + // Note that we have sent two DML statements to the database in one string. These should be + // treated as separate statements, and there should therefore be two results coming back + // from the server. That is; The first update count should be 1 (the INSERT), and the second + // should be 2 (the UPDATE). + assertEquals(1, statement.getUpdateCount()); + + // The following is a prime example of how not to design an API, but this is how JDBC works. + // getMoreResults() returns true if the next result is a ResultSet. However, if the next + // result is an update count, it returns false, and we have to check getUpdateCount() to + // verify whether there were any more results. + assertFalse(statement.getMoreResults()); + assertEquals(2, statement.getUpdateCount()); + + // There are no more results. This is indicated by getMoreResults returning false AND + // getUpdateCount returning -1. + assertFalse(statement.getMoreResults()); + assertEquals(-1, statement.getUpdateCount()); + } + } + } + + @Test + public void testTwoDmlStatements_withError() throws SQLException { + try (Connection connection = createConnection()) { + try (java.sql.Statement statement = connection.createStatement()) { + SQLException exception = + assertThrows( + SQLException.class, + () -> statement.execute(String.format("%s;%s;", INSERT_STATEMENT, INVALID_DML))); + assertEquals("ERROR: Statement is invalid.", exception.getMessage()); + } + } + } + + @Test + public void testJdbcBatch() throws SQLException { + try (Connection connection = createConnection()) { + try (java.sql.Statement statement = connection.createStatement()) { + statement.addBatch(INSERT_STATEMENT.getSql()); + statement.addBatch(UPDATE_STATEMENT.getSql()); + int[] updateCounts = statement.executeBatch(); + + assertEquals(2, updateCounts.length); + assertEquals(1, updateCounts[0]); + assertEquals(2, updateCounts[1]); + } + } + } + + @Test + public void testTwoQueries() throws SQLException { + try (Connection connection = createConnection()) { + try (java.sql.Statement statement = connection.createStatement()) { + // Statement#execute(String) returns true if the result is a result set. + assertTrue(statement.execute("SELECT 1;SELECT 2;")); + + try (ResultSet resultSet = statement.getResultSet()) { + assertTrue(resultSet.next()); + assertEquals(1L, resultSet.getLong(1)); + assertFalse(resultSet.next()); + } + + // getMoreResults() returns true if the next result is a ResultSet. + assertTrue(statement.getMoreResults()); + try (ResultSet resultSet = statement.getResultSet()) { + assertTrue(resultSet.next()); + assertEquals(2L, resultSet.getLong(1)); + assertFalse(resultSet.next()); + } + + // getMoreResults() should now return false. We should also check getUpdateCount() as that + // method should return -1 to indicate that there is also no update count available. + assertFalse(statement.getMoreResults()); + assertEquals(-1, statement.getUpdateCount()); + } + } + } + + @Test + public void testPreparedStatement() throws SQLException { + mockSpanner.putStatementResult( + StatementResult.update( + Statement.newBuilder( + "insert into all_types " + + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar) " + + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9)") + .bind("p1") + .to(2L) + .bind("p2") + .to((Boolean) null) + .bind("p3") + .to((ByteArray) null) + .bind("p4") + .to((Double) null) + .bind("p5") + .to((Long) null) + .bind("p6") + .to(com.google.cloud.spanner.Value.pgNumeric(null)) + .bind("p7") + .to((Timestamp) null) + .bind("p8") + .to((Date) null) + .bind("p9") + .to((String) null) + .build(), + 1L)); + mockSpanner.putStatementResult( + StatementResult.update( + Statement.newBuilder( + "insert into all_types " + + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar) " + + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9)") + .bind("p1") + .to(1L) + .bind("p2") + .to(true) + .bind("p3") + .to(ByteArray.copyFrom("test")) + .bind("p4") + .to(3.14d) + .bind("p5") + .to(100L) + .bind("p6") + .to(com.google.cloud.spanner.Value.pgNumeric("6.626")) + .bind("p7") + .to(Timestamp.parseTimestamp("2022-02-16T13:18:02.123457000")) + .bind("p8") + .to(Date.parseDate("2022-03-29")) + .bind("p9") + .to("test") + .build(), + 1L)); + + OffsetDateTime zonedDateTime = + LocalDateTime.of(2022, 2, 16, 13, 18, 2, 123456789).atOffset(ZoneOffset.UTC); + try (Connection connection = createConnection()) { + try (PreparedStatement statement = + connection.prepareStatement( + "insert into all_types " + + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar) " + + "values (?, ?, ?, ?, ?, ?, ?, ?, ?)")) { + PGStatement pgStatement = statement.unwrap(PGStatement.class); + pgStatement.setPrepareThreshold(1); + + int index = 0; + statement.setLong(++index, 1L); + statement.setBoolean(++index, true); + statement.setBytes(++index, "test".getBytes(StandardCharsets.UTF_8)); + statement.setDouble(++index, 3.14d); + statement.setInt(++index, 100); + statement.setBigDecimal(++index, new BigDecimal("6.626")); + statement.setObject(++index, zonedDateTime); + statement.setObject(++index, LocalDate.of(2022, 3, 29)); + statement.setString(++index, "test"); + + assertEquals(1, statement.executeUpdate()); + + index = 0; + statement.setLong(++index, 2); + statement.setNull(++index, Types.BOOLEAN); + statement.setNull(++index, Types.BINARY); + statement.setNull(++index, Types.DOUBLE); + statement.setNull(++index, Types.INTEGER); + statement.setNull(++index, Types.NUMERIC); + statement.setNull(++index, Types.TIMESTAMP_WITH_TIMEZONE); + statement.setNull(++index, Types.DATE); + statement.setNull(++index, Types.VARCHAR); + + assertEquals(1, statement.executeUpdate()); + } + } + } + + @Test + public void testCursorSuccess() throws SQLException { + try (Connection connection = createConnection()) { + connection.setAutoCommit(false); + try (PreparedStatement statement = connection.prepareStatement(SELECT_FIVE_ROWS.getSql())) { + // Fetch two rows at a time from the PG server. + statement.setFetchSize(2); + try (ResultSet resultSet = statement.executeQuery()) { + int index = 0; + while (resultSet.next()) { + assertEquals(++index, resultSet.getInt(1)); + } + assertEquals(5, index); + } + } + connection.commit(); + } + } + + @Test + public void testCursorFailsHalfway() throws SQLException { + mockSpanner.setExecuteStreamingSqlExecutionTime( + SimulatedExecutionTime.ofStreamException(Status.DATA_LOSS.asRuntimeException(), 2)); + try (Connection connection = createConnection()) { + connection.setAutoCommit(false); + try (PreparedStatement statement = connection.prepareStatement(SELECT_FIVE_ROWS.getSql())) { + // Fetch one row at a time from the PG server. + statement.setFetchSize(1); + try (ResultSet resultSet = statement.executeQuery()) { + // The first row should succeed. + assertTrue(resultSet.next()); + // The second row should fail. + assertThrows(SQLException.class, resultSet::next); + } + } + connection.rollback(); + } + } + + @Test + public void testRandomResults() throws SQLException { + // TODO: Enable binary transfer for this test when binary transfer of date values prior + // to the Julian/Gregorian switch has been fixed. + for (boolean binary : new boolean[] {false}) { + // Also get the random results using the normal Spanner client to compare the results with + // what is returned by PGAdapter. + Spanner spanner = + SpannerOptions.newBuilder() + .setProjectId("p") + .setHost(String.format("http://localhost:%d", spannerServer.getPort())) + .setCredentials(NoCredentials.getInstance()) + .setChannelConfigurator(ManagedChannelBuilder::usePlaintext) + .setClientLibToken("pg-adapter") + .build() + .getService(); + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of("p", "i", "d")); + com.google.cloud.spanner.ResultSet spannerResult = + client.singleUse().executeQuery(SELECT_RANDOM); + + String binaryTransferEnable = + "?binaryTransferEnable=" + + ImmutableList.of( + Oid.BOOL, + Oid.BYTEA, + Oid.VARCHAR, + Oid.NUMERIC, + Oid.FLOAT8, + Oid.INT8, + Oid.DATE, + Oid.TIMESTAMPTZ) + .stream() + .map(String::valueOf) + .collect(Collectors.joining(",")); + + final int fetchSize = 3; + try (Connection connection = createConnection(binaryTransferEnable)) { + connection.createStatement().execute("set time zone utc"); + connection.setAutoCommit(false); + connection.unwrap(PGConnection.class).setPrepareThreshold(binary ? -1 : 5); + try (PreparedStatement statement = connection.prepareStatement(SELECT_RANDOM.getSql())) { + statement.setFetchSize(fetchSize); + try (ResultSet resultSet = statement.executeQuery()) { + int rowCount = 0; + while (resultSet.next()) { + assertTrue(spannerResult.next()); + for (int col = 0; col < resultSet.getMetaData().getColumnCount(); col++) { + // TODO: Remove once we have a replacement for pg_type, as the JDBC driver will try + // to read type information from the backend when it hits an 'unknown' type (jsonb + // is not one of the types that the JDBC driver will load automatically). + if (col == 5 || col == 14) { + resultSet.getString(col + 1); + } else { + resultSet.getObject(col + 1); + } + } + assertEqual(spannerResult, resultSet, binary); + rowCount++; + } + assertEquals(RANDOM_RESULTS_ROW_COUNT, rowCount); + } + } + connection.commit(); + } + + // Close the resources used by the normal Spanner client. + spannerResult.close(); + spanner.close(); + } + } + + private void assertEqual( + com.google.cloud.spanner.ResultSet spannerResult, ResultSet pgResult, boolean binary) + throws SQLException { + assertEquals(spannerResult.getColumnCount(), pgResult.getMetaData().getColumnCount()); + for (int col = 0; col < spannerResult.getColumnCount(); col++) { + if (spannerResult.isNull(col)) { + assertNull(pgResult.getObject(col + 1)); + assertTrue(pgResult.wasNull()); + continue; + } + + switch (spannerResult.getColumnType(col).getCode()) { + case BOOL: + if (!binary) { + // Skip for binary for now, as there is a bug in the PG JDBC driver for decoding binary + // bool values. + assertEquals(spannerResult.getBoolean(col), pgResult.getBoolean(col + 1)); + } + break; + case INT64: + assertEquals(spannerResult.getLong(col), pgResult.getLong(col + 1)); + break; + case FLOAT64: + assertEquals(spannerResult.getDouble(col), pgResult.getDouble(col + 1), 0.0d); + break; + case PG_NUMERIC: + case STRING: + assertEquals(spannerResult.getString(col), pgResult.getString(col + 1)); + break; + case BYTES: + assertArrayEquals(spannerResult.getBytes(col).toByteArray(), pgResult.getBytes(col + 1)); + break; + case TIMESTAMP: + // Compare milliseconds, as PostgreSQL does not natively support nanosecond precision, and + // this is lost when using binary encoding. + assertEquals( + spannerResult.getTimestamp(col).toSqlTimestamp().getTime(), + pgResult.getTimestamp(col + 1).getTime()); + break; + case DATE: + assertEquals( + LocalDate.of( + spannerResult.getDate(col).getYear(), + spannerResult.getDate(col).getMonth(), + spannerResult.getDate(col).getDayOfMonth()), + pgResult.getDate(col + 1).toLocalDate()); + break; + case PG_JSONB: + assertEquals(spannerResult.getPgJsonb(col), pgResult.getString(col + 1)); + break; + case ARRAY: + break; + case NUMERIC: + case JSON: + case STRUCT: + fail("unsupported PG type: " + spannerResult.getColumnType(col)); + } + } + } + + @Test + public void testShowValidSetting() throws SQLException { + try (Connection connection = createConnection()) { + try (ResultSet resultSet = + connection.createStatement().executeQuery("show application_name")) { + assertTrue(resultSet.next()); + assertEquals(getExpectedInitialApplicationName(), resultSet.getString(1)); + assertFalse(resultSet.next()); + } + } + } + + @Test + public void testShowSettingWithStartupValue() throws SQLException { + try (Connection connection = createConnection()) { + // DATESTYLE is set to 'ISO' by the JDBC driver at startup. + try (ResultSet resultSet = connection.createStatement().executeQuery("show DATESTYLE")) { + assertTrue(resultSet.next()); + assertEquals("ISO", resultSet.getString(1)); + assertFalse(resultSet.next()); + } + } + } + + @Test + public void testShowInvalidSetting() throws SQLException { + try (Connection connection = createConnection()) { + SQLException exception = + assertThrows( + SQLException.class, + () -> connection.createStatement().executeQuery("show random_setting")); + assertEquals( + "ERROR: unrecognized configuration parameter \"random_setting\"", exception.getMessage()); + } + } + + @Test + public void testSetValidSetting() throws SQLException { + try (Connection connection = createConnection()) { + connection.createStatement().execute("set application_name to 'my-application'"); + + try (ResultSet resultSet = + connection.createStatement().executeQuery("show application_name ")) { + assertTrue(resultSet.next()); + assertEquals("my-application", resultSet.getString(1)); + assertFalse(resultSet.next()); + } + } + } + + @Test + public void testSetCaseInsensitiveSetting() throws SQLException { + try (Connection connection = createConnection()) { + // The setting is called 'DateStyle' in the pg_settings table. + connection.createStatement().execute("set datestyle to 'iso'"); + + try (ResultSet resultSet = connection.createStatement().executeQuery("show DATESTYLE")) { + assertTrue(resultSet.next()); + assertEquals("iso", resultSet.getString(1)); + assertFalse(resultSet.next()); + } + } + } + + @Test + public void testSetInvalidSetting() throws SQLException { + try (Connection connection = createConnection()) { + SQLException exception = + assertThrows( + SQLException.class, + () -> + connection.createStatement().executeQuery("set random_setting to 'some-value'")); + assertEquals( + "ERROR: unrecognized configuration parameter \"random_setting\"", exception.getMessage()); + } + } + + @Test + public void testResetValidSetting() throws SQLException { + try (Connection connection = createConnection()) { + connection.createStatement().execute("set application_name to 'my-application'"); + + try (ResultSet resultSet = + connection.createStatement().executeQuery("show application_name ")) { + assertTrue(resultSet.next()); + assertEquals("my-application", resultSet.getString(1)); + assertFalse(resultSet.next()); + } + + connection.createStatement().execute("reset application_name"); + + try (ResultSet resultSet = + connection.createStatement().executeQuery("show application_name ")) { + assertTrue(resultSet.next()); + assertNull(resultSet.getString(1)); + assertFalse(resultSet.next()); + } + } + } + + @Test + public void testResetSettingWithStartupValue() throws SQLException { + try (Connection connection = createConnection()) { + try (ResultSet resultSet = connection.createStatement().executeQuery("show datestyle")) { + assertTrue(resultSet.next()); + assertEquals("ISO", resultSet.getString(1)); + assertFalse(resultSet.next()); + } + + connection.createStatement().execute("set datestyle to 'iso, ymd'"); + + try (ResultSet resultSet = connection.createStatement().executeQuery("show datestyle")) { + assertTrue(resultSet.next()); + assertEquals("iso, ymd", resultSet.getString(1)); + assertFalse(resultSet.next()); + } + + connection.createStatement().execute("reset datestyle"); + + try (ResultSet resultSet = connection.createStatement().executeQuery("show datestyle")) { + assertTrue(resultSet.next()); + assertEquals("ISO", resultSet.getString(1)); + assertFalse(resultSet.next()); + } + } + } + + @Test + public void testResetInvalidSetting() throws SQLException { + try (Connection connection = createConnection()) { + SQLException exception = + assertThrows( + SQLException.class, + () -> connection.createStatement().executeQuery("reset random_setting")); + assertEquals( + "ERROR: unrecognized configuration parameter \"random_setting\"", exception.getMessage()); + } + } + + @Test + public void testShowUndefinedExtensionSetting() throws SQLException { + try (Connection connection = createConnection()) { + SQLException exception = + assertThrows( + SQLException.class, + () -> connection.createStatement().executeQuery("show spanner.some_setting")); + assertEquals( + "ERROR: unrecognized configuration parameter \"spanner.some_setting\"", + exception.getMessage()); + } + } + + @Test + public void testSetExtensionSetting() throws SQLException { + try (Connection connection = createConnection()) { + connection.createStatement().execute("set spanner.some_setting to 'some-value'"); + + try (ResultSet resultSet = + connection.createStatement().executeQuery("show spanner.some_setting ")) { + assertTrue(resultSet.next()); + assertEquals("some-value", resultSet.getString(1)); + assertFalse(resultSet.next()); + } + } + } + + @Test + public void testResetValidExtensionSetting() throws SQLException { + try (Connection connection = createConnection()) { + connection.createStatement().execute("set spanner.some_setting to 'some-value'"); + + try (ResultSet resultSet = + connection.createStatement().executeQuery("show spanner.some_setting")) { + assertTrue(resultSet.next()); + assertEquals("some-value", resultSet.getString(1)); + assertFalse(resultSet.next()); + } + + connection.createStatement().execute("reset spanner.some_setting"); + + try (ResultSet resultSet = + connection.createStatement().executeQuery("show spanner.some_setting")) { + assertTrue(resultSet.next()); + assertNull(resultSet.getString(1)); + assertFalse(resultSet.next()); + } + } + } + + @Test + public void testResetUndefinedExtensionSetting() throws SQLException { + try (Connection connection = createConnection()) { + // Resetting an undefined extension setting is allowed by PostgreSQL, and will effectively set + // the extension setting to null. + connection.createStatement().execute("reset spanner.some_setting"); + + verifySettingIsNull(connection, "spanner.some_setting"); + } + } + + @Test + public void testCommitSet() throws SQLException { + try (Connection connection = createConnection()) { + connection.setAutoCommit(false); + + // Verify that the initial value is 'PostgreSQL JDBC Driver'. + verifySettingValue(connection, "application_name", getExpectedInitialApplicationName()); + connection.createStatement().execute("set application_name to \"my-application\""); + verifySettingValue(connection, "application_name", "my-application"); + // Committing the transaction should persist the value. + connection.commit(); + verifySettingValue(connection, "application_name", "my-application"); + } + } + + @Test + public void testRollbackSet() throws SQLException { + try (Connection connection = createConnection()) { + connection.setAutoCommit(false); + + // Verify that the initial value is null. + verifySettingValue(connection, "application_name", getExpectedInitialApplicationName()); + connection.createStatement().execute("set application_name to \"my-application\""); + verifySettingValue(connection, "application_name", "my-application"); + // Rolling back the transaction should reset the value to what it was before the transaction. + connection.rollback(); + verifySettingValue(connection, "application_name", getExpectedInitialApplicationName()); + } + } + + @Test + public void testCommitSetExtension() throws SQLException { + try (Connection connection = createConnection()) { + connection.setAutoCommit(false); + + connection.createStatement().execute("set spanner.random_setting to \"42\""); + verifySettingValue(connection, "spanner.random_setting", "42"); + // Committing the transaction should persist the value. + connection.commit(); + verifySettingValue(connection, "spanner.random_setting", "42"); + } + } + + @Test + public void testRollbackSetExtension() throws SQLException { + try (Connection connection = createConnection()) { + connection.setAutoCommit(false); + + connection.createStatement().execute("set spanner.random_setting to \"42\""); + verifySettingValue(connection, "spanner.random_setting", "42"); + // Rolling back the transaction should reset the value to what it was before the transaction. + // In this case, that means that it should be undefined. + connection.rollback(); + verifySettingIsUnrecognized(connection, "spanner.random_setting"); + } + } + + @Test + public void testRollbackDefinedExtension() throws SQLException { + try (Connection connection = createConnection()) { + // First define the extension setting. + connection.createStatement().execute("set spanner.random_setting to '100'"); + connection.commit(); + + connection.createStatement().execute("set spanner.random_setting to \"42\""); + verifySettingValue(connection, "spanner.random_setting", "42"); + // Rolling back the transaction should reset the value to what it was before the transaction. + // In this case, that means back to '100'. + connection.rollback(); + verifySettingValue(connection, "spanner.random_setting", "100"); + } + } + + @Test + public void testCommitSetLocal() throws SQLException { + try (Connection connection = createConnection()) { + connection.setAutoCommit(false); + + // Verify that the initial value is 'PostgreSQL JDBC Driver'. + verifySettingValue(connection, "application_name", getExpectedInitialApplicationName()); + connection.createStatement().execute("set local application_name to \"my-application\""); + verifySettingValue(connection, "application_name", "my-application"); + // Committing the transaction should not persist the value as it was only set for the current + // transaction. + connection.commit(); + verifySettingValue(connection, "application_name", getExpectedInitialApplicationName()); + } + } + + @Test + public void testCommitSetLocalAndSession() throws SQLException { + try (Connection connection = createConnection()) { + connection.setAutoCommit(false); + + // Verify that the initial value is 'PostgreSQL JDBC Driver'. + verifySettingValue(connection, "application_name", getExpectedInitialApplicationName()); + // Set both a session and a local value. The session value will be 'hidden' by the local + // value, but the session value will be committed. + connection + .createStatement() + .execute("set session application_name to \"my-session-application\""); + verifySettingValue(connection, "application_name", "my-session-application"); + connection + .createStatement() + .execute("set local application_name to \"my-local-application\""); + verifySettingValue(connection, "application_name", "my-local-application"); + // Committing the transaction should persist the session value. + connection.commit(); + verifySettingValue(connection, "application_name", "my-session-application"); + } + } + + @Test + public void testCommitSetLocalAndSessionExtension() throws SQLException { + try (Connection connection = createConnection()) { + // Verify that the initial value is undefined. + verifySettingIsUnrecognized(connection, "spanner.custom_setting"); + connection.rollback(); + + // Set both a session and a local value. The session value will be 'hidden' by the local + // value, but the session value will be committed. + connection.createStatement().execute("set spanner.custom_setting to 'session-value'"); + verifySettingValue(connection, "spanner.custom_setting", "session-value"); + connection.createStatement().execute("set local spanner.custom_setting to 'local-value'"); + verifySettingValue(connection, "spanner.custom_setting", "local-value"); + // Committing the transaction should persist the session value. + connection.commit(); + verifySettingValue(connection, "spanner.custom_setting", "session-value"); + } + } + + @Test + public void testInvalidShowAbortsTransaction() throws SQLException { + try (Connection connection = createConnection()) { + connection.setAutoCommit(false); + + // Verify that executing an invalid show statement will abort the transaction. + assertThrows( + SQLException.class, + () -> connection.createStatement().execute("show spanner.non_existing_param")); + SQLException exception = + assertThrows( + SQLException.class, + () -> connection.createStatement().execute("show application_name ")); + assertEquals("ERROR: " + TRANSACTION_ABORTED_ERROR, exception.getMessage()); + + connection.rollback(); + + // Verify that the connection is usable again. + verifySettingValue(connection, "application_name", getExpectedInitialApplicationName()); + } + } + + @Test + public void testShowAll() throws SQLException { + try (Connection connection = createConnection()) { + try (ResultSet resultSet = connection.createStatement().executeQuery("show all")) { + assertEquals(3, resultSet.getMetaData().getColumnCount()); + assertEquals("name", resultSet.getMetaData().getColumnName(1)); + assertEquals("setting", resultSet.getMetaData().getColumnName(2)); + assertEquals("description", resultSet.getMetaData().getColumnName(3)); + int count = 0; + while (resultSet.next()) { + if ("client_encoding".equals(resultSet.getString("name"))) { + assertEquals("UTF8", resultSet.getString("setting")); + } + count++; + } + assertEquals(358, count); + } + } + } + + @Test + public void testResetAll() throws SQLException { + try (Connection connection = createConnection()) { + connection.createStatement().execute("set application_name to 'my-app'"); + connection.createStatement().execute("set search_path to 'my_schema'"); + verifySettingValue(connection, "application_name", "my-app"); + verifySettingValue(connection, "search_path", "my_schema"); + + connection.createStatement().execute("reset all"); + + verifySettingIsNull(connection, "application_name"); + verifySettingValue(connection, "search_path", "public"); + } + } + + @Test + public void testSetToDefault() throws SQLException { + try (Connection connection = createConnection()) { + connection.createStatement().execute("set application_name to 'my-app'"); + connection.createStatement().execute("set search_path to 'my_schema'"); + verifySettingValue(connection, "application_name", "my-app"); + verifySettingValue(connection, "search_path", "my_schema"); + + connection.createStatement().execute("set application_name to default"); + connection.createStatement().execute("set search_path to default"); + + verifySettingIsNull(connection, "application_name"); + verifySettingValue(connection, "search_path", "public"); + } + } + + @Test + public void testSetToEmpty() throws SQLException { + try (Connection connection = createConnection()) { + connection.createStatement().execute("set application_name to ''"); + verifySettingValue(connection, "application_name", ""); + } + } + + @Test + public void testSetTimeZone() throws SQLException { + try (Connection connection = createConnection()) { + connection.createStatement().execute("set time zone 'IST'"); + verifySettingValue(connection, "timezone", "Asia/Kolkata"); + } + } + + @Test + public void testSetTimeZoneToDefault() throws SQLException { + try (Connection connection = createConnection("?options=-c%%20timezone=IST")) { + connection.createStatement().execute("set time zone default"); + verifySettingValue(connection, "timezone", "Asia/Kolkata"); + } + } + + @Test + public void testSetTimeZoneToLocal() throws SQLException { + try (Connection connection = createConnection("?options=-c%%20timezone=IST")) { + connection.createStatement().execute("set time zone local"); + verifySettingValue(connection, "timezone", "Asia/Kolkata"); + } + } + + @Test + public void testSetTimeZoneWithTransactionCommit() throws SQLException { + try (Connection connection = createConnection()) { + connection.setAutoCommit(false); + connection.createStatement().execute("set time zone 'UTC'"); + verifySettingValue(connection, "timezone", "UTC"); + connection.commit(); + verifySettingValue(connection, "timezone", "UTC"); + } + } + + @Test + public void testSetTimeZoneWithTransactionRollback() throws SQLException { + try (Connection connection = createConnection()) { + String originalTimeZone = null; + try (ResultSet rs = connection.createStatement().executeQuery("SHOW TIMEZONE")) { + assertTrue(rs.next()); + originalTimeZone = rs.getString(1); + assertFalse(rs.next()); + } + assertNotNull(originalTimeZone); + connection.setAutoCommit(false); + + connection.createStatement().execute("set time zone 'UTC'"); + verifySettingValue(connection, "time zone", "UTC"); + connection.rollback(); + verifySettingValue(connection, "time zone", originalTimeZone); + } + } + + @Test + public void testSettingsAreUniqueToConnections() throws SQLException { + // Verify that each new connection gets a separate set of settings. + for (int connectionNum = 0; connectionNum < 5; connectionNum++) { + try (Connection connection = createConnection()) { + // Verify that the initial value is 'PostgreSQL JDBC Driver'. + verifySettingValue(connection, "application_name", getExpectedInitialApplicationName()); + connection.createStatement().execute("set application_name to \"my-application\""); + verifySettingValue(connection, "application_name", "my-application"); + } + } + } + + @Test + public void testSettingInConnectionOptions() throws SQLException { + try (Connection connection = + createConnection( + "?options=-c%%20spanner.ddl_transaction_mode=AutocommitExplicitTransaction")) { + verifySettingValue( + connection, "spanner.ddl_transaction_mode", "AutocommitExplicitTransaction"); + } + } + + @Test + public void testMultipleSettingsInConnectionOptions() throws SQLException { + try (Connection connection = + createConnection( + "?options=-c%%20spanner.setting1=value1%%20-c%%20spanner.setting2=value2")) { + verifySettingValue(connection, "spanner.setting1", "value1"); + verifySettingValue(connection, "spanner.setting2", "value2"); + } + } + + @Test + public void testServerVersionInConnectionOptions() throws SQLException { + try (Connection connection = createConnection("?options=-c%%20server_version=4.1")) { + verifySettingValue(connection, "server_version", "4.1"); + verifySettingValue(connection, "server_version_num", "40001"); + } + } + + @Test + public void testCustomServerVersionInConnectionOptions() throws SQLException { + try (Connection connection = + createConnection("?options=-c%%20server_version=5.2 custom version")) { + verifySettingValue(connection, "server_version", "5.2 custom version"); + verifySettingValue(connection, "server_version_num", "50002"); + } + } + + private void verifySettingIsNull(Connection connection, String setting) throws SQLException { + try (ResultSet resultSet = + connection.createStatement().executeQuery(String.format("show %s", setting))) { + assertTrue(resultSet.next()); + assertNull(resultSet.getString(1)); + assertFalse(resultSet.next()); + } + } + + private void verifySettingValue(Connection connection, String setting, String value) + throws SQLException { + try (ResultSet resultSet = + connection.createStatement().executeQuery(String.format("show %s", setting))) { + assertTrue(resultSet.next()); + assertEquals(value, resultSet.getString(1)); + assertFalse(resultSet.next()); + } + } + + private void verifySettingIsUnrecognized(Connection connection, String setting) { + SQLException exception = + assertThrows( + SQLException.class, + () -> connection.createStatement().execute(String.format("show %s", setting))); + assertEquals( + String.format("ERROR: unrecognized configuration parameter \"%s\"", setting), + exception.getMessage()); + } +} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java index 785d00389..7b6180323 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java @@ -312,6 +312,22 @@ protected static ResultSetMetadata createMetadata(ImmutableList types) return ResultSetMetadata.newBuilder().setRowType(builder.build()).build(); } + protected static ResultSetMetadata createParameterTypesMetadata(ImmutableList types) { + StructType.Builder builder = StructType.newBuilder(); + for (int index = 0; index < types.size(); index++) { + builder.addFields( + Field.newBuilder() + .setType( + Type.newBuilder() + .setCode(types.get(index)) + .setTypeAnnotation(getTypeAnnotationCode(types.get(index))) + .build()) + .setName("p" + (index + 1)) + .build()); + } + return ResultSetMetadata.newBuilder().setUndeclaredParameters(builder.build()).build(); + } + protected static ResultSetMetadata createMetadata( ImmutableList types, ImmutableList names) { Preconditions.checkArgument( diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/CopyInMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/CopyInMockServerTest.java index 8b704d56e..433aa7e71 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/CopyInMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/CopyInMockServerTest.java @@ -30,6 +30,7 @@ import com.google.cloud.spanner.MockSpannerServiceImpl; import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.pgadapter.error.SQLState; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.common.base.Strings; import com.google.common.hash.Hashing; @@ -765,7 +766,8 @@ public void testCopyIn_QueryDuringCopy() + "XX000\n" + "Expected CopyData ('d'), CopyDone ('c') or CopyFail ('f') messages, got: 'Q'\n" + "ERROR\n" - + "P0001\n" + + SQLState.QueryCanceled + + "\n" + "Error\n", errorMessage.toString()); @@ -1152,7 +1154,8 @@ public void testCopyInBatchWithCopyFail() throws Exception { } } assertTrue(receivedErrorMessage); - assertEquals("ERROR\n" + "P0001\n" + "Changed my mind\n", errorMessage.toString()); + assertEquals( + "ERROR\n" + SQLState.QueryCanceled + "\n" + "Changed my mind\n", errorMessage.toString()); } assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java index d965cb805..73954c357 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java @@ -26,6 +26,7 @@ import com.google.cloud.Timestamp; import com.google.cloud.spanner.Dialect; import com.google.cloud.spanner.MockSpannerServiceImpl; +import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.Type; @@ -207,6 +208,7 @@ public void testCopyOut() throws SQLException, IOException { StatementResult.query(Statement.of("select * from all_types"), ALL_TYPES_RESULTSET)); try (Connection connection = DriverManager.getConnection(createUrl())) { + connection.createStatement().execute("set time zone 'UTC'"); CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class)); StringWriter writer = new StringWriter(); copyManager.copyOut("COPY all_types TO STDOUT", writer); @@ -239,6 +241,7 @@ public void testCopyOutWithColumns() throws SQLException, IOException { Statement.of("select col_bigint, col_varchar from all_types"), ALL_TYPES_RESULTSET)); try (Connection connection = DriverManager.getConnection(createUrl())) { + connection.createStatement().execute("set time zone 'UTC'"); CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class)); StringWriter writer = new StringWriter(); copyManager.copyOut("COPY all_types (col_bigint, col_varchar) TO STDOUT", writer); @@ -270,6 +273,7 @@ public void testCopyOutCsv() throws SQLException, IOException { StatementResult.query(Statement.of("select * from all_types"), ALL_TYPES_RESULTSET)); try (Connection connection = DriverManager.getConnection(createUrl())) { + connection.createStatement().execute("set time zone 'UTC'"); CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class)); StringWriter writer = new StringWriter(); copyManager.copyOut( @@ -287,6 +291,7 @@ public void testCopyOutCsvWithHeader() throws SQLException, IOException { StatementResult.query(Statement.of("select * from all_types"), ALL_TYPES_RESULTSET)); try (Connection connection = DriverManager.getConnection(createUrl())) { + connection.createStatement().execute("set time zone 'UTC'"); CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class)); StringWriter writer = new StringWriter(); copyManager.copyOut( @@ -306,6 +311,7 @@ public void testCopyOutCsvWithColumnsAndHeader() throws SQLException, IOExceptio Statement.of("select col_bigint from all_types"), ALL_TYPES_RESULTSET)); try (Connection connection = DriverManager.getConnection(createUrl())) { + connection.createStatement().execute("set time zone 'UTC'"); CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class)); StringWriter writer = new StringWriter(); copyManager.copyOut( @@ -326,6 +332,7 @@ public void testCopyOutCsvWithQueryAndHeader() throws SQLException, IOException Statement.of("select * from all_types order by col_bigint"), ALL_TYPES_RESULTSET)); try (Connection connection = DriverManager.getConnection(createUrl())) { + connection.createStatement().execute("set time zone 'UTC'"); CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class)); StringWriter writer = new StringWriter(); copyManager.copyOut( @@ -345,6 +352,7 @@ public void testCopyOutCsvWithQuote() throws SQLException, IOException { StatementResult.query(Statement.of("select * from all_types"), ALL_TYPES_RESULTSET)); try (Connection connection = DriverManager.getConnection(createUrl())) { + connection.createStatement().execute("set time zone 'UTC'"); CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class)); StringWriter writer = new StringWriter(); copyManager.copyOut( @@ -362,6 +370,7 @@ public void testCopyOutCsvWithForceQuoteAll() throws SQLException, IOException { StatementResult.query(Statement.of("select * from all_types"), ALL_TYPES_RESULTSET)); try (Connection connection = DriverManager.getConnection(createUrl())) { + connection.createStatement().execute("set time zone 'UTC'"); CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class)); StringWriter writer = new StringWriter(); copyManager.copyOut( @@ -549,6 +558,26 @@ public void testCopyOutPartitioned() throws SQLException, IOException { } } + @Test + public void testCopyOutPartitioned_NonExistingTable() throws SQLException { + StatusRuntimeException exception = + Status.NOT_FOUND.withDescription("Table my_table not found").asRuntimeException(); + mockSpanner.putStatementResult( + StatementResult.exception(Statement.of("select * from my_table"), exception)); + mockSpanner.setPartitionQueryExecutionTime(SimulatedExecutionTime.ofException(exception)); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class)); + StringWriter writer = new StringWriter(); + SQLException sqlException = + assertThrows( + SQLException.class, () -> copyManager.copyOut("COPY my_table TO STDOUT", writer)); + assertEquals( + "ERROR: Table my_table not found - Statement: 'select * from my_table'", + sqlException.getMessage()); + } + } + static int findIndex(com.google.spanner.v1.ResultSet resultSet, String[] cols) { for (int index = 0; index < resultSet.getRowsCount(); index++) { boolean nullValuesEqual = true; diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/ErrorHandlingTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/ErrorHandlingTest.java index 0bde76bbc..3a00c9834 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/ErrorHandlingTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/ErrorHandlingTest.java @@ -26,6 +26,7 @@ import io.grpc.Status; import java.sql.Connection; import java.sql.DriverManager; +import java.sql.PreparedStatement; import java.sql.SQLException; import org.junit.BeforeClass; import org.junit.Test; @@ -33,6 +34,7 @@ import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameter; import org.junit.runners.Parameterized.Parameters; +import org.postgresql.PGConnection; @RunWith(Parameterized.class) public class ErrorHandlingTest extends AbstractMockServerTest { @@ -55,7 +57,10 @@ public static void loadPgJdbcDriver() throws Exception { public static void setupErrorResults() { mockSpanner.putStatementResult( StatementResult.exception( - Statement.of(INVALID_SELECT), Status.NOT_FOUND.asRuntimeException())); + Statement.of(INVALID_SELECT), + Status.NOT_FOUND + .withDescription("Table unknown_table not found") + .asRuntimeException())); } private String createUrl() { @@ -70,7 +75,9 @@ public void testInvalidQueryNoTransaction() throws SQLException { SQLException exception = assertThrows( SQLException.class, () -> connection.createStatement().executeQuery(INVALID_SELECT)); - assertTrue(exception.getMessage(), exception.getMessage().contains("NOT_FOUND")); + assertEquals( + "ERROR: Table unknown_table not found - Statement: 'SELECT * FROM unknown_table'", + exception.getMessage()); // The connection should be usable, as there was no transaction. assertTrue(connection.createStatement().execute("SELECT 1")); @@ -85,7 +92,9 @@ public void testInvalidQueryInTransaction() throws SQLException { SQLException exception = assertThrows( SQLException.class, () -> connection.createStatement().executeQuery(INVALID_SELECT)); - assertTrue(exception.getMessage(), exception.getMessage().contains("NOT_FOUND")); + assertEquals( + "ERROR: Table unknown_table not found - Statement: 'SELECT * FROM unknown_table'", + exception.getMessage()); // The connection should be in the aborted state. exception = @@ -108,7 +117,9 @@ public void testCommitAbortedTransaction() throws SQLException { SQLException exception = assertThrows( SQLException.class, () -> connection.createStatement().executeQuery(INVALID_SELECT)); - assertTrue(exception.getMessage(), exception.getMessage().contains("NOT_FOUND")); + assertEquals( + "ERROR: Table unknown_table not found - Statement: 'SELECT * FROM unknown_table'", + exception.getMessage()); // The connection should be in the aborted state. exception = @@ -126,4 +137,22 @@ public void testCommitAbortedTransaction() throws SQLException { assertTrue(mockSpanner.countRequestsOfType(RollbackRequest.class) > 0); assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); } + + @Test + public void testInvalidPreparedQuery() throws SQLException { + for (int prepareThreshold : new int[] {-1, 0, 1, 5}) { + try (Connection connection = DriverManager.getConnection(createUrl())) { + PGConnection pgConnection = connection.unwrap(PGConnection.class); + pgConnection.setPrepareThreshold(prepareThreshold); + try (PreparedStatement preparedStatement = connection.prepareStatement(INVALID_SELECT)) { + SQLException exception = + assertThrows(SQLException.class, preparedStatement::executeQuery); + assertEquals( + "Prepare threshold: " + prepareThreshold, + "ERROR: Table unknown_table not found - Statement: 'SELECT * FROM unknown_table'", + exception.getMessage()); + } + } + } + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/ITJdbcDescribeStatementTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/ITJdbcDescribeStatementTest.java index 5745f7440..a347bf26a 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/ITJdbcDescribeStatementTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/ITJdbcDescribeStatementTest.java @@ -39,6 +39,8 @@ import java.sql.SQLException; import java.sql.Types; import java.util.Collections; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -189,7 +191,7 @@ public void testParameterMetaData() throws SQLException { + "and col_date=? " + "and col_varchar=? " + "and col_jsonb=?", - "update all_types set col_bigint=?, " + "update all_types set " + "col_bool=?, " + "col_bytea=?, " + "col_float8=?, " @@ -199,7 +201,7 @@ public void testParameterMetaData() throws SQLException { + "col_date=?, " + "col_varchar=?, " + "col_jsonb=?", - "update all_types set col_bigint=null, " + "update all_types set " + "col_bool=null, " + "col_bytea=null, " + "col_float8=null, " @@ -235,22 +237,28 @@ public void testParameterMetaData() throws SQLException { try (Connection connection = DriverManager.getConnection(getConnectionUrl())) { try (PreparedStatement statement = connection.prepareStatement(sql)) { ParameterMetaData metadata = statement.getParameterMetaData(); - assertEquals(10, metadata.getParameterCount()); + if (sql.startsWith("update all_types set col_bool=?,")) { + assertEquals(sql, 9, metadata.getParameterCount()); + } else { + assertEquals(sql, 10, metadata.getParameterCount()); + } for (int index = 1; index <= metadata.getParameterCount(); index++) { assertEquals(ParameterMetaData.parameterModeIn, metadata.getParameterMode(index)); assertEquals(ParameterMetaData.parameterNullableUnknown, metadata.isNullable(index)); } int index = 0; + if (metadata.getParameterCount() == 10) { + assertEquals(sql, Types.BIGINT, metadata.getParameterType(++index)); + } + assertEquals(sql, Types.BIT, metadata.getParameterType(++index)); + assertEquals(sql, Types.BINARY, metadata.getParameterType(++index)); + assertEquals(sql, Types.DOUBLE, metadata.getParameterType(++index)); assertEquals(sql, Types.BIGINT, metadata.getParameterType(++index)); - assertEquals(Types.BIT, metadata.getParameterType(++index)); - assertEquals(Types.BINARY, metadata.getParameterType(++index)); - assertEquals(Types.DOUBLE, metadata.getParameterType(++index)); - assertEquals(Types.BIGINT, metadata.getParameterType(++index)); - assertEquals(Types.NUMERIC, metadata.getParameterType(++index)); - assertEquals(Types.TIMESTAMP, metadata.getParameterType(++index)); - assertEquals(Types.DATE, metadata.getParameterType(++index)); - assertEquals(Types.VARCHAR, metadata.getParameterType(++index)); - assertEquals(Types.VARCHAR, metadata.getParameterType(++index)); + assertEquals(sql, Types.NUMERIC, metadata.getParameterType(++index)); + assertEquals(sql, Types.TIMESTAMP, metadata.getParameterType(++index)); + assertEquals(sql, Types.DATE, metadata.getParameterType(++index)); + assertEquals(sql, Types.VARCHAR, metadata.getParameterType(++index)); + assertEquals(sql, Types.VARCHAR, metadata.getParameterType(++index)); } } } @@ -288,6 +296,24 @@ public void testParameterMetaDataInLimit() throws SQLException { } } + @Test + public void testMoreThan50Parameters() throws SQLException { + String sql = + "select * from all_types where " + + IntStream.range(0, 51) + .mapToObj(i -> "col_varchar=?") + .collect(Collectors.joining(" or ")); + try (Connection connection = DriverManager.getConnection(getConnectionUrl())) { + try (PreparedStatement statement = connection.prepareStatement(sql)) { + ParameterMetaData metadata = statement.getParameterMetaData(); + assertEquals(51, metadata.getParameterCount()); + for (int i = 1; i < metadata.getParameterCount(); i++) { + assertEquals(Types.VARCHAR, metadata.getParameterType(i)); + } + } + } + } + @Test public void testDescribeInvalidStatements() throws SQLException { for (String sql : diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/ITPsqlTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/ITPsqlTest.java index 36da2afd6..56f25f525 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/ITPsqlTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/ITPsqlTest.java @@ -340,6 +340,7 @@ public void testPrepareExecuteDeallocate() throws IOException, InterruptedExcept Tuple result = runUsingPsql( ImmutableList.of( + "set time zone 'UTC';\n", "prepare insert_row as " + "insert into all_types values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10);\n", "prepare find_row as " @@ -365,7 +366,8 @@ public void testPrepareExecuteDeallocate() throws IOException, InterruptedExcept String output = result.x(), errors = result.y(); assertEquals("", errors); assertEquals( - "PREPARE\n" + "SET\n" + + "PREPARE\n" + "PREPARE\n" + " col_bigint | col_bool | col_bytea | col_float8 | col_int | col_numeric | col_timestamptz | col_date | col_varchar | col_jsonb \n" + "------------+----------+-----------+------------+---------+-------------+-----------------+----------+-------------+-----------\n" diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java index 37c20e545..5bb930788 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.junit.Assume.assumeTrue; import com.google.cloud.ByteArray; import com.google.cloud.Date; @@ -42,6 +43,7 @@ import com.google.cloud.spanner.pgadapter.wireprotocol.ExecuteMessage; import com.google.cloud.spanner.pgadapter.wireprotocol.ParseMessage; import com.google.common.collect.ImmutableList; +import com.google.protobuf.ListValue; import com.google.protobuf.Value; import com.google.spanner.admin.database.v1.UpdateDatabaseDdlRequest; import com.google.spanner.v1.BeginTransactionRequest; @@ -49,6 +51,11 @@ import com.google.spanner.v1.ExecuteBatchDmlRequest; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.ExecuteSqlRequest.QueryMode; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.RollbackRequest; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; import com.google.spanner.v1.Type; import com.google.spanner.v1.TypeAnnotationCode; import com.google.spanner.v1.TypeCode; @@ -61,6 +68,7 @@ import java.sql.ParameterMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; +import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Types; import java.time.LocalDate; @@ -68,10 +76,7 @@ import java.time.OffsetDateTime; import java.time.ZoneOffset; import java.time.temporal.ChronoUnit; -import java.util.Base64; -import java.util.Calendar; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.junit.BeforeClass; @@ -156,6 +161,55 @@ public void testQuery() throws SQLException { } } + @Test + public void testPreparedStatementParameterMetadata() throws SQLException { + String sql = "SELECT * FROM foo WHERE id=? or value=?"; + String pgSql = "SELECT * FROM foo WHERE id=$1 or value=$2"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(pgSql), + com.google.spanner.v1.ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("col1") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("col2") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .build()) + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p2") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .build()) + .build()) + .build())); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) { + ParameterMetaData parameters = preparedStatement.getParameterMetaData(); + assertEquals(2, parameters.getParameterCount()); + assertEquals(Types.BIGINT, parameters.getParameterType(1)); + assertEquals(Types.VARCHAR, parameters.getParameterType(2)); + } + } + } + @Test public void testInvalidQuery() throws SQLException { String sql = "/ not a valid comment / SELECT 1"; @@ -171,6 +225,49 @@ public void testInvalidQuery() throws SQLException { assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); } + @Test + public void testClientSideStatementWithResultSet() throws SQLException { + String sql = "show statement_timeout"; + + try (Connection connection = DriverManager.getConnection(createUrl())) { + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + assertEquals("0", resultSet.getString("statement_timeout")); + assertFalse(resultSet.next()); + } + connection.createStatement().execute("set statement_timeout=6000"); + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + assertEquals("6s", resultSet.getString("statement_timeout")); + assertFalse(resultSet.next()); + } + } + + // The statement is handled locally and not sent to Cloud Spanner. + assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + } + + @Test + public void testClientSideStatementWithoutResultSet() throws SQLException { + try (Connection connection = DriverManager.getConnection(createUrl())) { + try (java.sql.Statement statement = connection.createStatement()) { + statement.execute("start batch dml"); + statement.execute(INSERT_STATEMENT.getSql()); + statement.execute(UPDATE_STATEMENT.getSql()); + statement.execute("run batch"); + } + } + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)); + ExecuteBatchDmlRequest request = + mockSpanner.getRequestsOfType(ExecuteBatchDmlRequest.class).get(0); + assertEquals(2, request.getStatementsCount()); + assertEquals(INSERT_STATEMENT.getSql(), request.getStatements(0).getSql()); + assertEquals(UPDATE_STATEMENT.getSql(), request.getStatements(1).getSql()); + assertTrue(request.getTransaction().hasBegin()); + assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + @Test public void testSelectCurrentSchema() throws SQLException { String sql = "SELECT current_schema"; @@ -229,6 +326,31 @@ public void testSelectCurrentCatalog() throws SQLException { assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); } + @Test + public void testSelectVersion() throws SQLException { + for (String sql : + new String[] {"SELECT version()", "select version()", "select * from version()"}) { + + try (Connection connection = DriverManager.getConnection(createUrl())) { + String version = null; + try (ResultSet resultSet = + connection.createStatement().executeQuery("show server_version")) { + assertTrue(resultSet.next()); + version = resultSet.getString(1); + assertFalse(resultSet.next()); + } + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + assertEquals("PostgreSQL " + version, resultSet.getString("version")); + assertFalse(resultSet.next()); + } + } + } + + // The statement is handled locally and not sent to Cloud Spanner. + assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + } + @Test public void testShowSearchPath() throws SQLException { String sql = "show search_path"; @@ -442,7 +564,21 @@ public void testQueryWithParameters() throws SQLException { public void testQueryWithLegacyDateParameter() throws SQLException { String jdbcSql = "select col_date from all_types where col_date=?"; String pgSql = "select col_date from all_types where col_date=$1"; - mockSpanner.putStatementResult(StatementResult.query(Statement.of(pgSql), ALL_TYPES_RESULTSET)); + ResultSetMetadata metadata = + ALL_TYPES_METADATA + .toBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.DATE).build()) + .build()) + .build()) + .build(); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(pgSql), ALL_TYPES_RESULTSET.toBuilder().setMetadata(metadata).build())); mockSpanner.putStatementResult( StatementResult.query( Statement.newBuilder(pgSql).bind("p1").to(Date.parseDate("2022-03-29")).build(), @@ -470,42 +606,77 @@ public void testQueryWithLegacyDateParameter() throws SQLException { List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); // Prepare threshold less than 0 means use binary transfer + DESCRIBE statement. // However, the legacy date type will never use BINARY transfer and will always be sent with - // unspecified type by the JDBC driver the first time. This means that we need 3 round trips - // for a query that uses a prepared statement the first time. - int expectedRequestCount; - switch (preparedThreshold) { - case -1: - case 1: - expectedRequestCount = 3; - break; - default: - expectedRequestCount = 2; - break; - } + // unspecified type by the JDBC driver the first time. This means that we need 2 round trips + // in all cases, as the statement will either use an explicit DESCRIBE message, or it will + // be auto-described by PGAdapter. + int expectedRequestCount = 2; assertEquals( "Prepare threshold: " + preparedThreshold, expectedRequestCount, requests.size()); - ExecuteSqlRequest executeRequest; - if (preparedThreshold == 1) { - // The order of statements here is a little strange. The execution of the statement is - // executed first, and the describe statements are then executed afterwards. The reason - // for this is that JDBC does the following when it encounters a statement parameter that - // is 'unknown' (it considers the legacy date type as unknown, as it does not know if the - // user means date, timestamp or timestamptz): - // 1. It sends a DescribeStatement message, but without a flush or a sync, as it is not - // planning on using the information for this request. - // 2. It then sends the Execute message followed by a sync. This causes PGAdapter to sync - // the backend connection and execute everything in the actual execution pipeline. - // 3. PGAdapter then executes anything left in the message queue. The DescribeMessage is - // still there, and is therefore executed after the Execute message. - // All the above still works as intended, as the responses are sent in the expected order. - executeRequest = requests.get(0); - for (int i = 1; i < requests.size(); i++) { - assertEquals(QueryMode.PLAN, requests.get(i).getQueryMode()); + ExecuteSqlRequest executeRequest = requests.get(requests.size() - 1); + assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); + assertEquals(pgSql, executeRequest.getSql()); + + Map params = executeRequest.getParams().getFieldsMap(); + Map types = executeRequest.getParamTypesMap(); + + assertEquals(TypeCode.DATE, types.get("p1").getCode()); + assertEquals("2022-03-29", params.get("p1").getStringValue()); + + mockSpanner.clearRequests(); + } + } + } + + @Test + public void testAutoDescribedStatementsAreReused() throws SQLException { + String jdbcSql = "select col_date from all_types where col_date=?"; + String pgSql = "select col_date from all_types where col_date=$1"; + ResultSetMetadata metadata = + ALL_TYPES_METADATA + .toBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.DATE).build()) + .build()) + .build()) + .build(); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(pgSql), ALL_TYPES_RESULTSET.toBuilder().setMetadata(metadata).build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(pgSql).bind("p1").to(Date.parseDate("2022-03-29")).build(), + ALL_TYPES_RESULTSET)); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + for (int attempt : new int[] {1, 2}) { + try (PreparedStatement preparedStatement = connection.prepareStatement(jdbcSql)) { + // Threshold 0 means never use a named prepared statement. + preparedStatement.unwrap(PgStatement.class).setPrepareThreshold(0); + preparedStatement.setDate(1, new java.sql.Date(2022 - 1900, Calendar.MARCH, 29)); + try (ResultSet resultSet = preparedStatement.executeQuery()) { + assertTrue(resultSet.next()); + assertEquals( + new java.sql.Date(2022 - 1900, Calendar.MARCH, 29), resultSet.getDate("col_date")); + assertFalse(resultSet.next()); } + } + + // The first time we execute this statement the number of requests should be 2, as the + // statement is auto-described by the backend. The second time we execute the statement the + // backend should reuse the result from the first auto-describe roundtrip. + List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); + if (attempt == 1) { + assertEquals(2, requests.size()); } else { - executeRequest = requests.get(requests.size() - 1); + assertEquals(1, requests.size()); } + + ExecuteSqlRequest executeRequest = requests.get(requests.size() - 1); assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); assertEquals(pgSql, executeRequest.getSql()); @@ -520,6 +691,43 @@ public void testQueryWithLegacyDateParameter() throws SQLException { } } + @Test + public void testDescribeDdlStatement() throws SQLException { + try (Connection connection = DriverManager.getConnection(createUrl())) { + try (PreparedStatement preparedStatement = + connection.prepareStatement("create table foo (id bigint primary key, value varchar)")) { + ParameterMetaData parameterMetaData = preparedStatement.getParameterMetaData(); + assertEquals(0, parameterMetaData.getParameterCount()); + assertNull(preparedStatement.getMetaData()); + } + } + } + + @Test + public void testDescribeClientSideNoResultStatement() throws SQLException { + try (Connection connection = DriverManager.getConnection(createUrl())) { + try (PreparedStatement preparedStatement = connection.prepareStatement("start batch dml")) { + ParameterMetaData parameterMetaData = preparedStatement.getParameterMetaData(); + assertEquals(0, parameterMetaData.getParameterCount()); + assertNull(preparedStatement.getMetaData()); + } + } + } + + @Test + public void testDescribeClientSideResultSetStatement() throws SQLException { + try (Connection connection = DriverManager.getConnection(createUrl())) { + try (PreparedStatement preparedStatement = + connection.prepareStatement("show statement_timeout")) { + SQLException exception = + assertThrows(SQLException.class, preparedStatement::getParameterMetaData); + assertEquals( + "ERROR: ResultSetMetadata are available only for results that were returned from Cloud Spanner", + exception.getMessage()); + } + } + } + @Test public void testMultipleQueriesInTransaction() throws SQLException { String sql = "SELECT 1"; @@ -656,12 +864,13 @@ public void testDmlWithNonExistingTable() throws SQLException { @Test public void testNullValues() throws SQLException { + String pgSql = + "insert into all_types " + + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar) " + + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9)"; mockSpanner.putStatementResult( StatementResult.update( - Statement.newBuilder( - "insert into all_types " - + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar) " - + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9)") + Statement.newBuilder(pgSql) .bind("p1") .to(2L) .bind("p2") @@ -776,14 +985,6 @@ public void testDescribeDmlWithNonExistingTable() throws SQLException { Status.NOT_FOUND .withDescription("Table non_existing_table not found") .asRuntimeException())); - String describeSql = - "select $1, $2 from (select value=$2 from non_existing_table where id=$1) p"; - mockSpanner.putStatementResult( - StatementResult.exception( - Statement.of(describeSql), - Status.NOT_FOUND - .withDescription("Table non_existing_table not found") - .asRuntimeException())); try (Connection connection = DriverManager.getConnection(createUrl())) { try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) { SQLException exception = @@ -792,27 +993,25 @@ public void testDescribeDmlWithNonExistingTable() throws SQLException { } } - // We receive two ExecuteSql requests: + // We receive one ExecuteSql requests: // 1. DescribeStatement (parameters). This statement fails as the table does not exist. - // 2. Because the DescribeStatement step fails, PGAdapter executes the DML statement in analyze - // mode to force a 'correct' error message. List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - assertEquals(2, requests.size()); - assertEquals(describeSql, requests.get(0).getSql()); + assertEquals(1, requests.size()); + assertEquals(sql, requests.get(0).getSql()); assertEquals(QueryMode.PLAN, requests.get(0).getQueryMode()); - assertEquals(sql, requests.get(1).getSql()); - assertEquals(QueryMode.PLAN, requests.get(1).getQueryMode()); } @Test public void testDescribeDmlWithSchemaPrefix() throws SQLException { String sql = "update public.my_table set value=? where id=?"; - String describeSql = "select $1, $2 from (select value=$1 from public.my_table where id=$2) p"; + String pgSql = "update public.my_table set value=$1 where id=$2"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeSql), + Statement.of(pgSql), com.google.spanner.v1.ResultSet.newBuilder() - .setMetadata(createMetadata(ImmutableList.of(TypeCode.STRING, TypeCode.INT64))) + .setMetadata( + createParameterTypesMetadata(ImmutableList.of(TypeCode.STRING, TypeCode.INT64))) + .setStats(ResultSetStats.newBuilder().build()) .build())); try (Connection connection = DriverManager.getConnection(createUrl())) { try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) { @@ -824,20 +1023,21 @@ public void testDescribeDmlWithSchemaPrefix() throws SQLException { List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); assertEquals(1, requests.size()); - assertEquals(describeSql, requests.get(0).getSql()); + assertEquals(pgSql, requests.get(0).getSql()); assertEquals(QueryMode.PLAN, requests.get(0).getQueryMode()); } @Test public void testDescribeDmlWithQuotedSchemaPrefix() throws SQLException { String sql = "update \"public\".\"my_table\" set value=? where id=?"; - String describeSql = - "select $1, $2 from (select value=$1 from \"public\".\"my_table\" where id=$2) p"; + String pgSql = "update \"public\".\"my_table\" set value=$1 where id=$2"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeSql), + Statement.of(pgSql), com.google.spanner.v1.ResultSet.newBuilder() - .setMetadata(createMetadata(ImmutableList.of(TypeCode.STRING, TypeCode.INT64))) + .setMetadata( + createParameterTypesMetadata(ImmutableList.of(TypeCode.STRING, TypeCode.INT64))) + .setStats(ResultSetStats.newBuilder().build()) .build())); try (Connection connection = DriverManager.getConnection(createUrl())) { try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) { @@ -849,7 +1049,7 @@ public void testDescribeDmlWithQuotedSchemaPrefix() throws SQLException { List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); assertEquals(1, requests.size()); - assertEquals(describeSql, requests.get(0).getSql()); + assertEquals(pgSql, requests.get(0).getSql()); assertEquals(QueryMode.PLAN, requests.get(0).getQueryMode()); } @@ -893,6 +1093,48 @@ public void testTwoDmlStatements() throws SQLException { assertEquals(UPDATE_STATEMENT.getSql(), request.getStatements(1).getSql()); } + @Test + public void testTwoDmlStatements_withError() throws SQLException { + try (Connection connection = DriverManager.getConnection(createUrl())) { + try (java.sql.Statement statement = connection.createStatement()) { + SQLException exception = + assertThrows( + SQLException.class, + () -> statement.execute(String.format("%s;%s;", INSERT_STATEMENT, INVALID_DML))); + assertEquals("ERROR: Statement is invalid.", exception.getMessage()); + } + } + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)); + ExecuteBatchDmlRequest request = + mockSpanner.getRequestsOfType(ExecuteBatchDmlRequest.class).get(0); + assertEquals(2, request.getStatementsCount()); + assertEquals(INSERT_STATEMENT.getSql(), request.getStatements(0).getSql()); + assertEquals(INVALID_DML.getSql(), request.getStatements(1).getSql()); + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(RollbackRequest.class)); + } + + @Test + public void testTwoDmlStatements_randomlyAborted() throws SQLException { + mockSpanner.setAbortProbability(0.5); + for (int run = 0; run < 50; run++) { + try (Connection connection = DriverManager.getConnection(createUrl())) { + try (java.sql.Statement statement = connection.createStatement()) { + assertFalse( + statement.execute(String.format("%s;%s;", INSERT_STATEMENT, UPDATE_STATEMENT))); + assertEquals(1, statement.getUpdateCount()); + assertFalse(statement.getMoreResults()); + assertEquals(2, statement.getUpdateCount()); + assertFalse(statement.getMoreResults()); + assertEquals(-1, statement.getUpdateCount()); + } + } finally { + mockSpanner.setAbortProbability(0.0); + } + } + } + @Test public void testJdbcBatch() throws SQLException { try (Connection connection = DriverManager.getConnection(createUrl())) { @@ -1461,6 +1703,123 @@ public void testPreparedStatement() throws SQLException { } } + @Test + public void testPreparedStatementReturning() throws SQLException { + String pgSql = + "insert into all_types " + + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) " + + "returning *"; + String sql = + "insert into all_types " + + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + + "values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) " + + "returning *"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(pgSql), + com.google.spanner.v1.ResultSet.newBuilder() + .setMetadata( + ALL_TYPES_METADATA + .toBuilder() + .setUndeclaredParameters( + createParameterTypesMetadata( + ImmutableList.of( + TypeCode.INT64, + TypeCode.BOOL, + TypeCode.BYTES, + TypeCode.FLOAT64, + TypeCode.INT64, + TypeCode.NUMERIC, + TypeCode.TIMESTAMP, + TypeCode.DATE, + TypeCode.STRING, + TypeCode.JSON)) + .getUndeclaredParameters()) + .build()) + .setStats(ResultSetStats.newBuilder().build()) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(pgSql) + .bind("p1") + .to(1L) + .bind("p2") + .to(true) + .bind("p3") + .to(ByteArray.copyFrom("test")) + .bind("p4") + .to(3.14d) + .bind("p5") + .to(100L) + .bind("p6") + .to(com.google.cloud.spanner.Value.pgNumeric("6.626")) + .bind("p7") + .to(Timestamp.parseTimestamp("2022-02-16T13:18:02.123457000Z")) + .bind("p8") + .to(Date.parseDate("2022-03-29")) + .bind("p9") + .to("test") + .bind("p10") + .to("{\"key\": \"value\"}") + .build(), + com.google.spanner.v1.ResultSet.newBuilder() + .setMetadata(ALL_TYPES_METADATA) + .setStats(ResultSetStats.newBuilder().setRowCountExact(1L).build()) + .addRows(ALL_TYPES_RESULTSET.getRows(0)) + .build())); + + OffsetDateTime zonedDateTime = + LocalDateTime.of(2022, 2, 16, 13, 18, 2, 123456789).atOffset(ZoneOffset.UTC); + try (Connection connection = DriverManager.getConnection(createUrl())) { + try (PreparedStatement statement = connection.prepareStatement(sql)) { + ParameterMetaData parameterMetaData = statement.getParameterMetaData(); + assertEquals(10, parameterMetaData.getParameterCount()); + assertEquals(Types.BIGINT, parameterMetaData.getParameterType(1)); + assertEquals(Types.BIT, parameterMetaData.getParameterType(2)); + assertEquals(Types.BINARY, parameterMetaData.getParameterType(3)); + assertEquals(Types.DOUBLE, parameterMetaData.getParameterType(4)); + assertEquals(Types.BIGINT, parameterMetaData.getParameterType(5)); + assertEquals(Types.NUMERIC, parameterMetaData.getParameterType(6)); + assertEquals(Types.TIMESTAMP, parameterMetaData.getParameterType(7)); + assertEquals(Types.DATE, parameterMetaData.getParameterType(8)); + assertEquals(Types.VARCHAR, parameterMetaData.getParameterType(9)); + // TODO: Enable when support for JSONB has been enabled. + // assertEquals(Types.OTHER, parameterMetaData.getParameterType(10)); + ResultSetMetaData metadata = statement.getMetaData(); + assertEquals(10, metadata.getColumnCount()); + assertEquals(Types.BIGINT, metadata.getColumnType(1)); + assertEquals(Types.BIT, metadata.getColumnType(2)); + assertEquals(Types.BINARY, metadata.getColumnType(3)); + assertEquals(Types.DOUBLE, metadata.getColumnType(4)); + assertEquals(Types.BIGINT, metadata.getColumnType(5)); + assertEquals(Types.NUMERIC, metadata.getColumnType(6)); + assertEquals(Types.TIMESTAMP, metadata.getColumnType(7)); + assertEquals(Types.DATE, metadata.getColumnType(8)); + assertEquals(Types.VARCHAR, metadata.getColumnType(9)); + // TODO: Enable when support for JSONB has been enabled. + // assertEquals(Types.OTHER, metadata.getColumnType(10)); + + int index = 0; + statement.setLong(++index, 1L); + statement.setBoolean(++index, true); + statement.setBytes(++index, "test".getBytes(StandardCharsets.UTF_8)); + statement.setDouble(++index, 3.14d); + statement.setInt(++index, 100); + statement.setBigDecimal(++index, new BigDecimal("6.626")); + statement.setObject(++index, zonedDateTime); + statement.setObject(++index, LocalDate.of(2022, 3, 29)); + statement.setString(++index, "test"); + statement.setString(++index, "{\"key\": \"value\"}"); + + try (ResultSet resultSet = statement.executeQuery()) { + assertTrue(resultSet.next()); + assertFalse(resultSet.next()); + } + } + } + } + @Test public void testCursorSuccess() throws SQLException { try (Connection connection = DriverManager.getConnection(createUrl())) { @@ -1964,6 +2323,82 @@ public void testInformationSchemaQueryInTransactionWithErrorDuringRetry() throws assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); } + @Test + public void testInformationSchemaQueryInTransactionWithReplacedPgCatalogTables() + throws SQLException { + String sql = "SELECT 1 FROM pg_namespace"; + String replacedSql = + "with pg_namespace as (\n" + + " select case schema_name when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as oid,\n" + + " schema_name as nspname, null as nspowner, null as nspacl\n" + + " from information_schema.schemata\n" + + ")\n" + + "SELECT 1 FROM pg_namespace"; + // Register a result for the query. Note that we don't really care what the result is, just that + // there is a result. + mockSpanner.putStatementResult( + StatementResult.query(Statement.of(replacedSql), SELECT1_RESULTSET)); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + // Make sure that we start a transaction. + connection.setAutoCommit(false); + + // Execute a query to start the transaction. + try (ResultSet resultSet = connection.createStatement().executeQuery(SELECT1.getSql())) { + assertTrue(resultSet.next()); + assertEquals(1L, resultSet.getLong(1)); + assertFalse(resultSet.next()); + } + + // This ensures that the following query returns an error the first time it is executed, and + // then succeeds the second time. This happens because the exception is 'popped' from the + // response queue when it is returned. The next time the query is executed, it will return the + // actual result that we set. + mockSpanner.setExecuteStreamingSqlExecutionTime( + SimulatedExecutionTime.ofException( + Status.INVALID_ARGUMENT + .withDescription( + "Unsupported concurrency mode in query using INFORMATION_SCHEMA.") + .asRuntimeException())); + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + assertEquals(1L, resultSet.getLong(1)); + assertFalse(resultSet.next()); + } + + // Make sure that the connection is still usable. + try (ResultSet resultSet = connection.createStatement().executeQuery(SELECT2.getSql())) { + assertTrue(resultSet.next()); + assertEquals(2L, resultSet.getLong(1)); + assertFalse(resultSet.next()); + } + connection.commit(); + } + + // We should receive the INFORMATION_SCHEMA statement twice on Cloud Spanner: + // 1. The first time it returns an error because it is using the wrong concurrency mode. + // 2. The specific error will cause the connection to retry the statement using a single-use + // read-only transaction. + assertEquals(4, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); + // The first statement should start a transaction + assertTrue(requests.get(0).getTransaction().hasBegin()); + // The second statement (the initial attempt of the INFORMATION_SCHEMA query) should try to use + // the transaction. + assertTrue(requests.get(1).getTransaction().hasId()); + assertEquals(replacedSql, requests.get(1).getSql()); + // The INFORMATION_SCHEMA query is then retried using a single-use read-only transaction. + assertFalse(requests.get(2).hasTransaction()); + assertEquals(replacedSql, requests.get(2).getSql()); + // The last statement should use the transaction. + assertTrue(requests.get(3).getTransaction().hasId()); + + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + CommitRequest commitRequest = mockSpanner.getRequestsOfType(CommitRequest.class).get(0); + assertEquals(commitRequest.getTransactionId(), requests.get(1).getTransaction().getId()); + assertEquals(commitRequest.getTransactionId(), requests.get(3).getTransaction().getId()); + } + @Test public void testShowGuessTypes() throws SQLException { try (Connection connection = DriverManager.getConnection(createUrl())) { @@ -2407,7 +2842,17 @@ public void testSetToEmpty() throws SQLException { public void testSetTimeZone() throws SQLException { try (Connection connection = DriverManager.getConnection(createUrl())) { connection.createStatement().execute("set time zone 'IST'"); - verifySettingValue(connection, "timezone", "IST"); + verifySettingValue(connection, "timezone", "Asia/Kolkata"); + } + } + + @Test + public void testSetTimeZoneToServerDefault() throws SQLException { + try (Connection connection = DriverManager.getConnection(createUrl())) { + connection.createStatement().execute("set time zone 'atlantic/faeroe'"); + verifySettingValue(connection, "timezone", "Atlantic/Faeroe"); + connection.createStatement().execute("set time zone default"); + verifySettingValue(connection, "timezone", TimeZone.getDefault().getID()); } } @@ -2416,7 +2861,7 @@ public void testSetTimeZoneToDefault() throws SQLException { try (Connection connection = DriverManager.getConnection(createUrl() + "?options=-c%20timezone=IST")) { connection.createStatement().execute("set time zone default"); - verifySettingValue(connection, "timezone", "IST"); + verifySettingValue(connection, "timezone", "Asia/Kolkata"); } } @@ -2425,7 +2870,7 @@ public void testSetTimeZoneToLocal() throws SQLException { try (Connection connection = DriverManager.getConnection(createUrl() + "?options=-c%20timezone=IST")) { connection.createStatement().execute("set time zone local"); - verifySettingValue(connection, "timezone", "IST"); + verifySettingValue(connection, "timezone", "Asia/Kolkata"); } } @@ -2655,32 +3100,150 @@ public void testReplacePgCatalogTablesOff() throws SQLException { public void testDescribeStatementWithMoreThan50Parameters() throws SQLException { try (Connection connection = DriverManager.getConnection(createUrl())) { // Force binary transfer + usage of server-side prepared statements. - connection.unwrap(PGConnection.class).setPrepareThreshold(-1); + connection.unwrap(PGConnection.class).setPrepareThreshold(1); String sql = String.format( "insert into foo values (%s)", IntStream.rangeClosed(1, 51).mapToObj(i -> "?").collect(Collectors.joining(","))); + String pgSql = + String.format( + "insert into foo values (%s)", + IntStream.rangeClosed(1, 51).mapToObj(i -> "$" + i).collect(Collectors.joining(","))); + ImmutableList typeCodes = + ImmutableList.copyOf( + IntStream.rangeClosed(1, 51) + .mapToObj(i -> TypeCode.STRING) + .collect(Collectors.toList())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(pgSql), + com.google.spanner.v1.ResultSet.newBuilder() + .setMetadata(createParameterTypesMetadata(typeCodes)) + .setStats(ResultSetStats.newBuilder().build()) + .build())); try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) { - SQLException sqlException = - assertThrows(SQLException.class, preparedStatement::getParameterMetaData); - assertEquals( - "ERROR: Cannot describe statements with more than 50 parameters", - sqlException.getMessage()); + ParameterMetaData metadata = preparedStatement.getParameterMetaData(); + assertEquals(51, metadata.getParameterCount()); } + Statement.Builder builder = Statement.newBuilder(pgSql); + IntStream.rangeClosed(1, 51).forEach(i -> builder.bind("p" + i).to((String) null)); + Statement statement = builder.build(); + mockSpanner.putStatementResult(StatementResult.update(statement, 1L)); try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) { for (int i = 0; i < 51; i++) { preparedStatement.setNull(i + 1, Types.NULL); } - SQLException sqlException = - assertThrows(SQLException.class, preparedStatement::executeUpdate); - assertEquals( - "ERROR: Cannot describe statements with more than 50 parameters", - sqlException.getMessage()); + assertEquals(1, preparedStatement.executeUpdate()); + } + } + } + + @Test + public void testDmlReturning() throws SQLException { + String sql = "INSERT INTO test (value) values ('test') RETURNING id"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(sql), + com.google.spanner.v1.ResultSet.newBuilder() + .setMetadata( + createMetadata(ImmutableList.of(TypeCode.INT64), ImmutableList.of("id"))) + .setStats(ResultSetStats.newBuilder().setRowCountExact(1L).build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("9999").build()) + .build()) + .build())); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + assertEquals(9999L, resultSet.getLong(1)); + assertFalse(resultSet.next()); + } + try (java.sql.Statement statement = connection.createStatement()) { + assertTrue(statement.execute(sql)); + try (ResultSet resultSet = statement.getResultSet()) { + assertTrue(resultSet.next()); + assertEquals(9999L, resultSet.getLong(1)); + assertFalse(resultSet.next()); + } + assertFalse(statement.getMoreResults()); + assertEquals(-1, statement.getUpdateCount()); } } } + @Test + public void testDmlReturningMultipleRows() throws SQLException { + String sql = "UPDATE test SET value='new_value' WHERE value='old_value' RETURNING id"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(sql), + com.google.spanner.v1.ResultSet.newBuilder() + .setMetadata( + createMetadata(ImmutableList.of(TypeCode.INT64), ImmutableList.of("id"))) + .setStats(ResultSetStats.newBuilder().setRowCountExact(3L).build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("2").build()) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("3").build()) + .build()) + .build())); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + assertEquals(1L, resultSet.getLong(1)); + assertTrue(resultSet.next()); + assertEquals(2L, resultSet.getLong(1)); + assertTrue(resultSet.next()); + assertEquals(3L, resultSet.getLong(1)); + assertFalse(resultSet.next()); + } + } + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest executeRequest = + mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0); + assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); + assertTrue(executeRequest.getTransaction().hasBegin()); + assertTrue(executeRequest.getTransaction().getBegin().hasReadWrite()); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void testUUIDParameter() throws SQLException { + assumeTrue(pgVersion.equals("14.1")); + + String jdbcSql = "SELECT * FROM all_types WHERE col_uuid=?"; + String pgSql = "SELECT * FROM all_types WHERE col_uuid=$1"; + UUID uuid = UUID.randomUUID(); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(pgSql).bind("p1").to(uuid.toString()).build(), + ALL_TYPES_RESULTSET)); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + try (PreparedStatement statement = connection.prepareStatement(jdbcSql)) { + statement.setObject(1, uuid); + try (ResultSet resultSet = statement.executeQuery()) { + assertTrue(resultSet.next()); + assertFalse(resultSet.next()); + } + } + } + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + } + private void verifySettingIsNull(Connection connection, String setting) throws SQLException { try (ResultSet resultSet = connection.createStatement().executeQuery(String.format("show %s", setting))) { diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/JdbcSimpleModeMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/JdbcSimpleModeMockServerTest.java index 6e967d5cb..ecfe42a41 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/JdbcSimpleModeMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/JdbcSimpleModeMockServerTest.java @@ -23,10 +23,14 @@ import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.ExecuteSqlRequest.QueryMode; import com.google.spanner.v1.RollbackRequest; +import com.google.spanner.v1.TypeCode; import io.grpc.Status; import java.math.BigDecimal; import java.sql.Connection; @@ -530,4 +534,96 @@ public void testExecuteUnknownStatement() throws SQLException { } } } + + @Test + public void testGetTimezoneStringUtc() throws SQLException { + String sql = "select '2022-01-01 10:00:00+01'::timestamptz"; + mockSpanner.putStatementResult( + StatementResult.query( + com.google.cloud.spanner.Statement.of(sql), + com.google.spanner.v1.ResultSet.newBuilder() + .setMetadata(createMetadata(ImmutableList.of(TypeCode.TIMESTAMP))) + .addRows( + ListValue.newBuilder() + .addValues( + Value.newBuilder().setStringValue("2022-01-01T09:00:00Z").build()) + .build()) + .build())); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + connection.createStatement().execute("set time zone utc"); + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + assertEquals("2022-01-01 09:00:00+00", resultSet.getString(1)); + assertFalse(resultSet.next()); + } + } + } + + @Test + public void testGetTimezoneStringEuropeAmsterdam() throws SQLException { + String sql = "select '2022-01-01 10:00:00Z'::timestamptz"; + mockSpanner.putStatementResult( + StatementResult.query( + com.google.cloud.spanner.Statement.of(sql), + com.google.spanner.v1.ResultSet.newBuilder() + .setMetadata(createMetadata(ImmutableList.of(TypeCode.TIMESTAMP))) + .addRows( + ListValue.newBuilder() + .addValues( + Value.newBuilder().setStringValue("2022-01-01T10:00:00Z").build()) + .build()) + .build())); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + connection.createStatement().execute("set time zone 'Europe/Amsterdam'"); + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + assertEquals("2022-01-01 11:00:00+01", resultSet.getString(1)); + assertFalse(resultSet.next()); + } + } + } + + @Test + public void testGetTimezoneStringAmericaLosAngeles() throws SQLException { + String sql = "select '1883-11-18 00:00:00Z'::timestamptz"; + mockSpanner.putStatementResult( + StatementResult.query( + com.google.cloud.spanner.Statement.of(sql), + com.google.spanner.v1.ResultSet.newBuilder() + .setMetadata(createMetadata(ImmutableList.of(TypeCode.TIMESTAMP))) + .addRows( + ListValue.newBuilder() + .addValues( + Value.newBuilder().setStringValue("1883-11-18T00:00:00Z").build()) + .build()) + .build())); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + connection.createStatement().execute("set time zone 'America/Los_Angeles'"); + try (ResultSet resultSet = connection.createStatement().executeQuery(sql)) { + assertTrue(resultSet.next()); + if (OptionsMetadata.isJava8()) { + // Java8 does not support timezone offsets with second precision. + assertEquals("1883-11-17 16:07:02-07:52", resultSet.getString(1)); + } else { + assertEquals("1883-11-17 16:07:02-07:52:58", resultSet.getString(1)); + } + assertFalse(resultSet.next()); + } + } + } + + @Test + public void testSetInvalidTimezone() throws SQLException { + try (Connection connection = DriverManager.getConnection(createUrl())) { + SQLException exception = + assertThrows( + SQLException.class, + () -> connection.createStatement().execute("set time zone 'foo'")); + assertEquals( + "ERROR: invalid value for parameter \"TimeZone\": \"foo\"", exception.getMessage()); + } + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/PGExceptionTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/PGExceptionTest.java deleted file mode 100644 index 3a044cbf6..000000000 --- a/src/test/java/com/google/cloud/spanner/pgadapter/PGExceptionTest.java +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.cloud.spanner.pgadapter; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; - -import com.google.cloud.spanner.pgadapter.error.PGException; -import com.google.cloud.spanner.pgadapter.error.SQLState; -import com.google.cloud.spanner.pgadapter.error.Severity; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public class PGExceptionTest { - - @Test - public void testHints() { - assertNull( - PGException.newBuilder("test message") - .setSQLState(SQLState.InternalError) - .setSeverity(Severity.ERROR) - .build() - .getHints()); - assertEquals( - "test hint\nsecond line", - PGException.newBuilder("test message") - .setSQLState(SQLState.InternalError) - .setSeverity(Severity.ERROR) - .setHints("test hint\nsecond line") - .build() - .getHints()); - } -} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/ServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/ServerTest.java index 584ac56a2..9604872ca 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/ServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/ServerTest.java @@ -94,4 +94,38 @@ public void testMainWithInvalidParam() { System.setErr(originalErr); } } + + @Test + public void testInvalidKeyStore() { + ByteArrayOutputStream outArrayStream = new ByteArrayOutputStream(); + PrintStream out = new PrintStream(outArrayStream); + ByteArrayOutputStream errArrayStream = new ByteArrayOutputStream(); + PrintStream err = new PrintStream(errArrayStream); + + PrintStream originalOut = System.out; + PrintStream originalErr = System.err; + String originalKeyStore = System.getProperty("javax.net.ssl.keyStore"); + System.setOut(out); + System.setErr(err); + System.setProperty("javax.net.ssl.keyStore", "/path/to/non/existing/file.pfx"); + + try { + Server.main(new String[] {}); + assertEquals( + "The server could not be started because an error occurred: Key store /path/to/non/existing/file.pfx does not exist\n", + errArrayStream.toString()); + assertTrue( + outArrayStream.toString(), + outArrayStream + .toString() + .startsWith( + String.format("-- Starting PGAdapter version %s --", Server.getVersion()))); + } finally { + System.setOut(originalOut); + System.setErr(originalErr); + if (originalKeyStore != null) { + System.setProperty("javax.net.ssl.keyStore", originalKeyStore); + } + } + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/csharp/AbstractNpgsqlMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/csharp/AbstractNpgsqlMockServerTest.java index e303a9f98..1f40ffd61 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/csharp/AbstractNpgsqlMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/csharp/AbstractNpgsqlMockServerTest.java @@ -81,6 +81,83 @@ public abstract class AbstractNpgsqlMockServerTest extends AbstractMockServerTes + " select 1184 as oid, 'timestamptz' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, 8 as typlen, true as typbyval, 'b' as typtype, 'D' as typcategory, true as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1185 as typarray, 'timestamptz_in' as typinput, 'timestamptz_out' as typoutput, 'timestamptz_recv' as typreceive, 'timestamptz_send' as typsend, 'timestamptztypmodin' as typmodin, 'timestamptztypmodout' as typmodout, '-' as typanalyze, 'd' as typalign, 'p' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + " select 1700 as oid, 'numeric' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, -1 as typlen, false as typbyval, 'b' as typtype, 'N' as typcategory, false as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1231 as typarray, 'numeric_in' as typinput, 'numeric_out' as typoutput, 'numeric_recv' as typreceive, 'numeric_send' as typsend, 'numerictypmodin' as typmodin, 'numerictypmodout' as typmodout, '-' as typanalyze, 'i' as typalign, 'm' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + " select 3802 as oid, 'jsonb' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, -1 as typlen, false as typbyval, 'b' as typtype, 'U' as typcategory, false as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 3807 as typarray, 'jsonb_in' as typinput, 'jsonb_out' as typoutput, 'jsonb_recv' as typreceive, 'jsonb_send' as typsend, '-' as typmodin, '-' as typmodout, '-' as typanalyze, 'i' as typalign, 'x' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl\n" + + "),\n" + + "pg_class as (\n" + + " select\n" + + " -1 as oid,\n" + + " table_name as relname,\n" + + " case table_schema when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as relnamespace,\n" + + " 0 as reltype,\n" + + " 0 as reloftype,\n" + + " 0 as relowner,\n" + + " 1 as relam,\n" + + " 0 as relfilenode,\n" + + " 0 as reltablespace,\n" + + " 0 as relpages,\n" + + " 0.0::float8 as reltuples,\n" + + " 0 as relallvisible,\n" + + " 0 as reltoastrelid,\n" + + " false as relhasindex,\n" + + " false as relisshared,\n" + + " 'p' as relpersistence,\n" + + " 'r' as relkind,\n" + + " count(*) as relnatts,\n" + + " 0 as relchecks,\n" + + " false as relhasrules,\n" + + " false as relhastriggers,\n" + + " false as relhassubclass,\n" + + " false as relrowsecurity,\n" + + " false as relforcerowsecurity,\n" + + " true as relispopulated,\n" + + " 'n' as relreplident,\n" + + " false as relispartition,\n" + + " 0 as relrewrite,\n" + + " 0 as relfrozenxid,\n" + + " 0 as relminmxid,\n" + + " '{}'::bigint[] as relacl,\n" + + " '{}'::text[] as reloptions,\n" + + " 0 as relpartbound\n" + + "from information_schema.tables t\n" + + "inner join information_schema.columns using (table_catalog, table_schema, table_name)\n" + + "group by t.table_name, t.table_schema\n" + + "union all\n" + + "select\n" + + " -1 as oid,\n" + + " i.index_name as relname,\n" + + " case table_schema when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as relnamespace,\n" + + " 0 as reltype,\n" + + " 0 as reloftype,\n" + + " 0 as relowner,\n" + + " 1 as relam,\n" + + " 0 as relfilenode,\n" + + " 0 as reltablespace,\n" + + " 0 as relpages,\n" + + " 0.0::float8 as reltuples,\n" + + " 0 as relallvisible,\n" + + " 0 as reltoastrelid,\n" + + " false as relhasindex,\n" + + " false as relisshared,\n" + + " 'p' as relpersistence,\n" + + " 'r' as relkind,\n" + + " count(*) as relnatts,\n" + + " 0 as relchecks,\n" + + " false as relhasrules,\n" + + " false as relhastriggers,\n" + + " false as relhassubclass,\n" + + " false as relrowsecurity,\n" + + " false as relforcerowsecurity,\n" + + " true as relispopulated,\n" + + " 'n' as relreplident,\n" + + " false as relispartition,\n" + + " 0 as relrewrite,\n" + + " 0 as relfrozenxid,\n" + + " 0 as relminmxid,\n" + + " '{}'::bigint[] as relacl,\n" + + " '{}'::text[] as reloptions,\n" + + " 0 as relpartbound\n" + + "from information_schema.indexes i\n" + + "inner join information_schema.index_columns using (table_catalog, table_schema, table_name)\n" + + "group by i.index_name, i.table_schema\n" + ")\n" + "SELECT ns.nspname, t.oid, t.typname, t.typtype, t.typnotnull, t.elemtypoid\n" + "FROM (\n" @@ -150,6 +227,83 @@ public abstract class AbstractNpgsqlMockServerTest extends AbstractMockServerTes + " select 1184 as oid, 'timestamptz' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, 8 as typlen, true as typbyval, 'b' as typtype, 'D' as typcategory, true as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1185 as typarray, 'timestamptz_in' as typinput, 'timestamptz_out' as typoutput, 'timestamptz_recv' as typreceive, 'timestamptz_send' as typsend, 'timestamptztypmodin' as typmodin, 'timestamptztypmodout' as typmodout, '-' as typanalyze, 'd' as typalign, 'p' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + " select 1700 as oid, 'numeric' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, -1 as typlen, false as typbyval, 'b' as typtype, 'N' as typcategory, false as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1231 as typarray, 'numeric_in' as typinput, 'numeric_out' as typoutput, 'numeric_recv' as typreceive, 'numeric_send' as typsend, 'numerictypmodin' as typmodin, 'numerictypmodout' as typmodout, '-' as typanalyze, 'i' as typalign, 'm' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + " select 3802 as oid, 'jsonb' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, -1 as typlen, false as typbyval, 'b' as typtype, 'U' as typcategory, false as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 3807 as typarray, 'jsonb_in' as typinput, 'jsonb_out' as typoutput, 'jsonb_recv' as typreceive, 'jsonb_send' as typsend, '-' as typmodin, '-' as typmodout, '-' as typanalyze, 'i' as typalign, 'x' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl\n" + + "),\n" + + "pg_class as (\n" + + " select\n" + + " -1 as oid,\n" + + " table_name as relname,\n" + + " case table_schema when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as relnamespace,\n" + + " 0 as reltype,\n" + + " 0 as reloftype,\n" + + " 0 as relowner,\n" + + " 1 as relam,\n" + + " 0 as relfilenode,\n" + + " 0 as reltablespace,\n" + + " 0 as relpages,\n" + + " 0.0::float8 as reltuples,\n" + + " 0 as relallvisible,\n" + + " 0 as reltoastrelid,\n" + + " false as relhasindex,\n" + + " false as relisshared,\n" + + " 'p' as relpersistence,\n" + + " 'r' as relkind,\n" + + " count(*) as relnatts,\n" + + " 0 as relchecks,\n" + + " false as relhasrules,\n" + + " false as relhastriggers,\n" + + " false as relhassubclass,\n" + + " false as relrowsecurity,\n" + + " false as relforcerowsecurity,\n" + + " true as relispopulated,\n" + + " 'n' as relreplident,\n" + + " false as relispartition,\n" + + " 0 as relrewrite,\n" + + " 0 as relfrozenxid,\n" + + " 0 as relminmxid,\n" + + " '{}'::bigint[] as relacl,\n" + + " '{}'::text[] as reloptions,\n" + + " 0 as relpartbound\n" + + "from information_schema.tables t\n" + + "inner join information_schema.columns using (table_catalog, table_schema, table_name)\n" + + "group by t.table_name, t.table_schema\n" + + "union all\n" + + "select\n" + + " -1 as oid,\n" + + " i.index_name as relname,\n" + + " case table_schema when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as relnamespace,\n" + + " 0 as reltype,\n" + + " 0 as reloftype,\n" + + " 0 as relowner,\n" + + " 1 as relam,\n" + + " 0 as relfilenode,\n" + + " 0 as reltablespace,\n" + + " 0 as relpages,\n" + + " 0.0::float8 as reltuples,\n" + + " 0 as relallvisible,\n" + + " 0 as reltoastrelid,\n" + + " false as relhasindex,\n" + + " false as relisshared,\n" + + " 'p' as relpersistence,\n" + + " 'r' as relkind,\n" + + " count(*) as relnatts,\n" + + " 0 as relchecks,\n" + + " false as relhasrules,\n" + + " false as relhastriggers,\n" + + " false as relhassubclass,\n" + + " false as relrowsecurity,\n" + + " false as relforcerowsecurity,\n" + + " true as relispopulated,\n" + + " 'n' as relreplident,\n" + + " false as relispartition,\n" + + " 0 as relrewrite,\n" + + " 0 as relfrozenxid,\n" + + " 0 as relminmxid,\n" + + " '{}'::bigint[] as relacl,\n" + + " '{}'::text[] as reloptions,\n" + + " 0 as relpartbound\n" + + "from information_schema.indexes i\n" + + "inner join information_schema.index_columns using (table_catalog, table_schema, table_name)\n" + + "group by i.index_name, i.table_schema\n" + ")\n" + "SELECT ns.nspname, t.oid, t.typname, t.typtype, t.typnotnull, t.elemtypoid\n" + "FROM (\n" @@ -393,6 +547,83 @@ public abstract class AbstractNpgsqlMockServerTest extends AbstractMockServerTes + " select 1184 as oid, 'timestamptz' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, 8 as typlen, true as typbyval, 'b' as typtype, 'D' as typcategory, true as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1185 as typarray, 'timestamptz_in' as typinput, 'timestamptz_out' as typoutput, 'timestamptz_recv' as typreceive, 'timestamptz_send' as typsend, 'timestamptztypmodin' as typmodin, 'timestamptztypmodout' as typmodout, '-' as typanalyze, 'd' as typalign, 'p' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + " select 1700 as oid, 'numeric' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, -1 as typlen, false as typbyval, 'b' as typtype, 'N' as typcategory, false as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 1231 as typarray, 'numeric_in' as typinput, 'numeric_out' as typoutput, 'numeric_recv' as typreceive, 'numeric_send' as typsend, 'numerictypmodin' as typmodin, 'numerictypmodout' as typmodout, '-' as typanalyze, 'i' as typalign, 'm' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl union all\n" + " select 3802 as oid, 'jsonb' as typname, (select oid from pg_namespace where nspname='pg_catalog') as typnamespace, null as typowner, -1 as typlen, false as typbyval, 'b' as typtype, 'U' as typcategory, false as typispreferred, true as typisdefined, ',' as typdelim, 0 as typrelid, 0 as typelem, 3807 as typarray, 'jsonb_in' as typinput, 'jsonb_out' as typoutput, 'jsonb_recv' as typreceive, 'jsonb_send' as typsend, '-' as typmodin, '-' as typmodout, '-' as typanalyze, 'i' as typalign, 'x' as typstorage, false as typnotnull, 0 as typbasetype, -1 as typtypmod, 0 as typndims, 0 as typcollation, null as typdefaultbin, null as typdefault, null as typacl\n" + + "),\n" + + "pg_class as (\n" + + " select\n" + + " -1 as oid,\n" + + " table_name as relname,\n" + + " case table_schema when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as relnamespace,\n" + + " 0 as reltype,\n" + + " 0 as reloftype,\n" + + " 0 as relowner,\n" + + " 1 as relam,\n" + + " 0 as relfilenode,\n" + + " 0 as reltablespace,\n" + + " 0 as relpages,\n" + + " 0.0::float8 as reltuples,\n" + + " 0 as relallvisible,\n" + + " 0 as reltoastrelid,\n" + + " false as relhasindex,\n" + + " false as relisshared,\n" + + " 'p' as relpersistence,\n" + + " 'r' as relkind,\n" + + " count(*) as relnatts,\n" + + " 0 as relchecks,\n" + + " false as relhasrules,\n" + + " false as relhastriggers,\n" + + " false as relhassubclass,\n" + + " false as relrowsecurity,\n" + + " false as relforcerowsecurity,\n" + + " true as relispopulated,\n" + + " 'n' as relreplident,\n" + + " false as relispartition,\n" + + " 0 as relrewrite,\n" + + " 0 as relfrozenxid,\n" + + " 0 as relminmxid,\n" + + " '{}'::bigint[] as relacl,\n" + + " '{}'::text[] as reloptions,\n" + + " 0 as relpartbound\n" + + "from information_schema.tables t\n" + + "inner join information_schema.columns using (table_catalog, table_schema, table_name)\n" + + "group by t.table_name, t.table_schema\n" + + "union all\n" + + "select\n" + + " -1 as oid,\n" + + " i.index_name as relname,\n" + + " case table_schema when 'pg_catalog' then 11 when 'public' then 2200 else 0 end as relnamespace,\n" + + " 0 as reltype,\n" + + " 0 as reloftype,\n" + + " 0 as relowner,\n" + + " 1 as relam,\n" + + " 0 as relfilenode,\n" + + " 0 as reltablespace,\n" + + " 0 as relpages,\n" + + " 0.0::float8 as reltuples,\n" + + " 0 as relallvisible,\n" + + " 0 as reltoastrelid,\n" + + " false as relhasindex,\n" + + " false as relisshared,\n" + + " 'p' as relpersistence,\n" + + " 'r' as relkind,\n" + + " count(*) as relnatts,\n" + + " 0 as relchecks,\n" + + " false as relhasrules,\n" + + " false as relhastriggers,\n" + + " false as relhassubclass,\n" + + " false as relrowsecurity,\n" + + " false as relforcerowsecurity,\n" + + " true as relispopulated,\n" + + " 'n' as relreplident,\n" + + " false as relispartition,\n" + + " 0 as relrewrite,\n" + + " 0 as relfrozenxid,\n" + + " 0 as relminmxid,\n" + + " '{}'::bigint[] as relacl,\n" + + " '{}'::text[] as reloptions,\n" + + " 0 as relpartbound\n" + + "from information_schema.indexes i\n" + + "inner join information_schema.index_columns using (table_catalog, table_schema, table_name)\n" + + "group by i.index_name, i.table_schema\n" + ")\n" + "-- Load field definitions for (free-standing) composite types\n" + "SELECT typ.oid, att.attname, att.atttypid\n" diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/error/PGExceptionTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/error/PGExceptionTest.java index 734cceb9e..efa25264e 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/error/PGExceptionTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/error/PGExceptionTest.java @@ -16,6 +16,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNull; import org.junit.Test; import org.junit.runner.RunWith; @@ -24,6 +25,24 @@ @RunWith(JUnit4.class) public class PGExceptionTest { + @Test + public void testHints() { + assertNull( + PGException.newBuilder("test message") + .setSQLState(SQLState.InternalError) + .setSeverity(Severity.ERROR) + .build() + .getHints()); + assertEquals( + "test hint\nsecond line", + PGException.newBuilder("test message") + .setSQLState(SQLState.InternalError) + .setSeverity(Severity.ERROR) + .setHints("test hint\nsecond line") + .build() + .getHints()); + } + @Test public void testEquals() { assertEquals( diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/golang/GormMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/golang/GormMockServerTest.java index d4de29fbf..6cb47f7d2 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/golang/GormMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/golang/GormMockServerTest.java @@ -34,6 +34,7 @@ import com.google.spanner.v1.ExecuteSqlRequest.QueryMode; import com.google.spanner.v1.ResultSet; import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; import com.google.spanner.v1.RollbackRequest; import com.google.spanner.v1.StructType; import com.google.spanner.v1.StructType.Field; @@ -185,16 +186,11 @@ public void testFirst() { @Test public void testCreateBlogAndUser() { String insertUserSql = - "INSERT INTO \"users\" (\"id\",\"name\",\"email\",\"age\",\"birthday\",\"member_number\",\"activated_at\",\"created_at\",\"updated_at\") VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9)"; - String describeInsertUserSql = - "select $1, $2, $3, $4, $5, $6, $7, $8, $9 from (select \"id\"=$1, \"name\"=$2, \"email\"=$3, \"age\"=$4, \"birthday\"=$5, \"member_number\"=$6, \"activated_at\"=$7, \"created_at\"=$8, \"updated_at\"=$9 from \"users\") p"; + "INSERT INTO \"users\" (\"id\",\"name\",\"email\",\"age\",\"birthday\",\"member_number\",\"activated_at\",\"created_at\",\"updated_at\") VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9) RETURNING \"name_and_number\""; String insertBlogSql = "INSERT INTO \"blogs\" (\"id\",\"name\",\"description\",\"user_id\",\"created_at\",\"updated_at\") VALUES ($1,$2,$3,$4,$5,$6)"; - String describeInsertBlogSql = - "select $1, $2, $3, $4, $5, $6 from (select \"id\"=$1, \"name\"=$2, \"description\"=$3, \"user_id\"=$4, \"created_at\"=$5, \"updated_at\"=$6 from \"blogs\") p"; - mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertUserSql), 0L)); mockSpanner.putStatementResult( - StatementResult.update( + StatementResult.query( Statement.newBuilder(insertUserSql) .bind("p1") .to(1L) @@ -215,25 +211,51 @@ public void testCreateBlogAndUser() { .bind("p9") .to(Timestamp.parseTimestamp("2022-09-09T12:00:00+01:00")) .build(), - 1L)); + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName("name_and_number") + .build()) + .build()) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("User Name null").build()) + .build()) + .setStats(ResultSetStats.newBuilder().setRowCountExact(1L).build()) + .build())); mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeInsertUserSql), + Statement.of(insertUserSql), ResultSet.newBuilder() .setMetadata( - createMetadata( - ImmutableList.of( - TypeCode.INT64, - TypeCode.STRING, - TypeCode.STRING, - TypeCode.INT64, - TypeCode.DATE, - TypeCode.STRING, - TypeCode.TIMESTAMP, - TypeCode.TIMESTAMP, - TypeCode.TIMESTAMP))) + createParameterTypesMetadata( + ImmutableList.of( + TypeCode.INT64, + TypeCode.STRING, + TypeCode.STRING, + TypeCode.INT64, + TypeCode.DATE, + TypeCode.STRING, + TypeCode.TIMESTAMP, + TypeCode.TIMESTAMP, + TypeCode.TIMESTAMP)) + .toBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName("name_and_number") + .build()) + .build())) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertBlogSql), 0L)); mockSpanner.putStatementResult( StatementResult.update( Statement.newBuilder(insertBlogSql) @@ -253,10 +275,10 @@ public void testCreateBlogAndUser() { 1L)); mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeInsertBlogSql), + Statement.of(insertBlogSql), ResultSet.newBuilder() .setMetadata( - createMetadata( + createParameterTypesMetadata( ImmutableList.of( TypeCode.INT64, TypeCode.STRING, @@ -264,26 +286,25 @@ public void testCreateBlogAndUser() { TypeCode.INT64, TypeCode.TIMESTAMP, TypeCode.TIMESTAMP))) + .setStats(ResultSetStats.newBuilder().build()) .build())); String res = gormTest.TestCreateBlogAndUser(createConnString()); assertNull(res); assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); - assertEquals(6, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals(4, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - assertEquals(describeInsertUserSql, requests.get(0).getSql()); + assertEquals(insertUserSql, requests.get(0).getSql()); + assertEquals(QueryMode.PLAN, requests.get(0).getQueryMode()); assertEquals(insertUserSql, requests.get(1).getSql()); - assertEquals(QueryMode.PLAN, requests.get(1).getQueryMode()); - assertEquals(insertUserSql, requests.get(2).getSql()); - assertEquals(QueryMode.NORMAL, requests.get(2).getQueryMode()); - - assertEquals(describeInsertBlogSql, requests.get(3).getSql()); - assertEquals(insertBlogSql, requests.get(4).getSql()); - assertEquals(QueryMode.PLAN, requests.get(4).getQueryMode()); - assertEquals(insertBlogSql, requests.get(5).getSql()); - assertEquals(QueryMode.NORMAL, requests.get(5).getQueryMode()); + assertEquals(QueryMode.NORMAL, requests.get(1).getQueryMode()); + + assertEquals(insertBlogSql, requests.get(2).getSql()); + assertEquals(QueryMode.PLAN, requests.get(2).getQueryMode()); + assertEquals(insertBlogSql, requests.get(3).getSql()); + assertEquals(QueryMode.NORMAL, requests.get(3).getQueryMode()); } @Test @@ -295,9 +316,7 @@ public void testQueryAllDataTypes() { assertNull(res); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - // pgx by default always uses prepared statements. As this statement does not contain any - // parameters, we don't need to describe the parameter types, so it is 'only' sent twice to the - // backend. + // pgx by default always uses prepared statements. assertEquals(2, requests.size()); ExecuteSqlRequest describeRequest = requests.get(0); assertEquals(sql, describeRequest.getSql()); @@ -320,9 +339,7 @@ public void testQueryNullsAllDataTypes() { assertNull(res); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - // pgx by default always uses prepared statements. As this statement does not contain any - // parameters, we don't need to describe the parameter types, so it is 'only' sent twice to the - // backend. + // pgx by default always uses prepared statements. assertEquals(2, requests.size()); ExecuteSqlRequest describeRequest = requests.get(0); assertEquals(sql, describeRequest.getSql()); @@ -341,15 +358,12 @@ public void testInsertAllDataTypes() { "INSERT INTO \"all_types\" " + "(\"col_bigint\",\"col_bool\",\"col_bytea\",\"col_float8\",\"col_int\",\"col_numeric\",\"col_timestamptz\",\"col_date\",\"col_varchar\") " + "VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9)"; - String describeSql = - "select $1, $2, $3, $4, $5, $6, $7, $8, $9 " - + "from (select \"col_bigint\"=$1, \"col_bool\"=$2, \"col_bytea\"=$3, \"col_float8\"=$4, \"col_int\"=$5, \"col_numeric\"=$6, \"col_timestamptz\"=$7, \"col_date\"=$8, \"col_varchar\"=$9 from \"all_types\") p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeSql), + Statement.of(sql), ResultSet.newBuilder() .setMetadata( - createMetadata( + createParameterTypesMetadata( ImmutableList.of( TypeCode.INT64, TypeCode.BOOL, @@ -360,8 +374,8 @@ public void testInsertAllDataTypes() { TypeCode.TIMESTAMP, TypeCode.DATE, TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 0L)); mockSpanner.putStatementResult( StatementResult.update( Statement.newBuilder(sql) @@ -390,31 +404,23 @@ public void testInsertAllDataTypes() { assertNull(res); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - // pgx by default always uses prepared statements. That means that each request is sent three + // pgx by default always uses prepared statements. That means that each request is sent two // times to the backend the first time it is executed: - // 1. DescribeStatement (parameters) - // 2. DescribeStatement (verify validity / PARSE) -- This step could be skipped. - // 3. Execute - assertEquals(3, requests.size()); - ExecuteSqlRequest describeParamsRequest = requests.get(0); - assertEquals(describeSql, describeParamsRequest.getSql()); - assertEquals(QueryMode.PLAN, describeParamsRequest.getQueryMode()); - assertTrue(describeParamsRequest.hasTransaction()); - assertTrue(describeParamsRequest.getTransaction().hasBegin()); - assertTrue(describeParamsRequest.getTransaction().getBegin().hasReadWrite()); - - ExecuteSqlRequest describeRequest = requests.get(1); + // 1. DescribeStatement + // 2. Execute + assertEquals(2, requests.size()); + ExecuteSqlRequest describeRequest = requests.get(0); assertEquals(sql, describeRequest.getSql()); assertEquals(QueryMode.PLAN, describeRequest.getQueryMode()); assertTrue(describeRequest.hasTransaction()); - assertTrue(describeRequest.getTransaction().hasId()); + assertTrue(describeRequest.getTransaction().hasBegin()); + assertTrue(describeRequest.getTransaction().getBegin().hasReadWrite()); - ExecuteSqlRequest executeRequest = requests.get(2); + ExecuteSqlRequest executeRequest = requests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); assertTrue(executeRequest.hasTransaction()); assertTrue(executeRequest.getTransaction().hasId()); - assertEquals(describeRequest.getTransaction().getId(), executeRequest.getTransaction().getId()); assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); CommitRequest commitRequest = mockSpanner.getRequestsOfType(CommitRequest.class).get(0); @@ -427,15 +433,12 @@ public void testInsertNullsAllDataTypes() { "INSERT INTO \"all_types\" " + "(\"col_bigint\",\"col_bool\",\"col_bytea\",\"col_float8\",\"col_int\",\"col_numeric\",\"col_timestamptz\",\"col_date\",\"col_varchar\") " + "VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9)"; - String describeSql = - "select $1, $2, $3, $4, $5, $6, $7, $8, $9 " - + "from (select \"col_bigint\"=$1, \"col_bool\"=$2, \"col_bytea\"=$3, \"col_float8\"=$4, \"col_int\"=$5, \"col_numeric\"=$6, \"col_timestamptz\"=$7, \"col_date\"=$8, \"col_varchar\"=$9 from \"all_types\") p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeSql), + Statement.of(sql), ResultSet.newBuilder() .setMetadata( - createMetadata( + createParameterTypesMetadata( ImmutableList.of( TypeCode.INT64, TypeCode.BOOL, @@ -446,8 +449,8 @@ public void testInsertNullsAllDataTypes() { TypeCode.TIMESTAMP, TypeCode.DATE, TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 0L)); mockSpanner.putStatementResult( StatementResult.update( Statement.newBuilder(sql) @@ -476,19 +479,15 @@ public void testInsertNullsAllDataTypes() { assertNull(res); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - // pgx by default always uses prepared statements. That means that each request is sent three + // pgx by default always uses prepared statements. That means that each request is sent two // times to the backend the first time it is executed: - // 1. DescribeStatement (parameters) - // 2. DescribeStatement (verify validity / PARSE) -- This step could be skipped. - // 3. Execute - assertEquals(3, requests.size()); - ExecuteSqlRequest describeParamsRequest = requests.get(0); - assertEquals(describeSql, describeParamsRequest.getSql()); - assertEquals(QueryMode.PLAN, describeParamsRequest.getQueryMode()); - ExecuteSqlRequest describeRequest = requests.get(1); + // 1. DescribeStatement + // 2. Execute + assertEquals(2, requests.size()); + ExecuteSqlRequest describeRequest = requests.get(0); assertEquals(sql, describeRequest.getSql()); assertEquals(QueryMode.PLAN, describeRequest.getQueryMode()); - ExecuteSqlRequest executeRequest = requests.get(2); + ExecuteSqlRequest executeRequest = requests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); } @@ -497,15 +496,12 @@ public void testInsertNullsAllDataTypes() { public void testUpdateAllDataTypes() { String sql = "UPDATE \"all_types\" SET \"col_bigint\"=$1,\"col_bool\"=$2,\"col_bytea\"=$3,\"col_float8\"=$4,\"col_int\"=$5,\"col_numeric\"=$6,\"col_timestamptz\"=$7,\"col_date\"=$8,\"col_varchar\"=$9 WHERE \"col_varchar\" = $10"; - String describeSql = - "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 from " - + "(select \"col_bigint\"=$1, \"col_bool\"=$2, \"col_bytea\"=$3, \"col_float8\"=$4, \"col_int\"=$5, \"col_numeric\"=$6, \"col_timestamptz\"=$7, \"col_date\"=$8, \"col_varchar\"=$9 from \"all_types\" WHERE \"col_varchar\" = $10) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeSql), + Statement.of(sql), ResultSet.newBuilder() .setMetadata( - createMetadata( + createParameterTypesMetadata( ImmutableList.of( TypeCode.INT64, TypeCode.BOOL, @@ -517,8 +513,8 @@ public void testUpdateAllDataTypes() { TypeCode.DATE, TypeCode.STRING, TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 0L)); mockSpanner.putStatementResult( StatementResult.update( Statement.newBuilder(sql) @@ -549,19 +545,15 @@ public void testUpdateAllDataTypes() { assertNull(res); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - // pgx by default always uses prepared statements. That means that each request is sent three + // pgx by default always uses prepared statements. That means that each request is sent two // times to the backend the first time it is executed: - // 1. DescribeStatement (parameters) - // 2. DescribeStatement (verify validity / PARSE) -- This step could be skipped. - // 3. Execute - assertEquals(3, requests.size()); - ExecuteSqlRequest describeParamsRequest = requests.get(0); - assertEquals(describeSql, describeParamsRequest.getSql()); - assertEquals(QueryMode.PLAN, describeParamsRequest.getQueryMode()); - ExecuteSqlRequest describeRequest = requests.get(1); + // 1. DescribeStatement + // 2. Execute + assertEquals(2, requests.size()); + ExecuteSqlRequest describeRequest = requests.get(0); assertEquals(sql, describeRequest.getSql()); assertEquals(QueryMode.PLAN, describeRequest.getQueryMode()); - ExecuteSqlRequest executeRequest = requests.get(2); + ExecuteSqlRequest executeRequest = requests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); } @@ -569,15 +561,13 @@ public void testUpdateAllDataTypes() { @Test public void testDelete() { String sql = "DELETE FROM \"all_types\" WHERE \"all_types\".\"col_varchar\" = $1"; - String describeSql = - "select $1 from (select 1 from \"all_types\" WHERE \"all_types\".\"col_varchar\" = $1) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeSql), + Statement.of(sql), ResultSet.newBuilder() - .setMetadata(createMetadata(ImmutableList.of(TypeCode.STRING))) + .setMetadata(createParameterTypesMetadata(ImmutableList.of(TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 0L)); mockSpanner.putStatementResult( StatementResult.update(Statement.newBuilder(sql).bind("p1").to("test_string").build(), 1L)); @@ -587,17 +577,13 @@ public void testDelete() { List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); // pgx by default always uses prepared statements. That means that each request is sent three // times to the backend the first time it is executed: - // 1. DescribeStatement (parameters) - // 2. DescribeStatement (verify validity / PARSE) -- This step could be skipped. - // 3. Execute - assertEquals(3, requests.size()); - ExecuteSqlRequest describeParamsRequest = requests.get(0); - assertEquals(describeSql, describeParamsRequest.getSql()); - assertEquals(QueryMode.PLAN, describeParamsRequest.getQueryMode()); - ExecuteSqlRequest describeRequest = requests.get(1); + // 1. DescribeStatement + // 2. Execute + assertEquals(2, requests.size()); + ExecuteSqlRequest describeRequest = requests.get(0); assertEquals(sql, describeRequest.getSql()); assertEquals(QueryMode.PLAN, describeRequest.getQueryMode()); - ExecuteSqlRequest executeRequest = requests.get(2); + ExecuteSqlRequest executeRequest = requests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); } @@ -610,15 +596,12 @@ public void testCreateInBatches() { + "($1,$2,$3,$4,$5,$6,$7,$8,$9)," + "($10,$11,$12,$13,$14,$15,$16,$17,$18)," + "($19,$20,$21,$22,$23,$24,$25,$26,$27)"; - String describeSql = - "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27 from " - + "(select \"col_bigint\"=$1, \"col_bool\"=$2, \"col_bytea\"=$3, \"col_float8\"=$4, \"col_int\"=$5, \"col_numeric\"=$6, \"col_timestamptz\"=$7, \"col_date\"=$8, \"col_varchar\"=$9, \"col_bigint\"=$10, \"col_bool\"=$11, \"col_bytea\"=$12, \"col_float8\"=$13, \"col_int\"=$14, \"col_numeric\"=$15, \"col_timestamptz\"=$16, \"col_date\"=$17, \"col_varchar\"=$18, \"col_bigint\"=$19, \"col_bool\"=$20, \"col_bytea\"=$21, \"col_float8\"=$22, \"col_int\"=$23, \"col_numeric\"=$24, \"col_timestamptz\"=$25, \"col_date\"=$26, \"col_varchar\"=$27 from \"all_types\") p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeSql), + Statement.of(sql), ResultSet.newBuilder() .setMetadata( - createMetadata( + createParameterTypesMetadata( ImmutableList.of( TypeCode.INT64, TypeCode.BOOL, @@ -647,8 +630,8 @@ public void testCreateInBatches() { TypeCode.TIMESTAMP, TypeCode.DATE, TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 0L)); mockSpanner.putStatementResult( StatementResult.update( Statement.newBuilder(sql) @@ -713,19 +696,15 @@ public void testCreateInBatches() { assertNull(res); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - // pgx by default always uses prepared statements. That means that each request is sent three + // pgx by default always uses prepared statements. That means that each request is sent two // times to the backend the first time it is executed: - // 1. DescribeStatement (parameters) - // 2. DescribeStatement (verify validity / PARSE) -- This step could be skipped. - // 3. Execute - assertEquals(3, requests.size()); - ExecuteSqlRequest describeParamsRequest = requests.get(0); - assertEquals(describeSql, describeParamsRequest.getSql()); - assertEquals(QueryMode.PLAN, describeParamsRequest.getQueryMode()); - ExecuteSqlRequest describeRequest = requests.get(1); + // 1. DescribeStatement + // 2. Execute + assertEquals(2, requests.size()); + ExecuteSqlRequest describeRequest = requests.get(0); assertEquals(sql, describeRequest.getSql()); assertEquals(QueryMode.PLAN, describeRequest.getQueryMode()); - ExecuteSqlRequest executeRequest = requests.get(2); + ExecuteSqlRequest executeRequest = requests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); } @@ -733,14 +712,13 @@ public void testCreateInBatches() { @Test public void testTransaction() { String sql = "INSERT INTO \"all_types\" (\"col_varchar\") VALUES ($1)"; - String describeSql = "select $1 from (select \"col_varchar\"=$1 from \"all_types\") p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeSql), + Statement.of(sql), ResultSet.newBuilder() - .setMetadata(createMetadata(ImmutableList.of(TypeCode.STRING))) + .setMetadata(createParameterTypesMetadata(ImmutableList.of(TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 0L)); mockSpanner.putStatementResult( StatementResult.update(Statement.newBuilder(sql).bind("p1").to("1").build(), 1L)); mockSpanner.putStatementResult( @@ -750,35 +728,25 @@ public void testTransaction() { assertNull(res); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - assertEquals(4, requests.size()); - ExecuteSqlRequest describeParamsRequest = requests.get(0); - assertEquals(describeSql, describeParamsRequest.getSql()); - assertEquals(QueryMode.PLAN, describeParamsRequest.getQueryMode()); - assertTrue(describeParamsRequest.hasTransaction()); - assertTrue(describeParamsRequest.getTransaction().hasBegin()); - assertTrue(describeParamsRequest.getTransaction().getBegin().hasReadWrite()); - - ExecuteSqlRequest describeRequest = requests.get(1); + assertEquals(3, requests.size()); + ExecuteSqlRequest describeRequest = requests.get(0); assertEquals(sql, describeRequest.getSql()); assertEquals(QueryMode.PLAN, describeRequest.getQueryMode()); assertTrue(describeRequest.hasTransaction()); - assertTrue(describeRequest.getTransaction().hasId()); + assertTrue(describeRequest.getTransaction().hasBegin()); + assertTrue(describeRequest.getTransaction().getBegin().hasReadWrite()); - ExecuteSqlRequest executeRequest1 = requests.get(2); + ExecuteSqlRequest executeRequest1 = requests.get(1); assertEquals(sql, executeRequest1.getSql()); assertEquals(QueryMode.NORMAL, executeRequest1.getQueryMode()); assertTrue(executeRequest1.hasTransaction()); assertTrue(executeRequest1.getTransaction().hasId()); - assertEquals( - describeRequest.getTransaction().getId(), executeRequest1.getTransaction().getId()); ExecuteSqlRequest executeRequest2 = requests.get(2); assertEquals(sql, executeRequest2.getSql()); assertEquals(QueryMode.NORMAL, executeRequest2.getQueryMode()); assertTrue(executeRequest2.hasTransaction()); assertTrue(executeRequest2.getTransaction().hasId()); - assertEquals( - describeRequest.getTransaction().getId(), executeRequest2.getTransaction().getId()); assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); CommitRequest commitRequest = mockSpanner.getRequestsOfType(CommitRequest.class).get(0); @@ -788,14 +756,13 @@ public void testTransaction() { @Test public void testNestedTransaction() { String sql = "INSERT INTO \"all_types\" (\"col_varchar\") VALUES ($1)"; - String describeSql = "select $1 from (select \"col_varchar\"=$1 from \"all_types\") p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeSql), + Statement.of(sql), ResultSet.newBuilder() - .setMetadata(createMetadata(ImmutableList.of(TypeCode.STRING))) + .setMetadata(createParameterTypesMetadata(ImmutableList.of(TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 0L)); mockSpanner.putStatementResult( StatementResult.update(Statement.newBuilder(sql).bind("p1").to("1").build(), 1L)); mockSpanner.putStatementResult( @@ -804,7 +771,7 @@ public void testNestedTransaction() { // Nested transactions are not supported, as we don't support savepoints. String res = gormTest.TestNestedTransaction(createConnString()); assertEquals( - "failed to execute nested transaction: ERROR: current transaction is aborted, commands ignored until end of transaction block (SQLSTATE P0001)", + "failed to execute nested transaction: ERROR: current transaction is aborted, commands ignored until end of transaction block (SQLSTATE 25P02)", res); assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); } @@ -812,58 +779,59 @@ public void testNestedTransaction() { @Test public void testErrorInTransaction() { String insertSql = "INSERT INTO \"all_types\" (\"col_varchar\") VALUES ($1)"; - String describeInsertSql = "select $1 from (select \"col_varchar\"=$1 from \"all_types\") p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeInsertSql), + Statement.of(insertSql), ResultSet.newBuilder() - .setMetadata(createMetadata(ImmutableList.of(TypeCode.STRING))) + .setMetadata(createParameterTypesMetadata(ImmutableList.of(TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertSql), 0L)); mockSpanner.putStatementResult( StatementResult.exception( Statement.newBuilder(insertSql).bind("p1").to("1").build(), Status.ALREADY_EXISTS.withDescription("Row [1] already exists").asRuntimeException())); String updateSql = "UPDATE \"all_types\" SET \"col_int\"=$1 WHERE \"col_varchar\" = $2"; - String describeUpdateSql = - "select $1, $2 from (select \"col_int\"=$1 from \"all_types\" WHERE \"col_varchar\" = $2) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeUpdateSql), + Statement.of(updateSql), ResultSet.newBuilder() - .setMetadata(createMetadata(ImmutableList.of(TypeCode.INT64, TypeCode.STRING))) + .setMetadata( + createParameterTypesMetadata(ImmutableList.of(TypeCode.INT64, TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(updateSql), 0L)); mockSpanner.putStatementResult( StatementResult.update( Statement.newBuilder(updateSql).bind("p1").to(100L).bind("p2").to("1").build(), 1L)); String res = gormTest.TestErrorInTransaction(createConnString()); assertEquals( - "failed to execute transaction: ERROR: current transaction is aborted, commands ignored until end of transaction block (SQLSTATE P0001)", + "failed to execute transaction: ERROR: current transaction is aborted, commands ignored until end of transaction block (SQLSTATE 25P02)", res); assertEquals(1, mockSpanner.countRequestsOfType(RollbackRequest.class)); - // This test also leads to 1 commit request. The reason for this is that the update statement is - // also described when gorm tries to execute it. At that point, there is no read/write - // transaction anymore on the underlying Spanner connection, as that transaction was rolled back - // when the insert statement failed. It is therefore executed using auto-commit, which again - // automatically leads to a commit. This is not a problem, as it is just an analyze of an update - // statement. - assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); } @Test public void testReadOnlyTransaction() { String sql = "SELECT * FROM \"all_types\" WHERE \"all_types\".\"col_varchar\" = $1"; - String describeSql = - "select $1 from (SELECT * FROM \"all_types\" WHERE \"all_types\".\"col_varchar\" = $1) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeSql), + Statement.of(sql), ResultSet.newBuilder() - .setMetadata(createMetadata(ImmutableList.of(TypeCode.STRING))) + .setMetadata( + ALL_TYPES_METADATA + .toBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .build()) + .build()) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.query(Statement.of(sql), ALL_TYPES_RESULTSET)); mockSpanner.putStatementResult( StatementResult.query( Statement.newBuilder(sql).bind("p1").to("1").build(), ALL_TYPES_RESULTSET)); @@ -880,24 +848,17 @@ public void testReadOnlyTransaction() { assertTrue(beginRequest.getOptions().hasReadOnly()); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - assertEquals(4, requests.size()); + assertEquals(3, requests.size()); ExecuteSqlRequest describeRequest = requests.get(0); assertEquals(sql, describeRequest.getSql()); assertEquals(QueryMode.PLAN, describeRequest.getQueryMode()); assertTrue(describeRequest.getTransaction().hasId()); - ExecuteSqlRequest describeParamsRequest = requests.get(1); - assertEquals(describeSql, describeParamsRequest.getSql()); - assertEquals(QueryMode.PLAN, describeParamsRequest.getQueryMode()); - assertEquals( - describeParamsRequest.getTransaction().getId(), describeRequest.getTransaction().getId()); - - ExecuteSqlRequest executeRequest = requests.get(2); + ExecuteSqlRequest executeRequest = requests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); - assertEquals( - describeParamsRequest.getTransaction().getId(), executeRequest.getTransaction().getId()); + assertEquals(describeRequest.getTransaction().getId(), executeRequest.getTransaction().getId()); // The read-only transaction is 'committed', but that does not cause a CommitRequest to be sent // to Cloud Spanner. diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java index 236fdfc7b..410fd2075 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java @@ -45,6 +45,7 @@ import com.google.spanner.v1.Mutation.OperationCase; import com.google.spanner.v1.ResultSet; import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; import com.google.spanner.v1.RollbackRequest; import com.google.spanner.v1.StructType; import com.google.spanner.v1.StructType.Field; @@ -56,7 +57,6 @@ import java.util.List; import java.util.stream.Collectors; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; @@ -76,7 +76,7 @@ public class PgxMockServerTest extends AbstractMockServerTest { private static PgxTest pgxTest; - @Rule public Timeout globalTimeout = Timeout.seconds(30L); + @Rule public Timeout globalTimeout = Timeout.seconds(300L); @Parameter public boolean useDomainSocket; @@ -176,19 +176,23 @@ public void testQueryWithParameter() { .build(); // Add a query result for the statement parameter types. - String selectParamsSql = "select $1 from (SELECT * FROM FOO WHERE BAR=$1) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(selectParamsSql), + Statement.of(sql), ResultSet.newBuilder() - .setMetadata(createMetadata(ImmutableList.of(TypeCode.STRING))) + .setMetadata( + metadata + .toBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .build())) .build())); - - // Add a query result with only the metadata for the query without parameter values. - mockSpanner.putStatementResult( - StatementResult.query( - Statement.of(sql), ResultSet.newBuilder().setMetadata(metadata).build())); - // Also add a query result with both metadata and rows for the statement with parameter values. + // Add a query result with both metadata and rows for the statement with parameter values. mockSpanner.putStatementResult( StatementResult.query( statement, @@ -204,19 +208,15 @@ public void testQueryWithParameter() { assertNull(res); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - // pgx by default always uses prepared statements. That means that each request is sent three + // pgx by default always uses prepared statements. That means that each request is sent two // times to the backend: - // 1. DescribeStatement (results) - // 2. DescribeStatement (parameters) - // 3. Execute (including DescribePortal) - assertEquals(3, requests.size()); + // 1. DescribeStatement (parameters + results) + // 2. Execute (including DescribePortal) + assertEquals(2, requests.size()); ExecuteSqlRequest describeStatementRequest = requests.get(0); assertEquals(sql, describeStatementRequest.getSql()); assertEquals(QueryMode.PLAN, describeStatementRequest.getQueryMode()); - ExecuteSqlRequest describeParametersRequest = requests.get(1); - assertEquals(selectParamsSql, describeParametersRequest.getSql()); - assertEquals(QueryMode.PLAN, describeParametersRequest.getQueryMode()); - ExecuteSqlRequest executeRequest = requests.get(2); + ExecuteSqlRequest executeRequest = requests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); } @@ -268,15 +268,12 @@ public void testQueryAllDataTypes() { public void testUpdateAllDataTypes() { String sql = "UPDATE \"all_types\" SET \"col_bigint\"=$1,\"col_bool\"=$2,\"col_bytea\"=$3,\"col_float8\"=$4,\"col_int\"=$5,\"col_numeric\"=$6,\"col_timestamptz\"=$7,\"col_date\"=$8,\"col_varchar\"=$9,\"col_jsonb\"=$10 WHERE \"col_varchar\" = $11"; - String describeSql = - "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11 from " - + "(select \"col_bigint\"=$1, \"col_bool\"=$2, \"col_bytea\"=$3, \"col_float8\"=$4, \"col_int\"=$5, \"col_numeric\"=$6, \"col_timestamptz\"=$7, \"col_date\"=$8, \"col_varchar\"=$9, \"col_jsonb\"=$10 from \"all_types\" WHERE \"col_varchar\" = $11) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeSql), + Statement.of(sql), ResultSet.newBuilder() .setMetadata( - createMetadata( + createParameterTypesMetadata( ImmutableList.of( TypeCode.INT64, TypeCode.BOOL, @@ -289,8 +286,8 @@ public void testUpdateAllDataTypes() { TypeCode.STRING, TypeCode.STRING, TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 0L)); mockSpanner.putStatementResult( StatementResult.update( Statement.newBuilder(sql) @@ -323,19 +320,15 @@ public void testUpdateAllDataTypes() { assertNull(res); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - // pgx by default always uses prepared statements. That means that each request is sent three + // pgx by default always uses prepared statements. That means that each request is sent two // times to the backend the first time it is executed: // 1. DescribeStatement (parameters) - // 2. DescribeStatement (verify validity / PARSE) -- This step could be skipped. - // 3. Execute - assertEquals(3, requests.size()); + // 2. Execute + assertEquals(2, requests.size()); ExecuteSqlRequest describeParamsRequest = requests.get(0); - assertEquals(describeSql, describeParamsRequest.getSql()); + assertEquals(sql, describeParamsRequest.getSql()); assertEquals(QueryMode.PLAN, describeParamsRequest.getQueryMode()); - ExecuteSqlRequest describeRequest = requests.get(1); - assertEquals(sql, describeRequest.getSql()); - assertEquals(QueryMode.PLAN, describeRequest.getQueryMode()); - ExecuteSqlRequest executeRequest = requests.get(2); + ExecuteSqlRequest executeRequest = requests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); } @@ -346,14 +339,12 @@ public void testInsertAllDataTypes() { "INSERT INTO all_types " + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"; - String describeSql = - "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 from (select col_bigint=$1, col_bool=$2, col_bytea=$3, col_float8=$4, col_int=$5, col_numeric=$6, col_timestamptz=$7, col_date=$8, col_varchar=$9, col_jsonb=$10 from all_types) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeSql), + Statement.of(sql), ResultSet.newBuilder() .setMetadata( - createMetadata( + createParameterTypesMetadata( ImmutableList.of( TypeCode.INT64, TypeCode.BOOL, @@ -365,8 +356,8 @@ public void testInsertAllDataTypes() { TypeCode.DATE, TypeCode.STRING, TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 0L)); mockSpanner.putStatementResult( StatementResult.update( Statement.newBuilder(sql) @@ -397,19 +388,107 @@ public void testInsertAllDataTypes() { assertNull(res); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - // pgx by default always uses prepared statements. That means that each request is sent three + // pgx by default always uses prepared statements. That means that each request is sent two // times to the backend the first time it is executed: // 1. DescribeStatement (parameters) - // 2. DescribeStatement (verify validity / PARSE) -- This step could be skipped. - // 3. Execute - assertEquals(3, requests.size()); + // 2. Execute + assertEquals(2, requests.size()); ExecuteSqlRequest describeParamsRequest = requests.get(0); - assertEquals(describeSql, describeParamsRequest.getSql()); + assertEquals(sql, describeParamsRequest.getSql()); assertEquals(QueryMode.PLAN, describeParamsRequest.getQueryMode()); - ExecuteSqlRequest describeRequest = requests.get(1); - assertEquals(sql, describeRequest.getSql()); - assertEquals(QueryMode.PLAN, describeRequest.getQueryMode()); - ExecuteSqlRequest executeRequest = requests.get(2); + ExecuteSqlRequest executeRequest = requests.get(1); + assertEquals(sql, executeRequest.getSql()); + assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); + } + + @Test + public void testInsertAllDataTypesReturning() { + String sql = + "INSERT INTO all_types " + + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) returning *"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata( + ALL_TYPES_METADATA + .toBuilder() + .setUndeclaredParameters( + createParameterTypesMetadata( + ImmutableList.of( + TypeCode.INT64, + TypeCode.BOOL, + TypeCode.BYTES, + TypeCode.FLOAT64, + TypeCode.INT64, + TypeCode.NUMERIC, + TypeCode.TIMESTAMP, + TypeCode.DATE, + TypeCode.STRING, + TypeCode.STRING)) + .getUndeclaredParameters())) + .setStats(ResultSetStats.newBuilder().build()) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(sql) + .bind("p1") + .to(100L) + .bind("p2") + .to(true) + .bind("p3") + .to(ByteArray.copyFrom("test_bytes")) + .bind("p4") + .to(3.14d) + .bind("p5") + .to(1L) + .bind("p6") + .to(com.google.cloud.spanner.Value.pgNumeric("6.626")) + .bind("p7") + .to(Timestamp.parseTimestamp("2022-03-24T06:39:10.123456000Z")) + .bind("p8") + .to(Date.parseDate("2022-04-02")) + .bind("p9") + .to("test_string") + .bind("p10") + .to("{\"key\": \"value\"}") + .build(), + ResultSet.newBuilder() + .setMetadata( + ALL_TYPES_METADATA + .toBuilder() + .setUndeclaredParameters( + createParameterTypesMetadata( + ImmutableList.of( + TypeCode.INT64, + TypeCode.BOOL, + TypeCode.BYTES, + TypeCode.FLOAT64, + TypeCode.INT64, + TypeCode.NUMERIC, + TypeCode.TIMESTAMP, + TypeCode.DATE, + TypeCode.STRING, + TypeCode.STRING)) + .getUndeclaredParameters())) + .setStats(ResultSetStats.newBuilder().setRowCountExact(1L).build()) + .addRows(ALL_TYPES_RESULTSET.getRows(0)) + .build())); + + String res = pgxTest.TestInsertAllDataTypesReturning(createConnString()); + + assertNull(res); + List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); + // pgx by default always uses prepared statements. That means that each request is sent two + // times to the backend the first time it is executed: + // 1. DescribeStatement (parameters) + // 2. Execute + assertEquals(2, requests.size()); + ExecuteSqlRequest describeParamsRequest = requests.get(0); + assertEquals(sql, describeParamsRequest.getSql()); + assertEquals(QueryMode.PLAN, describeParamsRequest.getQueryMode()); + ExecuteSqlRequest executeRequest = requests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); } @@ -420,14 +499,12 @@ public void testInsertBatch() { "INSERT INTO all_types " + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"; - String describeSql = - "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 from (select col_bigint=$1, col_bool=$2, col_bytea=$3, col_float8=$4, col_int=$5, col_numeric=$6, col_timestamptz=$7, col_date=$8, col_varchar=$9, col_jsonb=$10 from all_types) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeSql), + Statement.of(sql), ResultSet.newBuilder() .setMetadata( - createMetadata( + createParameterTypesMetadata( ImmutableList.of( TypeCode.INT64, TypeCode.BOOL, @@ -439,8 +516,8 @@ public void testInsertBatch() { TypeCode.DATE, TypeCode.STRING, TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 0L)); int batchSize = 10; for (int i = 0; i < batchSize; i++) { mockSpanner.putStatementResult( @@ -476,24 +553,17 @@ public void testInsertBatch() { assertNull(res); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - // pgx by default always uses prepared statements. That means that each request is sent 2 times - // to the backend to be described, and then `batchSize` times to be executed. + // pgx by default always uses prepared statements. That means that each request is sent once + // to the backend to be described, and then `batchSize` times to be executed (which is sent as + // one BatchDML request). // 1. DescribeStatement (parameters) - // 2. DescribeStatement (verify validity / PARSE) -- This step could be skipped. - // 3. Execute 10 times. - assertEquals(2, requests.size()); + // 2. Execute 10 times. + assertEquals(1, requests.size()); ExecuteSqlRequest describeParamsRequest = requests.get(0); - assertEquals(describeSql, describeParamsRequest.getSql()); + assertEquals(sql, describeParamsRequest.getSql()); assertEquals(QueryMode.PLAN, describeParamsRequest.getQueryMode()); - // The 'describe' query for the parameters will be executed as a single use transaction. - assertTrue(describeParamsRequest.getTransaction().hasSingleUse()); - - // The analyzeUpdate that is executed to verify the validity of the DML statement is executed as - // a separate transaction. - ExecuteSqlRequest describeRequest = requests.get(1); - assertEquals(sql, describeRequest.getSql()); - assertEquals(QueryMode.PLAN, describeRequest.getQueryMode()); - assertTrue(describeRequest.getTransaction().hasBegin()); + assertTrue(describeParamsRequest.getTransaction().hasBegin()); + assertTrue(describeParamsRequest.getTransaction().getBegin().hasReadWrite()); assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)); ExecuteBatchDmlRequest batchDmlRequest = @@ -501,10 +571,10 @@ public void testInsertBatch() { assertEquals(batchSize, batchDmlRequest.getStatementsCount()); assertTrue(batchDmlRequest.getTransaction().hasBegin()); - // There are two commit requests, as the 'Describe statement' message is executed as a separate - // transaction. - List commitRequests = mockSpanner.getRequestsOfType(CommitRequest.class); - assertEquals(2, commitRequests.size()); + // There are two commit requests: + // 1. One for the analyzeUpdate to describe the insert statement. + // 2. One for the actual batch. + assertEquals(2, mockSpanner.countRequestsOfType(CommitRequest.class)); } @Test @@ -513,14 +583,12 @@ public void testMixedBatch() { "INSERT INTO all_types " + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"; - String describeInsertSql = - "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 from (select col_bigint=$1, col_bool=$2, col_bytea=$3, col_float8=$4, col_int=$5, col_numeric=$6, col_timestamptz=$7, col_date=$8, col_varchar=$9, col_jsonb=$10 from all_types) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeInsertSql), + Statement.of(insertSql), ResultSet.newBuilder() .setMetadata( - createMetadata( + createParameterTypesMetadata( ImmutableList.of( TypeCode.INT64, TypeCode.BOOL, @@ -532,54 +600,60 @@ public void testMixedBatch() { TypeCode.DATE, TypeCode.STRING, TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertSql), 0L)); String selectSql = "select count(*) from all_types where col_bool=$1"; - ResultSet resultSet = - ResultSet.newBuilder() - .setMetadata( - ResultSetMetadata.newBuilder() - .setRowType( - StructType.newBuilder() - .addFields( - Field.newBuilder() - .setName("c") - .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) - .build()) + ResultSetMetadata metadata = + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("c") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) .build()) .build()) + .build(); + ResultSet resultSet = + ResultSet.newBuilder() + .setMetadata(metadata) .addRows( ListValue.newBuilder() .addValues(Value.newBuilder().setStringValue("3").build()) .build()) .build(); - mockSpanner.putStatementResult(StatementResult.query(Statement.of(selectSql), resultSet)); mockSpanner.putStatementResult( StatementResult.query( - Statement.newBuilder(selectSql).bind("p1").to(true).build(), resultSet)); - - String describeParamsSelectSql = - "select $1 from (select count(*) from all_types where col_bool=$1) p"; - mockSpanner.putStatementResult( - StatementResult.query( - Statement.of(describeParamsSelectSql), + Statement.of(selectSql), ResultSet.newBuilder() - .setMetadata(createMetadata(ImmutableList.of(TypeCode.BOOL))) + .setMetadata( + metadata + .toBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.BOOL).build()) + .build()) + .build()) + .build()) .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(selectSql).bind("p1").to(true).build(), resultSet)); String updateSql = "update all_types set col_bool=false where col_bool=$1"; - mockSpanner.putStatementResult(StatementResult.update(Statement.of(updateSql), 0L)); - mockSpanner.putStatementResult( - StatementResult.update(Statement.newBuilder(updateSql).bind("p1").to(true).build(), 3L)); - String describeUpdateSql = - "select $1 from (select col_bool=false from all_types where col_bool=$1) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeUpdateSql), + Statement.of(updateSql), ResultSet.newBuilder() - .setMetadata(createMetadata(ImmutableList.of(TypeCode.BOOL))) + .setMetadata(createParameterTypesMetadata(ImmutableList.of(TypeCode.BOOL))) + .setStats(ResultSetStats.newBuilder().build()) .build())); + mockSpanner.putStatementResult( + StatementResult.update(Statement.newBuilder(updateSql).bind("p1").to(true).build(), 3L)); int batchSize = 5; for (int i = 0; i < batchSize; i++) { @@ -616,91 +690,62 @@ public void testMixedBatch() { assertNull(res); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - // pgx by default always uses prepared statements. That means that we get the following list of - // statements: + // pgx by default always uses prepared statements. In addition, pgx first describes all + // statements in a batch before it executes the batch. The describe-messages that it sends do + // not use the current transaction. + // That means that we get the following list of statements: // 1. Describe parameters of insert statement in PLAN mode. - // 2. Parse insert statement in PLAN mode. - // 3. Describe columns of select statement in PLAN mode. - // 4. Describe parameters of select statement in PLAN mode. - // 5. Describe parameters of update statement in PLAN mode. - // 6. Parse update statement in PLAN mode. - // 7. Execute select statement. - // 8. Execute update statement. - assertEquals(8, requests.size()); + // 2. Describe parameters and columns of select statement in PLAN mode. + // 3. Describe parameters of update statement in PLAN mode. + // 4. Execute select statement. + // 5. Execute update statement. + // In addition, we get one ExecuteBatchDml request for the insert statements. + assertEquals(5, requests.size()); // NOTE: pgx will first create prepared statements for sql strings that it does not yet know. // All those describe statement messages will be executed in separate (single-use) transactions. // The order in which the describe statements are executed is random. - ExecuteSqlRequest describeInsertParamsRequest = - requests.stream() - .filter(request -> request.getSql().equals(describeInsertSql)) - .findFirst() - .orElse(ExecuteSqlRequest.getDefaultInstance()); - assertEquals(describeInsertSql, describeInsertParamsRequest.getSql()); - assertEquals(QueryMode.PLAN, describeInsertParamsRequest.getQueryMode()); - // The 'describe' query for the parameters will be executed as a single use transaction. - assertTrue(describeInsertParamsRequest.getTransaction().hasSingleUse()); - - ExecuteSqlRequest parseInsertRequest = + List describeInsertRequests = requests.stream() .filter( request -> - request.getSql().equals(insertSql) && request.getQueryMode() == QueryMode.PLAN) - .findFirst() - .orElse(ExecuteSqlRequest.getDefaultInstance()); - assertEquals(insertSql, parseInsertRequest.getSql()); - assertEquals(QueryMode.PLAN, parseInsertRequest.getQueryMode()); - assertTrue(parseInsertRequest.getTransaction().hasBegin()); - - ExecuteSqlRequest describeSelectColumnsRequest = + request.getSql().equals(insertSql) + && request.getQueryMode().equals(QueryMode.PLAN)) + .collect(Collectors.toList()); + assertEquals(1, describeInsertRequests.size()); + // TODO(#477): These Describe-message flows could use single-use read/write transactions. + assertTrue(describeInsertRequests.get(0).getTransaction().hasBegin()); + assertTrue(describeInsertRequests.get(0).getTransaction().getBegin().hasReadWrite()); + + List describeSelectRequests = requests.stream() .filter( request -> request.getSql().equals(selectSql) && request.getQueryMode() == QueryMode.PLAN) - .findFirst() - .orElse(ExecuteSqlRequest.getDefaultInstance()); - assertEquals(selectSql, describeSelectColumnsRequest.getSql()); - assertEquals(QueryMode.PLAN, describeSelectColumnsRequest.getQueryMode()); - assertTrue(describeSelectColumnsRequest.getTransaction().hasSingleUse()); + .collect(Collectors.toList()); + assertEquals(1, describeSelectRequests.size()); + assertTrue(describeSelectRequests.get(0).getTransaction().hasSingleUse()); + assertTrue(describeSelectRequests.get(0).getTransaction().getSingleUse().hasReadOnly()); - ExecuteSqlRequest describeSelectParamsRequest = - requests.stream() - .filter(request -> request.getSql().equals(describeParamsSelectSql)) - .findFirst() - .orElse(ExecuteSqlRequest.getDefaultInstance()); - assertEquals(describeParamsSelectSql, describeSelectParamsRequest.getSql()); - assertEquals(QueryMode.PLAN, describeSelectParamsRequest.getQueryMode()); - assertTrue(describeSelectParamsRequest.getTransaction().hasSingleUse()); - - ExecuteSqlRequest describeUpdateParamsRequest = - requests.stream() - .filter(request -> request.getSql().equals(describeUpdateSql)) - .findFirst() - .orElse(ExecuteSqlRequest.getDefaultInstance()); - assertEquals(describeUpdateSql, describeUpdateParamsRequest.getSql()); - assertEquals(QueryMode.PLAN, describeUpdateParamsRequest.getQueryMode()); - assertTrue(describeUpdateParamsRequest.getTransaction().hasSingleUse()); - - ExecuteSqlRequest parseUpdateRequest = + List describeUpdateRequests = requests.stream() .filter( request -> request.getSql().equals(updateSql) && request.getQueryMode() == QueryMode.PLAN) - .findFirst() - .orElse(ExecuteSqlRequest.getDefaultInstance()); - assertEquals(updateSql, parseUpdateRequest.getSql()); - assertEquals(QueryMode.PLAN, parseUpdateRequest.getQueryMode()); - assertTrue(parseUpdateRequest.getTransaction().hasBegin()); + .collect(Collectors.toList()); + assertEquals(1, describeUpdateRequests.size()); + assertTrue(describeUpdateRequests.get(0).getTransaction().hasBegin()); + assertTrue(describeUpdateRequests.get(0).getTransaction().getBegin().hasReadWrite()); // From here we start with the actual statement execution. - ExecuteSqlRequest executeSelectRequest = requests.get(6); + ExecuteSqlRequest executeSelectRequest = requests.get(3); assertEquals(selectSql, executeSelectRequest.getSql()); assertEquals(QueryMode.NORMAL, executeSelectRequest.getQueryMode()); // The SELECT statement should use the transaction that was started by the BatchDml request. assertTrue(executeSelectRequest.getTransaction().hasId()); - ExecuteSqlRequest executeUpdateRequest = requests.get(7); + ExecuteSqlRequest executeUpdateRequest = requests.get(4); assertEquals(updateSql, executeUpdateRequest.getSql()); assertEquals(QueryMode.NORMAL, executeUpdateRequest.getQueryMode()); assertTrue(executeUpdateRequest.getTransaction().hasId()); @@ -710,6 +755,9 @@ public void testMixedBatch() { mockSpanner.getRequestsOfType(ExecuteBatchDmlRequest.class).get(0); assertEquals(batchSize, batchDmlRequest.getStatementsCount()); assertTrue(batchDmlRequest.getTransaction().hasBegin()); + for (int i = 0; i < batchSize; i++) { + assertEquals(insertSql, batchDmlRequest.getStatements(i).getSql()); + } // There are three commit requests: // 1. Describe insert statement. @@ -728,48 +776,45 @@ public void testMixedBatch() { || request instanceof ExecuteBatchDmlRequest || request instanceof CommitRequest) .collect(Collectors.toList()); - // 12 == 3 Commit + 1 Batch DML + 8 ExecuteSql. - assertEquals(12, allRequests.size()); + // 9 == 3 Commit + 1 Batch DML + 5 ExecuteSql. + assertEquals(9, allRequests.size()); // We don't know the exact order of the DESCRIBE requests. // The order of EXECUTE requests is known and fixed. // The (theoretical) order of DESCRIBE requests is: // 1. Describe parameters of insert statement in PLAN mode. - // 2. Parse insert statement in PLAN mode. - // 3. Commit. - // 4. Describe columns of select statement in PLAN mode. - // 5. Describe parameters of select statement in PLAN mode. - // 6. Describe parameters of update statement in PLAN mode. - // 7. Parse update statement in PLAN mode. - // 8. Commit. + // 2. Commit. + // 3. Describe parameters and columns of select statement in PLAN mode. + // 4. Describe parameters of update statement in PLAN mode. + // 5. Commit. // The fixed order of EXECUTE requests is: - // 9. Execute insert batch (ExecuteBatchDml). - // 10. Execute select statement. - // 11. Execute update statement. - // 12. Commit transaction. + // 6. Execute insert batch (ExecuteBatchDml). + // 7. Execute select statement. + // 8. Execute update statement. + // 9. Commit transaction. assertEquals( 2, - allRequests.subList(0, 8).stream() + allRequests.subList(0, 5).stream() .filter(request -> request instanceof CommitRequest) .count()); assertEquals( - 6, - allRequests.subList(0, 8).stream() + 3, + allRequests.subList(0, 5).stream() .filter( request -> request instanceof ExecuteSqlRequest && ((ExecuteSqlRequest) request).getQueryMode() == QueryMode.PLAN) .count()); - assertEquals(ExecuteBatchDmlRequest.class, allRequests.get(8).getClass()); - assertEquals(ExecuteSqlRequest.class, allRequests.get(9).getClass()); - assertEquals(ExecuteSqlRequest.class, allRequests.get(10).getClass()); - assertEquals(CommitRequest.class, allRequests.get(11).getClass()); - - ByteString transactionId = ((CommitRequest) allRequests.get(11)).getTransactionId(); - assertEquals(transactionId, ((ExecuteSqlRequest) allRequests.get(9)).getTransaction().getId()); - assertEquals(transactionId, ((ExecuteSqlRequest) allRequests.get(10)).getTransaction().getId()); + assertEquals(ExecuteBatchDmlRequest.class, allRequests.get(5).getClass()); + assertEquals(ExecuteSqlRequest.class, allRequests.get(6).getClass()); + assertEquals(ExecuteSqlRequest.class, allRequests.get(7).getClass()); + assertEquals(CommitRequest.class, allRequests.get(8).getClass()); + + ByteString transactionId = ((CommitRequest) allRequests.get(8)).getTransactionId(); + assertEquals(transactionId, ((ExecuteSqlRequest) allRequests.get(6)).getTransaction().getId()); + assertEquals(transactionId, ((ExecuteSqlRequest) allRequests.get(7)).getTransaction().getId()); } @Test @@ -778,15 +823,12 @@ public void testBatchPrepareError() { "INSERT INTO all_types " + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"; - mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertSql), 0L)); - String describeInsertSql = - "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 from (select col_bigint=$1, col_bool=$2, col_bytea=$3, col_float8=$4, col_int=$5, col_numeric=$6, col_timestamptz=$7, col_date=$8, col_varchar=$9, col_jsonb=$10 from all_types) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeInsertSql), + Statement.of(insertSql), ResultSet.newBuilder() .setMetadata( - createMetadata( + createParameterTypesMetadata( ImmutableList.of( TypeCode.INT64, TypeCode.BOOL, @@ -798,6 +840,7 @@ public void testBatchPrepareError() { TypeCode.DATE, TypeCode.STRING, TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); // This select statement will fail during the PREPARE phase that pgx executes for all statements // before actually executing the batch. @@ -840,14 +883,12 @@ public void testBatchExecutionError() { "INSERT INTO all_types " + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"; - String describeInsertSql = - "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 from (select col_bigint=$1, col_bool=$2, col_bytea=$3, col_float8=$4, col_int=$5, col_numeric=$6, col_timestamptz=$7, col_date=$8, col_varchar=$9, col_jsonb=$10 from all_types) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeInsertSql), + Statement.of(insertSql), ResultSet.newBuilder() .setMetadata( - createMetadata( + createParameterTypesMetadata( ImmutableList.of( TypeCode.INT64, TypeCode.BOOL, @@ -859,8 +900,8 @@ public void testBatchExecutionError() { TypeCode.DATE, TypeCode.STRING, TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertSql), 0L)); int batchSize = 3; for (int i = 0; i < batchSize; i++) { Statement statement = @@ -920,14 +961,12 @@ public void testInsertNullsAllDataTypes() { "INSERT INTO all_types " + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"; - String describeSql = - "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 from (select col_bigint=$1, col_bool=$2, col_bytea=$3, col_float8=$4, col_int=$5, col_numeric=$6, col_timestamptz=$7, col_date=$8, col_varchar=$9, col_jsonb=$10 from all_types) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeSql), + Statement.of(sql), ResultSet.newBuilder() .setMetadata( - createMetadata( + createParameterTypesMetadata( ImmutableList.of( TypeCode.INT64, TypeCode.BOOL, @@ -939,8 +978,8 @@ public void testInsertNullsAllDataTypes() { TypeCode.DATE, TypeCode.STRING, TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 0L)); mockSpanner.putStatementResult( StatementResult.update( Statement.newBuilder(sql) @@ -971,19 +1010,15 @@ public void testInsertNullsAllDataTypes() { assertNull(res); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - // pgx by default always uses prepared statements. That means that each request is sent three + // pgx by default always uses prepared statements. That means that each request is sent two // times to the backend the first time it is executed: // 1. DescribeStatement (parameters) - // 2. DescribeStatement (verify validity / PARSE) -- This step could be skipped. - // 3. Execute - assertEquals(3, requests.size()); - ExecuteSqlRequest describeParamsRequest = requests.get(0); - assertEquals(describeSql, describeParamsRequest.getSql()); - assertEquals(QueryMode.PLAN, describeParamsRequest.getQueryMode()); - ExecuteSqlRequest describeRequest = requests.get(1); + // 2. Execute + assertEquals(2, requests.size()); + ExecuteSqlRequest describeRequest = requests.get(0); assertEquals(sql, describeRequest.getSql()); assertEquals(QueryMode.PLAN, describeRequest.getQueryMode()); - ExecuteSqlRequest executeRequest = requests.get(2); + ExecuteSqlRequest executeRequest = requests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); } @@ -1073,14 +1108,12 @@ public void testReadWriteTransaction() { "INSERT INTO all_types " + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"; - String describeSql = - "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 from (select col_bigint=$1, col_bool=$2, col_bytea=$3, col_float8=$4, col_int=$5, col_numeric=$6, col_timestamptz=$7, col_date=$8, col_varchar=$9, col_jsonb=$10 from all_types) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeSql), + Statement.of(sql), ResultSet.newBuilder() .setMetadata( - createMetadata( + createParameterTypesMetadata( ImmutableList.of( TypeCode.INT64, TypeCode.BOOL, @@ -1092,8 +1125,8 @@ public void testReadWriteTransaction() { TypeCode.DATE, TypeCode.STRING, TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 0L)); for (long id : new Long[] {10L, 20L}) { mockSpanner.putStatementResult( StatementResult.update( @@ -1127,14 +1160,12 @@ public void testReadWriteTransaction() { assertNull(res); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); // pgx by default always uses prepared statements. That means that the first time a SQL - // statement is executed, it will be sent three times to the backend (twice for statements - // without any query parameters): - // 1. DescribeStatement (parameters) - // 2. DescribeStatement (verify validity / PARSE) -- This step could be skipped. - // 3. Execute + // statement is executed, it will be sent two times to the backend: + // 1. DescribeStatement + // 2. Execute // The second time the same statement is executed, it is only sent once. - assertEquals(6, requests.size()); + assertEquals(5, requests.size()); ExecuteSqlRequest describeSelect1Request = requests.get(0); // The first statement should begin the transaction. assertTrue(describeSelect1Request.getTransaction().hasBegin()); @@ -1144,21 +1175,17 @@ public void testReadWriteTransaction() { assertTrue(executeSelect1Request.getTransaction().hasId()); assertEquals(QueryMode.NORMAL, executeSelect1Request.getQueryMode()); - ExecuteSqlRequest describeParamsRequest = requests.get(2); - assertEquals(describeSql, describeParamsRequest.getSql()); - assertEquals(QueryMode.PLAN, describeParamsRequest.getQueryMode()); - assertTrue(describeParamsRequest.getTransaction().hasId()); - - ExecuteSqlRequest describeRequest = requests.get(3); + ExecuteSqlRequest describeRequest = requests.get(2); assertEquals(sql, describeRequest.getSql()); assertEquals(QueryMode.PLAN, describeRequest.getQueryMode()); assertTrue(describeRequest.getTransaction().hasId()); - ExecuteSqlRequest executeRequest = requests.get(4); - assertEquals(sql, executeRequest.getSql()); - assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); - assertTrue(executeRequest.getTransaction().hasId()); - assertTrue(requests.get(3).getTransaction().hasId()); + for (int i = 3; i < 5; i++) { + ExecuteSqlRequest executeRequest = requests.get(i); + assertEquals(sql, executeRequest.getSql()); + assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); + assertTrue(executeRequest.getTransaction().hasId()); + } assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); CommitRequest commitRequest = mockSpanner.getRequestsOfType(CommitRequest.class).get(0); @@ -1222,7 +1249,6 @@ public void testReadWriteTransactionIsolationLevelRepeatableRead() { assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); } - @Ignore("Requires Spanner client library 6.26.0") @Test public void testReadOnlySerializableTransaction() { String res = pgxTest.TestReadOnlySerializableTransaction(createConnString()); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxSimpleModeMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxSimpleModeMockServerTest.java index 17a11a611..2012da7c5 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxSimpleModeMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxSimpleModeMockServerTest.java @@ -37,7 +37,6 @@ import java.nio.charset.StandardCharsets; import java.util.List; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; @@ -209,7 +208,6 @@ public void testInsertNullsAllDataTypes() { } @Test - @Ignore("Skip until https://github.com/googleapis/java-spanner/pull/1877 has been released") public void testWrongDialect() { // Let the mock server respond with the Google SQL dialect instead of PostgreSQL. The // connection should be gracefully rejected. Close all open pooled Spanner objects so we know diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxTest.java index efd9f97ac..3521d6cdb 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxTest.java @@ -31,6 +31,8 @@ public interface PgxTest extends Library { String TestInsertNullsAllDataTypes(GoString connString); + String TestInsertAllDataTypesReturning(GoString connString); + String TestUpdateAllDataTypes(GoString connString); String TestPrepareStatement(GoString connString); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/hibernate/ITHibernateTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/hibernate/ITHibernateTest.java new file mode 100644 index 000000000..50e532f5d --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/pgadapter/hibernate/ITHibernateTest.java @@ -0,0 +1,157 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.hibernate; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.spanner.Database; +import com.google.cloud.spanner.pgadapter.PgAdapterTestEnv; +import com.google.common.collect.ImmutableList; +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.io.InputStreamReader; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Scanner; +import java.util.logging.Logger; +import java.util.stream.Collectors; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ITHibernateTest { + private static final Logger LOGGER = Logger.getLogger(ITHibernateTest.class.getName()); + private static final String HIBERNATE_SAMPLE_DIRECTORY = "samples/java/hibernate"; + private static final String HIBERNATE_PROPERTIES_FILE = + HIBERNATE_SAMPLE_DIRECTORY + "/src/main/resources/hibernate.properties"; + private static final String HIBERNATE_SAMPLE_SCHEMA_FILE = + HIBERNATE_SAMPLE_DIRECTORY + "/src/main/resources/sample-schema-sql"; + private static final String HIBERNATE_DEFAULT_URL = + "jdbc:postgresql://localhost:5432/test-database"; + private static final PgAdapterTestEnv testEnv = new PgAdapterTestEnv(); + private static Database database; + private static String originalHibernateProperties; + + @BeforeClass + public static void setup() + throws ClassNotFoundException, IOException, SQLException, InterruptedException { + // Make sure the PG JDBC driver is loaded. + Class.forName("org.postgresql.Driver"); + + testEnv.setUp(); + database = testEnv.createDatabase(ImmutableList.of()); + testEnv.startPGAdapterServer(ImmutableList.of()); + // Create the sample schema. + StringBuilder builder = new StringBuilder(); + try (Scanner scanner = new Scanner(new FileReader(HIBERNATE_SAMPLE_SCHEMA_FILE))) { + while (scanner.hasNextLine()) { + builder.append(scanner.nextLine()).append("\n"); + } + } + // Note: We know that all semicolons in this file are outside of literals etc. + String[] ddl = builder.toString().split(";"); + String url = + String.format( + "jdbc:postgresql://localhost:%d/%s", + testEnv.getPGAdapterPort(), database.getId().getDatabase()); + try (Connection connection = DriverManager.getConnection(url)) { + try (Statement statement = connection.createStatement()) { + for (String sql : ddl) { + LOGGER.info("Executing " + sql); + statement.execute(sql); + } + } + } + + // Write hibernate.properties + StringBuilder original = new StringBuilder(); + try (Scanner scanner = new Scanner(new FileReader(HIBERNATE_PROPERTIES_FILE))) { + while (scanner.hasNextLine()) { + original.append(scanner.nextLine()).append("\n"); + } + } + originalHibernateProperties = original.toString(); + String updatesHibernateProperties = original.toString().replace(HIBERNATE_DEFAULT_URL, url); + try (FileWriter writer = new FileWriter(HIBERNATE_PROPERTIES_FILE)) { + LOGGER.info("Using Hibernate properties:\n" + updatesHibernateProperties); + writer.write(updatesHibernateProperties); + writer.flush(); + } + buildHibernateSample(); + } + + @AfterClass + public static void teardown() throws IOException { + try (FileWriter writer = new FileWriter(HIBERNATE_PROPERTIES_FILE)) { + writer.write(originalHibernateProperties); + } + testEnv.stopPGAdapterServer(); + testEnv.cleanUp(); + } + + @Test + public void testHibernateUpdate() throws IOException, InterruptedException { + System.out.println("Running hibernate test"); + ImmutableList hibernateCommand = + ImmutableList.builder() + .add( + "mvn", + "exec:java", + "-Dexec.mainClass=com.google.cloud.postgres.HibernateSampleTest") + .build(); + runCommand(hibernateCommand); + System.out.println("Hibernate Test Ended"); + } + + static void buildHibernateSample() throws IOException, InterruptedException { + System.out.println("Building Hibernate Sample."); + ImmutableList hibernateCommand = + ImmutableList.builder().add("mvn", "clean", "compile").build(); + runCommand(hibernateCommand); + System.out.println("Hibernate Sample build complete."); + } + + static void runCommand(ImmutableList commands) throws IOException, InterruptedException { + System.out.println( + "Executing commands: " + commands + ". Sample Directory: " + HIBERNATE_SAMPLE_DIRECTORY); + ProcessBuilder builder = new ProcessBuilder(); + builder.command(commands); + builder.directory(new File(HIBERNATE_SAMPLE_DIRECTORY)); + Process process = builder.start(); + + String errors; + String output; + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(process.getInputStream())); + BufferedReader errorReader = + new BufferedReader(new InputStreamReader(process.getErrorStream()))) { + System.out.println("Printing hibernate loadings"); + output = reader.lines().collect(Collectors.joining("\n")); + errors = errorReader.lines().collect(Collectors.joining("\n")); + System.out.println(output); + } + + // Verify that there were no errors, and print the error output if there was an error. + assertEquals(errors, 0, process.waitFor()); + } +} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/liquibase/ITLiquibaseTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/liquibase/ITLiquibaseTest.java index 311e5c46e..346785f39 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/liquibase/ITLiquibaseTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/liquibase/ITLiquibaseTest.java @@ -102,6 +102,7 @@ public static void setup() throws ClassNotFoundException, IOException, SQLExcept testEnv.getPGAdapterPort(), database.getId().getDatabase()); LOGGER.info("Using Liquibase properties:\n" + properties); writer.write(properties); + writer.flush(); } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/metadata/DescribeResultTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/metadata/DescribeResultTest.java new file mode 100644 index 000000000..24b707952 --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/pgadapter/metadata/DescribeResultTest.java @@ -0,0 +1,113 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.metadata; + +import static com.google.cloud.spanner.pgadapter.metadata.DescribeResult.extractParameterTypes; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import com.google.cloud.spanner.pgadapter.error.PGException; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.postgresql.core.Oid; + +@RunWith(JUnit4.class) +public class DescribeResultTest { + + @Test + public void testExtractParameterTypes() { + assertArrayEquals( + new int[] {}, extractParameterTypes(new int[] {}, StructType.newBuilder().build())); + assertArrayEquals( + new int[] {Oid.INT8}, + extractParameterTypes( + new int[] {}, + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .build())); + assertArrayEquals( + new int[] {Oid.BOOL}, + extractParameterTypes( + new int[] {}, + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.BOOL).build()) + .build()) + .build())); + assertArrayEquals( + new int[] {Oid.INT8}, + extractParameterTypes( + new int[] {Oid.INT8}, + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .build())); + assertArrayEquals( + new int[] {Oid.INT8}, + extractParameterTypes(new int[] {Oid.INT8}, StructType.newBuilder().build())); + assertArrayEquals( + new int[] {Oid.INT8, Oid.VARCHAR}, + extractParameterTypes( + new int[] {Oid.INT8}, + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p2") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .build())); + assertArrayEquals( + new int[] {Oid.INT8, Oid.UNSPECIFIED, Oid.VARCHAR}, + extractParameterTypes( + new int[] {Oid.INT8}, + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p3") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .build())); + + PGException exception = + assertThrows( + PGException.class, + () -> + extractParameterTypes( + new int[] {}, + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("foo") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .build())); + assertEquals("Invalid parameter name: foo", exception.getMessage()); + } +} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/ITNodePostgresTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/ITNodePostgresTest.java index 20ef01727..e3463206c 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/ITNodePostgresTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/ITNodePostgresTest.java @@ -166,13 +166,13 @@ private String getHost() { @Test public void testSelect1() throws Exception { String output = runTest("testSelect1", getHost(), testEnv.getServer().getLocalPort()); - assertEquals("\n\nSELECT 1 returned: 1\n", output); + assertEquals("SELECT 1 returned: 1\n", output); } @Test public void testInsert() throws Exception { String output = runTest("testInsert", getHost(), testEnv.getServer().getLocalPort()); - assertEquals("\n\nInserted 1 row(s)\n", output); + assertEquals("Inserted 1 row(s)\n", output); DatabaseClient client = testEnv.getSpanner().getDatabaseClient(database.getId()); try (ResultSet resultSet = @@ -186,7 +186,7 @@ public void testInsert() throws Exception { @Test public void testInsertExecutedTwice() throws Exception { String output = runTest("testInsertTwice", getHost(), testEnv.getServer().getLocalPort()); - assertEquals("\n\nInserted 1 row(s)\nInserted 1 row(s)\n", output); + assertEquals("Inserted 1 row(s)\nInserted 1 row(s)\n", output); DatabaseClient client = testEnv.getSpanner().getDatabaseClient(database.getId()); try (ResultSet resultSet = @@ -202,7 +202,7 @@ public void testInsertExecutedTwice() throws Exception { @Test public void testInsertAutoCommit() throws IOException, InterruptedException { String output = runTest("testInsertAutoCommit", getHost(), testEnv.getServer().getLocalPort()); - assertEquals("\n\nInserted 1 row(s)\n", output); + assertEquals("Inserted 1 row(s)\n", output); DatabaseClient client = testEnv.getSpanner().getDatabaseClient(database.getId()); try (ResultSet resultSet = @@ -216,7 +216,7 @@ public void testInsertAutoCommit() throws IOException, InterruptedException { @Test public void testInsertAllTypes() throws IOException, InterruptedException { String output = runTest("testInsertAllTypes", getHost(), testEnv.getServer().getLocalPort()); - assertEquals("\n\nInserted 1 row(s)\n", output); + assertEquals("Inserted 1 row(s)\n", output); DatabaseClient client = testEnv.getSpanner().getDatabaseClient(database.getId()); try (ResultSet resultSet = @@ -244,7 +244,7 @@ public void testInsertAllTypes() throws IOException, InterruptedException { public void testInsertAllTypesNull() throws IOException, InterruptedException { String output = runTest("testInsertAllTypesNull", getHost(), testEnv.getServer().getLocalPort()); - assertEquals("\n\nInserted 1 row(s)\n", output); + assertEquals("Inserted 1 row(s)\n", output); DatabaseClient client = testEnv.getSpanner().getDatabaseClient(database.getId()); try (ResultSet resultSet = @@ -269,7 +269,7 @@ public void testInsertAllTypesPreparedStatement() throws IOException, Interrupte String output = runTest( "testInsertAllTypesPreparedStatement", getHost(), testEnv.getServer().getLocalPort()); - assertEquals("\n\nInserted 1 row(s)\nInserted 1 row(s)\n", output); + assertEquals("Inserted 1 row(s)\nInserted 1 row(s)\n", output); DatabaseClient client = testEnv.getSpanner().getDatabaseClient(database.getId()); try (ResultSet resultSet = @@ -312,7 +312,7 @@ public void testSelectAllTypes() throws IOException, InterruptedException { String output = runTest("testSelectAllTypes", getHost(), testEnv.getServer().getLocalPort()); assertEquals( - "\n\nSelected {" + "Selected {" + "\"col_bigint\":\"1\"," + "\"col_bool\":true," + "\"col_bytea\":{\"type\":\"Buffer\",\"data\":[116,101,115,116]}," @@ -333,7 +333,7 @@ public void testSelectAllTypesNull() throws IOException, InterruptedException { String output = runTest("testSelectAllTypes", getHost(), testEnv.getServer().getLocalPort()); assertEquals( - "\n\nSelected {" + "Selected {" + "\"col_bigint\":\"1\"," + "\"col_bool\":null," + "\"col_bytea\":null," @@ -359,7 +359,7 @@ public void testErrorInReadWriteTransaction() throws IOException, InterruptedExc String output = runTest("testErrorInReadWriteTransaction", getHost(), testEnv.getServer().getLocalPort()); assertEquals( - "\n\nInsert error: error: com.google.api.gax.rpc.AlreadyExistsException: io.grpc.StatusRuntimeException: ALREADY_EXISTS: Row [foo] in table users already exists\n" + "Insert error: error: com.google.api.gax.rpc.AlreadyExistsException: io.grpc.StatusRuntimeException: ALREADY_EXISTS: Row [foo] in table users already exists\n" + "Second insert failed with error: error: current transaction is aborted, commands ignored until end of transaction block\n" + "SELECT 1 returned: 1\n", output); @@ -370,7 +370,7 @@ public void testReadOnlyTransaction() throws Exception { String output = runTest("testReadOnlyTransaction", getHost(), testEnv.getServer().getLocalPort()); - assertEquals("\n\nexecuted read-only transaction\n", output); + assertEquals("executed read-only transaction\n", output); } @Test @@ -378,7 +378,7 @@ public void testReadOnlyTransactionWithError() throws Exception { String output = runTest("testReadOnlyTransactionWithError", getHost(), testEnv.getServer().getLocalPort()); assertEquals( - "\n\ncurrent transaction is aborted, commands ignored until end of transaction block\n" + "current transaction is aborted, commands ignored until end of transaction block\n" + "[ { '?column?': '2' } ]\n", output); } @@ -389,14 +389,14 @@ public void testCopyTo() throws Exception { String output = runTest("testCopyTo", getHost(), testEnv.getServer().getLocalPort()); assertEquals( - "\n\n1\tt\t\\\\x74657374\t3.14\t100\t6.626\t2022-02-16 13:18:02.123456789+00\t2022-03-29\ttest\t{\"key\": \"value\"}\n", + "1\tt\t\\\\x74657374\t3.14\t100\t6.626\t2022-02-16 13:18:02.123456789+00\t2022-03-29\ttest\t{\"key\": \"value\"}\n", output); } @Test public void testCopyFrom() throws Exception { String output = runTest("testCopyFrom", getHost(), testEnv.getServer().getLocalPort()); - assertEquals("\n\nFinished copy operation\n", output); + assertEquals("Finished copy operation\n", output); DatabaseClient client = testEnv.getSpanner().getDatabaseClient(database.getId()); try (ResultSet resultSet = @@ -410,7 +410,7 @@ public void testCopyFrom() throws Exception { @Test public void testDmlBatch() throws Exception { String output = runTest("testDmlBatch", getHost(), testEnv.getServer().getLocalPort()); - assertEquals("\n\nexecuted dml batch\n", output); + assertEquals("executed dml batch\n", output); DatabaseClient client = testEnv.getSpanner().getDatabaseClient(database.getId()); try (ResultSet resultSet = diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodeJSTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodeJSTest.java index 215af367c..6d14cf2ce 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodeJSTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodeJSTest.java @@ -44,7 +44,8 @@ static String runTest(String directory, String testName, String host, int port, String currentPath = new java.io.File(".").getCanonicalPath(); String testFilePath = String.format("%s/src/test/nodejs/%s", currentPath, directory); ProcessBuilder builder = new ProcessBuilder(); - builder.command("npm", "start", testName, host, String.format("%d", port), database); + builder.command( + "npm", "--silent", "start", testName, host, String.format("%d", port), database); builder.directory(new File(testFilePath)); Process process = builder.start(); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java index 05d9a6281..04014463b 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java @@ -15,7 +15,6 @@ package com.google.cloud.spanner.pgadapter.nodejs; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import com.google.cloud.ByteArray; @@ -36,6 +35,8 @@ import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.ExecuteBatchDmlRequest; import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.ResultSetStats; import com.google.spanner.v1.RollbackRequest; import com.google.spanner.v1.TypeCode; import io.grpc.Status; @@ -79,7 +80,7 @@ public void testSelect1() throws Exception { String output = runTest("testSelect1", getHost(), pgServer.getLocalPort()); - assertEquals("\n\nSELECT 1 returned: 1\n", output); + assertEquals("SELECT 1 returned: 1\n", output); List executeSqlRequests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() @@ -100,47 +101,47 @@ public void testInsert() throws Exception { // The result of the describe statement call is cached for that connection, so executing the // same statement once more will not cause another describe-statement round-trip. String sql = "INSERT INTO users(name) VALUES($1)"; - String describeParamsSql = "select $1 from (select name=$1 from users) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeParamsSql), - com.google.spanner.v1.ResultSet.newBuilder() - .setMetadata(createMetadata(ImmutableList.of(TypeCode.STRING))) + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata(createParameterTypesMetadata(ImmutableList.of(TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); mockSpanner.putStatementResult( StatementResult.update(Statement.newBuilder(sql).bind("p1").to("foo").build(), 1L)); String output = runTest("testInsert", getHost(), pgServer.getLocalPort()); - assertEquals("\n\nInserted 1 row(s)\n", output); + assertEquals("Inserted 1 row(s)\n", output); List executeSqlRequests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() - .filter( - request -> - request.getSql().equals(sql) || request.getSql().equals(describeParamsSql)) + .filter(request -> request.getSql().equals(sql)) .collect(Collectors.toList()); assertEquals(2, executeSqlRequests.size()); ExecuteSqlRequest describeRequest = executeSqlRequests.get(0); - assertEquals(describeParamsSql, describeRequest.getSql()); - assertFalse(describeRequest.hasTransaction()); + assertEquals(sql, describeRequest.getSql()); + assertTrue(describeRequest.hasTransaction()); + assertTrue(describeRequest.getTransaction().hasBegin()); + assertTrue(describeRequest.getTransaction().getBegin().hasReadWrite()); + ExecuteSqlRequest executeRequest = executeSqlRequests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(1, executeRequest.getParamTypesCount()); - assertTrue(executeRequest.getTransaction().hasBegin()); - assertTrue(executeRequest.getTransaction().getBegin().hasReadWrite()); + assertTrue(executeRequest.getTransaction().hasId()); assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); } @Test public void testInsertExecutedTwice() throws Exception { String sql = "INSERT INTO users(name) VALUES($1)"; - String describeParamsSql = "select $1 from (select name=$1 from users) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeParamsSql), - com.google.spanner.v1.ResultSet.newBuilder() - .setMetadata(createMetadata(ImmutableList.of(TypeCode.STRING))) + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata(createParameterTypesMetadata(ImmutableList.of(TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); mockSpanner.putStatementResult( StatementResult.update(Statement.newBuilder(sql).bind("p1").to("foo").build(), 1L)); @@ -149,24 +150,23 @@ public void testInsertExecutedTwice() throws Exception { String output = runTest("testInsertTwice", getHost(), pgServer.getLocalPort()); - assertEquals("\n\nInserted 1 row(s)\nInserted 2 row(s)\n", output); + assertEquals("Inserted 1 row(s)\nInserted 2 row(s)\n", output); List executeSqlRequests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() - .filter( - request -> - request.getSql().equals(sql) || request.getSql().equals(describeParamsSql)) + .filter(request -> request.getSql().equals(sql)) .collect(Collectors.toList()); assertEquals(3, executeSqlRequests.size()); ExecuteSqlRequest describeRequest = executeSqlRequests.get(0); - assertEquals(describeParamsSql, describeRequest.getSql()); - assertFalse(describeRequest.hasTransaction()); + assertEquals(sql, describeRequest.getSql()); + assertTrue(describeRequest.hasTransaction()); + assertTrue(describeRequest.getTransaction().hasBegin()); + assertTrue(describeRequest.getTransaction().getBegin().hasReadWrite()); ExecuteSqlRequest executeRequest = executeSqlRequests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(1, executeRequest.getParamTypesCount()); - assertTrue(executeRequest.getTransaction().hasBegin()); - assertTrue(executeRequest.getTransaction().getBegin().hasReadWrite()); + assertTrue(executeRequest.getTransaction().hasId()); executeRequest = executeSqlRequests.get(2); assertEquals(sql, executeRequest.getSql()); @@ -178,36 +178,38 @@ public void testInsertExecutedTwice() throws Exception { @Test public void testInsertAutoCommit() throws IOException, InterruptedException { String sql = "INSERT INTO users(name) VALUES($1)"; - String describeParamsSql = "select $1 from (select name=$1 from users) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeParamsSql), - com.google.spanner.v1.ResultSet.newBuilder() - .setMetadata(createMetadata(ImmutableList.of(TypeCode.STRING))) + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata(createParameterTypesMetadata(ImmutableList.of(TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); mockSpanner.putStatementResult( StatementResult.update(Statement.newBuilder(sql).bind("p1").to("foo").build(), 1L)); String output = runTest("testInsertAutoCommit", getHost(), pgServer.getLocalPort()); - assertEquals("\n\nInserted 1 row(s)\n", output); + assertEquals("Inserted 1 row(s)\n", output); List executeSqlRequests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() - .filter( - request -> - request.getSql().equals(sql) || request.getSql().equals(describeParamsSql)) + .filter(request -> request.getSql().equals(sql)) .collect(Collectors.toList()); assertEquals(2, executeSqlRequests.size()); ExecuteSqlRequest describeRequest = executeSqlRequests.get(0); - assertEquals(describeParamsSql, describeRequest.getSql()); - assertFalse(describeRequest.hasTransaction()); + assertEquals(sql, describeRequest.getSql()); + assertTrue(describeRequest.hasTransaction()); + assertTrue(describeRequest.getTransaction().hasBegin()); + assertTrue(describeRequest.getTransaction().getBegin().hasReadWrite()); + ExecuteSqlRequest executeRequest = executeSqlRequests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(1, executeRequest.getParamTypesCount()); - assertTrue(executeRequest.getTransaction().hasBegin()); - assertTrue(executeRequest.getTransaction().getBegin().hasReadWrite()); - assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + // TODO: Enable when node-postgres 8.9 has been released. + // assertTrue(executeRequest.getTransaction().hasBegin()); + // assertTrue(executeRequest.getTransaction().getBegin().hasReadWrite()); + // assertEquals(2, mockSpanner.countRequestsOfType(CommitRequest.class)); } @Test @@ -216,15 +218,12 @@ public void testInsertAllTypes() throws IOException, InterruptedException { "INSERT INTO AllTypes " + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"; - String describeParamsSql = - "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 from " - + "(select col_bigint=$1, col_bool=$2, col_bytea=$3, col_float8=$4, col_int=$5, col_numeric=$6, col_timestamptz=$7, col_date=$8, col_varchar=$9, col_jsonb=$10 from AllTypes) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeParamsSql), - com.google.spanner.v1.ResultSet.newBuilder() + Statement.of(sql), + ResultSet.newBuilder() .setMetadata( - createMetadata( + createParameterTypesMetadata( ImmutableList.of( TypeCode.INT64, TypeCode.BOOL, @@ -236,6 +235,7 @@ public void testInsertAllTypes() throws IOException, InterruptedException { TypeCode.DATE, TypeCode.STRING, TypeCode.JSON))) + .setStats(ResultSetStats.newBuilder().build()) .build())); StatementResult updateResult = StatementResult.update( @@ -266,24 +266,26 @@ public void testInsertAllTypes() throws IOException, InterruptedException { String output = runTest("testInsertAllTypes", getHost(), pgServer.getLocalPort()); - assertEquals("\n\nInserted 1 row(s)\n", output); + assertEquals("Inserted 1 row(s)\n", output); List executeSqlRequests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() - .filter( - request -> - request.getSql().equals(sql) || request.getSql().equals(describeParamsSql)) + .filter(request -> request.getSql().equals(sql)) .collect(Collectors.toList()); assertEquals(2, executeSqlRequests.size()); ExecuteSqlRequest describeRequest = executeSqlRequests.get(0); - assertEquals(describeParamsSql, describeRequest.getSql()); - assertFalse(describeRequest.hasTransaction()); + assertEquals(sql, describeRequest.getSql()); + assertTrue(describeRequest.hasTransaction()); + assertTrue(describeRequest.getTransaction().hasBegin()); + assertTrue(describeRequest.getTransaction().getBegin().hasReadWrite()); + ExecuteSqlRequest executeRequest = executeSqlRequests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(10, executeRequest.getParamTypesCount()); - assertTrue(executeRequest.getTransaction().hasBegin()); - assertTrue(executeRequest.getTransaction().getBegin().hasReadWrite()); - assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + // TODO: Enable once node-postgres 8.9 is released. + // assertTrue(executeRequest.getTransaction().hasBegin()); + // assertTrue(executeRequest.getTransaction().getBegin().hasReadWrite()); + // assertEquals(2, mockSpanner.countRequestsOfType(CommitRequest.class)); } @Test @@ -320,7 +322,7 @@ public void testInsertAllTypesNull() throws IOException, InterruptedException { String output = runTest("testInsertAllTypesAllNull", getHost(), pgServer.getLocalPort()); - assertEquals("\n\nInserted 1 row(s)\n", output); + assertEquals("Inserted 1 row(s)\n", output); List executeSqlRequests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() @@ -341,15 +343,12 @@ public void testInsertAllTypesPreparedStatement() throws IOException, Interrupte "INSERT INTO AllTypes " + "(col_bigint, col_bool, col_bytea, col_float8, col_int, col_numeric, col_timestamptz, col_date, col_varchar, col_jsonb) " + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"; - String describeParamsSql = - "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 from " - + "(select col_bigint=$1, col_bool=$2, col_bytea=$3, col_float8=$4, col_int=$5, col_numeric=$6, col_timestamptz=$7, col_date=$8, col_varchar=$9, col_jsonb=$10 from AllTypes) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeParamsSql), - com.google.spanner.v1.ResultSet.newBuilder() + Statement.of(sql), + ResultSet.newBuilder() .setMetadata( - createMetadata( + createParameterTypesMetadata( ImmutableList.of( TypeCode.INT64, TypeCode.BOOL, @@ -361,6 +360,7 @@ public void testInsertAllTypesPreparedStatement() throws IOException, Interrupte TypeCode.DATE, TypeCode.STRING, TypeCode.JSON))) + .setStats(ResultSetStats.newBuilder().build()) .build())); StatementResult updateResult = StatementResult.update( @@ -419,7 +419,7 @@ public void testInsertAllTypesPreparedStatement() throws IOException, Interrupte String output = runTest("testInsertAllTypesPreparedStatement", getHost(), pgServer.getLocalPort()); - assertEquals("\n\nInserted 1 row(s)\nInserted 1 row(s)\n", output); + assertEquals("Inserted 1 row(s)\nInserted 1 row(s)\n", output); // node-postgres will only send one parse message when using prepared statements. It never uses // DescribeStatement. It will send a new DescribePortal for each time the prepared statement is @@ -436,20 +436,22 @@ public void testInsertAllTypesPreparedStatement() throws IOException, Interrupte List executeSqlRequests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() - .filter( - request -> - request.getSql().equals(sql) || request.getSql().equals(describeParamsSql)) + .filter(request -> request.getSql().equals(sql)) .collect(Collectors.toList()); assertEquals(3, executeSqlRequests.size()); ExecuteSqlRequest describeRequest = executeSqlRequests.get(0); - assertEquals(describeParamsSql, describeRequest.getSql()); - assertFalse(describeRequest.hasTransaction()); + assertEquals(sql, describeRequest.getSql()); + assertTrue(describeRequest.hasTransaction()); + assertTrue(describeRequest.getTransaction().hasBegin()); + assertTrue(describeRequest.getTransaction().getBegin().hasReadWrite()); + ExecuteSqlRequest executeRequest = executeSqlRequests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(10, executeRequest.getParamTypesCount()); - assertTrue(executeRequest.getTransaction().hasBegin()); - assertTrue(executeRequest.getTransaction().getBegin().hasReadWrite()); - assertEquals(2, mockSpanner.countRequestsOfType(CommitRequest.class)); + // TODO: Enable once node-postgres 8.9 is released. + // assertTrue(executeRequest.getTransaction().hasBegin()); + // assertTrue(executeRequest.getTransaction().getBegin().hasReadWrite()); + // assertEquals(3, mockSpanner.countRequestsOfType(CommitRequest.class)); } @Test @@ -463,7 +465,7 @@ public void testSelectAllTypes() throws IOException, InterruptedException { String output = runTest("testSelectAllTypes", getHost(), pgServer.getLocalPort()); assertEquals( - "\n\nSelected {" + "Selected {" + "\"col_bigint\":\"1\"," + "\"col_bool\":true," + "\"col_bytea\":{\"type\":\"Buffer\",\"data\":[116,101,115,116]}," @@ -494,7 +496,7 @@ public void testSelectAllTypesNull() throws IOException, InterruptedException { String output = runTest("testSelectAllTypes", getHost(), pgServer.getLocalPort()); assertEquals( - "\n\nSelected {" + "Selected {" + "\"col_bigint\":null," + "\"col_bool\":null," + "\"col_bytea\":null," @@ -534,7 +536,7 @@ public void testErrorInReadWriteTransaction() throws IOException, InterruptedExc String output = runTest("testErrorInReadWriteTransaction", getHost(), pgServer.getLocalPort()); assertEquals( - "\n\nInsert error: error: com.google.api.gax.rpc.AlreadyExistsException: io.grpc.StatusRuntimeException: ALREADY_EXISTS: Row with \"name\" 'foo' already exists\n" + "Insert error: error: com.google.api.gax.rpc.AlreadyExistsException: io.grpc.StatusRuntimeException: ALREADY_EXISTS: Row with \"name\" 'foo' already exists\n" + "Second insert failed with error: error: current transaction is aborted, commands ignored until end of transaction block\n" + "SELECT 1 returned: 1\n", output); @@ -547,7 +549,7 @@ public void testErrorInReadWriteTransaction() throws IOException, InterruptedExc public void testReadOnlyTransaction() throws Exception { String output = runTest("testReadOnlyTransaction", getHost(), pgServer.getLocalPort()); - assertEquals("\n\nexecuted read-only transaction\n", output); + assertEquals("executed read-only transaction\n", output); List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() @@ -574,7 +576,7 @@ public void testReadOnlyTransactionWithError() throws Exception { String output = runTest("testReadOnlyTransactionWithError", getHost(), pgServer.getLocalPort()); assertEquals( - "\n\ncurrent transaction is aborted, commands ignored until end of transaction block\n" + "current transaction is aborted, commands ignored until end of transaction block\n" + "[ { C: '2' } ]\n", output); } @@ -587,9 +589,7 @@ public void testCopyTo() throws Exception { String output = runTest("testCopyTo", getHost(), pgServer.getLocalPort()); assertEquals( - "\n" - + "\n" - + "1\tt\t\\\\x74657374\t3.14\t100\t6.626\t2022-02-16 13:18:02.123456789+00\t2022-03-29\ttest\t{\"key\": \"value\"}\n", + "1\tt\t\\\\x74657374\t3.14\t100\t6.626\t2022-02-16 13:18:02.123456789+00\t2022-03-29\ttest\t{\"key\": \"value\"}\n", output); } @@ -599,18 +599,18 @@ public void testCopyFrom() throws Exception { String output = runTest("testCopyFrom", getHost(), pgServer.getLocalPort()); - assertEquals("\n\nFinished copy operation\n", output); + assertEquals("Finished copy operation\n", output); } @Test public void testDmlBatch() throws Exception { String sql = "INSERT INTO users(name) VALUES($1)"; - String describeParamsSql = "select $1 from (select name=$1 from users) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeParamsSql), - com.google.spanner.v1.ResultSet.newBuilder() - .setMetadata(createMetadata(ImmutableList.of(TypeCode.STRING))) + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata(createParameterTypesMetadata(ImmutableList.of(TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) .build())); mockSpanner.putStatementResult( StatementResult.update(Statement.newBuilder(sql).bind("p1").to("foo").build(), 1L)); @@ -619,18 +619,18 @@ public void testDmlBatch() throws Exception { String output = runTest("testDmlBatch", getHost(), pgServer.getLocalPort()); - assertEquals("\n\nexecuted dml batch\n", output); + assertEquals("executed dml batch\n", output); List executeSqlRequests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() - .filter( - request -> - request.getSql().equals(sql) || request.getSql().equals(describeParamsSql)) + .filter(request -> request.getSql().equals(sql)) .collect(Collectors.toList()); assertEquals(1, executeSqlRequests.size()); ExecuteSqlRequest describeRequest = executeSqlRequests.get(0); - assertEquals(describeParamsSql, describeRequest.getSql()); - assertFalse(describeRequest.hasTransaction()); + assertEquals(sql, describeRequest.getSql()); + assertTrue(describeRequest.hasTransaction()); + assertTrue(describeRequest.getTransaction().hasBegin()); + assertTrue(describeRequest.getTransaction().getBegin().hasReadWrite()); List batchDmlRequests = mockSpanner.getRequestsOfType(ExecuteBatchDmlRequest.class); @@ -647,7 +647,9 @@ public void testDmlBatch() throws Exception { expectedValues[i], request.getStatements(i).getParams().getFieldsMap().get("p1").getStringValue()); } - assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + // We get two commits, because PGAdapter auto-describes the DML statement in a separate + // transaction if the auto-describe happens during a DML batch. + assertEquals(2, mockSpanner.countRequestsOfType(CommitRequest.class)); } @Test @@ -656,7 +658,7 @@ public void testDdlBatch() throws Exception { String output = runTest("testDdlBatch", getHost(), pgServer.getLocalPort()); - assertEquals("\n\nexecuted ddl batch\n", output); + assertEquals("executed ddl batch\n", output); assertEquals(1, mockDatabaseAdmin.getRequests().size()); assertEquals(UpdateDatabaseDdlRequest.class, mockDatabaseAdmin.getRequests().get(0).getClass()); UpdateDatabaseDdlRequest request = diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/TypeORMMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/TypeORMMockServerTest.java index e82ddedaa..93e241406 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/TypeORMMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/TypeORMMockServerTest.java @@ -18,16 +18,22 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import com.google.cloud.ByteArray; +import com.google.cloud.Date; import com.google.cloud.Timestamp; import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.pgadapter.AbstractMockServerTest; +import com.google.common.collect.ImmutableList; import com.google.protobuf.ListValue; import com.google.protobuf.Value; +import com.google.spanner.v1.BeginTransactionRequest; import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.ExecuteSqlRequest.QueryMode; import com.google.spanner.v1.ResultSet; import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; import com.google.spanner.v1.StructType; import com.google.spanner.v1.StructType.Field; import com.google.spanner.v1.Type; @@ -88,6 +94,22 @@ public void testFindOneUser() throws IOException, InterruptedException { mockSpanner.putStatementResult( StatementResult.query( Statement.of(sql), + ResultSet.newBuilder() + .setMetadata( + USERS_METADATA + .toBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .build())) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(sql).bind("p1").to(1L).build(), ResultSet.newBuilder() .setMetadata(USERS_METADATA) .addRows( @@ -101,27 +123,40 @@ public void testFindOneUser() throws IOException, InterruptedException { String output = runTest("findOneUser", pgServer.getLocalPort()); - assertEquals("\n\nFound user 1 with name Timber Saw\n", output); + assertEquals("Found user 1 with name Timber Saw\n", output); List executeSqlRequests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() .filter(request -> request.getSql().equals(sql)) .collect(Collectors.toList()); - assertEquals(1, executeSqlRequests.size()); - ExecuteSqlRequest request = executeSqlRequests.get(0); + assertEquals(2, executeSqlRequests.size()); + ExecuteSqlRequest describeRequest = executeSqlRequests.get(0); + ExecuteSqlRequest executeRequest = executeSqlRequests.get(1); // The TypeORM PostgreSQL driver sends both a Flush and a Sync message. The Flush message does // a look-ahead to determine if the next message is a Sync, and if it is, executes a Sync on the - // backend connection. This is a lot more efficient, as it means that we can use single-use - // read-only transactions for single queries. + // backend connection. This is a lot more efficient, as it means that we can use a read-only + // transaction for transactions that only contains queries. // There is however no guarantee that the server will see the Sync message in time to do this - // optimization, so in some cases the single query will be using a read/write transaction. + // optimization, so in some cases the single query will be using a read/write transaction, as we + // don't know what might be following the current query. + // This behavior in node-postgres has been fixed in + // https://github.com/brianc/node-postgres/pull/2842, + // but has not yet been released. int commitRequestCount = mockSpanner.countRequestsOfType(CommitRequest.class); if (commitRequestCount == 0) { - assertTrue(request.getTransaction().hasSingleUse()); - assertTrue(request.getTransaction().getSingleUse().hasReadOnly()); + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertTrue( + mockSpanner + .getRequestsOfType(BeginTransactionRequest.class) + .get(0) + .getOptions() + .hasReadOnly()); + assertTrue(describeRequest.getTransaction().hasId()); + assertTrue(executeRequest.getTransaction().hasId()); } else if (commitRequestCount == 1) { - assertTrue(request.getTransaction().hasBegin()); - assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertTrue(describeRequest.getTransaction().hasBegin()); + assertTrue(describeRequest.getTransaction().getBegin().hasReadWrite()); + assertTrue(executeRequest.getTransaction().hasId()); } else { fail("Invalid commit count: " + commitRequestCount); } @@ -135,22 +170,79 @@ public void testCreateUser() throws IOException, InterruptedException { + "FROM \"user\" \"User\" WHERE \"User\".\"id\" IN ($1)"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(existsSql), ResultSet.newBuilder().setMetadata(USERS_METADATA).build())); + Statement.of(existsSql), + ResultSet.newBuilder() + .setMetadata( + USERS_METADATA + .toBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .build())) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(existsSql).bind("p1").to(1L).build(), + ResultSet.newBuilder().setMetadata(USERS_METADATA).build())); String insertSql = "INSERT INTO \"user\"(\"id\", \"firstName\", \"lastName\", \"age\") VALUES ($1, $2, $3, $4)"; - mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertSql), 1L)); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(insertSql), + ResultSet.newBuilder() + .setMetadata( + createParameterTypesMetadata( + ImmutableList.of( + TypeCode.INT64, TypeCode.STRING, TypeCode.STRING, TypeCode.INT64))) + .setStats(ResultSetStats.newBuilder().build()) + .build())); + mockSpanner.putStatementResult( + StatementResult.update( + Statement.newBuilder(insertSql) + .bind("p1") + .to(1L) + .bind("p2") + .to("Timber") + .bind("p3") + .to("Saw") + .bind("p4") + .to(25L) + .build(), + 1L)); String sql = "SELECT \"User\".\"id\" AS \"User_id\", \"User\".\"firstName\" AS \"User_firstName\", " + "\"User\".\"lastName\" AS \"User_lastName\", \"User\".\"age\" AS \"User_age\" " + "FROM \"user\" \"User\" WHERE (\"User\".\"firstName\" = $1 AND \"User\".\"lastName\" = $2) " + "LIMIT 1"; - // The parameter is sent as an untyped parameter, and therefore not included in the statement - // lookup on the mock server, hence the Statement.of(sql) instead of building a statement that - // does include the parameter value. mockSpanner.putStatementResult( StatementResult.query( Statement.of(sql), + ResultSet.newBuilder() + .setMetadata( + USERS_METADATA + .toBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p2") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .build())) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(sql).bind("p1").to("Timber").bind("p2").to("Saw").build(), ResultSet.newBuilder() .setMetadata(USERS_METADATA) .addRows( @@ -164,7 +256,7 @@ public void testCreateUser() throws IOException, InterruptedException { String output = runTest("createUser", pgServer.getLocalPort()); - assertEquals("\n\nFound user 1 with name Timber Saw\n", output); + assertEquals("Found user 1 with name Timber Saw\n", output); // Creating the user will use a read/write transaction. The query that checks whether the record // already exists will however not use that transaction, as each statement is executed in @@ -174,25 +266,27 @@ public void testCreateUser() throws IOException, InterruptedException { mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() .filter(request -> request.getSql().equals(existsSql)) .collect(Collectors.toList()); - assertEquals(1, checkExistsRequests.size()); - ExecuteSqlRequest checkExistsRequest = checkExistsRequests.get(0); - if (checkExistsRequest.getTransaction().hasSingleUse()) { - assertTrue(checkExistsRequest.getTransaction().getSingleUse().hasReadOnly()); - } else if (checkExistsRequest.getTransaction().hasBegin()) { - assertTrue(checkExistsRequest.getTransaction().getBegin().hasReadWrite()); + assertEquals(2, checkExistsRequests.size()); + ExecuteSqlRequest describeCheckExistsRequest = checkExistsRequests.get(0); + ExecuteSqlRequest executeCheckExistsRequest = checkExistsRequests.get(1); + assertEquals(QueryMode.PLAN, describeCheckExistsRequest.getQueryMode()); + assertEquals(QueryMode.NORMAL, executeCheckExistsRequest.getQueryMode()); + if (describeCheckExistsRequest.getTransaction().hasBegin()) { expectedCommitCount++; - } else { - fail("missing or invalid transaction option: " + checkExistsRequest.getTransaction()); } List insertRequests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() .filter(request -> request.getSql().equals(insertSql)) .collect(Collectors.toList()); - assertEquals(1, insertRequests.size()); - ExecuteSqlRequest insertRequest = insertRequests.get(0); - assertTrue(insertRequest.getTransaction().hasBegin()); - assertTrue(insertRequest.getTransaction().getBegin().hasReadWrite()); + assertEquals(2, insertRequests.size()); + ExecuteSqlRequest describeInsertRequest = insertRequests.get(0); + assertEquals(QueryMode.PLAN, describeInsertRequest.getQueryMode()); + assertTrue(describeInsertRequest.getTransaction().hasBegin()); + assertTrue(describeInsertRequest.getTransaction().getBegin().hasReadWrite()); + ExecuteSqlRequest executeInsertRequest = insertRequests.get(1); + assertEquals(QueryMode.NORMAL, executeInsertRequest.getQueryMode()); + assertTrue(executeInsertRequest.getTransaction().hasId()); expectedCommitCount++; // Loading the user after having saved it will be done in a single-use read-only transaction. @@ -200,15 +294,13 @@ public void testCreateUser() throws IOException, InterruptedException { mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() .filter(request -> request.getSql().equals(sql)) .collect(Collectors.toList()); - assertEquals(1, loadRequests.size()); - ExecuteSqlRequest loadRequest = loadRequests.get(0); - if (loadRequest.getTransaction().hasSingleUse()) { - assertTrue(loadRequest.getTransaction().getSingleUse().hasReadOnly()); - } else if (loadRequest.getTransaction().hasBegin()) { - assertTrue(loadRequest.getTransaction().getBegin().hasReadWrite()); + assertEquals(2, loadRequests.size()); + ExecuteSqlRequest describeLoadRequest = loadRequests.get(0); + assertEquals(QueryMode.PLAN, describeLoadRequest.getQueryMode()); + ExecuteSqlRequest executeLoadRequest = loadRequests.get(1); + assertEquals(QueryMode.NORMAL, executeLoadRequest.getQueryMode()); + if (describeLoadRequest.getTransaction().hasBegin()) { expectedCommitCount++; - } else { - fail("missing or invalid transaction option: " + loadRequest.getTransaction()); } assertEquals(expectedCommitCount, mockSpanner.countRequestsOfType(CommitRequest.class)); } @@ -222,6 +314,23 @@ public void testUpdateUser() throws IOException, InterruptedException { mockSpanner.putStatementResult( StatementResult.query( Statement.of(loadSql), + ResultSet.newBuilder() + .setMetadata( + USERS_METADATA + .toBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .build()) + .build()) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(loadSql).bind("p1").to(1L).build(), ResultSet.newBuilder() .setMetadata(USERS_METADATA) .addRows( @@ -232,6 +341,7 @@ public void testUpdateUser() throws IOException, InterruptedException { .addValues(Value.newBuilder().setStringValue("25").build()) .build()) .build())); + String existsSql = "SELECT \"User\".\"id\" AS \"User_id\", \"User\".\"firstName\" AS \"User_firstName\", " + "\"User\".\"lastName\" AS \"User_lastName\", \"User\".\"age\" AS \"User_age\" " @@ -239,6 +349,23 @@ public void testUpdateUser() throws IOException, InterruptedException { mockSpanner.putStatementResult( StatementResult.query( Statement.of(existsSql), + ResultSet.newBuilder() + .setMetadata( + USERS_METADATA + .toBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .build()) + .build()) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(existsSql).bind("p1").to(1L).build(), ResultSet.newBuilder() .setMetadata(USERS_METADATA) .addRows( @@ -252,11 +379,33 @@ public void testUpdateUser() throws IOException, InterruptedException { String updateSql = "UPDATE \"user\" SET \"firstName\" = $1, \"lastName\" = $2, \"age\" = $3 WHERE \"id\" IN ($4)"; - mockSpanner.putStatementResult(StatementResult.update(Statement.of(updateSql), 1L)); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(updateSql), + ResultSet.newBuilder() + .setMetadata( + createParameterTypesMetadata( + ImmutableList.of( + TypeCode.STRING, TypeCode.STRING, TypeCode.INT64, TypeCode.INT64))) + .setStats(ResultSetStats.newBuilder().build()) + .build())); + mockSpanner.putStatementResult( + StatementResult.update( + Statement.newBuilder(updateSql) + .bind("p1") + .to("Lumber") + .bind("p2") + .to("Jack") + .bind("p3") + .to(45L) + .bind("p4") + .to(1L) + .build(), + 1L)); String output = runTest("updateUser", pgServer.getLocalPort()); - assertEquals("\n\nUpdated user 1\n", output); + assertEquals("Updated user 1\n", output); // Updating the user will use a read/write transaction. The query that checks whether the record // already exists will however not use that transaction, as each statement is executed in @@ -266,40 +415,38 @@ public void testUpdateUser() throws IOException, InterruptedException { mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() .filter(request -> request.getSql().equals(loadSql)) .collect(Collectors.toList()); - assertEquals(1, loadRequests.size()); - ExecuteSqlRequest loadRequest = loadRequests.get(0); - if (loadRequest.getTransaction().hasSingleUse()) { - assertTrue(loadRequest.getTransaction().getSingleUse().hasReadOnly()); - } else if (loadRequest.getTransaction().hasBegin()) { - assertTrue(loadRequest.getTransaction().getBegin().hasReadWrite()); + assertEquals(2, loadRequests.size()); + ExecuteSqlRequest describeLoadRequest = loadRequests.get(0); + assertEquals(QueryMode.PLAN, describeLoadRequest.getQueryMode()); + ExecuteSqlRequest executeLoadRequest = loadRequests.get(1); + assertEquals(QueryMode.NORMAL, executeLoadRequest.getQueryMode()); + if (describeLoadRequest.getTransaction().hasBegin()) { expectedCommitCount++; - } else { - fail("missing or invalid transaction option: " + loadRequest.getTransaction()); } List checkExistsRequests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() .filter(request -> request.getSql().equals(existsSql)) .collect(Collectors.toList()); - assertEquals(1, checkExistsRequests.size()); - ExecuteSqlRequest checkExistsRequest = checkExistsRequests.get(0); - if (checkExistsRequest.getTransaction().hasSingleUse()) { - assertTrue(checkExistsRequest.getTransaction().getSingleUse().hasReadOnly()); - } else if (checkExistsRequest.getTransaction().hasBegin()) { - assertTrue(checkExistsRequest.getTransaction().getBegin().hasReadWrite()); + assertEquals(2, checkExistsRequests.size()); + ExecuteSqlRequest describeCheckExistsRequest = checkExistsRequests.get(0); + assertEquals(QueryMode.PLAN, describeCheckExistsRequest.getQueryMode()); + ExecuteSqlRequest executeCheckExistsRequest = checkExistsRequests.get(1); + assertEquals(QueryMode.NORMAL, executeCheckExistsRequest.getQueryMode()); + if (describeCheckExistsRequest.getTransaction().hasBegin()) { expectedCommitCount++; - } else { - fail("missing or invalid transaction option: " + checkExistsRequest.getTransaction()); } List updateRequests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() .filter(request -> request.getSql().equals(updateSql)) .collect(Collectors.toList()); - assertEquals(1, updateRequests.size()); - ExecuteSqlRequest updateRequest = updateRequests.get(0); - assertTrue(updateRequest.getTransaction().hasBegin()); - assertTrue(updateRequest.getTransaction().getBegin().hasReadWrite()); + assertEquals(2, updateRequests.size()); + ExecuteSqlRequest describeUpdateRequest = updateRequests.get(0); + ExecuteSqlRequest executeUpdateRequest = updateRequests.get(1); + assertTrue(describeUpdateRequest.getTransaction().hasBegin()); + assertTrue(describeUpdateRequest.getTransaction().getBegin().hasReadWrite()); + assertTrue(executeUpdateRequest.getTransaction().hasId()); expectedCommitCount++; assertEquals(expectedCommitCount, mockSpanner.countRequestsOfType(CommitRequest.class)); @@ -314,6 +461,22 @@ public void testDeleteUser() throws IOException, InterruptedException { mockSpanner.putStatementResult( StatementResult.query( Statement.of(loadSql), + ResultSet.newBuilder() + .setMetadata( + USERS_METADATA + .toBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .build())) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(loadSql).bind("p1").to(1L).build(), ResultSet.newBuilder() .setMetadata(USERS_METADATA) .addRows( @@ -324,6 +487,7 @@ public void testDeleteUser() throws IOException, InterruptedException { .addValues(Value.newBuilder().setStringValue("25").build()) .build()) .build())); + String existsSql = "SELECT \"User\".\"id\" AS \"User_id\", \"User\".\"firstName\" AS \"User_firstName\", " + "\"User\".\"lastName\" AS \"User_lastName\", \"User\".\"age\" AS \"User_age\" " @@ -331,6 +495,22 @@ public void testDeleteUser() throws IOException, InterruptedException { mockSpanner.putStatementResult( StatementResult.query( Statement.of(existsSql), + ResultSet.newBuilder() + .setMetadata( + USERS_METADATA + .toBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .build())) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(existsSql).bind("p1").to(1L).build(), ResultSet.newBuilder() .setMetadata(USERS_METADATA) .addRows( @@ -343,11 +523,19 @@ public void testDeleteUser() throws IOException, InterruptedException { .build())); String deleteSql = "DELETE FROM \"user\" WHERE \"id\" = $1"; - mockSpanner.putStatementResult(StatementResult.update(Statement.of(deleteSql), 1L)); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(deleteSql), + ResultSet.newBuilder() + .setMetadata(createParameterTypesMetadata(ImmutableList.of(TypeCode.INT64))) + .setStats(ResultSetStats.newBuilder().build()) + .build())); + mockSpanner.putStatementResult( + StatementResult.update(Statement.newBuilder(deleteSql).bind("p1").to(1L).build(), 1L)); String output = runTest("deleteUser", pgServer.getLocalPort()); - assertEquals("\n\nDeleted user 1\n", output); + assertEquals("Deleted user 1\n", output); // Deleting the user will use a read/write transaction. The query that checks whether the record // already exists will however not use that transaction, as each statement is executed in @@ -357,40 +545,39 @@ public void testDeleteUser() throws IOException, InterruptedException { mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() .filter(request -> request.getSql().equals(loadSql)) .collect(Collectors.toList()); - assertEquals(1, loadRequests.size()); - ExecuteSqlRequest loadRequest = loadRequests.get(0); - if (loadRequest.getTransaction().hasSingleUse()) { - assertTrue(loadRequest.getTransaction().getSingleUse().hasReadOnly()); - } else if (loadRequest.getTransaction().hasBegin()) { - assertTrue(loadRequest.getTransaction().getBegin().hasReadWrite()); + assertEquals(2, loadRequests.size()); + ExecuteSqlRequest describeLoadRequest = loadRequests.get(0); + assertEquals(QueryMode.PLAN, describeLoadRequest.getQueryMode()); + ExecuteSqlRequest executeLoadRequest = loadRequests.get(1); + assertEquals(QueryMode.NORMAL, executeLoadRequest.getQueryMode()); + if (describeLoadRequest.getTransaction().hasBegin()) { expectedCommitCount++; - } else { - fail("missing or invalid transaction option: " + loadRequest.getTransaction()); } List checkExistsRequests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() .filter(request -> request.getSql().equals(existsSql)) .collect(Collectors.toList()); - assertEquals(1, checkExistsRequests.size()); - ExecuteSqlRequest checkExistsRequest = checkExistsRequests.get(0); - if (checkExistsRequest.getTransaction().hasSingleUse()) { - assertTrue(checkExistsRequest.getTransaction().getSingleUse().hasReadOnly()); - } else if (checkExistsRequest.getTransaction().hasBegin()) { - assertTrue(checkExistsRequest.getTransaction().getBegin().hasReadWrite()); + assertEquals(2, checkExistsRequests.size()); + ExecuteSqlRequest describeCheckExistsRequest = checkExistsRequests.get(0); + assertEquals(QueryMode.PLAN, describeCheckExistsRequest.getQueryMode()); + ExecuteSqlRequest executeCheckExistsRequest = checkExistsRequests.get(1); + assertEquals(QueryMode.NORMAL, executeCheckExistsRequest.getQueryMode()); + if (describeCheckExistsRequest.getTransaction().hasBegin()) { expectedCommitCount++; - } else { - fail("missing or invalid transaction option: " + checkExistsRequest.getTransaction()); } List deleteRequests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() .filter(request -> request.getSql().equals(deleteSql)) .collect(Collectors.toList()); - assertEquals(1, deleteRequests.size()); - ExecuteSqlRequest deleteRequest = deleteRequests.get(0); - assertTrue(deleteRequest.getTransaction().hasBegin()); - assertTrue(deleteRequest.getTransaction().getBegin().hasReadWrite()); + assertEquals(2, deleteRequests.size()); + ExecuteSqlRequest describeDeleteRequest = deleteRequests.get(0); + assertEquals(QueryMode.PLAN, describeDeleteRequest.getQueryMode()); + ExecuteSqlRequest executeDeleteRequest = deleteRequests.get(1); + assertEquals(QueryMode.NORMAL, executeDeleteRequest.getQueryMode()); + assertTrue(describeDeleteRequest.getTransaction().hasBegin()); + assertTrue(describeDeleteRequest.getTransaction().getBegin().hasReadWrite()); expectedCommitCount++; assertEquals(expectedCommitCount, mockSpanner.countRequestsOfType(CommitRequest.class)); @@ -407,12 +594,31 @@ public void testFindOneAllTypes() throws IOException, InterruptedException { + "FROM \"all_types\" \"AllTypes\" " + "WHERE (\"AllTypes\".\"col_bigint\" = $1) LIMIT 1"; mockSpanner.putStatementResult( - StatementResult.query(Statement.of(sql), createAllTypesResultSet("AllTypes_"))); + StatementResult.query( + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata( + createAllTypesResultSetMetadata("AllTypes_") + .toBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .build()) + .build()) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(sql).bind("p1").to(1L).build(), + createAllTypesResultSet("AllTypes_"))); String output = runTest("findOneAllTypes", pgServer.getLocalPort()); assertEquals( - "\n\nFound row 1\n" + "Found row 1\n" + "AllTypes {\n" + " col_bigint: '1',\n" + " col_bool: true,\n" @@ -432,15 +638,13 @@ public void testFindOneAllTypes() throws IOException, InterruptedException { mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() .filter(request -> request.getSql().equals(sql)) .collect(Collectors.toList()); - assertEquals(1, executeSqlRequests.size()); - ExecuteSqlRequest request = executeSqlRequests.get(0); - if (request.getTransaction().hasSingleUse()) { - assertTrue(request.getTransaction().getSingleUse().hasReadOnly()); - } else if (request.getTransaction().hasBegin()) { - assertTrue(request.getTransaction().getBegin().hasReadWrite()); + assertEquals(2, executeSqlRequests.size()); + ExecuteSqlRequest describeRequest = executeSqlRequests.get(0); + assertEquals(QueryMode.PLAN, describeRequest.getQueryMode()); + ExecuteSqlRequest executeRequest = executeSqlRequests.get(1); + assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); + if (describeRequest.getTransaction().hasBegin()) { expectedCommitCount++; - } else { - fail("missing or invalid transaction option: " + request.getTransaction()); } assertEquals(expectedCommitCount, mockSpanner.countRequestsOfType(CommitRequest.class)); } @@ -458,6 +662,23 @@ public void testCreateAllTypes() throws IOException, InterruptedException { mockSpanner.putStatementResult( StatementResult.query( Statement.of(existsSql), + ResultSet.newBuilder() + .setMetadata( + createAllTypesResultSetMetadata("AllTypes_") + .toBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .build()) + .build()) + .build())); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.newBuilder(existsSql).bind("p1").to(2L).build(), ResultSet.newBuilder() .setMetadata(createAllTypesResultSetMetadata("AllTypes_")) .build())); @@ -466,44 +687,171 @@ public void testCreateAllTypes() throws IOException, InterruptedException { "INSERT INTO \"all_types\"" + "(\"col_bigint\", \"col_bool\", \"col_bytea\", \"col_float8\", \"col_int\", \"col_numeric\", \"col_timestamptz\", \"col_date\", \"col_varchar\", \"col_jsonb\") " + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"; - mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertSql), 1L)); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(insertSql), + ResultSet.newBuilder() + .setMetadata( + createParameterTypesMetadata( + ImmutableList.of( + TypeCode.INT64, + TypeCode.BOOL, + TypeCode.BYTES, + TypeCode.FLOAT64, + TypeCode.INT64, + TypeCode.NUMERIC, + TypeCode.TIMESTAMP, + TypeCode.DATE, + TypeCode.STRING, + TypeCode.JSON))) + .setStats(ResultSetStats.newBuilder().build()) + .build())); + mockSpanner.putStatementResult( + StatementResult.update( + Statement.newBuilder(insertSql) + .bind("p1") + .to(2L) + .bind("p2") + .to(true) + .bind("p3") + .to(ByteArray.copyFrom("some random string")) + .bind("p4") + .to(0.123456789d) + .bind("p5") + .to(123456789L) + .bind("p6") + .to(com.google.cloud.spanner.Value.pgNumeric("234.54235")) + .bind("p7") + .to(Timestamp.parseTimestamp("2022-07-22T18:15:42.011Z")) + .bind("p8") + .to(Date.parseDate("2022-07-22")) + .bind("p9") + .to("some random string") + // TODO: Change to JSONB + .bind("p10") + .to("{\"key\":\"value\"}") + .build(), + 1L)); String output = runTest("createAllTypes", pgServer.getLocalPort()); - assertEquals("\n\nCreated one record\n", output); + assertEquals("Created one record\n", output); List insertRequests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() .filter(request -> request.getSql().equals(insertSql)) .collect(Collectors.toList()); - assertEquals(1, insertRequests.size()); - ExecuteSqlRequest insertRequest = insertRequests.get(0); - assertTrue(insertRequest.getTransaction().hasBegin()); - assertTrue(insertRequest.getTransaction().getBegin().hasReadWrite()); - - // The NodeJS PostgreSQL driver sends parameters without any type information to the backend. - // This means that all parameters are sent as untyped string values. - assertEquals(0, insertRequest.getParamTypesMap().size()); - assertEquals(10, insertRequest.getParams().getFieldsCount()); - assertEquals("2", insertRequest.getParams().getFieldsMap().get("p1").getStringValue()); - assertEquals("true", insertRequest.getParams().getFieldsMap().get("p2").getStringValue()); + assertEquals(2, insertRequests.size()); + ExecuteSqlRequest describeInsertRequest = insertRequests.get(0); + assertEquals(QueryMode.PLAN, describeInsertRequest.getQueryMode()); + ExecuteSqlRequest executeInsertRequest = insertRequests.get(1); + assertEquals(QueryMode.NORMAL, executeInsertRequest.getQueryMode()); + assertTrue(describeInsertRequest.getTransaction().hasBegin()); + assertTrue(describeInsertRequest.getTransaction().getBegin().hasReadWrite()); + assertTrue(executeInsertRequest.getTransaction().hasId()); + + assertEquals(10, executeInsertRequest.getParamTypesMap().size()); + assertEquals(10, executeInsertRequest.getParams().getFieldsCount()); + assertEquals("2", executeInsertRequest.getParams().getFieldsMap().get("p1").getStringValue()); + assertTrue(executeInsertRequest.getParams().getFieldsMap().get("p2").getBoolValue()); assertEquals( "c29tZSByYW5kb20gc3RyaW5n", - insertRequest.getParams().getFieldsMap().get("p3").getStringValue()); + executeInsertRequest.getParams().getFieldsMap().get("p3").getStringValue()); + assertEquals( + 0.123456789d, + executeInsertRequest.getParams().getFieldsMap().get("p4").getNumberValue(), + 0.0d); assertEquals( - "0.123456789", insertRequest.getParams().getFieldsMap().get("p4").getStringValue()); - assertEquals("123456789", insertRequest.getParams().getFieldsMap().get("p5").getStringValue()); - assertEquals("234.54235", insertRequest.getParams().getFieldsMap().get("p6").getStringValue()); + "123456789", executeInsertRequest.getParams().getFieldsMap().get("p5").getStringValue()); + assertEquals( + "234.54235", executeInsertRequest.getParams().getFieldsMap().get("p6").getStringValue()); assertEquals( Timestamp.parseTimestamp("2022-07-22T20:15:42.011+02:00"), Timestamp.parseTimestamp( - insertRequest.getParams().getFieldsMap().get("p7").getStringValue())); - assertEquals("2022-07-22", insertRequest.getParams().getFieldsMap().get("p8").getStringValue()); + executeInsertRequest.getParams().getFieldsMap().get("p7").getStringValue())); + assertEquals( + "2022-07-22", executeInsertRequest.getParams().getFieldsMap().get("p8").getStringValue()); assertEquals( - "some random string", insertRequest.getParams().getFieldsMap().get("p9").getStringValue()); + "some random string", + executeInsertRequest.getParams().getFieldsMap().get("p9").getStringValue()); assertEquals( "{\"key\":\"value\"}", - insertRequest.getParams().getFieldsMap().get("p10").getStringValue()); + executeInsertRequest.getParams().getFieldsMap().get("p10").getStringValue()); + } + + @Test + public void testUpdateAllTypes() throws Exception { + String sql = + "SELECT \"AllTypes\".\"col_bigint\" AS \"AllTypes_col_bigint\", \"AllTypes\".\"col_bool\" AS \"AllTypes_col_bool\", " + + "\"AllTypes\".\"col_bytea\" AS \"AllTypes_col_bytea\", \"AllTypes\".\"col_float8\" AS \"AllTypes_col_float8\", " + + "\"AllTypes\".\"col_int\" AS \"AllTypes_col_int\", \"AllTypes\".\"col_numeric\" AS \"AllTypes_col_numeric\", " + + "\"AllTypes\".\"col_timestamptz\" AS \"AllTypes_col_timestamptz\", \"AllTypes\".\"col_date\" AS \"AllTypes_col_date\", " + + "\"AllTypes\".\"col_varchar\" AS \"AllTypes_col_varchar\", \"AllTypes\".\"col_jsonb\" AS \"AllTypes_col_jsonb\" " + + "FROM \"all_types\" \"AllTypes\" " + + "WHERE (\"AllTypes\".\"col_bigint\" = $1) LIMIT 1"; + mockSpanner.putStatementResult( + StatementResult.query(Statement.of(sql), createAllTypesResultSet("AllTypes_"))); + String updateSql = + "UPDATE \"all_types\" SET \"col_bigint\" = $1, \"col_bool\" = $2, \"col_bytea\" = $3, \"col_float8\" = $4, \"col_int\" = $5, \"col_numeric\" = $6, \"col_timestamptz\" = $7, \"col_date\" = $8, \"col_varchar\" = $9, \"col_jsonb\" = $10 WHERE \"col_bigint\" IN ($11)"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(updateSql), + ResultSet.newBuilder() + .setMetadata( + createParameterTypesMetadata( + ImmutableList.of( + TypeCode.INT64, + TypeCode.BOOL, + TypeCode.BYTES, + TypeCode.FLOAT64, + TypeCode.INT64, + TypeCode.NUMERIC, + TypeCode.TIMESTAMP, + TypeCode.DATE, + TypeCode.STRING, + TypeCode.JSON))) + .setStats(ResultSetStats.newBuilder().build()) + .build())); + mockSpanner.putStatementResult( + StatementResult.update( + Statement.newBuilder(updateSql) + .bind("p1") + .to(1L) + .bind("p2") + .to(false) + .bind("p3") + .to(ByteArray.copyFrom("updated string")) + .bind("p4") + .to(1.23456789) + .bind("p5") + .to(987654321L) + .bind("p6") + .to(com.google.cloud.spanner.Value.pgNumeric("6.626")) + .bind("p7") + .to(Timestamp.parseTimestamp("2022-11-16T10:03:42.999Z")) + .bind("p8") + .to(Date.parseDate("2022-11-16")) + .bind("p9") + .to("some updated string") + // TODO: Change to JSONB + .bind("p10") + .to("{\"key\":\"updated-value\"}") + .build(), + 1L)); + + mockSpanner.putStatementResult(StatementResult.update(Statement.of(updateSql), 1L)); + + String output = runTest("updateAllTypes", pgServer.getLocalPort()); + + assertEquals("Updated one record\n", output); + + // We get two commit requests, because the statement is auto-described the first time the update + // is executed. The auto-describe also runs in autocommit mode. + // TODO: Enable when node-postgres 8.9 has been released. + // assertEquals(2, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(4, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest updateRequest = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(3); + assertEquals(updateSql, updateRequest.getSql()); } static String runTest(String testName, int port) throws IOException, InterruptedException { diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/parsers/ArrayParserTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/parsers/ArrayParserTest.java index d8fb1d9ed..c2ae5b688 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/parsers/ArrayParserTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/parsers/ArrayParserTest.java @@ -15,6 +15,8 @@ package com.google.cloud.spanner.pgadapter.parsers; import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import com.google.cloud.ByteArray; import com.google.cloud.Date; @@ -25,7 +27,9 @@ import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Type.StructField; import com.google.cloud.spanner.Value; +import com.google.cloud.spanner.pgadapter.session.SessionState; import com.google.common.collect.ImmutableList; +import java.time.ZoneId; import java.util.Arrays; import org.junit.Ignore; import org.junit.Test; @@ -39,7 +43,9 @@ public class ArrayParserTest { public void testInt64StringParse() { ArrayParser parser = new ArrayParser( - createArrayResultSet(Type.int64(), Value.int64Array(Arrays.asList(1L, null, 2L))), 0); + createArrayResultSet(Type.int64(), Value.int64Array(Arrays.asList(1L, null, 2L))), + 0, + mock(SessionState.class)); assertEquals("{1,NULL,2}", parser.stringParse()); } @@ -49,7 +55,8 @@ public void testBoolStringParse() { ArrayParser parser = new ArrayParser( createArrayResultSet(Type.bool(), Value.boolArray(Arrays.asList(true, null, false))), - 0); + 0, + mock(SessionState.class)); assertEquals("{t,NULL,f}", parser.stringParse()); } @@ -62,7 +69,8 @@ public void testBytesStringParse() { Type.bytes(), Value.bytesArray( Arrays.asList(ByteArray.copyFrom("test1"), null, ByteArray.copyFrom("test2")))), - 0); + 0, + mock(SessionState.class)); assertEquals("{\"\\\\x7465737431\",NULL,\"\\\\x7465737432\"}", parser.stringParse()); } @@ -73,7 +81,8 @@ public void testFloat64StringParse() { new ArrayParser( createArrayResultSet( Type.float64(), Value.float64Array(Arrays.asList(3.14, null, 6.626))), - 0); + 0, + mock(SessionState.class)); assertEquals("{3.14,NULL,6.626}", parser.stringParse()); } @@ -84,7 +93,8 @@ public void testNumericStringParse() { new ArrayParser( createArrayResultSet( Type.pgNumeric(), Value.pgNumericArray(Arrays.asList("3.14", null, "6.626"))), - 0); + 0, + mock(SessionState.class)); assertEquals("{3.14,NULL,6.626}", parser.stringParse()); } @@ -98,13 +108,17 @@ public void testDateStringParse() { Value.dateArray( Arrays.asList( Date.parseDate("2022-07-08"), null, Date.parseDate("2000-01-01")))), - 0); + 0, + mock(SessionState.class)); assertEquals("{\"2022-07-08\",NULL,\"2000-01-01\"}", parser.stringParse()); } @Test public void testTimestampStringParse() { + SessionState sessionState = mock(SessionState.class); + when(sessionState.getTimezone()).thenReturn(ZoneId.of("UTC")); + ArrayParser parser = new ArrayParser( createArrayResultSet( @@ -114,7 +128,8 @@ public void testTimestampStringParse() { Timestamp.parseTimestamp("2022-07-08T07:00:02.123456789Z"), null, Timestamp.parseTimestamp("2000-01-01T00:00:00Z")))), - 0); + 0, + sessionState); assertEquals( "{\"2022-07-08 07:00:02.123456789+00\",NULL,\"2000-01-01 00:00:00+00\"}", @@ -127,7 +142,8 @@ public void testStringStringParse() { new ArrayParser( createArrayResultSet( Type.string(), Value.stringArray(Arrays.asList("test1", null, "test2"))), - 0); + 0, + mock(SessionState.class)); assertEquals("{\"test1\",NULL,\"test2\"}", parser.stringParse()); } @@ -141,7 +157,8 @@ public void testJsonStringParse() { Type.json(), Value.jsonArray( Arrays.asList("{\"key\": \"value1\"}}", null, "{\"key\": \"value2\"}"))), - 0); + 0, + mock(SessionState.class)); assertEquals("{\"test1\",NULL,\"test2\"}", parser.stringParse()); } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/parsers/ParserTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/parsers/ParserTest.java index 5633d4d77..a11e0be31 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/parsers/ParserTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/parsers/ParserTest.java @@ -14,12 +14,14 @@ package com.google.cloud.spanner.pgadapter.parsers; +import static com.google.cloud.spanner.pgadapter.parsers.Parser.toOid; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import com.google.cloud.ByteArray; @@ -31,14 +33,15 @@ import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Value; import com.google.cloud.spanner.pgadapter.ProxyServer.DataFormat; +import com.google.cloud.spanner.pgadapter.error.PGException; import com.google.cloud.spanner.pgadapter.parsers.Parser.FormatCode; -import com.google.common.collect.ImmutableSet; +import com.google.cloud.spanner.pgadapter.session.SessionState; +import com.google.spanner.v1.TypeCode; import java.math.BigDecimal; import java.util.Arrays; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.Mockito; import org.postgresql.core.Oid; import org.postgresql.util.ByteConverter; @@ -61,13 +64,13 @@ private void validate( } private void validateCreateBinary(byte[] item, int oid, Object value) { - Parser binary = Parser.create(ImmutableSet.of(), item, oid, FormatCode.BINARY); + Parser binary = Parser.create(mock(SessionState.class), item, oid, FormatCode.BINARY); assertParserValueEqual(binary, value); } private void validateCreateText(byte[] item, int oid, Object value) { - Parser text = Parser.create(ImmutableSet.of(), item, oid, FormatCode.TEXT); + Parser text = Parser.create(mock(SessionState.class), item, oid, FormatCode.TEXT); assertParserValueEqual(text, value); } @@ -245,7 +248,7 @@ public void testTimestampParsingBytePart() { byte[] byteResult = {-1, -1, -38, 1, -93, -70, 48, 0}; - TimestampParser parsedValue = new TimestampParser(value); + TimestampParser parsedValue = new TimestampParser(value, mock(SessionState.class)); assertArrayEquals(byteResult, parsedValue.parse(DataFormat.POSTGRESQL_BINARY)); validateCreateBinary(byteResult, Oid.TIMESTAMP, value); @@ -313,11 +316,11 @@ public void testStringArrayParsing() { '[', '"', 'a', 'b', 'c', '"', ',', '"', 'd', 'e', 'f', '"', ',', '"', 'j', 'h', 'i', '"', ']' }; - ResultSet resultSet = Mockito.mock(ResultSet.class); + ResultSet resultSet = mock(ResultSet.class); when(resultSet.getColumnType(0)).thenReturn(Type.array(Type.string())); when(resultSet.getValue(0)).thenReturn(Value.stringArray(Arrays.asList(value))); - ArrayParser parser = new ArrayParser(resultSet, 0); + ArrayParser parser = new ArrayParser(resultSet, 0, mock(SessionState.class)); validate(parser, byteResult, stringResult, spannerResult); } @@ -339,21 +342,21 @@ public void testLongArrayParsing() { byte[] stringResult = {'{', '1', ',', '2', ',', '3', '}'}; byte[] spannerResult = {'[', '1', ',', '2', ',', '3', ']'}; - ResultSet resultSet = Mockito.mock(ResultSet.class); + ResultSet resultSet = mock(ResultSet.class); when(resultSet.getColumnType(0)).thenReturn(Type.array(Type.int64())); when(resultSet.getValue(0)).thenReturn(Value.int64Array(Arrays.asList(value))); - ArrayParser parser = new ArrayParser(resultSet, 0); + ArrayParser parser = new ArrayParser(resultSet, 0, mock(SessionState.class)); validate(parser, byteResult, stringResult, spannerResult); } @Test(expected = IllegalArgumentException.class) public void testArrayArrayParsingFails() { - ResultSet resultSet = Mockito.mock(ResultSet.class); + ResultSet resultSet = mock(ResultSet.class); when(resultSet.getColumnType(0)).thenReturn(Type.array(Type.array(Type.int64()))); - new ArrayParser(resultSet, 0); + new ArrayParser(resultSet, 0, mock(SessionState.class)); } @Test @@ -409,4 +412,42 @@ public void testNumericParsingNaN() { assertEquals(value, parser.getItem()); validateCreateText(stringResult, Oid.NUMERIC, value); } + + @Test + public void testTypeToOid() { + assertEquals(Oid.INT8, toOid(createType(TypeCode.INT64))); + assertEquals(Oid.BOOL, toOid(createType(TypeCode.BOOL))); + assertEquals(Oid.VARCHAR, toOid(createType(TypeCode.STRING))); + assertEquals(Oid.JSONB, toOid(createType(TypeCode.JSON))); + assertEquals(Oid.FLOAT8, toOid(createType(TypeCode.FLOAT64))); + assertEquals(Oid.TIMESTAMPTZ, toOid(createType(TypeCode.TIMESTAMP))); + assertEquals(Oid.DATE, toOid(createType(TypeCode.DATE))); + assertEquals(Oid.NUMERIC, toOid(createType(TypeCode.NUMERIC))); + assertEquals(Oid.BYTEA, toOid(createType(TypeCode.BYTES))); + + assertEquals(Oid.INT8_ARRAY, toOid(createArrayType(TypeCode.INT64))); + assertEquals(Oid.BOOL_ARRAY, toOid(createArrayType(TypeCode.BOOL))); + assertEquals(Oid.VARCHAR_ARRAY, toOid(createArrayType(TypeCode.STRING))); + assertEquals(Oid.JSONB_ARRAY, toOid(createArrayType(TypeCode.JSON))); + assertEquals(Oid.FLOAT8_ARRAY, toOid(createArrayType(TypeCode.FLOAT64))); + assertEquals(Oid.TIMESTAMPTZ_ARRAY, toOid(createArrayType(TypeCode.TIMESTAMP))); + assertEquals(Oid.DATE_ARRAY, toOid(createArrayType(TypeCode.DATE))); + assertEquals(Oid.NUMERIC_ARRAY, toOid(createArrayType(TypeCode.NUMERIC))); + assertEquals(Oid.BYTEA_ARRAY, toOid(createArrayType(TypeCode.BYTES))); + + assertThrows(PGException.class, () -> toOid(createType(TypeCode.STRUCT))); + assertThrows(PGException.class, () -> toOid(createArrayType(TypeCode.ARRAY))); + assertThrows(PGException.class, () -> toOid(createArrayType(TypeCode.STRUCT))); + } + + static com.google.spanner.v1.Type createType(TypeCode code) { + return com.google.spanner.v1.Type.newBuilder().setCode(code).build(); + } + + static com.google.spanner.v1.Type createArrayType(TypeCode code) { + return com.google.spanner.v1.Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(com.google.spanner.v1.Type.newBuilder().setCode(code).build()) + .build(); + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParserTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParserTest.java index f169aa643..932fb5d1a 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParserTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParserTest.java @@ -19,6 +19,8 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import com.google.cloud.Timestamp; import com.google.cloud.spanner.ErrorCode; @@ -31,8 +33,10 @@ import com.google.cloud.spanner.pgadapter.ProxyServer.DataFormat; import com.google.cloud.spanner.pgadapter.error.PGException; import com.google.cloud.spanner.pgadapter.parsers.Parser.FormatCode; +import com.google.cloud.spanner.pgadapter.session.SessionState; import com.google.common.collect.ImmutableList; import java.nio.charset.StandardCharsets; +import java.time.ZoneId; import java.util.Random; import org.junit.Test; import org.junit.runner.RunWith; @@ -58,8 +62,11 @@ public void testToTimestamp() { assertThrows(SpannerException.class, () -> TimestampParser.toTimestamp(new byte[4])); assertEquals(ErrorCode.INVALID_ARGUMENT, spannerException.getErrorCode()); - assertArrayEquals(data, new TimestampParser(TimestampParser.toTimestamp(data)).binaryParse()); - assertNull(new TimestampParser(null).binaryParse()); + assertArrayEquals( + data, + new TimestampParser(TimestampParser.toTimestamp(data), mock(SessionState.class)) + .binaryParse()); + assertNull(new TimestampParser(null, mock(SessionState.class)).binaryParse()); } @Test @@ -68,9 +75,10 @@ public void testSpannerParse() { "2022-07-08T07:22:59.123456789Z", new TimestampParser( "2022-07-08 07:22:59.123456789+00".getBytes(StandardCharsets.UTF_8), - FormatCode.TEXT) + FormatCode.TEXT, + mock(SessionState.class)) .spannerParse()); - assertNull(new TimestampParser(null).spannerParse()); + assertNull(new TimestampParser(null, mock(SessionState.class)).spannerParse()); ResultSet resultSet = ResultSets.forRows( @@ -83,18 +91,23 @@ public void testSpannerParse() { resultSet.next(); assertArrayEquals( "2022-07-08T07:22:59.123456789Z".getBytes(StandardCharsets.UTF_8), - TimestampParser.convertToPG(resultSet, 0, DataFormat.SPANNER)); + TimestampParser.convertToPG(resultSet, 0, DataFormat.SPANNER, ZoneId.of("UTC"))); } @Test public void testStringParse() { + SessionState sessionState = mock(SessionState.class); + when(sessionState.getTimezone()).thenReturn(ZoneId.of("+00")); assertEquals( "2022-07-08 07:22:59.123456789+00", - new TimestampParser(Timestamp.parseTimestamp("2022-07-08T07:22:59.123456789Z")) + new TimestampParser( + Timestamp.parseTimestamp("2022-07-08T07:22:59.123456789Z"), sessionState) .stringParse()); - assertNull(new TimestampParser(null).stringParse()); + assertNull(new TimestampParser(null, sessionState).stringParse()); assertThrows( PGException.class, - () -> new TimestampParser("foo".getBytes(StandardCharsets.UTF_8), FormatCode.TEXT)); + () -> + new TimestampParser( + "foo".getBytes(StandardCharsets.UTF_8), FormatCode.TEXT, sessionState)); } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/parsers/UnspecifiedParserTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/parsers/UnspecifiedParserTest.java index 221bf548f..93cbae291 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/parsers/UnspecifiedParserTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/parsers/UnspecifiedParserTest.java @@ -18,7 +18,9 @@ import static org.junit.Assert.assertNull; import com.google.cloud.spanner.Value; +import com.google.cloud.spanner.pgadapter.parsers.Parser.FormatCode; import com.google.protobuf.NullValue; +import java.nio.charset.StandardCharsets; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -42,5 +44,9 @@ public void testStringParse() { .build())) .stringParse()); assertNull(new UnspecifiedParser(null).stringParse()); + assertEquals( + "test", + new UnspecifiedParser("test".getBytes(StandardCharsets.UTF_8), FormatCode.TEXT) + .stringParse()); } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/parsers/UuidParserTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/parsers/UuidParserTest.java new file mode 100644 index 000000000..e406b6863 --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/pgadapter/parsers/UuidParserTest.java @@ -0,0 +1,149 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.parsers; + +import static com.google.cloud.spanner.pgadapter.parsers.UuidParser.binaryEncode; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; + +import com.google.cloud.spanner.pgadapter.error.PGException; +import com.google.cloud.spanner.pgadapter.error.SQLState; +import com.google.cloud.spanner.pgadapter.error.Severity; +import com.google.cloud.spanner.pgadapter.parsers.Parser.FormatCode; +import com.google.cloud.spanner.pgadapter.session.SessionState; +import java.nio.charset.StandardCharsets; +import java.util.UUID; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.postgresql.core.Oid; +import org.postgresql.util.ByteConverter; + +@RunWith(JUnit4.class) +public class UuidParserTest { + + @Test + public void testCreate() { + assertEquals( + UuidParser.class, + Parser.create( + mock(SessionState.class), + UUID.randomUUID().toString().getBytes(StandardCharsets.UTF_8), + Oid.UUID, + FormatCode.TEXT) + .getClass()); + assertEquals( + UuidParser.class, + Parser.create( + mock(SessionState.class), + binaryEncode(UUID.randomUUID().toString()), + Oid.UUID, + FormatCode.BINARY) + .getClass()); + } + + @Test + public void testTextToText() { + String uuidStringValue = "c852ee2a-7521-4a70-a02f-2b9d0dd9c19a"; + UuidParser parser = + new UuidParser(uuidStringValue.getBytes(StandardCharsets.UTF_8), FormatCode.TEXT); + assertEquals(uuidStringValue, parser.stringParse()); + + parser = new UuidParser(null, FormatCode.TEXT); + assertNull(parser.stringParse()); + } + + @Test + public void testTextToBinary() { + String uuidStringValue = "c852ee2a-7521-4a70-a02f-2b9d0dd9c19a"; + UuidParser parser = + new UuidParser(uuidStringValue.getBytes(StandardCharsets.UTF_8), FormatCode.TEXT); + UUID uuid = UUID.fromString(uuidStringValue); + byte[] bytes = new byte[16]; + ByteConverter.int8(bytes, 0, uuid.getMostSignificantBits()); + ByteConverter.int8(bytes, 8, uuid.getLeastSignificantBits()); + assertArrayEquals(bytes, parser.binaryParse()); + + parser = new UuidParser(null, FormatCode.TEXT); + assertNull(parser.binaryParse()); + } + + @Test + public void testBinaryToText() { + String uuidStringValue = "c852ee2a-7521-4a70-a02f-2b9d0dd9c19a"; + UUID uuid = UUID.fromString(uuidStringValue); + byte[] bytes = new byte[16]; + ByteConverter.int8(bytes, 0, uuid.getMostSignificantBits()); + ByteConverter.int8(bytes, 8, uuid.getLeastSignificantBits()); + + UuidParser parser = new UuidParser(bytes, FormatCode.BINARY); + assertEquals(uuidStringValue, parser.stringParse()); + + parser = new UuidParser(null, FormatCode.BINARY); + assertNull(parser.stringParse()); + } + + @Test + public void testBinaryToBinary() { + String uuidStringValue = "c852ee2a-7521-4a70-a02f-2b9d0dd9c19a"; + UUID uuid = UUID.fromString(uuidStringValue); + byte[] bytes = new byte[16]; + ByteConverter.int8(bytes, 0, uuid.getMostSignificantBits()); + ByteConverter.int8(bytes, 8, uuid.getLeastSignificantBits()); + + UuidParser parser = new UuidParser(bytes, FormatCode.BINARY); + assertArrayEquals(bytes, parser.binaryParse()); + + parser = new UuidParser(null, FormatCode.BINARY); + assertNull(parser.binaryParse()); + } + + @Test + public void testInvalidBinaryInput() { + PGException exception = + assertThrows(PGException.class, () -> new UuidParser(new byte[8], FormatCode.BINARY)); + assertEquals(SQLState.InvalidParameterValue, exception.getSQLState()); + assertEquals(Severity.ERROR, exception.getSeverity()); + } + + @Test + public void testInvalidTextInput() { + PGException exception = + assertThrows( + PGException.class, + () -> new UuidParser("foo".getBytes(StandardCharsets.UTF_8), FormatCode.TEXT)); + assertEquals(SQLState.InvalidParameterValue, exception.getSQLState()); + assertEquals(Severity.ERROR, exception.getSeverity()); + } + + @Test + public void testInvalidTextValueForBinaryEncode() { + PGException exception = assertThrows(PGException.class, () -> UuidParser.binaryEncode("bar")); + assertEquals(SQLState.InvalidParameterValue, exception.getSQLState()); + assertEquals(Severity.ERROR, exception.getSeverity()); + } + + @Test + public void testHandleInvalidFormatCode() { + PGException exception = + assertThrows(PGException.class, () -> UuidParser.handleInvalidFormat(FormatCode.TEXT)); + assertEquals(SQLState.InternalError, exception.getSQLState()); + assertEquals(Severity.ERROR, exception.getSeverity()); + assertEquals("Unsupported format: TEXT", exception.getMessage()); + } +} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/python/django/DjangoTestSetup.java b/src/test/java/com/google/cloud/spanner/pgadapter/python/django/DjangoTestSetup.java index 4cdbe80a8..96b0f770c 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/python/django/DjangoTestSetup.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/python/django/DjangoTestSetup.java @@ -41,10 +41,10 @@ static boolean isPythonAvailable() { private static String DJANGO_PATH = "./src/test/python/django"; - public String executeBasicTests(int port, String host, List options) + private static String execute(int port, String host, List options, String testFileName) throws IOException, InterruptedException { List runCommand = - new ArrayList<>(Arrays.asList("python3", "basic_test.py", host, Integer.toString(port))); + new ArrayList<>(Arrays.asList("python3", testFileName, host, Integer.toString(port))); runCommand.addAll(options); ProcessBuilder builder = new ProcessBuilder(); builder.command(runCommand); @@ -61,4 +61,14 @@ public String executeBasicTests(int port, String host, List options) return output.toString(); } + + public String executeBasicTests(int port, String host, List options) + throws IOException, InterruptedException { + return execute(port, host, options, "basic_test.py"); + } + + public String executeTransactionTests(int port, String host, List options) + throws IOException, InterruptedException { + return execute(port, host, options, "transaction_test.py"); + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/python/django/DjangoTransactionsTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/python/django/DjangoTransactionsTest.java new file mode 100644 index 000000000..96b0218ca --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/pgadapter/python/django/DjangoTransactionsTest.java @@ -0,0 +1,198 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.python.django; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.pgadapter.python.PythonTest; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.RollbackRequest; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +@Category(PythonTest.class) +public class DjangoTransactionsTest extends DjangoTestSetup { + + @Parameter public String host; + + @Parameters(name = "host = {0}") + public static List data() { + return ImmutableList.of(new Object[] {"localhost"}, new Object[] {"/tmp"}); + } + + private ResultSet createResultSet(List rows) { + ResultSet.Builder resultSetBuilder = ResultSet.newBuilder(); + + resultSetBuilder.setMetadata( + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .setName("singerid") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName("firstname") + .build()) + .addFields( + Field.newBuilder() + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .setName("lastname") + .build()) + .build()) + .build()); + for (int i = 0; i < rows.size(); i += 3) { + String singerid = rows.get(i), firstname = rows.get(i + 1), lastname = rows.get(i + 2); + resultSetBuilder.addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue(singerid).build()) + .addValues(Value.newBuilder().setStringValue(firstname).build()) + .addValues(Value.newBuilder().setStringValue(lastname).build()) + .build()); + } + return resultSetBuilder.build(); + } + + @Test + public void transactionCommitTest() throws IOException, InterruptedException { + + String updateSQL1 = + "UPDATE \"singers\" SET \"firstname\" = 'hello', \"lastname\" = 'world' WHERE \"singers\".\"singerid\" = 1"; + String insertSQL1 = + "INSERT INTO \"singers\" (\"singerid\", \"firstname\", \"lastname\") VALUES (1, 'hello', 'world')"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(updateSQL1), 0)); + mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertSQL1), 1)); + + String updateSQL2 = + "UPDATE \"singers\" SET \"firstname\" = 'hello', \"lastname\" = 'python' WHERE \"singers\".\"singerid\" = 2"; + String insertSQL2 = + "INSERT INTO \"singers\" (\"singerid\", \"firstname\", \"lastname\") VALUES (2, 'hello', 'python')"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(updateSQL2), 0)); + mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertSQL2), 1)); + + List options = new ArrayList(); + options.add("commit"); + + String actualOutput = executeTransactionTests(pgServer.getLocalPort(), host, options); + String expectedOutput = "Transaction Committed\n"; + + assertEquals(expectedOutput, actualOutput); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void transactionRollbackTest() throws IOException, InterruptedException { + + String updateSQL1 = + "UPDATE \"singers\" SET \"firstname\" = 'hello', \"lastname\" = 'world' WHERE \"singers\".\"singerid\" = 1"; + String insertSQL1 = + "INSERT INTO \"singers\" (\"singerid\", \"firstname\", \"lastname\") VALUES (1, 'hello', 'world')"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(updateSQL1), 0)); + mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertSQL1), 1)); + + String updateSQL2 = + "UPDATE \"singers\" SET \"firstname\" = 'hello', \"lastname\" = 'python' WHERE \"singers\".\"singerid\" = 2"; + String insertSQL2 = + "INSERT INTO \"singers\" (\"singerid\", \"firstname\", \"lastname\") VALUES (2, 'hello', 'python')"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(updateSQL2), 0)); + mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertSQL2), 1)); + + List options = new ArrayList(); + options.add("rollback"); + + String actualOutput = executeTransactionTests(pgServer.getLocalPort(), host, options); + String expectedOutput = "Transaction Rollbacked\n"; + + assertEquals(expectedOutput, actualOutput); + assertEquals(1, mockSpanner.countRequestsOfType(RollbackRequest.class)); + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void transactionAtomicTest() throws IOException, InterruptedException { + + String updateSQL1 = + "UPDATE \"singers\" SET \"firstname\" = 'hello', \"lastname\" = 'world' WHERE \"singers\".\"singerid\" = 1"; + String insertSQL1 = + "INSERT INTO \"singers\" (\"singerid\", \"firstname\", \"lastname\") VALUES (1, 'hello', 'world')"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(updateSQL1), 0)); + mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertSQL1), 1)); + + String updateSQL2 = + "UPDATE \"singers\" SET \"firstname\" = 'hello', \"lastname\" = 'python' WHERE \"singers\".\"singerid\" = 2"; + String insertSQL2 = + "INSERT INTO \"singers\" (\"singerid\", \"firstname\", \"lastname\") VALUES (2, 'hello', 'python')"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(updateSQL2), 0)); + mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertSQL2), 1)); + + List options = new ArrayList(); + options.add("atomic"); + + String actualOutput = executeTransactionTests(pgServer.getLocalPort(), host, options); + String expectedOutput = "Atomic Successful\n"; + + assertEquals(expectedOutput, actualOutput); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } + + @Test + public void transactionNestedAtomicTest() throws IOException, InterruptedException { + + String updateSQL1 = + "UPDATE \"singers\" SET \"firstname\" = 'hello', \"lastname\" = 'world' WHERE \"singers\".\"singerid\" = 1"; + String insertSQL1 = + "INSERT INTO \"singers\" (\"singerid\", \"firstname\", \"lastname\") VALUES (1, 'hello', 'world')"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(updateSQL1), 0)); + mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertSQL1), 1)); + + String updateSQL2 = + "UPDATE \"singers\" SET \"firstname\" = 'hello', \"lastname\" = 'python' WHERE \"singers\".\"singerid\" = 2"; + String insertSQL2 = + "INSERT INTO \"singers\" (\"singerid\", \"firstname\", \"lastname\") VALUES (2, 'hello', 'python')"; + mockSpanner.putStatementResult(StatementResult.update(Statement.of(updateSQL2), 0)); + mockSpanner.putStatementResult(StatementResult.update(Statement.of(insertSQL2), 1)); + + List options = new ArrayList(); + options.add("nested_atomic"); + + String actualOutput = executeTransactionTests(pgServer.getLocalPort(), host, options); + String expectedOutput = "Atomic Successful\n"; + + assertEquals(expectedOutput, actualOutput); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + } +} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/python/pg8000/Pg8000BasicsTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/python/pg8000/Pg8000BasicsTest.java new file mode 100644 index 000000000..7fe46b85a --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/pgadapter/python/pg8000/Pg8000BasicsTest.java @@ -0,0 +1,169 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.python.pg8000; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.pgadapter.AbstractMockServerTest; +import com.google.cloud.spanner.pgadapter.python.PythonTest; +import com.google.cloud.spanner.pgadapter.wireprotocol.BindMessage; +import com.google.cloud.spanner.pgadapter.wireprotocol.DescribeMessage; +import com.google.cloud.spanner.pgadapter.wireprotocol.ExecuteMessage; +import com.google.cloud.spanner.pgadapter.wireprotocol.FlushMessage; +import com.google.cloud.spanner.pgadapter.wireprotocol.ParseMessage; +import com.google.cloud.spanner.pgadapter.wireprotocol.QueryMessage; +import com.google.cloud.spanner.pgadapter.wireprotocol.StartupMessage; +import com.google.cloud.spanner.pgadapter.wireprotocol.SyncMessage; +import com.google.cloud.spanner.pgadapter.wireprotocol.TerminateMessage; +import com.google.cloud.spanner.pgadapter.wireprotocol.WireMessage; +import com.google.common.collect.ImmutableList; +import com.google.spanner.v1.ExecuteSqlRequest; +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Scanner; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +@Category(PythonTest.class) +public class Pg8000BasicsTest extends AbstractMockServerTest { + + @Parameter public String host; + + @Parameters(name = "host = {0}") + public static List data() { + return ImmutableList.of(new Object[] {"localhost"}, new Object[] {"/tmp"}); + } + + static String execute(String script, String host, int port) + throws IOException, InterruptedException { + String[] runCommand = new String[] {"python3", script, host, Integer.toString(port)}; + ProcessBuilder builder = new ProcessBuilder(); + builder.command(runCommand); + builder.directory(new File("./src/test/python/pg8000")); + Process process = builder.start(); + Scanner scanner = new Scanner(process.getInputStream()); + Scanner errorScanner = new Scanner(process.getErrorStream()); + + StringBuilder output = new StringBuilder(); + while (scanner.hasNextLine()) { + output.append(scanner.nextLine()).append("\n"); + } + StringBuilder error = new StringBuilder(); + while (errorScanner.hasNextLine()) { + error.append(errorScanner.nextLine()).append("\n"); + } + int result = process.waitFor(); + assertEquals(error.toString(), 0, result); + + return output.toString(); + } + + @Test + public void testBasicSelect() throws IOException, InterruptedException { + String sql = "SELECT 1"; + + String actualOutput = execute("select1.py", host, pgServer.getLocalPort()); + String expectedOutput = "SELECT 1: [1]\n"; + assertEquals(expectedOutput, actualOutput); + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0); + assertEquals(sql, request.getSql()); + assertTrue(request.getTransaction().hasSingleUse()); + assertTrue(request.getTransaction().getSingleUse().hasReadOnly()); + + List messages = getWireMessages(); + assertEquals(4, messages.size()); + assertEquals(StartupMessage.class, messages.get(0).getClass()); + assertEquals(QueryMessage.class, messages.get(1).getClass()); + assertEquals(QueryMessage.class, messages.get(2).getClass()); + assertEquals(TerminateMessage.class, messages.get(3).getClass()); + } + + @Test + public void testSelectAllTypes() throws IOException, InterruptedException { + String sql = "SELECT * FROM all_types"; + mockSpanner.putStatementResult(StatementResult.query(Statement.of(sql), ALL_TYPES_RESULTSET)); + + String actualOutput = execute("select_all_types.py", host, pgServer.getLocalPort()); + String expectedOutput = + "row: [1, True, b'test', 3.14, 100, Decimal('6.626'), " + + "datetime.datetime(2022, 2, 16, 13, 18, 2, 123456, tzinfo=tzutc()), datetime.date(2022, 3, 29), " + + "'test', {'key': 'value'}]\n"; + assertEquals(expectedOutput, actualOutput.replace("tzlocal()", "tzutc()")); + } + + @Test + public void testSelectParameterized() throws IOException, InterruptedException { + String sql = "SELECT * FROM all_types WHERE col_bigint=$1"; + mockSpanner.putStatementResult(StatementResult.query(Statement.of(sql), ALL_TYPES_RESULTSET)); + + String actualOutput = execute("select_parameterized.py", host, pgServer.getLocalPort()); + String expectedOutput = + "first execution: [1, True, b'test', 3.14, 100, Decimal('6.626'), " + + "datetime.datetime(2022, 2, 16, 13, 18, 2, 123456, tzinfo=tzutc()), datetime.date(2022, 3, 29), " + + "'test', {'key': 'value'}]\n" + + "second execution: [1, True, b'test', 3.14, 100, Decimal('6.626'), " + + "datetime.datetime(2022, 2, 16, 13, 18, 2, 123456, tzinfo=tzutc()), datetime.date(2022, 3, 29), " + + "'test', {'key': 'value'}]\n"; + assertEquals(expectedOutput, actualOutput.replace("tzlocal()", "tzutc()")); + + List messages = getWireMessages(); + assertEquals(25, messages.size()); + // Yes, you read that right. 25 messages to execute the same query twice. + // 3 of these messages are not related to executing the queries: + // 1. Startup + // 2. Execute `set time zone 'utc'` + // 3. Terminate + int index = 0; + assertEquals(StartupMessage.class, messages.get(index++).getClass()); + assertEquals(QueryMessage.class, messages.get(index++).getClass()); + + assertEquals(ParseMessage.class, messages.get(index++).getClass()); + assertEquals(FlushMessage.class, messages.get(index++).getClass()); + assertEquals(SyncMessage.class, messages.get(index++).getClass()); + assertEquals(DescribeMessage.class, messages.get(index++).getClass()); + assertEquals(FlushMessage.class, messages.get(index++).getClass()); + assertEquals(SyncMessage.class, messages.get(index++).getClass()); + assertEquals(BindMessage.class, messages.get(index++).getClass()); + assertEquals(FlushMessage.class, messages.get(index++).getClass()); + assertEquals(ExecuteMessage.class, messages.get(index++).getClass()); + assertEquals(FlushMessage.class, messages.get(index++).getClass()); + assertEquals(SyncMessage.class, messages.get(index++).getClass()); + + assertEquals(ParseMessage.class, messages.get(index++).getClass()); + assertEquals(FlushMessage.class, messages.get(index++).getClass()); + assertEquals(SyncMessage.class, messages.get(index++).getClass()); + assertEquals(DescribeMessage.class, messages.get(index++).getClass()); + assertEquals(FlushMessage.class, messages.get(index++).getClass()); + assertEquals(SyncMessage.class, messages.get(index++).getClass()); + assertEquals(BindMessage.class, messages.get(index++).getClass()); + assertEquals(FlushMessage.class, messages.get(index++).getClass()); + assertEquals(ExecuteMessage.class, messages.get(index++).getClass()); + assertEquals(FlushMessage.class, messages.get(index++).getClass()); + assertEquals(SyncMessage.class, messages.get(index++).getClass()); + + assertEquals(TerminateMessage.class, messages.get(index++).getClass()); + } +} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/python/psycopg2/PythonBasicTests.java b/src/test/java/com/google/cloud/spanner/pgadapter/python/psycopg2/PythonBasicTests.java index 2062ee94c..21f01e819 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/python/psycopg2/PythonBasicTests.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/python/psycopg2/PythonBasicTests.java @@ -29,6 +29,7 @@ import com.google.spanner.v1.ExecuteSqlRequest.QueryMode; import com.google.spanner.v1.ResultSet; import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; import com.google.spanner.v1.StructType; import com.google.spanner.v1.StructType.Field; import com.google.spanner.v1.Type; @@ -218,12 +219,13 @@ public void testPreparedInsertWithParameters() throws IOException, InterruptedEx parameters.add("VALUE"); String sql = "INSERT INTO SOME_TABLE(COLUMN_NAME) VALUES ($1)"; - String describeParametersSql = "select $1 from (select COLUMN_NAME=$1 from SOME_TABLE) p"; mockSpanner.putStatementResult( StatementResult.query( - Statement.of(describeParametersSql), - createResultSetWithOnlyMetadata(ImmutableList.of(TypeCode.STRING)))); - mockSpanner.putStatementResult(StatementResult.update(Statement.of(sql), 0)); + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata(createParameterTypesMetadata(ImmutableList.of(TypeCode.STRING))) + .setStats(ResultSetStats.newBuilder().build()) + .build())); mockSpanner.putStatementResult( StatementResult.update(Statement.newBuilder(sql).bind("p1").to("VALUE").build(), 1)); @@ -233,20 +235,17 @@ public void testPreparedInsertWithParameters() throws IOException, InterruptedEx String expectedOutput = "1\n1\n"; assertEquals(expectedOutput, actualOutput); - // We receive 4 ExecuteSqlRequests: - // 1. Describe parameters. - // 2. Analyze the update statement. - // 3. Execute the update statement twice. + // We receive 3 ExecuteSqlRequests: + // 1. Analyze the update statement. + // 2. Execute the update statement twice. List requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); - assertEquals(4, requests.size()); - assertEquals(describeParametersSql, requests.get(0).getSql()); + assertEquals(3, requests.size()); + assertEquals(sql, requests.get(0).getSql()); assertEquals(QueryMode.PLAN, requests.get(0).getQueryMode()); assertEquals(sql, requests.get(1).getSql()); - assertEquals(QueryMode.PLAN, requests.get(1).getQueryMode()); + assertEquals(QueryMode.NORMAL, requests.get(1).getQueryMode()); assertEquals(sql, requests.get(2).getSql()); assertEquals(QueryMode.NORMAL, requests.get(2).getQueryMode()); - assertEquals(sql, requests.get(3).getSql()); - assertEquals(QueryMode.NORMAL, requests.get(3).getQueryMode()); // This is all executed in auto commit mode. That means that the analyzeUpdate call is also // executed in auto commit mode, and is automatically committed. assertEquals(3, mockSpanner.countRequestsOfType(CommitRequest.class)); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/session/SessionStateTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/session/SessionStateTest.java index c0cd0ca90..29bdd3060 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/session/SessionStateTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/session/SessionStateTest.java @@ -33,12 +33,12 @@ import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata.DdlTransactionMode; import com.google.cloud.spanner.pgadapter.statements.PgCatalog; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import java.util.List; +import java.util.Map; +import java.util.TimeZone; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.postgresql.core.Oid; @RunWith(JUnit4.class) public class SessionStateTest { @@ -827,25 +827,114 @@ public void testDdlTransactionMode_bootVal() { } @Test - public void testGuessTypes_defaultNonJdbc() { - OptionsMetadata optionsMetadata = mock(OptionsMetadata.class); - SessionState state = new SessionState(ImmutableMap.of(), optionsMetadata); - assertEquals(ImmutableSet.of(), state.getGuessTypes()); + public void testGetDefaultTimeZone() { + Map originalSettings = ImmutableMap.copyOf(SessionState.SERVER_SETTINGS); + SessionState.SERVER_SETTINGS.remove("TimeZone"); + try { + OptionsMetadata optionsMetadata = mock(OptionsMetadata.class); + SessionState state = new SessionState(ImmutableMap.of(), optionsMetadata); + assertEquals(TimeZone.getDefault().toZoneId(), state.getTimezone()); + } finally { + SessionState.SERVER_SETTINGS.putAll(originalSettings); + } } @Test - public void testGuessTypes_defaultJdbc() { - OptionsMetadata optionsMetadata = mock(OptionsMetadata.class); - SessionState state = new SessionState(ImmutableMap.of(), optionsMetadata); - state.set("spanner", "guess_types", String.format("%d,%d", Oid.TIMESTAMPTZ, Oid.DATE)); - assertEquals(ImmutableSet.of(Oid.TIMESTAMPTZ, Oid.DATE), state.getGuessTypes()); + public void testTimeZoneResetVal() { + Map originalSettings = ImmutableMap.copyOf(SessionState.SERVER_SETTINGS); + SessionState.SERVER_SETTINGS.put( + "TimeZone", + new PGSetting( + null, + "TimeZone", + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + "Europe/Oslo", + "America/New_York", + null, + null, + false)); + try { + OptionsMetadata optionsMetadata = mock(OptionsMetadata.class); + SessionState state = new SessionState(ImmutableMap.of(), optionsMetadata); + assertEquals("America/New_York", state.getTimezone().getId()); + } finally { + SessionState.SERVER_SETTINGS.putAll(originalSettings); + } } @Test - public void testGuessTypes_invalidOids() { - OptionsMetadata optionsMetadata = mock(OptionsMetadata.class); - SessionState state = new SessionState(ImmutableMap.of(), optionsMetadata); - state.set("spanner", "guess_types", String.format("%d,%d,foo", Oid.TIMESTAMPTZ, Oid.DATE)); - assertEquals(ImmutableSet.of(Oid.TIMESTAMPTZ, Oid.DATE), state.getGuessTypes()); + public void testTimeZoneBootVal() { + Map originalSettings = ImmutableMap.copyOf(SessionState.SERVER_SETTINGS); + SessionState.SERVER_SETTINGS.put( + "TimeZone", + new PGSetting( + null, + "TimeZone", + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + "Europe/Oslo", + null, + null, + null, + false)); + try { + OptionsMetadata optionsMetadata = mock(OptionsMetadata.class); + SessionState state = new SessionState(ImmutableMap.of(), optionsMetadata); + assertEquals("Europe/Oslo", state.getTimezone().getId()); + } finally { + SessionState.SERVER_SETTINGS.putAll(originalSettings); + } + } + + @Test + public void testGetInvalidTimeZone() { + Map originalSettings = ImmutableMap.copyOf(SessionState.SERVER_SETTINGS); + SessionState.SERVER_SETTINGS.put( + "TimeZone", + new PGSetting( + null, + "TimeZone", + null, + null, + null, + null, + null, + null, + null, + null, + "foo/bar", + null, + null, + null, + null, + null, + null, + false)); + try { + OptionsMetadata optionsMetadata = mock(OptionsMetadata.class); + SessionState state = new SessionState(ImmutableMap.of(), optionsMetadata); + assertEquals(TimeZone.getDefault().toZoneId(), state.getTimezone()); + } finally { + SessionState.SERVER_SETTINGS.putAll(originalSettings); + } } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/statements/BackendConnectionTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/statements/BackendConnectionTest.java index 0ec73a20e..ca6958beb 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/statements/BackendConnectionTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/statements/BackendConnectionTest.java @@ -44,6 +44,7 @@ import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.connection.StatementResult.ResultType; import com.google.cloud.spanner.pgadapter.error.PGException; +import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; import com.google.cloud.spanner.pgadapter.error.SQLState; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.NoResult; @@ -58,6 +59,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; +import java.util.function.Function; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -137,10 +139,12 @@ public void testExecuteStatementsInBatch() { backendConnection.execute( PARSER.parse(Statement.of("CREATE TABLE \"Foo\" (id bigint primary key)")), - Statement.of("CREATE TABLE \"Foo\" (id bigint primary key)")); + Statement.of("CREATE TABLE \"Foo\" (id bigint primary key)"), + Function.identity()); backendConnection.execute( PARSER.parse(Statement.of("CREATE TABLE bar (id bigint primary key, value text)")), - Statement.of("CREATE TABLE bar (id bigint primary key, value text)")); + Statement.of("CREATE TABLE bar (id bigint primary key, value text)"), + Function.identity()); SpannerBatchUpdateException batchUpdateException = assertThrows( @@ -203,8 +207,8 @@ public void testHasDmlOrCopyStatementsAfter() { spannerConnection, mock(OptionsMetadata.class), ImmutableList.of()); - onlyDmlStatements.execute(parsedUpdateStatement, updateStatement); - onlyDmlStatements.execute(parsedUpdateStatement, updateStatement); + onlyDmlStatements.execute(parsedUpdateStatement, updateStatement, Function.identity()); + onlyDmlStatements.execute(parsedUpdateStatement, updateStatement, Function.identity()); assertTrue(onlyDmlStatements.hasDmlOrCopyStatementsAfter(0)); assertTrue(onlyDmlStatements.hasDmlOrCopyStatementsAfter(1)); @@ -225,7 +229,7 @@ public void testHasDmlOrCopyStatementsAfter() { spannerConnection, mock(OptionsMetadata.class), ImmutableList.of()); - dmlAndCopyStatements.execute(parsedUpdateStatement, updateStatement); + dmlAndCopyStatements.execute(parsedUpdateStatement, updateStatement, Function.identity()); dmlAndCopyStatements.executeCopy( parsedCopyStatement, copyStatement, receiver, writer, executor); assertTrue(dmlAndCopyStatements.hasDmlOrCopyStatementsAfter(0)); @@ -237,8 +241,8 @@ public void testHasDmlOrCopyStatementsAfter() { spannerConnection, mock(OptionsMetadata.class), ImmutableList.of()); - onlySelectStatements.execute(parsedSelectStatement, selectStatement); - onlySelectStatements.execute(parsedSelectStatement, selectStatement); + onlySelectStatements.execute(parsedSelectStatement, selectStatement, Function.identity()); + onlySelectStatements.execute(parsedSelectStatement, selectStatement, Function.identity()); assertFalse(onlySelectStatements.hasDmlOrCopyStatementsAfter(0)); assertFalse(onlySelectStatements.hasDmlOrCopyStatementsAfter(1)); @@ -248,8 +252,10 @@ public void testHasDmlOrCopyStatementsAfter() { spannerConnection, mock(OptionsMetadata.class), ImmutableList.of()); - onlyClientSideStatements.execute(parsedClientSideStatement, clientSideStatement); - onlyClientSideStatements.execute(parsedClientSideStatement, clientSideStatement); + onlyClientSideStatements.execute( + parsedClientSideStatement, clientSideStatement, Function.identity()); + onlyClientSideStatements.execute( + parsedClientSideStatement, clientSideStatement, Function.identity()); assertFalse(onlyClientSideStatements.hasDmlOrCopyStatementsAfter(0)); assertFalse(onlyClientSideStatements.hasDmlOrCopyStatementsAfter(1)); @@ -259,8 +265,8 @@ public void testHasDmlOrCopyStatementsAfter() { spannerConnection, mock(OptionsMetadata.class), ImmutableList.of()); - onlyUnknownStatements.execute(parsedUnknownStatement, unknownStatement); - onlyUnknownStatements.execute(parsedUnknownStatement, unknownStatement); + onlyUnknownStatements.execute(parsedUnknownStatement, unknownStatement, Function.identity()); + onlyUnknownStatements.execute(parsedUnknownStatement, unknownStatement, Function.identity()); assertFalse(onlyUnknownStatements.hasDmlOrCopyStatementsAfter(0)); assertFalse(onlyUnknownStatements.hasDmlOrCopyStatementsAfter(1)); @@ -270,8 +276,8 @@ public void testHasDmlOrCopyStatementsAfter() { spannerConnection, mock(OptionsMetadata.class), ImmutableList.of()); - dmlAndSelectStatements.execute(parsedUpdateStatement, updateStatement); - dmlAndSelectStatements.execute(parsedSelectStatement, selectStatement); + dmlAndSelectStatements.execute(parsedUpdateStatement, updateStatement, Function.identity()); + dmlAndSelectStatements.execute(parsedSelectStatement, selectStatement, Function.identity()); assertTrue(dmlAndSelectStatements.hasDmlOrCopyStatementsAfter(0)); assertFalse(dmlAndSelectStatements.hasDmlOrCopyStatementsAfter(1)); @@ -283,7 +289,7 @@ public void testHasDmlOrCopyStatementsAfter() { ImmutableList.of()); copyAndSelectStatements.executeCopy( parsedCopyStatement, copyStatement, receiver, writer, executor); - copyAndSelectStatements.execute(parsedSelectStatement, selectStatement); + copyAndSelectStatements.execute(parsedSelectStatement, selectStatement, Function.identity()); assertTrue(copyAndSelectStatements.hasDmlOrCopyStatementsAfter(0)); assertFalse(copyAndSelectStatements.hasDmlOrCopyStatementsAfter(1)); @@ -295,7 +301,7 @@ public void testHasDmlOrCopyStatementsAfter() { ImmutableList.of()); copyAndUnknownStatements.executeCopy( parsedCopyStatement, copyStatement, receiver, writer, executor); - copyAndUnknownStatements.execute(parsedUnknownStatement, unknownStatement); + copyAndUnknownStatements.execute(parsedUnknownStatement, unknownStatement, Function.identity()); assertTrue(copyAndUnknownStatements.hasDmlOrCopyStatementsAfter(0)); assertFalse(copyAndUnknownStatements.hasDmlOrCopyStatementsAfter(1)); } @@ -320,7 +326,9 @@ public void testExecuteLocalStatement() throws ExecutionException, InterruptedEx DatabaseId.of("p", "i", "d"), connection, mock(OptionsMetadata.class), localStatements); Future resultFuture = backendConnection.execute( - parsedListDatabasesStatement, Statement.of(ListDatabasesStatement.LIST_DATABASES_SQL)); + parsedListDatabasesStatement, + Statement.of(ListDatabasesStatement.LIST_DATABASES_SQL), + Function.identity()); backendConnection.flush(); verify(listDatabasesStatement).execute(backendConnection); @@ -348,7 +356,8 @@ public void testExecuteOtherStatementWithLocalStatements() BackendConnection backendConnection = new BackendConnection( DatabaseId.of("p", "i", "d"), connection, mock(OptionsMetadata.class), localStatements); - Future resultFuture = backendConnection.execute(parsedStatement, statement); + Future resultFuture = + backendConnection.execute(parsedStatement, statement, Function.identity()); backendConnection.flush(); verify(listDatabasesStatement, never()).execute(backendConnection); @@ -381,12 +390,13 @@ public void testGeneralException() { connection, mock(OptionsMetadata.class), EMPTY_LOCAL_STATEMENTS); - Future resultFuture = backendConnection.execute(parsedStatement, statement); + Future resultFuture = + backendConnection.execute(parsedStatement, statement, Function.identity()); backendConnection.flush(); ExecutionException executionException = assertThrows(ExecutionException.class, resultFuture::get); - assertSame(executionException.getCause(), error); + assertEquals(executionException.getCause(), PGExceptionFactory.toPGException(error)); } @Test @@ -404,7 +414,8 @@ public void testCancelledException() { connection, mock(OptionsMetadata.class), EMPTY_LOCAL_STATEMENTS); - Future resultFuture = backendConnection.execute(parsedStatement, statement); + Future resultFuture = + backendConnection.execute(parsedStatement, statement, Function.identity()); backendConnection.flush(); ExecutionException executionException = @@ -441,8 +452,9 @@ public void testDdlExceptionInBatch() { connection, mock(OptionsMetadata.class), EMPTY_LOCAL_STATEMENTS); - Future resultFuture1 = backendConnection.execute(parsedStatement1, statement1); - backendConnection.execute(parsedStatement2, statement2); + Future resultFuture1 = + backendConnection.execute(parsedStatement1, statement1, Function.identity()); + backendConnection.execute(parsedStatement2, statement2, Function.identity()); backendConnection.flush(); // The error will be set on the first statement in the batch, as the error occurs before @@ -465,7 +477,7 @@ public void testReplacePgCatalogTables() { new BackendConnection( DatabaseId.of("p", "i", "d"), connection, options, EMPTY_LOCAL_STATEMENTS); - backendConnection.execute(parsedStatement, statement); + backendConnection.execute(parsedStatement, statement, Function.identity()); backendConnection.flush(); verify(connection) @@ -508,7 +520,7 @@ public void testDisableReplacePgCatalogTables() { new BackendConnection( DatabaseId.of("p", "i", "d"), connection, options, EMPTY_LOCAL_STATEMENTS); - backendConnection.execute(parsedStatement, statement); + backendConnection.execute(parsedStatement, statement, Function.identity()); backendConnection.flush(); verify(connection).execute(statement); @@ -529,7 +541,7 @@ public void testDoNotStartTransactionInBatch() { mock(OptionsMetadata.class), EMPTY_LOCAL_STATEMENTS); - backendConnection.execute(parsedStatement, statement); + backendConnection.execute(parsedStatement, statement, Function.identity()); backendConnection.flush(); verify(connection).execute(statement); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatementTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatementTest.java index 5a523b41b..3cf44c766 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatementTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatementTest.java @@ -14,15 +14,10 @@ package com.google.cloud.spanner.pgadapter.statements; -import static com.google.cloud.spanner.pgadapter.statements.IntermediatePreparedStatement.extractParameters; -import static com.google.cloud.spanner.pgadapter.statements.IntermediatePreparedStatement.transformDeleteToSelectParams; -import static com.google.cloud.spanner.pgadapter.statements.IntermediatePreparedStatement.transformInsertToSelectParams; -import static com.google.cloud.spanner.pgadapter.statements.IntermediatePreparedStatement.transformUpdateToSelectParams; import static com.google.cloud.spanner.pgadapter.statements.SimpleParserTest.splitStatements; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; @@ -38,9 +33,13 @@ import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.connection.StatementResult.ResultType; import com.google.cloud.spanner.pgadapter.ConnectionHandler; +import com.google.cloud.spanner.pgadapter.error.PGException; +import com.google.cloud.spanner.pgadapter.error.SQLState; import com.google.cloud.spanner.pgadapter.metadata.ConnectionMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; +import com.google.cloud.spanner.pgadapter.statements.IntermediateStatement.ResultNotReadyBehavior; import com.google.common.collect.ImmutableList; +import java.util.concurrent.Future; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -157,174 +156,20 @@ public void testUpdateResultCount_NoResult() { } @Test - public void testTransformInsertValuesToSelectParams() { - assertEquals( - "select $1, $2 from (select col1=$1, col2=$2 from foo) p", - transformInsert("insert into foo (col1, col2) values ($1, $2)").getSql()); - assertEquals( - "select $1, $2 from (select col1=$1, col2=$2 from foo) p", - transformInsert("insert into foo(col1, col2) values ($1, $2)").getSql()); - assertEquals( - "select $1, $2 from (select col1=$1, col2=$2 from foo) p", - transformInsert("insert into foo (col1, col2) values($1, $2)").getSql()); - assertEquals( - "select $1, $2 from (select col1=$1, col2=$2 from foo) p", - transformInsert("insert into foo(col1, col2) values($1, $2)").getSql()); - assertEquals( - "select $1, $2 from (select col1=$1, col2=$2 from foo) p", - transformInsert("insert foo(col1, col2) values($1, $2)").getSql()); - assertEquals( - "select $1, $2 from (select col1=$1, col2=$2 from foo) p", - transformInsert("insert foo (col1, col2) values ($1, $2)").getSql()); - assertEquals( - "select $1, $2, $3, $4 from (select col1=$1, col2=$2, col1=$3, col2=$4 from foo) p", - transformInsert("insert into foo (col1, col2) values ($1, $2), ($3, $4)").getSql()); - assertEquals( - "select $1, $2 from (select col1=$1::varchar, col2=$2::bigint from foo) p", - transformInsert("insert into foo (col1, col2) values ($1::varchar, $2::bigint)").getSql()); - assertEquals( - "select $1, $2, $3, $4 from (select col1=($1 + $2), col2=$3 || to_char($4) from foo) p", - transformInsert("insert into foo (col1, col2) values (($1 + $2), $3 || to_char($4))") - .getSql()); - assertEquals( - "select $1, $2, $3, $4 from (select col1=($1 + $2), col2=$3 || to_char($4) from foo) p", - transformInsert("insert into foo (col1, col2) values (($1 + $2), $3 || to_char($4))") - .getSql()); - assertEquals( - "select $1, $2, $3, $4, $5 from (select col1=$1 + $2 + 5, col2=$3 || to_char($4) || coalesce($5, '') from foo) p", - transformInsert( - "insert\ninto\nfoo\n(col1,\ncol2 ) values ($1 + $2 + 5, $3 || to_char($4) || coalesce($5, ''))") - .getSql()); - assertEquals( - "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 from " - + "(select col1=$1, col2=$2, col3=$3, col4=$4, col5=$5, col6=$6, col7=$7, col8=$8, col9=$9, col10=$10 from foo) p", - transformInsert( - "insert into foo (col1, col2, col3, col4, col5, col6, col7, col8, col9, col10) " - + "values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)") - .getSql()); - assertEquals( - "select $1, $2 from (select col1=$1, col2=$2 from \"foo\") p", - transformInsert("insert\"foo\"(col1, col2)values($1, $2)").getSql()); - } - - @Test - public void testTransformInsertSelectToSelectParams() { - assertEquals( - "select $1 from (select * from bar where some_col=$1) p", - transformInsert("insert into foo select * from bar where some_col=$1").getSql()); - assertEquals( - "select $1 from ((select * from bar where some_col=$1)) p", - transformInsert("insert into foo (select * from bar where some_col=$1)").getSql()); - assertEquals( - "select $1 from ((select * from(select col1, col2 from bar) where col2=$1)) p", - transformInsert("insert into foo (select * from(select col1, col2 from bar) where col2=$1)") - .getSql()); - assertEquals( - "select $1 from (select * from bar where some_col=$1) p", - transformInsert("insert foo select * from bar where some_col=$1").getSql()); - assertEquals( - "select $1 from (select * from bar where some_col=$1) p", - transformInsert("insert into foo (col1, col2) select * from bar where some_col=$1") - .getSql()); - assertEquals( - "select $1 from (select * from bar where some_col=$1) p", - transformInsert("insert foo (col1, col2, col3) select * from bar where some_col=$1") - .getSql()); - assertEquals( - "select $1, $2 from (select * from bar where some_col=$1 limit $2) p", - transformInsert( - "insert foo (col1, col2, col3) select * from bar where some_col=$1 limit $2") - .getSql()); - assertNull(transformInsert("insert into foo (col1 values ('test')")); - } - - @Test - public void testTransformUpdateToSelectParams() { - assertEquals( - "select $1, $2, $3 from (select col1=$1, col2=$2 from foo where id=$3) p", - transformUpdate("update foo set col1=$1, col2=$2 where id=$3").getSql()); - assertEquals( - "select $1, $2, $3 from (select col1=col2 + $1, " - + "col2=coalesce($1, $2, $3, to_char(current_timestamp())), " - + "col3 = 15 " - + "from foo where id=$3 and value>100) p", - transformUpdate( - "update foo set col1=col2 + $1 , " - + "col2=coalesce($1, $2, $3, to_char(current_timestamp())), " - + "col3 = 15 " - + "where id=$3 and value>100") - .getSql()); - assertEquals( - "select $1 from (select col1=$1 from foo) p", - transformUpdate("update foo set col1=$1").getSql()); - - assertNull(transformUpdate("update foo col1=1")); - assertNull(transformUpdate("update foo col1=1 hwere col1=2")); - assertNull(transformUpdate("udpate foo col1=1 where col1=2")); - - assertEquals( - "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 from (select col1=$1, col2=$2, col3=$3, col4=$4, col5=$5, col6=$6, col7=$7, col8=$8, col9=$9 from foo where id=$10) p", - transformUpdate( - "update foo set col1=$1, col2=$2, col3=$3, col4=$4, col5=$5, col6=$6, col7=$7, col8=$8, col9=$9 where id=$10") - .getSql()); - assertEquals( - "select $1, $2, $3 from (select col1=(select col2 from bar where col3=$1), col2=$2 from foo where id=$3) p", - transformUpdate( - "update foo set col1=(select col2 from bar where col3=$1), col2=$2 where id=$3") - .getSql()); - } - - @Test - public void testTransformDeleteToSelectParams() { - assertEquals( - "select $1 from (select 1 from foo where id=$1) p", - transformDelete("delete from foo where id=$1").getSql()); - assertEquals( - "select $1, $2 from (select 1 from foo where id=$1 and bar > $2) p", - transformDelete("delete foo\nwhere id=$1 and bar > $2").getSql()); - assertEquals( - "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10 " - + "from (select 1 from all_types " - + "where col_bigint=$1 " - + "and col_bool=$2 " - + "and col_bytea=$3 " - + "and col_float8=$4 " - + "and col_int=$5 " - + "and col_numeric=$6 " - + "and col_timestamptz=$7 " - + "and col_date=$8 " - + "and col_varchar=$9 " - + "and col_jsonb=$10" - + ") p", - transformDelete( - "delete " - + "from all_types " - + "where col_bigint=$1 " - + "and col_bool=$2 " - + "and col_bytea=$3 " - + "and col_float8=$4 " - + "and col_int=$5 " - + "and col_numeric=$6 " - + "and col_timestamptz=$7 " - + "and col_date=$8 " - + "and col_varchar=$9 " - + "and col_jsonb=$10") - .getSql()); - - assertNull(transformDelete("delete from foo")); - assertNull(transformDelete("dlete from foo where id=$1")); - assertNull(transformDelete("delete from foo hwere col1=2")); - } - - private static Statement transformInsert(String sql) { - return transformInsertToSelectParams(mock(Connection.class), sql, extractParameters(sql)); - } + public void testInterruptedWhileWaitingForResult() throws Exception { + when(connectionHandler.getSpannerConnection()).thenReturn(connection); + when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); + Future future = mock(Future.class); + when(future.get()).thenThrow(new InterruptedException()); - private static Statement transformUpdate(String sql) { - return transformUpdateToSelectParams(sql, extractParameters(sql)); - } + String sql = "update bar set foo=1"; + IntermediateStatement statement = + new IntermediateStatement( + mock(OptionsMetadata.class), parse(sql), Statement.of(sql), connectionHandler); + statement.setFutureStatementResult(future); + statement.initFutureResult(ResultNotReadyBehavior.BLOCK); - private static Statement transformDelete(String sql) { - return transformDeleteToSelectParams(sql, extractParameters(sql)); + PGException pgException = statement.getException(); + assertEquals(SQLState.QueryCanceled, pgException.getSQLState()); } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/statements/StatementTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/statements/StatementTest.java index bd381dbab..19b87e5bd 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/statements/StatementTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/statements/StatementTest.java @@ -14,6 +14,8 @@ package com.google.cloud.spanner.pgadapter.statements; +import static com.google.cloud.spanner.pgadapter.statements.IntermediatePortalStatement.NO_PARAMS; +import static com.google.cloud.spanner.pgadapter.statements.IntermediatePreparedStatement.NO_PARAMETER_TYPES; import static com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.EMPTY_LOCAL_STATEMENTS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -21,6 +23,8 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; @@ -31,7 +35,6 @@ import com.google.cloud.spanner.Dialect; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.ReadContext; -import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; @@ -43,6 +46,9 @@ import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.ProxyServer; +import com.google.cloud.spanner.pgadapter.error.PGException; +import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; +import com.google.cloud.spanner.pgadapter.error.SQLState; import com.google.cloud.spanner.pgadapter.metadata.ConnectionMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.session.SessionState; @@ -51,7 +57,6 @@ import com.google.cloud.spanner.pgadapter.wireprotocol.QueryMessage; import com.google.cloud.spanner.pgadapter.wireprotocol.WireMessage; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Bytes; import java.io.ByteArrayInputStream; import java.io.DataInputStream; @@ -62,7 +67,6 @@ import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -118,14 +122,19 @@ public void testBasicSelectStatement() throws Exception { when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); IntermediatePortalStatement intermediateStatement = new IntermediatePortalStatement( - connectionHandler, options, "", parse(sql), Statement.of(sql)); + "", + new IntermediatePreparedStatement( + connectionHandler, options, "", NO_PARAMETER_TYPES, parse(sql), Statement.of(sql)), + NO_PARAMS, + ImmutableList.of(), + ImmutableList.of()); assertFalse(intermediateStatement.isExecuted()); assertEquals("SELECT", intermediateStatement.getCommand()); intermediateStatement.executeAsync(backendConnection); - verify(backendConnection).execute(parse(sql), Statement.of(sql)); + verify(backendConnection).execute(eq(parse(sql)), eq(Statement.of(sql)), any()); assertTrue(intermediateStatement.containsResultSet()); assertTrue(intermediateStatement.isExecuted()); assertEquals(StatementType.QUERY, intermediateStatement.getStatementType()); @@ -142,14 +151,19 @@ public void testBasicUpdateStatement() throws Exception { when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); IntermediatePortalStatement intermediateStatement = new IntermediatePortalStatement( - connectionHandler, options, "", parse(sql), Statement.of(sql)); + "", + new IntermediatePreparedStatement( + connectionHandler, options, "", NO_PARAMETER_TYPES, parse(sql), Statement.of(sql)), + NO_PARAMS, + ImmutableList.of(), + ImmutableList.of()); assertFalse(intermediateStatement.isExecuted()); assertEquals("UPDATE", intermediateStatement.getCommand()); intermediateStatement.executeAsync(backendConnection); - verify(backendConnection).execute(parse(sql), Statement.of(sql)); + verify(backendConnection).execute(eq(parse(sql)), eq(Statement.of(sql)), any()); assertFalse(intermediateStatement.containsResultSet()); assertTrue(intermediateStatement.isExecuted()); assertEquals(StatementType.UPDATE, intermediateStatement.getStatementType()); @@ -172,7 +186,12 @@ public void testBasicZeroUpdateCountResultStatement() throws Exception { when(connection.execute(Statement.of(sql))).thenReturn(statementResult); IntermediatePortalStatement intermediateStatement = new IntermediatePortalStatement( - connectionHandler, options, "", parse(sql), Statement.of(sql)); + "", + new IntermediatePreparedStatement( + connectionHandler, options, "", NO_PARAMETER_TYPES, parse(sql), Statement.of(sql)), + NO_PARAMS, + ImmutableList.of(), + ImmutableList.of()); BackendConnection backendConnection = new BackendConnection( connectionHandler.getDatabaseId(), connection, options, EMPTY_LOCAL_STATEMENTS); @@ -205,14 +224,19 @@ public void testBasicNoResultStatement() throws Exception { when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); IntermediatePortalStatement intermediateStatement = new IntermediatePortalStatement( - connectionHandler, options, "", parse(sql), Statement.of(sql)); + "", + new IntermediatePreparedStatement( + connectionHandler, options, "", NO_PARAMETER_TYPES, parse(sql), Statement.of(sql)), + NO_PARAMS, + ImmutableList.of(), + ImmutableList.of()); assertFalse(intermediateStatement.isExecuted()); assertEquals("CREATE", intermediateStatement.getCommand()); intermediateStatement.executeAsync(backendConnection); - verify(backendConnection).execute(parse(sql), Statement.of(sql)); + verify(backendConnection).execute(eq(parse(sql)), eq(Statement.of(sql)), any()); assertFalse(intermediateStatement.containsResultSet()); assertEquals(0, intermediateStatement.getUpdateCount()); assertTrue(intermediateStatement.isExecuted()); @@ -248,7 +272,12 @@ public void testBasicStatementExceptionGetsSetOnExceptedExecution() { when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); IntermediatePortalStatement intermediateStatement = new IntermediatePortalStatement( - connectionHandler, options, "", parse(sql), Statement.of(sql)); + "", + new IntermediatePreparedStatement( + connectionHandler, options, "", NO_PARAMETER_TYPES, parse(sql), Statement.of(sql)), + NO_PARAMS, + ImmutableList.of(), + ImmutableList.of()); BackendConnection backendConnection = new BackendConnection( connectionHandler.getDatabaseId(), connection, options, EMPTY_LOCAL_STATEMENTS); @@ -257,7 +286,8 @@ public void testBasicStatementExceptionGetsSetOnExceptedExecution() { backendConnection.flush(); assertTrue(intermediateStatement.hasException()); - assertEquals(thrownException, intermediateStatement.getException()); + assertEquals( + PGExceptionFactory.newPGException("test error"), intermediateStatement.getException()); } @Test @@ -271,7 +301,6 @@ public void testPreparedStatement() { when(extendedQueryProtocolHandler.getBackendConnection()).thenReturn(backendConnection); SessionState sessionState = mock(SessionState.class); when(backendConnection.getSessionState()).thenReturn(sessionState); - when(sessionState.getGuessTypes()).thenReturn(ImmutableSet.of()); String sqlStatement = "SELECT * FROM users WHERE age > $2 AND age < $3 AND name = $1"; int[] parameterDataTypes = new int[] {Oid.VARCHAR, Oid.INT8, Oid.INT4}; @@ -290,23 +319,28 @@ public void testPreparedStatement() { IntermediatePreparedStatement intermediateStatement = new IntermediatePreparedStatement( - connectionHandler, options, "", parse(sqlStatement), Statement.of(sqlStatement)); - intermediateStatement.setParameterDataTypes(parameterDataTypes); + connectionHandler, + options, + "", + parameterDataTypes, + parse(sqlStatement), + Statement.of(sqlStatement)); assertEquals(sqlStatement, intermediateStatement.getSql()); byte[][] parameters = {"userName".getBytes(), "20".getBytes(), "30".getBytes()}; IntermediatePortalStatement intermediatePortalStatement = - intermediateStatement.bind( + intermediateStatement.createPortal( "", parameters, Arrays.asList((short) 0, (short) 0, (short) 0), new ArrayList<>()); - intermediateStatement.executeAsync(backendConnection); + intermediatePortalStatement.bind(Statement.of(sqlStatement)); + intermediatePortalStatement.executeAsync(backendConnection); backendConnection.flush(); verify(connection).execute(statement); assertEquals(sqlStatement, intermediatePortalStatement.getSql()); assertEquals("SELECT", intermediatePortalStatement.getCommand()); - assertFalse(intermediatePortalStatement.isExecuted()); + assertTrue(intermediatePortalStatement.isExecuted()); assertTrue(intermediateStatement.isBound()); } @@ -320,21 +354,25 @@ public void testPreparedStatementIllegalTypeThrowsException() { when(extendedQueryProtocolHandler.getBackendConnection()).thenReturn(backendConnection); SessionState sessionState = mock(SessionState.class); when(backendConnection.getSessionState()).thenReturn(sessionState); - when(sessionState.getGuessTypes()).thenReturn(ImmutableSet.of()); String sqlStatement = "SELECT * FROM users WHERE metadata = $1"; int[] parameterDataTypes = new int[] {Oid.JSON}; IntermediatePreparedStatement intermediateStatement = new IntermediatePreparedStatement( - connectionHandler, options, "", parse(sqlStatement), Statement.of(sqlStatement)); - intermediateStatement.setParameterDataTypes(parameterDataTypes); + connectionHandler, + options, + "", + parameterDataTypes, + parse(sqlStatement), + Statement.of(sqlStatement)); byte[][] parameters = {"{}".getBytes()}; + IntermediatePortalStatement portalStatement = + intermediateStatement.createPortal("", parameters, new ArrayList<>(), new ArrayList<>()); assertThrows( - IllegalArgumentException.class, - () -> intermediateStatement.bind("", parameters, new ArrayList<>(), new ArrayList<>())); + IllegalArgumentException.class, () -> portalStatement.bind(Statement.of(sqlStatement))); } @Test @@ -342,65 +380,21 @@ public void testPreparedStatementDescribeDoesNotThrowException() { when(connectionHandler.getSpannerConnection()).thenReturn(connection); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); String sqlStatement = "SELECT * FROM users WHERE name = $1 AND age > $2 AND age < $3"; - when(connection.analyzeQuery(Statement.of(sqlStatement), QueryAnalyzeMode.PLAN)) - .thenReturn(resultSet); - IntermediatePreparedStatement intermediateStatement = - new IntermediatePreparedStatement( - connectionHandler, options, "", parse(sqlStatement), Statement.of(sqlStatement)); int[] parameters = new int[3]; Arrays.fill(parameters, Oid.INT8); - intermediateStatement.setParameterDataTypes(parameters); + IntermediatePreparedStatement intermediateStatement = + new IntermediatePreparedStatement( + connectionHandler, + options, + "", + parameters, + parse(sqlStatement), + Statement.of(sqlStatement)); intermediateStatement.describe(); } - @Test - public void testPortalStatement() { - when(connectionHandler.getSpannerConnection()).thenReturn(connection); - when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); - String sqlStatement = "SELECT * FROM users WHERE age > $1 AND age < $2 AND name = $3"; - - IntermediatePortalStatement intermediateStatement = - new IntermediatePortalStatement( - connectionHandler, options, "", parse(sqlStatement), Statement.of(sqlStatement)); - BackendConnection backendConnection = - new BackendConnection( - connectionHandler.getDatabaseId(), connection, options, EMPTY_LOCAL_STATEMENTS); - - intermediateStatement.describeAsync(backendConnection); - backendConnection.flush(); - - verify(connection).execute(Statement.of(sqlStatement)); - - assertEquals(0, intermediateStatement.getParameterFormatCode(0)); - assertEquals(0, intermediateStatement.getParameterFormatCode(1)); - assertEquals(0, intermediateStatement.getParameterFormatCode(2)); - assertEquals(0, intermediateStatement.getResultFormatCode(0)); - assertEquals(0, intermediateStatement.getResultFormatCode(1)); - assertEquals(0, intermediateStatement.getResultFormatCode(2)); - - intermediateStatement.setParameterFormatCodes(Collections.singletonList((short) 1)); - intermediateStatement.setResultFormatCodes(Collections.singletonList((short) 1)); - - assertEquals(1, intermediateStatement.getParameterFormatCode(0)); - assertEquals(1, intermediateStatement.getParameterFormatCode(1)); - assertEquals(1, intermediateStatement.getParameterFormatCode(2)); - assertEquals(1, intermediateStatement.getResultFormatCode(0)); - assertEquals(1, intermediateStatement.getResultFormatCode(1)); - assertEquals(1, intermediateStatement.getResultFormatCode(2)); - - intermediateStatement.setParameterFormatCodes(Arrays.asList((short) 0, (short) 1, (short) 0)); - intermediateStatement.setResultFormatCodes(Arrays.asList((short) 0, (short) 1, (short) 0)); - - assertEquals(0, intermediateStatement.getParameterFormatCode(0)); - assertEquals(1, intermediateStatement.getParameterFormatCode(1)); - assertEquals(0, intermediateStatement.getParameterFormatCode(2)); - assertEquals(0, intermediateStatement.getResultFormatCode(0)); - assertEquals(1, intermediateStatement.getResultFormatCode(1)); - assertEquals(0, intermediateStatement.getResultFormatCode(2)); - } - @Test public void testPortalStatementDescribePropagatesFailure() { when(connectionHandler.getSpannerConnection()).thenReturn(connection); @@ -409,7 +403,17 @@ public void testPortalStatementDescribePropagatesFailure() { IntermediatePortalStatement intermediateStatement = new IntermediatePortalStatement( - connectionHandler, options, "", parse(sqlStatement), Statement.of(sqlStatement)); + "", + new IntermediatePreparedStatement( + connectionHandler, + options, + "", + NO_PARAMETER_TYPES, + parse(sqlStatement), + Statement.of(sqlStatement)), + NO_PARAMS, + ImmutableList.of(), + ImmutableList.of()); BackendConnection backendConnection = new BackendConnection( connectionHandler.getDatabaseId(), connection, options, EMPTY_LOCAL_STATEMENTS); @@ -422,8 +426,9 @@ public void testPortalStatementDescribePropagatesFailure() { backendConnection.flush(); assertTrue(intermediateStatement.hasException()); - SpannerException exception = intermediateStatement.getException(); - assertEquals(ErrorCode.INVALID_ARGUMENT, exception.getErrorCode()); + PGException exception = intermediateStatement.getException(); + assertEquals(SQLState.RaiseException, exception.getSQLState()); + assertEquals("test error", exception.getMessage()); } @Test @@ -487,10 +492,10 @@ public void testCopyInvalidBuildMutation() throws Exception { backendConnection.flush(); - SpannerException thrown = assertThrows(SpannerException.class, statement::getUpdateCount); - assertEquals(ErrorCode.INVALID_ARGUMENT, thrown.getErrorCode()); + PGException thrown = assertThrows(PGException.class, statement::getUpdateCount); + assertEquals(SQLState.DataException, thrown.getSQLState()); assertEquals( - "INVALID_ARGUMENT: Invalid COPY data: Row length mismatched. Expected 2 columns, but only found 1", + "Invalid COPY data: Row length mismatched. Expected 2 columns, but only found 1", thrown.getMessage()); statement.close(); @@ -533,7 +538,12 @@ public void testGetStatementResultBeforeFlushFails() { when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); IntermediatePortalStatement intermediateStatement = new IntermediatePortalStatement( - connectionHandler, options, "", parse(sql), Statement.of(sql)); + "", + new IntermediatePreparedStatement( + connectionHandler, options, "", NO_PARAMETER_TYPES, parse(sql), Statement.of(sql)), + NO_PARAMS, + ImmutableList.of(), + ImmutableList.of()); BackendConnection backendConnection = new BackendConnection( connectionHandler.getDatabaseId(), connection, options, EMPTY_LOCAL_STATEMENTS); @@ -634,10 +644,10 @@ public void testCopyDataRowLengthMismatchLimit() throws Exception { backendConnection.flush(); - SpannerException thrown = assertThrows(SpannerException.class, copyStatement::getUpdateCount); - assertEquals(ErrorCode.INVALID_ARGUMENT, thrown.getErrorCode()); + PGException thrown = assertThrows(PGException.class, copyStatement::getUpdateCount); + assertEquals(SQLState.DataException, thrown.getSQLState()); assertEquals( - "INVALID_ARGUMENT: Invalid COPY data: Row length mismatched. Expected 2 columns, but only found 1", + "Invalid COPY data: Row length mismatched. Expected 2 columns, but only found 1", thrown.getMessage()); copyStatement.close(); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/statements/local/SelectVersionStatementTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/statements/local/SelectVersionStatementTest.java new file mode 100644 index 000000000..6f414bc07 --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/pgadapter/statements/local/SelectVersionStatementTest.java @@ -0,0 +1,52 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.statements.local; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.spanner.ResultSet; +import com.google.cloud.spanner.pgadapter.session.PGSetting; +import com.google.cloud.spanner.pgadapter.session.SessionState; +import com.google.cloud.spanner.pgadapter.statements.BackendConnection; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class SelectVersionStatementTest { + @Test + public void testExecute() { + for (String version : new String[] {"14.1", "10.0"}) { + BackendConnection backendConnection = mock(BackendConnection.class); + SessionState sessionState = mock(SessionState.class); + PGSetting pgSetting = mock(PGSetting.class); + when(backendConnection.getSessionState()).thenReturn(sessionState); + when(sessionState.get(null, "server_version")).thenReturn(pgSetting); + when(pgSetting.getSetting()).thenReturn(version); + + try (ResultSet resultSet = + SelectVersionStatement.INSTANCE.execute(backendConnection).getResultSet()) { + assertTrue(resultSet.next()); + assertEquals(1, resultSet.getColumnCount()); + assertEquals("PostgreSQL " + version, resultSet.getString("version")); + assertFalse(resultSet.next()); + } + } + } +} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/utils/CopyDataReceiverTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/utils/CopyDataReceiverTest.java index a0658916f..d938e5661 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/utils/CopyDataReceiverTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/utils/CopyDataReceiverTest.java @@ -22,9 +22,11 @@ import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.SpannerException; -import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.ConnectionHandler.ConnectionStatus; +import com.google.cloud.spanner.pgadapter.error.PGException; +import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; +import com.google.cloud.spanner.pgadapter.error.SQLState; import com.google.cloud.spanner.pgadapter.metadata.ConnectionMetadata; import com.google.cloud.spanner.pgadapter.statements.CopyStatement; import java.io.ByteArrayOutputStream; @@ -38,18 +40,17 @@ public class CopyDataReceiverTest { @Test public void testCopyStatementWithException() { - SpannerException exception = - SpannerExceptionFactory.newSpannerException( - ErrorCode.INVALID_ARGUMENT, "Invalid copy statement"); + PGException exception = + PGExceptionFactory.newPGException("Invalid copy statement", SQLState.SyntaxError); CopyStatement statement = mock(CopyStatement.class); when(statement.hasException()).thenReturn(true); when(statement.getException()).thenReturn(exception); ConnectionHandler connectionHandler = mock(ConnectionHandler.class); CopyDataReceiver receiver = new CopyDataReceiver(statement, connectionHandler); - SpannerException spannerException = assertThrows(SpannerException.class, receiver::handleCopy); + PGException pgException = assertThrows(PGException.class, receiver::handleCopy); - assertSame(exception, spannerException); + assertSame(exception, pgException); } @Test diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/utils/MutationWriterTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/utils/MutationWriterTest.java index ac776c0b6..b6881a67b 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/utils/MutationWriterTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/utils/MutationWriterTest.java @@ -17,6 +17,7 @@ import static com.google.cloud.spanner.pgadapter.utils.MutationWriter.calculateSize; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; @@ -33,11 +34,11 @@ import com.google.cloud.Timestamp; import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.Mutation; -import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Value; import com.google.cloud.spanner.connection.Connection; import com.google.cloud.spanner.connection.StatementResult; +import com.google.cloud.spanner.pgadapter.error.PGException; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.session.SessionState; import com.google.cloud.spanner.pgadapter.statements.CopyStatement.Format; @@ -56,6 +57,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -73,7 +75,7 @@ public class MutationWriterTest { @AfterClass public static void shutdownExecutor() { - executor.shutdown(); + // executor.shutdown(); } @Test @@ -207,9 +209,9 @@ public void testWriteMutations_FailsForLargeBatch() throws Exception { } }); - SpannerException exception = assertThrows(SpannerException.class, mutationWriter::call); + PGException exception = assertThrows(PGException.class, mutationWriter::call); assertEquals( - "FAILED_PRECONDITION: Record count: 2 has exceeded the limit: 1.\n" + "Record count: 2 has exceeded the limit: 1.\n" + "\n" + "The number of mutations per record is equal to the number of columns in the record plus the number of indexed columns in the record. The maximum number of mutations in one transaction is 1.\n" + "\n" @@ -314,7 +316,7 @@ public void testWriteMutations_FailsForLargeCommit() throws Exception { } }); - SpannerException exception = assertThrows(SpannerException.class, mutationWriter::call); + PGException exception = assertThrows(PGException.class, mutationWriter::call); assertTrue(exception.getMessage().contains("Commit size: 40 has exceeded the limit: 30")); verify(connection, never()).write(anyIterable()); @@ -605,4 +607,46 @@ public void testBuildMutationNulls() throws IOException { assertEquals(String.format("Type: %s", type), 1, mutation.asMap().size()); } } + + @Test + public void testWriteAfterClose() throws Exception { + Map tableColumns = ImmutableMap.of("number", Type.int64(), "name", Type.string()); + CSVFormat format = + CSVFormat.POSTGRESQL_TEXT + .builder() + .setHeader(tableColumns.keySet().toArray(new String[0])) + .build(); + SessionState sessionState = new SessionState(mock(OptionsMetadata.class)); + Connection connection = mock(Connection.class); + DatabaseClient databaseClient = mock(DatabaseClient.class); + when(connection.getDatabaseClient()).thenReturn(databaseClient); + MutationWriter mutationWriter = + new MutationWriter( + sessionState, + CopyTransactionMode.ImplicitAtomic, + connection, + "numbers", + tableColumns, + /* indexedColumnsCount = */ 1, + Format.TEXT, + format, + false); + + CountDownLatch latch = new CountDownLatch(1); + Future fut = + executor.submit( + () -> { + mutationWriter.addCopyData( + "1\t\"One\"\n2\t\"Two\"\n".getBytes(StandardCharsets.UTF_8)); + mutationWriter.close(); + latch.await(); + mutationWriter.addCopyData( + "1\t\"One\"\n2\t\"Two\"\n".getBytes(StandardCharsets.UTF_8)); + return null; + }); + + mutationWriter.call(); + latch.countDown(); + assertNull(fut.get()); + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/ControlMessageTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessageTest.java similarity index 85% rename from src/test/java/com/google/cloud/spanner/pgadapter/ControlMessageTest.java rename to src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessageTest.java index 976b3d8b3..a5221b2c3 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/ControlMessageTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessageTest.java @@ -12,24 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -package com.google.cloud.spanner.pgadapter; +package com.google.cloud.spanner.pgadapter.wireprotocol; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import com.google.cloud.spanner.connection.AbstractStatementParser.StatementType; import com.google.cloud.spanner.connection.Connection; +import com.google.cloud.spanner.connection.StatementResult; +import com.google.cloud.spanner.connection.StatementResult.ResultType; +import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode; +import com.google.cloud.spanner.pgadapter.ProxyServer; import com.google.cloud.spanner.pgadapter.metadata.ConnectionMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata.TextFormat; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.NoResult; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.UpdateCount; import com.google.cloud.spanner.pgadapter.statements.IntermediateStatement; -import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage; import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage.ManuallyCreatedToken; -import com.google.cloud.spanner.pgadapter.wireprotocol.ExecuteMessage; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; @@ -138,4 +141,19 @@ public void testUnknownStatementTypeDoesNotThrowError() throws Exception { assertEquals(numOfBytes, outputReader.read(bytes, 0, numOfBytes)); assertEquals(resultMessage, new String(bytes, UTF8)); } + + @Test + public void testSendNoRowsAsResultSetFails() { + when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); + IntermediateStatement describedResult = mock(IntermediateStatement.class); + StatementResult statementResult = mock(StatementResult.class); + when(statementResult.getResultType()).thenReturn(ResultType.NO_RESULT); + when(describedResult.getStatementResult()).thenReturn(statementResult); + + ControlMessage message = + new DescribeMessage(connectionHandler, ManuallyCreatedToken.MANUALLY_CREATED_TOKEN); + assertThrows( + IllegalArgumentException.class, + () -> message.sendResultSet(describedResult, QueryMode.SIMPLE, 0L)); + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ProtocolTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ProtocolTest.java index e5c302d17..43cc6e714 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ProtocolTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ProtocolTest.java @@ -19,6 +19,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; @@ -32,11 +33,8 @@ import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.Dialect; -import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.ReadContext; import com.google.cloud.spanner.ResultSet; -import com.google.cloud.spanner.SpannerException; -import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.connection.AbstractStatementParser; import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; @@ -46,6 +44,9 @@ import com.google.cloud.spanner.pgadapter.ConnectionHandler.ConnectionStatus; import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode; import com.google.cloud.spanner.pgadapter.ProxyServer; +import com.google.cloud.spanner.pgadapter.error.PGException; +import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; +import com.google.cloud.spanner.pgadapter.error.SQLState; import com.google.cloud.spanner.pgadapter.metadata.ConnectionMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata.SslMode; @@ -70,6 +71,7 @@ import java.io.DataOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.time.ZoneId; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -256,7 +258,7 @@ public void testParseMessageException() throws Exception { assertEquals(expectedSQL, ((ParseMessage) message).getStatement().getSql()); assertArrayEquals( expectedParameterDataTypes, - ((ParseMessage) message).getStatement().getParameterDataTypes()); + ((ParseMessage) message).getStatement().getGivenParameterDataTypes()); when(connectionHandler.hasStatement(anyString())).thenReturn(false); message.send(); @@ -320,7 +322,7 @@ public void testParseMessage() throws Exception { assertEquals(expectedSQL, ((ParseMessage) message).getStatement().getSql()); assertArrayEquals( expectedParameterDataTypes, - ((ParseMessage) message).getStatement().getParameterDataTypes()); + ((ParseMessage) message).getStatement().getGivenParameterDataTypes()); when(connectionHandler.hasStatement(anyString())).thenReturn(false); message.send(); @@ -385,7 +387,7 @@ public void testParseMessageAcceptsUntypedParameter() throws Exception { assertEquals(expectedSQL, ((ParseMessage) message).getStatement().getSql()); assertArrayEquals( expectedParameterDataTypes, - ((ParseMessage) message).getStatement().getParameterDataTypes()); + ((ParseMessage) message).getStatement().getGivenParameterDataTypes()); when(connectionHandler.hasStatement(anyString())).thenReturn(false); message.send(); @@ -435,7 +437,7 @@ public void testParseMessageWithNonMatchingParameterTypeCount() throws Exception assertEquals(expectedSQL, ((ParseMessage) message).getStatement().getSql()); assertArrayEquals( expectedParameterDataTypes, - ((ParseMessage) message).getStatement().getParameterDataTypes()); + ((ParseMessage) message).getStatement().getGivenParameterDataTypes()); when(connectionHandler.hasStatement(anyString())).thenReturn(false); message.send(); @@ -623,7 +625,11 @@ public void testBindMessage() throws Exception { resultCodesCount); when(connectionHandler.getStatement(anyString())).thenReturn(intermediatePreparedStatement); - when(intermediatePreparedStatement.getSql()).thenReturn("select * from foo"); + when(intermediatePreparedStatement.createPortal(anyString(), any(), any(), any())) + .thenReturn(intermediatePortalStatement); + when(intermediatePortalStatement.getSql()).thenReturn("select * from foo"); + when(intermediatePortalStatement.getPreparedStatement()) + .thenReturn(intermediatePreparedStatement); byte[][] expectedParameters = {parameter}; List expectedFormatCodes = new ArrayList<>(); @@ -648,13 +654,7 @@ public void testBindMessage() throws Exception { assertEquals(expectedFormatCodes, ((BindMessage) message).getFormatCodes()); assertEquals(expectedFormatCodes, ((BindMessage) message).getResultFormatCodes()); assertEquals("select * from foo", ((BindMessage) message).getSql()); - - when(intermediatePreparedStatement.bind( - ArgumentMatchers.anyString(), - ArgumentMatchers.any(), - ArgumentMatchers.any(), - ArgumentMatchers.any())) - .thenReturn(intermediatePortalStatement); + assertTrue(((BindMessage) message).hasParameterValues()); message.send(); ((BindMessage) message).flush(); @@ -726,6 +726,9 @@ public void testBindMessageOneNonTextParam() throws Exception { when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(connectionHandler.getStatement(anyString())).thenReturn(intermediatePreparedStatement); + when(intermediatePreparedStatement.createPortal(anyString(), any(), any(), any())) + .thenReturn(intermediatePortalStatement); WireMessage message = ControlMessage.create(connectionHandler); assertEquals(BindMessage.class, message.getClass()); @@ -796,6 +799,9 @@ public void testBindMessageAllNonTextParam() throws Exception { when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(connectionHandler.getStatement(anyString())).thenReturn(intermediatePreparedStatement); + when(intermediatePreparedStatement.createPortal(anyString(), any(), any(), any())) + .thenReturn(intermediatePortalStatement); WireMessage message = ControlMessage.create(connectionHandler); assertEquals(BindMessage.class, message.getClass()); @@ -882,6 +888,39 @@ public void testDescribeStatementMessage() throws Exception { verify(messageSpy).handleDescribeStatement(); } + @Test + public void testDescribeMessageWithException() throws Exception { + byte[] messageMetadata = {'D'}; + byte[] statementType = {'S'}; + String statementName = "some statement\0"; + + byte[] length = intToBytes(4 + 1 + statementName.length()); + + byte[] value = Bytes.concat(messageMetadata, length, statementType, statementName.getBytes()); + + DataInputStream inputStream = new DataInputStream(new ByteArrayInputStream(value)); + ByteArrayOutputStream result = new ByteArrayOutputStream(); + DataOutputStream outputStream = new DataOutputStream(result); + + when(connectionHandler.getStatement(anyString())).thenReturn(intermediatePreparedStatement); + when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); + when(connectionMetadata.getInputStream()).thenReturn(inputStream); + when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(connectionHandler.getExtendedQueryProtocolHandler()) + .thenReturn(extendedQueryProtocolHandler); + when(intermediatePreparedStatement.hasException()).thenReturn(true); + when(intermediatePreparedStatement.getException()) + .thenReturn(PGExceptionFactory.newPGException("test error", SQLState.InternalError)); + + WireMessage message = ControlMessage.create(connectionHandler); + assertEquals(DescribeMessage.class, message.getClass()); + DescribeMessage describeMessage = (DescribeMessage) message; + + PGException exception = + assertThrows(PGException.class, describeMessage::handleDescribeStatement); + assertEquals("test error", exception.getMessage()); + } + @Test public void testExecuteMessage() throws Exception { byte[] messageMetadata = {'E'}; @@ -943,8 +982,8 @@ public void testExecuteMessageWithException() throws Exception { ByteArrayOutputStream result = new ByteArrayOutputStream(); DataOutputStream outputStream = new DataOutputStream(result); - SpannerException testException = - SpannerExceptionFactory.newSpannerException(ErrorCode.INVALID_ARGUMENT, "test error"); + PGException testException = + PGExceptionFactory.newPGException("test error", SQLState.SyntaxError); when(intermediatePortalStatement.hasException()).thenReturn(true); when(intermediatePortalStatement.getException()).thenReturn(testException); when(connectionHandler.getPortal(anyString())).thenReturn(intermediatePortalStatement); @@ -1584,9 +1623,12 @@ public void testStartUpMessage() throws Exception { readUntil(outputResult, "standard_conforming_strings\0".length())); assertEquals("on\0", readUntil(outputResult, "on\0".length())); assertEquals('S', outputResult.readByte()); - assertEquals(17, outputResult.readInt()); - assertEquals("TimeZone\0", readUntil(outputResult, "TimeZone\0".length())); + // Timezone will vary depending on the default location of the JVM that is running. + String timezoneIdentifier = ZoneId.systemDefault().getId(); + int expectedLength = timezoneIdentifier.getBytes(StandardCharsets.UTF_8).length + 10 + 4; + assertEquals(expectedLength, outputResult.readInt()); + assertEquals("TimeZone\0", readUntil(outputResult, "TimeZone\0".length())); readUntilNullTerminator(outputResult); // ReadyResponse diff --git a/src/test/nodejs/node-postgres/src/index.ts b/src/test/nodejs/node-postgres/src/index.ts index cc7563102..cc833175e 100644 --- a/src/test/nodejs/node-postgres/src/index.ts +++ b/src/test/nodejs/node-postgres/src/index.ts @@ -231,6 +231,7 @@ async function testReadOnlyTransactionWithError(client) { async function testCopyTo(client) { try { + await client.query("set time zone 'UTC'"); const copyTo = require('pg-copy-streams').to; const stream = client.query(copyTo('COPY AllTypes TO STDOUT')); stream.pipe(process.stdout); diff --git a/src/test/nodejs/typeorm/data-test/src/index.ts b/src/test/nodejs/typeorm/data-test/src/index.ts index 5909ee474..b5c0b21cb 100644 --- a/src/test/nodejs/typeorm/data-test/src/index.ts +++ b/src/test/nodejs/typeorm/data-test/src/index.ts @@ -103,7 +103,7 @@ async function testCreateAllTypes(dataSource: DataSource) { const allTypes = { col_bigint: 2, col_bool: true, - col_bytea: Buffer.from(Buffer.from('some random string').toString('base64')), + col_bytea: Buffer.from('some random string'), col_float8: 0.123456789, col_int: 123456789, col_numeric: 234.54235, @@ -118,6 +118,24 @@ async function testCreateAllTypes(dataSource: DataSource) { console.log('Created one record') } +async function testUpdateAllTypes(dataSource: DataSource) { + const repository = dataSource.getRepository(AllTypes); + const row = await repository.findOneBy({col_bigint: 1}); + + row.col_bool = false; + row.col_bytea = Buffer.from(Buffer.from('updated string').toString('base64')); + row.col_float8 = 1.23456789; + row.col_int = 987654321; + row.col_numeric = 6.626; + row.col_timestamptz = new Date(Date.UTC(2022, 10, 16, 10, 3, 42, 999)); + row.col_date = '2022-11-16'; + row.col_varchar = 'some updated string'; + row.col_jsonb = {key: 'updated-value'}; + + await repository.update(row.col_bigint, row); + console.log('Updated one record') +} + require('yargs') .demand(4) .command( @@ -162,6 +180,13 @@ require('yargs') opts => runTest(opts.host, opts.port, opts.database, testCreateAllTypes) ) .example('node $0 createAllTypes 5432') + .command( + 'updateAllTypes ', + 'Updates one row with all types', + {}, + opts => runTest(opts.host, opts.port, opts.database, testUpdateAllTypes) + ) + .example('node $0 updateAllTypes 5432') .wrap(120) .recommendCommands() .strict() diff --git a/src/test/python/django/basic_test.py b/src/test/python/django/basic_test.py index 48e1b710d..9800bed83 100644 --- a/src/test/python/django/basic_test.py +++ b/src/test/python/django/basic_test.py @@ -14,7 +14,6 @@ ''' import sys -from django.core.cache import cache def create_django_setup(host, port): from django.conf import settings @@ -101,7 +100,3 @@ def execute(option): execute(sys.argv[3:]) except Exception as e: print(e) - - - - diff --git a/src/test/python/django/transaction_test.py b/src/test/python/django/transaction_test.py new file mode 100644 index 000000000..d00e76a0c --- /dev/null +++ b/src/test/python/django/transaction_test.py @@ -0,0 +1,114 @@ +''' Copyright 2022 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +''' +from django.db import transaction +import sys + +def create_django_setup(host, port): + from django.conf import settings + from django.apps import apps + conf = { + 'INSTALLED_APPS': [ + 'django.contrib.auth', + 'django.contrib.contenttypes', + 'data' + ], + 'DATABASES': { + 'default': { + 'ENGINE': 'django.db.backends.postgresql_psycopg2', + 'NAME': 'postgres', + 'USER': 'postgres', + 'PASSWORD': 'postgres', + } + }, + } + conf['DATABASES']['default']['PORT'] = port + conf['DATABASES']['default']['HOST'] = host + settings.configure(**conf) + apps.populate(settings.INSTALLED_APPS) + + +def test_commit_transaction(): + transaction.set_autocommit(False) + singer = Singer(singerid=1, firstname='hello', lastname='world') + singer2 = Singer(singerid=2, firstname='hello', lastname='python') + + singer.save() + singer2.save() + + transaction.commit() + print('Transaction Committed') + +def test_rollback_transaction(): + transaction.set_autocommit(False) + + singer = Singer(singerid=1, firstname='hello', lastname='world') + singer2 = Singer(singerid=2, firstname='hello', lastname='python') + + singer.save() + singer2.save() + + transaction.rollback() + print('Transaction Rollbacked') + +def test_atomic(): + with transaction.atomic(): + singer = Singer(singerid=1, firstname='hello', lastname='world') + singer2 = Singer(singerid=2, firstname='hello', lastname='python') + + singer.save() + singer2.save() + print('Atomic Successful') + +def test_nested_atomic(): + with transaction.atomic(savepoint=False): + singer2 = Singer(singerid=2, firstname='hello', lastname='python') + with transaction.atomic(savepoint=False): + singer = Singer(singerid=1, firstname='hello', lastname='world') + singer.save() + singer2.save() + print('Atomic Successful') + + +if __name__ == '__main__': + if len(sys.argv) < 4: + print('Invalid command line arguments') + sys.exit() + host = sys.argv[1] + port = sys.argv[2] + + try: + create_django_setup(host, port) + from data.models import Singer + except Exception as e: + print(e) + sys.exit() + + try: + option = sys.argv[3] + + if option == 'commit': + test_commit_transaction() + elif option == 'rollback': + test_rollback_transaction() + elif option == 'atomic': + test_atomic() + elif option == 'nested_atomic': + test_nested_atomic() + else: + print('Invalid Option') + except Exception as e: + print(e) + + diff --git a/src/test/python/pg8000/basics.py b/src/test/python/pg8000/basics.py new file mode 100644 index 000000000..2c6990997 --- /dev/null +++ b/src/test/python/pg8000/basics.py @@ -0,0 +1,37 @@ +import pg8000.dbapi + +conn = pg8000.dbapi.connect(host="localhost", + port=5433, + database="knut-test-db", + user="test", + password="test") +cursor = conn.cursor() +cursor.execute("DELETE FROM test") +print("delete rowcount = %s" % cursor.rowcount) + +cursor = conn.cursor() +cursor.execute("INSERT INTO test (id, value) VALUES (%s, %s), (%s, %s) RETURNING id, value", + (1, "Ender's Game", 2, "Speaker for the Dead")) +results = cursor.fetchall() +for row in results: + id, value = row + print("id = %s, value = %s" % (id, value)) + +conn.commit() + +cursor.execute("SELECT TIMESTAMPTZ '2021-10-10'") +row = cursor.fetchone() +print("Timestamp = %s" % row) + +conn.close() + +pg8000.dbapi.paramstyle = "numeric" +conn = pg8000.dbapi.connect(host="localhost", + port=5433, + database="knut-test-db", + user="test", + password="test") +cursor = conn.cursor() +cursor.execute("SELECT * FROM test WHERE id=:1", (1,)) +row = cursor.fetchone() +print("Row: %s" % row) diff --git a/src/test/python/pg8000/connect.py b/src/test/python/pg8000/connect.py new file mode 100644 index 000000000..9f86e23a4 --- /dev/null +++ b/src/test/python/pg8000/connect.py @@ -0,0 +1,21 @@ +import pg8000.dbapi +import sys + + +def create_conn(): + host = sys.argv[1] + port = sys.argv[2] + if host.startswith("/"): + conn = pg8000.dbapi.connect(unix_sock=host + "/.s.PGSQL." + port, + database="d", + user="test", + password="test") + else: + conn = pg8000.dbapi.connect(host=host, + port=port, + database="d", + user="test", + password="test") + conn.execute_simple("set time zone 'utc'") + return conn + diff --git a/src/test/python/pg8000/requirements.txt b/src/test/python/pg8000/requirements.txt new file mode 100644 index 000000000..c9557c8cd --- /dev/null +++ b/src/test/python/pg8000/requirements.txt @@ -0,0 +1 @@ +pg8000~=1.29.3 diff --git a/src/test/python/pg8000/select1.py b/src/test/python/pg8000/select1.py new file mode 100644 index 000000000..40352fedf --- /dev/null +++ b/src/test/python/pg8000/select1.py @@ -0,0 +1,10 @@ +import connect + + +with connect.create_conn() as conn: + conn.autocommit = True + + cursor = conn.cursor() + cursor.execute("SELECT 1") + row = cursor.fetchone() + print("SELECT 1: %s" % row) diff --git a/src/test/python/pg8000/select_all_types.py b/src/test/python/pg8000/select_all_types.py new file mode 100644 index 000000000..bda89b04e --- /dev/null +++ b/src/test/python/pg8000/select_all_types.py @@ -0,0 +1,11 @@ +import connect + + +with connect.create_conn() as conn: + conn.autocommit = True + + cursor = conn.cursor() + cursor.execute("SELECT * FROM all_types") + results = cursor.fetchall() + for row in results: + print("row: %s" % row) diff --git a/src/test/python/pg8000/select_parameterized.py b/src/test/python/pg8000/select_parameterized.py new file mode 100644 index 000000000..00a20c973 --- /dev/null +++ b/src/test/python/pg8000/select_parameterized.py @@ -0,0 +1,19 @@ +import connect +import pg8000.dbapi + +pg8000.dbapi.paramstyle = "named" + +with connect.create_conn() as conn: + conn.autocommit = True + + cursor = conn.cursor() + cursor.execute("SELECT * FROM all_types WHERE col_bigint=:id", {"id": 1}) + results = cursor.fetchall() + for row in results: + print("first execution: %s" % row) + + cursor = conn.cursor() + cursor.execute("SELECT * FROM all_types WHERE col_bigint=:id", {"id": 1}) + results = cursor.fetchall() + for row in results: + print("second execution: %s" % row) diff --git a/src/test/python/psycopg2/Batching.py b/src/test/python/psycopg2/Batching.py index 634bbe274..b5b0226ee 100644 --- a/src/test/python/psycopg2/Batching.py +++ b/src/test/python/psycopg2/Batching.py @@ -23,7 +23,7 @@ def create_connection(version, host, port): connection = pg.connect(database="my-database", host=host, port=port, - options="-c server_version=" + version) + options="-c timezone=UTC -c server_version=" + version) connection.autocommit = True return connection except Exception as e: diff --git a/src/test/python/psycopg2/StatementsWithCopy.py b/src/test/python/psycopg2/StatementsWithCopy.py index 76a2e7659..d8693cb24 100644 --- a/src/test/python/psycopg2/StatementsWithCopy.py +++ b/src/test/python/psycopg2/StatementsWithCopy.py @@ -22,7 +22,8 @@ def create_connection(port): connection = pg.connect(user="postgres", database="postgres", host="localhost", - port=port) + port=port, + options="-c timezone=UTC") connection.autocommit = True return connection except Exception as e: diff --git a/src/test/python/psycopg2/StatementsWithNamedParameters.py b/src/test/python/psycopg2/StatementsWithNamedParameters.py index d361c4973..b1c5a69b2 100644 --- a/src/test/python/psycopg2/StatementsWithNamedParameters.py +++ b/src/test/python/psycopg2/StatementsWithNamedParameters.py @@ -24,7 +24,7 @@ def create_connection(version, host, port): connection = pg.connect(database="my-database", host=host, port=port, - options="-c server_version=" + version) + options="-c timezone=UTC -c server_version=" + version) connection.autocommit = True return connection except Exception as e: diff --git a/src/test/python/psycopg2/StatementsWithParameters.py b/src/test/python/psycopg2/StatementsWithParameters.py index e1f6dc7ba..bd77029fd 100644 --- a/src/test/python/psycopg2/StatementsWithParameters.py +++ b/src/test/python/psycopg2/StatementsWithParameters.py @@ -21,7 +21,7 @@ def create_connection(version, host, port): connection = pg.connect(database = "my-database", host = host, port = port, - options="-c server_version=" + version) + options="-c timezone=UTC -c server_version=" + version) connection.autocommit = True return connection except Exception as e: diff --git a/src/test/python/psycopg2/StatementsWithTransactions.py b/src/test/python/psycopg2/StatementsWithTransactions.py index 85499a058..1a68a1c41 100644 --- a/src/test/python/psycopg2/StatementsWithTransactions.py +++ b/src/test/python/psycopg2/StatementsWithTransactions.py @@ -21,7 +21,7 @@ def create_connection(version, host, port): connection = pg.connect(database="my-database", host=host, port=port, - options="-c server_version=" + version) + options="-c timezone=UTC -c server_version=" + version) return connection except Exception as e: print(e) diff --git a/src/test/python/psycopg2/StatementsWithoutParameters.py b/src/test/python/psycopg2/StatementsWithoutParameters.py index b63ff163b..ce19341bc 100644 --- a/src/test/python/psycopg2/StatementsWithoutParameters.py +++ b/src/test/python/psycopg2/StatementsWithoutParameters.py @@ -22,7 +22,7 @@ def create_connection(version, host, port): connection = pg.connect(database="my-database", host=host, port=port, - options="-c server_version=" + version) + options="-c timezone=UTC -c server_version=" + version) connection.autocommit = True return connection except Exception as e: diff --git a/versions.txt b/versions.txt index ac347e64a..1a53ddeb4 100644 --- a/versions.txt +++ b/versions.txt @@ -1,4 +1,4 @@ # Format: # module:released-version:current-version -google-cloud-spanner-pgadapter:0.12.0:0.12.0 +google-cloud-spanner-pgadapter:0.13.0:0.13.0