diff --git a/.kokoro/build.sh b/.kokoro/build.sh
index 64fe343878..031fa2200a 100755
--- a/.kokoro/build.sh
+++ b/.kokoro/build.sh
@@ -33,9 +33,6 @@ export GOOGLE_APPLICATION_CREDENTIALS=${KOKORO_GFILE_DIR}/service-account.json
# Setup project id.
export PROJECT_ID=$(cat "${KOKORO_GFILE_DIR}/project-id.json")
-# Setup staging endpoint.
-export STAGING_ENDPOINT=$(cat "${KOKORO_KEYSTORE_DIR}/73713_vertexai-staging-endpoint")
-
# Remove old nox
python3 -m pip uninstall --yes --quiet nox-automation
diff --git a/.release-please-manifest.json b/.release-please-manifest.json
index 3ed68cb97c..979c42f7ec 100644
--- a/.release-please-manifest.json
+++ b/.release-please-manifest.json
@@ -1,3 +1,3 @@
{
- ".": "1.64.0"
+ ".": "1.65.0"
}
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 49f8adb800..d31b3a100f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,29 @@
# Changelog
+## [1.65.0](https://github.com/googleapis/python-aiplatform/compare/v1.64.0...v1.65.0) (2024-09-04)
+
+
+### âš BREAKING CHANGES
+
+* Tokenization - Deprecated `ComputeTokenResult.token_info_list` in favor of `ComputeTokenResult.tokens_info`
+
+### Features
+
+* Add support for system instruction and tools in tokenization. ([72fcc06](https://github.com/googleapis/python-aiplatform/commit/72fcc063ed4a086da0ad37ec2ac58860d4e79051))
+* Add vector search alpha to rag retrieval for hybrid search ranking ([6624ebe](https://github.com/googleapis/python-aiplatform/commit/6624ebe22726942dd70781122352f47268ee2dee))
+* Adding Weaviate Vector DB option for RAG corpuses to SDK ([9b28202](https://github.com/googleapis/python-aiplatform/commit/9b28202a1bb17f54e042301d6cdac3b6aa826797))
+* GenAI - Added `system_instruction` and `tools` support to `GenerativeModel.count_tokens` ([50fca69](https://github.com/googleapis/python-aiplatform/commit/50fca693b2f3b1a0b61867dc136be5a468fb2b2f))
+* GenAI - Added Llama3 support in GenAI batch prediction ([6166152](https://github.com/googleapis/python-aiplatform/commit/6166152844dc0078f7a5a02355ef3555cc428cfa))
+* PrivateEndpoint.stream_raw_predict ([197f333](https://github.com/googleapis/python-aiplatform/commit/197f333be5a075d41f98b762cd933cd2e89cecae))
+* Support reserved_ip_ranges for VPC network in Ray on Vertex cluster ([36a56b9](https://github.com/googleapis/python-aiplatform/commit/36a56b99f9e53d19d80c2bff3bf55c208988c518))
+* Tokenization - Deprecated `ComputeTokenResult.token_info_list` in favor of `ComputeTokenResult.tokens_info` ([efbcb54](https://github.com/googleapis/python-aiplatform/commit/efbcb54e0d5df4d65a79e60afdbc5c328538aef6))
+
+
+### Bug Fixes
+
+* Tensorboard - Fixed bug in tensorboard uploader where profile logs are not uploaded from nested run directories. ([37627de](https://github.com/googleapis/python-aiplatform/commit/37627de4ec12b8c51abf41524b9e7e3adf2dab54))
+* Tokenizers - Fixed `Tokenizer.compute_tokens` ([c29fa5d](https://github.com/googleapis/python-aiplatform/commit/c29fa5d98fc1c1550c039e84ff2f5725818c2231))
+
## [1.64.0](https://github.com/googleapis/python-aiplatform/compare/v1.63.0...v1.64.0) (2024-08-27)
diff --git a/gemini_docs/vertexai/services.rst b/gemini_docs/vertexai/services.rst
index 8a2cca514e..3496c44afc 100644
--- a/gemini_docs/vertexai/services.rst
+++ b/gemini_docs/vertexai/services.rst
@@ -45,6 +45,16 @@ Vertex AI SDK
:show-inheritance:
:inherited-members:
+.. automodule:: vertexai.preview.tuning
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. automodule:: vertexai.preview.tuning.sft
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
.. automodule:: vertexai.evaluation
:members:
:show-inheritance:
diff --git a/google/cloud/aiplatform/gapic_version.py b/google/cloud/aiplatform/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/gapic_version.py
+++ b/google/cloud/aiplatform/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py
index d3ddc86962..dc8e105ef6 100644
--- a/google/cloud/aiplatform/models.py
+++ b/google/cloud/aiplatform/models.py
@@ -3666,6 +3666,95 @@ def raw_predict(
headers=headers_with_token,
)
+ def stream_raw_predict(
+ self,
+ body: bytes,
+ headers: Dict[str, str],
+ endpoint_override: Optional[str] = None,
+ ) -> Iterator[bytes]:
+ """Make a streaming prediction request using arbitrary headers.
+
+ Example usage:
+ my_endpoint = aiplatform.PrivateEndpoint(ENDPOINT_ID)
+
+ # Prepare the request body
+ request_body = json.dumps({...}).encode('utf-8')
+
+ # Define the headers
+ headers = {
+ 'Content-Type': 'application/json',
+ }
+
+ # Use stream_raw_predict to send the request and process the response
+ for stream_response in psc_endpoint.stream_raw_predict(
+ body=request_body,
+ headers=headers,
+ endpoint_override="10.128.0.26" # Replace with your actual endpoint
+ ):
+ stream_response_text = stream_response.decode('utf-8')
+
+ Args:
+ body (bytes):
+ The body of the prediction request in bytes. This must not
+ exceed 10 mb per request.
+ headers (Dict[str, str]):
+ The header of the request as a dictionary. There are no
+ restrictions on the header.
+ endpoint_override (Optional[str]):
+ The Private Service Connect endpoint's IP address or DNS that
+ points to the endpoint's service attachment.
+
+ Yields:
+ predictions (Iterator[bytes]):
+ The streaming prediction results as lines of bytes.
+
+ Raises:
+ ValueError: If a endpoint override is not provided for PSC based
+ endpoint.
+ ValueError: If a endpoint override is invalid for PSC based endpoint.
+ """
+ self.wait()
+ if self.network or not self.private_service_connect_config:
+ raise ValueError(
+ "PSA based private endpoint does not support streaming prediction."
+ )
+
+ if self.private_service_connect_config:
+ if not endpoint_override:
+ raise ValueError(
+ "Cannot make a predict request because endpoint override is"
+ "not provided. Please ensure an endpoint override is"
+ "provided."
+ )
+ if not self._validate_endpoint_override(endpoint_override):
+ raise ValueError(
+ "Invalid endpoint override provided. Please only use IP"
+ "address or DNS."
+ )
+ if not self.credentials.valid:
+ self.credentials.refresh(google_auth_requests.Request())
+
+ token = self.credentials.token
+ headers_with_token = dict(headers)
+ headers_with_token["Authorization"] = f"Bearer {token}"
+
+ if not self.authorized_session:
+ self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES
+ self.authorized_session = google_auth_requests.AuthorizedSession(
+ self.credentials
+ )
+
+ url = f"https://{endpoint_override}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:streamRawPredict"
+ with self.authorized_session.post(
+ url=url,
+ data=body,
+ headers=headers_with_token,
+ stream=True,
+ verify=False,
+ ) as resp:
+ for line in resp.iter_lines():
+ yield line
+
def explain(self):
raise NotImplementedError(
f"{self.__class__.__name__} class does not support 'explain' as of now."
diff --git a/google/cloud/aiplatform/tensorboard/logdir_loader.py b/google/cloud/aiplatform/tensorboard/logdir_loader.py
index ea9b6200f7..bd55e9505a 100644
--- a/google/cloud/aiplatform/tensorboard/logdir_loader.py
+++ b/google/cloud/aiplatform/tensorboard/logdir_loader.py
@@ -70,7 +70,6 @@ def synchronize_runs(self):
runs_seen = set()
for subdir in io_wrapper.GetLogdirSubdirectories(self._logdir):
run = os.path.relpath(subdir, self._logdir)
- run = run.replace("/", "-").replace("_", "-")
runs_seen.add(run)
if run not in self._directory_loaders:
logger.info("- Adding run for relative directory %s", run)
diff --git a/google/cloud/aiplatform/tensorboard/plugins/tf_profiler/profile_uploader.py b/google/cloud/aiplatform/tensorboard/plugins/tf_profiler/profile_uploader.py
index f1967420ba..cb8f119172 100644
--- a/google/cloud/aiplatform/tensorboard/plugins/tf_profiler/profile_uploader.py
+++ b/google/cloud/aiplatform/tensorboard/plugins/tf_profiler/profile_uploader.py
@@ -176,7 +176,10 @@ def send_request(self, run_name: str):
self._profile_dir(run_name)
)
- tb_run = self._one_platform_resource_manager.get_run_resource_name(run_name)
+ experiment_run_name = uploader_utils.reformat_run_name(run_name)
+ tb_run = self._one_platform_resource_manager.get_run_resource_name(
+ experiment_run_name
+ )
if run_name not in self._run_to_file_request_sender:
self._run_to_file_request_sender[
diff --git a/google/cloud/aiplatform/tensorboard/uploader.py b/google/cloud/aiplatform/tensorboard/uploader.py
index 5e928ac812..c24455ae5e 100644
--- a/google/cloud/aiplatform/tensorboard/uploader.py
+++ b/google/cloud/aiplatform/tensorboard/uploader.py
@@ -417,7 +417,8 @@ def _pre_create_runs_and_time_series(self):
plugin_data=metadata.plugin_data.content,
)
- self._one_platform_resource_manager.batch_create_runs(run_names)
+ experiment_runs = [uploader_utils.reformat_run_name(run) for run in run_names]
+ self._one_platform_resource_manager.batch_create_runs(experiment_runs)
self._one_platform_resource_manager.batch_create_time_series(
run_tag_name_to_time_series_proto
)
@@ -451,7 +452,9 @@ def _upload_once(self):
)
run_to_events[profile_run_name] = None
- self._experiment_runs = run_to_events.keys()
+ self._experiment_runs = [
+ uploader_utils.reformat_run_name(run) for run in run_to_events.keys()
+ ]
with self._tracker.send_tracker():
self._dispatcher.dispatch_requests(run_to_events)
@@ -807,6 +810,7 @@ def _add_event_internal(
metadata: tf.compat.v1.SummaryMetadata,
):
self._num_values += 1
+ run_name = uploader_utils.reformat_run_name(run_name)
time_series_data_proto = self._run_to_tag_to_time_series_data[run_name].get(
value.tag
)
diff --git a/google/cloud/aiplatform/tensorboard/uploader_utils.py b/google/cloud/aiplatform/tensorboard/uploader_utils.py
index c9723d0f89..8c5b9ec811 100644
--- a/google/cloud/aiplatform/tensorboard/uploader_utils.py
+++ b/google/cloud/aiplatform/tensorboard/uploader_utils.py
@@ -47,6 +47,11 @@
logger.setLevel(logging.WARNING)
+def reformat_run_name(run_name: str) -> str:
+ """Reformats the run name to be compatible with One Platform."""
+ return run_name.replace("/", "-").replace("_", "-")
+
+
class RateLimiter:
"""Helper class for rate-limiting using a fixed minimum interval."""
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/version.py b/google/cloud/aiplatform/version.py
index c50dd04e4a..8df11f50b0 100644
--- a/google/cloud/aiplatform/version.py
+++ b/google/cloud/aiplatform/version.py
@@ -15,4 +15,4 @@
# limitations under the License.
#
-__version__ = "1.64.0"
+__version__ = "1.65.0"
diff --git a/google/cloud/aiplatform/vertex_ray/cluster_init.py b/google/cloud/aiplatform/vertex_ray/cluster_init.py
index da3da31449..bff297bd8a 100644
--- a/google/cloud/aiplatform/vertex_ray/cluster_init.py
+++ b/google/cloud/aiplatform/vertex_ray/cluster_init.py
@@ -61,6 +61,7 @@ def create_ray_cluster(
enable_metrics_collection: Optional[bool] = True,
enable_logging: Optional[bool] = True,
psc_interface_config: Optional[resources.PscIConfig] = None,
+ reserved_ip_ranges: Optional[List[str]] = None,
labels: Optional[Dict[str, str]] = None,
) -> str:
"""Create a ray cluster on the Vertex AI.
@@ -126,6 +127,11 @@ def create_ray_cluster(
enable_metrics_collection: Enable Ray metrics collection for visualization.
enable_logging: Enable exporting Ray logs to Cloud Logging.
psc_interface_config: PSC-I config.
+ reserved_ip_ranges: A list of names for the reserved IP ranges under
+ the VPC network that can be used for this cluster. If set, we will
+ deploy the cluster within the provided IP ranges. Otherwise, the
+ cluster is deployed to any IP ranges under the provided VPC network.
+ Example: ["vertex-ai-ip-range"].
labels:
The labels with user-defined metadata to organize Ray cluster.
@@ -325,6 +331,7 @@ def create_ray_cluster(
labels=labels,
resource_runtime_spec=resource_runtime_spec,
psc_interface_config=gapic_psc_interface_config,
+ reserved_ip_ranges=reserved_ip_ranges,
)
location = initializer.global_config.location
diff --git a/google/cloud/aiplatform/vertex_ray/predict/xgboost/register.py b/google/cloud/aiplatform/vertex_ray/predict/xgboost/register.py
index c93c6ce3e9..5aa35fa92c 100644
--- a/google/cloud/aiplatform/vertex_ray/predict/xgboost/register.py
+++ b/google/cloud/aiplatform/vertex_ray/predict/xgboost/register.py
@@ -43,7 +43,10 @@
import xgboost
except ModuleNotFoundError as mnfe:
- raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe
+ if ray.__version__ == "2.9.3":
+ raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe
+ else:
+ xgboost = None
def register_xgboost(
diff --git a/google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py b/google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py
index 73e60c8121..8ed158e012 100644
--- a/google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py
+++ b/google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py
@@ -150,6 +150,7 @@ def persistent_resource_to_cluster(
cluster = Cluster(
cluster_resource_name=persistent_resource.name,
network=persistent_resource.network,
+ reserved_ip_ranges=persistent_resource.reserved_ip_ranges,
state=persistent_resource.state.name,
labels=persistent_resource.labels,
dashboard_address=dashboard_address,
@@ -238,6 +239,11 @@ def persistent_resource_to_cluster(
custom_image=head_image_uri,
)
)
+ if head_resource_pool.autoscaling_spec:
+ worker_node_types[0].autoscaling_spec = AutoscalingSpec(
+ min_replica_count=head_resource_pool.autoscaling_spec.min_replica_count,
+ max_replica_count=head_resource_pool.autoscaling_spec.max_replica_count,
+ )
for i in range(len(resource_pools) - 1):
# Convert the second and more resource pools to vertex_ray.Resources,
# and append then to worker_node_types.
diff --git a/google/cloud/aiplatform/vertex_ray/util/resources.py b/google/cloud/aiplatform/vertex_ray/util/resources.py
index 37fae9be8d..117cd69e1c 100644
--- a/google/cloud/aiplatform/vertex_ray/util/resources.py
+++ b/google/cloud/aiplatform/vertex_ray/util/resources.py
@@ -117,6 +117,11 @@ class Cluster:
managed in the Vertex API service. For Ray Job API, VPC network is
not required because cluster connection can be accessed through
dashboard address.
+ reserved_ip_ranges: A list of names for the reserved IP ranges under
+ the VPC network that can be used for this cluster. If set, we will
+ deploy the cluster within the provided IP ranges. Otherwise, the
+ cluster is deployed to any IP ranges under the provided VPC network.
+ Example: ["vertex-ai-ip-range"].
service_account: Service account to be used for running Ray programs on
the cluster.
state: Describes the cluster state (defined in PersistentResource.State).
@@ -140,6 +145,7 @@ class Cluster:
cluster_resource_name: str = None
network: str = None
+ reserved_ip_ranges: List[str] = None
service_account: str = None
state: PersistentResource.State = None
python_version: str = None
diff --git a/google/cloud/aiplatform_v1/gapic_version.py b/google/cloud/aiplatform_v1/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform_v1/gapic_version.py
+++ b/google/cloud/aiplatform_v1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform_v1/services/schedule_service/async_client.py b/google/cloud/aiplatform_v1/services/schedule_service/async_client.py
index a8b76c8eba..db7cf616f1 100644
--- a/google/cloud/aiplatform_v1/services/schedule_service/async_client.py
+++ b/google/cloud/aiplatform_v1/services/schedule_service/async_client.py
@@ -46,6 +46,7 @@
from google.api_core import operation as gac_operation # type: ignore
from google.api_core import operation_async # type: ignore
from google.cloud.aiplatform_v1.services.schedule_service import pagers
+from google.cloud.aiplatform_v1.types import notebook_service
from google.cloud.aiplatform_v1.types import operation as gca_operation
from google.cloud.aiplatform_v1.types import pipeline_service
from google.cloud.aiplatform_v1.types import schedule
@@ -88,6 +89,18 @@ class ScheduleServiceAsyncClient:
parse_execution_path = staticmethod(ScheduleServiceClient.parse_execution_path)
network_path = staticmethod(ScheduleServiceClient.network_path)
parse_network_path = staticmethod(ScheduleServiceClient.parse_network_path)
+ notebook_execution_job_path = staticmethod(
+ ScheduleServiceClient.notebook_execution_job_path
+ )
+ parse_notebook_execution_job_path = staticmethod(
+ ScheduleServiceClient.parse_notebook_execution_job_path
+ )
+ notebook_runtime_template_path = staticmethod(
+ ScheduleServiceClient.notebook_runtime_template_path
+ )
+ parse_notebook_runtime_template_path = staticmethod(
+ ScheduleServiceClient.parse_notebook_runtime_template_path
+ )
pipeline_job_path = staticmethod(ScheduleServiceClient.pipeline_job_path)
parse_pipeline_job_path = staticmethod(
ScheduleServiceClient.parse_pipeline_job_path
diff --git a/google/cloud/aiplatform_v1/services/schedule_service/client.py b/google/cloud/aiplatform_v1/services/schedule_service/client.py
index 33aa3f04d8..43a4c50a65 100644
--- a/google/cloud/aiplatform_v1/services/schedule_service/client.py
+++ b/google/cloud/aiplatform_v1/services/schedule_service/client.py
@@ -51,6 +51,7 @@
from google.api_core import operation as gac_operation # type: ignore
from google.api_core import operation_async # type: ignore
from google.cloud.aiplatform_v1.services.schedule_service import pagers
+from google.cloud.aiplatform_v1.types import notebook_service
from google.cloud.aiplatform_v1.types import operation as gca_operation
from google.cloud.aiplatform_v1.types import pipeline_service
from google.cloud.aiplatform_v1.types import schedule
@@ -311,6 +312,50 @@ def parse_network_path(path: str) -> Dict[str, str]:
)
return m.groupdict() if m else {}
+ @staticmethod
+ def notebook_execution_job_path(
+ project: str,
+ location: str,
+ notebook_execution_job: str,
+ ) -> str:
+ """Returns a fully-qualified notebook_execution_job string."""
+ return "projects/{project}/locations/{location}/notebookExecutionJobs/{notebook_execution_job}".format(
+ project=project,
+ location=location,
+ notebook_execution_job=notebook_execution_job,
+ )
+
+ @staticmethod
+ def parse_notebook_execution_job_path(path: str) -> Dict[str, str]:
+ """Parses a notebook_execution_job path into its component segments."""
+ m = re.match(
+ r"^projects/(?P.+?)/locations/(?P.+?)/notebookExecutionJobs/(?P.+?)$",
+ path,
+ )
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def notebook_runtime_template_path(
+ project: str,
+ location: str,
+ notebook_runtime_template: str,
+ ) -> str:
+ """Returns a fully-qualified notebook_runtime_template string."""
+ return "projects/{project}/locations/{location}/notebookRuntimeTemplates/{notebook_runtime_template}".format(
+ project=project,
+ location=location,
+ notebook_runtime_template=notebook_runtime_template,
+ )
+
+ @staticmethod
+ def parse_notebook_runtime_template_path(path: str) -> Dict[str, str]:
+ """Parses a notebook_runtime_template path into its component segments."""
+ m = re.match(
+ r"^projects/(?P.+?)/locations/(?P.+?)/notebookRuntimeTemplates/(?P.+?)$",
+ path,
+ )
+ return m.groupdict() if m else {}
+
@staticmethod
def pipeline_job_path(
project: str,
diff --git a/google/cloud/aiplatform_v1/types/custom_job.py b/google/cloud/aiplatform_v1/types/custom_job.py
index 709f7261f5..3068f16a91 100644
--- a/google/cloud/aiplatform_v1/types/custom_job.py
+++ b/google/cloud/aiplatform_v1/types/custom_job.py
@@ -558,6 +558,12 @@ class Scheduling(proto.Message):
Optional. Indicates if the job should retry for internal
errors after the job starts running. If true, overrides
``Scheduling.restart_job_on_worker_restart`` to false.
+ max_wait_duration (google.protobuf.duration_pb2.Duration):
+ Optional. This is the maximum duration that a job will wait
+ for the requested resources to be provisioned if the
+ scheduling strategy is set to [Strategy.DWS_FLEX_START]. If
+ set to 0, the job will wait indefinitely. The default is 24
+ hours.
"""
class Strategy(proto.Enum):
@@ -606,6 +612,11 @@ class Strategy(proto.Enum):
proto.BOOL,
number=5,
)
+ max_wait_duration: duration_pb2.Duration = proto.Field(
+ proto.MESSAGE,
+ number=6,
+ message=duration_pb2.Duration,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform_v1/types/schedule.py b/google/cloud/aiplatform_v1/types/schedule.py
index 5d6e4277d1..8cd50a40fd 100644
--- a/google/cloud/aiplatform_v1/types/schedule.py
+++ b/google/cloud/aiplatform_v1/types/schedule.py
@@ -19,6 +19,7 @@
import proto # type: ignore
+from google.cloud.aiplatform_v1.types import notebook_service
from google.cloud.aiplatform_v1.types import pipeline_service
from google.protobuf import timestamp_pb2 # type: ignore
@@ -36,6 +37,10 @@ class Schedule(proto.Message):
API calls based on user specified time specification and API
request type.
+ This message has `oneof`_ fields (mutually exclusive fields).
+ For each oneof, at most one member field can be set at the same time.
+ Setting any member of the oneof automatically clears all other
+ members.
.. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
@@ -56,6 +61,11 @@ class Schedule(proto.Message):
CreatePipelineJobRequest.parent field is required (format:
projects/{project}/locations/{location}).
+ This field is a member of `oneof`_ ``request``.
+ create_notebook_execution_job_request (google.cloud.aiplatform_v1.types.CreateNotebookExecutionJobRequest):
+ Request for
+ [NotebookService.CreateNotebookExecutionJob][google.cloud.aiplatform.v1.NotebookService.CreateNotebookExecutionJob].
+
This field is a member of `oneof`_ ``request``.
name (str):
Immutable. The resource name of the Schedule.
@@ -186,6 +196,12 @@ class RunResponse(proto.Message):
message=pipeline_service.CreatePipelineJobRequest,
)
)
+ create_notebook_execution_job_request: notebook_service.CreateNotebookExecutionJobRequest = proto.Field(
+ proto.MESSAGE,
+ number=20,
+ oneof="request",
+ message=notebook_service.CreateNotebookExecutionJobRequest,
+ )
name: str = proto.Field(
proto.STRING,
number=1,
diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py
index 2fb939444e..3bd01743e9 100644
--- a/google/cloud/aiplatform_v1beta1/__init__.py
+++ b/google/cloud/aiplatform_v1beta1/__init__.py
@@ -1002,12 +1002,15 @@
from .types.unmanaged_container_model import UnmanagedContainerModel
from .types.user_action_reference import UserActionReference
from .types.value import Value
+from .types.vertex_rag_data import CorpusStatus
+from .types.vertex_rag_data import FileStatus
from .types.vertex_rag_data import ImportRagFilesConfig
from .types.vertex_rag_data import RagCorpus
from .types.vertex_rag_data import RagEmbeddingModelConfig
from .types.vertex_rag_data import RagFile
from .types.vertex_rag_data import RagFileChunkingConfig
from .types.vertex_rag_data import RagFileParsingConfig
+from .types.vertex_rag_data import RagVectorDbConfig
from .types.vertex_rag_data import UploadRagFileConfig
from .types.vertex_rag_data_service import CreateRagCorpusOperationMetadata
from .types.vertex_rag_data_service import CreateRagCorpusRequest
@@ -1022,6 +1025,8 @@
from .types.vertex_rag_data_service import ListRagCorporaResponse
from .types.vertex_rag_data_service import ListRagFilesRequest
from .types.vertex_rag_data_service import ListRagFilesResponse
+from .types.vertex_rag_data_service import UpdateRagCorpusOperationMetadata
+from .types.vertex_rag_data_service import UpdateRagCorpusRequest
from .types.vertex_rag_data_service import UploadRagFileRequest
from .types.vertex_rag_data_service import UploadRagFileResponse
from .types.vertex_rag_service import RagContexts
@@ -1175,6 +1180,7 @@
"CopyModelOperationMetadata",
"CopyModelRequest",
"CopyModelResponse",
+ "CorpusStatus",
"CountTokensRequest",
"CountTokensResponse",
"CreateArtifactRequest",
@@ -1403,6 +1409,7 @@
"FetchFeatureValuesRequest",
"FetchFeatureValuesResponse",
"FileData",
+ "FileStatus",
"FilterSplit",
"FindNeighborsRequest",
"FindNeighborsResponse",
@@ -1799,6 +1806,7 @@
"RagFileChunkingConfig",
"RagFileParsingConfig",
"RagQuery",
+ "RagVectorDbConfig",
"RawPredictRequest",
"RayLogsSpec",
"RayMetricSpec",
@@ -2013,6 +2021,8 @@
"UpdateNotebookRuntimeTemplateRequest",
"UpdatePersistentResourceOperationMetadata",
"UpdatePersistentResourceRequest",
+ "UpdateRagCorpusOperationMetadata",
+ "UpdateRagCorpusRequest",
"UpdateReasoningEngineOperationMetadata",
"UpdateReasoningEngineRequest",
"UpdateScheduleRequest",
diff --git a/google/cloud/aiplatform_v1beta1/gapic_metadata.json b/google/cloud/aiplatform_v1beta1/gapic_metadata.json
index 669d08cbc2..dfecadeb83 100644
--- a/google/cloud/aiplatform_v1beta1/gapic_metadata.json
+++ b/google/cloud/aiplatform_v1beta1/gapic_metadata.json
@@ -5368,6 +5368,11 @@
"list_rag_files"
]
},
+ "UpdateRagCorpus": {
+ "methods": [
+ "update_rag_corpus"
+ ]
+ },
"UploadRagFile": {
"methods": [
"upload_rag_file"
@@ -5418,6 +5423,11 @@
"list_rag_files"
]
},
+ "UpdateRagCorpus": {
+ "methods": [
+ "update_rag_corpus"
+ ]
+ },
"UploadRagFile": {
"methods": [
"upload_rag_file"
@@ -5468,6 +5478,11 @@
"list_rag_files"
]
},
+ "UpdateRagCorpus": {
+ "methods": [
+ "update_rag_corpus"
+ ]
+ },
"UploadRagFile": {
"methods": [
"upload_rag_file"
diff --git a/google/cloud/aiplatform_v1beta1/gapic_version.py b/google/cloud/aiplatform_v1beta1/gapic_version.py
index 8fd4500cf1..b3833ad200 100644
--- a/google/cloud/aiplatform_v1beta1/gapic_version.py
+++ b/google/cloud/aiplatform_v1beta1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.64.0" # {x-release-please-version}
+__version__ = "1.65.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py
index 9ebb5e6e63..3c6d7b8f85 100644
--- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py
+++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py
@@ -417,6 +417,133 @@ async def sample_create_rag_corpus():
# Done; return the response.
return response
+ async def update_rag_corpus(
+ self,
+ request: Optional[
+ Union[vertex_rag_data_service.UpdateRagCorpusRequest, dict]
+ ] = None,
+ *,
+ rag_corpus: Optional[vertex_rag_data.RagCorpus] = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: Union[float, object] = gapic_v1.method.DEFAULT,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Updates a RagCorpus.
+
+ .. code-block:: python
+
+ # This snippet has been automatically generated and should be regarded as a
+ # code template only.
+ # It will require modifications to work:
+ # - It may require correct/in-range values for request initialization.
+ # - It may require specifying regional endpoints when creating the service
+ # client as shown in:
+ # https://googleapis.dev/python/google-api-core/latest/client_options.html
+ from google.cloud import aiplatform_v1beta1
+
+ async def sample_update_rag_corpus():
+ # Create a client
+ client = aiplatform_v1beta1.VertexRagDataServiceAsyncClient()
+
+ # Initialize request argument(s)
+ rag_corpus = aiplatform_v1beta1.RagCorpus()
+ rag_corpus.display_name = "display_name_value"
+
+ request = aiplatform_v1beta1.UpdateRagCorpusRequest(
+ rag_corpus=rag_corpus,
+ )
+
+ # Make the request
+ operation = client.update_rag_corpus(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = (await operation).result()
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Optional[Union[google.cloud.aiplatform_v1beta1.types.UpdateRagCorpusRequest, dict]]):
+ The request object. Request message for
+ [VertexRagDataService.UpdateRagCorpus][google.cloud.aiplatform.v1beta1.VertexRagDataService.UpdateRagCorpus].
+ rag_corpus (:class:`google.cloud.aiplatform_v1beta1.types.RagCorpus`):
+ Required. The RagCorpus which
+ replaces the resource on the server.
+
+ This corresponds to the ``rag_corpus`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.RagCorpus` A RagCorpus is a RagFile container and a project can have multiple
+ RagCorpora.
+
+ """
+ # Create or coerce a protobuf request object.
+ # - Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([rag_corpus])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ # - Use the request object if provided (there's no risk of modifying the input as
+ # there are no flattened fields), or create one.
+ if not isinstance(request, vertex_rag_data_service.UpdateRagCorpusRequest):
+ request = vertex_rag_data_service.UpdateRagCorpusRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if rag_corpus is not None:
+ request.rag_corpus = rag_corpus
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._client._transport._wrapped_methods[
+ self._client._transport.update_rag_corpus
+ ]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("rag_corpus.name", request.rag_corpus.name),)
+ ),
+ )
+
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ vertex_rag_data.RagCorpus,
+ metadata_type=vertex_rag_data_service.UpdateRagCorpusOperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
async def get_rag_corpus(
self,
request: Optional[
diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py
index 23e4fa5dce..23618833d5 100644
--- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py
+++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py
@@ -915,6 +915,130 @@ def sample_create_rag_corpus():
# Done; return the response.
return response
+ def update_rag_corpus(
+ self,
+ request: Optional[
+ Union[vertex_rag_data_service.UpdateRagCorpusRequest, dict]
+ ] = None,
+ *,
+ rag_corpus: Optional[vertex_rag_data.RagCorpus] = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: Union[float, object] = gapic_v1.method.DEFAULT,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> gac_operation.Operation:
+ r"""Updates a RagCorpus.
+
+ .. code-block:: python
+
+ # This snippet has been automatically generated and should be regarded as a
+ # code template only.
+ # It will require modifications to work:
+ # - It may require correct/in-range values for request initialization.
+ # - It may require specifying regional endpoints when creating the service
+ # client as shown in:
+ # https://googleapis.dev/python/google-api-core/latest/client_options.html
+ from google.cloud import aiplatform_v1beta1
+
+ def sample_update_rag_corpus():
+ # Create a client
+ client = aiplatform_v1beta1.VertexRagDataServiceClient()
+
+ # Initialize request argument(s)
+ rag_corpus = aiplatform_v1beta1.RagCorpus()
+ rag_corpus.display_name = "display_name_value"
+
+ request = aiplatform_v1beta1.UpdateRagCorpusRequest(
+ rag_corpus=rag_corpus,
+ )
+
+ # Make the request
+ operation = client.update_rag_corpus(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = operation.result()
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1beta1.types.UpdateRagCorpusRequest, dict]):
+ The request object. Request message for
+ [VertexRagDataService.UpdateRagCorpus][google.cloud.aiplatform.v1beta1.VertexRagDataService.UpdateRagCorpus].
+ rag_corpus (google.cloud.aiplatform_v1beta1.types.RagCorpus):
+ Required. The RagCorpus which
+ replaces the resource on the server.
+
+ This corresponds to the ``rag_corpus`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation.Operation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.RagCorpus` A RagCorpus is a RagFile container and a project can have multiple
+ RagCorpora.
+
+ """
+ # Create or coerce a protobuf request object.
+ # - Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([rag_corpus])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ # - Use the request object if provided (there's no risk of modifying the input as
+ # there are no flattened fields), or create one.
+ if not isinstance(request, vertex_rag_data_service.UpdateRagCorpusRequest):
+ request = vertex_rag_data_service.UpdateRagCorpusRequest(request)
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if rag_corpus is not None:
+ request.rag_corpus = rag_corpus
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.update_rag_corpus]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("rag_corpus.name", request.rag_corpus.name),)
+ ),
+ )
+
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Wrap the response in an operation future.
+ response = gac_operation.from_gapic(
+ response,
+ self._transport.operations_client,
+ vertex_rag_data.RagCorpus,
+ metadata_type=vertex_rag_data_service.UpdateRagCorpusOperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
def get_rag_corpus(
self,
request: Optional[
diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/base.py
index c96aaec4ab..5023b56a66 100644
--- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/base.py
+++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/base.py
@@ -140,6 +140,11 @@ def _prep_wrapped_messages(self, client_info):
default_timeout=None,
client_info=client_info,
),
+ self.update_rag_corpus: gapic_v1.method.wrap_method(
+ self.update_rag_corpus,
+ default_timeout=None,
+ client_info=client_info,
+ ),
self.get_rag_corpus: gapic_v1.method.wrap_method(
self.get_rag_corpus,
default_timeout=None,
@@ -205,6 +210,15 @@ def create_rag_corpus(
]:
raise NotImplementedError()
+ @property
+ def update_rag_corpus(
+ self,
+ ) -> Callable[
+ [vertex_rag_data_service.UpdateRagCorpusRequest],
+ Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]],
+ ]:
+ raise NotImplementedError()
+
@property
def get_rag_corpus(
self,
diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc.py
index adddd14811..61b5e95563 100644
--- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc.py
+++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc.py
@@ -284,6 +284,34 @@ def create_rag_corpus(
)
return self._stubs["create_rag_corpus"]
+ @property
+ def update_rag_corpus(
+ self,
+ ) -> Callable[
+ [vertex_rag_data_service.UpdateRagCorpusRequest], operations_pb2.Operation
+ ]:
+ r"""Return a callable for the update rag corpus method over gRPC.
+
+ Updates a RagCorpus.
+
+ Returns:
+ Callable[[~.UpdateRagCorpusRequest],
+ ~.Operation]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "update_rag_corpus" not in self._stubs:
+ self._stubs["update_rag_corpus"] = self.grpc_channel.unary_unary(
+ "/google.cloud.aiplatform.v1beta1.VertexRagDataService/UpdateRagCorpus",
+ request_serializer=vertex_rag_data_service.UpdateRagCorpusRequest.serialize,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["update_rag_corpus"]
+
@property
def get_rag_corpus(
self,
diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc_asyncio.py
index 13d74aca8a..f804a8a7fd 100644
--- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc_asyncio.py
+++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc_asyncio.py
@@ -291,6 +291,35 @@ def create_rag_corpus(
)
return self._stubs["create_rag_corpus"]
+ @property
+ def update_rag_corpus(
+ self,
+ ) -> Callable[
+ [vertex_rag_data_service.UpdateRagCorpusRequest],
+ Awaitable[operations_pb2.Operation],
+ ]:
+ r"""Return a callable for the update rag corpus method over gRPC.
+
+ Updates a RagCorpus.
+
+ Returns:
+ Callable[[~.UpdateRagCorpusRequest],
+ Awaitable[~.Operation]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "update_rag_corpus" not in self._stubs:
+ self._stubs["update_rag_corpus"] = self.grpc_channel.unary_unary(
+ "/google.cloud.aiplatform.v1beta1.VertexRagDataService/UpdateRagCorpus",
+ request_serializer=vertex_rag_data_service.UpdateRagCorpusRequest.serialize,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["update_rag_corpus"]
+
@property
def get_rag_corpus(
self,
@@ -531,6 +560,11 @@ def _prep_wrapped_messages(self, client_info):
default_timeout=None,
client_info=client_info,
),
+ self.update_rag_corpus: gapic_v1.method_async.wrap_method(
+ self.update_rag_corpus,
+ default_timeout=None,
+ client_info=client_info,
+ ),
self.get_rag_corpus: gapic_v1.method_async.wrap_method(
self.get_rag_corpus,
default_timeout=None,
diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/rest.py
index f623649e32..856524bcce 100644
--- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/rest.py
+++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/rest.py
@@ -139,6 +139,14 @@ def post_list_rag_files(self, response):
logging.log(f"Received response: {response}")
return response
+ def pre_update_rag_corpus(self, request, metadata):
+ logging.log(f"Received request: {request}")
+ return request, metadata
+
+ def post_update_rag_corpus(self, response):
+ logging.log(f"Received response: {response}")
+ return response
+
def pre_upload_rag_file(self, request, metadata):
logging.log(f"Received request: {request}")
return request, metadata
@@ -345,6 +353,31 @@ def post_list_rag_files(
"""
return response
+ def pre_update_rag_corpus(
+ self,
+ request: vertex_rag_data_service.UpdateRagCorpusRequest,
+ metadata: Sequence[Tuple[str, str]],
+ ) -> Tuple[
+ vertex_rag_data_service.UpdateRagCorpusRequest, Sequence[Tuple[str, str]]
+ ]:
+ """Pre-rpc interceptor for update_rag_corpus
+
+ Override in a subclass to manipulate the request or metadata
+ before they are sent to the VertexRagDataService server.
+ """
+ return request, metadata
+
+ def post_update_rag_corpus(
+ self, response: operations_pb2.Operation
+ ) -> operations_pb2.Operation:
+ """Post-rpc interceptor for update_rag_corpus
+
+ Override in a subclass to manipulate the response
+ after it is returned by the VertexRagDataService server but before
+ it is returned to user code.
+ """
+ return response
+
def pre_upload_rag_file(
self,
request: vertex_rag_data_service.UploadRagFileRequest,
@@ -3465,6 +3498,102 @@ def __call__(
resp = self._interceptor.post_list_rag_files(resp)
return resp
+ class _UpdateRagCorpus(VertexRagDataServiceRestStub):
+ def __hash__(self):
+ return hash("UpdateRagCorpus")
+
+ __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {}
+
+ @classmethod
+ def _get_unset_required_fields(cls, message_dict):
+ return {
+ k: v
+ for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items()
+ if k not in message_dict
+ }
+
+ def __call__(
+ self,
+ request: vertex_rag_data_service.UpdateRagCorpusRequest,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.Operation:
+ r"""Call the update rag corpus method over HTTP.
+
+ Args:
+ request (~.vertex_rag_data_service.UpdateRagCorpusRequest):
+ The request object. Request message for
+ [VertexRagDataService.UpdateRagCorpus][google.cloud.aiplatform.v1beta1.VertexRagDataService.UpdateRagCorpus].
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ ~.operations_pb2.Operation:
+ This resource represents a
+ long-running operation that is the
+ result of a network API call.
+
+ """
+
+ http_options: List[Dict[str, str]] = [
+ {
+ "method": "patch",
+ "uri": "/v1beta1/{rag_corpus.name=projects/*/locations/*/ragCorpora/*}",
+ "body": "rag_corpus",
+ },
+ ]
+ request, metadata = self._interceptor.pre_update_rag_corpus(
+ request, metadata
+ )
+ pb_request = vertex_rag_data_service.UpdateRagCorpusRequest.pb(request)
+ transcoded_request = path_template.transcode(http_options, pb_request)
+
+ # Jsonify the request body
+
+ body = json_format.MessageToJson(
+ transcoded_request["body"], use_integers_for_enums=True
+ )
+ uri = transcoded_request["uri"]
+ method = transcoded_request["method"]
+
+ # Jsonify the query params
+ query_params = json.loads(
+ json_format.MessageToJson(
+ transcoded_request["query_params"],
+ use_integers_for_enums=True,
+ )
+ )
+ query_params.update(self._get_unset_required_fields(query_params))
+
+ query_params["$alt"] = "json;enum-encoding=int"
+
+ # Send the request
+ headers = dict(metadata)
+ headers["Content-Type"] = "application/json"
+ response = getattr(self._session, method)(
+ "{host}{uri}".format(host=self._host, uri=uri),
+ timeout=timeout,
+ headers=headers,
+ params=rest_helpers.flatten_query_params(query_params, strict=True),
+ data=body,
+ )
+
+ # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception
+ # subclass.
+ if response.status_code >= 400:
+ raise core_exceptions.from_http_response(response)
+
+ # Return the response
+ resp = operations_pb2.Operation()
+ json_format.Parse(response.content, resp, ignore_unknown_fields=True)
+ resp = self._interceptor.post_update_rag_corpus(resp)
+ return resp
+
class _UploadRagFile(VertexRagDataServiceRestStub):
def __hash__(self):
return hash("UploadRagFile")
@@ -3640,6 +3769,16 @@ def list_rag_files(
# In C++ this would require a dynamic_cast
return self._ListRagFiles(self._session, self._host, self._interceptor) # type: ignore
+ @property
+ def update_rag_corpus(
+ self,
+ ) -> Callable[
+ [vertex_rag_data_service.UpdateRagCorpusRequest], operations_pb2.Operation
+ ]:
+ # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here.
+ # In C++ this would require a dynamic_cast
+ return self._UpdateRagCorpus(self._session, self._host, self._interceptor) # type: ignore
+
@property
def upload_rag_file(
self,
diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py
index c2bf840843..7ed18b6fd2 100644
--- a/google/cloud/aiplatform_v1beta1/types/__init__.py
+++ b/google/cloud/aiplatform_v1beta1/types/__init__.py
@@ -1141,12 +1141,15 @@
Value,
)
from .vertex_rag_data import (
+ CorpusStatus,
+ FileStatus,
ImportRagFilesConfig,
RagCorpus,
RagEmbeddingModelConfig,
RagFile,
RagFileChunkingConfig,
RagFileParsingConfig,
+ RagVectorDbConfig,
UploadRagFileConfig,
)
from .vertex_rag_data_service import (
@@ -1163,6 +1166,8 @@
ListRagCorporaResponse,
ListRagFilesRequest,
ListRagFilesResponse,
+ UpdateRagCorpusOperationMetadata,
+ UpdateRagCorpusRequest,
UploadRagFileRequest,
UploadRagFileResponse,
)
@@ -2079,12 +2084,15 @@
"UnmanagedContainerModel",
"UserActionReference",
"Value",
+ "CorpusStatus",
+ "FileStatus",
"ImportRagFilesConfig",
"RagCorpus",
"RagEmbeddingModelConfig",
"RagFile",
"RagFileChunkingConfig",
"RagFileParsingConfig",
+ "RagVectorDbConfig",
"UploadRagFileConfig",
"CreateRagCorpusOperationMetadata",
"CreateRagCorpusRequest",
@@ -2099,6 +2107,8 @@
"ListRagCorporaResponse",
"ListRagFilesRequest",
"ListRagFilesResponse",
+ "UpdateRagCorpusOperationMetadata",
+ "UpdateRagCorpusRequest",
"UploadRagFileRequest",
"UploadRagFileResponse",
"RagContexts",
diff --git a/google/cloud/aiplatform_v1beta1/types/api_auth.py b/google/cloud/aiplatform_v1beta1/types/api_auth.py
index dc2a7ebe9a..2e7eb49e24 100644
--- a/google/cloud/aiplatform_v1beta1/types/api_auth.py
+++ b/google/cloud/aiplatform_v1beta1/types/api_auth.py
@@ -29,7 +29,16 @@
class ApiAuth(proto.Message):
- r"""The generic reusable api auth config."""
+ r"""The generic reusable api auth config.
+
+ .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
+
+ Attributes:
+ api_key_config (google.cloud.aiplatform_v1beta1.types.ApiAuth.ApiKeyConfig):
+ The API secret.
+
+ This field is a member of `oneof`_ ``auth_config``.
+ """
class ApiKeyConfig(proto.Message):
r"""The API secret.
@@ -46,5 +55,12 @@ class ApiKeyConfig(proto.Message):
number=1,
)
+ api_key_config: ApiKeyConfig = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ oneof="auth_config",
+ message=ApiKeyConfig,
+ )
+
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py
index 348f722971..df3118413a 100644
--- a/google/cloud/aiplatform_v1beta1/types/custom_job.py
+++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py
@@ -558,6 +558,12 @@ class Scheduling(proto.Message):
Optional. Indicates if the job should retry for internal
errors after the job starts running. If true, overrides
``Scheduling.restart_job_on_worker_restart`` to false.
+ max_wait_duration (google.protobuf.duration_pb2.Duration):
+ Optional. This is the maximum duration that a job will wait
+ for the requested resources to be provisioned if the
+ scheduling strategy is set to [Strategy.DWS_FLEX_START]. If
+ set to 0, the job will wait indefinitely. The default is 24
+ hours.
"""
class Strategy(proto.Enum):
@@ -606,6 +612,11 @@ class Strategy(proto.Enum):
proto.BOOL,
number=5,
)
+ max_wait_duration: duration_pb2.Duration = proto.Field(
+ proto.MESSAGE,
+ number=6,
+ message=duration_pb2.Duration,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py
index fb79ed0e90..787dd96a6d 100644
--- a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py
+++ b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py
@@ -19,6 +19,7 @@
import proto # type: ignore
+from google.cloud.aiplatform_v1beta1.types import api_auth as gca_api_auth
from google.cloud.aiplatform_v1beta1.types import io
from google.protobuf import timestamp_pb2 # type: ignore
@@ -27,6 +28,9 @@
package="google.cloud.aiplatform.v1beta1",
manifest={
"RagEmbeddingModelConfig",
+ "RagVectorDbConfig",
+ "FileStatus",
+ "CorpusStatus",
"RagCorpus",
"RagFile",
"RagFileChunkingConfig",
@@ -40,6 +44,11 @@
class RagEmbeddingModelConfig(proto.Message):
r"""Config for the embedding model to use for RAG.
+ This message has `oneof`_ fields (mutually exclusive fields).
+ For each oneof, at most one member field can be set at the same time.
+ Setting any member of the oneof automatically clears all other
+ members.
+
.. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
Attributes:
@@ -49,6 +58,11 @@ class RagEmbeddingModelConfig(proto.Message):
is hosting a 1P fine-tuned text embedding model.
Endpoints hosting non-1P fine-tuned text
embedding models are currently not supported.
+ This is used for dense vector search.
+
+ This field is a member of `oneof`_ ``model_config``.
+ hybrid_search_config (google.cloud.aiplatform_v1beta1.types.RagEmbeddingModelConfig.HybridSearchConfig):
+ Configuration for hybrid search.
This field is a member of `oneof`_ ``model_config``.
"""
@@ -87,12 +101,278 @@ class VertexPredictionEndpoint(proto.Message):
number=3,
)
+ class SparseEmbeddingConfig(proto.Message):
+ r"""Configuration for sparse emebdding generation.
+
+ .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
+
+ Attributes:
+ bm25 (google.cloud.aiplatform_v1beta1.types.RagEmbeddingModelConfig.SparseEmbeddingConfig.Bm25):
+ Use BM25 scoring algorithm.
+
+ This field is a member of `oneof`_ ``model``.
+ """
+
+ class Bm25(proto.Message):
+ r"""Message for BM25 parameters.
+
+ .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
+
+ Attributes:
+ multilingual (bool):
+ Optional. Use multilingual tokenizer if set
+ to true.
+ k1 (float):
+ Optional. The parameter to control term frequency
+ saturation. It determines the scaling between the matching
+ term frequency and final score. k1 is in the range of [1.2,
+ 3]. The default value is 1.2.
+
+ This field is a member of `oneof`_ ``_k1``.
+ b (float):
+ Optional. The parameter to control document length
+ normalization. It determines how much the document length
+ affects the final score. b is in the range of [0, 1]. The
+ default value is 0.75.
+
+ This field is a member of `oneof`_ ``_b``.
+ """
+
+ multilingual: bool = proto.Field(
+ proto.BOOL,
+ number=1,
+ )
+ k1: float = proto.Field(
+ proto.FLOAT,
+ number=2,
+ optional=True,
+ )
+ b: float = proto.Field(
+ proto.FLOAT,
+ number=3,
+ optional=True,
+ )
+
+ bm25: "RagEmbeddingModelConfig.SparseEmbeddingConfig.Bm25" = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ oneof="model",
+ message="RagEmbeddingModelConfig.SparseEmbeddingConfig.Bm25",
+ )
+
+ class HybridSearchConfig(proto.Message):
+ r"""Config for hybrid search.
+
+ Attributes:
+ sparse_embedding_config (google.cloud.aiplatform_v1beta1.types.RagEmbeddingModelConfig.SparseEmbeddingConfig):
+ Optional. The configuration for sparse
+ embedding generation. This field is optional the
+ default behavior depends on the vector database
+ choice on the RagCorpus.
+ dense_embedding_model_prediction_endpoint (google.cloud.aiplatform_v1beta1.types.RagEmbeddingModelConfig.VertexPredictionEndpoint):
+ Required. The Vertex AI Prediction Endpoint
+ that hosts the embedding model for dense
+ embedding generations.
+ """
+
+ sparse_embedding_config: "RagEmbeddingModelConfig.SparseEmbeddingConfig" = (
+ proto.Field(
+ proto.MESSAGE,
+ number=1,
+ message="RagEmbeddingModelConfig.SparseEmbeddingConfig",
+ )
+ )
+ dense_embedding_model_prediction_endpoint: "RagEmbeddingModelConfig.VertexPredictionEndpoint" = proto.Field(
+ proto.MESSAGE,
+ number=2,
+ message="RagEmbeddingModelConfig.VertexPredictionEndpoint",
+ )
+
vertex_prediction_endpoint: VertexPredictionEndpoint = proto.Field(
proto.MESSAGE,
number=1,
oneof="model_config",
message=VertexPredictionEndpoint,
)
+ hybrid_search_config: HybridSearchConfig = proto.Field(
+ proto.MESSAGE,
+ number=2,
+ oneof="model_config",
+ message=HybridSearchConfig,
+ )
+
+
+class RagVectorDbConfig(proto.Message):
+ r"""Config for the Vector DB to use for RAG.
+
+ This message has `oneof`_ fields (mutually exclusive fields).
+ For each oneof, at most one member field can be set at the same time.
+ Setting any member of the oneof automatically clears all other
+ members.
+
+ .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
+
+ Attributes:
+ rag_managed_db (google.cloud.aiplatform_v1beta1.types.RagVectorDbConfig.RagManagedDb):
+ The config for the RAG-managed Vector DB.
+
+ This field is a member of `oneof`_ ``vector_db``.
+ weaviate (google.cloud.aiplatform_v1beta1.types.RagVectorDbConfig.Weaviate):
+ The config for the Weaviate.
+
+ This field is a member of `oneof`_ ``vector_db``.
+ vertex_feature_store (google.cloud.aiplatform_v1beta1.types.RagVectorDbConfig.VertexFeatureStore):
+ The config for the Vertex Feature Store.
+
+ This field is a member of `oneof`_ ``vector_db``.
+ api_auth (google.cloud.aiplatform_v1beta1.types.ApiAuth):
+ Authentication config for the chosen Vector
+ DB.
+ """
+
+ class RagManagedDb(proto.Message):
+ r"""The config for the default RAG-managed Vector DB."""
+
+ class Weaviate(proto.Message):
+ r"""The config for the Weaviate.
+
+ Attributes:
+ http_endpoint (str):
+ Weaviate DB instance HTTP endpoint. e.g.
+ 34.56.78.90:8080 Vertex RAG only supports HTTP
+ connection to Weaviate. This value cannot be
+ changed after it's set.
+ collection_name (str):
+ The corresponding collection this corpus maps
+ to. This value cannot be changed after it's set.
+ """
+
+ http_endpoint: str = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ collection_name: str = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+
+ class VertexFeatureStore(proto.Message):
+ r"""The config for the Vertex Feature Store.
+
+ Attributes:
+ feature_view_resource_name (str):
+ The resource name of the FeatureView. Format:
+ ``projects/{project}/locations/{location}/featureOnlineStores/{feature_online_store}/featureViews/{feature_view}``
+ """
+
+ feature_view_resource_name: str = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+
+ rag_managed_db: RagManagedDb = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ oneof="vector_db",
+ message=RagManagedDb,
+ )
+ weaviate: Weaviate = proto.Field(
+ proto.MESSAGE,
+ number=2,
+ oneof="vector_db",
+ message=Weaviate,
+ )
+ vertex_feature_store: VertexFeatureStore = proto.Field(
+ proto.MESSAGE,
+ number=4,
+ oneof="vector_db",
+ message=VertexFeatureStore,
+ )
+ api_auth: gca_api_auth.ApiAuth = proto.Field(
+ proto.MESSAGE,
+ number=5,
+ message=gca_api_auth.ApiAuth,
+ )
+
+
+class FileStatus(proto.Message):
+ r"""RagFile status.
+
+ Attributes:
+ state (google.cloud.aiplatform_v1beta1.types.FileStatus.State):
+ Output only. RagFile state.
+ error_status (str):
+ Output only. Only when the ``state`` field is ERROR.
+ """
+
+ class State(proto.Enum):
+ r"""RagFile state.
+
+ Values:
+ STATE_UNSPECIFIED (0):
+ RagFile state is unspecified.
+ ACTIVE (1):
+ RagFile resource has been created and indexed
+ successfully.
+ ERROR (2):
+ RagFile resource is in a problematic state. See
+ ``error_message`` field for details.
+ """
+ STATE_UNSPECIFIED = 0
+ ACTIVE = 1
+ ERROR = 2
+
+ state: State = proto.Field(
+ proto.ENUM,
+ number=1,
+ enum=State,
+ )
+ error_status: str = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+
+
+class CorpusStatus(proto.Message):
+ r"""RagCorpus status.
+
+ Attributes:
+ state (google.cloud.aiplatform_v1beta1.types.CorpusStatus.State):
+ Output only. RagCorpus life state.
+ error_status (str):
+ Output only. Only when the ``state`` field is ERROR.
+ """
+
+ class State(proto.Enum):
+ r"""RagCorpus life state.
+
+ Values:
+ UNKNOWN (0):
+ This state is not supposed to happen.
+ INITIALIZED (1):
+ RagCorpus resource entry is initialized, but
+ hasn't done validation.
+ ACTIVE (2):
+ RagCorpus is provisioned successfully and is
+ ready to serve.
+ ERROR (3):
+ RagCorpus is in a problematic situation. See
+ ``error_message`` field for details.
+ """
+ UNKNOWN = 0
+ INITIALIZED = 1
+ ACTIVE = 2
+ ERROR = 3
+
+ state: State = proto.Field(
+ proto.ENUM,
+ number=1,
+ enum=State,
+ )
+ error_status: str = proto.Field(
+ proto.STRING,
+ number=2,
+ )
class RagCorpus(proto.Message):
@@ -112,12 +392,17 @@ class RagCorpus(proto.Message):
rag_embedding_model_config (google.cloud.aiplatform_v1beta1.types.RagEmbeddingModelConfig):
Optional. Immutable. The embedding model
config of the RagCorpus.
+ rag_vector_db_config (google.cloud.aiplatform_v1beta1.types.RagVectorDbConfig):
+ Optional. Immutable. The Vector DB config of
+ the RagCorpus.
create_time (google.protobuf.timestamp_pb2.Timestamp):
Output only. Timestamp when this RagCorpus
was created.
update_time (google.protobuf.timestamp_pb2.Timestamp):
Output only. Timestamp when this RagCorpus
was last updated.
+ corpus_status (google.cloud.aiplatform_v1beta1.types.CorpusStatus):
+ Output only. RagCorpus state.
"""
name: str = proto.Field(
@@ -137,6 +422,11 @@ class RagCorpus(proto.Message):
number=6,
message="RagEmbeddingModelConfig",
)
+ rag_vector_db_config: "RagVectorDbConfig" = proto.Field(
+ proto.MESSAGE,
+ number=7,
+ message="RagVectorDbConfig",
+ )
create_time: timestamp_pb2.Timestamp = proto.Field(
proto.MESSAGE,
number=4,
@@ -147,6 +437,11 @@ class RagCorpus(proto.Message):
number=5,
message=timestamp_pb2.Timestamp,
)
+ corpus_status: "CorpusStatus" = proto.Field(
+ proto.MESSAGE,
+ number=8,
+ message="CorpusStatus",
+ )
class RagFile(proto.Message):
@@ -206,6 +501,8 @@ class RagFile(proto.Message):
update_time (google.protobuf.timestamp_pb2.Timestamp):
Output only. Timestamp when this RagFile was
last updated.
+ file_status (google.cloud.aiplatform_v1beta1.types.FileStatus):
+ Output only. State of the RagFile.
"""
class RagFileType(proto.Enum):
@@ -284,6 +581,11 @@ class RagFileType(proto.Enum):
number=7,
message=timestamp_pb2.Timestamp,
)
+ file_status: "FileStatus" = proto.Field(
+ proto.MESSAGE,
+ number=13,
+ message="FileStatus",
+ )
class RagFileChunkingConfig(proto.Message):
diff --git a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data_service.py b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data_service.py
index 25e828db4a..94f0621ba4 100644
--- a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data_service.py
+++ b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data_service.py
@@ -41,6 +41,8 @@
"ListRagFilesResponse",
"DeleteRagFileRequest",
"CreateRagCorpusOperationMetadata",
+ "UpdateRagCorpusRequest",
+ "UpdateRagCorpusOperationMetadata",
"ImportRagFilesOperationMetadata",
},
)
@@ -412,6 +414,39 @@ class CreateRagCorpusOperationMetadata(proto.Message):
)
+class UpdateRagCorpusRequest(proto.Message):
+ r"""Request message for
+ [VertexRagDataService.UpdateRagCorpus][google.cloud.aiplatform.v1beta1.VertexRagDataService.UpdateRagCorpus].
+
+ Attributes:
+ rag_corpus (google.cloud.aiplatform_v1beta1.types.RagCorpus):
+ Required. The RagCorpus which replaces the
+ resource on the server.
+ """
+
+ rag_corpus: vertex_rag_data.RagCorpus = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ message=vertex_rag_data.RagCorpus,
+ )
+
+
+class UpdateRagCorpusOperationMetadata(proto.Message):
+ r"""Runtime operation information for
+ [VertexRagDataService.UpdateRagCorpus][google.cloud.aiplatform.v1beta1.VertexRagDataService.UpdateRagCorpus].
+
+ Attributes:
+ generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata):
+ The operation generic information.
+ """
+
+ generic_metadata: operation.GenericOperationMetadata = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ message=operation.GenericOperationMetadata,
+ )
+
+
class ImportRagFilesOperationMetadata(proto.Message):
r"""Runtime operation information for
[VertexRagDataService.ImportRagFiles][google.cloud.aiplatform.v1beta1.VertexRagDataService.ImportRagFiles].
diff --git a/google/cloud/aiplatform_v1beta1/types/vertex_rag_service.py b/google/cloud/aiplatform_v1beta1/types/vertex_rag_service.py
index e33f30e95c..2cbb95c76c 100644
--- a/google/cloud/aiplatform_v1beta1/types/vertex_rag_service.py
+++ b/google/cloud/aiplatform_v1beta1/types/vertex_rag_service.py
@@ -44,8 +44,33 @@ class RagQuery(proto.Message):
This field is a member of `oneof`_ ``query``.
similarity_top_k (int):
Optional. The number of contexts to retrieve.
+ ranking (google.cloud.aiplatform_v1beta1.types.RagQuery.Ranking):
+ Optional. Configurations for hybrid search
+ results ranking.
"""
+ class Ranking(proto.Message):
+ r"""Configurations for hybrid search results ranking.
+
+ .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
+
+ Attributes:
+ alpha (float):
+ Optional. Alpha value controls the weight between dense and
+ sparse vector search results. The range is [0, 1], while 0
+ means sparse vector search only and 1 means dense vector
+ search only. The default value is 0.5 which balances sparse
+ and dense vector search equally.
+
+ This field is a member of `oneof`_ ``_alpha``.
+ """
+
+ alpha: float = proto.Field(
+ proto.FLOAT,
+ number=1,
+ optional=True,
+ )
+
text: str = proto.Field(
proto.STRING,
number=1,
@@ -55,6 +80,11 @@ class RagQuery(proto.Message):
proto.INT32,
number=2,
)
+ ranking: Ranking = proto.Field(
+ proto.MESSAGE,
+ number=4,
+ message=Ranking,
+ )
class RetrieveContextsRequest(proto.Message):
@@ -175,8 +205,11 @@ class Context(proto.Message):
text (str):
The text chunk.
distance (float):
- The distance between the query vector and the
- context text vector.
+ The distance between the query dense
+ embedding vector and the context text vector.
+ sparse_distance (float):
+ The distance between the query sparse
+ embedding vector and the context text vector.
"""
source_uri: str = proto.Field(
@@ -191,6 +224,10 @@ class Context(proto.Message):
proto.DOUBLE,
number=3,
)
+ sparse_distance: float = proto.Field(
+ proto.DOUBLE,
+ number=4,
+ )
contexts: MutableSequence[Context] = proto.RepeatedField(
proto.MESSAGE,
diff --git a/owlbot.py b/owlbot.py
index ea1a2b9795..c8a1afe6ec 100644
--- a/owlbot.py
+++ b/owlbot.py
@@ -112,6 +112,7 @@
".kokoro/continuous/prerelease-deps.cfg",
".kokoro/presubmit/prerelease-deps.cfg",
".kokoro/docs/docs-presubmit.cfg",
+ ".kokoro/build.sh",
".kokoro/release.sh",
".kokoro/release/common.cfg",
".kokoro/requirements*",
diff --git a/pypi/_vertex_ai_placeholder/version.py b/pypi/_vertex_ai_placeholder/version.py
index dc62baac78..7979098357 100644
--- a/pypi/_vertex_ai_placeholder/version.py
+++ b/pypi/_vertex_ai_placeholder/version.py
@@ -15,4 +15,4 @@
# limitations under the License.
#
-__version__ = "1.64.0"
+__version__ = "1.65.0"
diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_async.py b/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_async.py
new file mode 100644
index 0000000000..55bde0360a
--- /dev/null
+++ b/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_async.py
@@ -0,0 +1,59 @@
+# -*- coding: utf-8 -*-
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Generated code. DO NOT EDIT!
+#
+# Snippet for UpdateRagCorpus
+# NOTE: This snippet has been automatically generated for illustrative purposes only.
+# It may require modifications to work in your environment.
+
+# To install the latest published package dependency, execute the following:
+# python3 -m pip install google-cloud-aiplatform
+
+
+# [START aiplatform_v1beta1_generated_VertexRagDataService_UpdateRagCorpus_async]
+# This snippet has been automatically generated and should be regarded as a
+# code template only.
+# It will require modifications to work:
+# - It may require correct/in-range values for request initialization.
+# - It may require specifying regional endpoints when creating the service
+# client as shown in:
+# https://googleapis.dev/python/google-api-core/latest/client_options.html
+from google.cloud import aiplatform_v1beta1
+
+
+async def sample_update_rag_corpus():
+ # Create a client
+ client = aiplatform_v1beta1.VertexRagDataServiceAsyncClient()
+
+ # Initialize request argument(s)
+ rag_corpus = aiplatform_v1beta1.RagCorpus()
+ rag_corpus.display_name = "display_name_value"
+
+ request = aiplatform_v1beta1.UpdateRagCorpusRequest(
+ rag_corpus=rag_corpus,
+ )
+
+ # Make the request
+ operation = client.update_rag_corpus(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = (await operation).result()
+
+ # Handle the response
+ print(response)
+
+# [END aiplatform_v1beta1_generated_VertexRagDataService_UpdateRagCorpus_async]
diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_sync.py b/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_sync.py
new file mode 100644
index 0000000000..81f3e7fb57
--- /dev/null
+++ b/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_sync.py
@@ -0,0 +1,59 @@
+# -*- coding: utf-8 -*-
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Generated code. DO NOT EDIT!
+#
+# Snippet for UpdateRagCorpus
+# NOTE: This snippet has been automatically generated for illustrative purposes only.
+# It may require modifications to work in your environment.
+
+# To install the latest published package dependency, execute the following:
+# python3 -m pip install google-cloud-aiplatform
+
+
+# [START aiplatform_v1beta1_generated_VertexRagDataService_UpdateRagCorpus_sync]
+# This snippet has been automatically generated and should be regarded as a
+# code template only.
+# It will require modifications to work:
+# - It may require correct/in-range values for request initialization.
+# - It may require specifying regional endpoints when creating the service
+# client as shown in:
+# https://googleapis.dev/python/google-api-core/latest/client_options.html
+from google.cloud import aiplatform_v1beta1
+
+
+def sample_update_rag_corpus():
+ # Create a client
+ client = aiplatform_v1beta1.VertexRagDataServiceClient()
+
+ # Initialize request argument(s)
+ rag_corpus = aiplatform_v1beta1.RagCorpus()
+ rag_corpus.display_name = "display_name_value"
+
+ request = aiplatform_v1beta1.UpdateRagCorpusRequest(
+ rag_corpus=rag_corpus,
+ )
+
+ # Make the request
+ operation = client.update_rag_corpus(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = operation.result()
+
+ # Handle the response
+ print(response)
+
+# [END aiplatform_v1beta1_generated_VertexRagDataService_UpdateRagCorpus_sync]
diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json
index c4b8e9f3fc..e824509826 100644
--- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json
+++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json
@@ -8,7 +8,7 @@
],
"language": "PYTHON",
"name": "google-cloud-aiplatform",
- "version": "1.64.0"
+ "version": "1.65.0"
},
"snippets": [
{
diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json
index 86cee3f831..c2a656078a 100644
--- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json
+++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json
@@ -8,7 +8,7 @@
],
"language": "PYTHON",
"name": "google-cloud-aiplatform",
- "version": "1.64.0"
+ "version": "1.65.0"
},
"snippets": [
{
@@ -52847,6 +52847,167 @@
],
"title": "aiplatform_v1beta1_generated_vertex_rag_data_service_list_rag_files_sync.py"
},
+ {
+ "canonical": true,
+ "clientMethod": {
+ "async": true,
+ "client": {
+ "fullName": "google.cloud.aiplatform_v1beta1.VertexRagDataServiceAsyncClient",
+ "shortName": "VertexRagDataServiceAsyncClient"
+ },
+ "fullName": "google.cloud.aiplatform_v1beta1.VertexRagDataServiceAsyncClient.update_rag_corpus",
+ "method": {
+ "fullName": "google.cloud.aiplatform.v1beta1.VertexRagDataService.UpdateRagCorpus",
+ "service": {
+ "fullName": "google.cloud.aiplatform.v1beta1.VertexRagDataService",
+ "shortName": "VertexRagDataService"
+ },
+ "shortName": "UpdateRagCorpus"
+ },
+ "parameters": [
+ {
+ "name": "request",
+ "type": "google.cloud.aiplatform_v1beta1.types.UpdateRagCorpusRequest"
+ },
+ {
+ "name": "rag_corpus",
+ "type": "google.cloud.aiplatform_v1beta1.types.RagCorpus"
+ },
+ {
+ "name": "retry",
+ "type": "google.api_core.retry.Retry"
+ },
+ {
+ "name": "timeout",
+ "type": "float"
+ },
+ {
+ "name": "metadata",
+ "type": "Sequence[Tuple[str, str]"
+ }
+ ],
+ "resultType": "google.api_core.operation_async.AsyncOperation",
+ "shortName": "update_rag_corpus"
+ },
+ "description": "Sample for UpdateRagCorpus",
+ "file": "aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_async.py",
+ "language": "PYTHON",
+ "origin": "API_DEFINITION",
+ "regionTag": "aiplatform_v1beta1_generated_VertexRagDataService_UpdateRagCorpus_async",
+ "segments": [
+ {
+ "end": 58,
+ "start": 27,
+ "type": "FULL"
+ },
+ {
+ "end": 58,
+ "start": 27,
+ "type": "SHORT"
+ },
+ {
+ "end": 40,
+ "start": 38,
+ "type": "CLIENT_INITIALIZATION"
+ },
+ {
+ "end": 48,
+ "start": 41,
+ "type": "REQUEST_INITIALIZATION"
+ },
+ {
+ "end": 55,
+ "start": 49,
+ "type": "REQUEST_EXECUTION"
+ },
+ {
+ "end": 59,
+ "start": 56,
+ "type": "RESPONSE_HANDLING"
+ }
+ ],
+ "title": "aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_async.py"
+ },
+ {
+ "canonical": true,
+ "clientMethod": {
+ "client": {
+ "fullName": "google.cloud.aiplatform_v1beta1.VertexRagDataServiceClient",
+ "shortName": "VertexRagDataServiceClient"
+ },
+ "fullName": "google.cloud.aiplatform_v1beta1.VertexRagDataServiceClient.update_rag_corpus",
+ "method": {
+ "fullName": "google.cloud.aiplatform.v1beta1.VertexRagDataService.UpdateRagCorpus",
+ "service": {
+ "fullName": "google.cloud.aiplatform.v1beta1.VertexRagDataService",
+ "shortName": "VertexRagDataService"
+ },
+ "shortName": "UpdateRagCorpus"
+ },
+ "parameters": [
+ {
+ "name": "request",
+ "type": "google.cloud.aiplatform_v1beta1.types.UpdateRagCorpusRequest"
+ },
+ {
+ "name": "rag_corpus",
+ "type": "google.cloud.aiplatform_v1beta1.types.RagCorpus"
+ },
+ {
+ "name": "retry",
+ "type": "google.api_core.retry.Retry"
+ },
+ {
+ "name": "timeout",
+ "type": "float"
+ },
+ {
+ "name": "metadata",
+ "type": "Sequence[Tuple[str, str]"
+ }
+ ],
+ "resultType": "google.api_core.operation.Operation",
+ "shortName": "update_rag_corpus"
+ },
+ "description": "Sample for UpdateRagCorpus",
+ "file": "aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_sync.py",
+ "language": "PYTHON",
+ "origin": "API_DEFINITION",
+ "regionTag": "aiplatform_v1beta1_generated_VertexRagDataService_UpdateRagCorpus_sync",
+ "segments": [
+ {
+ "end": 58,
+ "start": 27,
+ "type": "FULL"
+ },
+ {
+ "end": 58,
+ "start": 27,
+ "type": "SHORT"
+ },
+ {
+ "end": 40,
+ "start": 38,
+ "type": "CLIENT_INITIALIZATION"
+ },
+ {
+ "end": 48,
+ "start": 41,
+ "type": "REQUEST_INITIALIZATION"
+ },
+ {
+ "end": 55,
+ "start": 49,
+ "type": "REQUEST_EXECUTION"
+ },
+ {
+ "end": 59,
+ "start": 56,
+ "type": "RESPONSE_HANDLING"
+ }
+ ],
+ "title": "aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_sync.py"
+ },
{
"canonical": true,
"clientMethod": {
diff --git a/setup.py b/setup.py
index 0786ec7215..95b1590948 100644
--- a/setup.py
+++ b/setup.py
@@ -98,10 +98,7 @@
autologging_extra_require = ["mlflow>=1.27.0,<=2.1.1"]
-preview_extra_require = [
- "cloudpickle < 3.0",
- "google-cloud-logging < 4.0",
-]
+preview_extra_require = []
ray_extra_require = [
# Cluster only supports 2.9.3. Keep 2.4.0 for our testing environment.
@@ -152,7 +149,6 @@
evaluation_extra_require = [
"pandas >= 1.0.0, < 2.2.0",
"tqdm>=4.23.0",
- "immutabledict",
]
langchain_extra_require = [
@@ -205,7 +201,6 @@
"grpcio-testing",
"ipython",
"kfp >= 2.6.0, < 3.0.0",
- "pyfakefs",
"pytest-asyncio",
"pytest-xdist",
"scikit-learn",
diff --git a/tests/system/aiplatform/test_vision_models.py b/tests/system/aiplatform/test_vision_models.py
index 33e5720d37..16929846f9 100644
--- a/tests/system/aiplatform/test_vision_models.py
+++ b/tests/system/aiplatform/test_vision_models.py
@@ -375,3 +375,63 @@ def test_image_verification_model_verify_image(self):
image_with_watermark
)
assert image_verification_response["decision"] == "ACCEPT"
+
+ def test_image_upscaling_model_upscale_image(self):
+ """Tests the image upscaling model upscaling an image."""
+ arbitrary_width = 1111
+ arbitrary_height = 2000
+ target_max_size = 4096
+ image_generation_model = vision_models.ImageGenerationModel.from_pretrained(
+ "imagen-3.0-generate-001"
+ )
+ blank_image = _create_blank_image(
+ width=arbitrary_width, height=arbitrary_height
+ )
+ assert blank_image.size == (arbitrary_width, arbitrary_height)
+ upscaled_image = image_generation_model.upscale_image(
+ blank_image, new_size=target_max_size
+ )
+ new_size = (
+ int(arbitrary_width / arbitrary_height * target_max_size),
+ target_max_size,
+ )
+ assert upscaled_image._size == new_size
+
+ def test_image_upscaling_model_upscaling_factor(self):
+ """Tests the image upscaling model upscaling an image."""
+ arbitrary_width = 1111
+ arbitrary_height = 2222
+ target_upscaling_factor = "x2"
+ image_generation_model = vision_models.ImageGenerationModel.from_pretrained(
+ "imagen-3.0-generate-001"
+ )
+ blank_image = _create_blank_image(
+ width=arbitrary_width, height=arbitrary_height
+ )
+ assert blank_image.size == (arbitrary_width, arbitrary_height)
+ upscaled_image = image_generation_model.upscale_image(
+ blank_image, upscale_factor=target_upscaling_factor
+ )
+ new_size = (arbitrary_width * 2, arbitrary_height * 2)
+ assert upscaled_image._size == new_size
+
+ def test_image_upscaling_model_save_as_jpeg(self):
+ """Tests the image upscaling model upscaling an image."""
+ arbitrary_width = 1111
+ arbitrary_height = 2222
+ target_upscaling_factor = "x2"
+ image_generation_model = vision_models.ImageGenerationModel.from_pretrained(
+ "imagen-3.0-generate-001"
+ )
+ blank_image = _create_blank_image(
+ width=arbitrary_width, height=arbitrary_height
+ )
+ assert blank_image.size == (arbitrary_width, arbitrary_height)
+ upscaled_image = image_generation_model.upscale_image(
+ blank_image,
+ upscale_factor=target_upscaling_factor,
+ output_mime_type="image/jpeg",
+ output_compression_quality=90,
+ )
+ assert upscaled_image._mime_type == "image/jpeg"
+ assert upscaled_image._size == (arbitrary_width * 2, arbitrary_height * 2)
diff --git a/tests/system/vertexai/test_generative_models.py b/tests/system/vertexai/test_generative_models.py
index 39fd0a4b99..e972225fe4 100644
--- a/tests/system/vertexai/test_generative_models.py
+++ b/tests/system/vertexai/test_generative_models.py
@@ -505,3 +505,36 @@ def test_compute_tokens_from_text(self, api_endpoint_env_name):
assert token_info.role
# Lightly validate that the tokens are not Base64 encoded
assert b"=" not in token_info.tokens
+
+ def test_count_tokens_from_text(self):
+ plain_model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
+ model = generative_models.GenerativeModel(
+ GEMINI_MODEL_NAME, system_instruction=["You are a chatbot."]
+ )
+ get_current_weather_func = generative_models.FunctionDeclaration.from_func(
+ get_current_weather
+ )
+ weather_tool = generative_models.Tool(
+ function_declarations=[get_current_weather_func],
+ )
+ content = ["Why is sky blue?", "Explain it like I'm 5."]
+
+ response_without_si = plain_model.count_tokens(content)
+ response_with_si = model.count_tokens(content)
+ response_with_si_and_tool = model.count_tokens(
+ content,
+ tools=[weather_tool],
+ )
+
+ # system instruction + user prompt
+ assert response_with_si.total_tokens > response_without_si.total_tokens
+ assert (
+ response_with_si.total_billable_characters
+ > response_without_si.total_billable_characters
+ )
+ # system instruction + user prompt + tool
+ assert response_with_si_and_tool.total_tokens > response_with_si.total_tokens
+ assert (
+ response_with_si_and_tool.total_billable_characters
+ > response_with_si.total_billable_characters
+ )
diff --git a/tests/system/vertexai/test_tokenization.py b/tests/system/vertexai/test_tokenization.py
index 8234aeff76..2795831a95 100644
--- a/tests/system/vertexai/test_tokenization.py
+++ b/tests/system/vertexai/test_tokenization.py
@@ -20,11 +20,24 @@
from nltk.corpus import udhr
from google.cloud import aiplatform
from vertexai.preview.tokenization import (
- get_tokenizer_for_model,
+ get_tokenizer_for_model as tokenizer_preview,
+)
+from vertexai.tokenization._tokenizers import (
+ get_tokenizer_for_model as tokenizer_ga,
+)
+from vertexai.generative_models import (
+ GenerativeModel,
+ Part,
+ Tool,
)
-from vertexai.generative_models import GenerativeModel
from tests.system.aiplatform import e2e_base
from google import auth
+from google.cloud.aiplatform_v1beta1.types import (
+ content as gapic_content_types,
+ tool as gapic_tool_types,
+ openapi,
+)
+from google.protobuf import struct_pb2
_MODELS = ["gemini-1.0-pro", "gemini-1.5-pro", "gemini-1.5-flash"]
@@ -34,11 +47,58 @@
_CORPUS_LIB = [
udhr,
]
+_VERSIONED_TOKENIZER = [tokenizer_preview, tokenizer_ga]
_MODEL_CORPUS_PARAMS = [
- (model_name, corpus_name, corpus_lib)
+ (get_tokenizer_for_model, model_name, corpus_name, corpus_lib)
+ for get_tokenizer_for_model in _VERSIONED_TOKENIZER
for model_name in _MODELS
for (corpus_name, corpus_lib) in zip(_CORPUS, _CORPUS_LIB)
]
+_STRUCT = struct_pb2.Struct(
+ fields={
+ "string_key": struct_pb2.Value(string_value="value"),
+ }
+)
+_FUNCTION_CALL = gapic_tool_types.FunctionCall(name="test_function_call", args=_STRUCT)
+_FUNCTION_RESPONSE = gapic_tool_types.FunctionResponse(
+ name="function_response",
+ response=_STRUCT,
+)
+
+
+_SCHEMA_1 = openapi.Schema(format="schema1_format", description="schema1_description")
+_SCHEMA_2 = openapi.Schema(format="schema2_format", description="schema2_description")
+_EXAMPLE = struct_pb2.Value(string_value="value1")
+
+_FUNCTION_DECLARATION_1 = gapic_tool_types.FunctionDeclaration(
+ name="function_declaration_name",
+ description="function_declaration_description",
+ parameters=openapi.Schema(
+ format="schema_format",
+ description="schema_description",
+ enum=["schema_enum1", "schema_enum2"],
+ required=["schema_required1", "schema_required2"],
+ items=_SCHEMA_2,
+ properties={"property_key": _SCHEMA_1},
+ example=_EXAMPLE,
+ ),
+)
+_FUNCTION_DECLARATION_2 = gapic_tool_types.FunctionDeclaration(
+ parameters=openapi.Schema(
+ nullable=True,
+ default=struct_pb2.Value(string_value="value1"),
+ min_items=0,
+ max_items=0,
+ min_properties=0,
+ max_properties=0,
+ minimum=0,
+ maximum=0,
+ min_length=0,
+ max_length=0,
+ pattern="pattern",
+ ),
+ response=_SCHEMA_1,
+)
STAGING_API_ENDPOINT = "STAGING_ENDPOINT"
PROD_API_ENDPOINT = "PROD_ENDPOINT"
@@ -70,11 +130,16 @@ def setup_method(self, api_endpoint_env_name):
)
@pytest.mark.parametrize(
- "model_name, corpus_name, corpus_lib",
+ "get_tokenizer_for_model, model_name, corpus_name, corpus_lib",
_MODEL_CORPUS_PARAMS,
)
def test_count_tokens_local(
- self, model_name, corpus_name, corpus_lib, api_endpoint_env_name
+ self,
+ get_tokenizer_for_model,
+ model_name,
+ corpus_name,
+ corpus_lib,
+ api_endpoint_env_name,
):
# The Gemini 1.5 flash model requires the model version
# number suffix (001) in staging only
@@ -90,11 +155,16 @@ def test_count_tokens_local(
assert service_result.total_tokens == local_result.total_tokens
@pytest.mark.parametrize(
- "model_name, corpus_name, corpus_lib",
+ "get_tokenizer_for_model, model_name, corpus_name, corpus_lib",
_MODEL_CORPUS_PARAMS,
)
def test_compute_tokens(
- self, model_name, corpus_name, corpus_lib, api_endpoint_env_name
+ self,
+ get_tokenizer_for_model,
+ model_name,
+ corpus_name,
+ corpus_lib,
+ api_endpoint_env_name,
):
# The Gemini 1.5 flash model requires the model version
# number suffix (001) in staging only
@@ -107,8 +177,107 @@ def test_compute_tokens(
text = corpus_lib.raw(book)
response = model.compute_tokens(text)
local_result = tokenizer.compute_tokens(text)
- for local, service in zip(
- local_result.token_info_list, response.tokens_info
- ):
+ for local, service in zip(local_result.tokens_info, response.tokens_info):
assert local.tokens == service.tokens
assert local.token_ids == service.token_ids
+
+ @pytest.mark.parametrize(
+ "model_name",
+ _MODELS,
+ )
+ def test_count_tokens_system_instruction(self, model_name):
+ tokenizer = tokenizer_preview(model_name)
+ model = GenerativeModel(model_name, system_instruction=["You are a chatbot."])
+
+ assert (
+ tokenizer.count_tokens(
+ "hello", system_instruction=["You are a chatbot."]
+ ).total_tokens
+ == model.count_tokens("hello").total_tokens
+ )
+
+ @pytest.mark.parametrize(
+ "model_name",
+ _MODELS,
+ )
+ def test_count_tokens_system_instruction_is_function_call(self, model_name):
+ part = Part._from_gapic(gapic_content_types.Part(function_call=_FUNCTION_CALL))
+
+ tokenizer = tokenizer_preview(model_name)
+ model = GenerativeModel(model_name, system_instruction=[part])
+
+ assert (
+ tokenizer.count_tokens("hello", system_instruction=[part]).total_tokens
+ == model.count_tokens("hello").total_tokens
+ )
+
+ @pytest.mark.parametrize(
+ "model_name",
+ _MODELS,
+ )
+ def test_count_tokens_system_instruction_is_function_response(self, model_name):
+ part = Part._from_gapic(
+ gapic_content_types.Part(function_response=_FUNCTION_RESPONSE)
+ )
+ tokenizer = tokenizer_preview(model_name)
+ model = GenerativeModel(model_name, system_instruction=[part])
+
+ assert tokenizer.count_tokens(part, system_instruction=[part]).total_tokens
+ assert (
+ tokenizer.count_tokens("hello", system_instruction=[part]).total_tokens
+ == model.count_tokens("hello").total_tokens
+ )
+
+ @pytest.mark.parametrize(
+ "model_name",
+ _MODELS,
+ )
+ def test_count_tokens_tool_is_function_declaration(self, model_name):
+ tokenizer = tokenizer_preview(model_name)
+ model = GenerativeModel(model_name)
+ tool1 = Tool._from_gapic(
+ gapic_tool_types.Tool(function_declarations=[_FUNCTION_DECLARATION_1])
+ )
+ tool2 = Tool._from_gapic(
+ gapic_tool_types.Tool(function_declarations=[_FUNCTION_DECLARATION_2])
+ )
+
+ assert tokenizer.count_tokens("hello", tools=[tool1]).total_tokens
+ with pytest.raises(ValueError):
+ tokenizer.count_tokens("hello", tools=[tool2]).total_tokens
+ assert (
+ tokenizer.count_tokens("hello", tools=[tool1]).total_tokens
+ == model.count_tokens("hello", tools=[tool1]).total_tokens
+ )
+
+ @pytest.mark.parametrize(
+ "model_name",
+ _MODELS,
+ )
+ def test_count_tokens_content_is_function_call(self, model_name):
+ part = Part._from_gapic(gapic_content_types.Part(function_call=_FUNCTION_CALL))
+ tokenizer = tokenizer_preview(model_name)
+ model = GenerativeModel(model_name)
+
+ assert tokenizer.count_tokens(part).total_tokens
+ assert (
+ tokenizer.count_tokens(part).total_tokens
+ == model.count_tokens(part).total_tokens
+ )
+
+ @pytest.mark.parametrize(
+ "model_name",
+ _MODELS,
+ )
+ def test_count_tokens_content_is_function_response(self, model_name):
+ part = Part._from_gapic(
+ gapic_content_types.Part(function_response=_FUNCTION_RESPONSE)
+ )
+ tokenizer = tokenizer_preview(model_name)
+ model = GenerativeModel(model_name)
+
+ assert tokenizer.count_tokens(part).total_tokens
+ assert (
+ tokenizer.count_tokens(part).total_tokens
+ == model.count_tokens(part).total_tokens
+ )
diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py
index a13de3158d..25553d8ad7 100644
--- a/tests/unit/aiplatform/test_endpoints.py
+++ b/tests/unit/aiplatform/test_endpoints.py
@@ -18,8 +18,8 @@
import copy
from datetime import datetime, timedelta
from importlib import reload
-import requests
import json
+import requests
from unittest import mock
from google.api_core import operation as ga_operation
@@ -920,6 +920,49 @@ def predict_private_endpoint_mock():
yield predict_mock
+@pytest.fixture
+def stream_raw_predict_private_endpoint_mock():
+ with mock.patch.object(
+ google_auth_requests.AuthorizedSession, "post"
+ ) as stream_raw_predict_mock:
+ # Create a mock response object
+ mock_response = mock.Mock(spec=requests.Response)
+
+ # Configure the mock to be used as a context manager
+ stream_raw_predict_mock.return_value.__enter__.return_value = mock_response
+
+ # Set the status code to 200 (OK)
+ mock_response.status_code = 200
+
+ # Simulate streaming data with iter_lines
+ mock_response.iter_lines = mock.Mock(
+ return_value=iter(
+ [
+ json.dumps(
+ {
+ "predictions": [1.0, 2.0, 3.0],
+ "metadata": {"key": "value"},
+ "deployedModelId": "model-id-123",
+ "model": "model-name",
+ "modelVersionId": "1",
+ }
+ ).encode("utf-8"),
+ json.dumps(
+ {
+ "predictions": [4.0, 5.0, 6.0],
+ "metadata": {"key": "value"},
+ "deployedModelId": "model-id-123",
+ "model": "model-name",
+ "modelVersionId": "1",
+ }
+ ).encode("utf-8"),
+ ]
+ )
+ )
+
+ yield stream_raw_predict_mock
+
+
@pytest.fixture
def health_check_private_endpoint_mock():
with mock.patch.object(urllib3.PoolManager, "request") as health_check_mock:
@@ -3195,6 +3238,57 @@ def test_psc_predict(self, predict_private_endpoint_mock):
},
)
+ @pytest.mark.usefixtures("get_psc_private_endpoint_mock")
+ def test_psc_stream_raw_predict(self, stream_raw_predict_private_endpoint_mock):
+ test_endpoint = models.PrivateEndpoint(
+ project=_TEST_PROJECT, location=_TEST_LOCATION, endpoint_name=_TEST_ID
+ )
+
+ test_prediction_iterator = test_endpoint.stream_raw_predict(
+ body='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}',
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": "Bearer None",
+ },
+ endpoint_override=_TEST_ENDPOINT_OVERRIDE,
+ )
+
+ test_prediction = list(test_prediction_iterator)
+
+ stream_raw_predict_private_endpoint_mock.assert_called_once_with(
+ url=f"https://{_TEST_ENDPOINT_OVERRIDE}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:streamRawPredict",
+ data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}',
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": "Bearer None",
+ },
+ stream=True,
+ verify=False,
+ )
+
+ # Validate the content of the returned predictions
+ expected_predictions = [
+ json.dumps(
+ {
+ "predictions": [1.0, 2.0, 3.0],
+ "metadata": {"key": "value"},
+ "deployedModelId": "model-id-123",
+ "model": "model-name",
+ "modelVersionId": "1",
+ }
+ ).encode("utf-8"),
+ json.dumps(
+ {
+ "predictions": [4.0, 5.0, 6.0],
+ "metadata": {"key": "value"},
+ "deployedModelId": "model-id-123",
+ "model": "model-name",
+ "modelVersionId": "1",
+ }
+ ).encode("utf-8"),
+ ]
+ assert test_prediction == expected_predictions
+
@pytest.mark.usefixtures("get_psc_private_endpoint_mock")
def test_psc_predict_without_endpoint_override(self):
test_endpoint = models.PrivateEndpoint(
diff --git a/tests/unit/aiplatform/test_logdir_loader.py b/tests/unit/aiplatform/test_logdir_loader.py
index 90c87464a2..648f39022d 100644
--- a/tests/unit/aiplatform/test_logdir_loader.py
+++ b/tests/unit/aiplatform/test_logdir_loader.py
@@ -201,15 +201,15 @@ def test_multiple_writes_to_logdir(self):
{
"a": ["tag_a"],
"b": ["tag_b"],
- "b-x": ["tag_b_x"],
- "b-z": ["tag_b_z"],
+ "b/x": ["tag_b_x"],
+ "b_z": ["tag_b_z"],
"c": ["tag_c"],
},
)
# A second load should indicate no new data.
self.assertEqual(
self._extract_run_to_tags(loader.get_run_events()),
- {"a": [], "b": [], "b-x": [], "b-z": [], "c": []},
+ {"a": [], "b": [], "b/x": [], "b_z": [], "c": []},
)
# Write some new data to both new and pre-existing event files.
with FileWriter(os.path.join(logdir, "a"), filename_suffix=".other") as writer:
@@ -228,8 +228,8 @@ def test_multiple_writes_to_logdir(self):
{
"a": ["tag_a_2", "tag_a_3", "tag_a_4"],
"b": [],
- "b-x": ["tag_b_x_2"],
- "b-z": [],
+ "b/x": ["tag_b_x_2"],
+ "b_z": [],
"c": ["tag_c_2"],
},
)
diff --git a/tests/unit/aiplatform/test_vision_models.py b/tests/unit/aiplatform/test_vision_models.py
index 0323d0f5bb..c2fcf578c0 100644
--- a/tests/unit/aiplatform/test_vision_models.py
+++ b/tests/unit/aiplatform/test_vision_models.py
@@ -22,7 +22,7 @@
import io
import os
import tempfile
-from typing import Any, Dict
+from typing import Any, Dict, Literal, Optional, Tuple
import unittest
from unittest import mock
@@ -39,12 +39,10 @@
from google.cloud.aiplatform.compat.types import (
publisher_model as gca_publisher_model,
)
-
from vertexai import vision_models as ga_vision_models
from vertexai.preview import (
vision_models as preview_vision_models,
)
-
from PIL import Image as PIL_Image
import pytest
@@ -131,20 +129,39 @@ def make_image_generation_response_gcs(count: int = 1) -> Dict[str, Any]:
for _ in range(count):
predictions.append(
{
- "gcsUri": (
- "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
- ),
+ "gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png",
"mimeType": "image/png",
}
)
return {"predictions": predictions}
-def make_image_upscale_response(upscale_size: int) -> Dict[str, Any]:
- predictions = {
- "bytesBase64Encoded": make_image_base64(upscale_size, upscale_size),
- "mimeType": "image/png",
- }
+def make_image_upscale_response(
+ upscale_size: Optional[int] = None,
+ upscale_factor: Optional[Literal["x2", "x4"]] = None,
+ input_size: Optional[Tuple[int, int]] = None,
+) -> Dict[str, Any]:
+ predictions = {}
+ if upscale_size:
+ predictions = {
+ "bytesBase64Encoded": make_image_base64(upscale_size, upscale_size),
+ "mimeType": "image/png",
+ }
+ else:
+ if upscale_factor == "x2":
+ predictions = {
+ "bytesBase64Encoded": make_image_base64(
+ input_size[0] * 2, input_size[1] * 2
+ ),
+ "mimeType": "image/png",
+ }
+ else:
+ predictions = {
+ "bytesBase64Encoded": make_image_base64(
+ input_size[0] * 4, input_size[1] * 4
+ ),
+ "mimeType": "image/png",
+ }
return {"predictions": [predictions]}
@@ -723,7 +740,7 @@ def test_upscale_image_on_provided_image(self):
model = self._get_image_generation_model()
image_generation_response = make_image_generation_response(
- count=1, height=1024, width=1024
+ count=1, height=1111, width=2000
)
gca_generation_response = gca_prediction_service.PredictResponse()
gca_generation_response.predictions.extend(
@@ -739,7 +756,7 @@ def test_upscale_image_on_provided_image(self):
attribute="predict",
return_value=gca_upscale_response,
) as mock_upscale:
- test_image = generate_image_from_file(height=1024, width=1024)
+ test_image = generate_image_from_file(height=1111, width=2000)
upscaled_image = model.upscale_image(image=test_image, new_size=4096)
@@ -748,24 +765,107 @@ def test_upscale_image_on_provided_image(self):
assert actual_instance["image"]["bytesBase64Encoded"]
image_upscale_parameters = predict_kwargs["parameters"]
- assert (
- image_upscale_parameters["sampleImageSize"]
- == str(upscaled_image._size[0])
- == str(upscaled_image.generation_parameters["upscaled_image_size"])
+ assert image_upscale_parameters["sampleImageSize"] == str(
+ upscaled_image.generation_parameters["upscaled_image_size"]
)
assert image_upscale_parameters["mode"] == "upscale"
assert upscaled_image._image_bytes
assert isinstance(upscaled_image, preview_vision_models.GeneratedImage)
- def test_upscale_image_raises_if_not_1024x1024(self):
+ def test_upscale_image_on_wrong_image_size(self):
+ """Tests image upscaling on generated images."""
+ model = self._get_image_generation_model()
+ image_generation_response = make_image_generation_response(
+ count=1, height=1111, width=2222
+ )
+ gca_generation_response = gca_prediction_service.PredictResponse()
+ gca_generation_response.predictions.extend(
+ image_generation_response["predictions"]
+ )
+
+ image_upscale_response = make_image_upscale_response(upscale_size=4096)
+ gca_upscale_response = gca_prediction_service.PredictResponse()
+ gca_upscale_response.predictions.extend(image_upscale_response["predictions"])
+
+ with mock.patch.object(
+ target=prediction_service_client.PredictionServiceClient,
+ attribute="predict",
+ return_value=gca_upscale_response,
+ ):
+ test_image = generate_image_from_file(height=1111, width=2222)
+ with pytest.raises(ValueError):
+ model.upscale_image(image=test_image, new_size=4096)
+
+ def test_upscale_image_with_upscale_factor(self):
+ """Tests image upscaling on generated images."""
+ model = self._get_image_generation_model()
+ image_generation_response = make_image_generation_response(
+ count=1, height=1111, width=2222
+ )
+ gca_generation_response = gca_prediction_service.PredictResponse()
+ gca_generation_response.predictions.extend(
+ image_generation_response["predictions"]
+ )
+
+ image_upscale_response = make_image_upscale_response(
+ upscale_factor="x2", input_size=(1111, 2222)
+ )
+ gca_upscale_response = gca_prediction_service.PredictResponse()
+ gca_upscale_response.predictions.extend(image_upscale_response["predictions"])
+ with mock.patch.object(
+ target=prediction_service_client.PredictionServiceClient,
+ attribute="predict",
+ return_value=gca_upscale_response,
+ ) as mock_upscale:
+ test_image = generate_image_from_file(height=1111, width=2222)
+
+ upscaled_image = model.upscale_image(image=test_image, upscale_factor="x2")
+
+ predict_kwargs = mock_upscale.call_args[1]
+ actual_instance = predict_kwargs["instances"][0]
+ assert actual_instance["image"]["bytesBase64Encoded"]
+
+ image_upscale_parameters = predict_kwargs["parameters"]
+ assert image_upscale_parameters["mode"] == "upscale"
+ assert "sampleImageSize" not in image_upscale_parameters
+
+ assert upscaled_image._image_bytes
+ assert isinstance(upscaled_image, preview_vision_models.GeneratedImage)
+
+ def test_upscale_image_with_jpeg_output(self):
"""Tests image upscaling on generated images."""
model = self._get_image_generation_model()
+ image_generation_response = make_image_generation_response(
+ count=1, height=1111, width=2222
+ )
+ gca_generation_response = gca_prediction_service.PredictResponse()
+ gca_generation_response.predictions.extend(
+ image_generation_response["predictions"]
+ )
+ with mock.patch.object(
+ target=prediction_service_client.PredictionServiceClient,
+ attribute="predict",
+ return_value=gca_generation_response,
+ ) as mock_upscale:
+ test_image = generate_image_from_file(height=1111, width=2222)
+ upscaled_image = model.upscale_image(
+ image=test_image,
+ upscale_factor="x2",
+ output_mime_type="image/jpeg",
+ output_compression_quality=90,
+ )
- test_image = generate_image_from_file(height=100, width=100)
+ predict_kwargs = mock_upscale.call_args[1]
+ actual_instance = predict_kwargs["instances"][0]
+ assert actual_instance["image"]["bytesBase64Encoded"]
- with pytest.raises(ValueError):
- model.upscale_image(image=test_image)
+ image_upscale_parameters = predict_kwargs["parameters"]
+ assert image_upscale_parameters["mode"] == "upscale"
+ assert image_upscale_parameters["outputOptions"]["mimeType"] == "image/jpeg"
+ assert image_upscale_parameters["outputOptions"]["compressionQuality"] == 90
+ assert upscaled_image._image_bytes
+ assert isinstance(upscaled_image, preview_vision_models.GeneratedImage)
@pytest.mark.usefixtures("google_auth_mock")
diff --git a/tests/unit/gapic/aiplatform_v1/test_job_service.py b/tests/unit/gapic/aiplatform_v1/test_job_service.py
index 3d77876553..2deb4c5275 100644
--- a/tests/unit/gapic/aiplatform_v1/test_job_service.py
+++ b/tests/unit/gapic/aiplatform_v1/test_job_service.py
@@ -16816,6 +16816,7 @@ def test_create_custom_job_rest(request_type):
"restart_job_on_worker_restart": True,
"strategy": 1,
"disable_retries": True,
+ "max_wait_duration": {},
},
"service_account": "service_account_value",
"network": "network_value",
@@ -20425,6 +20426,7 @@ def test_create_hyperparameter_tuning_job_rest(request_type):
"restart_job_on_worker_restart": True,
"strategy": 1,
"disable_retries": True,
+ "max_wait_duration": {},
},
"service_account": "service_account_value",
"network": "network_value",
@@ -22322,6 +22324,7 @@ def test_create_nas_job_rest(request_type):
"restart_job_on_worker_restart": True,
"strategy": 1,
"disable_retries": True,
+ "max_wait_duration": {},
},
"service_account": "service_account_value",
"network": "network_value",
diff --git a/tests/unit/gapic/aiplatform_v1/test_schedule_service.py b/tests/unit/gapic/aiplatform_v1/test_schedule_service.py
index e0e59c7aca..439956e90c 100644
--- a/tests/unit/gapic/aiplatform_v1/test_schedule_service.py
+++ b/tests/unit/gapic/aiplatform_v1/test_schedule_service.py
@@ -60,6 +60,9 @@
from google.cloud.aiplatform_v1.types import context
from google.cloud.aiplatform_v1.types import encryption_spec
from google.cloud.aiplatform_v1.types import execution
+from google.cloud.aiplatform_v1.types import job_state
+from google.cloud.aiplatform_v1.types import notebook_execution_job
+from google.cloud.aiplatform_v1.types import notebook_service
from google.cloud.aiplatform_v1.types import operation as gca_operation
from google.cloud.aiplatform_v1.types import pipeline_failure_policy
from google.cloud.aiplatform_v1.types import pipeline_job
@@ -76,6 +79,7 @@
from google.longrunning import operations_pb2 # type: ignore
from google.oauth2 import service_account
from google.protobuf import any_pb2 # type: ignore
+from google.protobuf import duration_pb2 # type: ignore
from google.protobuf import empty_pb2 # type: ignore
from google.protobuf import field_mask_pb2 # type: ignore
from google.protobuf import struct_pb2 # type: ignore
@@ -4158,6 +4162,35 @@ def test_create_schedule_rest(request_type):
},
"pipeline_job_id": "pipeline_job_id_value",
},
+ "create_notebook_execution_job_request": {
+ "parent": "parent_value",
+ "notebook_execution_job": {
+ "dataform_repository_source": {
+ "dataform_repository_resource_name": "dataform_repository_resource_name_value",
+ "commit_sha": "commit_sha_value",
+ },
+ "gcs_notebook_source": {
+ "uri": "uri_value",
+ "generation": "generation_value",
+ },
+ "direct_notebook_source": {"content": b"content_blob"},
+ "notebook_runtime_template_resource_name": "notebook_runtime_template_resource_name_value",
+ "gcs_output_uri": "gcs_output_uri_value",
+ "execution_user": "execution_user_value",
+ "service_account": "service_account_value",
+ "name": "name_value",
+ "display_name": "display_name_value",
+ "execution_timeout": {"seconds": 751, "nanos": 543},
+ "schedule_resource_name": "schedule_resource_name_value",
+ "job_state": 1,
+ "status": {},
+ "create_time": {},
+ "update_time": {},
+ "labels": {},
+ "encryption_spec": {},
+ },
+ "notebook_execution_job_id": "notebook_execution_job_id_value",
+ },
"name": "name_value",
"display_name": "display_name_value",
"start_time": {},
@@ -6263,6 +6296,35 @@ def test_update_schedule_rest(request_type):
},
"pipeline_job_id": "pipeline_job_id_value",
},
+ "create_notebook_execution_job_request": {
+ "parent": "parent_value",
+ "notebook_execution_job": {
+ "dataform_repository_source": {
+ "dataform_repository_resource_name": "dataform_repository_resource_name_value",
+ "commit_sha": "commit_sha_value",
+ },
+ "gcs_notebook_source": {
+ "uri": "uri_value",
+ "generation": "generation_value",
+ },
+ "direct_notebook_source": {"content": b"content_blob"},
+ "notebook_runtime_template_resource_name": "notebook_runtime_template_resource_name_value",
+ "gcs_output_uri": "gcs_output_uri_value",
+ "execution_user": "execution_user_value",
+ "service_account": "service_account_value",
+ "name": "name_value",
+ "display_name": "display_name_value",
+ "execution_timeout": {"seconds": 751, "nanos": 543},
+ "schedule_resource_name": "schedule_resource_name_value",
+ "job_state": 1,
+ "status": {},
+ "create_time": {},
+ "update_time": {},
+ "labels": {},
+ "encryption_spec": {},
+ },
+ "notebook_execution_job_id": "notebook_execution_job_id_value",
+ },
"name": "projects/sample1/locations/sample2/schedules/sample3",
"display_name": "display_name_value",
"start_time": {},
@@ -7426,6 +7488,62 @@ def test_parse_network_path():
assert expected == actual
+def test_notebook_execution_job_path():
+ project = "scallop"
+ location = "abalone"
+ notebook_execution_job = "squid"
+ expected = "projects/{project}/locations/{location}/notebookExecutionJobs/{notebook_execution_job}".format(
+ project=project,
+ location=location,
+ notebook_execution_job=notebook_execution_job,
+ )
+ actual = ScheduleServiceClient.notebook_execution_job_path(
+ project, location, notebook_execution_job
+ )
+ assert expected == actual
+
+
+def test_parse_notebook_execution_job_path():
+ expected = {
+ "project": "clam",
+ "location": "whelk",
+ "notebook_execution_job": "octopus",
+ }
+ path = ScheduleServiceClient.notebook_execution_job_path(**expected)
+
+ # Check that the path construction is reversible.
+ actual = ScheduleServiceClient.parse_notebook_execution_job_path(path)
+ assert expected == actual
+
+
+def test_notebook_runtime_template_path():
+ project = "oyster"
+ location = "nudibranch"
+ notebook_runtime_template = "cuttlefish"
+ expected = "projects/{project}/locations/{location}/notebookRuntimeTemplates/{notebook_runtime_template}".format(
+ project=project,
+ location=location,
+ notebook_runtime_template=notebook_runtime_template,
+ )
+ actual = ScheduleServiceClient.notebook_runtime_template_path(
+ project, location, notebook_runtime_template
+ )
+ assert expected == actual
+
+
+def test_parse_notebook_runtime_template_path():
+ expected = {
+ "project": "mussel",
+ "location": "winkle",
+ "notebook_runtime_template": "nautilus",
+ }
+ path = ScheduleServiceClient.notebook_runtime_template_path(**expected)
+
+ # Check that the path construction is reversible.
+ actual = ScheduleServiceClient.parse_notebook_runtime_template_path(path)
+ assert expected == actual
+
+
def test_pipeline_job_path():
project = "scallop"
location = "abalone"
diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py
index 17d35098b2..6924ccaa6c 100644
--- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py
+++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py
@@ -16819,6 +16819,7 @@ def test_create_custom_job_rest(request_type):
"restart_job_on_worker_restart": True,
"strategy": 1,
"disable_retries": True,
+ "max_wait_duration": {},
},
"service_account": "service_account_value",
"network": "network_value",
@@ -20446,6 +20447,7 @@ def test_create_hyperparameter_tuning_job_rest(request_type):
"restart_job_on_worker_restart": True,
"strategy": 1,
"disable_retries": True,
+ "max_wait_duration": {},
},
"service_account": "service_account_value",
"network": "network_value",
@@ -22343,6 +22345,7 @@ def test_create_nas_job_rest(request_type):
"restart_job_on_worker_restart": True,
"strategy": 1,
"disable_retries": True,
+ "max_wait_duration": {},
},
"service_account": "service_account_value",
"network": "network_value",
diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py
index 07cd3e052c..a891b9f491 100644
--- a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py
+++ b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py
@@ -1609,11 +1609,11 @@ async def test_create_rag_corpus_flattened_error_async():
@pytest.mark.parametrize(
"request_type",
[
- vertex_rag_data_service.GetRagCorpusRequest,
+ vertex_rag_data_service.UpdateRagCorpusRequest,
dict,
],
)
-def test_get_rag_corpus(request_type, transport: str = "grpc"):
+def test_update_rag_corpus(request_type, transport: str = "grpc"):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
@@ -1624,29 +1624,24 @@ def test_get_rag_corpus(request_type, transport: str = "grpc"):
request = request_type()
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.update_rag_corpus), "__call__"
+ ) as call:
# Designate an appropriate return value for the call.
- call.return_value = vertex_rag_data.RagCorpus(
- name="name_value",
- display_name="display_name_value",
- description="description_value",
- )
- response = client.get_rag_corpus(request)
+ call.return_value = operations_pb2.Operation(name="operations/spam")
+ response = client.update_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.GetRagCorpusRequest()
+ request = vertex_rag_data_service.UpdateRagCorpusRequest()
assert args[0] == request
# Establish that the response is the type that we expect.
- assert isinstance(response, vertex_rag_data.RagCorpus)
- assert response.name == "name_value"
- assert response.display_name == "display_name_value"
- assert response.description == "description_value"
+ assert isinstance(response, future.Future)
-def test_get_rag_corpus_empty_call():
+def test_update_rag_corpus_empty_call():
# This test is a coverage failsafe to make sure that totally empty calls,
# i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceClient(
@@ -1655,17 +1650,19 @@ def test_get_rag_corpus_empty_call():
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.update_rag_corpus), "__call__"
+ ) as call:
call.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client.get_rag_corpus()
+ client.update_rag_corpus()
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.GetRagCorpusRequest()
+ assert args[0] == vertex_rag_data_service.UpdateRagCorpusRequest()
-def test_get_rag_corpus_non_empty_request_with_auto_populated_field():
+def test_update_rag_corpus_non_empty_request_with_auto_populated_field():
# This test is a coverage failsafe to make sure that UUID4 fields are
# automatically populated, according to AIP-4235, with non-empty requests.
client = VertexRagDataServiceClient(
@@ -1676,24 +1673,22 @@ def test_get_rag_corpus_non_empty_request_with_auto_populated_field():
# Populate all string fields in the request which are not UUID4
# since we want to check that UUID4 are populated automatically
# if they meet the requirements of AIP 4235.
- request = vertex_rag_data_service.GetRagCorpusRequest(
- name="name_value",
- )
+ request = vertex_rag_data_service.UpdateRagCorpusRequest()
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.update_rag_corpus), "__call__"
+ ) as call:
call.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client.get_rag_corpus(request=request)
+ client.update_rag_corpus(request=request)
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.GetRagCorpusRequest(
- name="name_value",
- )
+ assert args[0] == vertex_rag_data_service.UpdateRagCorpusRequest()
-def test_get_rag_corpus_use_cached_wrapped_rpc():
+def test_update_rag_corpus_use_cached_wrapped_rpc():
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
# instead of constructing them on each call
with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn:
@@ -1707,21 +1702,28 @@ def test_get_rag_corpus_use_cached_wrapped_rpc():
wrapper_fn.reset_mock()
# Ensure method has been cached
- assert client._transport.get_rag_corpus in client._transport._wrapped_methods
+ assert client._transport.update_rag_corpus in client._transport._wrapped_methods
# Replace cached wrapped function with mock
mock_rpc = mock.Mock()
mock_rpc.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client._transport._wrapped_methods[client._transport.get_rag_corpus] = mock_rpc
+ client._transport._wrapped_methods[
+ client._transport.update_rag_corpus
+ ] = mock_rpc
request = {}
- client.get_rag_corpus(request)
+ client.update_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- client.get_rag_corpus(request)
+ # Operation methods call wrapper_fn to build a cached
+ # client._transport.operations_client instance on first rpc call.
+ # Subsequent calls should use the cached wrapper
+ wrapper_fn.reset_mock()
+
+ client.update_rag_corpus(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -1729,7 +1731,7 @@ def test_get_rag_corpus_use_cached_wrapped_rpc():
@pytest.mark.asyncio
-async def test_get_rag_corpus_empty_call_async():
+async def test_update_rag_corpus_empty_call_async():
# This test is a coverage failsafe to make sure that totally empty calls,
# i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceAsyncClient(
@@ -1738,23 +1740,21 @@ async def test_get_rag_corpus_empty_call_async():
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.update_rag_corpus), "__call__"
+ ) as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data.RagCorpus(
- name="name_value",
- display_name="display_name_value",
- description="description_value",
- )
+ operations_pb2.Operation(name="operations/spam")
)
- response = await client.get_rag_corpus()
+ response = await client.update_rag_corpus()
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.GetRagCorpusRequest()
+ assert args[0] == vertex_rag_data_service.UpdateRagCorpusRequest()
@pytest.mark.asyncio
-async def test_get_rag_corpus_async_use_cached_wrapped_rpc(
+async def test_update_rag_corpus_async_use_cached_wrapped_rpc(
transport: str = "grpc_asyncio",
):
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
@@ -1771,7 +1771,7 @@ async def test_get_rag_corpus_async_use_cached_wrapped_rpc(
# Ensure method has been cached
assert (
- client._client._transport.get_rag_corpus
+ client._client._transport.update_rag_corpus
in client._client._transport._wrapped_methods
)
@@ -1779,16 +1779,21 @@ async def test_get_rag_corpus_async_use_cached_wrapped_rpc(
mock_rpc = mock.AsyncMock()
mock_rpc.return_value = mock.Mock()
client._client._transport._wrapped_methods[
- client._client._transport.get_rag_corpus
+ client._client._transport.update_rag_corpus
] = mock_rpc
request = {}
- await client.get_rag_corpus(request)
+ await client.update_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- await client.get_rag_corpus(request)
+ # Operation methods call wrapper_fn to build a cached
+ # client._transport.operations_client instance on first rpc call.
+ # Subsequent calls should use the cached wrapper
+ wrapper_fn.reset_mock()
+
+ await client.update_rag_corpus(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -1796,9 +1801,9 @@ async def test_get_rag_corpus_async_use_cached_wrapped_rpc(
@pytest.mark.asyncio
-async def test_get_rag_corpus_async(
+async def test_update_rag_corpus_async(
transport: str = "grpc_asyncio",
- request_type=vertex_rag_data_service.GetRagCorpusRequest,
+ request_type=vertex_rag_data_service.UpdateRagCorpusRequest,
):
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
@@ -1810,50 +1815,47 @@ async def test_get_rag_corpus_async(
request = request_type()
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.update_rag_corpus), "__call__"
+ ) as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data.RagCorpus(
- name="name_value",
- display_name="display_name_value",
- description="description_value",
- )
+ operations_pb2.Operation(name="operations/spam")
)
- response = await client.get_rag_corpus(request)
+ response = await client.update_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.GetRagCorpusRequest()
+ request = vertex_rag_data_service.UpdateRagCorpusRequest()
assert args[0] == request
# Establish that the response is the type that we expect.
- assert isinstance(response, vertex_rag_data.RagCorpus)
- assert response.name == "name_value"
- assert response.display_name == "display_name_value"
- assert response.description == "description_value"
+ assert isinstance(response, future.Future)
@pytest.mark.asyncio
-async def test_get_rag_corpus_async_from_dict():
- await test_get_rag_corpus_async(request_type=dict)
+async def test_update_rag_corpus_async_from_dict():
+ await test_update_rag_corpus_async(request_type=dict)
-def test_get_rag_corpus_field_headers():
+def test_update_rag_corpus_field_headers():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.GetRagCorpusRequest()
+ request = vertex_rag_data_service.UpdateRagCorpusRequest()
- request.name = "name_value"
+ request.rag_corpus.name = "name_value"
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
- call.return_value = vertex_rag_data.RagCorpus()
- client.get_rag_corpus(request)
+ with mock.patch.object(
+ type(client.transport.update_rag_corpus), "__call__"
+ ) as call:
+ call.return_value = operations_pb2.Operation(name="operations/op")
+ client.update_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
@@ -1864,28 +1866,30 @@ def test_get_rag_corpus_field_headers():
_, _, kw = call.mock_calls[0]
assert (
"x-goog-request-params",
- "name=name_value",
+ "rag_corpus.name=name_value",
) in kw["metadata"]
@pytest.mark.asyncio
-async def test_get_rag_corpus_field_headers_async():
+async def test_update_rag_corpus_field_headers_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.GetRagCorpusRequest()
+ request = vertex_rag_data_service.UpdateRagCorpusRequest()
- request.name = "name_value"
+ request.rag_corpus.name = "name_value"
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.update_rag_corpus), "__call__"
+ ) as call:
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data.RagCorpus()
+ operations_pb2.Operation(name="operations/op")
)
- await client.get_rag_corpus(request)
+ await client.update_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls)
@@ -1896,35 +1900,37 @@ async def test_get_rag_corpus_field_headers_async():
_, _, kw = call.mock_calls[0]
assert (
"x-goog-request-params",
- "name=name_value",
+ "rag_corpus.name=name_value",
) in kw["metadata"]
-def test_get_rag_corpus_flattened():
+def test_update_rag_corpus_flattened():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.update_rag_corpus), "__call__"
+ ) as call:
# Designate an appropriate return value for the call.
- call.return_value = vertex_rag_data.RagCorpus()
+ call.return_value = operations_pb2.Operation(name="operations/op")
# Call the method with a truthy value for each flattened field,
# using the keyword arguments to the method.
- client.get_rag_corpus(
- name="name_value",
+ client.update_rag_corpus(
+ rag_corpus=vertex_rag_data.RagCorpus(name="name_value"),
)
# Establish that the underlying call was made with the expected
# request object values.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
- arg = args[0].name
- mock_val = "name_value"
+ arg = args[0].rag_corpus
+ mock_val = vertex_rag_data.RagCorpus(name="name_value")
assert arg == mock_val
-def test_get_rag_corpus_flattened_error():
+def test_update_rag_corpus_flattened_error():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
@@ -1932,43 +1938,45 @@ def test_get_rag_corpus_flattened_error():
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- client.get_rag_corpus(
- vertex_rag_data_service.GetRagCorpusRequest(),
- name="name_value",
+ client.update_rag_corpus(
+ vertex_rag_data_service.UpdateRagCorpusRequest(),
+ rag_corpus=vertex_rag_data.RagCorpus(name="name_value"),
)
@pytest.mark.asyncio
-async def test_get_rag_corpus_flattened_async():
+async def test_update_rag_corpus_flattened_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.update_rag_corpus), "__call__"
+ ) as call:
# Designate an appropriate return value for the call.
- call.return_value = vertex_rag_data.RagCorpus()
+ call.return_value = operations_pb2.Operation(name="operations/op")
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data.RagCorpus()
+ operations_pb2.Operation(name="operations/spam")
)
# Call the method with a truthy value for each flattened field,
# using the keyword arguments to the method.
- response = await client.get_rag_corpus(
- name="name_value",
+ response = await client.update_rag_corpus(
+ rag_corpus=vertex_rag_data.RagCorpus(name="name_value"),
)
# Establish that the underlying call was made with the expected
# request object values.
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
- arg = args[0].name
- mock_val = "name_value"
+ arg = args[0].rag_corpus
+ mock_val = vertex_rag_data.RagCorpus(name="name_value")
assert arg == mock_val
@pytest.mark.asyncio
-async def test_get_rag_corpus_flattened_error_async():
+async def test_update_rag_corpus_flattened_error_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
@@ -1976,20 +1984,20 @@ async def test_get_rag_corpus_flattened_error_async():
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- await client.get_rag_corpus(
- vertex_rag_data_service.GetRagCorpusRequest(),
- name="name_value",
+ await client.update_rag_corpus(
+ vertex_rag_data_service.UpdateRagCorpusRequest(),
+ rag_corpus=vertex_rag_data.RagCorpus(name="name_value"),
)
@pytest.mark.parametrize(
"request_type",
[
- vertex_rag_data_service.ListRagCorporaRequest,
+ vertex_rag_data_service.GetRagCorpusRequest,
dict,
],
)
-def test_list_rag_corpora(request_type, transport: str = "grpc"):
+def test_get_rag_corpus(request_type, transport: str = "grpc"):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
@@ -2000,25 +2008,29 @@ def test_list_rag_corpora(request_type, transport: str = "grpc"):
request = request_type()
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
# Designate an appropriate return value for the call.
- call.return_value = vertex_rag_data_service.ListRagCorporaResponse(
- next_page_token="next_page_token_value",
+ call.return_value = vertex_rag_data.RagCorpus(
+ name="name_value",
+ display_name="display_name_value",
+ description="description_value",
)
- response = client.list_rag_corpora(request)
+ response = client.get_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.ListRagCorporaRequest()
+ request = vertex_rag_data_service.GetRagCorpusRequest()
assert args[0] == request
# Establish that the response is the type that we expect.
- assert isinstance(response, pagers.ListRagCorporaPager)
- assert response.next_page_token == "next_page_token_value"
+ assert isinstance(response, vertex_rag_data.RagCorpus)
+ assert response.name == "name_value"
+ assert response.display_name == "display_name_value"
+ assert response.description == "description_value"
-def test_list_rag_corpora_empty_call():
+def test_get_rag_corpus_empty_call():
# This test is a coverage failsafe to make sure that totally empty calls,
# i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceClient(
@@ -2027,17 +2039,17 @@ def test_list_rag_corpora_empty_call():
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
call.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client.list_rag_corpora()
+ client.get_rag_corpus()
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.ListRagCorporaRequest()
+ assert args[0] == vertex_rag_data_service.GetRagCorpusRequest()
-def test_list_rag_corpora_non_empty_request_with_auto_populated_field():
+def test_get_rag_corpus_non_empty_request_with_auto_populated_field():
# This test is a coverage failsafe to make sure that UUID4 fields are
# automatically populated, according to AIP-4235, with non-empty requests.
client = VertexRagDataServiceClient(
@@ -2048,26 +2060,24 @@ def test_list_rag_corpora_non_empty_request_with_auto_populated_field():
# Populate all string fields in the request which are not UUID4
# since we want to check that UUID4 are populated automatically
# if they meet the requirements of AIP 4235.
- request = vertex_rag_data_service.ListRagCorporaRequest(
- parent="parent_value",
- page_token="page_token_value",
+ request = vertex_rag_data_service.GetRagCorpusRequest(
+ name="name_value",
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
call.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client.list_rag_corpora(request=request)
+ client.get_rag_corpus(request=request)
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.ListRagCorporaRequest(
- parent="parent_value",
- page_token="page_token_value",
+ assert args[0] == vertex_rag_data_service.GetRagCorpusRequest(
+ name="name_value",
)
-def test_list_rag_corpora_use_cached_wrapped_rpc():
+def test_get_rag_corpus_use_cached_wrapped_rpc():
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
# instead of constructing them on each call
with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn:
@@ -2081,23 +2091,21 @@ def test_list_rag_corpora_use_cached_wrapped_rpc():
wrapper_fn.reset_mock()
# Ensure method has been cached
- assert client._transport.list_rag_corpora in client._transport._wrapped_methods
+ assert client._transport.get_rag_corpus in client._transport._wrapped_methods
# Replace cached wrapped function with mock
mock_rpc = mock.Mock()
mock_rpc.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client._transport._wrapped_methods[
- client._transport.list_rag_corpora
- ] = mock_rpc
+ client._transport._wrapped_methods[client._transport.get_rag_corpus] = mock_rpc
request = {}
- client.list_rag_corpora(request)
+ client.get_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- client.list_rag_corpora(request)
+ client.get_rag_corpus(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -2105,7 +2113,7 @@ def test_list_rag_corpora_use_cached_wrapped_rpc():
@pytest.mark.asyncio
-async def test_list_rag_corpora_empty_call_async():
+async def test_get_rag_corpus_empty_call_async():
# This test is a coverage failsafe to make sure that totally empty calls,
# i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceAsyncClient(
@@ -2114,21 +2122,23 @@ async def test_list_rag_corpora_empty_call_async():
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data_service.ListRagCorporaResponse(
- next_page_token="next_page_token_value",
+ vertex_rag_data.RagCorpus(
+ name="name_value",
+ display_name="display_name_value",
+ description="description_value",
)
)
- response = await client.list_rag_corpora()
+ response = await client.get_rag_corpus()
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.ListRagCorporaRequest()
+ assert args[0] == vertex_rag_data_service.GetRagCorpusRequest()
@pytest.mark.asyncio
-async def test_list_rag_corpora_async_use_cached_wrapped_rpc(
+async def test_get_rag_corpus_async_use_cached_wrapped_rpc(
transport: str = "grpc_asyncio",
):
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
@@ -2145,7 +2155,7 @@ async def test_list_rag_corpora_async_use_cached_wrapped_rpc(
# Ensure method has been cached
assert (
- client._client._transport.list_rag_corpora
+ client._client._transport.get_rag_corpus
in client._client._transport._wrapped_methods
)
@@ -2153,16 +2163,16 @@ async def test_list_rag_corpora_async_use_cached_wrapped_rpc(
mock_rpc = mock.AsyncMock()
mock_rpc.return_value = mock.Mock()
client._client._transport._wrapped_methods[
- client._client._transport.list_rag_corpora
+ client._client._transport.get_rag_corpus
] = mock_rpc
request = {}
- await client.list_rag_corpora(request)
+ await client.get_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- await client.list_rag_corpora(request)
+ await client.get_rag_corpus(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -2170,9 +2180,9 @@ async def test_list_rag_corpora_async_use_cached_wrapped_rpc(
@pytest.mark.asyncio
-async def test_list_rag_corpora_async(
+async def test_get_rag_corpus_async(
transport: str = "grpc_asyncio",
- request_type=vertex_rag_data_service.ListRagCorporaRequest,
+ request_type=vertex_rag_data_service.GetRagCorpusRequest,
):
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
@@ -2184,46 +2194,50 @@ async def test_list_rag_corpora_async(
request = request_type()
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data_service.ListRagCorporaResponse(
- next_page_token="next_page_token_value",
+ vertex_rag_data.RagCorpus(
+ name="name_value",
+ display_name="display_name_value",
+ description="description_value",
)
)
- response = await client.list_rag_corpora(request)
+ response = await client.get_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.ListRagCorporaRequest()
+ request = vertex_rag_data_service.GetRagCorpusRequest()
assert args[0] == request
# Establish that the response is the type that we expect.
- assert isinstance(response, pagers.ListRagCorporaAsyncPager)
- assert response.next_page_token == "next_page_token_value"
+ assert isinstance(response, vertex_rag_data.RagCorpus)
+ assert response.name == "name_value"
+ assert response.display_name == "display_name_value"
+ assert response.description == "description_value"
@pytest.mark.asyncio
-async def test_list_rag_corpora_async_from_dict():
- await test_list_rag_corpora_async(request_type=dict)
+async def test_get_rag_corpus_async_from_dict():
+ await test_get_rag_corpus_async(request_type=dict)
-def test_list_rag_corpora_field_headers():
+def test_get_rag_corpus_field_headers():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.ListRagCorporaRequest()
+ request = vertex_rag_data_service.GetRagCorpusRequest()
- request.parent = "parent_value"
+ request.name = "name_value"
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
- call.return_value = vertex_rag_data_service.ListRagCorporaResponse()
- client.list_rag_corpora(request)
+ with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
+ call.return_value = vertex_rag_data.RagCorpus()
+ client.get_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
@@ -2234,28 +2248,28 @@ def test_list_rag_corpora_field_headers():
_, _, kw = call.mock_calls[0]
assert (
"x-goog-request-params",
- "parent=parent_value",
+ "name=name_value",
) in kw["metadata"]
@pytest.mark.asyncio
-async def test_list_rag_corpora_field_headers_async():
+async def test_get_rag_corpus_field_headers_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.ListRagCorporaRequest()
+ request = vertex_rag_data_service.GetRagCorpusRequest()
- request.parent = "parent_value"
+ request.name = "name_value"
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data_service.ListRagCorporaResponse()
+ vertex_rag_data.RagCorpus()
)
- await client.list_rag_corpora(request)
+ await client.get_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls)
@@ -2266,35 +2280,35 @@ async def test_list_rag_corpora_field_headers_async():
_, _, kw = call.mock_calls[0]
assert (
"x-goog-request-params",
- "parent=parent_value",
+ "name=name_value",
) in kw["metadata"]
-def test_list_rag_corpora_flattened():
+def test_get_rag_corpus_flattened():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
# Designate an appropriate return value for the call.
- call.return_value = vertex_rag_data_service.ListRagCorporaResponse()
+ call.return_value = vertex_rag_data.RagCorpus()
# Call the method with a truthy value for each flattened field,
# using the keyword arguments to the method.
- client.list_rag_corpora(
- parent="parent_value",
+ client.get_rag_corpus(
+ name="name_value",
)
# Establish that the underlying call was made with the expected
# request object values.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
- arg = args[0].parent
- mock_val = "parent_value"
+ arg = args[0].name
+ mock_val = "name_value"
assert arg == mock_val
-def test_list_rag_corpora_flattened_error():
+def test_get_rag_corpus_flattened_error():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
@@ -2302,43 +2316,43 @@ def test_list_rag_corpora_flattened_error():
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- client.list_rag_corpora(
- vertex_rag_data_service.ListRagCorporaRequest(),
- parent="parent_value",
+ client.get_rag_corpus(
+ vertex_rag_data_service.GetRagCorpusRequest(),
+ name="name_value",
)
@pytest.mark.asyncio
-async def test_list_rag_corpora_flattened_async():
+async def test_get_rag_corpus_flattened_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call:
# Designate an appropriate return value for the call.
- call.return_value = vertex_rag_data_service.ListRagCorporaResponse()
+ call.return_value = vertex_rag_data.RagCorpus()
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data_service.ListRagCorporaResponse()
+ vertex_rag_data.RagCorpus()
)
# Call the method with a truthy value for each flattened field,
# using the keyword arguments to the method.
- response = await client.list_rag_corpora(
- parent="parent_value",
+ response = await client.get_rag_corpus(
+ name="name_value",
)
# Establish that the underlying call was made with the expected
# request object values.
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
- arg = args[0].parent
- mock_val = "parent_value"
+ arg = args[0].name
+ mock_val = "name_value"
assert arg == mock_val
@pytest.mark.asyncio
-async def test_list_rag_corpora_flattened_error_async():
+async def test_get_rag_corpus_flattened_error_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
@@ -2346,307 +2360,112 @@ async def test_list_rag_corpora_flattened_error_async():
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- await client.list_rag_corpora(
- vertex_rag_data_service.ListRagCorporaRequest(),
- parent="parent_value",
+ await client.get_rag_corpus(
+ vertex_rag_data_service.GetRagCorpusRequest(),
+ name="name_value",
)
-def test_list_rag_corpora_pager(transport_name: str = "grpc"):
+@pytest.mark.parametrize(
+ "request_type",
+ [
+ vertex_rag_data_service.ListRagCorporaRequest,
+ dict,
+ ],
+)
+def test_list_rag_corpora(request_type, transport: str = "grpc"):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
- transport=transport_name,
+ transport=transport,
)
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
- # Set the response to a series of pages.
- call.side_effect = (
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[
- vertex_rag_data.RagCorpus(),
- vertex_rag_data.RagCorpus(),
- vertex_rag_data.RagCorpus(),
- ],
- next_page_token="abc",
- ),
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[],
- next_page_token="def",
- ),
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[
- vertex_rag_data.RagCorpus(),
- ],
- next_page_token="ghi",
- ),
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[
- vertex_rag_data.RagCorpus(),
- vertex_rag_data.RagCorpus(),
- ],
- ),
- RuntimeError,
- )
-
- expected_metadata = ()
- retry = retries.Retry()
- timeout = 5
- expected_metadata = tuple(expected_metadata) + (
- gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)),
+ # Designate an appropriate return value for the call.
+ call.return_value = vertex_rag_data_service.ListRagCorporaResponse(
+ next_page_token="next_page_token_value",
)
- pager = client.list_rag_corpora(request={}, retry=retry, timeout=timeout)
+ response = client.list_rag_corpora(request)
- assert pager._metadata == expected_metadata
- assert pager._retry == retry
- assert pager._timeout == timeout
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ request = vertex_rag_data_service.ListRagCorporaRequest()
+ assert args[0] == request
- results = list(pager)
- assert len(results) == 6
- assert all(isinstance(i, vertex_rag_data.RagCorpus) for i in results)
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, pagers.ListRagCorporaPager)
+ assert response.next_page_token == "next_page_token_value"
-def test_list_rag_corpora_pages(transport_name: str = "grpc"):
+def test_list_rag_corpora_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
- transport=transport_name,
+ transport="grpc",
)
# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
- # Set the response to a series of pages.
- call.side_effect = (
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[
- vertex_rag_data.RagCorpus(),
- vertex_rag_data.RagCorpus(),
- vertex_rag_data.RagCorpus(),
- ],
- next_page_token="abc",
- ),
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[],
- next_page_token="def",
- ),
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[
- vertex_rag_data.RagCorpus(),
- ],
- next_page_token="ghi",
- ),
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[
- vertex_rag_data.RagCorpus(),
- vertex_rag_data.RagCorpus(),
- ],
- ),
- RuntimeError,
+ call.return_value.name = (
+ "foo" # operation_request.operation in compute client(s) expect a string.
)
- pages = list(client.list_rag_corpora(request={}).pages)
- for page_, token in zip(pages, ["abc", "def", "ghi", ""]):
- assert page_.raw_page.next_page_token == token
+ client.list_rag_corpora()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == vertex_rag_data_service.ListRagCorporaRequest()
-@pytest.mark.asyncio
-async def test_list_rag_corpora_async_pager():
- client = VertexRagDataServiceAsyncClient(
+def test_list_rag_corpora_non_empty_request_with_auto_populated_field():
+ # This test is a coverage failsafe to make sure that UUID4 fields are
+ # automatically populated, according to AIP-4235, with non-empty requests.
+ client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
+ transport="grpc",
+ )
+
+ # Populate all string fields in the request which are not UUID4
+ # since we want to check that UUID4 are populated automatically
+ # if they meet the requirements of AIP 4235.
+ request = vertex_rag_data_service.ListRagCorporaRequest(
+ parent="parent_value",
+ page_token="page_token_value",
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(
- type(client.transport.list_rag_corpora), "__call__", new_callable=mock.AsyncMock
- ) as call:
- # Set the response to a series of pages.
- call.side_effect = (
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[
- vertex_rag_data.RagCorpus(),
- vertex_rag_data.RagCorpus(),
- vertex_rag_data.RagCorpus(),
- ],
- next_page_token="abc",
- ),
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[],
- next_page_token="def",
- ),
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[
- vertex_rag_data.RagCorpus(),
- ],
- next_page_token="ghi",
- ),
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[
- vertex_rag_data.RagCorpus(),
- vertex_rag_data.RagCorpus(),
- ],
- ),
- RuntimeError,
+ with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
+ call.return_value.name = (
+ "foo" # operation_request.operation in compute client(s) expect a string.
)
- async_pager = await client.list_rag_corpora(
- request={},
+ client.list_rag_corpora(request=request)
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == vertex_rag_data_service.ListRagCorporaRequest(
+ parent="parent_value",
+ page_token="page_token_value",
)
- assert async_pager.next_page_token == "abc"
- responses = []
- async for response in async_pager: # pragma: no branch
- responses.append(response)
- assert len(responses) == 6
- assert all(isinstance(i, vertex_rag_data.RagCorpus) for i in responses)
-
-@pytest.mark.asyncio
-async def test_list_rag_corpora_async_pages():
- client = VertexRagDataServiceAsyncClient(
- credentials=ga_credentials.AnonymousCredentials(),
- )
-
- # Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(
- type(client.transport.list_rag_corpora), "__call__", new_callable=mock.AsyncMock
- ) as call:
- # Set the response to a series of pages.
- call.side_effect = (
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[
- vertex_rag_data.RagCorpus(),
- vertex_rag_data.RagCorpus(),
- vertex_rag_data.RagCorpus(),
- ],
- next_page_token="abc",
- ),
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[],
- next_page_token="def",
- ),
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[
- vertex_rag_data.RagCorpus(),
- ],
- next_page_token="ghi",
- ),
- vertex_rag_data_service.ListRagCorporaResponse(
- rag_corpora=[
- vertex_rag_data.RagCorpus(),
- vertex_rag_data.RagCorpus(),
- ],
- ),
- RuntimeError,
- )
- pages = []
- # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch`
- # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372
- async for page_ in ( # pragma: no branch
- await client.list_rag_corpora(request={})
- ).pages:
- pages.append(page_)
- for page_, token in zip(pages, ["abc", "def", "ghi", ""]):
- assert page_.raw_page.next_page_token == token
-
-
-@pytest.mark.parametrize(
- "request_type",
- [
- vertex_rag_data_service.DeleteRagCorpusRequest,
- dict,
- ],
-)
-def test_delete_rag_corpus(request_type, transport: str = "grpc"):
- client = VertexRagDataServiceClient(
- credentials=ga_credentials.AnonymousCredentials(),
- transport=transport,
- )
-
- # Everything is optional in proto3 as far as the runtime is concerned,
- # and we are mocking out the actual API, so just send an empty request.
- request = request_type()
-
- # Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(
- type(client.transport.delete_rag_corpus), "__call__"
- ) as call:
- # Designate an appropriate return value for the call.
- call.return_value = operations_pb2.Operation(name="operations/spam")
- response = client.delete_rag_corpus(request)
-
- # Establish that the underlying gRPC stub method was called.
- assert len(call.mock_calls) == 1
- _, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.DeleteRagCorpusRequest()
- assert args[0] == request
-
- # Establish that the response is the type that we expect.
- assert isinstance(response, future.Future)
-
-
-def test_delete_rag_corpus_empty_call():
- # This test is a coverage failsafe to make sure that totally empty calls,
- # i.e. request == None and no flattened fields passed, work.
- client = VertexRagDataServiceClient(
- credentials=ga_credentials.AnonymousCredentials(),
- transport="grpc",
- )
-
- # Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(
- type(client.transport.delete_rag_corpus), "__call__"
- ) as call:
- call.return_value.name = (
- "foo" # operation_request.operation in compute client(s) expect a string.
- )
- client.delete_rag_corpus()
- call.assert_called()
- _, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.DeleteRagCorpusRequest()
-
-
-def test_delete_rag_corpus_non_empty_request_with_auto_populated_field():
- # This test is a coverage failsafe to make sure that UUID4 fields are
- # automatically populated, according to AIP-4235, with non-empty requests.
- client = VertexRagDataServiceClient(
- credentials=ga_credentials.AnonymousCredentials(),
- transport="grpc",
- )
-
- # Populate all string fields in the request which are not UUID4
- # since we want to check that UUID4 are populated automatically
- # if they meet the requirements of AIP 4235.
- request = vertex_rag_data_service.DeleteRagCorpusRequest(
- name="name_value",
- )
-
- # Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(
- type(client.transport.delete_rag_corpus), "__call__"
- ) as call:
- call.return_value.name = (
- "foo" # operation_request.operation in compute client(s) expect a string.
- )
- client.delete_rag_corpus(request=request)
- call.assert_called()
- _, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.DeleteRagCorpusRequest(
- name="name_value",
- )
-
-
-def test_delete_rag_corpus_use_cached_wrapped_rpc():
- # Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
- # instead of constructing them on each call
- with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn:
- client = VertexRagDataServiceClient(
- credentials=ga_credentials.AnonymousCredentials(),
- transport="grpc",
- )
+def test_list_rag_corpora_use_cached_wrapped_rpc():
+ # Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
+ # instead of constructing them on each call
+ with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn:
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport="grpc",
+ )
# Should wrap all calls on client creation
assert wrapper_fn.call_count > 0
wrapper_fn.reset_mock()
# Ensure method has been cached
- assert client._transport.delete_rag_corpus in client._transport._wrapped_methods
+ assert client._transport.list_rag_corpora in client._transport._wrapped_methods
# Replace cached wrapped function with mock
mock_rpc = mock.Mock()
@@ -2654,20 +2473,15 @@ def test_delete_rag_corpus_use_cached_wrapped_rpc():
"foo" # operation_request.operation in compute client(s) expect a string.
)
client._transport._wrapped_methods[
- client._transport.delete_rag_corpus
+ client._transport.list_rag_corpora
] = mock_rpc
request = {}
- client.delete_rag_corpus(request)
+ client.list_rag_corpora(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- # Operation methods call wrapper_fn to build a cached
- # client._transport.operations_client instance on first rpc call.
- # Subsequent calls should use the cached wrapper
- wrapper_fn.reset_mock()
-
- client.delete_rag_corpus(request)
+ client.list_rag_corpora(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -2675,7 +2489,7 @@ def test_delete_rag_corpus_use_cached_wrapped_rpc():
@pytest.mark.asyncio
-async def test_delete_rag_corpus_empty_call_async():
+async def test_list_rag_corpora_empty_call_async():
# This test is a coverage failsafe to make sure that totally empty calls,
# i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceAsyncClient(
@@ -2684,21 +2498,21 @@ async def test_delete_rag_corpus_empty_call_async():
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(
- type(client.transport.delete_rag_corpus), "__call__"
- ) as call:
+ with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- operations_pb2.Operation(name="operations/spam")
+ vertex_rag_data_service.ListRagCorporaResponse(
+ next_page_token="next_page_token_value",
+ )
)
- response = await client.delete_rag_corpus()
+ response = await client.list_rag_corpora()
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.DeleteRagCorpusRequest()
+ assert args[0] == vertex_rag_data_service.ListRagCorporaRequest()
@pytest.mark.asyncio
-async def test_delete_rag_corpus_async_use_cached_wrapped_rpc(
+async def test_list_rag_corpora_async_use_cached_wrapped_rpc(
transport: str = "grpc_asyncio",
):
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
@@ -2715,7 +2529,7 @@ async def test_delete_rag_corpus_async_use_cached_wrapped_rpc(
# Ensure method has been cached
assert (
- client._client._transport.delete_rag_corpus
+ client._client._transport.list_rag_corpora
in client._client._transport._wrapped_methods
)
@@ -2723,21 +2537,16 @@ async def test_delete_rag_corpus_async_use_cached_wrapped_rpc(
mock_rpc = mock.AsyncMock()
mock_rpc.return_value = mock.Mock()
client._client._transport._wrapped_methods[
- client._client._transport.delete_rag_corpus
+ client._client._transport.list_rag_corpora
] = mock_rpc
request = {}
- await client.delete_rag_corpus(request)
+ await client.list_rag_corpora(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- # Operation methods call wrapper_fn to build a cached
- # client._transport.operations_client instance on first rpc call.
- # Subsequent calls should use the cached wrapper
- wrapper_fn.reset_mock()
-
- await client.delete_rag_corpus(request)
+ await client.list_rag_corpora(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -2745,9 +2554,9 @@ async def test_delete_rag_corpus_async_use_cached_wrapped_rpc(
@pytest.mark.asyncio
-async def test_delete_rag_corpus_async(
+async def test_list_rag_corpora_async(
transport: str = "grpc_asyncio",
- request_type=vertex_rag_data_service.DeleteRagCorpusRequest,
+ request_type=vertex_rag_data_service.ListRagCorporaRequest,
):
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
@@ -2759,47 +2568,46 @@ async def test_delete_rag_corpus_async(
request = request_type()
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(
- type(client.transport.delete_rag_corpus), "__call__"
- ) as call:
+ with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- operations_pb2.Operation(name="operations/spam")
+ vertex_rag_data_service.ListRagCorporaResponse(
+ next_page_token="next_page_token_value",
+ )
)
- response = await client.delete_rag_corpus(request)
+ response = await client.list_rag_corpora(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.DeleteRagCorpusRequest()
+ request = vertex_rag_data_service.ListRagCorporaRequest()
assert args[0] == request
# Establish that the response is the type that we expect.
- assert isinstance(response, future.Future)
+ assert isinstance(response, pagers.ListRagCorporaAsyncPager)
+ assert response.next_page_token == "next_page_token_value"
@pytest.mark.asyncio
-async def test_delete_rag_corpus_async_from_dict():
- await test_delete_rag_corpus_async(request_type=dict)
+async def test_list_rag_corpora_async_from_dict():
+ await test_list_rag_corpora_async(request_type=dict)
-def test_delete_rag_corpus_field_headers():
+def test_list_rag_corpora_field_headers():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.DeleteRagCorpusRequest()
+ request = vertex_rag_data_service.ListRagCorporaRequest()
- request.name = "name_value"
+ request.parent = "parent_value"
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(
- type(client.transport.delete_rag_corpus), "__call__"
- ) as call:
- call.return_value = operations_pb2.Operation(name="operations/op")
- client.delete_rag_corpus(request)
+ with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
+ call.return_value = vertex_rag_data_service.ListRagCorporaResponse()
+ client.list_rag_corpora(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
@@ -2810,30 +2618,28 @@ def test_delete_rag_corpus_field_headers():
_, _, kw = call.mock_calls[0]
assert (
"x-goog-request-params",
- "name=name_value",
+ "parent=parent_value",
) in kw["metadata"]
@pytest.mark.asyncio
-async def test_delete_rag_corpus_field_headers_async():
+async def test_list_rag_corpora_field_headers_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.DeleteRagCorpusRequest()
+ request = vertex_rag_data_service.ListRagCorporaRequest()
- request.name = "name_value"
+ request.parent = "parent_value"
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(
- type(client.transport.delete_rag_corpus), "__call__"
- ) as call:
+ with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- operations_pb2.Operation(name="operations/op")
+ vertex_rag_data_service.ListRagCorporaResponse()
)
- await client.delete_rag_corpus(request)
+ await client.list_rag_corpora(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls)
@@ -2844,37 +2650,35 @@ async def test_delete_rag_corpus_field_headers_async():
_, _, kw = call.mock_calls[0]
assert (
"x-goog-request-params",
- "name=name_value",
+ "parent=parent_value",
) in kw["metadata"]
-def test_delete_rag_corpus_flattened():
+def test_list_rag_corpora_flattened():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(
- type(client.transport.delete_rag_corpus), "__call__"
- ) as call:
+ with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
# Designate an appropriate return value for the call.
- call.return_value = operations_pb2.Operation(name="operations/op")
+ call.return_value = vertex_rag_data_service.ListRagCorporaResponse()
# Call the method with a truthy value for each flattened field,
# using the keyword arguments to the method.
- client.delete_rag_corpus(
- name="name_value",
+ client.list_rag_corpora(
+ parent="parent_value",
)
# Establish that the underlying call was made with the expected
# request object values.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
- arg = args[0].name
- mock_val = "name_value"
+ arg = args[0].parent
+ mock_val = "parent_value"
assert arg == mock_val
-def test_delete_rag_corpus_flattened_error():
+def test_list_rag_corpora_flattened_error():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
@@ -2882,45 +2686,43 @@ def test_delete_rag_corpus_flattened_error():
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- client.delete_rag_corpus(
- vertex_rag_data_service.DeleteRagCorpusRequest(),
- name="name_value",
+ client.list_rag_corpora(
+ vertex_rag_data_service.ListRagCorporaRequest(),
+ parent="parent_value",
)
@pytest.mark.asyncio
-async def test_delete_rag_corpus_flattened_async():
+async def test_list_rag_corpora_flattened_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(
- type(client.transport.delete_rag_corpus), "__call__"
- ) as call:
+ with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
# Designate an appropriate return value for the call.
- call.return_value = operations_pb2.Operation(name="operations/op")
+ call.return_value = vertex_rag_data_service.ListRagCorporaResponse()
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- operations_pb2.Operation(name="operations/spam")
+ vertex_rag_data_service.ListRagCorporaResponse()
)
# Call the method with a truthy value for each flattened field,
# using the keyword arguments to the method.
- response = await client.delete_rag_corpus(
- name="name_value",
+ response = await client.list_rag_corpora(
+ parent="parent_value",
)
# Establish that the underlying call was made with the expected
# request object values.
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
- arg = args[0].name
- mock_val = "name_value"
+ arg = args[0].parent
+ mock_val = "parent_value"
assert arg == mock_val
@pytest.mark.asyncio
-async def test_delete_rag_corpus_flattened_error_async():
+async def test_list_rag_corpora_flattened_error_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
@@ -2928,20 +2730,214 @@ async def test_delete_rag_corpus_flattened_error_async():
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- await client.delete_rag_corpus(
- vertex_rag_data_service.DeleteRagCorpusRequest(),
- name="name_value",
+ await client.list_rag_corpora(
+ vertex_rag_data_service.ListRagCorporaRequest(),
+ parent="parent_value",
+ )
+
+
+def test_list_rag_corpora_pager(transport_name: str = "grpc"):
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport=transport_name,
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[
+ vertex_rag_data.RagCorpus(),
+ vertex_rag_data.RagCorpus(),
+ vertex_rag_data.RagCorpus(),
+ ],
+ next_page_token="abc",
+ ),
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[],
+ next_page_token="def",
+ ),
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[
+ vertex_rag_data.RagCorpus(),
+ ],
+ next_page_token="ghi",
+ ),
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[
+ vertex_rag_data.RagCorpus(),
+ vertex_rag_data.RagCorpus(),
+ ],
+ ),
+ RuntimeError,
+ )
+
+ expected_metadata = ()
+ retry = retries.Retry()
+ timeout = 5
+ expected_metadata = tuple(expected_metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)),
+ )
+ pager = client.list_rag_corpora(request={}, retry=retry, timeout=timeout)
+
+ assert pager._metadata == expected_metadata
+ assert pager._retry == retry
+ assert pager._timeout == timeout
+
+ results = list(pager)
+ assert len(results) == 6
+ assert all(isinstance(i, vertex_rag_data.RagCorpus) for i in results)
+
+
+def test_list_rag_corpora_pages(transport_name: str = "grpc"):
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport=transport_name,
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[
+ vertex_rag_data.RagCorpus(),
+ vertex_rag_data.RagCorpus(),
+ vertex_rag_data.RagCorpus(),
+ ],
+ next_page_token="abc",
+ ),
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[],
+ next_page_token="def",
+ ),
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[
+ vertex_rag_data.RagCorpus(),
+ ],
+ next_page_token="ghi",
+ ),
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[
+ vertex_rag_data.RagCorpus(),
+ vertex_rag_data.RagCorpus(),
+ ],
+ ),
+ RuntimeError,
+ )
+ pages = list(client.list_rag_corpora(request={}).pages)
+ for page_, token in zip(pages, ["abc", "def", "ghi", ""]):
+ assert page_.raw_page.next_page_token == token
+
+
+@pytest.mark.asyncio
+async def test_list_rag_corpora_async_pager():
+ client = VertexRagDataServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_rag_corpora), "__call__", new_callable=mock.AsyncMock
+ ) as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[
+ vertex_rag_data.RagCorpus(),
+ vertex_rag_data.RagCorpus(),
+ vertex_rag_data.RagCorpus(),
+ ],
+ next_page_token="abc",
+ ),
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[],
+ next_page_token="def",
+ ),
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[
+ vertex_rag_data.RagCorpus(),
+ ],
+ next_page_token="ghi",
+ ),
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[
+ vertex_rag_data.RagCorpus(),
+ vertex_rag_data.RagCorpus(),
+ ],
+ ),
+ RuntimeError,
+ )
+ async_pager = await client.list_rag_corpora(
+ request={},
+ )
+ assert async_pager.next_page_token == "abc"
+ responses = []
+ async for response in async_pager: # pragma: no branch
+ responses.append(response)
+
+ assert len(responses) == 6
+ assert all(isinstance(i, vertex_rag_data.RagCorpus) for i in responses)
+
+
+@pytest.mark.asyncio
+async def test_list_rag_corpora_async_pages():
+ client = VertexRagDataServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_rag_corpora), "__call__", new_callable=mock.AsyncMock
+ ) as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[
+ vertex_rag_data.RagCorpus(),
+ vertex_rag_data.RagCorpus(),
+ vertex_rag_data.RagCorpus(),
+ ],
+ next_page_token="abc",
+ ),
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[],
+ next_page_token="def",
+ ),
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[
+ vertex_rag_data.RagCorpus(),
+ ],
+ next_page_token="ghi",
+ ),
+ vertex_rag_data_service.ListRagCorporaResponse(
+ rag_corpora=[
+ vertex_rag_data.RagCorpus(),
+ vertex_rag_data.RagCorpus(),
+ ],
+ ),
+ RuntimeError,
)
+ pages = []
+ # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch`
+ # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372
+ async for page_ in ( # pragma: no branch
+ await client.list_rag_corpora(request={})
+ ).pages:
+ pages.append(page_)
+ for page_, token in zip(pages, ["abc", "def", "ghi", ""]):
+ assert page_.raw_page.next_page_token == token
@pytest.mark.parametrize(
"request_type",
[
- vertex_rag_data_service.UploadRagFileRequest,
+ vertex_rag_data_service.DeleteRagCorpusRequest,
dict,
],
)
-def test_upload_rag_file(request_type, transport: str = "grpc"):
+def test_delete_rag_corpus(request_type, transport: str = "grpc"):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
@@ -2952,22 +2948,24 @@ def test_upload_rag_file(request_type, transport: str = "grpc"):
request = request_type()
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.delete_rag_corpus), "__call__"
+ ) as call:
# Designate an appropriate return value for the call.
- call.return_value = vertex_rag_data_service.UploadRagFileResponse()
- response = client.upload_rag_file(request)
+ call.return_value = operations_pb2.Operation(name="operations/spam")
+ response = client.delete_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.UploadRagFileRequest()
+ request = vertex_rag_data_service.DeleteRagCorpusRequest()
assert args[0] == request
# Establish that the response is the type that we expect.
- assert isinstance(response, vertex_rag_data_service.UploadRagFileResponse)
+ assert isinstance(response, future.Future)
-def test_upload_rag_file_empty_call():
+def test_delete_rag_corpus_empty_call():
# This test is a coverage failsafe to make sure that totally empty calls,
# i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceClient(
@@ -2976,17 +2974,19 @@ def test_upload_rag_file_empty_call():
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.delete_rag_corpus), "__call__"
+ ) as call:
call.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client.upload_rag_file()
+ client.delete_rag_corpus()
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.UploadRagFileRequest()
+ assert args[0] == vertex_rag_data_service.DeleteRagCorpusRequest()
-def test_upload_rag_file_non_empty_request_with_auto_populated_field():
+def test_delete_rag_corpus_non_empty_request_with_auto_populated_field():
# This test is a coverage failsafe to make sure that UUID4 fields are
# automatically populated, according to AIP-4235, with non-empty requests.
client = VertexRagDataServiceClient(
@@ -2997,24 +2997,26 @@ def test_upload_rag_file_non_empty_request_with_auto_populated_field():
# Populate all string fields in the request which are not UUID4
# since we want to check that UUID4 are populated automatically
# if they meet the requirements of AIP 4235.
- request = vertex_rag_data_service.UploadRagFileRequest(
- parent="parent_value",
+ request = vertex_rag_data_service.DeleteRagCorpusRequest(
+ name="name_value",
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.delete_rag_corpus), "__call__"
+ ) as call:
call.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client.upload_rag_file(request=request)
+ client.delete_rag_corpus(request=request)
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.UploadRagFileRequest(
- parent="parent_value",
+ assert args[0] == vertex_rag_data_service.DeleteRagCorpusRequest(
+ name="name_value",
)
-def test_upload_rag_file_use_cached_wrapped_rpc():
+def test_delete_rag_corpus_use_cached_wrapped_rpc():
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
# instead of constructing them on each call
with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn:
@@ -3028,21 +3030,28 @@ def test_upload_rag_file_use_cached_wrapped_rpc():
wrapper_fn.reset_mock()
# Ensure method has been cached
- assert client._transport.upload_rag_file in client._transport._wrapped_methods
+ assert client._transport.delete_rag_corpus in client._transport._wrapped_methods
# Replace cached wrapped function with mock
mock_rpc = mock.Mock()
mock_rpc.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client._transport._wrapped_methods[client._transport.upload_rag_file] = mock_rpc
+ client._transport._wrapped_methods[
+ client._transport.delete_rag_corpus
+ ] = mock_rpc
request = {}
- client.upload_rag_file(request)
+ client.delete_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- client.upload_rag_file(request)
+ # Operation methods call wrapper_fn to build a cached
+ # client._transport.operations_client instance on first rpc call.
+ # Subsequent calls should use the cached wrapper
+ wrapper_fn.reset_mock()
+
+ client.delete_rag_corpus(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -3050,7 +3059,7 @@ def test_upload_rag_file_use_cached_wrapped_rpc():
@pytest.mark.asyncio
-async def test_upload_rag_file_empty_call_async():
+async def test_delete_rag_corpus_empty_call_async():
# This test is a coverage failsafe to make sure that totally empty calls,
# i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceAsyncClient(
@@ -3059,19 +3068,21 @@ async def test_upload_rag_file_empty_call_async():
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.delete_rag_corpus), "__call__"
+ ) as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data_service.UploadRagFileResponse()
+ operations_pb2.Operation(name="operations/spam")
)
- response = await client.upload_rag_file()
+ response = await client.delete_rag_corpus()
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.UploadRagFileRequest()
+ assert args[0] == vertex_rag_data_service.DeleteRagCorpusRequest()
@pytest.mark.asyncio
-async def test_upload_rag_file_async_use_cached_wrapped_rpc(
+async def test_delete_rag_corpus_async_use_cached_wrapped_rpc(
transport: str = "grpc_asyncio",
):
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
@@ -3088,7 +3099,7 @@ async def test_upload_rag_file_async_use_cached_wrapped_rpc(
# Ensure method has been cached
assert (
- client._client._transport.upload_rag_file
+ client._client._transport.delete_rag_corpus
in client._client._transport._wrapped_methods
)
@@ -3096,16 +3107,21 @@ async def test_upload_rag_file_async_use_cached_wrapped_rpc(
mock_rpc = mock.AsyncMock()
mock_rpc.return_value = mock.Mock()
client._client._transport._wrapped_methods[
- client._client._transport.upload_rag_file
+ client._client._transport.delete_rag_corpus
] = mock_rpc
request = {}
- await client.upload_rag_file(request)
+ await client.delete_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- await client.upload_rag_file(request)
+ # Operation methods call wrapper_fn to build a cached
+ # client._transport.operations_client instance on first rpc call.
+ # Subsequent calls should use the cached wrapper
+ wrapper_fn.reset_mock()
+
+ await client.delete_rag_corpus(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -3113,9 +3129,9 @@ async def test_upload_rag_file_async_use_cached_wrapped_rpc(
@pytest.mark.asyncio
-async def test_upload_rag_file_async(
+async def test_delete_rag_corpus_async(
transport: str = "grpc_asyncio",
- request_type=vertex_rag_data_service.UploadRagFileRequest,
+ request_type=vertex_rag_data_service.DeleteRagCorpusRequest,
):
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
@@ -3127,43 +3143,47 @@ async def test_upload_rag_file_async(
request = request_type()
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.delete_rag_corpus), "__call__"
+ ) as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data_service.UploadRagFileResponse()
+ operations_pb2.Operation(name="operations/spam")
)
- response = await client.upload_rag_file(request)
+ response = await client.delete_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.UploadRagFileRequest()
+ request = vertex_rag_data_service.DeleteRagCorpusRequest()
assert args[0] == request
# Establish that the response is the type that we expect.
- assert isinstance(response, vertex_rag_data_service.UploadRagFileResponse)
+ assert isinstance(response, future.Future)
@pytest.mark.asyncio
-async def test_upload_rag_file_async_from_dict():
- await test_upload_rag_file_async(request_type=dict)
+async def test_delete_rag_corpus_async_from_dict():
+ await test_delete_rag_corpus_async(request_type=dict)
-def test_upload_rag_file_field_headers():
+def test_delete_rag_corpus_field_headers():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.UploadRagFileRequest()
+ request = vertex_rag_data_service.DeleteRagCorpusRequest()
- request.parent = "parent_value"
+ request.name = "name_value"
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
- call.return_value = vertex_rag_data_service.UploadRagFileResponse()
- client.upload_rag_file(request)
+ with mock.patch.object(
+ type(client.transport.delete_rag_corpus), "__call__"
+ ) as call:
+ call.return_value = operations_pb2.Operation(name="operations/op")
+ client.delete_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
@@ -3174,28 +3194,30 @@ def test_upload_rag_file_field_headers():
_, _, kw = call.mock_calls[0]
assert (
"x-goog-request-params",
- "parent=parent_value",
+ "name=name_value",
) in kw["metadata"]
@pytest.mark.asyncio
-async def test_upload_rag_file_field_headers_async():
+async def test_delete_rag_corpus_field_headers_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.UploadRagFileRequest()
+ request = vertex_rag_data_service.DeleteRagCorpusRequest()
- request.parent = "parent_value"
+ request.name = "name_value"
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.delete_rag_corpus), "__call__"
+ ) as call:
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data_service.UploadRagFileResponse()
+ operations_pb2.Operation(name="operations/op")
)
- await client.upload_rag_file(request)
+ await client.delete_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls)
@@ -3206,53 +3228,37 @@ async def test_upload_rag_file_field_headers_async():
_, _, kw = call.mock_calls[0]
assert (
"x-goog-request-params",
- "parent=parent_value",
+ "name=name_value",
) in kw["metadata"]
-def test_upload_rag_file_flattened():
+def test_delete_rag_corpus_flattened():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.delete_rag_corpus), "__call__"
+ ) as call:
# Designate an appropriate return value for the call.
- call.return_value = vertex_rag_data_service.UploadRagFileResponse()
+ call.return_value = operations_pb2.Operation(name="operations/op")
# Call the method with a truthy value for each flattened field,
# using the keyword arguments to the method.
- client.upload_rag_file(
- parent="parent_value",
- rag_file=vertex_rag_data.RagFile(
- gcs_source=io.GcsSource(uris=["uris_value"])
- ),
- upload_rag_file_config=vertex_rag_data.UploadRagFileConfig(
- rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig(
- chunk_size=1075
- )
- ),
+ client.delete_rag_corpus(
+ name="name_value",
)
# Establish that the underlying call was made with the expected
# request object values.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
- arg = args[0].parent
- mock_val = "parent_value"
- assert arg == mock_val
- arg = args[0].rag_file
- mock_val = vertex_rag_data.RagFile(gcs_source=io.GcsSource(uris=["uris_value"]))
- assert arg == mock_val
- arg = args[0].upload_rag_file_config
- mock_val = vertex_rag_data.UploadRagFileConfig(
- rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig(
- chunk_size=1075
- )
- )
+ arg = args[0].name
+ mock_val = "name_value"
assert arg == mock_val
-def test_upload_rag_file_flattened_error():
+def test_delete_rag_corpus_flattened_error():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
@@ -3260,69 +3266,45 @@ def test_upload_rag_file_flattened_error():
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- client.upload_rag_file(
- vertex_rag_data_service.UploadRagFileRequest(),
- parent="parent_value",
- rag_file=vertex_rag_data.RagFile(
- gcs_source=io.GcsSource(uris=["uris_value"])
- ),
- upload_rag_file_config=vertex_rag_data.UploadRagFileConfig(
- rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig(
- chunk_size=1075
- )
- ),
+ client.delete_rag_corpus(
+ vertex_rag_data_service.DeleteRagCorpusRequest(),
+ name="name_value",
)
@pytest.mark.asyncio
-async def test_upload_rag_file_flattened_async():
+async def test_delete_rag_corpus_flattened_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
+ with mock.patch.object(
+ type(client.transport.delete_rag_corpus), "__call__"
+ ) as call:
# Designate an appropriate return value for the call.
- call.return_value = vertex_rag_data_service.UploadRagFileResponse()
+ call.return_value = operations_pb2.Operation(name="operations/op")
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data_service.UploadRagFileResponse()
+ operations_pb2.Operation(name="operations/spam")
)
# Call the method with a truthy value for each flattened field,
# using the keyword arguments to the method.
- response = await client.upload_rag_file(
- parent="parent_value",
- rag_file=vertex_rag_data.RagFile(
- gcs_source=io.GcsSource(uris=["uris_value"])
- ),
- upload_rag_file_config=vertex_rag_data.UploadRagFileConfig(
- rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig(
- chunk_size=1075
- )
- ),
+ response = await client.delete_rag_corpus(
+ name="name_value",
)
# Establish that the underlying call was made with the expected
# request object values.
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
- arg = args[0].parent
- mock_val = "parent_value"
- assert arg == mock_val
- arg = args[0].rag_file
- mock_val = vertex_rag_data.RagFile(gcs_source=io.GcsSource(uris=["uris_value"]))
- assert arg == mock_val
- arg = args[0].upload_rag_file_config
- mock_val = vertex_rag_data.UploadRagFileConfig(
- rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig(
- chunk_size=1075
- )
- )
+ arg = args[0].name
+ mock_val = "name_value"
assert arg == mock_val
@pytest.mark.asyncio
-async def test_upload_rag_file_flattened_error_async():
+async def test_delete_rag_corpus_flattened_error_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
@@ -3330,28 +3312,20 @@ async def test_upload_rag_file_flattened_error_async():
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- await client.upload_rag_file(
- vertex_rag_data_service.UploadRagFileRequest(),
- parent="parent_value",
- rag_file=vertex_rag_data.RagFile(
- gcs_source=io.GcsSource(uris=["uris_value"])
- ),
- upload_rag_file_config=vertex_rag_data.UploadRagFileConfig(
- rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig(
- chunk_size=1075
- )
- ),
+ await client.delete_rag_corpus(
+ vertex_rag_data_service.DeleteRagCorpusRequest(),
+ name="name_value",
)
@pytest.mark.parametrize(
"request_type",
[
- vertex_rag_data_service.ImportRagFilesRequest,
+ vertex_rag_data_service.UploadRagFileRequest,
dict,
],
)
-def test_import_rag_files(request_type, transport: str = "grpc"):
+def test_upload_rag_file(request_type, transport: str = "grpc"):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
@@ -3362,22 +3336,22 @@ def test_import_rag_files(request_type, transport: str = "grpc"):
request = request_type()
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
# Designate an appropriate return value for the call.
- call.return_value = operations_pb2.Operation(name="operations/spam")
- response = client.import_rag_files(request)
+ call.return_value = vertex_rag_data_service.UploadRagFileResponse()
+ response = client.upload_rag_file(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.ImportRagFilesRequest()
+ request = vertex_rag_data_service.UploadRagFileRequest()
assert args[0] == request
# Establish that the response is the type that we expect.
- assert isinstance(response, future.Future)
+ assert isinstance(response, vertex_rag_data_service.UploadRagFileResponse)
-def test_import_rag_files_empty_call():
+def test_upload_rag_file_empty_call():
# This test is a coverage failsafe to make sure that totally empty calls,
# i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceClient(
@@ -3386,17 +3360,17 @@ def test_import_rag_files_empty_call():
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
call.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client.import_rag_files()
+ client.upload_rag_file()
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.ImportRagFilesRequest()
+ assert args[0] == vertex_rag_data_service.UploadRagFileRequest()
-def test_import_rag_files_non_empty_request_with_auto_populated_field():
+def test_upload_rag_file_non_empty_request_with_auto_populated_field():
# This test is a coverage failsafe to make sure that UUID4 fields are
# automatically populated, according to AIP-4235, with non-empty requests.
client = VertexRagDataServiceClient(
@@ -3407,24 +3381,24 @@ def test_import_rag_files_non_empty_request_with_auto_populated_field():
# Populate all string fields in the request which are not UUID4
# since we want to check that UUID4 are populated automatically
# if they meet the requirements of AIP 4235.
- request = vertex_rag_data_service.ImportRagFilesRequest(
+ request = vertex_rag_data_service.UploadRagFileRequest(
parent="parent_value",
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
call.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client.import_rag_files(request=request)
+ client.upload_rag_file(request=request)
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.ImportRagFilesRequest(
+ assert args[0] == vertex_rag_data_service.UploadRagFileRequest(
parent="parent_value",
)
-def test_import_rag_files_use_cached_wrapped_rpc():
+def test_upload_rag_file_use_cached_wrapped_rpc():
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
# instead of constructing them on each call
with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn:
@@ -3438,28 +3412,21 @@ def test_import_rag_files_use_cached_wrapped_rpc():
wrapper_fn.reset_mock()
# Ensure method has been cached
- assert client._transport.import_rag_files in client._transport._wrapped_methods
+ assert client._transport.upload_rag_file in client._transport._wrapped_methods
# Replace cached wrapped function with mock
mock_rpc = mock.Mock()
mock_rpc.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client._transport._wrapped_methods[
- client._transport.import_rag_files
- ] = mock_rpc
+ client._transport._wrapped_methods[client._transport.upload_rag_file] = mock_rpc
request = {}
- client.import_rag_files(request)
+ client.upload_rag_file(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- # Operation methods call wrapper_fn to build a cached
- # client._transport.operations_client instance on first rpc call.
- # Subsequent calls should use the cached wrapper
- wrapper_fn.reset_mock()
-
- client.import_rag_files(request)
+ client.upload_rag_file(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -3467,7 +3434,7 @@ def test_import_rag_files_use_cached_wrapped_rpc():
@pytest.mark.asyncio
-async def test_import_rag_files_empty_call_async():
+async def test_upload_rag_file_empty_call_async():
# This test is a coverage failsafe to make sure that totally empty calls,
# i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceAsyncClient(
@@ -3476,19 +3443,19 @@ async def test_import_rag_files_empty_call_async():
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- operations_pb2.Operation(name="operations/spam")
+ vertex_rag_data_service.UploadRagFileResponse()
)
- response = await client.import_rag_files()
+ response = await client.upload_rag_file()
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.ImportRagFilesRequest()
+ assert args[0] == vertex_rag_data_service.UploadRagFileRequest()
@pytest.mark.asyncio
-async def test_import_rag_files_async_use_cached_wrapped_rpc(
+async def test_upload_rag_file_async_use_cached_wrapped_rpc(
transport: str = "grpc_asyncio",
):
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
@@ -3505,7 +3472,7 @@ async def test_import_rag_files_async_use_cached_wrapped_rpc(
# Ensure method has been cached
assert (
- client._client._transport.import_rag_files
+ client._client._transport.upload_rag_file
in client._client._transport._wrapped_methods
)
@@ -3513,21 +3480,16 @@ async def test_import_rag_files_async_use_cached_wrapped_rpc(
mock_rpc = mock.AsyncMock()
mock_rpc.return_value = mock.Mock()
client._client._transport._wrapped_methods[
- client._client._transport.import_rag_files
+ client._client._transport.upload_rag_file
] = mock_rpc
request = {}
- await client.import_rag_files(request)
+ await client.upload_rag_file(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- # Operation methods call wrapper_fn to build a cached
- # client._transport.operations_client instance on first rpc call.
- # Subsequent calls should use the cached wrapper
- wrapper_fn.reset_mock()
-
- await client.import_rag_files(request)
+ await client.upload_rag_file(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -3535,9 +3497,9 @@ async def test_import_rag_files_async_use_cached_wrapped_rpc(
@pytest.mark.asyncio
-async def test_import_rag_files_async(
+async def test_upload_rag_file_async(
transport: str = "grpc_asyncio",
- request_type=vertex_rag_data_service.ImportRagFilesRequest,
+ request_type=vertex_rag_data_service.UploadRagFileRequest,
):
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
@@ -3549,43 +3511,43 @@ async def test_import_rag_files_async(
request = request_type()
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- operations_pb2.Operation(name="operations/spam")
+ vertex_rag_data_service.UploadRagFileResponse()
)
- response = await client.import_rag_files(request)
+ response = await client.upload_rag_file(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.ImportRagFilesRequest()
+ request = vertex_rag_data_service.UploadRagFileRequest()
assert args[0] == request
# Establish that the response is the type that we expect.
- assert isinstance(response, future.Future)
+ assert isinstance(response, vertex_rag_data_service.UploadRagFileResponse)
@pytest.mark.asyncio
-async def test_import_rag_files_async_from_dict():
- await test_import_rag_files_async(request_type=dict)
+async def test_upload_rag_file_async_from_dict():
+ await test_upload_rag_file_async(request_type=dict)
-def test_import_rag_files_field_headers():
+def test_upload_rag_file_field_headers():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.ImportRagFilesRequest()
+ request = vertex_rag_data_service.UploadRagFileRequest()
request.parent = "parent_value"
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
- call.return_value = operations_pb2.Operation(name="operations/op")
- client.import_rag_files(request)
+ with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
+ call.return_value = vertex_rag_data_service.UploadRagFileResponse()
+ client.upload_rag_file(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
@@ -3601,23 +3563,23 @@ def test_import_rag_files_field_headers():
@pytest.mark.asyncio
-async def test_import_rag_files_field_headers_async():
+async def test_upload_rag_file_field_headers_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.ImportRagFilesRequest()
+ request = vertex_rag_data_service.UploadRagFileRequest()
request.parent = "parent_value"
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- operations_pb2.Operation(name="operations/op")
+ vertex_rag_data_service.UploadRagFileResponse()
)
- await client.import_rag_files(request)
+ await client.upload_rag_file(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls)
@@ -3632,22 +3594,27 @@ async def test_import_rag_files_field_headers_async():
) in kw["metadata"]
-def test_import_rag_files_flattened():
+def test_upload_rag_file_flattened():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
# Designate an appropriate return value for the call.
- call.return_value = operations_pb2.Operation(name="operations/op")
+ call.return_value = vertex_rag_data_service.UploadRagFileResponse()
# Call the method with a truthy value for each flattened field,
# using the keyword arguments to the method.
- client.import_rag_files(
+ client.upload_rag_file(
parent="parent_value",
- import_rag_files_config=vertex_rag_data.ImportRagFilesConfig(
+ rag_file=vertex_rag_data.RagFile(
gcs_source=io.GcsSource(uris=["uris_value"])
),
+ upload_rag_file_config=vertex_rag_data.UploadRagFileConfig(
+ rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig(
+ chunk_size=1075
+ )
+ ),
)
# Establish that the underlying call was made with the expected
@@ -3657,14 +3624,19 @@ def test_import_rag_files_flattened():
arg = args[0].parent
mock_val = "parent_value"
assert arg == mock_val
- arg = args[0].import_rag_files_config
- mock_val = vertex_rag_data.ImportRagFilesConfig(
- gcs_source=io.GcsSource(uris=["uris_value"])
+ arg = args[0].rag_file
+ mock_val = vertex_rag_data.RagFile(gcs_source=io.GcsSource(uris=["uris_value"]))
+ assert arg == mock_val
+ arg = args[0].upload_rag_file_config
+ mock_val = vertex_rag_data.UploadRagFileConfig(
+ rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig(
+ chunk_size=1075
+ )
)
assert arg == mock_val
-def test_import_rag_files_flattened_error():
+def test_upload_rag_file_flattened_error():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
@@ -3672,36 +3644,46 @@ def test_import_rag_files_flattened_error():
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- client.import_rag_files(
- vertex_rag_data_service.ImportRagFilesRequest(),
+ client.upload_rag_file(
+ vertex_rag_data_service.UploadRagFileRequest(),
parent="parent_value",
- import_rag_files_config=vertex_rag_data.ImportRagFilesConfig(
+ rag_file=vertex_rag_data.RagFile(
gcs_source=io.GcsSource(uris=["uris_value"])
),
+ upload_rag_file_config=vertex_rag_data.UploadRagFileConfig(
+ rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig(
+ chunk_size=1075
+ )
+ ),
)
@pytest.mark.asyncio
-async def test_import_rag_files_flattened_async():
+async def test_upload_rag_file_flattened_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call:
# Designate an appropriate return value for the call.
- call.return_value = operations_pb2.Operation(name="operations/op")
+ call.return_value = vertex_rag_data_service.UploadRagFileResponse()
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- operations_pb2.Operation(name="operations/spam")
+ vertex_rag_data_service.UploadRagFileResponse()
)
# Call the method with a truthy value for each flattened field,
# using the keyword arguments to the method.
- response = await client.import_rag_files(
+ response = await client.upload_rag_file(
parent="parent_value",
- import_rag_files_config=vertex_rag_data.ImportRagFilesConfig(
+ rag_file=vertex_rag_data.RagFile(
gcs_source=io.GcsSource(uris=["uris_value"])
),
+ upload_rag_file_config=vertex_rag_data.UploadRagFileConfig(
+ rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig(
+ chunk_size=1075
+ )
+ ),
)
# Establish that the underlying call was made with the expected
@@ -3711,15 +3693,20 @@ async def test_import_rag_files_flattened_async():
arg = args[0].parent
mock_val = "parent_value"
assert arg == mock_val
- arg = args[0].import_rag_files_config
- mock_val = vertex_rag_data.ImportRagFilesConfig(
- gcs_source=io.GcsSource(uris=["uris_value"])
+ arg = args[0].rag_file
+ mock_val = vertex_rag_data.RagFile(gcs_source=io.GcsSource(uris=["uris_value"]))
+ assert arg == mock_val
+ arg = args[0].upload_rag_file_config
+ mock_val = vertex_rag_data.UploadRagFileConfig(
+ rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig(
+ chunk_size=1075
+ )
)
assert arg == mock_val
@pytest.mark.asyncio
-async def test_import_rag_files_flattened_error_async():
+async def test_upload_rag_file_flattened_error_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
@@ -3727,23 +3714,28 @@ async def test_import_rag_files_flattened_error_async():
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- await client.import_rag_files(
- vertex_rag_data_service.ImportRagFilesRequest(),
+ await client.upload_rag_file(
+ vertex_rag_data_service.UploadRagFileRequest(),
parent="parent_value",
- import_rag_files_config=vertex_rag_data.ImportRagFilesConfig(
+ rag_file=vertex_rag_data.RagFile(
gcs_source=io.GcsSource(uris=["uris_value"])
),
+ upload_rag_file_config=vertex_rag_data.UploadRagFileConfig(
+ rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig(
+ chunk_size=1075
+ )
+ ),
)
@pytest.mark.parametrize(
"request_type",
[
- vertex_rag_data_service.GetRagFileRequest,
+ vertex_rag_data_service.ImportRagFilesRequest,
dict,
],
)
-def test_get_rag_file(request_type, transport: str = "grpc"):
+def test_import_rag_files(request_type, transport: str = "grpc"):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
@@ -3754,35 +3746,22 @@ def test_get_rag_file(request_type, transport: str = "grpc"):
request = request_type()
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
+ with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
# Designate an appropriate return value for the call.
- call.return_value = vertex_rag_data.RagFile(
- name="name_value",
- display_name="display_name_value",
- description="description_value",
- size_bytes=1089,
- rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT,
- )
- response = client.get_rag_file(request)
+ call.return_value = operations_pb2.Operation(name="operations/spam")
+ response = client.import_rag_files(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.GetRagFileRequest()
+ request = vertex_rag_data_service.ImportRagFilesRequest()
assert args[0] == request
# Establish that the response is the type that we expect.
- assert isinstance(response, vertex_rag_data.RagFile)
- assert response.name == "name_value"
- assert response.display_name == "display_name_value"
- assert response.description == "description_value"
- assert response.size_bytes == 1089
- assert (
- response.rag_file_type == vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT
- )
+ assert isinstance(response, future.Future)
-def test_get_rag_file_empty_call():
+def test_import_rag_files_empty_call():
# This test is a coverage failsafe to make sure that totally empty calls,
# i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceClient(
@@ -3791,17 +3770,17 @@ def test_get_rag_file_empty_call():
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
+ with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
call.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client.get_rag_file()
+ client.import_rag_files()
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.GetRagFileRequest()
+ assert args[0] == vertex_rag_data_service.ImportRagFilesRequest()
-def test_get_rag_file_non_empty_request_with_auto_populated_field():
+def test_import_rag_files_non_empty_request_with_auto_populated_field():
# This test is a coverage failsafe to make sure that UUID4 fields are
# automatically populated, according to AIP-4235, with non-empty requests.
client = VertexRagDataServiceClient(
@@ -3812,24 +3791,24 @@ def test_get_rag_file_non_empty_request_with_auto_populated_field():
# Populate all string fields in the request which are not UUID4
# since we want to check that UUID4 are populated automatically
# if they meet the requirements of AIP 4235.
- request = vertex_rag_data_service.GetRagFileRequest(
- name="name_value",
+ request = vertex_rag_data_service.ImportRagFilesRequest(
+ parent="parent_value",
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
+ with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
call.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client.get_rag_file(request=request)
+ client.import_rag_files(request=request)
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.GetRagFileRequest(
- name="name_value",
+ assert args[0] == vertex_rag_data_service.ImportRagFilesRequest(
+ parent="parent_value",
)
-def test_get_rag_file_use_cached_wrapped_rpc():
+def test_import_rag_files_use_cached_wrapped_rpc():
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
# instead of constructing them on each call
with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn:
@@ -3843,21 +3822,28 @@ def test_get_rag_file_use_cached_wrapped_rpc():
wrapper_fn.reset_mock()
# Ensure method has been cached
- assert client._transport.get_rag_file in client._transport._wrapped_methods
+ assert client._transport.import_rag_files in client._transport._wrapped_methods
# Replace cached wrapped function with mock
mock_rpc = mock.Mock()
mock_rpc.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client._transport._wrapped_methods[client._transport.get_rag_file] = mock_rpc
+ client._transport._wrapped_methods[
+ client._transport.import_rag_files
+ ] = mock_rpc
request = {}
- client.get_rag_file(request)
+ client.import_rag_files(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- client.get_rag_file(request)
+ # Operation methods call wrapper_fn to build a cached
+ # client._transport.operations_client instance on first rpc call.
+ # Subsequent calls should use the cached wrapper
+ wrapper_fn.reset_mock()
+
+ client.import_rag_files(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -3865,7 +3851,7 @@ def test_get_rag_file_use_cached_wrapped_rpc():
@pytest.mark.asyncio
-async def test_get_rag_file_empty_call_async():
+async def test_import_rag_files_empty_call_async():
# This test is a coverage failsafe to make sure that totally empty calls,
# i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceAsyncClient(
@@ -3874,25 +3860,19 @@ async def test_get_rag_file_empty_call_async():
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
+ with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data.RagFile(
- name="name_value",
- display_name="display_name_value",
- description="description_value",
- size_bytes=1089,
- rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT,
- )
+ operations_pb2.Operation(name="operations/spam")
)
- response = await client.get_rag_file()
+ response = await client.import_rag_files()
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.GetRagFileRequest()
+ assert args[0] == vertex_rag_data_service.ImportRagFilesRequest()
@pytest.mark.asyncio
-async def test_get_rag_file_async_use_cached_wrapped_rpc(
+async def test_import_rag_files_async_use_cached_wrapped_rpc(
transport: str = "grpc_asyncio",
):
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
@@ -3909,7 +3889,7 @@ async def test_get_rag_file_async_use_cached_wrapped_rpc(
# Ensure method has been cached
assert (
- client._client._transport.get_rag_file
+ client._client._transport.import_rag_files
in client._client._transport._wrapped_methods
)
@@ -3917,16 +3897,21 @@ async def test_get_rag_file_async_use_cached_wrapped_rpc(
mock_rpc = mock.AsyncMock()
mock_rpc.return_value = mock.Mock()
client._client._transport._wrapped_methods[
- client._client._transport.get_rag_file
+ client._client._transport.import_rag_files
] = mock_rpc
request = {}
- await client.get_rag_file(request)
+ await client.import_rag_files(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- await client.get_rag_file(request)
+ # Operation methods call wrapper_fn to build a cached
+ # client._transport.operations_client instance on first rpc call.
+ # Subsequent calls should use the cached wrapper
+ wrapper_fn.reset_mock()
+
+ await client.import_rag_files(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -3934,9 +3919,9 @@ async def test_get_rag_file_async_use_cached_wrapped_rpc(
@pytest.mark.asyncio
-async def test_get_rag_file_async(
+async def test_import_rag_files_async(
transport: str = "grpc_asyncio",
- request_type=vertex_rag_data_service.GetRagFileRequest,
+ request_type=vertex_rag_data_service.ImportRagFilesRequest,
):
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
@@ -3948,56 +3933,43 @@ async def test_get_rag_file_async(
request = request_type()
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
+ with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data.RagFile(
- name="name_value",
- display_name="display_name_value",
- description="description_value",
- size_bytes=1089,
- rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT,
- )
+ operations_pb2.Operation(name="operations/spam")
)
- response = await client.get_rag_file(request)
+ response = await client.import_rag_files(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.GetRagFileRequest()
+ request = vertex_rag_data_service.ImportRagFilesRequest()
assert args[0] == request
# Establish that the response is the type that we expect.
- assert isinstance(response, vertex_rag_data.RagFile)
- assert response.name == "name_value"
- assert response.display_name == "display_name_value"
- assert response.description == "description_value"
- assert response.size_bytes == 1089
- assert (
- response.rag_file_type == vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT
- )
+ assert isinstance(response, future.Future)
@pytest.mark.asyncio
-async def test_get_rag_file_async_from_dict():
- await test_get_rag_file_async(request_type=dict)
+async def test_import_rag_files_async_from_dict():
+ await test_import_rag_files_async(request_type=dict)
-def test_get_rag_file_field_headers():
+def test_import_rag_files_field_headers():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.GetRagFileRequest()
+ request = vertex_rag_data_service.ImportRagFilesRequest()
- request.name = "name_value"
+ request.parent = "parent_value"
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
- call.return_value = vertex_rag_data.RagFile()
- client.get_rag_file(request)
+ with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
+ call.return_value = operations_pb2.Operation(name="operations/op")
+ client.import_rag_files(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
@@ -4008,28 +3980,28 @@ def test_get_rag_file_field_headers():
_, _, kw = call.mock_calls[0]
assert (
"x-goog-request-params",
- "name=name_value",
+ "parent=parent_value",
) in kw["metadata"]
@pytest.mark.asyncio
-async def test_get_rag_file_field_headers_async():
+async def test_import_rag_files_field_headers_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.GetRagFileRequest()
+ request = vertex_rag_data_service.ImportRagFilesRequest()
- request.name = "name_value"
+ request.parent = "parent_value"
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
+ with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data.RagFile()
+ operations_pb2.Operation(name="operations/op")
)
- await client.get_rag_file(request)
+ await client.import_rag_files(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls)
@@ -4040,35 +4012,43 @@ async def test_get_rag_file_field_headers_async():
_, _, kw = call.mock_calls[0]
assert (
"x-goog-request-params",
- "name=name_value",
+ "parent=parent_value",
) in kw["metadata"]
-def test_get_rag_file_flattened():
+def test_import_rag_files_flattened():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
+ with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
# Designate an appropriate return value for the call.
- call.return_value = vertex_rag_data.RagFile()
+ call.return_value = operations_pb2.Operation(name="operations/op")
# Call the method with a truthy value for each flattened field,
# using the keyword arguments to the method.
- client.get_rag_file(
- name="name_value",
+ client.import_rag_files(
+ parent="parent_value",
+ import_rag_files_config=vertex_rag_data.ImportRagFilesConfig(
+ gcs_source=io.GcsSource(uris=["uris_value"])
+ ),
)
# Establish that the underlying call was made with the expected
# request object values.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
- arg = args[0].name
- mock_val = "name_value"
+ arg = args[0].parent
+ mock_val = "parent_value"
+ assert arg == mock_val
+ arg = args[0].import_rag_files_config
+ mock_val = vertex_rag_data.ImportRagFilesConfig(
+ gcs_source=io.GcsSource(uris=["uris_value"])
+ )
assert arg == mock_val
-def test_get_rag_file_flattened_error():
+def test_import_rag_files_flattened_error():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
@@ -4076,43 +4056,54 @@ def test_get_rag_file_flattened_error():
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- client.get_rag_file(
- vertex_rag_data_service.GetRagFileRequest(),
- name="name_value",
+ client.import_rag_files(
+ vertex_rag_data_service.ImportRagFilesRequest(),
+ parent="parent_value",
+ import_rag_files_config=vertex_rag_data.ImportRagFilesConfig(
+ gcs_source=io.GcsSource(uris=["uris_value"])
+ ),
)
@pytest.mark.asyncio
-async def test_get_rag_file_flattened_async():
+async def test_import_rag_files_flattened_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
+ with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call:
# Designate an appropriate return value for the call.
- call.return_value = vertex_rag_data.RagFile()
+ call.return_value = operations_pb2.Operation(name="operations/op")
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data.RagFile()
+ operations_pb2.Operation(name="operations/spam")
)
# Call the method with a truthy value for each flattened field,
# using the keyword arguments to the method.
- response = await client.get_rag_file(
- name="name_value",
+ response = await client.import_rag_files(
+ parent="parent_value",
+ import_rag_files_config=vertex_rag_data.ImportRagFilesConfig(
+ gcs_source=io.GcsSource(uris=["uris_value"])
+ ),
)
# Establish that the underlying call was made with the expected
# request object values.
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
- arg = args[0].name
- mock_val = "name_value"
+ arg = args[0].parent
+ mock_val = "parent_value"
+ assert arg == mock_val
+ arg = args[0].import_rag_files_config
+ mock_val = vertex_rag_data.ImportRagFilesConfig(
+ gcs_source=io.GcsSource(uris=["uris_value"])
+ )
assert arg == mock_val
@pytest.mark.asyncio
-async def test_get_rag_file_flattened_error_async():
+async def test_import_rag_files_flattened_error_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
@@ -4120,20 +4111,23 @@ async def test_get_rag_file_flattened_error_async():
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- await client.get_rag_file(
- vertex_rag_data_service.GetRagFileRequest(),
- name="name_value",
+ await client.import_rag_files(
+ vertex_rag_data_service.ImportRagFilesRequest(),
+ parent="parent_value",
+ import_rag_files_config=vertex_rag_data.ImportRagFilesConfig(
+ gcs_source=io.GcsSource(uris=["uris_value"])
+ ),
)
@pytest.mark.parametrize(
"request_type",
[
- vertex_rag_data_service.ListRagFilesRequest,
+ vertex_rag_data_service.GetRagFileRequest,
dict,
],
)
-def test_list_rag_files(request_type, transport: str = "grpc"):
+def test_get_rag_file(request_type, transport: str = "grpc"):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
@@ -4144,25 +4138,35 @@ def test_list_rag_files(request_type, transport: str = "grpc"):
request = request_type()
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
# Designate an appropriate return value for the call.
- call.return_value = vertex_rag_data_service.ListRagFilesResponse(
- next_page_token="next_page_token_value",
+ call.return_value = vertex_rag_data.RagFile(
+ name="name_value",
+ display_name="display_name_value",
+ description="description_value",
+ size_bytes=1089,
+ rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT,
)
- response = client.list_rag_files(request)
+ response = client.get_rag_file(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.ListRagFilesRequest()
+ request = vertex_rag_data_service.GetRagFileRequest()
assert args[0] == request
# Establish that the response is the type that we expect.
- assert isinstance(response, pagers.ListRagFilesPager)
- assert response.next_page_token == "next_page_token_value"
+ assert isinstance(response, vertex_rag_data.RagFile)
+ assert response.name == "name_value"
+ assert response.display_name == "display_name_value"
+ assert response.description == "description_value"
+ assert response.size_bytes == 1089
+ assert (
+ response.rag_file_type == vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT
+ )
-def test_list_rag_files_empty_call():
+def test_get_rag_file_empty_call():
# This test is a coverage failsafe to make sure that totally empty calls,
# i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceClient(
@@ -4171,17 +4175,17 @@ def test_list_rag_files_empty_call():
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
call.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client.list_rag_files()
+ client.get_rag_file()
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.ListRagFilesRequest()
+ assert args[0] == vertex_rag_data_service.GetRagFileRequest()
-def test_list_rag_files_non_empty_request_with_auto_populated_field():
+def test_get_rag_file_non_empty_request_with_auto_populated_field():
# This test is a coverage failsafe to make sure that UUID4 fields are
# automatically populated, according to AIP-4235, with non-empty requests.
client = VertexRagDataServiceClient(
@@ -4192,26 +4196,24 @@ def test_list_rag_files_non_empty_request_with_auto_populated_field():
# Populate all string fields in the request which are not UUID4
# since we want to check that UUID4 are populated automatically
# if they meet the requirements of AIP 4235.
- request = vertex_rag_data_service.ListRagFilesRequest(
- parent="parent_value",
- page_token="page_token_value",
+ request = vertex_rag_data_service.GetRagFileRequest(
+ name="name_value",
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
call.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client.list_rag_files(request=request)
+ client.get_rag_file(request=request)
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.ListRagFilesRequest(
- parent="parent_value",
- page_token="page_token_value",
+ assert args[0] == vertex_rag_data_service.GetRagFileRequest(
+ name="name_value",
)
-def test_list_rag_files_use_cached_wrapped_rpc():
+def test_get_rag_file_use_cached_wrapped_rpc():
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
# instead of constructing them on each call
with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn:
@@ -4225,21 +4227,21 @@ def test_list_rag_files_use_cached_wrapped_rpc():
wrapper_fn.reset_mock()
# Ensure method has been cached
- assert client._transport.list_rag_files in client._transport._wrapped_methods
+ assert client._transport.get_rag_file in client._transport._wrapped_methods
# Replace cached wrapped function with mock
mock_rpc = mock.Mock()
mock_rpc.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client._transport._wrapped_methods[client._transport.list_rag_files] = mock_rpc
+ client._transport._wrapped_methods[client._transport.get_rag_file] = mock_rpc
request = {}
- client.list_rag_files(request)
+ client.get_rag_file(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- client.list_rag_files(request)
+ client.get_rag_file(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -4247,7 +4249,7 @@ def test_list_rag_files_use_cached_wrapped_rpc():
@pytest.mark.asyncio
-async def test_list_rag_files_empty_call_async():
+async def test_get_rag_file_empty_call_async():
# This test is a coverage failsafe to make sure that totally empty calls,
# i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceAsyncClient(
@@ -4256,21 +4258,25 @@ async def test_list_rag_files_empty_call_async():
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data_service.ListRagFilesResponse(
- next_page_token="next_page_token_value",
+ vertex_rag_data.RagFile(
+ name="name_value",
+ display_name="display_name_value",
+ description="description_value",
+ size_bytes=1089,
+ rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT,
)
)
- response = await client.list_rag_files()
+ response = await client.get_rag_file()
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.ListRagFilesRequest()
+ assert args[0] == vertex_rag_data_service.GetRagFileRequest()
@pytest.mark.asyncio
-async def test_list_rag_files_async_use_cached_wrapped_rpc(
+async def test_get_rag_file_async_use_cached_wrapped_rpc(
transport: str = "grpc_asyncio",
):
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
@@ -4287,7 +4293,7 @@ async def test_list_rag_files_async_use_cached_wrapped_rpc(
# Ensure method has been cached
assert (
- client._client._transport.list_rag_files
+ client._client._transport.get_rag_file
in client._client._transport._wrapped_methods
)
@@ -4295,16 +4301,16 @@ async def test_list_rag_files_async_use_cached_wrapped_rpc(
mock_rpc = mock.AsyncMock()
mock_rpc.return_value = mock.Mock()
client._client._transport._wrapped_methods[
- client._client._transport.list_rag_files
+ client._client._transport.get_rag_file
] = mock_rpc
request = {}
- await client.list_rag_files(request)
+ await client.get_rag_file(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- await client.list_rag_files(request)
+ await client.get_rag_file(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -4312,9 +4318,9 @@ async def test_list_rag_files_async_use_cached_wrapped_rpc(
@pytest.mark.asyncio
-async def test_list_rag_files_async(
+async def test_get_rag_file_async(
transport: str = "grpc_asyncio",
- request_type=vertex_rag_data_service.ListRagFilesRequest,
+ request_type=vertex_rag_data_service.GetRagFileRequest,
):
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
@@ -4326,46 +4332,56 @@ async def test_list_rag_files_async(
request = request_type()
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data_service.ListRagFilesResponse(
- next_page_token="next_page_token_value",
+ vertex_rag_data.RagFile(
+ name="name_value",
+ display_name="display_name_value",
+ description="description_value",
+ size_bytes=1089,
+ rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT,
)
)
- response = await client.list_rag_files(request)
+ response = await client.get_rag_file(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.ListRagFilesRequest()
+ request = vertex_rag_data_service.GetRagFileRequest()
assert args[0] == request
# Establish that the response is the type that we expect.
- assert isinstance(response, pagers.ListRagFilesAsyncPager)
- assert response.next_page_token == "next_page_token_value"
+ assert isinstance(response, vertex_rag_data.RagFile)
+ assert response.name == "name_value"
+ assert response.display_name == "display_name_value"
+ assert response.description == "description_value"
+ assert response.size_bytes == 1089
+ assert (
+ response.rag_file_type == vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT
+ )
@pytest.mark.asyncio
-async def test_list_rag_files_async_from_dict():
- await test_list_rag_files_async(request_type=dict)
+async def test_get_rag_file_async_from_dict():
+ await test_get_rag_file_async(request_type=dict)
-def test_list_rag_files_field_headers():
+def test_get_rag_file_field_headers():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.ListRagFilesRequest()
+ request = vertex_rag_data_service.GetRagFileRequest()
- request.parent = "parent_value"
+ request.name = "name_value"
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
- call.return_value = vertex_rag_data_service.ListRagFilesResponse()
- client.list_rag_files(request)
+ with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
+ call.return_value = vertex_rag_data.RagFile()
+ client.get_rag_file(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
@@ -4376,28 +4392,28 @@ def test_list_rag_files_field_headers():
_, _, kw = call.mock_calls[0]
assert (
"x-goog-request-params",
- "parent=parent_value",
+ "name=name_value",
) in kw["metadata"]
@pytest.mark.asyncio
-async def test_list_rag_files_field_headers_async():
+async def test_get_rag_file_field_headers_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.ListRagFilesRequest()
+ request = vertex_rag_data_service.GetRagFileRequest()
- request.parent = "parent_value"
+ request.name = "name_value"
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data_service.ListRagFilesResponse()
+ vertex_rag_data.RagFile()
)
- await client.list_rag_files(request)
+ await client.get_rag_file(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls)
@@ -4408,35 +4424,35 @@ async def test_list_rag_files_field_headers_async():
_, _, kw = call.mock_calls[0]
assert (
"x-goog-request-params",
- "parent=parent_value",
+ "name=name_value",
) in kw["metadata"]
-def test_list_rag_files_flattened():
+def test_get_rag_file_flattened():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
# Designate an appropriate return value for the call.
- call.return_value = vertex_rag_data_service.ListRagFilesResponse()
+ call.return_value = vertex_rag_data.RagFile()
# Call the method with a truthy value for each flattened field,
# using the keyword arguments to the method.
- client.list_rag_files(
- parent="parent_value",
+ client.get_rag_file(
+ name="name_value",
)
# Establish that the underlying call was made with the expected
# request object values.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
- arg = args[0].parent
- mock_val = "parent_value"
+ arg = args[0].name
+ mock_val = "name_value"
assert arg == mock_val
-def test_list_rag_files_flattened_error():
+def test_get_rag_file_flattened_error():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
@@ -4444,43 +4460,43 @@ def test_list_rag_files_flattened_error():
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- client.list_rag_files(
- vertex_rag_data_service.ListRagFilesRequest(),
- parent="parent_value",
+ client.get_rag_file(
+ vertex_rag_data_service.GetRagFileRequest(),
+ name="name_value",
)
@pytest.mark.asyncio
-async def test_list_rag_files_flattened_async():
+async def test_get_rag_file_flattened_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
+ with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call:
# Designate an appropriate return value for the call.
- call.return_value = vertex_rag_data_service.ListRagFilesResponse()
+ call.return_value = vertex_rag_data.RagFile()
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- vertex_rag_data_service.ListRagFilesResponse()
+ vertex_rag_data.RagFile()
)
# Call the method with a truthy value for each flattened field,
# using the keyword arguments to the method.
- response = await client.list_rag_files(
- parent="parent_value",
+ response = await client.get_rag_file(
+ name="name_value",
)
# Establish that the underlying call was made with the expected
# request object values.
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
- arg = args[0].parent
- mock_val = "parent_value"
+ arg = args[0].name
+ mock_val = "name_value"
assert arg == mock_val
@pytest.mark.asyncio
-async def test_list_rag_files_flattened_error_async():
+async def test_get_rag_file_flattened_error_async():
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
)
@@ -4488,320 +4504,126 @@ async def test_list_rag_files_flattened_error_async():
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- await client.list_rag_files(
- vertex_rag_data_service.ListRagFilesRequest(),
- parent="parent_value",
+ await client.get_rag_file(
+ vertex_rag_data_service.GetRagFileRequest(),
+ name="name_value",
)
-def test_list_rag_files_pager(transport_name: str = "grpc"):
+@pytest.mark.parametrize(
+ "request_type",
+ [
+ vertex_rag_data_service.ListRagFilesRequest,
+ dict,
+ ],
+)
+def test_list_rag_files(request_type, transport: str = "grpc"):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
- transport=transport_name,
+ transport=transport,
)
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
- # Set the response to a series of pages.
- call.side_effect = (
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[
- vertex_rag_data.RagFile(),
- vertex_rag_data.RagFile(),
- vertex_rag_data.RagFile(),
- ],
- next_page_token="abc",
- ),
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[],
- next_page_token="def",
- ),
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[
- vertex_rag_data.RagFile(),
- ],
- next_page_token="ghi",
- ),
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[
- vertex_rag_data.RagFile(),
- vertex_rag_data.RagFile(),
- ],
- ),
- RuntimeError,
- )
-
- expected_metadata = ()
- retry = retries.Retry()
- timeout = 5
- expected_metadata = tuple(expected_metadata) + (
- gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)),
+ # Designate an appropriate return value for the call.
+ call.return_value = vertex_rag_data_service.ListRagFilesResponse(
+ next_page_token="next_page_token_value",
)
- pager = client.list_rag_files(request={}, retry=retry, timeout=timeout)
+ response = client.list_rag_files(request)
- assert pager._metadata == expected_metadata
- assert pager._retry == retry
- assert pager._timeout == timeout
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ request = vertex_rag_data_service.ListRagFilesRequest()
+ assert args[0] == request
- results = list(pager)
- assert len(results) == 6
- assert all(isinstance(i, vertex_rag_data.RagFile) for i in results)
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, pagers.ListRagFilesPager)
+ assert response.next_page_token == "next_page_token_value"
-def test_list_rag_files_pages(transport_name: str = "grpc"):
+def test_list_rag_files_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
- transport=transport_name,
+ transport="grpc",
)
# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
- # Set the response to a series of pages.
- call.side_effect = (
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[
- vertex_rag_data.RagFile(),
- vertex_rag_data.RagFile(),
- vertex_rag_data.RagFile(),
- ],
- next_page_token="abc",
- ),
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[],
- next_page_token="def",
- ),
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[
- vertex_rag_data.RagFile(),
- ],
- next_page_token="ghi",
- ),
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[
- vertex_rag_data.RagFile(),
- vertex_rag_data.RagFile(),
- ],
- ),
- RuntimeError,
+ call.return_value.name = (
+ "foo" # operation_request.operation in compute client(s) expect a string.
)
- pages = list(client.list_rag_files(request={}).pages)
- for page_, token in zip(pages, ["abc", "def", "ghi", ""]):
- assert page_.raw_page.next_page_token == token
+ client.list_rag_files()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == vertex_rag_data_service.ListRagFilesRequest()
-@pytest.mark.asyncio
-async def test_list_rag_files_async_pager():
- client = VertexRagDataServiceAsyncClient(
+def test_list_rag_files_non_empty_request_with_auto_populated_field():
+ # This test is a coverage failsafe to make sure that UUID4 fields are
+ # automatically populated, according to AIP-4235, with non-empty requests.
+ client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
+ transport="grpc",
+ )
+
+ # Populate all string fields in the request which are not UUID4
+ # since we want to check that UUID4 are populated automatically
+ # if they meet the requirements of AIP 4235.
+ request = vertex_rag_data_service.ListRagFilesRequest(
+ parent="parent_value",
+ page_token="page_token_value",
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(
- type(client.transport.list_rag_files), "__call__", new_callable=mock.AsyncMock
- ) as call:
- # Set the response to a series of pages.
- call.side_effect = (
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[
- vertex_rag_data.RagFile(),
- vertex_rag_data.RagFile(),
- vertex_rag_data.RagFile(),
- ],
- next_page_token="abc",
- ),
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[],
- next_page_token="def",
- ),
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[
- vertex_rag_data.RagFile(),
- ],
- next_page_token="ghi",
- ),
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[
- vertex_rag_data.RagFile(),
- vertex_rag_data.RagFile(),
- ],
- ),
- RuntimeError,
+ with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
+ call.return_value.name = (
+ "foo" # operation_request.operation in compute client(s) expect a string.
)
- async_pager = await client.list_rag_files(
- request={},
+ client.list_rag_files(request=request)
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == vertex_rag_data_service.ListRagFilesRequest(
+ parent="parent_value",
+ page_token="page_token_value",
)
- assert async_pager.next_page_token == "abc"
- responses = []
- async for response in async_pager: # pragma: no branch
- responses.append(response)
- assert len(responses) == 6
- assert all(isinstance(i, vertex_rag_data.RagFile) for i in responses)
+def test_list_rag_files_use_cached_wrapped_rpc():
+ # Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
+ # instead of constructing them on each call
+ with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn:
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport="grpc",
+ )
-@pytest.mark.asyncio
-async def test_list_rag_files_async_pages():
- client = VertexRagDataServiceAsyncClient(
- credentials=ga_credentials.AnonymousCredentials(),
- )
-
- # Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(
- type(client.transport.list_rag_files), "__call__", new_callable=mock.AsyncMock
- ) as call:
- # Set the response to a series of pages.
- call.side_effect = (
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[
- vertex_rag_data.RagFile(),
- vertex_rag_data.RagFile(),
- vertex_rag_data.RagFile(),
- ],
- next_page_token="abc",
- ),
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[],
- next_page_token="def",
- ),
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[
- vertex_rag_data.RagFile(),
- ],
- next_page_token="ghi",
- ),
- vertex_rag_data_service.ListRagFilesResponse(
- rag_files=[
- vertex_rag_data.RagFile(),
- vertex_rag_data.RagFile(),
- ],
- ),
- RuntimeError,
- )
- pages = []
- # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch`
- # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372
- async for page_ in ( # pragma: no branch
- await client.list_rag_files(request={})
- ).pages:
- pages.append(page_)
- for page_, token in zip(pages, ["abc", "def", "ghi", ""]):
- assert page_.raw_page.next_page_token == token
-
-
-@pytest.mark.parametrize(
- "request_type",
- [
- vertex_rag_data_service.DeleteRagFileRequest,
- dict,
- ],
-)
-def test_delete_rag_file(request_type, transport: str = "grpc"):
- client = VertexRagDataServiceClient(
- credentials=ga_credentials.AnonymousCredentials(),
- transport=transport,
- )
-
- # Everything is optional in proto3 as far as the runtime is concerned,
- # and we are mocking out the actual API, so just send an empty request.
- request = request_type()
-
- # Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
- # Designate an appropriate return value for the call.
- call.return_value = operations_pb2.Operation(name="operations/spam")
- response = client.delete_rag_file(request)
-
- # Establish that the underlying gRPC stub method was called.
- assert len(call.mock_calls) == 1
- _, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.DeleteRagFileRequest()
- assert args[0] == request
-
- # Establish that the response is the type that we expect.
- assert isinstance(response, future.Future)
-
-
-def test_delete_rag_file_empty_call():
- # This test is a coverage failsafe to make sure that totally empty calls,
- # i.e. request == None and no flattened fields passed, work.
- client = VertexRagDataServiceClient(
- credentials=ga_credentials.AnonymousCredentials(),
- transport="grpc",
- )
-
- # Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
- call.return_value.name = (
- "foo" # operation_request.operation in compute client(s) expect a string.
- )
- client.delete_rag_file()
- call.assert_called()
- _, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.DeleteRagFileRequest()
-
-
-def test_delete_rag_file_non_empty_request_with_auto_populated_field():
- # This test is a coverage failsafe to make sure that UUID4 fields are
- # automatically populated, according to AIP-4235, with non-empty requests.
- client = VertexRagDataServiceClient(
- credentials=ga_credentials.AnonymousCredentials(),
- transport="grpc",
- )
-
- # Populate all string fields in the request which are not UUID4
- # since we want to check that UUID4 are populated automatically
- # if they meet the requirements of AIP 4235.
- request = vertex_rag_data_service.DeleteRagFileRequest(
- name="name_value",
- )
-
- # Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
- call.return_value.name = (
- "foo" # operation_request.operation in compute client(s) expect a string.
- )
- client.delete_rag_file(request=request)
- call.assert_called()
- _, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.DeleteRagFileRequest(
- name="name_value",
- )
-
-
-def test_delete_rag_file_use_cached_wrapped_rpc():
- # Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
- # instead of constructing them on each call
- with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn:
- client = VertexRagDataServiceClient(
- credentials=ga_credentials.AnonymousCredentials(),
- transport="grpc",
- )
-
- # Should wrap all calls on client creation
- assert wrapper_fn.call_count > 0
- wrapper_fn.reset_mock()
+ # Should wrap all calls on client creation
+ assert wrapper_fn.call_count > 0
+ wrapper_fn.reset_mock()
# Ensure method has been cached
- assert client._transport.delete_rag_file in client._transport._wrapped_methods
+ assert client._transport.list_rag_files in client._transport._wrapped_methods
# Replace cached wrapped function with mock
mock_rpc = mock.Mock()
mock_rpc.return_value.name = (
"foo" # operation_request.operation in compute client(s) expect a string.
)
- client._transport._wrapped_methods[client._transport.delete_rag_file] = mock_rpc
+ client._transport._wrapped_methods[client._transport.list_rag_files] = mock_rpc
request = {}
- client.delete_rag_file(request)
+ client.list_rag_files(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- # Operation methods call wrapper_fn to build a cached
- # client._transport.operations_client instance on first rpc call.
- # Subsequent calls should use the cached wrapper
- wrapper_fn.reset_mock()
-
- client.delete_rag_file(request)
+ client.list_rag_files(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -4809,7 +4631,7 @@ def test_delete_rag_file_use_cached_wrapped_rpc():
@pytest.mark.asyncio
-async def test_delete_rag_file_empty_call_async():
+async def test_list_rag_files_empty_call_async():
# This test is a coverage failsafe to make sure that totally empty calls,
# i.e. request == None and no flattened fields passed, work.
client = VertexRagDataServiceAsyncClient(
@@ -4818,19 +4640,21 @@ async def test_delete_rag_file_empty_call_async():
)
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
+ with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- operations_pb2.Operation(name="operations/spam")
+ vertex_rag_data_service.ListRagFilesResponse(
+ next_page_token="next_page_token_value",
+ )
)
- response = await client.delete_rag_file()
+ response = await client.list_rag_files()
call.assert_called()
_, args, _ = call.mock_calls[0]
- assert args[0] == vertex_rag_data_service.DeleteRagFileRequest()
+ assert args[0] == vertex_rag_data_service.ListRagFilesRequest()
@pytest.mark.asyncio
-async def test_delete_rag_file_async_use_cached_wrapped_rpc(
+async def test_list_rag_files_async_use_cached_wrapped_rpc(
transport: str = "grpc_asyncio",
):
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
@@ -4847,7 +4671,7 @@ async def test_delete_rag_file_async_use_cached_wrapped_rpc(
# Ensure method has been cached
assert (
- client._client._transport.delete_rag_file
+ client._client._transport.list_rag_files
in client._client._transport._wrapped_methods
)
@@ -4855,21 +4679,16 @@ async def test_delete_rag_file_async_use_cached_wrapped_rpc(
mock_rpc = mock.AsyncMock()
mock_rpc.return_value = mock.Mock()
client._client._transport._wrapped_methods[
- client._client._transport.delete_rag_file
+ client._client._transport.list_rag_files
] = mock_rpc
request = {}
- await client.delete_rag_file(request)
+ await client.list_rag_files(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
- # Operation methods call wrapper_fn to build a cached
- # client._transport.operations_client instance on first rpc call.
- # Subsequent calls should use the cached wrapper
- wrapper_fn.reset_mock()
-
- await client.delete_rag_file(request)
+ await client.list_rag_files(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
@@ -4877,9 +4696,9 @@ async def test_delete_rag_file_async_use_cached_wrapped_rpc(
@pytest.mark.asyncio
-async def test_delete_rag_file_async(
+async def test_list_rag_files_async(
transport: str = "grpc_asyncio",
- request_type=vertex_rag_data_service.DeleteRagFileRequest,
+ request_type=vertex_rag_data_service.ListRagFilesRequest,
):
client = VertexRagDataServiceAsyncClient(
credentials=ga_credentials.AnonymousCredentials(),
@@ -4891,43 +4710,46 @@ async def test_delete_rag_file_async(
request = request_type()
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
+ with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
# Designate an appropriate return value for the call.
call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- operations_pb2.Operation(name="operations/spam")
+ vertex_rag_data_service.ListRagFilesResponse(
+ next_page_token="next_page_token_value",
+ )
)
- response = await client.delete_rag_file(request)
+ response = await client.list_rag_files(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
- request = vertex_rag_data_service.DeleteRagFileRequest()
+ request = vertex_rag_data_service.ListRagFilesRequest()
assert args[0] == request
# Establish that the response is the type that we expect.
- assert isinstance(response, future.Future)
+ assert isinstance(response, pagers.ListRagFilesAsyncPager)
+ assert response.next_page_token == "next_page_token_value"
@pytest.mark.asyncio
-async def test_delete_rag_file_async_from_dict():
- await test_delete_rag_file_async(request_type=dict)
+async def test_list_rag_files_async_from_dict():
+ await test_list_rag_files_async(request_type=dict)
-def test_delete_rag_file_field_headers():
+def test_list_rag_files_field_headers():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
)
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.DeleteRagFileRequest()
+ request = vertex_rag_data_service.ListRagFilesRequest()
- request.name = "name_value"
+ request.parent = "parent_value"
# Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
- call.return_value = operations_pb2.Operation(name="operations/op")
- client.delete_rag_file(request)
+ with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
+ call.return_value = vertex_rag_data_service.ListRagFilesResponse()
+ client.list_rag_files(request)
# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
@@ -4938,141 +4760,1122 @@ def test_delete_rag_file_field_headers():
_, _, kw = call.mock_calls[0]
assert (
"x-goog-request-params",
- "name=name_value",
+ "parent=parent_value",
+ ) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_list_rag_files_field_headers_async():
+ client = VertexRagDataServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = vertex_rag_data_service.ListRagFilesRequest()
+
+ request.parent = "parent_value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ vertex_rag_data_service.ListRagFilesResponse()
+ )
+ await client.list_rag_files(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert (
+ "x-goog-request-params",
+ "parent=parent_value",
) in kw["metadata"]
-@pytest.mark.asyncio
-async def test_delete_rag_file_field_headers_async():
- client = VertexRagDataServiceAsyncClient(
+def test_list_rag_files_flattened():
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = vertex_rag_data_service.ListRagFilesResponse()
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ client.list_rag_files(
+ parent="parent_value",
+ )
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ arg = args[0].parent
+ mock_val = "parent_value"
+ assert arg == mock_val
+
+
+def test_list_rag_files_flattened_error():
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ client.list_rag_files(
+ vertex_rag_data_service.ListRagFilesRequest(),
+ parent="parent_value",
+ )
+
+
+@pytest.mark.asyncio
+async def test_list_rag_files_flattened_async():
+ client = VertexRagDataServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = vertex_rag_data_service.ListRagFilesResponse()
+
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ vertex_rag_data_service.ListRagFilesResponse()
+ )
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ response = await client.list_rag_files(
+ parent="parent_value",
+ )
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ arg = args[0].parent
+ mock_val = "parent_value"
+ assert arg == mock_val
+
+
+@pytest.mark.asyncio
+async def test_list_rag_files_flattened_error_async():
+ client = VertexRagDataServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ await client.list_rag_files(
+ vertex_rag_data_service.ListRagFilesRequest(),
+ parent="parent_value",
+ )
+
+
+def test_list_rag_files_pager(transport_name: str = "grpc"):
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport=transport_name,
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[
+ vertex_rag_data.RagFile(),
+ vertex_rag_data.RagFile(),
+ vertex_rag_data.RagFile(),
+ ],
+ next_page_token="abc",
+ ),
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[],
+ next_page_token="def",
+ ),
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[
+ vertex_rag_data.RagFile(),
+ ],
+ next_page_token="ghi",
+ ),
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[
+ vertex_rag_data.RagFile(),
+ vertex_rag_data.RagFile(),
+ ],
+ ),
+ RuntimeError,
+ )
+
+ expected_metadata = ()
+ retry = retries.Retry()
+ timeout = 5
+ expected_metadata = tuple(expected_metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)),
+ )
+ pager = client.list_rag_files(request={}, retry=retry, timeout=timeout)
+
+ assert pager._metadata == expected_metadata
+ assert pager._retry == retry
+ assert pager._timeout == timeout
+
+ results = list(pager)
+ assert len(results) == 6
+ assert all(isinstance(i, vertex_rag_data.RagFile) for i in results)
+
+
+def test_list_rag_files_pages(transport_name: str = "grpc"):
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport=transport_name,
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[
+ vertex_rag_data.RagFile(),
+ vertex_rag_data.RagFile(),
+ vertex_rag_data.RagFile(),
+ ],
+ next_page_token="abc",
+ ),
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[],
+ next_page_token="def",
+ ),
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[
+ vertex_rag_data.RagFile(),
+ ],
+ next_page_token="ghi",
+ ),
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[
+ vertex_rag_data.RagFile(),
+ vertex_rag_data.RagFile(),
+ ],
+ ),
+ RuntimeError,
+ )
+ pages = list(client.list_rag_files(request={}).pages)
+ for page_, token in zip(pages, ["abc", "def", "ghi", ""]):
+ assert page_.raw_page.next_page_token == token
+
+
+@pytest.mark.asyncio
+async def test_list_rag_files_async_pager():
+ client = VertexRagDataServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_rag_files), "__call__", new_callable=mock.AsyncMock
+ ) as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[
+ vertex_rag_data.RagFile(),
+ vertex_rag_data.RagFile(),
+ vertex_rag_data.RagFile(),
+ ],
+ next_page_token="abc",
+ ),
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[],
+ next_page_token="def",
+ ),
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[
+ vertex_rag_data.RagFile(),
+ ],
+ next_page_token="ghi",
+ ),
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[
+ vertex_rag_data.RagFile(),
+ vertex_rag_data.RagFile(),
+ ],
+ ),
+ RuntimeError,
+ )
+ async_pager = await client.list_rag_files(
+ request={},
+ )
+ assert async_pager.next_page_token == "abc"
+ responses = []
+ async for response in async_pager: # pragma: no branch
+ responses.append(response)
+
+ assert len(responses) == 6
+ assert all(isinstance(i, vertex_rag_data.RagFile) for i in responses)
+
+
+@pytest.mark.asyncio
+async def test_list_rag_files_async_pages():
+ client = VertexRagDataServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_rag_files), "__call__", new_callable=mock.AsyncMock
+ ) as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[
+ vertex_rag_data.RagFile(),
+ vertex_rag_data.RagFile(),
+ vertex_rag_data.RagFile(),
+ ],
+ next_page_token="abc",
+ ),
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[],
+ next_page_token="def",
+ ),
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[
+ vertex_rag_data.RagFile(),
+ ],
+ next_page_token="ghi",
+ ),
+ vertex_rag_data_service.ListRagFilesResponse(
+ rag_files=[
+ vertex_rag_data.RagFile(),
+ vertex_rag_data.RagFile(),
+ ],
+ ),
+ RuntimeError,
+ )
+ pages = []
+ # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch`
+ # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372
+ async for page_ in ( # pragma: no branch
+ await client.list_rag_files(request={})
+ ).pages:
+ pages.append(page_)
+ for page_, token in zip(pages, ["abc", "def", "ghi", ""]):
+ assert page_.raw_page.next_page_token == token
+
+
+@pytest.mark.parametrize(
+ "request_type",
+ [
+ vertex_rag_data_service.DeleteRagFileRequest,
+ dict,
+ ],
+)
+def test_delete_rag_file(request_type, transport: str = "grpc"):
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = operations_pb2.Operation(name="operations/spam")
+ response = client.delete_rag_file(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ request = vertex_rag_data_service.DeleteRagFileRequest()
+ assert args[0] == request
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, future.Future)
+
+
+def test_delete_rag_file_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
+ call.return_value.name = (
+ "foo" # operation_request.operation in compute client(s) expect a string.
+ )
+ client.delete_rag_file()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == vertex_rag_data_service.DeleteRagFileRequest()
+
+
+def test_delete_rag_file_non_empty_request_with_auto_populated_field():
+ # This test is a coverage failsafe to make sure that UUID4 fields are
+ # automatically populated, according to AIP-4235, with non-empty requests.
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport="grpc",
+ )
+
+ # Populate all string fields in the request which are not UUID4
+ # since we want to check that UUID4 are populated automatically
+ # if they meet the requirements of AIP 4235.
+ request = vertex_rag_data_service.DeleteRagFileRequest(
+ name="name_value",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
+ call.return_value.name = (
+ "foo" # operation_request.operation in compute client(s) expect a string.
+ )
+ client.delete_rag_file(request=request)
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == vertex_rag_data_service.DeleteRagFileRequest(
+ name="name_value",
+ )
+
+
+def test_delete_rag_file_use_cached_wrapped_rpc():
+ # Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
+ # instead of constructing them on each call
+ with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn:
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport="grpc",
+ )
+
+ # Should wrap all calls on client creation
+ assert wrapper_fn.call_count > 0
+ wrapper_fn.reset_mock()
+
+ # Ensure method has been cached
+ assert client._transport.delete_rag_file in client._transport._wrapped_methods
+
+ # Replace cached wrapped function with mock
+ mock_rpc = mock.Mock()
+ mock_rpc.return_value.name = (
+ "foo" # operation_request.operation in compute client(s) expect a string.
+ )
+ client._transport._wrapped_methods[client._transport.delete_rag_file] = mock_rpc
+ request = {}
+ client.delete_rag_file(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert mock_rpc.call_count == 1
+
+ # Operation methods call wrapper_fn to build a cached
+ # client._transport.operations_client instance on first rpc call.
+ # Subsequent calls should use the cached wrapper
+ wrapper_fn.reset_mock()
+
+ client.delete_rag_file(request)
+
+ # Establish that a new wrapper was not created for this call
+ assert wrapper_fn.call_count == 0
+ assert mock_rpc.call_count == 2
+
+
+@pytest.mark.asyncio
+async def test_delete_rag_file_empty_call_async():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = VertexRagDataServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport="grpc_asyncio",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/spam")
+ )
+ response = await client.delete_rag_file()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == vertex_rag_data_service.DeleteRagFileRequest()
+
+
+@pytest.mark.asyncio
+async def test_delete_rag_file_async_use_cached_wrapped_rpc(
+ transport: str = "grpc_asyncio",
+):
+ # Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
+ # instead of constructing them on each call
+ with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn:
+ client = VertexRagDataServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport=transport,
+ )
+
+ # Should wrap all calls on client creation
+ assert wrapper_fn.call_count > 0
+ wrapper_fn.reset_mock()
+
+ # Ensure method has been cached
+ assert (
+ client._client._transport.delete_rag_file
+ in client._client._transport._wrapped_methods
+ )
+
+ # Replace cached wrapped function with mock
+ mock_rpc = mock.AsyncMock()
+ mock_rpc.return_value = mock.Mock()
+ client._client._transport._wrapped_methods[
+ client._client._transport.delete_rag_file
+ ] = mock_rpc
+
+ request = {}
+ await client.delete_rag_file(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert mock_rpc.call_count == 1
+
+ # Operation methods call wrapper_fn to build a cached
+ # client._transport.operations_client instance on first rpc call.
+ # Subsequent calls should use the cached wrapper
+ wrapper_fn.reset_mock()
+
+ await client.delete_rag_file(request)
+
+ # Establish that a new wrapper was not created for this call
+ assert wrapper_fn.call_count == 0
+ assert mock_rpc.call_count == 2
+
+
+@pytest.mark.asyncio
+async def test_delete_rag_file_async(
+ transport: str = "grpc_asyncio",
+ request_type=vertex_rag_data_service.DeleteRagFileRequest,
+):
+ client = VertexRagDataServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/spam")
+ )
+ response = await client.delete_rag_file(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ request = vertex_rag_data_service.DeleteRagFileRequest()
+ assert args[0] == request
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, future.Future)
+
+
+@pytest.mark.asyncio
+async def test_delete_rag_file_async_from_dict():
+ await test_delete_rag_file_async(request_type=dict)
+
+
+def test_delete_rag_file_field_headers():
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = vertex_rag_data_service.DeleteRagFileRequest()
+
+ request.name = "name_value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
+ call.return_value = operations_pb2.Operation(name="operations/op")
+ client.delete_rag_file(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert (
+ "x-goog-request-params",
+ "name=name_value",
+ ) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_delete_rag_file_field_headers_async():
+ client = VertexRagDataServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = vertex_rag_data_service.DeleteRagFileRequest()
+
+ request.name = "name_value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/op")
+ )
+ await client.delete_rag_file(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert (
+ "x-goog-request-params",
+ "name=name_value",
+ ) in kw["metadata"]
+
+
+def test_delete_rag_file_flattened():
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = operations_pb2.Operation(name="operations/op")
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ client.delete_rag_file(
+ name="name_value",
+ )
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ arg = args[0].name
+ mock_val = "name_value"
+ assert arg == mock_val
+
+
+def test_delete_rag_file_flattened_error():
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ client.delete_rag_file(
+ vertex_rag_data_service.DeleteRagFileRequest(),
+ name="name_value",
+ )
+
+
+@pytest.mark.asyncio
+async def test_delete_rag_file_flattened_async():
+ client = VertexRagDataServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = operations_pb2.Operation(name="operations/op")
+
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/spam")
+ )
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ response = await client.delete_rag_file(
+ name="name_value",
+ )
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ arg = args[0].name
+ mock_val = "name_value"
+ assert arg == mock_val
+
+
+@pytest.mark.asyncio
+async def test_delete_rag_file_flattened_error_async():
+ client = VertexRagDataServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ await client.delete_rag_file(
+ vertex_rag_data_service.DeleteRagFileRequest(),
+ name="name_value",
+ )
+
+
+@pytest.mark.parametrize(
+ "request_type",
+ [
+ vertex_rag_data_service.CreateRagCorpusRequest,
+ dict,
+ ],
+)
+def test_create_rag_corpus_rest(request_type):
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport="rest",
+ )
+
+ # send a request that will satisfy transcoding
+ request_init = {"parent": "projects/sample1/locations/sample2"}
+ request_init["rag_corpus"] = {
+ "name": "name_value",
+ "display_name": "display_name_value",
+ "description": "description_value",
+ "rag_embedding_model_config": {
+ "vertex_prediction_endpoint": {
+ "endpoint": "endpoint_value",
+ "model": "model_value",
+ "model_version_id": "model_version_id_value",
+ },
+ "hybrid_search_config": {
+ "sparse_embedding_config": {
+ "bm25": {"multilingual": True, "k1": 0.156, "b": 0.98}
+ },
+ "dense_embedding_model_prediction_endpoint": {},
+ },
+ },
+ "rag_vector_db_config": {
+ "rag_managed_db": {},
+ "weaviate": {
+ "http_endpoint": "http_endpoint_value",
+ "collection_name": "collection_name_value",
+ },
+ "vertex_feature_store": {
+ "feature_view_resource_name": "feature_view_resource_name_value"
+ },
+ "api_auth": {
+ "api_key_config": {
+ "api_key_secret_version": "api_key_secret_version_value"
+ }
+ },
+ },
+ "create_time": {"seconds": 751, "nanos": 543},
+ "update_time": {},
+ "corpus_status": {"state": 1, "error_status": "error_status_value"},
+ }
+ # The version of a generated dependency at test runtime may differ from the version used during generation.
+ # Delete any fields which are not present in the current runtime dependency
+ # See https://github.com/googleapis/gapic-generator-python/issues/1748
+
+ # Determine if the message type is proto-plus or protobuf
+ test_field = vertex_rag_data_service.CreateRagCorpusRequest.meta.fields[
+ "rag_corpus"
+ ]
+
+ def get_message_fields(field):
+ # Given a field which is a message (composite type), return a list with
+ # all the fields of the message.
+ # If the field is not a composite type, return an empty list.
+ message_fields = []
+
+ if hasattr(field, "message") and field.message:
+ is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR")
+
+ if is_field_type_proto_plus_type:
+ message_fields = field.message.meta.fields.values()
+ # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types
+ else: # pragma: NO COVER
+ message_fields = field.message.DESCRIPTOR.fields
+ return message_fields
+
+ runtime_nested_fields = [
+ (field.name, nested_field.name)
+ for field in get_message_fields(test_field)
+ for nested_field in get_message_fields(field)
+ ]
+
+ subfields_not_in_runtime = []
+
+ # For each item in the sample request, create a list of sub fields which are not present at runtime
+ # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime
+ for field, value in request_init["rag_corpus"].items(): # pragma: NO COVER
+ result = None
+ is_repeated = False
+ # For repeated fields
+ if isinstance(value, list) and len(value):
+ is_repeated = True
+ result = value[0]
+ # For fields where the type is another message
+ if isinstance(value, dict):
+ result = value
+
+ if result and hasattr(result, "keys"):
+ for subfield in result.keys():
+ if (field, subfield) not in runtime_nested_fields:
+ subfields_not_in_runtime.append(
+ {
+ "field": field,
+ "subfield": subfield,
+ "is_repeated": is_repeated,
+ }
+ )
+
+ # Remove fields from the sample request which are not present in the runtime version of the dependency
+ # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime
+ for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER
+ field = subfield_to_delete.get("field")
+ field_repeated = subfield_to_delete.get("is_repeated")
+ subfield = subfield_to_delete.get("subfield")
+ if subfield:
+ if field_repeated:
+ for i in range(0, len(request_init["rag_corpus"][field])):
+ del request_init["rag_corpus"][field][i][subfield]
+ else:
+ del request_init["rag_corpus"][field][subfield]
+ request = request_type(**request_init)
+
+ # Mock the http request call within the method and fake a response.
+ with mock.patch.object(type(client.transport._session), "request") as req:
+ # Designate an appropriate value for the returned response.
+ return_value = operations_pb2.Operation(name="operations/spam")
+
+ # Wrap the value into a proper Response obj
+ response_value = Response()
+ response_value.status_code = 200
+ json_return_value = json_format.MessageToJson(return_value)
+
+ response_value._content = json_return_value.encode("UTF-8")
+ req.return_value = response_value
+ response = client.create_rag_corpus(request)
+
+ # Establish that the response is the type that we expect.
+ assert response.operation.name == "operations/spam"
+
+
+def test_create_rag_corpus_rest_use_cached_wrapped_rpc():
+ # Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
+ # instead of constructing them on each call
+ with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn:
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport="rest",
+ )
+
+ # Should wrap all calls on client creation
+ assert wrapper_fn.call_count > 0
+ wrapper_fn.reset_mock()
+
+ # Ensure method has been cached
+ assert client._transport.create_rag_corpus in client._transport._wrapped_methods
+
+ # Replace cached wrapped function with mock
+ mock_rpc = mock.Mock()
+ mock_rpc.return_value.name = (
+ "foo" # operation_request.operation in compute client(s) expect a string.
+ )
+ client._transport._wrapped_methods[
+ client._transport.create_rag_corpus
+ ] = mock_rpc
+
+ request = {}
+ client.create_rag_corpus(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert mock_rpc.call_count == 1
+
+ # Operation methods build a cached wrapper on first rpc call
+ # subsequent calls should use the cached wrapper
+ wrapper_fn.reset_mock()
+
+ client.create_rag_corpus(request)
+
+ # Establish that a new wrapper was not created for this call
+ assert wrapper_fn.call_count == 0
+ assert mock_rpc.call_count == 2
+
+
+def test_create_rag_corpus_rest_required_fields(
+ request_type=vertex_rag_data_service.CreateRagCorpusRequest,
+):
+ transport_class = transports.VertexRagDataServiceRestTransport
+
+ request_init = {}
+ request_init["parent"] = ""
+ request = request_type(**request_init)
+ pb_request = request_type.pb(request)
+ jsonified_request = json.loads(
+ json_format.MessageToJson(pb_request, use_integers_for_enums=False)
+ )
+
+ # verify fields with default values are dropped
+
+ unset_fields = transport_class(
+ credentials=ga_credentials.AnonymousCredentials()
+ ).create_rag_corpus._get_unset_required_fields(jsonified_request)
+ jsonified_request.update(unset_fields)
+
+ # verify required fields with default values are now present
+
+ jsonified_request["parent"] = "parent_value"
+
+ unset_fields = transport_class(
+ credentials=ga_credentials.AnonymousCredentials()
+ ).create_rag_corpus._get_unset_required_fields(jsonified_request)
+ jsonified_request.update(unset_fields)
+
+ # verify required fields with non-default values are left alone
+ assert "parent" in jsonified_request
+ assert jsonified_request["parent"] == "parent_value"
+
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport="rest",
+ )
+ request = request_type(**request_init)
+
+ # Designate an appropriate value for the returned response.
+ return_value = operations_pb2.Operation(name="operations/spam")
+ # Mock the http request call within the method and fake a response.
+ with mock.patch.object(Session, "request") as req:
+ # We need to mock transcode() because providing default values
+ # for required fields will fail the real version if the http_options
+ # expect actual values for those fields.
+ with mock.patch.object(path_template, "transcode") as transcode:
+ # A uri without fields and an empty body will force all the
+ # request fields to show up in the query_params.
+ pb_request = request_type.pb(request)
+ transcode_result = {
+ "uri": "v1/sample_method",
+ "method": "post",
+ "query_params": pb_request,
+ }
+ transcode_result["body"] = pb_request
+ transcode.return_value = transcode_result
+
+ response_value = Response()
+ response_value.status_code = 200
+ json_return_value = json_format.MessageToJson(return_value)
+
+ response_value._content = json_return_value.encode("UTF-8")
+ req.return_value = response_value
+
+ response = client.create_rag_corpus(request)
+
+ expected_params = [("$alt", "json;enum-encoding=int")]
+ actual_params = req.call_args.kwargs["params"]
+ assert expected_params == actual_params
+
+
+def test_create_rag_corpus_rest_unset_required_fields():
+ transport = transports.VertexRagDataServiceRestTransport(
+ credentials=ga_credentials.AnonymousCredentials
+ )
+
+ unset_fields = transport.create_rag_corpus._get_unset_required_fields({})
+ assert set(unset_fields) == (
+ set(())
+ & set(
+ (
+ "parent",
+ "ragCorpus",
+ )
+ )
+ )
+
+
+@pytest.mark.parametrize("null_interceptor", [True, False])
+def test_create_rag_corpus_rest_interceptors(null_interceptor):
+ transport = transports.VertexRagDataServiceRestTransport(
credentials=ga_credentials.AnonymousCredentials(),
+ interceptor=None
+ if null_interceptor
+ else transports.VertexRagDataServiceRestInterceptor(),
)
+ client = VertexRagDataServiceClient(transport=transport)
+ with mock.patch.object(
+ type(client.transport._session), "request"
+ ) as req, mock.patch.object(
+ path_template, "transcode"
+ ) as transcode, mock.patch.object(
+ operation.Operation, "_set_result_from_operation"
+ ), mock.patch.object(
+ transports.VertexRagDataServiceRestInterceptor, "post_create_rag_corpus"
+ ) as post, mock.patch.object(
+ transports.VertexRagDataServiceRestInterceptor, "pre_create_rag_corpus"
+ ) as pre:
+ pre.assert_not_called()
+ post.assert_not_called()
+ pb_message = vertex_rag_data_service.CreateRagCorpusRequest.pb(
+ vertex_rag_data_service.CreateRagCorpusRequest()
+ )
+ transcode.return_value = {
+ "method": "post",
+ "uri": "my_uri",
+ "body": pb_message,
+ "query_params": pb_message,
+ }
- # Any value that is part of the HTTP/1.1 URI should be sent as
- # a field header. Set these to a non-empty value.
- request = vertex_rag_data_service.DeleteRagFileRequest()
+ req.return_value = Response()
+ req.return_value.status_code = 200
+ req.return_value.request = PreparedRequest()
+ req.return_value._content = json_format.MessageToJson(
+ operations_pb2.Operation()
+ )
- request.name = "name_value"
+ request = vertex_rag_data_service.CreateRagCorpusRequest()
+ metadata = [
+ ("key", "val"),
+ ("cephalopod", "squid"),
+ ]
+ pre.return_value = request, metadata
+ post.return_value = operations_pb2.Operation()
- # Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
- call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- operations_pb2.Operation(name="operations/op")
+ client.create_rag_corpus(
+ request,
+ metadata=[
+ ("key", "val"),
+ ("cephalopod", "squid"),
+ ],
)
- await client.delete_rag_file(request)
-
- # Establish that the underlying gRPC stub method was called.
- assert len(call.mock_calls)
- _, args, _ = call.mock_calls[0]
- assert args[0] == request
- # Establish that the field header was sent.
- _, _, kw = call.mock_calls[0]
- assert (
- "x-goog-request-params",
- "name=name_value",
- ) in kw["metadata"]
+ pre.assert_called_once()
+ post.assert_called_once()
-def test_delete_rag_file_flattened():
+def test_create_rag_corpus_rest_bad_request(
+ transport: str = "rest", request_type=vertex_rag_data_service.CreateRagCorpusRequest
+):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
+ transport=transport,
)
- # Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
- # Designate an appropriate return value for the call.
- call.return_value = operations_pb2.Operation(name="operations/op")
- # Call the method with a truthy value for each flattened field,
- # using the keyword arguments to the method.
- client.delete_rag_file(
- name="name_value",
- )
+ # send a request that will satisfy transcoding
+ request_init = {"parent": "projects/sample1/locations/sample2"}
+ request = request_type(**request_init)
- # Establish that the underlying call was made with the expected
- # request object values.
- assert len(call.mock_calls) == 1
- _, args, _ = call.mock_calls[0]
- arg = args[0].name
- mock_val = "name_value"
- assert arg == mock_val
+ # Mock the http request call within the method and fake a BadRequest error.
+ with mock.patch.object(Session, "request") as req, pytest.raises(
+ core_exceptions.BadRequest
+ ):
+ # Wrap the value into a proper Response obj
+ response_value = Response()
+ response_value.status_code = 400
+ response_value.request = Request()
+ req.return_value = response_value
+ client.create_rag_corpus(request)
-def test_delete_rag_file_flattened_error():
+def test_create_rag_corpus_rest_flattened():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
+ transport="rest",
)
- # Attempting to call a method with both a request object and flattened
- # fields is an error.
- with pytest.raises(ValueError):
- client.delete_rag_file(
- vertex_rag_data_service.DeleteRagFileRequest(),
- name="name_value",
- )
+ # Mock the http request call within the method and fake a response.
+ with mock.patch.object(type(client.transport._session), "request") as req:
+ # Designate an appropriate value for the returned response.
+ return_value = operations_pb2.Operation(name="operations/spam")
+ # get arguments that satisfy an http rule for this method
+ sample_request = {"parent": "projects/sample1/locations/sample2"}
-@pytest.mark.asyncio
-async def test_delete_rag_file_flattened_async():
- client = VertexRagDataServiceAsyncClient(
- credentials=ga_credentials.AnonymousCredentials(),
- )
+ # get truthy value for each flattened field
+ mock_args = dict(
+ parent="parent_value",
+ rag_corpus=vertex_rag_data.RagCorpus(name="name_value"),
+ )
+ mock_args.update(sample_request)
- # Mock the actual call within the gRPC stub, and fake the request.
- with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call:
- # Designate an appropriate return value for the call.
- call.return_value = operations_pb2.Operation(name="operations/op")
+ # Wrap the value into a proper Response obj
+ response_value = Response()
+ response_value.status_code = 200
+ json_return_value = json_format.MessageToJson(return_value)
+ response_value._content = json_return_value.encode("UTF-8")
+ req.return_value = response_value
- call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
- operations_pb2.Operation(name="operations/spam")
- )
- # Call the method with a truthy value for each flattened field,
- # using the keyword arguments to the method.
- response = await client.delete_rag_file(
- name="name_value",
- )
+ client.create_rag_corpus(**mock_args)
# Establish that the underlying call was made with the expected
# request object values.
- assert len(call.mock_calls)
- _, args, _ = call.mock_calls[0]
- arg = args[0].name
- mock_val = "name_value"
- assert arg == mock_val
+ assert len(req.mock_calls) == 1
+ _, args, _ = req.mock_calls[0]
+ assert path_template.validate(
+ "%s/v1beta1/{parent=projects/*/locations/*}/ragCorpora"
+ % client.transport._host,
+ args[1],
+ )
-@pytest.mark.asyncio
-async def test_delete_rag_file_flattened_error_async():
- client = VertexRagDataServiceAsyncClient(
+def test_create_rag_corpus_rest_flattened_error(transport: str = "rest"):
+ client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
+ transport=transport,
)
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- await client.delete_rag_file(
- vertex_rag_data_service.DeleteRagFileRequest(),
- name="name_value",
+ client.create_rag_corpus(
+ vertex_rag_data_service.CreateRagCorpusRequest(),
+ parent="parent_value",
+ rag_corpus=vertex_rag_data.RagCorpus(name="name_value"),
)
+def test_create_rag_corpus_rest_error():
+ client = VertexRagDataServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="rest"
+ )
+
+
@pytest.mark.parametrize(
"request_type",
[
- vertex_rag_data_service.CreateRagCorpusRequest,
+ vertex_rag_data_service.UpdateRagCorpusRequest,
dict,
],
)
-def test_create_rag_corpus_rest(request_type):
+def test_update_rag_corpus_rest(request_type):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
transport="rest",
)
# send a request that will satisfy transcoding
- request_init = {"parent": "projects/sample1/locations/sample2"}
+ request_init = {
+ "rag_corpus": {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"}
+ }
request_init["rag_corpus"] = {
- "name": "name_value",
+ "name": "projects/sample1/locations/sample2/ragCorpora/sample3",
"display_name": "display_name_value",
"description": "description_value",
"rag_embedding_model_config": {
@@ -5080,17 +5883,39 @@ def test_create_rag_corpus_rest(request_type):
"endpoint": "endpoint_value",
"model": "model_value",
"model_version_id": "model_version_id_value",
- }
+ },
+ "hybrid_search_config": {
+ "sparse_embedding_config": {
+ "bm25": {"multilingual": True, "k1": 0.156, "b": 0.98}
+ },
+ "dense_embedding_model_prediction_endpoint": {},
+ },
+ },
+ "rag_vector_db_config": {
+ "rag_managed_db": {},
+ "weaviate": {
+ "http_endpoint": "http_endpoint_value",
+ "collection_name": "collection_name_value",
+ },
+ "vertex_feature_store": {
+ "feature_view_resource_name": "feature_view_resource_name_value"
+ },
+ "api_auth": {
+ "api_key_config": {
+ "api_key_secret_version": "api_key_secret_version_value"
+ }
+ },
},
"create_time": {"seconds": 751, "nanos": 543},
"update_time": {},
+ "corpus_status": {"state": 1, "error_status": "error_status_value"},
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
# See https://github.com/googleapis/gapic-generator-python/issues/1748
# Determine if the message type is proto-plus or protobuf
- test_field = vertex_rag_data_service.CreateRagCorpusRequest.meta.fields[
+ test_field = vertex_rag_data_service.UpdateRagCorpusRequest.meta.fields[
"rag_corpus"
]
@@ -5168,13 +5993,13 @@ def get_message_fields(field):
response_value._content = json_return_value.encode("UTF-8")
req.return_value = response_value
- response = client.create_rag_corpus(request)
+ response = client.update_rag_corpus(request)
# Establish that the response is the type that we expect.
assert response.operation.name == "operations/spam"
-def test_create_rag_corpus_rest_use_cached_wrapped_rpc():
+def test_update_rag_corpus_rest_use_cached_wrapped_rpc():
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
# instead of constructing them on each call
with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn:
@@ -5188,7 +6013,7 @@ def test_create_rag_corpus_rest_use_cached_wrapped_rpc():
wrapper_fn.reset_mock()
# Ensure method has been cached
- assert client._transport.create_rag_corpus in client._transport._wrapped_methods
+ assert client._transport.update_rag_corpus in client._transport._wrapped_methods
# Replace cached wrapped function with mock
mock_rpc = mock.Mock()
@@ -5196,11 +6021,11 @@ def test_create_rag_corpus_rest_use_cached_wrapped_rpc():
"foo" # operation_request.operation in compute client(s) expect a string.
)
client._transport._wrapped_methods[
- client._transport.create_rag_corpus
+ client._transport.update_rag_corpus
] = mock_rpc
request = {}
- client.create_rag_corpus(request)
+ client.update_rag_corpus(request)
# Establish that the underlying gRPC stub method was called.
assert mock_rpc.call_count == 1
@@ -5209,20 +6034,19 @@ def test_create_rag_corpus_rest_use_cached_wrapped_rpc():
# subsequent calls should use the cached wrapper
wrapper_fn.reset_mock()
- client.create_rag_corpus(request)
+ client.update_rag_corpus(request)
# Establish that a new wrapper was not created for this call
assert wrapper_fn.call_count == 0
assert mock_rpc.call_count == 2
-def test_create_rag_corpus_rest_required_fields(
- request_type=vertex_rag_data_service.CreateRagCorpusRequest,
+def test_update_rag_corpus_rest_required_fields(
+ request_type=vertex_rag_data_service.UpdateRagCorpusRequest,
):
transport_class = transports.VertexRagDataServiceRestTransport
request_init = {}
- request_init["parent"] = ""
request = request_type(**request_init)
pb_request = request_type.pb(request)
jsonified_request = json.loads(
@@ -5233,21 +6057,17 @@ def test_create_rag_corpus_rest_required_fields(
unset_fields = transport_class(
credentials=ga_credentials.AnonymousCredentials()
- ).create_rag_corpus._get_unset_required_fields(jsonified_request)
+ ).update_rag_corpus._get_unset_required_fields(jsonified_request)
jsonified_request.update(unset_fields)
# verify required fields with default values are now present
- jsonified_request["parent"] = "parent_value"
-
unset_fields = transport_class(
credentials=ga_credentials.AnonymousCredentials()
- ).create_rag_corpus._get_unset_required_fields(jsonified_request)
+ ).update_rag_corpus._get_unset_required_fields(jsonified_request)
jsonified_request.update(unset_fields)
# verify required fields with non-default values are left alone
- assert "parent" in jsonified_request
- assert jsonified_request["parent"] == "parent_value"
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
@@ -5268,7 +6088,7 @@ def test_create_rag_corpus_rest_required_fields(
pb_request = request_type.pb(request)
transcode_result = {
"uri": "v1/sample_method",
- "method": "post",
+ "method": "patch",
"query_params": pb_request,
}
transcode_result["body"] = pb_request
@@ -5281,32 +6101,24 @@ def test_create_rag_corpus_rest_required_fields(
response_value._content = json_return_value.encode("UTF-8")
req.return_value = response_value
- response = client.create_rag_corpus(request)
+ response = client.update_rag_corpus(request)
expected_params = [("$alt", "json;enum-encoding=int")]
actual_params = req.call_args.kwargs["params"]
assert expected_params == actual_params
-def test_create_rag_corpus_rest_unset_required_fields():
+def test_update_rag_corpus_rest_unset_required_fields():
transport = transports.VertexRagDataServiceRestTransport(
credentials=ga_credentials.AnonymousCredentials
)
- unset_fields = transport.create_rag_corpus._get_unset_required_fields({})
- assert set(unset_fields) == (
- set(())
- & set(
- (
- "parent",
- "ragCorpus",
- )
- )
- )
+ unset_fields = transport.update_rag_corpus._get_unset_required_fields({})
+ assert set(unset_fields) == (set(()) & set(("ragCorpus",)))
@pytest.mark.parametrize("null_interceptor", [True, False])
-def test_create_rag_corpus_rest_interceptors(null_interceptor):
+def test_update_rag_corpus_rest_interceptors(null_interceptor):
transport = transports.VertexRagDataServiceRestTransport(
credentials=ga_credentials.AnonymousCredentials(),
interceptor=None
@@ -5321,14 +6133,14 @@ def test_create_rag_corpus_rest_interceptors(null_interceptor):
) as transcode, mock.patch.object(
operation.Operation, "_set_result_from_operation"
), mock.patch.object(
- transports.VertexRagDataServiceRestInterceptor, "post_create_rag_corpus"
+ transports.VertexRagDataServiceRestInterceptor, "post_update_rag_corpus"
) as post, mock.patch.object(
- transports.VertexRagDataServiceRestInterceptor, "pre_create_rag_corpus"
+ transports.VertexRagDataServiceRestInterceptor, "pre_update_rag_corpus"
) as pre:
pre.assert_not_called()
post.assert_not_called()
- pb_message = vertex_rag_data_service.CreateRagCorpusRequest.pb(
- vertex_rag_data_service.CreateRagCorpusRequest()
+ pb_message = vertex_rag_data_service.UpdateRagCorpusRequest.pb(
+ vertex_rag_data_service.UpdateRagCorpusRequest()
)
transcode.return_value = {
"method": "post",
@@ -5344,7 +6156,7 @@ def test_create_rag_corpus_rest_interceptors(null_interceptor):
operations_pb2.Operation()
)
- request = vertex_rag_data_service.CreateRagCorpusRequest()
+ request = vertex_rag_data_service.UpdateRagCorpusRequest()
metadata = [
("key", "val"),
("cephalopod", "squid"),
@@ -5352,7 +6164,7 @@ def test_create_rag_corpus_rest_interceptors(null_interceptor):
pre.return_value = request, metadata
post.return_value = operations_pb2.Operation()
- client.create_rag_corpus(
+ client.update_rag_corpus(
request,
metadata=[
("key", "val"),
@@ -5364,8 +6176,8 @@ def test_create_rag_corpus_rest_interceptors(null_interceptor):
post.assert_called_once()
-def test_create_rag_corpus_rest_bad_request(
- transport: str = "rest", request_type=vertex_rag_data_service.CreateRagCorpusRequest
+def test_update_rag_corpus_rest_bad_request(
+ transport: str = "rest", request_type=vertex_rag_data_service.UpdateRagCorpusRequest
):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
@@ -5373,7 +6185,9 @@ def test_create_rag_corpus_rest_bad_request(
)
# send a request that will satisfy transcoding
- request_init = {"parent": "projects/sample1/locations/sample2"}
+ request_init = {
+ "rag_corpus": {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"}
+ }
request = request_type(**request_init)
# Mock the http request call within the method and fake a BadRequest error.
@@ -5385,10 +6199,10 @@ def test_create_rag_corpus_rest_bad_request(
response_value.status_code = 400
response_value.request = Request()
req.return_value = response_value
- client.create_rag_corpus(request)
+ client.update_rag_corpus(request)
-def test_create_rag_corpus_rest_flattened():
+def test_update_rag_corpus_rest_flattened():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
transport="rest",
@@ -5400,11 +6214,14 @@ def test_create_rag_corpus_rest_flattened():
return_value = operations_pb2.Operation(name="operations/spam")
# get arguments that satisfy an http rule for this method
- sample_request = {"parent": "projects/sample1/locations/sample2"}
+ sample_request = {
+ "rag_corpus": {
+ "name": "projects/sample1/locations/sample2/ragCorpora/sample3"
+ }
+ }
# get truthy value for each flattened field
mock_args = dict(
- parent="parent_value",
rag_corpus=vertex_rag_data.RagCorpus(name="name_value"),
)
mock_args.update(sample_request)
@@ -5416,20 +6233,20 @@ def test_create_rag_corpus_rest_flattened():
response_value._content = json_return_value.encode("UTF-8")
req.return_value = response_value
- client.create_rag_corpus(**mock_args)
+ client.update_rag_corpus(**mock_args)
# Establish that the underlying call was made with the expected
# request object values.
assert len(req.mock_calls) == 1
_, args, _ = req.mock_calls[0]
assert path_template.validate(
- "%s/v1beta1/{parent=projects/*/locations/*}/ragCorpora"
+ "%s/v1beta1/{rag_corpus.name=projects/*/locations/*/ragCorpora/*}"
% client.transport._host,
args[1],
)
-def test_create_rag_corpus_rest_flattened_error(transport: str = "rest"):
+def test_update_rag_corpus_rest_flattened_error(transport: str = "rest"):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
@@ -5438,14 +6255,13 @@ def test_create_rag_corpus_rest_flattened_error(transport: str = "rest"):
# Attempting to call a method with both a request object and flattened
# fields is an error.
with pytest.raises(ValueError):
- client.create_rag_corpus(
- vertex_rag_data_service.CreateRagCorpusRequest(),
- parent="parent_value",
+ client.update_rag_corpus(
+ vertex_rag_data_service.UpdateRagCorpusRequest(),
rag_corpus=vertex_rag_data.RagCorpus(name="name_value"),
)
-def test_create_rag_corpus_rest_error():
+def test_update_rag_corpus_rest_error():
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(), transport="rest"
)
@@ -8241,6 +9057,7 @@ def test_vertex_rag_data_service_base_transport():
# raise NotImplementedError.
methods = (
"create_rag_corpus",
+ "update_rag_corpus",
"get_rag_corpus",
"list_rag_corpora",
"delete_rag_corpus",
@@ -8546,6 +9363,9 @@ def test_vertex_rag_data_service_client_transport_session_collision(transport_na
session1 = client1.transport.create_rag_corpus._session
session2 = client2.transport.create_rag_corpus._session
assert session1 != session2
+ session1 = client1.transport.update_rag_corpus._session
+ session2 = client2.transport.update_rag_corpus._session
+ assert session1 != session2
session1 = client1.transport.get_rag_corpus._session
session2 = client2.transport.get_rag_corpus._session
assert session1 != session2
diff --git a/tests/unit/vertex_rag/test_rag_constants.py b/tests/unit/vertex_rag/test_rag_constants.py
index a525d2784f..cd2a74e30c 100644
--- a/tests/unit/vertex_rag/test_rag_constants.py
+++ b/tests/unit/vertex_rag/test_rag_constants.py
@@ -18,7 +18,17 @@
from google.cloud import aiplatform
-from vertexai.preview import rag
+from vertexai.preview.rag import (
+ EmbeddingModelConfig,
+ RagCorpus,
+ RagFile,
+ RagResource,
+ SlackChannelsSource,
+ SlackChannel,
+ JiraSource,
+ JiraQuery,
+ Weaviate,
+)
from google.cloud.aiplatform_v1beta1 import (
GoogleDriveSource,
RagFileChunkingConfig,
@@ -32,6 +42,7 @@
SlackSource as GapicSlackSource,
RagContexts,
RetrieveContextsResponse,
+ RagVectorDbConfig,
)
from google.cloud.aiplatform_v1beta1.types import api_auth
from google.protobuf import timestamp_pb2
@@ -47,6 +58,16 @@
TEST_RAG_CORPUS_RESOURCE_NAME = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragCorpora/{TEST_RAG_CORPUS_ID}"
# RagCorpus
+TEST_WEAVIATE_HTTP_ENDPOINT = "test.weaviate.com"
+TEST_WEAVIATE_COLLECTION_NAME = "test-collection"
+TEST_WEAVIATE_API_KEY_SECRET_VERSION = (
+ "projects/test-project/secrets/test-secret/versions/1"
+)
+TEST_WEAVIATE_CONFIG = Weaviate(
+ weaviate_http_endpoint=TEST_WEAVIATE_HTTP_ENDPOINT,
+ collection_name=TEST_WEAVIATE_COLLECTION_NAME,
+ api_key=TEST_WEAVIATE_API_KEY_SECRET_VERSION,
+)
TEST_GAPIC_RAG_CORPUS = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
@@ -57,15 +78,37 @@
TEST_PROJECT, TEST_REGION
)
)
-TEST_EMBEDDING_MODEL_CONFIG = rag.EmbeddingModelConfig(
+TEST_GAPIC_RAG_CORPUS_WEAVIATE = GapicRagCorpus(
+ name=TEST_RAG_CORPUS_RESOURCE_NAME,
+ display_name=TEST_CORPUS_DISPLAY_NAME,
+ description=TEST_CORPUS_DISCRIPTION,
+ rag_vector_db_config=RagVectorDbConfig(
+ weaviate=RagVectorDbConfig.Weaviate(
+ http_endpoint=TEST_WEAVIATE_HTTP_ENDPOINT,
+ collection_name=TEST_WEAVIATE_COLLECTION_NAME,
+ ),
+ api_auth=api_auth.ApiAuth(
+ api_key_config=api_auth.ApiAuth.ApiKeyConfig(
+ api_key_secret_version=TEST_WEAVIATE_API_KEY_SECRET_VERSION
+ ),
+ ),
+ ),
+)
+TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig(
publisher_model="publishers/google/models/textembedding-gecko",
)
-TEST_RAG_CORPUS = rag.RagCorpus(
+TEST_RAG_CORPUS = RagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
embedding_model_config=TEST_EMBEDDING_MODEL_CONFIG,
)
+TEST_RAG_CORPUS_WEAVIATE = RagCorpus(
+ name=TEST_RAG_CORPUS_RESOURCE_NAME,
+ display_name=TEST_CORPUS_DISPLAY_NAME,
+ description=TEST_CORPUS_DISCRIPTION,
+ vector_db=TEST_WEAVIATE_CONFIG,
+)
TEST_PAGE_TOKEN = "test-page-token"
# RagFiles
@@ -165,7 +208,7 @@
display_name=TEST_FILE_DISPLAY_NAME,
description=TEST_FILE_DESCRIPTION,
)
-TEST_RAG_FILE = rag.RagFile(
+TEST_RAG_FILE = RagFile(
name=TEST_RAG_FILE_RESOURCE_NAME,
display_name=TEST_FILE_DISPLAY_NAME,
description=TEST_FILE_DESCRIPTION,
@@ -183,15 +226,15 @@
TEST_SLACK_API_KEY_SECRET_VERSION_2 = (
"projects/test-project/secrets/test-secret/versions/2"
)
-TEST_SLACK_SOURCE = rag.SlackChannelsSource(
+TEST_SLACK_SOURCE = SlackChannelsSource(
channels=[
- rag.SlackChannel(
+ SlackChannel(
channel_id=TEST_SLACK_CHANNEL_ID,
api_key=TEST_SLACK_API_KEY_SECRET_VERSION,
start_time=TEST_SLACK_START_TIME,
end_time=TEST_SLACK_END_TIME,
),
- rag.SlackChannel(
+ SlackChannel(
channel_id=TEST_SLACK_CHANNEL_ID_2,
api_key=TEST_SLACK_API_KEY_SECRET_VERSION_2,
),
@@ -241,9 +284,9 @@
TEST_JIRA_API_KEY_SECRET_VERSION = (
"projects/test-project/secrets/test-secret/versions/1"
)
-TEST_JIRA_SOURCE = rag.JiraSource(
+TEST_JIRA_SOURCE = JiraSource(
queries=[
- rag.JiraQuery(
+ JiraQuery(
email=TEST_JIRA_EMAIL,
jira_projects=[TEST_JIRA_PROJECT],
custom_queries=[TEST_JIRA_CUSTOM_QUERY],
@@ -286,11 +329,11 @@
]
)
TEST_RETRIEVAL_RESPONSE = RetrieveContextsResponse(contexts=TEST_CONTEXTS)
-TEST_RAG_RESOURCE = rag.RagResource(
+TEST_RAG_RESOURCE = RagResource(
rag_corpus=TEST_RAG_CORPUS_RESOURCE_NAME,
rag_file_ids=[TEST_RAG_FILE_ID],
)
-TEST_RAG_RESOURCE_INVALID_NAME = rag.RagResource(
+TEST_RAG_RESOURCE_INVALID_NAME = RagResource(
rag_corpus="213lkj-1/23jkl/",
rag_file_ids=[TEST_RAG_FILE_ID],
)
diff --git a/tests/unit/vertex_rag/test_rag_data.py b/tests/unit/vertex_rag/test_rag_data.py
index 2b789d6513..fd243920b1 100644
--- a/tests/unit/vertex_rag/test_rag_data.py
+++ b/tests/unit/vertex_rag/test_rag_data.py
@@ -47,6 +47,21 @@ def create_rag_corpus_mock():
yield create_rag_corpus_mock
+@pytest.fixture
+def create_rag_corpus_mock_weaviate():
+ with mock.patch.object(
+ VertexRagDataServiceClient,
+ "create_rag_corpus",
+ ) as create_rag_corpus_mock_weaviate:
+ create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
+ create_rag_corpus_lro_mock.done.return_value = True
+ create_rag_corpus_lro_mock.result.return_value = (
+ tc.TEST_GAPIC_RAG_CORPUS_WEAVIATE
+ )
+ create_rag_corpus_mock_weaviate.return_value = create_rag_corpus_lro_mock
+ yield create_rag_corpus_mock_weaviate
+
+
@pytest.fixture
def list_rag_corpora_pager_mock():
with mock.patch.object(
@@ -141,6 +156,7 @@ def list_rag_files_pager_mock():
def rag_corpus_eq(returned_corpus, expected_corpus):
assert returned_corpus.name == expected_corpus.name
assert returned_corpus.display_name == expected_corpus.display_name
+ assert returned_corpus.vector_db.__eq__(expected_corpus.vector_db)
def rag_file_eq(returned_file, expected_file):
@@ -191,6 +207,15 @@ def test_create_corpus_success(self):
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS)
+ @pytest.mark.usefixtures("create_rag_corpus_mock_weaviate")
+ def test_create_corpus_weaviate_success(self):
+ rag_corpus = rag.create_corpus(
+ display_name=tc.TEST_CORPUS_DISPLAY_NAME,
+ vector_db=tc.TEST_WEAVIATE_CONFIG,
+ )
+
+ rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_WEAVIATE)
+
@pytest.mark.usefixtures("rag_data_client_mock_exception")
def test_create_corpus_failure(self):
with pytest.raises(RuntimeError) as e:
diff --git a/tests/unit/vertex_rag/test_rag_retrieval.py b/tests/unit/vertex_rag/test_rag_retrieval.py
index 5aec0bd72a..07b2180422 100644
--- a/tests/unit/vertex_rag/test_rag_retrieval.py
+++ b/tests/unit/vertex_rag/test_rag_retrieval.py
@@ -76,6 +76,7 @@ def test_retrieval_query_rag_resources_success(self):
text=tc.TEST_QUERY_TEXT,
similarity_top_k=2,
vector_distance_threshold=0.5,
+ vector_search_alpha=0.5,
)
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
diff --git a/tests/unit/vertex_ray/test_cluster_init.py b/tests/unit/vertex_ray/test_cluster_init.py
index a4ddd5f818..864b6f6f18 100644
--- a/tests/unit/vertex_ray/test_cluster_init.py
+++ b/tests/unit/vertex_ray/test_cluster_init.py
@@ -384,6 +384,7 @@ def test_create_ray_cluster_2_pools_custom_images_success(
head_node_type=tc.ClusterConstants.TEST_HEAD_NODE_TYPE_2_POOLS_CUSTOM_IMAGE,
worker_node_types=tc.ClusterConstants.TEST_WORKER_NODE_TYPES_2_POOLS_CUSTOM_IMAGE,
network=tc.ProjectConstants.TEST_VPC_NETWORK,
+ reserved_ip_ranges=["vertex-dedicated-range"],
cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID,
)
diff --git a/tests/unit/vertex_ray/test_constants.py b/tests/unit/vertex_ray/test_constants.py
index 64018e930d..26537391bf 100644
--- a/tests/unit/vertex_ray/test_constants.py
+++ b/tests/unit/vertex_ray/test_constants.py
@@ -51,12 +51,17 @@
from google.cloud.aiplatform_v1beta1.types.service_networking import (
PscInterfaceConfig,
)
+import ray
import pytest
rovminversion = pytest.mark.skipif(
sys.version_info > (3, 10), reason="Requires python3.10 or lower"
)
+# TODO(b/363340317)
+xgbversion = pytest.mark.skipif(
+ ray.__version__ != "2.9.3", reason="Requires xgboost 1.7 or higher"
+)
@dataclasses.dataclass(frozen=True)
@@ -347,6 +352,7 @@ class ClusterConstants:
),
psc_interface_config=None,
network=ProjectConstants.TEST_VPC_NETWORK,
+ reserved_ip_ranges=["vertex-dedicated-range"],
)
# Responses
TEST_RESOURCE_POOL_2.replica_count = 1
@@ -366,6 +372,7 @@ class ClusterConstants:
network_attachment=TEST_PSC_NETWORK_ATTACHMENT
),
network=None,
+ reserved_ip_ranges=None,
resource_runtime=ResourceRuntime(
access_uris={
"RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
@@ -386,6 +393,7 @@ class ClusterConstants:
),
),
network=ProjectConstants.TEST_VPC_NETWORK,
+ reserved_ip_ranges=["vertex-dedicated-range"],
resource_runtime=ResourceRuntime(
access_uris={
"RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
@@ -399,6 +407,7 @@ class ClusterConstants:
python_version="3.10",
ray_version="2.9",
network=ProjectConstants.TEST_VPC_NETWORK,
+ reserved_ip_ranges=None,
service_account=None,
state="RUNNING",
head_node_type=TEST_HEAD_NODE_TYPE_1_POOL,
@@ -412,6 +421,7 @@ class ClusterConstants:
python_version="3.10",
ray_version="2.9",
network="",
+ reserved_ip_ranges="",
service_account=None,
state="RUNNING",
head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS,
@@ -424,6 +434,7 @@ class ClusterConstants:
TEST_CLUSTER_CUSTOM_IMAGE = Cluster(
cluster_resource_name=TEST_VERTEX_RAY_PR_ADDRESS,
network=ProjectConstants.TEST_VPC_NETWORK,
+ reserved_ip_ranges=["vertex-dedicated-range"],
service_account=None,
state="RUNNING",
head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS_CUSTOM_IMAGE,
@@ -438,6 +449,7 @@ class ClusterConstants:
python_version="3.10",
ray_version="2.9",
network="",
+ reserved_ip_ranges="",
service_account=ProjectConstants.TEST_SERVICE_ACCOUNT,
state="RUNNING",
head_node_type=TEST_HEAD_NODE_TYPE_1_POOL,
diff --git a/tests/unit/vertex_ray/test_ray_prediction.py b/tests/unit/vertex_ray/test_ray_prediction.py
index fdab42c425..c5d41c3725 100644
--- a/tests/unit/vertex_ray/test_ray_prediction.py
+++ b/tests/unit/vertex_ray/test_ray_prediction.py
@@ -41,7 +41,6 @@
import numpy as np
import pytest
import ray
-from ray.train import xgboost as ray_xgboost
import tensorflow as tf
import torch
import xgboost
@@ -90,9 +89,14 @@ def ray_sklearn_checkpoint():
@pytest.fixture()
def ray_xgboost_checkpoint():
- model = test_prediction_utils.get_xgboost_model()
- checkpoint = ray_xgboost.XGBoostCheckpoint.from_model(model.get_booster())
- return checkpoint
+ if ray.__version__ == "2.9.3":
+ from ray.train import xgboost as ray_xgboost
+
+ model = test_prediction_utils.get_xgboost_model()
+ checkpoint = ray_xgboost.XGBoostCheckpoint.from_model(model.get_booster())
+ return checkpoint
+ else:
+ return None
@pytest.fixture()
@@ -374,6 +378,7 @@ def test_register_sklearnartifact_uri_not_gcs_uri_raise_error(
assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*")
# XGBoost Tests
+ @tc.xgbversion
@tc.rovminversion
def test_convert_checkpoint_to_xgboost_raise_exception(
self, ray_checkpoint_from_dict
@@ -392,6 +397,7 @@ def test_convert_checkpoint_to_xgboost_raise_exception(
"ray.train.xgboost.XGBoostCheckpoint .*"
)
+ @tc.xgbversion
def test_convert_checkpoint_to_xgboost_model_succeed(
self, ray_xgboost_checkpoint
) -> None:
@@ -406,6 +412,7 @@ def test_convert_checkpoint_to_xgboost_model_succeed(
y_pred = model.predict(xgboost.DMatrix(np.array([[1, 2]])))
assert y_pred[0] is not None
+ @tc.xgbversion
def test_register_xgboost_succeed(
self,
ray_xgboost_checkpoint,
@@ -429,6 +436,7 @@ def test_register_xgboost_succeed(
pickle_dump.assert_called_once()
gcs_utils_upload_to_gcs.assert_called_once()
+ @tc.xgbversion
def test_register_xgboost_initialized_succeed(
self,
ray_xgboost_checkpoint,
@@ -455,6 +463,7 @@ def test_register_xgboost_initialized_succeed(
pickle_dump.assert_called_once()
gcs_utils_upload_to_gcs.assert_called_once()
+ @tc.xgbversion
def test_register_xgboostartifact_uri_is_none_raise_error(
self, ray_xgboost_checkpoint
) -> None:
@@ -467,6 +476,7 @@ def test_register_xgboostartifact_uri_is_none_raise_error(
)
assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*")
+ @tc.xgbversion
def test_register_xgboostartifact_uri_not_gcs_uri_raise_error(
self, ray_xgboost_checkpoint
) -> None:
diff --git a/tests/unit/vertexai/conftest.py b/tests/unit/vertexai/conftest.py
index ecc50c2df7..73084cb412 100644
--- a/tests/unit/vertexai/conftest.py
+++ b/tests/unit/vertexai/conftest.py
@@ -65,10 +65,7 @@
_TEST_OPTIMIZED_FV2,
_TEST_PSC_OPTIMIZED_FOS,
)
-from google.cloud.logging import Logger
-from pyfakefs import fake_filesystem_unittest
import pytest
-import tensorflow.saved_model as tf_saved_model
_TEST_PROJECT = "test-project"
_TEST_PROJECT_NUMBER = "12345678"
@@ -142,14 +139,6 @@ def google_auth_mock():
yield auth_mock
-@pytest.fixture
-def mock_filesystem():
- with fake_filesystem_unittest.Patcher() as patcher:
- patcher.setUp()
- yield patcher.fs
- patcher.tearDown()
-
-
@pytest.fixture
def mock_storage_blob(mock_filesystem):
"""Mocks the storage Blob API.
@@ -292,19 +281,6 @@ def mock_uuid():
yield uuid_mock
-@pytest.fixture
-def mock_tf_saved_model_load():
- with mock.patch.object(tf_saved_model, "load") as load_mock:
- yield load_mock
-
-
-@pytest.fixture
-def mock_cloud_logging_list_entries():
- with mock.patch.object(Logger, "list_entries") as list_entries_mock:
- list_entries_mock.return_value = []
- yield list_entries_mock
-
-
@pytest.fixture
def base_logger_mock():
with patch.object(
diff --git a/tests/unit/vertexai/test_batch_prediction.py b/tests/unit/vertexai/test_batch_prediction.py
index 485da92ae6..2970e44177 100644
--- a/tests/unit/vertexai/test_batch_prediction.py
+++ b/tests/unit/vertexai/test_batch_prediction.py
@@ -50,6 +50,8 @@
_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME = "projects/123/locations/us-central1/models/456"
_TEST_PALM_MODEL_NAME = "text-bison"
_TEST_PALM_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_PALM_MODEL_NAME}"
+_TEST_LLAMA_MODEL_NAME = "llama3-405b-instruct-maas"
+_TEST_LLAMA_MODEL_RESOURCE_NAME = f"publishers/meta/models/{_TEST_LLAMA_MODEL_NAME}"
_TEST_GCS_INPUT_URI = "gs://test-bucket/test-input.jsonl"
_TEST_GCS_INPUT_URI_2 = "gs://test-bucket/test-input-2.jsonl"
@@ -127,6 +129,23 @@ def get_batch_prediction_job_with_gcs_output_mock():
yield get_job_mock
+@pytest.fixture
+def get_batch_prediction_job_with_llama_model_mock():
+ with mock.patch.object(
+ job_service_client.JobServiceClient, "get_batch_prediction_job"
+ ) as get_job_mock:
+ get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
+ name=_TEST_BATCH_PREDICTION_JOB_NAME,
+ display_name=_TEST_DISPLAY_NAME,
+ model=_TEST_LLAMA_MODEL_RESOURCE_NAME,
+ state=_TEST_JOB_STATE_SUCCESS,
+ output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo(
+ gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX
+ ),
+ )
+ yield get_job_mock
+
+
@pytest.fixture
def get_batch_prediction_job_with_tuned_gemini_model_mock():
with mock.patch.object(
@@ -252,6 +271,16 @@ def test_init_batch_prediction_job(
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
)
+ def test_init_batch_prediction_job_with_llama_model(
+ self,
+ get_batch_prediction_job_with_llama_model_mock,
+ ):
+ batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
+
+ get_batch_prediction_job_with_llama_model_mock.assert_called_once_with(
+ name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
+ )
+
def test_init_batch_prediction_job_with_tuned_gemini_model(
self,
get_batch_prediction_job_with_tuned_gemini_model_mock,
@@ -447,6 +476,39 @@ def test_submit_batch_prediction_job_with_bq_input_without_output_uri_prefix(
timeout=None,
)
+ def test_submit_batch_prediction_job_with_llama_model(
+ self,
+ create_batch_prediction_job_mock,
+ ):
+ job = batch_prediction.BatchPredictionJob.submit(
+ source_model=_TEST_LLAMA_MODEL_RESOURCE_NAME,
+ input_dataset=_TEST_BQ_INPUT_URI,
+ )
+
+ assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
+
+ expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
+ display_name=_TEST_DISPLAY_NAME,
+ model=_TEST_LLAMA_MODEL_RESOURCE_NAME,
+ input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
+ instances_format="bigquery",
+ bigquery_source=gca_io_compat.BigQuerySource(
+ input_uri=_TEST_BQ_INPUT_URI
+ ),
+ ),
+ output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
+ bigquery_destination=gca_io_compat.BigQueryDestination(
+ output_uri=_TEST_BQ_OUTPUT_PREFIX
+ ),
+ predictions_format="bigquery",
+ ),
+ )
+ create_batch_prediction_job_mock.assert_called_once_with(
+ parent=_TEST_PARENT,
+ batch_prediction_job=expected_gapic_batch_prediction_job,
+ timeout=None,
+ )
+
@pytest.mark.usefixtures("create_batch_prediction_job_mock")
def test_submit_batch_prediction_job_with_tuned_model(
self,
@@ -467,7 +529,8 @@ def test_submit_batch_prediction_job_with_invalid_source_model(self):
with pytest.raises(
ValueError,
match=(
- f"Model '{_TEST_PALM_MODEL_RESOURCE_NAME}' is not a Generative AI model."
+ "Abbreviated model names are only supported for Gemini models. "
+ "Please provide the full publisher model name."
),
):
batch_prediction.BatchPredictionJob.submit(
@@ -475,6 +538,19 @@ def test_submit_batch_prediction_job_with_invalid_source_model(self):
input_dataset=_TEST_GCS_INPUT_URI,
)
+ def test_submit_batch_prediction_job_with_invalid_abbreviated_model_name(self):
+ with pytest.raises(
+ ValueError,
+ match=(
+ "Abbreviated model names are only supported for Gemini models. "
+ "Please provide the full publisher model name."
+ ),
+ ):
+ batch_prediction.BatchPredictionJob.submit(
+ source_model=_TEST_LLAMA_MODEL_NAME,
+ input_dataset=_TEST_GCS_INPUT_URI,
+ )
+
@pytest.mark.usefixtures("get_non_gemini_model_mock")
def test_submit_batch_prediction_job_with_non_gemini_tuned_model(self):
with pytest.raises(
diff --git a/tests/unit/vertexai/test_evaluation.py b/tests/unit/vertexai/test_evaluation.py
index c386419305..f5e0879599 100644
--- a/tests/unit/vertexai/test_evaluation.py
+++ b/tests/unit/vertexai/test_evaluation.py
@@ -1176,15 +1176,14 @@ def test_evaluate_response_column_and_model_provided(self):
ValueError,
match=re.escape(
(
- "The `model` parameter or `baseline_model` in pairwise"
- " metric is specified, but the evaluation `dataset` contains"
- " model response column or baseline model response column"
- " `response` to perform bring-your-own-response(BYOR)"
- " evaluation. If you would like to perform evaluation using the"
- " dataset with the existing model response column or or"
- " baseline model response column `response`, please remove"
- " `model` parameter in `EvalTask.evaluate` function or keep"
- " `baseline_model` as None in pairwise metric."
+ "The `model` parameter or `baseline_model` in pairwise metric is"
+ " specified, but the evaluation `dataset` contains model response"
+ " column or baseline model response column `response`"
+ " to perform bring-your-own-response(BYOR) evaluation. If you would"
+ " like to perform evaluation using the dataset with the"
+ " existing model response column or or baseline model response column"
+ " `response`, please remove `model` parameter in `EvalTask.evaluate()`"
+ " function or `baseline_model` in `PairwiseMetric`."
)
),
):
@@ -1202,16 +1201,15 @@ def test_evaluate_baseline_response_column_and_baseline_model_provided(self):
ValueError,
match=re.escape(
(
- "The `model` parameter or `baseline_model` in pairwise"
- " metric is specified, but the evaluation `dataset` contains"
- " model response column or baseline model response column"
- " `baseline_model_response` to perform"
- " bring-your-own-response(BYOR) evaluation. If you would"
+ "The `model` parameter or `baseline_model` in pairwise metric is"
+ " specified, but the evaluation `dataset` contains model response"
+ " column or baseline model response column `baseline_model_response`"
+ " to perform bring-your-own-response(BYOR) evaluation. If you would"
" like to perform evaluation using the dataset with the"
- " existing model response column or or baseline model"
- " response column `baseline_model_response`, please remove"
- " `model` parameter in `EvalTask.evaluate` function or keep"
- " `baseline_model` as None in pairwise metric."
+ " existing model response column or or baseline model response column"
+ " `baseline_model_response`, please remove `model` parameter in"
+ " `EvalTask.evaluate()` function or `baseline_model` in"
+ " `PairwiseMetric`."
)
),
):
@@ -1227,8 +1225,8 @@ def test_evaluate_response_column_and_model_not_provided(self):
KeyError,
match=re.escape(
(
- "Required column `response` not found in the eval dataset."
- " The columns in the provided dataset are ['prompt']."
+ "Required column `response` not found in the evaluation dataset."
+ " The columns in the evaluation dataset are ['prompt']."
)
),
):
@@ -1248,7 +1246,7 @@ def test_evaluate_baseline_response_column_and_baseline_model_not_provided(
match=re.escape(
(
"Required column `baseline_model_response` not found in the"
- " eval dataset. The columns in the provided dataset are"
+ " evaluation dataset. The columns in the evaluation dataset are"
" ['prompt', 'response']."
)
),
@@ -1301,7 +1299,7 @@ def test_evaluate_pairwise_metrics_with_multiple_baseline_models(self):
)
with pytest.raises(
ValueError,
- match="Not all PairwiseMetric instances have the same baseline_model",
+ match="Not all `PairwiseMetric` instances have the same `baseline_model`",
):
test_eval_task.evaluate()
@@ -1314,15 +1312,14 @@ def test_evaluate_invalid_model_and_dataset_input(self):
ValueError,
match=re.escape(
(
- "The `model` parameter or `baseline_model` in pairwise metric"
- " is specified, but the evaluation `dataset` contains model"
- " response column or baseline model response column `response`"
- " to perform bring-your-own-response(BYOR) evaluation. If you"
- " would like to perform evaluation using the dataset with the"
- " existing model response column or or baseline model response"
- " column `response`, please remove `model` parameter in"
- " `EvalTask.evaluate` function or keep `baseline_model` as None"
- " in pairwise metric."
+ "The `model` parameter or `baseline_model` in pairwise metric is"
+ " specified, but the evaluation `dataset` contains model response"
+ " column or baseline model response column `response`"
+ " to perform bring-your-own-response(BYOR) evaluation. If you would"
+ " like to perform evaluation using the dataset with the"
+ " existing model response column or or baseline model response column"
+ " `response`, please remove `model` parameter in `EvalTask.evaluate()`"
+ " function or `baseline_model` in `PairwiseMetric`."
)
),
):
@@ -1339,15 +1336,14 @@ def test_unmatched_metric_column_mapping(self):
ValueError,
match=re.escape(
(
- "The `model` parameter or `baseline_model` in pairwise metric"
- " is specified, but the evaluation `dataset` contains model"
- " response column or baseline model response column `response`"
- " to perform bring-your-own-response(BYOR) evaluation. If you"
- " would like to perform evaluation using the dataset with the"
- " existing model response column or or baseline model response"
- " column `response`, please remove `model` parameter in"
- " `EvalTask.evaluate` function or keep `baseline_model` as None"
- " in pairwise metric."
+ "The `model` parameter or `baseline_model` in pairwise metric is"
+ " specified, but the evaluation `dataset` contains model response"
+ " column or baseline model response column `response`"
+ " to perform bring-your-own-response(BYOR) evaluation. If you would"
+ " like to perform evaluation using the dataset with the"
+ " existing model response column or or baseline model response column"
+ " `response`, please remove `model` parameter in `EvalTask.evaluate()`"
+ " function or `baseline_model` in `PairwiseMetric`."
)
),
):
diff --git a/tests/unit/vertexai/test_tokenization.py b/tests/unit/vertexai/test_tokenization.py
index 861c225211..f51be60360 100644
--- a/tests/unit/vertexai/test_tokenization.py
+++ b/tests/unit/vertexai/test_tokenization.py
@@ -20,11 +20,19 @@
import tempfile
from typing import List
from unittest import mock
-from vertexai.generative_models import Content, Image, Part
+from vertexai.generative_models import (
+ Content,
+ Image,
+ Part,
+ FunctionDeclaration,
+ Tool,
+)
+
from vertexai.tokenization import _tokenizer_loading
from vertexai.tokenization._tokenizers import (
CountTokensResult,
TokensInfo,
+ _TextsAccumulator,
get_tokenizer_for_model,
)
import pytest
@@ -32,7 +40,11 @@
from sentencepiece import sentencepiece_model_pb2, sentencepiece_pb2
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types,
+ tool as gapic_tool_types,
+ openapi,
)
+from google.protobuf import struct_pb2
+
_TOKENIZER_NAME = "google/gemma"
_MODEL_NAME = "gemini-1.5-pro"
@@ -170,14 +182,8 @@
_LIST_OF_UNSUPPORTED_CONTENTS = [
- gapic_content_types.Part(
- video_metadata=gapic_content_types.VideoMetadata(start_offset="10s")
- ),
Part.from_uri("gs://bucket/object", mime_type="mime_type"),
Part.from_data(b"inline_data_bytes", mime_type="mime_type"),
- Part.from_dict({"function_call": {"name": "test_function_call"}}),
- Part.from_dict({"function_response": {"name": "test_function_response"}}),
- Part.from_dict({"video_metadata": {"start_offset": "10s"}}),
Content(
role="user",
parts=[Part.from_uri("gs://bucket/object", mime_type="mime_type")],
@@ -186,20 +192,57 @@
role="user",
parts=[Part.from_data(b"inline_data_bytes", mime_type="mime_type")],
),
- Content(
- role="user",
- parts=[Part.from_dict({"function_call": {"name": "test_function_call"}})],
- ),
- Content(
- role="user",
- parts=[
- Part.from_dict({"function_response": {"name": "test_function_response"}})
- ],
- ),
- Content(
- role="user",
- parts=[Part.from_dict({"video_metadata": {"start_offset": "10s"}})],
- ),
+]
+
+_NESTED_STRUCT_1 = struct_pb2.Struct(
+ fields={"string_key": struct_pb2.Value(string_value="value1")}
+)
+_NESTED_STRUCT_2 = struct_pb2.Struct(
+ fields={
+ "list_key": struct_pb2.Value(
+ list_value=struct_pb2.ListValue(
+ values=[struct_pb2.Value(string_value="value2")]
+ )
+ )
+ }
+)
+_NESTED_STRUCT_3 = struct_pb2.Struct(
+ fields={
+ "struct_key": struct_pb2.Value(
+ struct_value=struct_pb2.Struct(
+ fields={"string_key": struct_pb2.Value(string_value="value3")}
+ )
+ )
+ }
+)
+_STRUCT = struct_pb2.Struct(
+ fields={
+ "string_key": struct_pb2.Value(string_value="value4"),
+ "list_key": struct_pb2.Value(
+ list_value=struct_pb2.ListValue(
+ values=[struct_pb2.Value(string_value="value5")]
+ )
+ ),
+ "struct_key1": struct_pb2.Value(struct_value=_NESTED_STRUCT_1),
+ "struct_key2": struct_pb2.Value(struct_value=_NESTED_STRUCT_2),
+ "struct_key3": struct_pb2.Value(struct_value=_NESTED_STRUCT_3),
+ }
+)
+_STRUCT_TEXTS = [
+ "struct_key3",
+ "struct_key1",
+ "list_key",
+ "string_key",
+ "struct_key2",
+ "struct_key",
+ "string_key",
+ "value3",
+ "string_key",
+ "value1",
+ "value5",
+ "value4",
+ "list_key",
+ "value2",
]
@@ -218,7 +261,15 @@ def mock_sp_processor():
def _encode_as_ids(contents: List[str]):
- return [_TOKENS_MAP[content]["ids"] for content in contents]
+ return [
+ (
+ _TOKENS_MAP[content]["ids"]
+ if content in _TOKENS_MAP
+ # Returns stable ids arrary when content is not predefined.
+ else [0] * len(content.split(" "))
+ )
+ for content in contents
+ ]
def _build_sentencepiece_text(content: str):
@@ -254,6 +305,21 @@ def mock_hashlib_sha256():
yield sha256_mock
+def get_current_weather(location: str, unit: str = "centigrade"):
+ """Gets weather in the specified location.
+ Args:
+ location: The location for which to get the weather.
+ unit: Optional. Temperature unit. Can be Centigrade or Fahrenheit. Defaults to Centigrade.
+ Returns:
+ The weather information as a dict.
+ """
+ return dict(
+ location=location,
+ unit=unit,
+ weather="Super nice, but maybe a bit hot.",
+ )
+
+
@pytest.mark.usefixtures("mock_requests_get", "mock_hashlib_sha256")
class TestTokenizers:
"""Unit tests for the tokenizers."""
@@ -269,7 +335,7 @@ def test_valid_contents_type_for_bytes_token_type(self, mock_sp_processor):
_tokenizer_loading.get_sentencepiece.cache_clear()
assert get_tokenizer_for_model(_MODEL_NAME).compute_tokens(
[_SENTENCE_4]
- ).token_info_list == (
+ ).tokens_info == (
[TokensInfo(token_ids=[0, 1], tokens=[b"A", b"B"], role="user")]
)
assert get_tokenizer_for_model(_MODEL_NAME).count_tokens(
@@ -304,10 +370,9 @@ def testcompute_tokens_valid_contents_type(
self, mock_sp_processor, contents, encode_input, encode_output, roles
):
_tokenizer_loading.get_sentencepiece.cache_clear()
-
assert (
get_tokenizer_for_model(_MODEL_NAME).compute_tokens(contents)
- ).token_info_list == (
+ ).tokens_info == (
[
TokensInfo(token_ids=output["ids"], tokens=output["tokens"], role=role)
for role, output in zip(roles, encode_output)
@@ -334,6 +399,54 @@ def test_count_tokens_unsupported_contents_type(
get_tokenizer_for_model(_MODEL_NAME).count_tokens(contents)
e.match("Tokenizers do not support non-text content types.")
+ def test_system_instruction_count_tokens(self, mock_sp_processor):
+ _tokenizer_loading.get_sentencepiece.cache_clear()
+ tokenizer = get_tokenizer_for_model(_MODEL_NAME)
+ result = tokenizer.count_tokens(
+ ["hello world"], system_instruction=["You are a chatbot."]
+ )
+ assert result.total_tokens == 6
+
+ def test_function_call_count_tokens(self, mock_sp_processor):
+ tokenizer = get_tokenizer_for_model(_MODEL_NAME)
+ part = Part._from_gapic(
+ gapic_content_types.Part(
+ function_call=gapic_tool_types.FunctionCall(
+ name="test_function_call",
+ args=_STRUCT,
+ ),
+ )
+ )
+
+ result = tokenizer.count_tokens(part)
+
+ assert result.total_tokens
+
+ def test_function_response_count_tokens(self, mock_sp_processor):
+ tokenizer = get_tokenizer_for_model(_MODEL_NAME)
+ part = Part._from_gapic(
+ gapic_content_types.Part(
+ function_response=gapic_tool_types.FunctionResponse(
+ name="test_function_response", response=_STRUCT
+ ),
+ )
+ )
+
+ result = tokenizer.count_tokens(part)
+
+ assert result.total_tokens
+
+ def test_tools_count_tokens(self, mock_sp_processor):
+ tokenizer = get_tokenizer_for_model(_MODEL_NAME)
+ get_current_weather_func = FunctionDeclaration.from_func(get_current_weather)
+ weather_tool = Tool(
+ function_declarations=[get_current_weather_func],
+ )
+
+ result = tokenizer.count_tokens(contents=[], tools=[weather_tool])
+
+ assert result.total_tokens
+
def test_image_mime_types(self, mock_sp_processor):
# Importing external library lazily to reduce the scope of import errors.
from PIL import Image as PIL_Image # pylint: disable=g-import-not-at-top
@@ -417,3 +530,107 @@ def test_load_model_proto_from_corrupted_cache(self, hash_mock, mock_requests_ge
mock_requests_get.assert_called_once()
with open(cache_path, "rb") as f:
assert f.read() == _TOKENIZER_MODEL.SerializeToString()
+
+
+class TestTextsAccumulator:
+ def setup_method(self):
+ self.texts_accumulator = _TextsAccumulator()
+
+ def test_function_declaration_unsupported_field(self):
+ function_declaration = gapic_tool_types.FunctionDeclaration(
+ parameters=openapi.Schema(nullable=True)
+ )
+ with pytest.raises(ValueError):
+ self.texts_accumulator.add_tool(
+ gapic_tool_types.Tool(function_declarations=[function_declaration])
+ )
+
+ def test_function_call_unsupported_field(self):
+ function_call = gapic_tool_types.FunctionCall(
+ name="test_function_call",
+ args=struct_pb2.Struct(
+ fields={
+ "bool_key": struct_pb2.Value(bool_value=True),
+ }
+ ),
+ )
+ with pytest.raises(ValueError):
+ self.texts_accumulator.add_function_call(function_call)
+
+ def test_function_response_unsupported_field(self):
+ function_call = gapic_tool_types.FunctionResponse(
+ name="test_function_response",
+ response=struct_pb2.Struct(
+ fields={
+ "bool_key": struct_pb2.Value(bool_value=True),
+ }
+ ),
+ )
+ with pytest.raises(ValueError):
+ self.texts_accumulator.add_function_response(function_call)
+
+ def test_function_declaration(self):
+ schema1 = openapi.Schema(
+ format="schema1_format", description="schema1_description"
+ )
+ schema2 = openapi.Schema(
+ format="schema2_format", description="schema2_description"
+ )
+ example = struct_pb2.Value(string_value="value1")
+ function_declaration = gapic_tool_types.FunctionDeclaration(
+ name="function_declaration_name",
+ description="function_declaration_description",
+ parameters=openapi.Schema(
+ format="schema_format",
+ description="schema_description",
+ enum=["schema_enum1", "schema_enum2"],
+ required=["schema_required1", "schema_required2"],
+ items=schema1,
+ properties={"property_key": schema2},
+ example=example,
+ ),
+ )
+
+ self.texts_accumulator.add_tool(
+ gapic_tool_types.Tool(function_declarations=[function_declaration])
+ )
+ assert self.texts_accumulator.get_texts() == [
+ "function_declaration_name",
+ "function_declaration_description",
+ "schema_format",
+ "schema_description",
+ "schema_enum1",
+ "schema_enum2",
+ "schema_required1",
+ "schema_required2",
+ "schema1_format",
+ "schema1_description",
+ "property_key",
+ "schema2_format",
+ "schema2_description",
+ "value1",
+ ]
+
+ def test_function_call(self):
+ function_call = gapic_tool_types.FunctionCall(
+ name="test_function_call",
+ args=_STRUCT,
+ )
+
+ self.texts_accumulator.add_function_call(function_call)
+
+ assert (
+ self.texts_accumulator.get_texts() == ["test_function_call"] + _STRUCT_TEXTS
+ )
+
+ def test_function_response(self):
+ function_response = gapic_tool_types.FunctionResponse(
+ name="test_function_response", response=_STRUCT
+ )
+
+ self.texts_accumulator.add_function_response(function_response)
+
+ assert (
+ self.texts_accumulator.get_texts()
+ == ["test_function_response"] + _STRUCT_TEXTS
+ )
diff --git a/vertexai/batch_prediction/_batch_prediction.py b/vertexai/batch_prediction/_batch_prediction.py
index c7cafc7543..2ca7d7b85c 100644
--- a/vertexai/batch_prediction/_batch_prediction.py
+++ b/vertexai/batch_prediction/_batch_prediction.py
@@ -33,6 +33,7 @@
_LOGGER = aiplatform_base.Logger(__name__)
_GEMINI_MODEL_PATTERN = r"publishers/google/models/gemini"
+_LLAMA_MODEL_PATTERN = r"publishers/meta/models/llama"
_GEMINI_TUNED_MODEL_PATTERN = r"^projects/[0-9]+?/locations/[0-9a-z-]+?/models/[0-9]+?$"
@@ -272,13 +273,20 @@ def _reconcile_model_name(cls, model_name: str) -> str:
if "/" not in model_name:
# model name (e.g., gemini-1.0-pro)
- model_name = "publishers/google/models/" + model_name
+ if model_name.startswith("gemini"):
+ model_name = "publishers/google/models/" + model_name
+ else:
+ raise ValueError(
+ "Abbreviated model names are only supported for Gemini models. "
+ "Please provide the full publisher model name."
+ )
elif model_name.startswith("models/"):
# publisher model name (e.g., models/gemini-1.0-pro)
model_name = "publishers/google/" + model_name
elif (
# publisher model full name
not model_name.startswith("publishers/google/models/")
+ and not model_name.startswith("publishers/meta/models/")
# tuned model full resource name
and not re.search(_GEMINI_TUNED_MODEL_PATTERN, model_name)
):
@@ -302,6 +310,10 @@ def _is_genai_model(cls, model_name: str) -> bool:
# Model is a tuned Gemini model.
return True
+ if re.search(_LLAMA_MODEL_PATTERN, model_name):
+ # Model is a Llama3 model.
+ return True
+
return False
@classmethod
diff --git a/vertexai/evaluation/_base.py b/vertexai/evaluation/_base.py
index 60696f8ed2..863ab0b6a1 100644
--- a/vertexai/evaluation/_base.py
+++ b/vertexai/evaluation/_base.py
@@ -76,9 +76,10 @@ def validate_dataset_column(self, column_name: str) -> None:
not in self.dataset.columns
):
raise KeyError(
- f"Required column `{self.metric_column_mapping.get(column_name, column_name)}`"
- " not found in the eval dataset. The columns in the provided dataset"
- f" are {list(self.dataset.columns)}."
+ "Required column"
+ f" `{self.metric_column_mapping.get(column_name, column_name)}`"
+ " not found in the evaluation dataset. The columns in the"
+ f" evaluation dataset are {list(self.dataset.columns)}."
)
diff --git a/vertexai/evaluation/_evaluation.py b/vertexai/evaluation/_evaluation.py
index cfce0f3c07..ab878b1822 100644
--- a/vertexai/evaluation/_evaluation.py
+++ b/vertexai/evaluation/_evaluation.py
@@ -113,12 +113,13 @@ def _validate_metric_column_map(
not in evaluation_run_config.dataset.columns
):
raise ValueError(
- "Cannot find appropriate column name mapping to help"
- " fill metric prompt template with evaluation dataset. "
- "Please provide key-value pair in `metric_column_mapping`"
- " parameter from 'EvalTask.evaluate' function. Key is"
- f" '{variable}', value should be found in evaluation"
- f" dataset columns"
+ f"Cannot find the `{variable}` column in the evaluation"
+ " dataset to fill the metric prompt template for"
+ f" `{str(metric)}` metric. Please check if the column is"
+ " present in the evaluation dataset, or provide a"
+ " key-value pair in `metric_column_mapping` parameter"
+ " of `EvalTask` to map it to a different column name."
+ " The evaluation dataset columns are"
f" {list(evaluation_run_config.dataset.columns)}."
)
@@ -158,10 +159,11 @@ def _compute_custom_metrics(
] = metric_output[custom_metric.name]
except KeyError:
raise KeyError(
- f"Custom metric score `{custom_metric.name}` not found in the metric"
- f" output {metric_output}. Please make sure the custom metric"
- " function is valid, and the output dictionary uses"
- f" `{custom_metric.name}` as the key for metric value."
+ f"Custom metric score `{custom_metric.name}` not found in"
+ f" the metric output {metric_output}. Please make sure the"
+ " custom metric function is valid, and the output"
+ f" dictionary uses `{custom_metric.name}` as the key for"
+ " metric score."
)
# Include additional metric results like explanation.
for key, value in metric_output.items():
@@ -224,8 +226,7 @@ def _aggregate_summary_metrics(
].std()
except (ValueError, KeyError) as e:
_LOGGER.warning(
- f"Failed to compute metric statistics for {metric}. This metric"
- " output contains error from the Autorater.\n"
+ f"Failed to compute metric statistics for `{metric}` metric."
f"{type(e).__name__}: {e}"
)
continue
@@ -414,12 +415,13 @@ def _run_model_inference(
] = response_column_name
else:
raise ValueError(
- "Missing required input column to start model inference."
+ "Missing required input `prompt` column to start model inference."
" Please provide a `prompt_template` parameter in"
" `EvalTask.evaluate()` function if you want to assemble a"
- " `prompt` with variables from the dataset, or provide a"
+ " `prompt` column with variables from the dataset, or provide a"
" `prompt` column in dataset to directly use as input to"
- " the model."
+ " the model. Mappings in `metric_column_mapping` do not"
+ " apply for model inference and are used for evaluation only."
)
else:
if model:
@@ -432,8 +434,8 @@ def _run_model_inference(
" like to perform evaluation using the dataset with the"
" existing model response column or or baseline model response column"
f" `{evaluation_run_config.metric_column_mapping[response_column_name]}`,"
- " please remove `model` parameter in `EvalTask.evaluate`"
- " function or keep `baseline_model` as None in pairwise metric."
+ " please remove `model` parameter in `EvalTask.evaluate()`"
+ " function or `baseline_model` in `PairwiseMetric`."
)
@@ -486,8 +488,8 @@ def _assemble_prompt_for_dataset(
raise ValueError("Prompt template cannot be an empty string.")
_LOGGER.info(
- "Assembling prompts from the prompt_template. The `prompt` column in"
- " the EvalResult.metrics_table has the assembled prompts used for model"
+ "Assembling prompts from the `prompt_template`. The `prompt` column in"
+ " the `EvalResult.metrics_table` has the assembled prompts used for model"
" response generation."
)
if isinstance(prompt_template, str):
@@ -507,13 +509,28 @@ def _assemble_prompt_for_dataset(
),
axis=1,
)
+ if (
+ constants.Dataset.PROMPT_COLUMN
+ in evaluation_run_config.metric_column_mapping
+ and evaluation_run_config.metric_column_mapping[
+ constants.Dataset.PROMPT_COLUMN
+ ]
+ != constants.Dataset.PROMPT_COLUMN
+ ):
+ _LOGGER.warning(
+ "The `prompt` column mapping provided in"
+ " `metric_column_mapping` parameter is overwritten by the"
+ " assembled `prompt` column because the `prompt_template`"
+ " parameter is provided. Please verify that you want to use"
+ " the assembled `prompt` column for evaluation."
+ )
evaluation_run_config.metric_column_mapping[
constants.Dataset.PROMPT_COLUMN
] = constants.Dataset.PROMPT_COLUMN
except Exception as e:
raise ValueError(
f"Failed to assemble prompt template: {e}. Please make sure all"
- " variables in prompt_template are present in the evaluation"
+ " variables in `prompt_template` are present in the evaluation"
f" dataset columns: `{list(evaluation_run_config.dataset.columns)}`."
) from e
@@ -634,8 +651,8 @@ def _compute_metrics(
total_request_count = api_request_count + custom_metric_request_count
_LOGGER.info(
- f"Computing metrics with a total of {total_request_count} Vertex online"
- " evaluation service requests."
+ f"Computing metrics with a total of {total_request_count} Vertex Gen AI"
+ " Evaluation Service API requests."
)
instance_list = []
@@ -718,10 +735,10 @@ def _get_baseline_model(evaluation_run_config: evaluation_base.EvaluationRunConf
}
if len(set(baseline_models.values())) > 1:
raise ValueError(
- "Not all PairwiseMetric instances have the same baseline_model "
+ "Not all `PairwiseMetric` instances have the same `baseline_model`. "
f"Here are the detected baseline models: `{baseline_models}`. "
"Please separate pairwise metrics with different baseline models "
- "in different EvalTask or use the same baseline model for "
+ "in different `EvalTask` or use the same baseline model for "
"all pairwise metrics."
)
return pairwise_metric_instances[0].baseline_model
@@ -747,10 +764,11 @@ def _convert_metric_prompt_template_example(metrics):
)
_LOGGER.info(
f"Pairwise metric `{metric.metric_name}` loaded from"
- " MetricPromptTemplateExamples does not have baseline_model."
- " If you would like to run inference on the baseline model,"
- " please instantiate a PairwiseMetric and provide the"
- " `baseline_model` input."
+ " `MetricPromptTemplateExamples` does not have `baseline_model`"
+ " specified and only supports Bring-Your-Own-Response(BYOR)"
+ " evaluation. If you would like to run inference on the baseline model,"
+ " please instantiate a `PairwiseMetric` and provide the"
+ " `baseline_model` parameter."
)
updated_metrics.append(metric)
return updated_metrics
@@ -836,7 +854,6 @@ def evaluate(
constants.Dataset.REFERENCE_COLUMN
)
- # Model inference
if prompt_template:
_assemble_prompt_for_dataset(evaluation_run_config, prompt_template)
@@ -852,22 +869,12 @@ def evaluate(
)
)
- # Baseline model inference
pairwise_metric_exists = any(
isinstance(metric, pairwise_metric.PairwiseMetric)
for metric in evaluation_run_config.metrics
)
if pairwise_metric_exists:
baseline_model = _get_baseline_model(evaluation_run_config)
- if baseline_model is None:
- _LOGGER.warning(
- "Note: pointwise metric with metric prompt template example and "
- "pairwise metric with metric prompt template example do not "
- "have baseline models. If evaluation run fails by "
- "`baseline_model_response` column validation, Please sepecify"
- " `baseline_model_response_column_name` in `EvalTask.evaluate` to"
- " perform bring-your-own-response(BYOR) evaluation."
- )
_run_model_inference(
model=baseline_model,
evaluation_run_config=evaluation_run_config,
diff --git a/vertexai/evaluation/eval_task.py b/vertexai/evaluation/eval_task.py
index 973c3d3483..3ff24e5f22 100644
--- a/vertexai/evaluation/eval_task.py
+++ b/vertexai/evaluation/eval_task.py
@@ -396,7 +396,7 @@ def evaluate(
global_experiment_name = metadata._experiment_tracker.experiment_name
if experiment_run_name and not self._experiment and not global_experiment_name:
raise ValueError(
- "Experiment is not set. Please initialize EvalTask with an"
+ "Experiment is not set. Please initialize `EvalTask` with an"
" experiment, or initialize a global experiment with "
"`vertexai.init(experiment='experiment_name')`for logging this"
" evaluation run."
diff --git a/vertexai/evaluation/metrics/metric_prompt_template.py b/vertexai/evaluation/metrics/metric_prompt_template.py
index bae45d2a76..75256d3ac3 100644
--- a/vertexai/evaluation/metrics/metric_prompt_template.py
+++ b/vertexai/evaluation/metrics/metric_prompt_template.py
@@ -51,13 +51,7 @@ def __init__(
):
"""Initializes a metric prompt template."""
- if not input_variables:
- _LOGGER.info(
- "input_variables is empty. The evaluation dataset will not be used"
- " for computing this model-based metric. Only the responses are used."
- )
self._input_variables = input_variables
-
self._instruction = instruction
self._metric_definition = metric_definition
self._criteria = criteria
@@ -117,6 +111,10 @@ def __init__(
"""
if not input_variables:
input_variables = []
+ _LOGGER.info(
+ "The `input_variables` parameter is empty. Only the `response`"
+ " column is used for computing this model-based metric."
+ )
input_variables = list(set(input_variables + ["response"]))
instruction = instruction or self.get_default_pointwise_instruction()
@@ -273,6 +271,11 @@ def __init__(
"""
if not input_variables:
input_variables = []
+ _LOGGER.info(
+ "The `input_variables` parameter is empty. Only the `response`"
+ " and `baseline_model_response` columns are used for computing"
+ " this model-based metric."
+ )
input_variables = list(
set(input_variables + ["response", "baseline_model_response"])
)
diff --git a/vertexai/evaluation/metrics/metric_prompt_template_examples.py b/vertexai/evaluation/metrics/metric_prompt_template_examples.py
index 659453e9a8..d067bf0824 100644
--- a/vertexai/evaluation/metrics/metric_prompt_template_examples.py
+++ b/vertexai/evaluation/metrics/metric_prompt_template_examples.py
@@ -24,74 +24,71 @@
)
from vertexai.evaluation.metrics import pairwise_metric
from vertexai.evaluation.metrics import pointwise_metric
-import immutabledict
class MetricPromptTemplateExamples:
"""Examples of metric prompt templates for model-based evaluation."""
- _PROMPT_TEMPLATE_MAP = immutabledict.immutabledict(
- {
- constants.Metric.COHERENCE: _default_templates.COHERENCE_PROMPT_TEMPLATE,
- constants.Metric.FLUENCY: _default_templates.FLUENCY_PROMPT_TEMPLATE,
- constants.Metric.SAFETY: _default_templates.SAFETY_PROMPT_TEMPLATE,
- constants.Metric.GROUNDEDNESS: (
- _default_templates.GROUNDEDNESS_PROMPT_TEMPLATE
- ),
- constants.Metric.INSTRUCTION_FOLLOWING: (
- _default_templates.INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE
- ),
- constants.Metric.VERBOSITY: _default_templates.VERBOSITY_PROMPT_TEMPLATE,
- constants.Metric.TEXT_QUALITY: (
- _default_templates.TEXT_QUALITY_PROMPT_TEMPLATE
- ),
- constants.Metric.SUMMARIZATION_QUALITY: (
- _default_templates.SUMMARIZATION_QUALITY_PROMPT_TEMPLATE
- ),
- constants.Metric.QUESTION_ANSWERING_QUALITY: (
- _default_templates.QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE
- ),
- constants.Metric.MULTI_TURN_CHAT_QUALITY: (
- _default_templates.MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE
- ),
- constants.Metric.MULTI_TURN_SAFETY: (
- _default_templates.MULTI_TURN_SAFETY_PROMPT_TEMPLATE
- ),
- constants.Metric.PAIRWISE_COHERENCE: (
- _default_templates.PAIRWISE_COHERENCE_PROMPT_TEMPLATE
- ),
- constants.Metric.PAIRWISE_FLUENCY: (
- _default_templates.PAIRWISE_FLUENCY_PROMPT_TEMPLATE
- ),
- constants.Metric.PAIRWISE_SAFETY: (
- _default_templates.PAIRWISE_SAFETY_PROMPT_TEMPLATE
- ),
- constants.Metric.PAIRWISE_GROUNDEDNESS: (
- _default_templates.PAIRWISE_GROUNDEDNESS_PROMPT_TEMPLATE
- ),
- constants.Metric.PAIRWISE_INSTRUCTION_FOLLOWING: (
- _default_templates.PAIRWISE_INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE
- ),
- constants.Metric.PAIRWISE_VERBOSITY: (
- _default_templates.PAIRWISE_VERBOSITY_PROMPT_TEMPLATE
- ),
- constants.Metric.PAIRWISE_TEXT_QUALITY: (
- _default_templates.PAIRWISE_TEXT_QUALITY_PROMPT_TEMPLATE
- ),
- constants.Metric.PAIRWISE_SUMMARIZATION_QUALITY: (
- _default_templates.PAIRWISE_SUMMARIZATION_QUALITY_PROMPT_TEMPLATE
- ),
- constants.Metric.PAIRWISE_QUESTION_ANSWERING_QUALITY: (
- _default_templates.PAIRWISE_QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE
- ),
- constants.Metric.PAIRWISE_MULTI_TURN_CHAT_QUALITY: (
- _default_templates.PAIRWISE_MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE
- ),
- constants.Metric.PAIRWISE_MULTI_TURN_SAFETY: (
- _default_templates.PAIRWISE_MULTI_TURN_SAFETY_PROMPT_TEMPLATE
- ),
- }
- )
+ _PROMPT_TEMPLATE_MAP = {
+ constants.Metric.COHERENCE: _default_templates.COHERENCE_PROMPT_TEMPLATE,
+ constants.Metric.FLUENCY: _default_templates.FLUENCY_PROMPT_TEMPLATE,
+ constants.Metric.SAFETY: _default_templates.SAFETY_PROMPT_TEMPLATE,
+ constants.Metric.GROUNDEDNESS: (
+ _default_templates.GROUNDEDNESS_PROMPT_TEMPLATE
+ ),
+ constants.Metric.INSTRUCTION_FOLLOWING: (
+ _default_templates.INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE
+ ),
+ constants.Metric.VERBOSITY: _default_templates.VERBOSITY_PROMPT_TEMPLATE,
+ constants.Metric.TEXT_QUALITY: (
+ _default_templates.TEXT_QUALITY_PROMPT_TEMPLATE
+ ),
+ constants.Metric.SUMMARIZATION_QUALITY: (
+ _default_templates.SUMMARIZATION_QUALITY_PROMPT_TEMPLATE
+ ),
+ constants.Metric.QUESTION_ANSWERING_QUALITY: (
+ _default_templates.QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE
+ ),
+ constants.Metric.MULTI_TURN_CHAT_QUALITY: (
+ _default_templates.MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE
+ ),
+ constants.Metric.MULTI_TURN_SAFETY: (
+ _default_templates.MULTI_TURN_SAFETY_PROMPT_TEMPLATE
+ ),
+ constants.Metric.PAIRWISE_COHERENCE: (
+ _default_templates.PAIRWISE_COHERENCE_PROMPT_TEMPLATE
+ ),
+ constants.Metric.PAIRWISE_FLUENCY: (
+ _default_templates.PAIRWISE_FLUENCY_PROMPT_TEMPLATE
+ ),
+ constants.Metric.PAIRWISE_SAFETY: (
+ _default_templates.PAIRWISE_SAFETY_PROMPT_TEMPLATE
+ ),
+ constants.Metric.PAIRWISE_GROUNDEDNESS: (
+ _default_templates.PAIRWISE_GROUNDEDNESS_PROMPT_TEMPLATE
+ ),
+ constants.Metric.PAIRWISE_INSTRUCTION_FOLLOWING: (
+ _default_templates.PAIRWISE_INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE
+ ),
+ constants.Metric.PAIRWISE_VERBOSITY: (
+ _default_templates.PAIRWISE_VERBOSITY_PROMPT_TEMPLATE
+ ),
+ constants.Metric.PAIRWISE_TEXT_QUALITY: (
+ _default_templates.PAIRWISE_TEXT_QUALITY_PROMPT_TEMPLATE
+ ),
+ constants.Metric.PAIRWISE_SUMMARIZATION_QUALITY: (
+ _default_templates.PAIRWISE_SUMMARIZATION_QUALITY_PROMPT_TEMPLATE
+ ),
+ constants.Metric.PAIRWISE_QUESTION_ANSWERING_QUALITY: (
+ _default_templates.PAIRWISE_QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE
+ ),
+ constants.Metric.PAIRWISE_MULTI_TURN_CHAT_QUALITY: (
+ _default_templates.PAIRWISE_MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE
+ ),
+ constants.Metric.PAIRWISE_MULTI_TURN_SAFETY: (
+ _default_templates.PAIRWISE_MULTI_TURN_SAFETY_PROMPT_TEMPLATE
+ ),
+ }
@classmethod
def get_prompt_template(cls, metric_name: str) -> str:
diff --git a/vertexai/evaluation/utils.py b/vertexai/evaluation/utils.py
index 5b3d9b4498..a7240eed59 100644
--- a/vertexai/evaluation/utils.py
+++ b/vertexai/evaluation/utils.py
@@ -31,7 +31,6 @@
from google.cloud.aiplatform_v1.services import (
evaluation_service as gapic_evaluation_services,
)
-from vertexai.evaluation import constants
if TYPE_CHECKING:
@@ -180,13 +179,13 @@ def load_dataset(
else:
raise ValueError(
f"Unsupported file type: {file_type} from {source}. Please"
- " provide valid GCS path with jsonl or csv suffix or valid"
+ " provide a valid GCS path with `jsonl` or `csv` suffix or a valid"
" BigQuery table URI."
)
else:
raise TypeError(
- "Unsupported dataset type. Must be DataFrame, dictionary, or"
- " valid GCS path with jsonl or csv suffix or BigQuery table URI."
+ "Unsupported dataset type. Must be a `pd.DataFrame`, Python dictionary,"
+ " valid GCS path with `jsonl` or `csv` suffix or a valid BigQuery table URI."
)
@@ -254,25 +253,27 @@ def _read_gcs_file_contents(filepath: str) -> str:
def upload_evaluation_results(
dataset: "pd.DataFrame", destination_uri_prefix: str, file_name: str
):
- """Uploads eval results to GCS csv destination."""
+ """Uploads eval results to GCS CSV destination."""
+ supported_file_types = ["csv"]
if not destination_uri_prefix:
return
if destination_uri_prefix.startswith(_GCS_PREFIX):
_, extension = os.path.splitext(file_name)
file_type = extension.lower()[1:]
- if file_type in ["csv"]:
+ if file_type in supported_file_types:
output_path = destination_uri_prefix + "/" + file_name
utils.gcs_utils._upload_pandas_df_to_gcs(dataset, output_path)
else:
raise ValueError(
- "Unsupported file type of GCS destination uri:"
- f" {file_name}, please provide valid GCS"
- " path with csv suffix."
+ "Unsupported file type in the GCS destination URI:"
+ f" {file_name}, please provide a valid GCS"
+ f" file name with a file type in {supported_file_types}."
)
else:
raise ValueError(
- f"Unsupported destination uri: {destination_uri_prefix}."
- " Please provide valid GCS bucket path."
+ f"Unsupported destination URI: {destination_uri_prefix}."
+ f" Please provide a valid GCS bucket URI prefix starting with"
+ f" {_GCS_PREFIX}."
)
@@ -285,23 +286,12 @@ def initialize_metric_column_mapping(
initialized_metric_column_mapping[column] = column
if metric_column_mapping:
for key, value in metric_column_mapping.items():
- if (
- key == constants.Dataset.PROMPT_COLUMN
- and value != constants.Dataset.PROMPT_COLUMN
- ):
- _LOGGER.warning(
- f"`{key}:{value}` will be overwritten to"
- f" `prompt:prompt`. Please do not set `prompt` as"
- " key in metric_column_mapping."
- )
- continue
if key in initialized_metric_column_mapping:
_LOGGER.warning(
- f"`{key}:{key}` is already in metric_column_mapping. Cannot"
- f" override it with `{key}:{value}` because `{key}` is an"
- " evaluation dataset column. Metric_column_mapping cannot override"
- " keys that are already in evaluation dataset"
- " column.metric_column_mapping: {metric_column_mapping}."
+ f"Cannot override `{key}` column with `{key}:{value}` mapping"
+ f" because `{key}` column is present in the evaluation"
+ " dataset. `metric_column_mapping` cannot override keys"
+ " that are already in evaluation dataset columns."
)
else:
initialized_metric_column_mapping[key] = value
diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py
index f6db0d115a..80b21b9591 100644
--- a/vertexai/generative_models/_generative_models.py
+++ b/vertexai/generative_models/_generative_models.py
@@ -29,9 +29,11 @@
Dict,
Iterable,
List,
+ Literal,
Optional,
Sequence,
Union,
+ overload,
TYPE_CHECKING,
)
@@ -514,6 +516,32 @@ def _parse_response(
) -> "GenerationResponse":
return GenerationResponse._from_gapic(response)
+ @overload
+ def generate_content(
+ self,
+ contents: ContentsType,
+ *,
+ generation_config: Optional[GenerationConfigType] = None,
+ safety_settings: Optional[SafetySettingsType] = None,
+ tools: Optional[List["Tool"]] = None,
+ tool_config: Optional["ToolConfig"] = None,
+ stream: Literal[False] = False,
+ ) -> "GenerationResponse":
+ ...
+
+ @overload
+ def generate_content(
+ self,
+ contents: ContentsType,
+ *,
+ generation_config: Optional[GenerationConfigType] = None,
+ safety_settings: Optional[SafetySettingsType] = None,
+ tools: Optional[List["Tool"]] = None,
+ tool_config: Optional["ToolConfig"] = None,
+ stream: Literal[True],
+ ) -> Iterable["GenerationResponse"]:
+ ...
+
def generate_content(
self,
contents: ContentsType,
@@ -523,7 +551,7 @@ def generate_content(
tools: Optional[List["Tool"]] = None,
tool_config: Optional["ToolConfig"] = None,
stream: bool = False,
- ) -> Union["GenerationResponse", Iterable["GenerationResponse"],]:
+ ) -> Union["GenerationResponse", Iterable["GenerationResponse"]]:
"""Generates content.
Args:
@@ -562,6 +590,32 @@ def generate_content(
tool_config=tool_config,
)
+ @overload
+ async def generate_content_async(
+ self,
+ contents: ContentsType,
+ *,
+ generation_config: Optional[GenerationConfigType] = None,
+ safety_settings: Optional[SafetySettingsType] = None,
+ tools: Optional[List["Tool"]] = None,
+ tool_config: Optional["ToolConfig"] = None,
+ stream: Literal[False] = False,
+ ) -> "GenerationResponse":
+ ...
+
+ @overload
+ async def generate_content_async(
+ self,
+ contents: ContentsType,
+ *,
+ generation_config: Optional[GenerationConfigType] = None,
+ safety_settings: Optional[SafetySettingsType] = None,
+ tools: Optional[List["Tool"]] = None,
+ tool_config: Optional["ToolConfig"] = None,
+ stream: Literal[True] = True,
+ ) -> AsyncIterable["GenerationResponse"]:
+ ...
+
async def generate_content_async(
self,
contents: ContentsType,
@@ -571,7 +625,7 @@ async def generate_content_async(
tools: Optional[List["Tool"]] = None,
tool_config: Optional["ToolConfig"] = None,
stream: bool = False,
- ) -> Union["GenerationResponse", AsyncIterable["GenerationResponse"],]:
+ ) -> Union["GenerationResponse", AsyncIterable["GenerationResponse"]]:
"""Generates content asynchronously.
Args:
@@ -770,7 +824,7 @@ async def async_generator():
return async_generator()
def count_tokens(
- self, contents: ContentsType
+ self, contents: ContentsType, *, tools: Optional[List["Tool"]] = None
) -> gapic_prediction_service_types.CountTokensResponse:
"""Counts tokens.
@@ -782,22 +836,32 @@ def count_tokens(
* str, Image, Part,
* List[Union[str, Image, Part]],
* List[Content]
+ tools: A list of tools (functions) that the model can try calling.
Returns:
A CountTokensResponse object that has the following attributes:
total_tokens: The total number of tokens counted across all instances from the request.
total_billable_characters: The total number of billable characters counted across all instances from the request.
"""
+ request = self._prepare_request(
+ contents=contents,
+ tools=tools,
+ )
return self._prediction_client.count_tokens(
request=gapic_prediction_service_types.CountTokensRequest(
endpoint=self._prediction_resource_name,
model=self._prediction_resource_name,
- contents=self._prepare_request(contents=contents).contents,
+ contents=request.contents,
+ system_instruction=request.system_instruction,
+ tools=request.tools,
)
)
async def count_tokens_async(
- self, contents: ContentsType
+ self,
+ contents: ContentsType,
+ *,
+ tools: Optional[List["Tool"]] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
"""Counts tokens asynchronously.
@@ -809,17 +873,24 @@ async def count_tokens_async(
* str, Image, Part,
* List[Union[str, Image, Part]],
* List[Content]
+ tools: A list of tools (functions) that the model can try calling.
Returns:
And awaitable for a CountTokensResponse object that has the following attributes:
total_tokens: The total number of tokens counted across all instances from the request.
total_billable_characters: The total number of billable characters counted across all instances from the request.
"""
+ request = self._prepare_request(
+ contents=contents,
+ tools=tools,
+ )
return await self._prediction_async_client.count_tokens(
request=gapic_prediction_service_types.CountTokensRequest(
endpoint=self._prediction_resource_name,
model=self._prediction_resource_name,
- contents=self._prepare_request(contents=contents).contents,
+ contents=request.contents,
+ system_instruction=request.system_instruction,
+ tools=request.tools,
)
)
@@ -981,6 +1052,30 @@ def __init__(
def history(self) -> List["Content"]:
return self._history
+ @overload
+ def send_message(
+ self,
+ content: PartsType,
+ *,
+ generation_config: Optional[GenerationConfigType] = None,
+ safety_settings: Optional[SafetySettingsType] = None,
+ tools: Optional[List["Tool"]] = None,
+ stream: Literal[False] = False,
+ ) -> "GenerationResponse":
+ ...
+
+ @overload
+ def send_message(
+ self,
+ content: PartsType,
+ *,
+ generation_config: Optional[GenerationConfigType] = None,
+ safety_settings: Optional[SafetySettingsType] = None,
+ tools: Optional[List["Tool"]] = None,
+ stream: Literal[True] = True,
+ ) -> Iterable["GenerationResponse"]:
+ ...
+
def send_message(
self,
content: PartsType,
@@ -1025,6 +1120,30 @@ def send_message(
tools=tools,
)
+ @overload
+ def send_message_async(
+ self,
+ content: PartsType,
+ *,
+ generation_config: Optional[GenerationConfigType] = None,
+ safety_settings: Optional[SafetySettingsType] = None,
+ tools: Optional[List["Tool"]] = None,
+ stream: Literal[False] = False,
+ ) -> Awaitable["GenerationResponse"]:
+ ...
+
+ @overload
+ def send_message_async(
+ self,
+ content: PartsType,
+ *,
+ generation_config: Optional[GenerationConfigType] = None,
+ safety_settings: Optional[SafetySettingsType] = None,
+ tools: Optional[List["Tool"]] = None,
+ stream: Literal[True] = True,
+ ) -> Awaitable[AsyncIterable["GenerationResponse"]]:
+ ...
+
def send_message_async(
self,
content: PartsType,
diff --git a/vertexai/preview/rag/__init__.py b/vertexai/preview/rag/__init__.py
index fff380b359..56590fed0e 100644
--- a/vertexai/preview/rag/__init__.py
+++ b/vertexai/preview/rag/__init__.py
@@ -38,36 +38,38 @@
)
from vertexai.preview.rag.utils.resources import (
EmbeddingModelConfig,
- JiraSource,
JiraQuery,
+ JiraSource,
RagCorpus,
RagFile,
RagResource,
SlackChannel,
SlackChannelsSource,
+ Weaviate,
)
__all__ = (
+ "EmbeddingModelConfig",
+ "JiraQuery",
+ "JiraSource",
+ "RagCorpus",
+ "RagFile",
+ "RagResource",
+ "Retrieval",
+ "SlackChannel",
+ "SlackChannelsSource",
+ "VertexRagStore",
+ "Weaviate",
"create_corpus",
- "list_corpora",
- "get_corpus",
"delete_corpus",
- "upload_file",
+ "delete_file",
+ "get_corpus",
+ "get_file",
"import_files",
"import_files_async",
- "get_file",
+ "list_corpora",
"list_files",
- "delete_file",
"retrieval_query",
- "EmbeddingModelConfig",
- "Retrieval",
- "VertexRagStore",
- "RagResource",
- "RagFile",
- "RagCorpus",
- "JiraSource",
- "JiraQuery",
- "SlackChannel",
- "SlackChannelsSource",
+ "upload_file",
)
diff --git a/vertexai/preview/rag/rag_data.py b/vertexai/preview/rag/rag_data.py
index 32da89c0a3..0983037173 100644
--- a/vertexai/preview/rag/rag_data.py
+++ b/vertexai/preview/rag/rag_data.py
@@ -48,6 +48,7 @@
RagCorpus,
RagFile,
SlackChannelsSource,
+ Weaviate,
)
@@ -55,6 +56,7 @@ def create_corpus(
display_name: Optional[str] = None,
description: Optional[str] = None,
embedding_model_config: Optional[EmbeddingModelConfig] = None,
+ vector_db: Optional[Weaviate] = None,
) -> RagCorpus:
"""Creates a new RagCorpus resource.
@@ -76,6 +78,8 @@ def create_corpus(
consist of any UTF-8 characters.
description: The description of the RagCorpus.
embedding_model_config: The embedding model config.
+ vector_db: The vector db config of the RagCorpus. If unspecified, the
+ default database Spanner is used.
Returns:
RagCorpus.
Raises:
@@ -88,9 +92,14 @@ def create_corpus(
rag_corpus = GapicRagCorpus(display_name=display_name, description=description)
if embedding_model_config:
- rag_corpus = _gapic_utils.set_embedding_model_config(
- embedding_model_config,
- rag_corpus,
+ _gapic_utils.set_embedding_model_config(
+ embedding_model_config=embedding_model_config,
+ rag_corpus=rag_corpus,
+ )
+ if vector_db is not None:
+ _gapic_utils.set_vector_db(
+ vector_db=vector_db,
+ rag_corpus=rag_corpus,
)
request = CreateRagCorpusRequest(
diff --git a/vertexai/preview/rag/rag_retrieval.py b/vertexai/preview/rag/rag_retrieval.py
index 519b15a822..1d8bbb5612 100644
--- a/vertexai/preview/rag/rag_retrieval.py
+++ b/vertexai/preview/rag/rag_retrieval.py
@@ -37,6 +37,7 @@ def retrieval_query(
rag_corpora: Optional[List[str]] = None,
similarity_top_k: Optional[int] = 10,
vector_distance_threshold: Optional[float] = 0.3,
+ vector_search_alpha: Optional[float] = 0.5,
) -> RetrieveContextsResponse:
"""Retrieve top k relevant docs/chunks.
@@ -54,6 +55,7 @@ def retrieval_query(
)],
similarity_top_k=2,
vector_distance_threshold=0.5,
+ vector_search_alpha=0.5,
)
```
@@ -67,6 +69,10 @@ def retrieval_query(
similarity_top_k: The number of contexts to retrieve.
vector_distance_threshold: Optional. Only return contexts with vector
distance smaller than the threshold.
+ vector_search_alpha: Optional. Controls the weight between dense and
+ sparse vector search results. The range is [0, 1], where 0 means
+ sparse vector search only and 1 means dense vector search only.
+ The default value is 0.5.
Returns:
RetrieveContextsResonse.
@@ -111,7 +117,13 @@ def retrieval_query(
)
vertex_rag_store.vector_distance_threshold = vector_distance_threshold
- query = RagQuery(text=text, similarity_top_k=similarity_top_k)
+ query = RagQuery(
+ text=text,
+ similarity_top_k=similarity_top_k,
+ ranking=RagQuery.Ranking(
+ alpha=vector_search_alpha,
+ ),
+ )
request = RetrieveContextsRequest(
vertex_rag_store=vertex_rag_store,
parent=parent,
diff --git a/vertexai/preview/rag/utils/_gapic_utils.py b/vertexai/preview/rag/utils/_gapic_utils.py
index 6cc1c1d316..640fd8c5f0 100644
--- a/vertexai/preview/rag/utils/_gapic_utils.py
+++ b/vertexai/preview/rag/utils/_gapic_utils.py
@@ -28,6 +28,7 @@
RagFile as GapicRagFile,
SlackSource as GapicSlackSource,
JiraSource as GapicJiraSource,
+ RagVectorDbConfig,
)
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.utils import (
@@ -41,6 +42,7 @@
RagFile,
SlackChannelsSource,
JiraSource,
+ Weaviate,
)
@@ -93,8 +95,22 @@ def convert_gapic_to_embedding_model_config(
return embedding_model_config
+def convert_gapic_to_vector_db(
+ gapic_vector_db: RagVectorDbConfig,
+) -> Weaviate:
+ """Convert Gapic RagVectorDbConfig to Weaviate."""
+ if gapic_vector_db.__contains__("weaviate"):
+ return Weaviate(
+ weaviate_http_endpoint=gapic_vector_db.weaviate.http_endpoint,
+ collection_name=gapic_vector_db.weaviate.collection_name,
+ api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version,
+ )
+ else:
+ return None
+
+
def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus:
- """ "Convert GapicRagCorpus to RagCorpus."""
+ """Convert GapicRagCorpus to RagCorpus."""
rag_corpus = RagCorpus(
name=gapic_rag_corpus.name,
display_name=gapic_rag_corpus.display_name,
@@ -102,12 +118,13 @@ def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus:
embedding_model_config=convert_gapic_to_embedding_model_config(
gapic_rag_corpus.rag_embedding_model_config
),
+ vector_db=convert_gapic_to_vector_db(gapic_rag_corpus.rag_vector_db_config),
)
return rag_corpus
def convert_gapic_to_rag_file(gapic_rag_file: GapicRagFile) -> RagFile:
- """ "Convert GapicRagFile to RagFile."""
+ """Convert GapicRagFile to RagFile."""
rag_file = RagFile(
name=gapic_rag_file.name,
display_name=gapic_rag_file.display_name,
@@ -315,7 +332,7 @@ def get_file_name(
def set_embedding_model_config(
embedding_model_config: EmbeddingModelConfig,
rag_corpus: GapicRagCorpus,
-) -> GapicRagCorpus:
+) -> None:
if embedding_model_config.publisher_model and embedding_model_config.endpoint:
raise ValueError("publisher_model and endpoint cannot be set at the same time.")
if (
@@ -371,4 +388,27 @@ def set_embedding_model_config(
"endpoint must be of the format `projects/{project}/locations/{location}/endpoints/{endpoint}` or `endpoints/{endpoint}`"
)
- return rag_corpus
+
+def set_vector_db(
+ vector_db: Weaviate,
+ rag_corpus: GapicRagCorpus,
+) -> None:
+ """Sets the vector db configuration for the rag corpus."""
+ if isinstance(vector_db, Weaviate):
+ http_endpoint = vector_db.weaviate_http_endpoint
+ collection_name = vector_db.collection_name
+ api_key = vector_db.api_key
+
+ rag_corpus.rag_vector_db_config = RagVectorDbConfig(
+ weaviate=RagVectorDbConfig.Weaviate(
+ http_endpoint=http_endpoint,
+ collection_name=collection_name,
+ ),
+ api_auth=api_auth.ApiAuth(
+ api_key_config=api_auth.ApiAuth.ApiKeyConfig(
+ api_key_secret_version=api_key
+ ),
+ ),
+ )
+ else:
+ raise TypeError("vector_db must be a Weaviate.")
diff --git a/vertexai/preview/rag/utils/resources.py b/vertexai/preview/rag/utils/resources.py
index 1b5af451f6..aad7bad35d 100644
--- a/vertexai/preview/rag/utils/resources.py
+++ b/vertexai/preview/rag/utils/resources.py
@@ -69,6 +69,22 @@ class EmbeddingModelConfig:
model_version_id: Optional[str] = None
+@dataclasses.dataclass
+class Weaviate:
+ """Weaviate.
+
+ Attributes:
+ weaviate_http_endpoint: The Weaviate DB instance HTTP endpoint
+ collection_name: The corresponding Weaviate collection this corpus maps to
+ api_key: The SecretManager resource name for the Weaviate DB API token. Format:
+ ``projects/{project}/secrets/{secret}/versions/{version}``
+ """
+
+ weaviate_http_endpoint: str
+ collection_name: str
+ api_key: str
+
+
@dataclasses.dataclass
class RagCorpus:
"""RAG corpus(output only).
@@ -78,12 +94,15 @@ class RagCorpus:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}``
display_name: Display name that was configured at client side.
description: The description of the RagCorpus.
+ embedding_model_config: The embedding model config of the RagCorpus.
+ vector_db: The Vector DB of the RagCorpus.
"""
name: Optional[str] = None
display_name: Optional[str] = None
description: Optional[str] = None
embedding_model_config: Optional[EmbeddingModelConfig] = None
+ vector_db: Optional[Weaviate] = None
@dataclasses.dataclass
diff --git a/vertexai/preview/tokenization.py b/vertexai/preview/tokenization.py
index add6cf67cc..ad7e06203b 100644
--- a/vertexai/preview/tokenization.py
+++ b/vertexai/preview/tokenization.py
@@ -16,7 +16,7 @@
# We just want to re-export certain classes
# pylint: disable=g-multiple-import,g-importing-member
from vertexai.tokenization._tokenizers import (
- get_tokenizer_for_model,
+ _get_tokenizer_for_model_preview as get_tokenizer_for_model,
)
diff --git a/vertexai/tokenization/_tokenizers.py b/vertexai/tokenization/_tokenizers.py
index 1bad192db6..17e2b5b823 100644
--- a/vertexai/tokenization/_tokenizers.py
+++ b/vertexai/tokenization/_tokenizers.py
@@ -18,13 +18,18 @@
Iterable,
List,
Sequence,
+ Optional,
+ Union,
)
from vertexai.generative_models._generative_models import (
ContentsType,
Image,
+ Tool,
+ PartsType,
_validate_contents_type_as_valid_sequence,
_content_types_to_gapic_contents,
+ _to_content,
)
from vertexai.tokenization._tokenizer_loading import (
@@ -35,8 +40,10 @@
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types,
tool as gapic_tool_types,
+ openapi,
)
from sentencepiece import sentencepiece_model_pb2
+from google.protobuf import struct_pb2
@dataclasses.dataclass(frozen=True)
@@ -48,11 +55,37 @@ class TokensInfo:
@dataclasses.dataclass(frozen=True)
class ComputeTokensResult:
- token_info_list: Sequence[TokensInfo]
+ """Represents token string pieces and ids output in compute_tokens function.
+
+ Attributes:
+ tokens_info: Lists of tokens_info from the input.
+ The input `contents: ContentsType` could have
+ multiple string instances and each tokens_info
+ item represents each string instance. Each token
+ info consists tokens list, token_ids list and
+ a role.
+ """
+
+ tokens_info: Sequence[TokensInfo]
+
+
+class PreviewComputeTokensResult(ComputeTokensResult):
+ def token_info_list(self) -> Sequence[TokensInfo]:
+ import warnings
+
+ message = "PreviewComputeTokensResult.token_info_list is deprecated. Use ComputeTokensResult.tokens_info instead."
+ warnings.warn(message, DeprecationWarning, stacklevel=2)
+ return self.tokens_info
@dataclasses.dataclass(frozen=True)
class CountTokensResult:
+ """Represents an token numbers output in count_tokens function.
+
+ Attributes:
+ total_tokens: number of total tokens.
+ """
+
total_tokens: int
@@ -129,27 +162,20 @@ def compute_tokens(
role=role,
)
)
- return ComputeTokensResult(token_info_list=token_infos)
+ return ComputeTokensResult(tokens_info=token_infos)
def _to_gapic_contents(
contents: ContentsType,
-) -> gapic_content_types.Content:
+) -> List[gapic_content_types.Content]:
"""Converts a GenerativeModel compatible contents type to a gapic content."""
_validate_contents_type_as_valid_sequence(contents)
_assert_no_image_contents_type(contents)
gapic_contents = _content_types_to_gapic_contents(contents)
- _assert_text_only_content_types_sequence(gapic_contents)
+ # _assert_text_only_content_types_sequence(gapic_contents)
return gapic_contents
-def _content_types_to_string_iterator(contents: ContentsType) -> Iterable[str]:
- """Converts a GenerativeModel compatible contents type to a list of strings."""
- gapic_contents = _to_gapic_contents(contents)
- for content in gapic_contents:
- yield from _to_string_array(content)
-
-
def _content_types_to_role_iterator(contents: ContentsType) -> Iterable[str]:
gapic_contents = _to_gapic_contents(contents)
# Flattening role by content's multi parts
@@ -158,13 +184,6 @@ def _content_types_to_role_iterator(contents: ContentsType) -> Iterable[str]:
yield content.role
-def _to_string_array(content: gapic_content_types.Content) -> Iterable[str]:
- """Converts a gapic content type to a list of strings."""
- if not content:
- raise TypeError("content must not be empty.")
- return [part.text for part in content.parts]
-
-
def _assert_no_image_contents_type(contents: ContentsType):
"""Asserts that the contents type does not contain any image content."""
if isinstance(contents, Image) or (
@@ -174,37 +193,12 @@ def _assert_no_image_contents_type(contents: ContentsType):
raise ValueError("Tokenizers do not support Image content type.")
-def _assert_text_only_content_types_sequence(
- contents: List[gapic_content_types.Content],
-):
- """Asserts that the contents type does not contain any non-text content."""
- for value in contents:
- for part in value.parts:
- _assert_text_only_gapic_part(part)
-
-
-def _assert_text_only_gapic_part(value: gapic_content_types.Part):
- """Asserts that the gapic content part is a text content type."""
- if (
- gapic_content_types.FileData() != value.file_data
- or gapic_content_types.Blob() != value.inline_data
- or gapic_tool_types.FunctionCall() != value.function_call
- or gapic_tool_types.FunctionResponse() != value.function_response
- or gapic_content_types.VideoMetadata() != value.video_metadata
- ):
- raise ValueError("Tokenizers do not support non-text content types.")
-
-
-def _to_canonical_contents_texts(contents: ContentsType) -> Iterable[str]:
- """Gets the canonical contents."""
- if isinstance(contents, str):
- yield contents
- elif isinstance(contents, Sequence) and all(
- isinstance(content, str) for content in contents
- ):
- yield from contents
- else:
- yield from _content_types_to_string_iterator(contents)
+def _is_string_inputs(contents: ContentsType) -> bool:
+ return (
+ isinstance(contents, str)
+ or isinstance(contents, Sequence)
+ and all(isinstance(content, str) for content in contents)
+ )
def _to_canonical_roles(contents: ContentsType) -> Iterable[str]:
@@ -218,6 +212,224 @@ def _to_canonical_roles(contents: ContentsType) -> Iterable[str]:
yield from _content_types_to_role_iterator(contents)
+class _TextsAccumulator:
+ """Accumulates texts from contents and tools.
+
+ This class is used to accumulate countable texts from contents and tools.
+ When user passes a unsupported fields that are added in the future, the new
+ fields might be only counted in remote tokenizer. In this case, the local
+ tokenizer should know that an unsupported new field exist in the content or
+ tool instances and raise error to avoid returning incorrect result to users.
+
+ The machanism to detect unsupported fields introduced in the future: when
+ local tokenizer traversing the input instances, it is allowlist based text
+ accumulation. When a field is traversed and evaluated to be countable, the
+ value of this
+ field is copied to two places: (1) self._texts for inputs to sentencepiece
+ token count function. (2) a counted instance object in the recursive
+ function's return value. That's to say, after done current recurssion,
+ the instance(of same type as the input) only keeps the counted values.
+ If user sets unsupported future proto fields, they can be detected by
+ comparing the input instances equal to counted instances or not.
+ """
+
+ def __init__(self):
+ self._texts = []
+
+ def get_texts(self) -> Iterable[str]:
+ return self._texts
+
+ def add_texts(self, texts: Union[Iterable[str], str]) -> None:
+ if isinstance(texts, str):
+ self._texts.append(texts)
+ else:
+ self._texts.extend(texts)
+
+ def add_content(self, content: gapic_content_types.Content) -> None:
+ counted_content = gapic_content_types.Content()
+ for part in content.parts:
+ counted_part = gapic_content_types.Part()
+ if "file_data" in part or "inline_data" in part:
+ raise ValueError("Tokenizers do not support non-text content types.")
+ if "video_metadata" in part:
+ counted_part.video_metadata = part.video_metadata
+ if "function_call" in part:
+ self.add_function_call(part.function_call)
+ counted_part.function_call = part.function_call
+ if "function_response" in part:
+ self.add_function_response(part.function_response)
+ counted_part.function_response = part.function_response
+ if "text" in part:
+ counted_part.text = part.text
+ self._texts.append(part.text)
+ counted_content.parts.append(counted_part)
+ counted_content.role = content.role
+ if content._pb != counted_content._pb:
+ raise ValueError(
+ f"Content contains unsupported types for token counting. Supported fields {counted_content}. Got {content}."
+ )
+
+ def add_function_call(self, function_call: gapic_tool_types.FunctionCall) -> None:
+ """Processes a function call and adds relevant text to the accumulator.
+
+ Args:
+ function_call: The function call to process.
+ """
+ self._texts.append(function_call.name)
+ counted_function_call = gapic_tool_types.FunctionCall(name=function_call.name)
+ counted_struct = self._struct_traverse(function_call._pb.args)
+ counted_function_call.args = counted_struct
+ if counted_function_call._pb != function_call._pb:
+ raise ValueError(
+ f"Function call argument contains unsupported types for token counting. Supported fields {counted_function_call}. Got {function_call}."
+ )
+
+ def add_function_calls(
+ self, function_calls: Iterable[gapic_tool_types.FunctionCall]
+ ) -> None:
+ for function_call in function_calls:
+ self.add_function_call(function_call)
+
+ def add_tool(self, tool: gapic_tool_types.Tool) -> gapic_tool_types.Tool:
+ counted_tool = gapic_tool_types.Tool()
+ for function_declaration in tool.function_declarations:
+ counted_function_declaration = self._function_declaration_traverse(
+ function_declaration
+ )
+ counted_tool.function_declarations.append(counted_function_declaration)
+ if counted_tool._pb != tool._pb:
+ raise ValueError(
+ f"Tool argument contains unsupported types for token counting. Supported fields {counted_tool}. Got {tool}."
+ )
+
+ def add_tools(self, tools: Iterable[gapic_tool_types.Tool]) -> None:
+ for tool in tools:
+ self.add_tool(tool)
+
+ def add_function_responses(
+ self, function_responses: Iterable[gapic_tool_types.FunctionResponse]
+ ) -> None:
+ for function_response in function_responses:
+ self.add_function_response(function_response)
+
+ def add_function_response(
+ self, function_response: gapic_tool_types.FunctionResponse
+ ) -> None:
+ counted_function_response = gapic_tool_types.FunctionResponse()
+ self._texts.append(function_response.name)
+ counted_struct = self._struct_traverse(function_response._pb.response)
+ counted_function_response.name = function_response.name
+ counted_function_response.response = counted_struct
+ if counted_function_response._pb != function_response._pb:
+ raise ValueError(
+ f"Function response argument contains unsupported types for token counting. Supported fields {counted_function_response}. Got {function_response}."
+ )
+
+ def _function_declaration_traverse(
+ self, function_declaration: gapic_tool_types.FunctionDeclaration
+ ) -> gapic_tool_types.FunctionDeclaration:
+ counted_function_declaration = gapic_tool_types.FunctionDeclaration()
+ self._texts.append(function_declaration.name)
+ counted_function_declaration.name = function_declaration.name
+ if function_declaration.description:
+ self._texts.append(function_declaration.description)
+ counted_function_declaration.description = function_declaration.description
+ if function_declaration.parameters:
+ counted_parameters = self._schema_traverse(function_declaration.parameters)
+ counted_function_declaration.parameters = counted_parameters
+ if function_declaration.response:
+ counted_response = self._schema_traverse(function_declaration.response)
+ counted_function_declaration.response = counted_response
+ return counted_function_declaration
+
+ def _schema_traverse(self, schema: openapi.Schema) -> openapi.Schema:
+ """Processes a schema and adds relevant text to the accumulator.
+
+ Args:
+ schema: The schema to process.
+
+ Returns:
+ The new schema object with only countable fields.
+ """
+ counted_schema = openapi.Schema()
+ if "type_" in schema:
+ counted_schema.type = schema.type
+ if "title" in schema:
+ counted_schema.title = schema.title
+ if "default" in schema:
+ counted_schema.default = schema.default
+ if "format_" in schema:
+ self._texts.append(schema.format_)
+ counted_schema.format_ = schema.format_
+ if "description" in schema:
+ self._texts.append(schema.description)
+ counted_schema.description = schema.description
+ if "enum" in schema:
+ self._texts.extend(schema.enum)
+ counted_schema.enum = schema.enum
+ if "required" in schema:
+ self._texts.extend(schema.required)
+ counted_schema.required = schema.required
+ if "items" in schema:
+ counted_schema_items = self._schema_traverse(schema.items)
+ counted_schema.items = counted_schema_items
+ if "properties" in schema:
+ d = {}
+ for key, value in schema.properties.items():
+ self._texts.append(key)
+ counted_value = self._schema_traverse(value)
+ d[key] = counted_value
+ counted_schema.properties.update(d)
+ if "example" in schema:
+ counted_schema_example = self._value_traverse(schema._pb.example)
+ counted_schema.example = counted_schema_example
+ return counted_schema
+
+ def _struct_traverse(self, struct: struct_pb2.Struct) -> struct_pb2.Struct:
+ """Processes a struct and adds relevant text to the accumulator.
+
+ Args:
+ struct: The struct to process.
+
+ Returns:
+ The new struct object with only countable fields.
+ """
+ counted_struct = struct_pb2.Struct()
+ self._texts.extend(list(struct.fields.keys()))
+ for key, val in struct.fields.items():
+ counted_struct_fields = self._value_traverse(val)
+ if isinstance(counted_struct_fields, str):
+ counted_struct.fields[key] = counted_struct_fields
+ else:
+ counted_struct.fields[key].MergeFrom(counted_struct_fields)
+ return counted_struct
+
+ def _value_traverse(self, value: struct_pb2.Value) -> struct_pb2.Value:
+ """Processes a struct field and adds relevant text to the accumulator.
+
+ Args:
+ struct: The struct field to process.
+
+ Returns:
+ The new struct field object with only countable fields.
+ """
+ kind = value.WhichOneof("kind")
+ counted_value = struct_pb2.Value()
+ if kind == "string_value":
+ self._texts.append(value.string_value)
+ counted_value.string_value = value.string_value
+ elif kind == "struct_value":
+ counted_struct = self._struct_traverse(value.struct_value)
+ counted_value.struct_value.MergeFrom(counted_struct)
+ elif kind == "list_value":
+ counted_list_value = struct_pb2.ListValue()
+ for item in value.list_value.values:
+ counted_value = self._value_traverse(item)
+ counted_list_value.values.append(counted_value)
+ counted_value.list_value.MergeFrom(counted_list_value)
+ return counted_value
+
+
class Tokenizer:
"""A tokenizer that can parse text into tokens."""
@@ -228,10 +440,17 @@ def __init__(self, tokenizer_name: str):
Args:
name: The name of the tokenizer.
+
"""
self._sentencepiece_adapter = _SentencePieceAdaptor(tokenizer_name)
- def count_tokens(self, contents: ContentsType) -> CountTokensResult:
+ def count_tokens(
+ self,
+ contents: ContentsType,
+ *,
+ tools: Optional[List["Tool"]] = None,
+ system_instruction: Optional[PartsType] = None,
+ ) -> CountTokensResult:
r"""Counts the number of tokens in the text-only contents.
Args:
@@ -244,16 +463,34 @@ def count_tokens(self, contents: ContentsType) -> CountTokensResult:
* List[Union[str, Part]],
* List[Content]
Throws an error if the contents contain non-text content.
-
+ tools: A list of tools (functions) that the model can try calling.
+ system_instruction: The provided system instructions for the model.
+ Note: only text should be used in parts and content in each part
+ will be in a separate paragraph.
Returns:
A CountTokensResult object containing the total number of tokens in
the contents.
"""
- return self._sentencepiece_adapter.count_tokens(
- _to_canonical_contents_texts(contents)
- )
+ text_accumulator = _TextsAccumulator()
+ if _is_string_inputs(contents):
+ text_accumulator.add_texts(contents)
+ else:
+ gapic_contents = _to_gapic_contents(contents)
+ for content in gapic_contents:
+ text_accumulator.add_content(content)
+
+ if tools:
+ text_accumulator.add_tools((tool._raw_tool for tool in tools))
+
+ if system_instruction:
+ if _is_string_inputs(system_instruction):
+ text_accumulator.add_texts(system_instruction)
+ else:
+ text_accumulator.add_content(_to_content(system_instruction))
+
+ return self._sentencepiece_adapter.count_tokens(text_accumulator.get_texts())
def compute_tokens(self, contents: ContentsType) -> ComputeTokensResult:
r"""Computes the tokens ids and string pieces in the text-only contents.
@@ -272,13 +509,55 @@ def compute_tokens(self, contents: ContentsType) -> ComputeTokensResult:
Returns:
A ComputeTokensResult object containing the tokens ids and string
pieces in the contents.
+
+ Examples:
+ compute_tokens(["hello world", "what's the weather today"])
+ outputs:
+ ComputeTokensResult(tokens_info=[TokensInfo(token_ids=[17534, 2134], tokens=[b'hello', b' world'], role='user'), TokensInfo(token_ids=[84264, 235341], tokens=[b'Goodbye', b'!'], role='user')], token_info_list=...The same as tokens_info)
+
"""
+ text_accumulator = _TextsAccumulator()
+ if _is_string_inputs(contents):
+ text_accumulator.add_texts(contents)
+ else:
+ gapic_contents = _to_gapic_contents(contents)
+ for content in gapic_contents:
+ text_accumulator.add_content(content)
+
return self._sentencepiece_adapter.compute_tokens(
- contents=_to_canonical_contents_texts(contents),
+ contents=text_accumulator.get_texts(),
roles=_to_canonical_roles(contents),
)
+class PreviewTokenizer(Tokenizer):
+ def compute_tokens(self, contents: ContentsType) -> PreviewComputeTokensResult:
+ return PreviewComputeTokensResult(
+ tokens_info=super().compute_tokens(contents).tokens_info
+ )
+
+
+def _get_tokenizer_for_model_preview(model_name: str) -> PreviewTokenizer:
+ """Returns a tokenizer for the given tokenizer name.
+
+ Usage:
+ ```
+ tokenizer = get_tokenizer_for_model("gemini-1.5-pro-001")
+ print(tokenizer.count_tokens("Hello world!"))
+ ```
+
+ Supported models can be found at
+ https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models.
+
+ Args:
+ model_name: Specify the tokenizer is from which model.
+ """
+ if not model_name:
+ raise ValueError("model_name must not be empty.")
+
+ return PreviewTokenizer(get_tokenizer_name(model_name))
+
+
def get_tokenizer_for_model(model_name: str) -> Tokenizer:
"""Returns a tokenizer for the given tokenizer name.
@@ -288,6 +567,9 @@ def get_tokenizer_for_model(model_name: str) -> Tokenizer:
print(tokenizer.count_tokens("Hello world!"))
```
+ Supported models can be found at
+ https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models.
+
Args:
model_name: Specify the tokenizer is from which model.
"""
diff --git a/vertexai/vision_models/_vision_models.py b/vertexai/vision_models/_vision_models.py
index f9de7ef3ff..a80f0b2dc1 100644
--- a/vertexai/vision_models/_vision_models.py
+++ b/vertexai/vision_models/_vision_models.py
@@ -792,12 +792,15 @@ def upscale_image(
self,
image: Union["Image", "GeneratedImage"],
new_size: Optional[int] = 2048,
+ upscale_factor: Optional[Literal["x2", "x4"]] = None,
+ output_mime_type: Optional[Literal["image/png", "image/jpeg"]] = "image/png",
+ output_compression_quality: Optional[int] = None,
output_gcs_uri: Optional[str] = None,
) -> "Image":
"""Upscales an image.
- This supports upscaling images generated through the `generate_images()` method,
- or upscaling a new image that is 1024x1024.
+ This supports upscaling images generated through the `generate_images()`
+ method, or upscaling a new image.
Examples::
@@ -812,27 +815,63 @@ def upscale_image(
my_image = Image.load_from_file("my-image.png")
model.upscale_image(image=my_image)
+ # Upscale a new arbitrary sized image using a x2 or x4 upscaling factor
+ my_image = Image.load_from_file("my-image.png")
+ model.upscale_image(image=my_image, upscale_factor="x2")
+
+ # Upscale an image and get the result in JPEG format
+ my_image = Image.load_from_file("my-image.png")
+ model.upscale_image(image=my_image, output_mime_type="image/jpeg",
+ output_compression_quality=90)
+
Args:
- image (Union[GeneratedImage, Image]):
- Required. The generated image to upscale.
- new_size (int):
- The size of the biggest dimension of the upscaled image. Only 2048 and 4096 are currently
- supported. Results in a 2048x2048 or 4096x4096 image. Defaults to 2048 if not provided.
- output_gcs_uri: Google Cloud Storage uri to store the upscaled images.
+ image (Union[GeneratedImage, Image]): Required. The generated image
+ to upscale.
+ new_size (int): The size of the biggest dimension of the upscaled
+ image.
+ Only 2048 and 4096 are currently supported. Results in a
+ 2048x2048 or 4096x4096 image. Defaults to 2048 if not provided.
+ upscale_factor: The upscaling factor. Supported values are "x2" and
+ "x4". Defaults to None.
+ output_mime_type: The mime type of the output image. Supported values
+ are "image/png" and "image/jpeg". Defaults to "image/png".
+ output_compression_quality: The compression quality of the output
+ image
+ as an int (0-100). Only applicable if the output mime type is
+ "image/jpeg". Defaults to None.
+ output_gcs_uri: Google Cloud Storage uri to store the upscaled
+ images.
Returns:
An `Image` object.
"""
+ target_image_size = new_size if new_size else None
+ longest_dim = max(image._size[0], image._size[1])
- # Currently this method only supports 1024x1024 images
- if image._size[0] != 1024 and image._size[1] != 1024:
- raise ValueError(
- "Upscaling is currently only supported on images that are 1024x1024."
- )
+ if not new_size and not upscale_factor:
+ raise ValueError("Either new_size or upscale_factor must be provided.")
+
+ if not upscale_factor:
+ x2_factor = 2.0
+ x4_factor = 4.0
+ epsilon = 0.1
+ is_upscaling_x2_request = abs(new_size / longest_dim - x2_factor) < epsilon
+ is_upscaling_x4_request = abs(new_size / longest_dim - x4_factor) < epsilon
+ if not is_upscaling_x2_request and not is_upscaling_x4_request:
+ raise ValueError(
+ "Only x2 and x4 upscaling are currently supported. Requested"
+ f" upscaling factor: {new_size / longest_dim}"
+ )
+ else:
+ if upscale_factor == "x2":
+ target_image_size = longest_dim * 2
+ else:
+ target_image_size = longest_dim * 4
if new_size not in _SUPPORTED_UPSCALING_SIZES:
raise ValueError(
- f"Only the folowing square upscaling sizes are currently supported: {_SUPPORTED_UPSCALING_SIZES}."
+ "Only the folowing square upscaling sizes are currently supported:"
+ f" {_SUPPORTED_UPSCALING_SIZES}."
)
instance = {"prompt": ""}
@@ -847,14 +886,25 @@ def upscale_image(
}
parameters = {
- "sampleImageSize": str(new_size),
"sampleCount": 1,
"mode": "upscale",
}
+ if upscale_factor:
+ parameters["upscaleConfig"] = {"upscaleFactor": upscale_factor}
+
+ else:
+ parameters["sampleImageSize"] = str(new_size)
+
if output_gcs_uri is not None:
parameters["storageUri"] = output_gcs_uri
+ parameters["outputOptions"] = {"mimeType": output_mime_type}
+ if output_mime_type == "image/jpeg" and output_compression_quality is not None:
+ parameters["outputOptions"][
+ "compressionQuality"
+ ] = output_compression_quality
+
response = self._endpoint.predict(
instances=[instance],
parameters=parameters,
@@ -868,7 +918,7 @@ def upscale_image(
else:
generation_parameters = {}
- generation_parameters["upscaled_image_size"] = new_size
+ generation_parameters["upscaled_image_size"] = target_image_size
encoded_bytes = upscaled_image.get("bytesBase64Encoded")
return GeneratedImage(