Skip to content

Commit

Permalink
samples: bit reverse sequence (#937)
Browse files Browse the repository at this point in the history
* samples: bit reverse sequence

* fix: sequence

* fix: review comments

* fix: lint E721

* fix: new database for sequence test

* fix: lint and blacken

* fix: doc

* fix: database name
  • Loading branch information
surbhigarg92 authored Aug 2, 2023
1 parent ba5ff0b commit 79e1398
Show file tree
Hide file tree
Showing 19 changed files with 839 additions and 534 deletions.
4 changes: 2 additions & 2 deletions google/cloud/spanner_dbapi/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __repr__(self):
return self.__str__()

def __eq__(self, other):
if type(self) != type(other):
if type(self) is not type(other):
return False
if self.name != other.name:
return False
Expand Down Expand Up @@ -95,7 +95,7 @@ def __len__(self):
return len(self.argv)

def __eq__(self, other):
if type(self) != type(other):
if type(self) is not type(other):
return False

if len(self) != len(other):
Expand Down
4 changes: 2 additions & 2 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ def _merge_query_options(base, merge):
If the resultant object only has empty fields, returns None.
"""
combined = base or ExecuteSqlRequest.QueryOptions()
if type(combined) == dict:
if type(combined) is dict:
combined = ExecuteSqlRequest.QueryOptions(
optimizer_version=combined.get("optimizer_version", ""),
optimizer_statistics_package=combined.get(
"optimizer_statistics_package", ""
),
)
merge = merge or ExecuteSqlRequest.QueryOptions()
if type(merge) == dict:
if type(merge) is dict:
merge = ExecuteSqlRequest.QueryOptions(
optimizer_version=merge.get("optimizer_version", ""),
optimizer_statistics_package=merge.get("optimizer_statistics_package", ""),
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
self._max_expire_time = None
self._referencing_backups = None
self._database_dialect = None
if type(encryption_config) == dict:
if type(encryption_config) is dict:
if source_backup:
self._encryption_config = CopyBackupEncryptionConfig(
**encryption_config
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def commit(self, return_commit_stats=False, request_options=None):

if request_options is None:
request_options = RequestOptions()
elif type(request_options) == dict:
elif type(request_options) is dict:
request_options = RequestOptions(request_options)
request_options.transaction_tag = self.transaction_tag

Expand Down
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(
):
self._emulator_host = _get_spanner_emulator_host()

if client_options and type(client_options) == dict:
if client_options and type(client_options) is dict:
self._client_options = google.api_core.client_options.from_dict(
client_options
)
Expand Down
8 changes: 4 additions & 4 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def create(self):
db_name = f'"{db_name}"'
else:
db_name = f"`{db_name}`"
if type(self._encryption_config) == dict:
if type(self._encryption_config) is dict:
self._encryption_config = EncryptionConfig(**self._encryption_config)

request = CreateDatabaseRequest(
Expand Down Expand Up @@ -621,7 +621,7 @@ def execute_partitioned_dml(
)
if request_options is None:
request_options = RequestOptions()
elif type(request_options) == dict:
elif type(request_options) is dict:
request_options = RequestOptions(request_options)
request_options.transaction_tag = None

Expand Down Expand Up @@ -806,7 +806,7 @@ def restore(self, source):
"""
if source is None:
raise ValueError("Restore source not specified")
if type(self._encryption_config) == dict:
if type(self._encryption_config) is dict:
self._encryption_config = RestoreDatabaseEncryptionConfig(
**self._encryption_config
)
Expand Down Expand Up @@ -1011,7 +1011,7 @@ def __init__(self, database, request_options=None):
self._session = self._batch = None
if request_options is None:
self._request_options = RequestOptions()
elif type(request_options) == dict:
elif type(request_options) is dict:
self._request_options = RequestOptions(request_options)
else:
self._request_options = request_options
Expand Down
4 changes: 2 additions & 2 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def read(

if request_options is None:
request_options = RequestOptions()
elif type(request_options) == dict:
elif type(request_options) is dict:
request_options = RequestOptions(request_options)

if self._read_only:
Expand Down Expand Up @@ -414,7 +414,7 @@ def execute_sql(

if request_options is None:
request_options = RequestOptions()
elif type(request_options) == dict:
elif type(request_options) is dict:
request_options = RequestOptions(request_options)
if self._read_only:
# Transaction tags are not supported for read only transactions.
Expand Down
6 changes: 3 additions & 3 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def commit(self, return_commit_stats=False, request_options=None):

if request_options is None:
request_options = RequestOptions()
elif type(request_options) == dict:
elif type(request_options) is dict:
request_options = RequestOptions(request_options)
if self.transaction_tag is not None:
request_options.transaction_tag = self.transaction_tag
Expand Down Expand Up @@ -352,7 +352,7 @@ def execute_update(

if request_options is None:
request_options = RequestOptions()
elif type(request_options) == dict:
elif type(request_options) is dict:
request_options = RequestOptions(request_options)
request_options.transaction_tag = self.transaction_tag

Expand Down Expand Up @@ -463,7 +463,7 @@ def batch_update(self, statements, request_options=None):

if request_options is None:
request_options = RequestOptions()
elif type(request_options) == dict:
elif type(request_options) is dict:
request_options = RequestOptions(request_options)
request_options.transaction_tag = self.transaction_tag

Expand Down
144 changes: 90 additions & 54 deletions samples/samples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,19 @@
def sample_name():
"""Sample testcase modules must define this fixture.
The name is used to label the instance created by the sample, to
aid in debugging leaked instances.
"""
raise NotImplementedError(
"Define 'sample_name' fixture in sample test driver")
The name is used to label the instance created by the sample, to
aid in debugging leaked instances.
"""
raise NotImplementedError("Define 'sample_name' fixture in sample test driver")


@pytest.fixture(scope="module")
def database_dialect():
"""Database dialect to be used for this sample.
The dialect is used to initialize the dialect for the database.
It can either be GoogleStandardSql or PostgreSql.
"""
The dialect is used to initialize the dialect for the database.
It can either be GoogleStandardSql or PostgreSql.
"""
# By default, we consider GOOGLE_STANDARD_SQL dialect. Other specific tests
# can override this if required.
return DatabaseDialect.GOOGLE_STANDARD_SQL
Expand Down Expand Up @@ -105,7 +104,7 @@ def multi_region_instance_id():
@pytest.fixture(scope="module")
def instance_config(spanner_client):
return "{}/instanceConfigs/{}".format(
spanner_client.project_name, "regional-us-central1"
spanner_client.project_name, "regional-us-central1"
)


Expand All @@ -116,20 +115,20 @@ def multi_region_instance_config(spanner_client):

@pytest.fixture(scope="module")
def sample_instance(
spanner_client,
cleanup_old_instances,
instance_id,
instance_config,
sample_name,
spanner_client,
cleanup_old_instances,
instance_id,
instance_config,
sample_name,
):
sample_instance = spanner_client.instance(
instance_id,
instance_config,
labels={
"cloud_spanner_samples": "true",
"sample_name": sample_name,
"created": str(int(time.time())),
},
instance_id,
instance_config,
labels={
"cloud_spanner_samples": "true",
"sample_name": sample_name,
"created": str(int(time.time())),
},
)
op = retry_429(sample_instance.create)()
op.result(INSTANCE_CREATION_TIMEOUT) # block until completion
Expand All @@ -151,20 +150,20 @@ def sample_instance(

@pytest.fixture(scope="module")
def multi_region_instance(
spanner_client,
cleanup_old_instances,
multi_region_instance_id,
multi_region_instance_config,
sample_name,
spanner_client,
cleanup_old_instances,
multi_region_instance_id,
multi_region_instance_config,
sample_name,
):
multi_region_instance = spanner_client.instance(
multi_region_instance_id,
multi_region_instance_config,
labels={
"cloud_spanner_samples": "true",
"sample_name": sample_name,
"created": str(int(time.time())),
},
multi_region_instance_id,
multi_region_instance_config,
labels={
"cloud_spanner_samples": "true",
"sample_name": sample_name,
"created": str(int(time.time())),
},
)
op = retry_429(multi_region_instance.create)()
op.result(INSTANCE_CREATION_TIMEOUT) # block until completion
Expand All @@ -188,44 +187,49 @@ def multi_region_instance(
def database_id():
"""Id for the database used in samples.
Sample testcase modules can override as needed.
"""
Sample testcase modules can override as needed.
"""
return "my-database-id"


@pytest.fixture(scope="module")
def bit_reverse_sequence_database_id():
"""Id for the database used in bit reverse sequence samples.
Sample testcase modules can override as needed.
"""
return "sequence-database-id"


@pytest.fixture(scope="module")
def database_ddl():
"""Sequence of DDL statements used to set up the database.
Sample testcase modules can override as needed.
"""
Sample testcase modules can override as needed.
"""
return []


@pytest.fixture(scope="module")
def sample_database(
spanner_client,
sample_instance,
database_id,
database_ddl,
database_dialect):
spanner_client, sample_instance, database_id, database_ddl, database_dialect
):
if database_dialect == DatabaseDialect.POSTGRESQL:
sample_database = sample_instance.database(
database_id,
database_dialect=DatabaseDialect.POSTGRESQL,
database_id,
database_dialect=DatabaseDialect.POSTGRESQL,
)

if not sample_database.exists():
operation = sample_database.create()
operation.result(OPERATION_TIMEOUT_SECONDS)

request = spanner_admin_database_v1.UpdateDatabaseDdlRequest(
database=sample_database.name,
statements=database_ddl,
database=sample_database.name,
statements=database_ddl,
)

operation =\
spanner_client.database_admin_api.update_database_ddl(request)
operation = spanner_client.database_admin_api.update_database_ddl(request)
operation.result(OPERATION_TIMEOUT_SECONDS)

yield sample_database
Expand All @@ -234,8 +238,8 @@ def sample_database(
return

sample_database = sample_instance.database(
database_id,
ddl_statements=database_ddl,
database_id,
ddl_statements=database_ddl,
)

if not sample_database.exists():
Expand All @@ -247,11 +251,43 @@ def sample_database(
sample_database.drop()


@pytest.fixture(scope="module")
def bit_reverse_sequence_database(
spanner_client, sample_instance, bit_reverse_sequence_database_id, database_dialect
):
if database_dialect == DatabaseDialect.POSTGRESQL:
bit_reverse_sequence_database = sample_instance.database(
bit_reverse_sequence_database_id,
database_dialect=DatabaseDialect.POSTGRESQL,
)

if not bit_reverse_sequence_database.exists():
operation = bit_reverse_sequence_database.create()
operation.result(OPERATION_TIMEOUT_SECONDS)

yield bit_reverse_sequence_database

bit_reverse_sequence_database.drop()
return

bit_reverse_sequence_database = sample_instance.database(
bit_reverse_sequence_database_id
)

if not bit_reverse_sequence_database.exists():
operation = bit_reverse_sequence_database.create()
operation.result(OPERATION_TIMEOUT_SECONDS)

yield bit_reverse_sequence_database

bit_reverse_sequence_database.drop()


@pytest.fixture(scope="module")
def kms_key_name(spanner_client):
return "projects/{}/locations/{}/keyRings/{}/cryptoKeys/{}".format(
spanner_client.project,
"us-central1",
"spanner-test-keyring",
"spanner-test-cmek",
spanner_client.project,
"us-central1",
"spanner-test-keyring",
"spanner-test-cmek",
)
Loading

0 comments on commit 79e1398

Please sign in to comment.