diff --git a/.kokoro/build.sh b/.kokoro/build.sh index 64fe343878..031fa2200a 100755 --- a/.kokoro/build.sh +++ b/.kokoro/build.sh @@ -33,9 +33,6 @@ export GOOGLE_APPLICATION_CREDENTIALS=${KOKORO_GFILE_DIR}/service-account.json # Setup project id. export PROJECT_ID=$(cat "${KOKORO_GFILE_DIR}/project-id.json") -# Setup staging endpoint. -export STAGING_ENDPOINT=$(cat "${KOKORO_KEYSTORE_DIR}/73713_vertexai-staging-endpoint") - # Remove old nox python3 -m pip uninstall --yes --quiet nox-automation diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 3ed68cb97c..979c42f7ec 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.64.0" + ".": "1.65.0" } diff --git a/CHANGELOG.md b/CHANGELOG.md index 49f8adb800..d31b3a100f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,29 @@ # Changelog +## [1.65.0](https://github.com/googleapis/python-aiplatform/compare/v1.64.0...v1.65.0) (2024-09-04) + + +### âš  BREAKING CHANGES + +* Tokenization - Deprecated `ComputeTokenResult.token_info_list` in favor of `ComputeTokenResult.tokens_info` + +### Features + +* Add support for system instruction and tools in tokenization. ([72fcc06](https://github.com/googleapis/python-aiplatform/commit/72fcc063ed4a086da0ad37ec2ac58860d4e79051)) +* Add vector search alpha to rag retrieval for hybrid search ranking ([6624ebe](https://github.com/googleapis/python-aiplatform/commit/6624ebe22726942dd70781122352f47268ee2dee)) +* Adding Weaviate Vector DB option for RAG corpuses to SDK ([9b28202](https://github.com/googleapis/python-aiplatform/commit/9b28202a1bb17f54e042301d6cdac3b6aa826797)) +* GenAI - Added `system_instruction` and `tools` support to `GenerativeModel.count_tokens` ([50fca69](https://github.com/googleapis/python-aiplatform/commit/50fca693b2f3b1a0b61867dc136be5a468fb2b2f)) +* GenAI - Added Llama3 support in GenAI batch prediction ([6166152](https://github.com/googleapis/python-aiplatform/commit/6166152844dc0078f7a5a02355ef3555cc428cfa)) +* PrivateEndpoint.stream_raw_predict ([197f333](https://github.com/googleapis/python-aiplatform/commit/197f333be5a075d41f98b762cd933cd2e89cecae)) +* Support reserved_ip_ranges for VPC network in Ray on Vertex cluster ([36a56b9](https://github.com/googleapis/python-aiplatform/commit/36a56b99f9e53d19d80c2bff3bf55c208988c518)) +* Tokenization - Deprecated `ComputeTokenResult.token_info_list` in favor of `ComputeTokenResult.tokens_info` ([efbcb54](https://github.com/googleapis/python-aiplatform/commit/efbcb54e0d5df4d65a79e60afdbc5c328538aef6)) + + +### Bug Fixes + +* Tensorboard - Fixed bug in tensorboard uploader where profile logs are not uploaded from nested run directories. ([37627de](https://github.com/googleapis/python-aiplatform/commit/37627de4ec12b8c51abf41524b9e7e3adf2dab54)) +* Tokenizers - Fixed `Tokenizer.compute_tokens` ([c29fa5d](https://github.com/googleapis/python-aiplatform/commit/c29fa5d98fc1c1550c039e84ff2f5725818c2231)) + ## [1.64.0](https://github.com/googleapis/python-aiplatform/compare/v1.63.0...v1.64.0) (2024-08-27) diff --git a/gemini_docs/vertexai/services.rst b/gemini_docs/vertexai/services.rst index 8a2cca514e..3496c44afc 100644 --- a/gemini_docs/vertexai/services.rst +++ b/gemini_docs/vertexai/services.rst @@ -45,6 +45,16 @@ Vertex AI SDK :show-inheritance: :inherited-members: +.. automodule:: vertexai.preview.tuning + :members: + :show-inheritance: + :inherited-members: + +.. automodule:: vertexai.preview.tuning.sft + :members: + :show-inheritance: + :inherited-members: + .. automodule:: vertexai.evaluation :members: :show-inheritance: diff --git a/google/cloud/aiplatform/gapic_version.py b/google/cloud/aiplatform/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/gapic_version.py +++ b/google/cloud/aiplatform/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index d3ddc86962..dc8e105ef6 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -3666,6 +3666,95 @@ def raw_predict( headers=headers_with_token, ) + def stream_raw_predict( + self, + body: bytes, + headers: Dict[str, str], + endpoint_override: Optional[str] = None, + ) -> Iterator[bytes]: + """Make a streaming prediction request using arbitrary headers. + + Example usage: + my_endpoint = aiplatform.PrivateEndpoint(ENDPOINT_ID) + + # Prepare the request body + request_body = json.dumps({...}).encode('utf-8') + + # Define the headers + headers = { + 'Content-Type': 'application/json', + } + + # Use stream_raw_predict to send the request and process the response + for stream_response in psc_endpoint.stream_raw_predict( + body=request_body, + headers=headers, + endpoint_override="10.128.0.26" # Replace with your actual endpoint + ): + stream_response_text = stream_response.decode('utf-8') + + Args: + body (bytes): + The body of the prediction request in bytes. This must not + exceed 10 mb per request. + headers (Dict[str, str]): + The header of the request as a dictionary. There are no + restrictions on the header. + endpoint_override (Optional[str]): + The Private Service Connect endpoint's IP address or DNS that + points to the endpoint's service attachment. + + Yields: + predictions (Iterator[bytes]): + The streaming prediction results as lines of bytes. + + Raises: + ValueError: If a endpoint override is not provided for PSC based + endpoint. + ValueError: If a endpoint override is invalid for PSC based endpoint. + """ + self.wait() + if self.network or not self.private_service_connect_config: + raise ValueError( + "PSA based private endpoint does not support streaming prediction." + ) + + if self.private_service_connect_config: + if not endpoint_override: + raise ValueError( + "Cannot make a predict request because endpoint override is" + "not provided. Please ensure an endpoint override is" + "provided." + ) + if not self._validate_endpoint_override(endpoint_override): + raise ValueError( + "Invalid endpoint override provided. Please only use IP" + "address or DNS." + ) + if not self.credentials.valid: + self.credentials.refresh(google_auth_requests.Request()) + + token = self.credentials.token + headers_with_token = dict(headers) + headers_with_token["Authorization"] = f"Bearer {token}" + + if not self.authorized_session: + self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES + self.authorized_session = google_auth_requests.AuthorizedSession( + self.credentials + ) + + url = f"https://{endpoint_override}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:streamRawPredict" + with self.authorized_session.post( + url=url, + data=body, + headers=headers_with_token, + stream=True, + verify=False, + ) as resp: + for line in resp.iter_lines(): + yield line + def explain(self): raise NotImplementedError( f"{self.__class__.__name__} class does not support 'explain' as of now." diff --git a/google/cloud/aiplatform/tensorboard/logdir_loader.py b/google/cloud/aiplatform/tensorboard/logdir_loader.py index ea9b6200f7..bd55e9505a 100644 --- a/google/cloud/aiplatform/tensorboard/logdir_loader.py +++ b/google/cloud/aiplatform/tensorboard/logdir_loader.py @@ -70,7 +70,6 @@ def synchronize_runs(self): runs_seen = set() for subdir in io_wrapper.GetLogdirSubdirectories(self._logdir): run = os.path.relpath(subdir, self._logdir) - run = run.replace("/", "-").replace("_", "-") runs_seen.add(run) if run not in self._directory_loaders: logger.info("- Adding run for relative directory %s", run) diff --git a/google/cloud/aiplatform/tensorboard/plugins/tf_profiler/profile_uploader.py b/google/cloud/aiplatform/tensorboard/plugins/tf_profiler/profile_uploader.py index f1967420ba..cb8f119172 100644 --- a/google/cloud/aiplatform/tensorboard/plugins/tf_profiler/profile_uploader.py +++ b/google/cloud/aiplatform/tensorboard/plugins/tf_profiler/profile_uploader.py @@ -176,7 +176,10 @@ def send_request(self, run_name: str): self._profile_dir(run_name) ) - tb_run = self._one_platform_resource_manager.get_run_resource_name(run_name) + experiment_run_name = uploader_utils.reformat_run_name(run_name) + tb_run = self._one_platform_resource_manager.get_run_resource_name( + experiment_run_name + ) if run_name not in self._run_to_file_request_sender: self._run_to_file_request_sender[ diff --git a/google/cloud/aiplatform/tensorboard/uploader.py b/google/cloud/aiplatform/tensorboard/uploader.py index 5e928ac812..c24455ae5e 100644 --- a/google/cloud/aiplatform/tensorboard/uploader.py +++ b/google/cloud/aiplatform/tensorboard/uploader.py @@ -417,7 +417,8 @@ def _pre_create_runs_and_time_series(self): plugin_data=metadata.plugin_data.content, ) - self._one_platform_resource_manager.batch_create_runs(run_names) + experiment_runs = [uploader_utils.reformat_run_name(run) for run in run_names] + self._one_platform_resource_manager.batch_create_runs(experiment_runs) self._one_platform_resource_manager.batch_create_time_series( run_tag_name_to_time_series_proto ) @@ -451,7 +452,9 @@ def _upload_once(self): ) run_to_events[profile_run_name] = None - self._experiment_runs = run_to_events.keys() + self._experiment_runs = [ + uploader_utils.reformat_run_name(run) for run in run_to_events.keys() + ] with self._tracker.send_tracker(): self._dispatcher.dispatch_requests(run_to_events) @@ -807,6 +810,7 @@ def _add_event_internal( metadata: tf.compat.v1.SummaryMetadata, ): self._num_values += 1 + run_name = uploader_utils.reformat_run_name(run_name) time_series_data_proto = self._run_to_tag_to_time_series_data[run_name].get( value.tag ) diff --git a/google/cloud/aiplatform/tensorboard/uploader_utils.py b/google/cloud/aiplatform/tensorboard/uploader_utils.py index c9723d0f89..8c5b9ec811 100644 --- a/google/cloud/aiplatform/tensorboard/uploader_utils.py +++ b/google/cloud/aiplatform/tensorboard/uploader_utils.py @@ -47,6 +47,11 @@ logger.setLevel(logging.WARNING) +def reformat_run_name(run_name: str) -> str: + """Reformats the run name to be compatible with One Platform.""" + return run_name.replace("/", "-").replace("_", "-") + + class RateLimiter: """Helper class for rate-limiting using a fixed minimum interval.""" diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/version.py b/google/cloud/aiplatform/version.py index c50dd04e4a..8df11f50b0 100644 --- a/google/cloud/aiplatform/version.py +++ b/google/cloud/aiplatform/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.64.0" +__version__ = "1.65.0" diff --git a/google/cloud/aiplatform/vertex_ray/cluster_init.py b/google/cloud/aiplatform/vertex_ray/cluster_init.py index da3da31449..bff297bd8a 100644 --- a/google/cloud/aiplatform/vertex_ray/cluster_init.py +++ b/google/cloud/aiplatform/vertex_ray/cluster_init.py @@ -61,6 +61,7 @@ def create_ray_cluster( enable_metrics_collection: Optional[bool] = True, enable_logging: Optional[bool] = True, psc_interface_config: Optional[resources.PscIConfig] = None, + reserved_ip_ranges: Optional[List[str]] = None, labels: Optional[Dict[str, str]] = None, ) -> str: """Create a ray cluster on the Vertex AI. @@ -126,6 +127,11 @@ def create_ray_cluster( enable_metrics_collection: Enable Ray metrics collection for visualization. enable_logging: Enable exporting Ray logs to Cloud Logging. psc_interface_config: PSC-I config. + reserved_ip_ranges: A list of names for the reserved IP ranges under + the VPC network that can be used for this cluster. If set, we will + deploy the cluster within the provided IP ranges. Otherwise, the + cluster is deployed to any IP ranges under the provided VPC network. + Example: ["vertex-ai-ip-range"]. labels: The labels with user-defined metadata to organize Ray cluster. @@ -325,6 +331,7 @@ def create_ray_cluster( labels=labels, resource_runtime_spec=resource_runtime_spec, psc_interface_config=gapic_psc_interface_config, + reserved_ip_ranges=reserved_ip_ranges, ) location = initializer.global_config.location diff --git a/google/cloud/aiplatform/vertex_ray/predict/xgboost/register.py b/google/cloud/aiplatform/vertex_ray/predict/xgboost/register.py index c93c6ce3e9..5aa35fa92c 100644 --- a/google/cloud/aiplatform/vertex_ray/predict/xgboost/register.py +++ b/google/cloud/aiplatform/vertex_ray/predict/xgboost/register.py @@ -43,7 +43,10 @@ import xgboost except ModuleNotFoundError as mnfe: - raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe + if ray.__version__ == "2.9.3": + raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe + else: + xgboost = None def register_xgboost( diff --git a/google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py b/google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py index 73e60c8121..8ed158e012 100644 --- a/google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py +++ b/google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py @@ -150,6 +150,7 @@ def persistent_resource_to_cluster( cluster = Cluster( cluster_resource_name=persistent_resource.name, network=persistent_resource.network, + reserved_ip_ranges=persistent_resource.reserved_ip_ranges, state=persistent_resource.state.name, labels=persistent_resource.labels, dashboard_address=dashboard_address, @@ -238,6 +239,11 @@ def persistent_resource_to_cluster( custom_image=head_image_uri, ) ) + if head_resource_pool.autoscaling_spec: + worker_node_types[0].autoscaling_spec = AutoscalingSpec( + min_replica_count=head_resource_pool.autoscaling_spec.min_replica_count, + max_replica_count=head_resource_pool.autoscaling_spec.max_replica_count, + ) for i in range(len(resource_pools) - 1): # Convert the second and more resource pools to vertex_ray.Resources, # and append then to worker_node_types. diff --git a/google/cloud/aiplatform/vertex_ray/util/resources.py b/google/cloud/aiplatform/vertex_ray/util/resources.py index 37fae9be8d..117cd69e1c 100644 --- a/google/cloud/aiplatform/vertex_ray/util/resources.py +++ b/google/cloud/aiplatform/vertex_ray/util/resources.py @@ -117,6 +117,11 @@ class Cluster: managed in the Vertex API service. For Ray Job API, VPC network is not required because cluster connection can be accessed through dashboard address. + reserved_ip_ranges: A list of names for the reserved IP ranges under + the VPC network that can be used for this cluster. If set, we will + deploy the cluster within the provided IP ranges. Otherwise, the + cluster is deployed to any IP ranges under the provided VPC network. + Example: ["vertex-ai-ip-range"]. service_account: Service account to be used for running Ray programs on the cluster. state: Describes the cluster state (defined in PersistentResource.State). @@ -140,6 +145,7 @@ class Cluster: cluster_resource_name: str = None network: str = None + reserved_ip_ranges: List[str] = None service_account: str = None state: PersistentResource.State = None python_version: str = None diff --git a/google/cloud/aiplatform_v1/gapic_version.py b/google/cloud/aiplatform_v1/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform_v1/gapic_version.py +++ b/google/cloud/aiplatform_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1/services/schedule_service/async_client.py b/google/cloud/aiplatform_v1/services/schedule_service/async_client.py index a8b76c8eba..db7cf616f1 100644 --- a/google/cloud/aiplatform_v1/services/schedule_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/schedule_service/async_client.py @@ -46,6 +46,7 @@ from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.schedule_service import pagers +from google.cloud.aiplatform_v1.types import notebook_service from google.cloud.aiplatform_v1.types import operation as gca_operation from google.cloud.aiplatform_v1.types import pipeline_service from google.cloud.aiplatform_v1.types import schedule @@ -88,6 +89,18 @@ class ScheduleServiceAsyncClient: parse_execution_path = staticmethod(ScheduleServiceClient.parse_execution_path) network_path = staticmethod(ScheduleServiceClient.network_path) parse_network_path = staticmethod(ScheduleServiceClient.parse_network_path) + notebook_execution_job_path = staticmethod( + ScheduleServiceClient.notebook_execution_job_path + ) + parse_notebook_execution_job_path = staticmethod( + ScheduleServiceClient.parse_notebook_execution_job_path + ) + notebook_runtime_template_path = staticmethod( + ScheduleServiceClient.notebook_runtime_template_path + ) + parse_notebook_runtime_template_path = staticmethod( + ScheduleServiceClient.parse_notebook_runtime_template_path + ) pipeline_job_path = staticmethod(ScheduleServiceClient.pipeline_job_path) parse_pipeline_job_path = staticmethod( ScheduleServiceClient.parse_pipeline_job_path diff --git a/google/cloud/aiplatform_v1/services/schedule_service/client.py b/google/cloud/aiplatform_v1/services/schedule_service/client.py index 33aa3f04d8..43a4c50a65 100644 --- a/google/cloud/aiplatform_v1/services/schedule_service/client.py +++ b/google/cloud/aiplatform_v1/services/schedule_service/client.py @@ -51,6 +51,7 @@ from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.schedule_service import pagers +from google.cloud.aiplatform_v1.types import notebook_service from google.cloud.aiplatform_v1.types import operation as gca_operation from google.cloud.aiplatform_v1.types import pipeline_service from google.cloud.aiplatform_v1.types import schedule @@ -311,6 +312,50 @@ def parse_network_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def notebook_execution_job_path( + project: str, + location: str, + notebook_execution_job: str, + ) -> str: + """Returns a fully-qualified notebook_execution_job string.""" + return "projects/{project}/locations/{location}/notebookExecutionJobs/{notebook_execution_job}".format( + project=project, + location=location, + notebook_execution_job=notebook_execution_job, + ) + + @staticmethod + def parse_notebook_execution_job_path(path: str) -> Dict[str, str]: + """Parses a notebook_execution_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/notebookExecutionJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def notebook_runtime_template_path( + project: str, + location: str, + notebook_runtime_template: str, + ) -> str: + """Returns a fully-qualified notebook_runtime_template string.""" + return "projects/{project}/locations/{location}/notebookRuntimeTemplates/{notebook_runtime_template}".format( + project=project, + location=location, + notebook_runtime_template=notebook_runtime_template, + ) + + @staticmethod + def parse_notebook_runtime_template_path(path: str) -> Dict[str, str]: + """Parses a notebook_runtime_template path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/notebookRuntimeTemplates/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def pipeline_job_path( project: str, diff --git a/google/cloud/aiplatform_v1/types/custom_job.py b/google/cloud/aiplatform_v1/types/custom_job.py index 709f7261f5..3068f16a91 100644 --- a/google/cloud/aiplatform_v1/types/custom_job.py +++ b/google/cloud/aiplatform_v1/types/custom_job.py @@ -558,6 +558,12 @@ class Scheduling(proto.Message): Optional. Indicates if the job should retry for internal errors after the job starts running. If true, overrides ``Scheduling.restart_job_on_worker_restart`` to false. + max_wait_duration (google.protobuf.duration_pb2.Duration): + Optional. This is the maximum duration that a job will wait + for the requested resources to be provisioned if the + scheduling strategy is set to [Strategy.DWS_FLEX_START]. If + set to 0, the job will wait indefinitely. The default is 24 + hours. """ class Strategy(proto.Enum): @@ -606,6 +612,11 @@ class Strategy(proto.Enum): proto.BOOL, number=5, ) + max_wait_duration: duration_pb2.Duration = proto.Field( + proto.MESSAGE, + number=6, + message=duration_pb2.Duration, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/schedule.py b/google/cloud/aiplatform_v1/types/schedule.py index 5d6e4277d1..8cd50a40fd 100644 --- a/google/cloud/aiplatform_v1/types/schedule.py +++ b/google/cloud/aiplatform_v1/types/schedule.py @@ -19,6 +19,7 @@ import proto # type: ignore +from google.cloud.aiplatform_v1.types import notebook_service from google.cloud.aiplatform_v1.types import pipeline_service from google.protobuf import timestamp_pb2 # type: ignore @@ -36,6 +37,10 @@ class Schedule(proto.Message): API calls based on user specified time specification and API request type. + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields @@ -56,6 +61,11 @@ class Schedule(proto.Message): CreatePipelineJobRequest.parent field is required (format: projects/{project}/locations/{location}). + This field is a member of `oneof`_ ``request``. + create_notebook_execution_job_request (google.cloud.aiplatform_v1.types.CreateNotebookExecutionJobRequest): + Request for + [NotebookService.CreateNotebookExecutionJob][google.cloud.aiplatform.v1.NotebookService.CreateNotebookExecutionJob]. + This field is a member of `oneof`_ ``request``. name (str): Immutable. The resource name of the Schedule. @@ -186,6 +196,12 @@ class RunResponse(proto.Message): message=pipeline_service.CreatePipelineJobRequest, ) ) + create_notebook_execution_job_request: notebook_service.CreateNotebookExecutionJobRequest = proto.Field( + proto.MESSAGE, + number=20, + oneof="request", + message=notebook_service.CreateNotebookExecutionJobRequest, + ) name: str = proto.Field( proto.STRING, number=1, diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index 2fb939444e..3bd01743e9 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -1002,12 +1002,15 @@ from .types.unmanaged_container_model import UnmanagedContainerModel from .types.user_action_reference import UserActionReference from .types.value import Value +from .types.vertex_rag_data import CorpusStatus +from .types.vertex_rag_data import FileStatus from .types.vertex_rag_data import ImportRagFilesConfig from .types.vertex_rag_data import RagCorpus from .types.vertex_rag_data import RagEmbeddingModelConfig from .types.vertex_rag_data import RagFile from .types.vertex_rag_data import RagFileChunkingConfig from .types.vertex_rag_data import RagFileParsingConfig +from .types.vertex_rag_data import RagVectorDbConfig from .types.vertex_rag_data import UploadRagFileConfig from .types.vertex_rag_data_service import CreateRagCorpusOperationMetadata from .types.vertex_rag_data_service import CreateRagCorpusRequest @@ -1022,6 +1025,8 @@ from .types.vertex_rag_data_service import ListRagCorporaResponse from .types.vertex_rag_data_service import ListRagFilesRequest from .types.vertex_rag_data_service import ListRagFilesResponse +from .types.vertex_rag_data_service import UpdateRagCorpusOperationMetadata +from .types.vertex_rag_data_service import UpdateRagCorpusRequest from .types.vertex_rag_data_service import UploadRagFileRequest from .types.vertex_rag_data_service import UploadRagFileResponse from .types.vertex_rag_service import RagContexts @@ -1175,6 +1180,7 @@ "CopyModelOperationMetadata", "CopyModelRequest", "CopyModelResponse", + "CorpusStatus", "CountTokensRequest", "CountTokensResponse", "CreateArtifactRequest", @@ -1403,6 +1409,7 @@ "FetchFeatureValuesRequest", "FetchFeatureValuesResponse", "FileData", + "FileStatus", "FilterSplit", "FindNeighborsRequest", "FindNeighborsResponse", @@ -1799,6 +1806,7 @@ "RagFileChunkingConfig", "RagFileParsingConfig", "RagQuery", + "RagVectorDbConfig", "RawPredictRequest", "RayLogsSpec", "RayMetricSpec", @@ -2013,6 +2021,8 @@ "UpdateNotebookRuntimeTemplateRequest", "UpdatePersistentResourceOperationMetadata", "UpdatePersistentResourceRequest", + "UpdateRagCorpusOperationMetadata", + "UpdateRagCorpusRequest", "UpdateReasoningEngineOperationMetadata", "UpdateReasoningEngineRequest", "UpdateScheduleRequest", diff --git a/google/cloud/aiplatform_v1beta1/gapic_metadata.json b/google/cloud/aiplatform_v1beta1/gapic_metadata.json index 669d08cbc2..dfecadeb83 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_metadata.json +++ b/google/cloud/aiplatform_v1beta1/gapic_metadata.json @@ -5368,6 +5368,11 @@ "list_rag_files" ] }, + "UpdateRagCorpus": { + "methods": [ + "update_rag_corpus" + ] + }, "UploadRagFile": { "methods": [ "upload_rag_file" @@ -5418,6 +5423,11 @@ "list_rag_files" ] }, + "UpdateRagCorpus": { + "methods": [ + "update_rag_corpus" + ] + }, "UploadRagFile": { "methods": [ "upload_rag_file" @@ -5468,6 +5478,11 @@ "list_rag_files" ] }, + "UpdateRagCorpus": { + "methods": [ + "update_rag_corpus" + ] + }, "UploadRagFile": { "methods": [ "upload_rag_file" diff --git a/google/cloud/aiplatform_v1beta1/gapic_version.py b/google/cloud/aiplatform_v1beta1/gapic_version.py index 8fd4500cf1..b3833ad200 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.64.0" # {x-release-please-version} +__version__ = "1.65.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py index 9ebb5e6e63..3c6d7b8f85 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py @@ -417,6 +417,133 @@ async def sample_create_rag_corpus(): # Done; return the response. return response + async def update_rag_corpus( + self, + request: Optional[ + Union[vertex_rag_data_service.UpdateRagCorpusRequest, dict] + ] = None, + *, + rag_corpus: Optional[vertex_rag_data.RagCorpus] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Updates a RagCorpus. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + async def sample_update_rag_corpus(): + # Create a client + client = aiplatform_v1beta1.VertexRagDataServiceAsyncClient() + + # Initialize request argument(s) + rag_corpus = aiplatform_v1beta1.RagCorpus() + rag_corpus.display_name = "display_name_value" + + request = aiplatform_v1beta1.UpdateRagCorpusRequest( + rag_corpus=rag_corpus, + ) + + # Make the request + operation = client.update_rag_corpus(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1beta1.types.UpdateRagCorpusRequest, dict]]): + The request object. Request message for + [VertexRagDataService.UpdateRagCorpus][google.cloud.aiplatform.v1beta1.VertexRagDataService.UpdateRagCorpus]. + rag_corpus (:class:`google.cloud.aiplatform_v1beta1.types.RagCorpus`): + Required. The RagCorpus which + replaces the resource on the server. + + This corresponds to the ``rag_corpus`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.RagCorpus` A RagCorpus is a RagFile container and a project can have multiple + RagCorpora. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([rag_corpus]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vertex_rag_data_service.UpdateRagCorpusRequest): + request = vertex_rag_data_service.UpdateRagCorpusRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if rag_corpus is not None: + request.rag_corpus = rag_corpus + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_rag_corpus + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("rag_corpus.name", request.rag_corpus.name),) + ), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + vertex_rag_data.RagCorpus, + metadata_type=vertex_rag_data_service.UpdateRagCorpusOperationMetadata, + ) + + # Done; return the response. + return response + async def get_rag_corpus( self, request: Optional[ diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py index 23e4fa5dce..23618833d5 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py @@ -915,6 +915,130 @@ def sample_create_rag_corpus(): # Done; return the response. return response + def update_rag_corpus( + self, + request: Optional[ + Union[vertex_rag_data_service.UpdateRagCorpusRequest, dict] + ] = None, + *, + rag_corpus: Optional[vertex_rag_data.RagCorpus] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Updates a RagCorpus. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + def sample_update_rag_corpus(): + # Create a client + client = aiplatform_v1beta1.VertexRagDataServiceClient() + + # Initialize request argument(s) + rag_corpus = aiplatform_v1beta1.RagCorpus() + rag_corpus.display_name = "display_name_value" + + request = aiplatform_v1beta1.UpdateRagCorpusRequest( + rag_corpus=rag_corpus, + ) + + # Make the request + operation = client.update_rag_corpus(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1beta1.types.UpdateRagCorpusRequest, dict]): + The request object. Request message for + [VertexRagDataService.UpdateRagCorpus][google.cloud.aiplatform.v1beta1.VertexRagDataService.UpdateRagCorpus]. + rag_corpus (google.cloud.aiplatform_v1beta1.types.RagCorpus): + Required. The RagCorpus which + replaces the resource on the server. + + This corresponds to the ``rag_corpus`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.RagCorpus` A RagCorpus is a RagFile container and a project can have multiple + RagCorpora. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([rag_corpus]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vertex_rag_data_service.UpdateRagCorpusRequest): + request = vertex_rag_data_service.UpdateRagCorpusRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if rag_corpus is not None: + request.rag_corpus = rag_corpus + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_rag_corpus] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("rag_corpus.name", request.rag_corpus.name),) + ), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + vertex_rag_data.RagCorpus, + metadata_type=vertex_rag_data_service.UpdateRagCorpusOperationMetadata, + ) + + # Done; return the response. + return response + def get_rag_corpus( self, request: Optional[ diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/base.py index c96aaec4ab..5023b56a66 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/base.py @@ -140,6 +140,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.update_rag_corpus: gapic_v1.method.wrap_method( + self.update_rag_corpus, + default_timeout=None, + client_info=client_info, + ), self.get_rag_corpus: gapic_v1.method.wrap_method( self.get_rag_corpus, default_timeout=None, @@ -205,6 +210,15 @@ def create_rag_corpus( ]: raise NotImplementedError() + @property + def update_rag_corpus( + self, + ) -> Callable[ + [vertex_rag_data_service.UpdateRagCorpusRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + @property def get_rag_corpus( self, diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc.py index adddd14811..61b5e95563 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc.py @@ -284,6 +284,34 @@ def create_rag_corpus( ) return self._stubs["create_rag_corpus"] + @property + def update_rag_corpus( + self, + ) -> Callable[ + [vertex_rag_data_service.UpdateRagCorpusRequest], operations_pb2.Operation + ]: + r"""Return a callable for the update rag corpus method over gRPC. + + Updates a RagCorpus. + + Returns: + Callable[[~.UpdateRagCorpusRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_rag_corpus" not in self._stubs: + self._stubs["update_rag_corpus"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VertexRagDataService/UpdateRagCorpus", + request_serializer=vertex_rag_data_service.UpdateRagCorpusRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["update_rag_corpus"] + @property def get_rag_corpus( self, diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc_asyncio.py index 13d74aca8a..f804a8a7fd 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc_asyncio.py @@ -291,6 +291,35 @@ def create_rag_corpus( ) return self._stubs["create_rag_corpus"] + @property + def update_rag_corpus( + self, + ) -> Callable[ + [vertex_rag_data_service.UpdateRagCorpusRequest], + Awaitable[operations_pb2.Operation], + ]: + r"""Return a callable for the update rag corpus method over gRPC. + + Updates a RagCorpus. + + Returns: + Callable[[~.UpdateRagCorpusRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_rag_corpus" not in self._stubs: + self._stubs["update_rag_corpus"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VertexRagDataService/UpdateRagCorpus", + request_serializer=vertex_rag_data_service.UpdateRagCorpusRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["update_rag_corpus"] + @property def get_rag_corpus( self, @@ -531,6 +560,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.update_rag_corpus: gapic_v1.method_async.wrap_method( + self.update_rag_corpus, + default_timeout=None, + client_info=client_info, + ), self.get_rag_corpus: gapic_v1.method_async.wrap_method( self.get_rag_corpus, default_timeout=None, diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/rest.py index f623649e32..856524bcce 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/rest.py @@ -139,6 +139,14 @@ def post_list_rag_files(self, response): logging.log(f"Received response: {response}") return response + def pre_update_rag_corpus(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_update_rag_corpus(self, response): + logging.log(f"Received response: {response}") + return response + def pre_upload_rag_file(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -345,6 +353,31 @@ def post_list_rag_files( """ return response + def pre_update_rag_corpus( + self, + request: vertex_rag_data_service.UpdateRagCorpusRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + vertex_rag_data_service.UpdateRagCorpusRequest, Sequence[Tuple[str, str]] + ]: + """Pre-rpc interceptor for update_rag_corpus + + Override in a subclass to manipulate the request or metadata + before they are sent to the VertexRagDataService server. + """ + return request, metadata + + def post_update_rag_corpus( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for update_rag_corpus + + Override in a subclass to manipulate the response + after it is returned by the VertexRagDataService server but before + it is returned to user code. + """ + return response + def pre_upload_rag_file( self, request: vertex_rag_data_service.UploadRagFileRequest, @@ -3465,6 +3498,102 @@ def __call__( resp = self._interceptor.post_list_rag_files(resp) return resp + class _UpdateRagCorpus(VertexRagDataServiceRestStub): + def __hash__(self): + return hash("UpdateRagCorpus") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: vertex_rag_data_service.UpdateRagCorpusRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the update rag corpus method over HTTP. + + Args: + request (~.vertex_rag_data_service.UpdateRagCorpusRequest): + The request object. Request message for + [VertexRagDataService.UpdateRagCorpus][google.cloud.aiplatform.v1beta1.VertexRagDataService.UpdateRagCorpus]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "patch", + "uri": "/v1beta1/{rag_corpus.name=projects/*/locations/*/ragCorpora/*}", + "body": "rag_corpus", + }, + ] + request, metadata = self._interceptor.pre_update_rag_corpus( + request, metadata + ) + pb_request = vertex_rag_data_service.UpdateRagCorpusRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + resp = self._interceptor.post_update_rag_corpus(resp) + return resp + class _UploadRagFile(VertexRagDataServiceRestStub): def __hash__(self): return hash("UploadRagFile") @@ -3640,6 +3769,16 @@ def list_rag_files( # In C++ this would require a dynamic_cast return self._ListRagFiles(self._session, self._host, self._interceptor) # type: ignore + @property + def update_rag_corpus( + self, + ) -> Callable[ + [vertex_rag_data_service.UpdateRagCorpusRequest], operations_pb2.Operation + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._UpdateRagCorpus(self._session, self._host, self._interceptor) # type: ignore + @property def upload_rag_file( self, diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index c2bf840843..7ed18b6fd2 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -1141,12 +1141,15 @@ Value, ) from .vertex_rag_data import ( + CorpusStatus, + FileStatus, ImportRagFilesConfig, RagCorpus, RagEmbeddingModelConfig, RagFile, RagFileChunkingConfig, RagFileParsingConfig, + RagVectorDbConfig, UploadRagFileConfig, ) from .vertex_rag_data_service import ( @@ -1163,6 +1166,8 @@ ListRagCorporaResponse, ListRagFilesRequest, ListRagFilesResponse, + UpdateRagCorpusOperationMetadata, + UpdateRagCorpusRequest, UploadRagFileRequest, UploadRagFileResponse, ) @@ -2079,12 +2084,15 @@ "UnmanagedContainerModel", "UserActionReference", "Value", + "CorpusStatus", + "FileStatus", "ImportRagFilesConfig", "RagCorpus", "RagEmbeddingModelConfig", "RagFile", "RagFileChunkingConfig", "RagFileParsingConfig", + "RagVectorDbConfig", "UploadRagFileConfig", "CreateRagCorpusOperationMetadata", "CreateRagCorpusRequest", @@ -2099,6 +2107,8 @@ "ListRagCorporaResponse", "ListRagFilesRequest", "ListRagFilesResponse", + "UpdateRagCorpusOperationMetadata", + "UpdateRagCorpusRequest", "UploadRagFileRequest", "UploadRagFileResponse", "RagContexts", diff --git a/google/cloud/aiplatform_v1beta1/types/api_auth.py b/google/cloud/aiplatform_v1beta1/types/api_auth.py index dc2a7ebe9a..2e7eb49e24 100644 --- a/google/cloud/aiplatform_v1beta1/types/api_auth.py +++ b/google/cloud/aiplatform_v1beta1/types/api_auth.py @@ -29,7 +29,16 @@ class ApiAuth(proto.Message): - r"""The generic reusable api auth config.""" + r"""The generic reusable api auth config. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + api_key_config (google.cloud.aiplatform_v1beta1.types.ApiAuth.ApiKeyConfig): + The API secret. + + This field is a member of `oneof`_ ``auth_config``. + """ class ApiKeyConfig(proto.Message): r"""The API secret. @@ -46,5 +55,12 @@ class ApiKeyConfig(proto.Message): number=1, ) + api_key_config: ApiKeyConfig = proto.Field( + proto.MESSAGE, + number=1, + oneof="auth_config", + message=ApiKeyConfig, + ) + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py index 348f722971..df3118413a 100644 --- a/google/cloud/aiplatform_v1beta1/types/custom_job.py +++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py @@ -558,6 +558,12 @@ class Scheduling(proto.Message): Optional. Indicates if the job should retry for internal errors after the job starts running. If true, overrides ``Scheduling.restart_job_on_worker_restart`` to false. + max_wait_duration (google.protobuf.duration_pb2.Duration): + Optional. This is the maximum duration that a job will wait + for the requested resources to be provisioned if the + scheduling strategy is set to [Strategy.DWS_FLEX_START]. If + set to 0, the job will wait indefinitely. The default is 24 + hours. """ class Strategy(proto.Enum): @@ -606,6 +612,11 @@ class Strategy(proto.Enum): proto.BOOL, number=5, ) + max_wait_duration: duration_pb2.Duration = proto.Field( + proto.MESSAGE, + number=6, + message=duration_pb2.Duration, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py index fb79ed0e90..787dd96a6d 100644 --- a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py +++ b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py @@ -19,6 +19,7 @@ import proto # type: ignore +from google.cloud.aiplatform_v1beta1.types import api_auth as gca_api_auth from google.cloud.aiplatform_v1beta1.types import io from google.protobuf import timestamp_pb2 # type: ignore @@ -27,6 +28,9 @@ package="google.cloud.aiplatform.v1beta1", manifest={ "RagEmbeddingModelConfig", + "RagVectorDbConfig", + "FileStatus", + "CorpusStatus", "RagCorpus", "RagFile", "RagFileChunkingConfig", @@ -40,6 +44,11 @@ class RagEmbeddingModelConfig(proto.Message): r"""Config for the embedding model to use for RAG. + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields Attributes: @@ -49,6 +58,11 @@ class RagEmbeddingModelConfig(proto.Message): is hosting a 1P fine-tuned text embedding model. Endpoints hosting non-1P fine-tuned text embedding models are currently not supported. + This is used for dense vector search. + + This field is a member of `oneof`_ ``model_config``. + hybrid_search_config (google.cloud.aiplatform_v1beta1.types.RagEmbeddingModelConfig.HybridSearchConfig): + Configuration for hybrid search. This field is a member of `oneof`_ ``model_config``. """ @@ -87,12 +101,278 @@ class VertexPredictionEndpoint(proto.Message): number=3, ) + class SparseEmbeddingConfig(proto.Message): + r"""Configuration for sparse emebdding generation. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + bm25 (google.cloud.aiplatform_v1beta1.types.RagEmbeddingModelConfig.SparseEmbeddingConfig.Bm25): + Use BM25 scoring algorithm. + + This field is a member of `oneof`_ ``model``. + """ + + class Bm25(proto.Message): + r"""Message for BM25 parameters. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + multilingual (bool): + Optional. Use multilingual tokenizer if set + to true. + k1 (float): + Optional. The parameter to control term frequency + saturation. It determines the scaling between the matching + term frequency and final score. k1 is in the range of [1.2, + 3]. The default value is 1.2. + + This field is a member of `oneof`_ ``_k1``. + b (float): + Optional. The parameter to control document length + normalization. It determines how much the document length + affects the final score. b is in the range of [0, 1]. The + default value is 0.75. + + This field is a member of `oneof`_ ``_b``. + """ + + multilingual: bool = proto.Field( + proto.BOOL, + number=1, + ) + k1: float = proto.Field( + proto.FLOAT, + number=2, + optional=True, + ) + b: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + + bm25: "RagEmbeddingModelConfig.SparseEmbeddingConfig.Bm25" = proto.Field( + proto.MESSAGE, + number=1, + oneof="model", + message="RagEmbeddingModelConfig.SparseEmbeddingConfig.Bm25", + ) + + class HybridSearchConfig(proto.Message): + r"""Config for hybrid search. + + Attributes: + sparse_embedding_config (google.cloud.aiplatform_v1beta1.types.RagEmbeddingModelConfig.SparseEmbeddingConfig): + Optional. The configuration for sparse + embedding generation. This field is optional the + default behavior depends on the vector database + choice on the RagCorpus. + dense_embedding_model_prediction_endpoint (google.cloud.aiplatform_v1beta1.types.RagEmbeddingModelConfig.VertexPredictionEndpoint): + Required. The Vertex AI Prediction Endpoint + that hosts the embedding model for dense + embedding generations. + """ + + sparse_embedding_config: "RagEmbeddingModelConfig.SparseEmbeddingConfig" = ( + proto.Field( + proto.MESSAGE, + number=1, + message="RagEmbeddingModelConfig.SparseEmbeddingConfig", + ) + ) + dense_embedding_model_prediction_endpoint: "RagEmbeddingModelConfig.VertexPredictionEndpoint" = proto.Field( + proto.MESSAGE, + number=2, + message="RagEmbeddingModelConfig.VertexPredictionEndpoint", + ) + vertex_prediction_endpoint: VertexPredictionEndpoint = proto.Field( proto.MESSAGE, number=1, oneof="model_config", message=VertexPredictionEndpoint, ) + hybrid_search_config: HybridSearchConfig = proto.Field( + proto.MESSAGE, + number=2, + oneof="model_config", + message=HybridSearchConfig, + ) + + +class RagVectorDbConfig(proto.Message): + r"""Config for the Vector DB to use for RAG. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + rag_managed_db (google.cloud.aiplatform_v1beta1.types.RagVectorDbConfig.RagManagedDb): + The config for the RAG-managed Vector DB. + + This field is a member of `oneof`_ ``vector_db``. + weaviate (google.cloud.aiplatform_v1beta1.types.RagVectorDbConfig.Weaviate): + The config for the Weaviate. + + This field is a member of `oneof`_ ``vector_db``. + vertex_feature_store (google.cloud.aiplatform_v1beta1.types.RagVectorDbConfig.VertexFeatureStore): + The config for the Vertex Feature Store. + + This field is a member of `oneof`_ ``vector_db``. + api_auth (google.cloud.aiplatform_v1beta1.types.ApiAuth): + Authentication config for the chosen Vector + DB. + """ + + class RagManagedDb(proto.Message): + r"""The config for the default RAG-managed Vector DB.""" + + class Weaviate(proto.Message): + r"""The config for the Weaviate. + + Attributes: + http_endpoint (str): + Weaviate DB instance HTTP endpoint. e.g. + 34.56.78.90:8080 Vertex RAG only supports HTTP + connection to Weaviate. This value cannot be + changed after it's set. + collection_name (str): + The corresponding collection this corpus maps + to. This value cannot be changed after it's set. + """ + + http_endpoint: str = proto.Field( + proto.STRING, + number=1, + ) + collection_name: str = proto.Field( + proto.STRING, + number=2, + ) + + class VertexFeatureStore(proto.Message): + r"""The config for the Vertex Feature Store. + + Attributes: + feature_view_resource_name (str): + The resource name of the FeatureView. Format: + ``projects/{project}/locations/{location}/featureOnlineStores/{feature_online_store}/featureViews/{feature_view}`` + """ + + feature_view_resource_name: str = proto.Field( + proto.STRING, + number=1, + ) + + rag_managed_db: RagManagedDb = proto.Field( + proto.MESSAGE, + number=1, + oneof="vector_db", + message=RagManagedDb, + ) + weaviate: Weaviate = proto.Field( + proto.MESSAGE, + number=2, + oneof="vector_db", + message=Weaviate, + ) + vertex_feature_store: VertexFeatureStore = proto.Field( + proto.MESSAGE, + number=4, + oneof="vector_db", + message=VertexFeatureStore, + ) + api_auth: gca_api_auth.ApiAuth = proto.Field( + proto.MESSAGE, + number=5, + message=gca_api_auth.ApiAuth, + ) + + +class FileStatus(proto.Message): + r"""RagFile status. + + Attributes: + state (google.cloud.aiplatform_v1beta1.types.FileStatus.State): + Output only. RagFile state. + error_status (str): + Output only. Only when the ``state`` field is ERROR. + """ + + class State(proto.Enum): + r"""RagFile state. + + Values: + STATE_UNSPECIFIED (0): + RagFile state is unspecified. + ACTIVE (1): + RagFile resource has been created and indexed + successfully. + ERROR (2): + RagFile resource is in a problematic state. See + ``error_message`` field for details. + """ + STATE_UNSPECIFIED = 0 + ACTIVE = 1 + ERROR = 2 + + state: State = proto.Field( + proto.ENUM, + number=1, + enum=State, + ) + error_status: str = proto.Field( + proto.STRING, + number=2, + ) + + +class CorpusStatus(proto.Message): + r"""RagCorpus status. + + Attributes: + state (google.cloud.aiplatform_v1beta1.types.CorpusStatus.State): + Output only. RagCorpus life state. + error_status (str): + Output only. Only when the ``state`` field is ERROR. + """ + + class State(proto.Enum): + r"""RagCorpus life state. + + Values: + UNKNOWN (0): + This state is not supposed to happen. + INITIALIZED (1): + RagCorpus resource entry is initialized, but + hasn't done validation. + ACTIVE (2): + RagCorpus is provisioned successfully and is + ready to serve. + ERROR (3): + RagCorpus is in a problematic situation. See + ``error_message`` field for details. + """ + UNKNOWN = 0 + INITIALIZED = 1 + ACTIVE = 2 + ERROR = 3 + + state: State = proto.Field( + proto.ENUM, + number=1, + enum=State, + ) + error_status: str = proto.Field( + proto.STRING, + number=2, + ) class RagCorpus(proto.Message): @@ -112,12 +392,17 @@ class RagCorpus(proto.Message): rag_embedding_model_config (google.cloud.aiplatform_v1beta1.types.RagEmbeddingModelConfig): Optional. Immutable. The embedding model config of the RagCorpus. + rag_vector_db_config (google.cloud.aiplatform_v1beta1.types.RagVectorDbConfig): + Optional. Immutable. The Vector DB config of + the RagCorpus. create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this RagCorpus was created. update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this RagCorpus was last updated. + corpus_status (google.cloud.aiplatform_v1beta1.types.CorpusStatus): + Output only. RagCorpus state. """ name: str = proto.Field( @@ -137,6 +422,11 @@ class RagCorpus(proto.Message): number=6, message="RagEmbeddingModelConfig", ) + rag_vector_db_config: "RagVectorDbConfig" = proto.Field( + proto.MESSAGE, + number=7, + message="RagVectorDbConfig", + ) create_time: timestamp_pb2.Timestamp = proto.Field( proto.MESSAGE, number=4, @@ -147,6 +437,11 @@ class RagCorpus(proto.Message): number=5, message=timestamp_pb2.Timestamp, ) + corpus_status: "CorpusStatus" = proto.Field( + proto.MESSAGE, + number=8, + message="CorpusStatus", + ) class RagFile(proto.Message): @@ -206,6 +501,8 @@ class RagFile(proto.Message): update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this RagFile was last updated. + file_status (google.cloud.aiplatform_v1beta1.types.FileStatus): + Output only. State of the RagFile. """ class RagFileType(proto.Enum): @@ -284,6 +581,11 @@ class RagFileType(proto.Enum): number=7, message=timestamp_pb2.Timestamp, ) + file_status: "FileStatus" = proto.Field( + proto.MESSAGE, + number=13, + message="FileStatus", + ) class RagFileChunkingConfig(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data_service.py b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data_service.py index 25e828db4a..94f0621ba4 100644 --- a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data_service.py +++ b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data_service.py @@ -41,6 +41,8 @@ "ListRagFilesResponse", "DeleteRagFileRequest", "CreateRagCorpusOperationMetadata", + "UpdateRagCorpusRequest", + "UpdateRagCorpusOperationMetadata", "ImportRagFilesOperationMetadata", }, ) @@ -412,6 +414,39 @@ class CreateRagCorpusOperationMetadata(proto.Message): ) +class UpdateRagCorpusRequest(proto.Message): + r"""Request message for + [VertexRagDataService.UpdateRagCorpus][google.cloud.aiplatform.v1beta1.VertexRagDataService.UpdateRagCorpus]. + + Attributes: + rag_corpus (google.cloud.aiplatform_v1beta1.types.RagCorpus): + Required. The RagCorpus which replaces the + resource on the server. + """ + + rag_corpus: vertex_rag_data.RagCorpus = proto.Field( + proto.MESSAGE, + number=1, + message=vertex_rag_data.RagCorpus, + ) + + +class UpdateRagCorpusOperationMetadata(proto.Message): + r"""Runtime operation information for + [VertexRagDataService.UpdateRagCorpus][google.cloud.aiplatform.v1beta1.VertexRagDataService.UpdateRagCorpus]. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + The operation generic information. + """ + + generic_metadata: operation.GenericOperationMetadata = proto.Field( + proto.MESSAGE, + number=1, + message=operation.GenericOperationMetadata, + ) + + class ImportRagFilesOperationMetadata(proto.Message): r"""Runtime operation information for [VertexRagDataService.ImportRagFiles][google.cloud.aiplatform.v1beta1.VertexRagDataService.ImportRagFiles]. diff --git a/google/cloud/aiplatform_v1beta1/types/vertex_rag_service.py b/google/cloud/aiplatform_v1beta1/types/vertex_rag_service.py index e33f30e95c..2cbb95c76c 100644 --- a/google/cloud/aiplatform_v1beta1/types/vertex_rag_service.py +++ b/google/cloud/aiplatform_v1beta1/types/vertex_rag_service.py @@ -44,8 +44,33 @@ class RagQuery(proto.Message): This field is a member of `oneof`_ ``query``. similarity_top_k (int): Optional. The number of contexts to retrieve. + ranking (google.cloud.aiplatform_v1beta1.types.RagQuery.Ranking): + Optional. Configurations for hybrid search + results ranking. """ + class Ranking(proto.Message): + r"""Configurations for hybrid search results ranking. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + alpha (float): + Optional. Alpha value controls the weight between dense and + sparse vector search results. The range is [0, 1], while 0 + means sparse vector search only and 1 means dense vector + search only. The default value is 0.5 which balances sparse + and dense vector search equally. + + This field is a member of `oneof`_ ``_alpha``. + """ + + alpha: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + text: str = proto.Field( proto.STRING, number=1, @@ -55,6 +80,11 @@ class RagQuery(proto.Message): proto.INT32, number=2, ) + ranking: Ranking = proto.Field( + proto.MESSAGE, + number=4, + message=Ranking, + ) class RetrieveContextsRequest(proto.Message): @@ -175,8 +205,11 @@ class Context(proto.Message): text (str): The text chunk. distance (float): - The distance between the query vector and the - context text vector. + The distance between the query dense + embedding vector and the context text vector. + sparse_distance (float): + The distance between the query sparse + embedding vector and the context text vector. """ source_uri: str = proto.Field( @@ -191,6 +224,10 @@ class Context(proto.Message): proto.DOUBLE, number=3, ) + sparse_distance: float = proto.Field( + proto.DOUBLE, + number=4, + ) contexts: MutableSequence[Context] = proto.RepeatedField( proto.MESSAGE, diff --git a/owlbot.py b/owlbot.py index ea1a2b9795..c8a1afe6ec 100644 --- a/owlbot.py +++ b/owlbot.py @@ -112,6 +112,7 @@ ".kokoro/continuous/prerelease-deps.cfg", ".kokoro/presubmit/prerelease-deps.cfg", ".kokoro/docs/docs-presubmit.cfg", + ".kokoro/build.sh", ".kokoro/release.sh", ".kokoro/release/common.cfg", ".kokoro/requirements*", diff --git a/pypi/_vertex_ai_placeholder/version.py b/pypi/_vertex_ai_placeholder/version.py index dc62baac78..7979098357 100644 --- a/pypi/_vertex_ai_placeholder/version.py +++ b/pypi/_vertex_ai_placeholder/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.64.0" +__version__ = "1.65.0" diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_async.py b/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_async.py new file mode 100644 index 0000000000..55bde0360a --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_async.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateRagCorpus +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_VertexRagDataService_UpdateRagCorpus_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +async def sample_update_rag_corpus(): + # Create a client + client = aiplatform_v1beta1.VertexRagDataServiceAsyncClient() + + # Initialize request argument(s) + rag_corpus = aiplatform_v1beta1.RagCorpus() + rag_corpus.display_name = "display_name_value" + + request = aiplatform_v1beta1.UpdateRagCorpusRequest( + rag_corpus=rag_corpus, + ) + + # Make the request + operation = client.update_rag_corpus(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + +# [END aiplatform_v1beta1_generated_VertexRagDataService_UpdateRagCorpus_async] diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_sync.py b/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_sync.py new file mode 100644 index 0000000000..81f3e7fb57 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_sync.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateRagCorpus +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_VertexRagDataService_UpdateRagCorpus_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +def sample_update_rag_corpus(): + # Create a client + client = aiplatform_v1beta1.VertexRagDataServiceClient() + + # Initialize request argument(s) + rag_corpus = aiplatform_v1beta1.RagCorpus() + rag_corpus.display_name = "display_name_value" + + request = aiplatform_v1beta1.UpdateRagCorpusRequest( + rag_corpus=rag_corpus, + ) + + # Make the request + operation = client.update_rag_corpus(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + +# [END aiplatform_v1beta1_generated_VertexRagDataService_UpdateRagCorpus_sync] diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index c4b8e9f3fc..e824509826 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.64.0" + "version": "1.65.0" }, "snippets": [ { diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json index 86cee3f831..c2a656078a 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.64.0" + "version": "1.65.0" }, "snippets": [ { @@ -52847,6 +52847,167 @@ ], "title": "aiplatform_v1beta1_generated_vertex_rag_data_service_list_rag_files_sync.py" }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.VertexRagDataServiceAsyncClient", + "shortName": "VertexRagDataServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.VertexRagDataServiceAsyncClient.update_rag_corpus", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.VertexRagDataService.UpdateRagCorpus", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.VertexRagDataService", + "shortName": "VertexRagDataService" + }, + "shortName": "UpdateRagCorpus" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.UpdateRagCorpusRequest" + }, + { + "name": "rag_corpus", + "type": "google.cloud.aiplatform_v1beta1.types.RagCorpus" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.api_core.operation_async.AsyncOperation", + "shortName": "update_rag_corpus" + }, + "description": "Sample for UpdateRagCorpus", + "file": "aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_VertexRagDataService_UpdateRagCorpus_async", + "segments": [ + { + "end": 58, + "start": 27, + "type": "FULL" + }, + { + "end": 58, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 48, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 55, + "start": 49, + "type": "REQUEST_EXECUTION" + }, + { + "end": 59, + "start": 56, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.VertexRagDataServiceClient", + "shortName": "VertexRagDataServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.VertexRagDataServiceClient.update_rag_corpus", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.VertexRagDataService.UpdateRagCorpus", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.VertexRagDataService", + "shortName": "VertexRagDataService" + }, + "shortName": "UpdateRagCorpus" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.UpdateRagCorpusRequest" + }, + { + "name": "rag_corpus", + "type": "google.cloud.aiplatform_v1beta1.types.RagCorpus" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.api_core.operation.Operation", + "shortName": "update_rag_corpus" + }, + "description": "Sample for UpdateRagCorpus", + "file": "aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_VertexRagDataService_UpdateRagCorpus_sync", + "segments": [ + { + "end": 58, + "start": 27, + "type": "FULL" + }, + { + "end": 58, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 48, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 55, + "start": 49, + "type": "REQUEST_EXECUTION" + }, + { + "end": 59, + "start": 56, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_vertex_rag_data_service_update_rag_corpus_sync.py" + }, { "canonical": true, "clientMethod": { diff --git a/setup.py b/setup.py index 0786ec7215..95b1590948 100644 --- a/setup.py +++ b/setup.py @@ -98,10 +98,7 @@ autologging_extra_require = ["mlflow>=1.27.0,<=2.1.1"] -preview_extra_require = [ - "cloudpickle < 3.0", - "google-cloud-logging < 4.0", -] +preview_extra_require = [] ray_extra_require = [ # Cluster only supports 2.9.3. Keep 2.4.0 for our testing environment. @@ -152,7 +149,6 @@ evaluation_extra_require = [ "pandas >= 1.0.0, < 2.2.0", "tqdm>=4.23.0", - "immutabledict", ] langchain_extra_require = [ @@ -205,7 +201,6 @@ "grpcio-testing", "ipython", "kfp >= 2.6.0, < 3.0.0", - "pyfakefs", "pytest-asyncio", "pytest-xdist", "scikit-learn", diff --git a/tests/system/aiplatform/test_vision_models.py b/tests/system/aiplatform/test_vision_models.py index 33e5720d37..16929846f9 100644 --- a/tests/system/aiplatform/test_vision_models.py +++ b/tests/system/aiplatform/test_vision_models.py @@ -375,3 +375,63 @@ def test_image_verification_model_verify_image(self): image_with_watermark ) assert image_verification_response["decision"] == "ACCEPT" + + def test_image_upscaling_model_upscale_image(self): + """Tests the image upscaling model upscaling an image.""" + arbitrary_width = 1111 + arbitrary_height = 2000 + target_max_size = 4096 + image_generation_model = vision_models.ImageGenerationModel.from_pretrained( + "imagen-3.0-generate-001" + ) + blank_image = _create_blank_image( + width=arbitrary_width, height=arbitrary_height + ) + assert blank_image.size == (arbitrary_width, arbitrary_height) + upscaled_image = image_generation_model.upscale_image( + blank_image, new_size=target_max_size + ) + new_size = ( + int(arbitrary_width / arbitrary_height * target_max_size), + target_max_size, + ) + assert upscaled_image._size == new_size + + def test_image_upscaling_model_upscaling_factor(self): + """Tests the image upscaling model upscaling an image.""" + arbitrary_width = 1111 + arbitrary_height = 2222 + target_upscaling_factor = "x2" + image_generation_model = vision_models.ImageGenerationModel.from_pretrained( + "imagen-3.0-generate-001" + ) + blank_image = _create_blank_image( + width=arbitrary_width, height=arbitrary_height + ) + assert blank_image.size == (arbitrary_width, arbitrary_height) + upscaled_image = image_generation_model.upscale_image( + blank_image, upscale_factor=target_upscaling_factor + ) + new_size = (arbitrary_width * 2, arbitrary_height * 2) + assert upscaled_image._size == new_size + + def test_image_upscaling_model_save_as_jpeg(self): + """Tests the image upscaling model upscaling an image.""" + arbitrary_width = 1111 + arbitrary_height = 2222 + target_upscaling_factor = "x2" + image_generation_model = vision_models.ImageGenerationModel.from_pretrained( + "imagen-3.0-generate-001" + ) + blank_image = _create_blank_image( + width=arbitrary_width, height=arbitrary_height + ) + assert blank_image.size == (arbitrary_width, arbitrary_height) + upscaled_image = image_generation_model.upscale_image( + blank_image, + upscale_factor=target_upscaling_factor, + output_mime_type="image/jpeg", + output_compression_quality=90, + ) + assert upscaled_image._mime_type == "image/jpeg" + assert upscaled_image._size == (arbitrary_width * 2, arbitrary_height * 2) diff --git a/tests/system/vertexai/test_generative_models.py b/tests/system/vertexai/test_generative_models.py index 39fd0a4b99..e972225fe4 100644 --- a/tests/system/vertexai/test_generative_models.py +++ b/tests/system/vertexai/test_generative_models.py @@ -505,3 +505,36 @@ def test_compute_tokens_from_text(self, api_endpoint_env_name): assert token_info.role # Lightly validate that the tokens are not Base64 encoded assert b"=" not in token_info.tokens + + def test_count_tokens_from_text(self): + plain_model = generative_models.GenerativeModel(GEMINI_MODEL_NAME) + model = generative_models.GenerativeModel( + GEMINI_MODEL_NAME, system_instruction=["You are a chatbot."] + ) + get_current_weather_func = generative_models.FunctionDeclaration.from_func( + get_current_weather + ) + weather_tool = generative_models.Tool( + function_declarations=[get_current_weather_func], + ) + content = ["Why is sky blue?", "Explain it like I'm 5."] + + response_without_si = plain_model.count_tokens(content) + response_with_si = model.count_tokens(content) + response_with_si_and_tool = model.count_tokens( + content, + tools=[weather_tool], + ) + + # system instruction + user prompt + assert response_with_si.total_tokens > response_without_si.total_tokens + assert ( + response_with_si.total_billable_characters + > response_without_si.total_billable_characters + ) + # system instruction + user prompt + tool + assert response_with_si_and_tool.total_tokens > response_with_si.total_tokens + assert ( + response_with_si_and_tool.total_billable_characters + > response_with_si.total_billable_characters + ) diff --git a/tests/system/vertexai/test_tokenization.py b/tests/system/vertexai/test_tokenization.py index 8234aeff76..2795831a95 100644 --- a/tests/system/vertexai/test_tokenization.py +++ b/tests/system/vertexai/test_tokenization.py @@ -20,11 +20,24 @@ from nltk.corpus import udhr from google.cloud import aiplatform from vertexai.preview.tokenization import ( - get_tokenizer_for_model, + get_tokenizer_for_model as tokenizer_preview, +) +from vertexai.tokenization._tokenizers import ( + get_tokenizer_for_model as tokenizer_ga, +) +from vertexai.generative_models import ( + GenerativeModel, + Part, + Tool, ) -from vertexai.generative_models import GenerativeModel from tests.system.aiplatform import e2e_base from google import auth +from google.cloud.aiplatform_v1beta1.types import ( + content as gapic_content_types, + tool as gapic_tool_types, + openapi, +) +from google.protobuf import struct_pb2 _MODELS = ["gemini-1.0-pro", "gemini-1.5-pro", "gemini-1.5-flash"] @@ -34,11 +47,58 @@ _CORPUS_LIB = [ udhr, ] +_VERSIONED_TOKENIZER = [tokenizer_preview, tokenizer_ga] _MODEL_CORPUS_PARAMS = [ - (model_name, corpus_name, corpus_lib) + (get_tokenizer_for_model, model_name, corpus_name, corpus_lib) + for get_tokenizer_for_model in _VERSIONED_TOKENIZER for model_name in _MODELS for (corpus_name, corpus_lib) in zip(_CORPUS, _CORPUS_LIB) ] +_STRUCT = struct_pb2.Struct( + fields={ + "string_key": struct_pb2.Value(string_value="value"), + } +) +_FUNCTION_CALL = gapic_tool_types.FunctionCall(name="test_function_call", args=_STRUCT) +_FUNCTION_RESPONSE = gapic_tool_types.FunctionResponse( + name="function_response", + response=_STRUCT, +) + + +_SCHEMA_1 = openapi.Schema(format="schema1_format", description="schema1_description") +_SCHEMA_2 = openapi.Schema(format="schema2_format", description="schema2_description") +_EXAMPLE = struct_pb2.Value(string_value="value1") + +_FUNCTION_DECLARATION_1 = gapic_tool_types.FunctionDeclaration( + name="function_declaration_name", + description="function_declaration_description", + parameters=openapi.Schema( + format="schema_format", + description="schema_description", + enum=["schema_enum1", "schema_enum2"], + required=["schema_required1", "schema_required2"], + items=_SCHEMA_2, + properties={"property_key": _SCHEMA_1}, + example=_EXAMPLE, + ), +) +_FUNCTION_DECLARATION_2 = gapic_tool_types.FunctionDeclaration( + parameters=openapi.Schema( + nullable=True, + default=struct_pb2.Value(string_value="value1"), + min_items=0, + max_items=0, + min_properties=0, + max_properties=0, + minimum=0, + maximum=0, + min_length=0, + max_length=0, + pattern="pattern", + ), + response=_SCHEMA_1, +) STAGING_API_ENDPOINT = "STAGING_ENDPOINT" PROD_API_ENDPOINT = "PROD_ENDPOINT" @@ -70,11 +130,16 @@ def setup_method(self, api_endpoint_env_name): ) @pytest.mark.parametrize( - "model_name, corpus_name, corpus_lib", + "get_tokenizer_for_model, model_name, corpus_name, corpus_lib", _MODEL_CORPUS_PARAMS, ) def test_count_tokens_local( - self, model_name, corpus_name, corpus_lib, api_endpoint_env_name + self, + get_tokenizer_for_model, + model_name, + corpus_name, + corpus_lib, + api_endpoint_env_name, ): # The Gemini 1.5 flash model requires the model version # number suffix (001) in staging only @@ -90,11 +155,16 @@ def test_count_tokens_local( assert service_result.total_tokens == local_result.total_tokens @pytest.mark.parametrize( - "model_name, corpus_name, corpus_lib", + "get_tokenizer_for_model, model_name, corpus_name, corpus_lib", _MODEL_CORPUS_PARAMS, ) def test_compute_tokens( - self, model_name, corpus_name, corpus_lib, api_endpoint_env_name + self, + get_tokenizer_for_model, + model_name, + corpus_name, + corpus_lib, + api_endpoint_env_name, ): # The Gemini 1.5 flash model requires the model version # number suffix (001) in staging only @@ -107,8 +177,107 @@ def test_compute_tokens( text = corpus_lib.raw(book) response = model.compute_tokens(text) local_result = tokenizer.compute_tokens(text) - for local, service in zip( - local_result.token_info_list, response.tokens_info - ): + for local, service in zip(local_result.tokens_info, response.tokens_info): assert local.tokens == service.tokens assert local.token_ids == service.token_ids + + @pytest.mark.parametrize( + "model_name", + _MODELS, + ) + def test_count_tokens_system_instruction(self, model_name): + tokenizer = tokenizer_preview(model_name) + model = GenerativeModel(model_name, system_instruction=["You are a chatbot."]) + + assert ( + tokenizer.count_tokens( + "hello", system_instruction=["You are a chatbot."] + ).total_tokens + == model.count_tokens("hello").total_tokens + ) + + @pytest.mark.parametrize( + "model_name", + _MODELS, + ) + def test_count_tokens_system_instruction_is_function_call(self, model_name): + part = Part._from_gapic(gapic_content_types.Part(function_call=_FUNCTION_CALL)) + + tokenizer = tokenizer_preview(model_name) + model = GenerativeModel(model_name, system_instruction=[part]) + + assert ( + tokenizer.count_tokens("hello", system_instruction=[part]).total_tokens + == model.count_tokens("hello").total_tokens + ) + + @pytest.mark.parametrize( + "model_name", + _MODELS, + ) + def test_count_tokens_system_instruction_is_function_response(self, model_name): + part = Part._from_gapic( + gapic_content_types.Part(function_response=_FUNCTION_RESPONSE) + ) + tokenizer = tokenizer_preview(model_name) + model = GenerativeModel(model_name, system_instruction=[part]) + + assert tokenizer.count_tokens(part, system_instruction=[part]).total_tokens + assert ( + tokenizer.count_tokens("hello", system_instruction=[part]).total_tokens + == model.count_tokens("hello").total_tokens + ) + + @pytest.mark.parametrize( + "model_name", + _MODELS, + ) + def test_count_tokens_tool_is_function_declaration(self, model_name): + tokenizer = tokenizer_preview(model_name) + model = GenerativeModel(model_name) + tool1 = Tool._from_gapic( + gapic_tool_types.Tool(function_declarations=[_FUNCTION_DECLARATION_1]) + ) + tool2 = Tool._from_gapic( + gapic_tool_types.Tool(function_declarations=[_FUNCTION_DECLARATION_2]) + ) + + assert tokenizer.count_tokens("hello", tools=[tool1]).total_tokens + with pytest.raises(ValueError): + tokenizer.count_tokens("hello", tools=[tool2]).total_tokens + assert ( + tokenizer.count_tokens("hello", tools=[tool1]).total_tokens + == model.count_tokens("hello", tools=[tool1]).total_tokens + ) + + @pytest.mark.parametrize( + "model_name", + _MODELS, + ) + def test_count_tokens_content_is_function_call(self, model_name): + part = Part._from_gapic(gapic_content_types.Part(function_call=_FUNCTION_CALL)) + tokenizer = tokenizer_preview(model_name) + model = GenerativeModel(model_name) + + assert tokenizer.count_tokens(part).total_tokens + assert ( + tokenizer.count_tokens(part).total_tokens + == model.count_tokens(part).total_tokens + ) + + @pytest.mark.parametrize( + "model_name", + _MODELS, + ) + def test_count_tokens_content_is_function_response(self, model_name): + part = Part._from_gapic( + gapic_content_types.Part(function_response=_FUNCTION_RESPONSE) + ) + tokenizer = tokenizer_preview(model_name) + model = GenerativeModel(model_name) + + assert tokenizer.count_tokens(part).total_tokens + assert ( + tokenizer.count_tokens(part).total_tokens + == model.count_tokens(part).total_tokens + ) diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index a13de3158d..25553d8ad7 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -18,8 +18,8 @@ import copy from datetime import datetime, timedelta from importlib import reload -import requests import json +import requests from unittest import mock from google.api_core import operation as ga_operation @@ -920,6 +920,49 @@ def predict_private_endpoint_mock(): yield predict_mock +@pytest.fixture +def stream_raw_predict_private_endpoint_mock(): + with mock.patch.object( + google_auth_requests.AuthorizedSession, "post" + ) as stream_raw_predict_mock: + # Create a mock response object + mock_response = mock.Mock(spec=requests.Response) + + # Configure the mock to be used as a context manager + stream_raw_predict_mock.return_value.__enter__.return_value = mock_response + + # Set the status code to 200 (OK) + mock_response.status_code = 200 + + # Simulate streaming data with iter_lines + mock_response.iter_lines = mock.Mock( + return_value=iter( + [ + json.dumps( + { + "predictions": [1.0, 2.0, 3.0], + "metadata": {"key": "value"}, + "deployedModelId": "model-id-123", + "model": "model-name", + "modelVersionId": "1", + } + ).encode("utf-8"), + json.dumps( + { + "predictions": [4.0, 5.0, 6.0], + "metadata": {"key": "value"}, + "deployedModelId": "model-id-123", + "model": "model-name", + "modelVersionId": "1", + } + ).encode("utf-8"), + ] + ) + ) + + yield stream_raw_predict_mock + + @pytest.fixture def health_check_private_endpoint_mock(): with mock.patch.object(urllib3.PoolManager, "request") as health_check_mock: @@ -3195,6 +3238,57 @@ def test_psc_predict(self, predict_private_endpoint_mock): }, ) + @pytest.mark.usefixtures("get_psc_private_endpoint_mock") + def test_psc_stream_raw_predict(self, stream_raw_predict_private_endpoint_mock): + test_endpoint = models.PrivateEndpoint( + project=_TEST_PROJECT, location=_TEST_LOCATION, endpoint_name=_TEST_ID + ) + + test_prediction_iterator = test_endpoint.stream_raw_predict( + body='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}', + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer None", + }, + endpoint_override=_TEST_ENDPOINT_OVERRIDE, + ) + + test_prediction = list(test_prediction_iterator) + + stream_raw_predict_private_endpoint_mock.assert_called_once_with( + url=f"https://{_TEST_ENDPOINT_OVERRIDE}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:streamRawPredict", + data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}', + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer None", + }, + stream=True, + verify=False, + ) + + # Validate the content of the returned predictions + expected_predictions = [ + json.dumps( + { + "predictions": [1.0, 2.0, 3.0], + "metadata": {"key": "value"}, + "deployedModelId": "model-id-123", + "model": "model-name", + "modelVersionId": "1", + } + ).encode("utf-8"), + json.dumps( + { + "predictions": [4.0, 5.0, 6.0], + "metadata": {"key": "value"}, + "deployedModelId": "model-id-123", + "model": "model-name", + "modelVersionId": "1", + } + ).encode("utf-8"), + ] + assert test_prediction == expected_predictions + @pytest.mark.usefixtures("get_psc_private_endpoint_mock") def test_psc_predict_without_endpoint_override(self): test_endpoint = models.PrivateEndpoint( diff --git a/tests/unit/aiplatform/test_logdir_loader.py b/tests/unit/aiplatform/test_logdir_loader.py index 90c87464a2..648f39022d 100644 --- a/tests/unit/aiplatform/test_logdir_loader.py +++ b/tests/unit/aiplatform/test_logdir_loader.py @@ -201,15 +201,15 @@ def test_multiple_writes_to_logdir(self): { "a": ["tag_a"], "b": ["tag_b"], - "b-x": ["tag_b_x"], - "b-z": ["tag_b_z"], + "b/x": ["tag_b_x"], + "b_z": ["tag_b_z"], "c": ["tag_c"], }, ) # A second load should indicate no new data. self.assertEqual( self._extract_run_to_tags(loader.get_run_events()), - {"a": [], "b": [], "b-x": [], "b-z": [], "c": []}, + {"a": [], "b": [], "b/x": [], "b_z": [], "c": []}, ) # Write some new data to both new and pre-existing event files. with FileWriter(os.path.join(logdir, "a"), filename_suffix=".other") as writer: @@ -228,8 +228,8 @@ def test_multiple_writes_to_logdir(self): { "a": ["tag_a_2", "tag_a_3", "tag_a_4"], "b": [], - "b-x": ["tag_b_x_2"], - "b-z": [], + "b/x": ["tag_b_x_2"], + "b_z": [], "c": ["tag_c_2"], }, ) diff --git a/tests/unit/aiplatform/test_vision_models.py b/tests/unit/aiplatform/test_vision_models.py index 0323d0f5bb..c2fcf578c0 100644 --- a/tests/unit/aiplatform/test_vision_models.py +++ b/tests/unit/aiplatform/test_vision_models.py @@ -22,7 +22,7 @@ import io import os import tempfile -from typing import Any, Dict +from typing import Any, Dict, Literal, Optional, Tuple import unittest from unittest import mock @@ -39,12 +39,10 @@ from google.cloud.aiplatform.compat.types import ( publisher_model as gca_publisher_model, ) - from vertexai import vision_models as ga_vision_models from vertexai.preview import ( vision_models as preview_vision_models, ) - from PIL import Image as PIL_Image import pytest @@ -131,20 +129,39 @@ def make_image_generation_response_gcs(count: int = 1) -> Dict[str, Any]: for _ in range(count): predictions.append( { - "gcsUri": ( - "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" - ), + "gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png", "mimeType": "image/png", } ) return {"predictions": predictions} -def make_image_upscale_response(upscale_size: int) -> Dict[str, Any]: - predictions = { - "bytesBase64Encoded": make_image_base64(upscale_size, upscale_size), - "mimeType": "image/png", - } +def make_image_upscale_response( + upscale_size: Optional[int] = None, + upscale_factor: Optional[Literal["x2", "x4"]] = None, + input_size: Optional[Tuple[int, int]] = None, +) -> Dict[str, Any]: + predictions = {} + if upscale_size: + predictions = { + "bytesBase64Encoded": make_image_base64(upscale_size, upscale_size), + "mimeType": "image/png", + } + else: + if upscale_factor == "x2": + predictions = { + "bytesBase64Encoded": make_image_base64( + input_size[0] * 2, input_size[1] * 2 + ), + "mimeType": "image/png", + } + else: + predictions = { + "bytesBase64Encoded": make_image_base64( + input_size[0] * 4, input_size[1] * 4 + ), + "mimeType": "image/png", + } return {"predictions": [predictions]} @@ -723,7 +740,7 @@ def test_upscale_image_on_provided_image(self): model = self._get_image_generation_model() image_generation_response = make_image_generation_response( - count=1, height=1024, width=1024 + count=1, height=1111, width=2000 ) gca_generation_response = gca_prediction_service.PredictResponse() gca_generation_response.predictions.extend( @@ -739,7 +756,7 @@ def test_upscale_image_on_provided_image(self): attribute="predict", return_value=gca_upscale_response, ) as mock_upscale: - test_image = generate_image_from_file(height=1024, width=1024) + test_image = generate_image_from_file(height=1111, width=2000) upscaled_image = model.upscale_image(image=test_image, new_size=4096) @@ -748,24 +765,107 @@ def test_upscale_image_on_provided_image(self): assert actual_instance["image"]["bytesBase64Encoded"] image_upscale_parameters = predict_kwargs["parameters"] - assert ( - image_upscale_parameters["sampleImageSize"] - == str(upscaled_image._size[0]) - == str(upscaled_image.generation_parameters["upscaled_image_size"]) + assert image_upscale_parameters["sampleImageSize"] == str( + upscaled_image.generation_parameters["upscaled_image_size"] ) assert image_upscale_parameters["mode"] == "upscale" assert upscaled_image._image_bytes assert isinstance(upscaled_image, preview_vision_models.GeneratedImage) - def test_upscale_image_raises_if_not_1024x1024(self): + def test_upscale_image_on_wrong_image_size(self): + """Tests image upscaling on generated images.""" + model = self._get_image_generation_model() + image_generation_response = make_image_generation_response( + count=1, height=1111, width=2222 + ) + gca_generation_response = gca_prediction_service.PredictResponse() + gca_generation_response.predictions.extend( + image_generation_response["predictions"] + ) + + image_upscale_response = make_image_upscale_response(upscale_size=4096) + gca_upscale_response = gca_prediction_service.PredictResponse() + gca_upscale_response.predictions.extend(image_upscale_response["predictions"]) + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_upscale_response, + ): + test_image = generate_image_from_file(height=1111, width=2222) + with pytest.raises(ValueError): + model.upscale_image(image=test_image, new_size=4096) + + def test_upscale_image_with_upscale_factor(self): + """Tests image upscaling on generated images.""" + model = self._get_image_generation_model() + image_generation_response = make_image_generation_response( + count=1, height=1111, width=2222 + ) + gca_generation_response = gca_prediction_service.PredictResponse() + gca_generation_response.predictions.extend( + image_generation_response["predictions"] + ) + + image_upscale_response = make_image_upscale_response( + upscale_factor="x2", input_size=(1111, 2222) + ) + gca_upscale_response = gca_prediction_service.PredictResponse() + gca_upscale_response.predictions.extend(image_upscale_response["predictions"]) + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_upscale_response, + ) as mock_upscale: + test_image = generate_image_from_file(height=1111, width=2222) + + upscaled_image = model.upscale_image(image=test_image, upscale_factor="x2") + + predict_kwargs = mock_upscale.call_args[1] + actual_instance = predict_kwargs["instances"][0] + assert actual_instance["image"]["bytesBase64Encoded"] + + image_upscale_parameters = predict_kwargs["parameters"] + assert image_upscale_parameters["mode"] == "upscale" + assert "sampleImageSize" not in image_upscale_parameters + + assert upscaled_image._image_bytes + assert isinstance(upscaled_image, preview_vision_models.GeneratedImage) + + def test_upscale_image_with_jpeg_output(self): """Tests image upscaling on generated images.""" model = self._get_image_generation_model() + image_generation_response = make_image_generation_response( + count=1, height=1111, width=2222 + ) + gca_generation_response = gca_prediction_service.PredictResponse() + gca_generation_response.predictions.extend( + image_generation_response["predictions"] + ) + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_generation_response, + ) as mock_upscale: + test_image = generate_image_from_file(height=1111, width=2222) + upscaled_image = model.upscale_image( + image=test_image, + upscale_factor="x2", + output_mime_type="image/jpeg", + output_compression_quality=90, + ) - test_image = generate_image_from_file(height=100, width=100) + predict_kwargs = mock_upscale.call_args[1] + actual_instance = predict_kwargs["instances"][0] + assert actual_instance["image"]["bytesBase64Encoded"] - with pytest.raises(ValueError): - model.upscale_image(image=test_image) + image_upscale_parameters = predict_kwargs["parameters"] + assert image_upscale_parameters["mode"] == "upscale" + assert image_upscale_parameters["outputOptions"]["mimeType"] == "image/jpeg" + assert image_upscale_parameters["outputOptions"]["compressionQuality"] == 90 + assert upscaled_image._image_bytes + assert isinstance(upscaled_image, preview_vision_models.GeneratedImage) @pytest.mark.usefixtures("google_auth_mock") diff --git a/tests/unit/gapic/aiplatform_v1/test_job_service.py b/tests/unit/gapic/aiplatform_v1/test_job_service.py index 3d77876553..2deb4c5275 100644 --- a/tests/unit/gapic/aiplatform_v1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_job_service.py @@ -16816,6 +16816,7 @@ def test_create_custom_job_rest(request_type): "restart_job_on_worker_restart": True, "strategy": 1, "disable_retries": True, + "max_wait_duration": {}, }, "service_account": "service_account_value", "network": "network_value", @@ -20425,6 +20426,7 @@ def test_create_hyperparameter_tuning_job_rest(request_type): "restart_job_on_worker_restart": True, "strategy": 1, "disable_retries": True, + "max_wait_duration": {}, }, "service_account": "service_account_value", "network": "network_value", @@ -22322,6 +22324,7 @@ def test_create_nas_job_rest(request_type): "restart_job_on_worker_restart": True, "strategy": 1, "disable_retries": True, + "max_wait_duration": {}, }, "service_account": "service_account_value", "network": "network_value", diff --git a/tests/unit/gapic/aiplatform_v1/test_schedule_service.py b/tests/unit/gapic/aiplatform_v1/test_schedule_service.py index e0e59c7aca..439956e90c 100644 --- a/tests/unit/gapic/aiplatform_v1/test_schedule_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_schedule_service.py @@ -60,6 +60,9 @@ from google.cloud.aiplatform_v1.types import context from google.cloud.aiplatform_v1.types import encryption_spec from google.cloud.aiplatform_v1.types import execution +from google.cloud.aiplatform_v1.types import job_state +from google.cloud.aiplatform_v1.types import notebook_execution_job +from google.cloud.aiplatform_v1.types import notebook_service from google.cloud.aiplatform_v1.types import operation as gca_operation from google.cloud.aiplatform_v1.types import pipeline_failure_policy from google.cloud.aiplatform_v1.types import pipeline_job @@ -76,6 +79,7 @@ from google.longrunning import operations_pb2 # type: ignore from google.oauth2 import service_account from google.protobuf import any_pb2 # type: ignore +from google.protobuf import duration_pb2 # type: ignore from google.protobuf import empty_pb2 # type: ignore from google.protobuf import field_mask_pb2 # type: ignore from google.protobuf import struct_pb2 # type: ignore @@ -4158,6 +4162,35 @@ def test_create_schedule_rest(request_type): }, "pipeline_job_id": "pipeline_job_id_value", }, + "create_notebook_execution_job_request": { + "parent": "parent_value", + "notebook_execution_job": { + "dataform_repository_source": { + "dataform_repository_resource_name": "dataform_repository_resource_name_value", + "commit_sha": "commit_sha_value", + }, + "gcs_notebook_source": { + "uri": "uri_value", + "generation": "generation_value", + }, + "direct_notebook_source": {"content": b"content_blob"}, + "notebook_runtime_template_resource_name": "notebook_runtime_template_resource_name_value", + "gcs_output_uri": "gcs_output_uri_value", + "execution_user": "execution_user_value", + "service_account": "service_account_value", + "name": "name_value", + "display_name": "display_name_value", + "execution_timeout": {"seconds": 751, "nanos": 543}, + "schedule_resource_name": "schedule_resource_name_value", + "job_state": 1, + "status": {}, + "create_time": {}, + "update_time": {}, + "labels": {}, + "encryption_spec": {}, + }, + "notebook_execution_job_id": "notebook_execution_job_id_value", + }, "name": "name_value", "display_name": "display_name_value", "start_time": {}, @@ -6263,6 +6296,35 @@ def test_update_schedule_rest(request_type): }, "pipeline_job_id": "pipeline_job_id_value", }, + "create_notebook_execution_job_request": { + "parent": "parent_value", + "notebook_execution_job": { + "dataform_repository_source": { + "dataform_repository_resource_name": "dataform_repository_resource_name_value", + "commit_sha": "commit_sha_value", + }, + "gcs_notebook_source": { + "uri": "uri_value", + "generation": "generation_value", + }, + "direct_notebook_source": {"content": b"content_blob"}, + "notebook_runtime_template_resource_name": "notebook_runtime_template_resource_name_value", + "gcs_output_uri": "gcs_output_uri_value", + "execution_user": "execution_user_value", + "service_account": "service_account_value", + "name": "name_value", + "display_name": "display_name_value", + "execution_timeout": {"seconds": 751, "nanos": 543}, + "schedule_resource_name": "schedule_resource_name_value", + "job_state": 1, + "status": {}, + "create_time": {}, + "update_time": {}, + "labels": {}, + "encryption_spec": {}, + }, + "notebook_execution_job_id": "notebook_execution_job_id_value", + }, "name": "projects/sample1/locations/sample2/schedules/sample3", "display_name": "display_name_value", "start_time": {}, @@ -7426,6 +7488,62 @@ def test_parse_network_path(): assert expected == actual +def test_notebook_execution_job_path(): + project = "scallop" + location = "abalone" + notebook_execution_job = "squid" + expected = "projects/{project}/locations/{location}/notebookExecutionJobs/{notebook_execution_job}".format( + project=project, + location=location, + notebook_execution_job=notebook_execution_job, + ) + actual = ScheduleServiceClient.notebook_execution_job_path( + project, location, notebook_execution_job + ) + assert expected == actual + + +def test_parse_notebook_execution_job_path(): + expected = { + "project": "clam", + "location": "whelk", + "notebook_execution_job": "octopus", + } + path = ScheduleServiceClient.notebook_execution_job_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_notebook_execution_job_path(path) + assert expected == actual + + +def test_notebook_runtime_template_path(): + project = "oyster" + location = "nudibranch" + notebook_runtime_template = "cuttlefish" + expected = "projects/{project}/locations/{location}/notebookRuntimeTemplates/{notebook_runtime_template}".format( + project=project, + location=location, + notebook_runtime_template=notebook_runtime_template, + ) + actual = ScheduleServiceClient.notebook_runtime_template_path( + project, location, notebook_runtime_template + ) + assert expected == actual + + +def test_parse_notebook_runtime_template_path(): + expected = { + "project": "mussel", + "location": "winkle", + "notebook_runtime_template": "nautilus", + } + path = ScheduleServiceClient.notebook_runtime_template_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_notebook_runtime_template_path(path) + assert expected == actual + + def test_pipeline_job_path(): project = "scallop" location = "abalone" diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py index 17d35098b2..6924ccaa6c 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -16819,6 +16819,7 @@ def test_create_custom_job_rest(request_type): "restart_job_on_worker_restart": True, "strategy": 1, "disable_retries": True, + "max_wait_duration": {}, }, "service_account": "service_account_value", "network": "network_value", @@ -20446,6 +20447,7 @@ def test_create_hyperparameter_tuning_job_rest(request_type): "restart_job_on_worker_restart": True, "strategy": 1, "disable_retries": True, + "max_wait_duration": {}, }, "service_account": "service_account_value", "network": "network_value", @@ -22343,6 +22345,7 @@ def test_create_nas_job_rest(request_type): "restart_job_on_worker_restart": True, "strategy": 1, "disable_retries": True, + "max_wait_duration": {}, }, "service_account": "service_account_value", "network": "network_value", diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py index 07cd3e052c..a891b9f491 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py @@ -1609,11 +1609,11 @@ async def test_create_rag_corpus_flattened_error_async(): @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.GetRagCorpusRequest, + vertex_rag_data_service.UpdateRagCorpusRequest, dict, ], ) -def test_get_rag_corpus(request_type, transport: str = "grpc"): +def test_update_rag_corpus(request_type, transport: str = "grpc"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1624,29 +1624,24 @@ def test_get_rag_corpus(request_type, transport: str = "grpc"): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: + with mock.patch.object( + type(client.transport.update_rag_corpus), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = vertex_rag_data.RagCorpus( - name="name_value", - display_name="display_name_value", - description="description_value", - ) - response = client.get_rag_corpus(request) + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.update_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.GetRagCorpusRequest() + request = vertex_rag_data_service.UpdateRagCorpusRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, vertex_rag_data.RagCorpus) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" + assert isinstance(response, future.Future) -def test_get_rag_corpus_empty_call(): +def test_update_rag_corpus_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceClient( @@ -1655,17 +1650,19 @@ def test_get_rag_corpus_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: + with mock.patch.object( + type(client.transport.update_rag_corpus), "__call__" + ) as call: call.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client.get_rag_corpus() + client.update_rag_corpus() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.GetRagCorpusRequest() + assert args[0] == vertex_rag_data_service.UpdateRagCorpusRequest() -def test_get_rag_corpus_non_empty_request_with_auto_populated_field(): +def test_update_rag_corpus_non_empty_request_with_auto_populated_field(): # This test is a coverage failsafe to make sure that UUID4 fields are # automatically populated, according to AIP-4235, with non-empty requests. client = VertexRagDataServiceClient( @@ -1676,24 +1673,22 @@ def test_get_rag_corpus_non_empty_request_with_auto_populated_field(): # Populate all string fields in the request which are not UUID4 # since we want to check that UUID4 are populated automatically # if they meet the requirements of AIP 4235. - request = vertex_rag_data_service.GetRagCorpusRequest( - name="name_value", - ) + request = vertex_rag_data_service.UpdateRagCorpusRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: + with mock.patch.object( + type(client.transport.update_rag_corpus), "__call__" + ) as call: call.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client.get_rag_corpus(request=request) + client.update_rag_corpus(request=request) call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.GetRagCorpusRequest( - name="name_value", - ) + assert args[0] == vertex_rag_data_service.UpdateRagCorpusRequest() -def test_get_rag_corpus_use_cached_wrapped_rpc(): +def test_update_rag_corpus_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -1707,21 +1702,28 @@ def test_get_rag_corpus_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.get_rag_corpus in client._transport._wrapped_methods + assert client._transport.update_rag_corpus in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[client._transport.get_rag_corpus] = mock_rpc + client._transport._wrapped_methods[ + client._transport.update_rag_corpus + ] = mock_rpc request = {} - client.get_rag_corpus(request) + client.update_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.get_rag_corpus(request) + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_rag_corpus(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -1729,7 +1731,7 @@ def test_get_rag_corpus_use_cached_wrapped_rpc(): @pytest.mark.asyncio -async def test_get_rag_corpus_empty_call_async(): +async def test_update_rag_corpus_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceAsyncClient( @@ -1738,23 +1740,21 @@ async def test_get_rag_corpus_empty_call_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: + with mock.patch.object( + type(client.transport.update_rag_corpus), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data.RagCorpus( - name="name_value", - display_name="display_name_value", - description="description_value", - ) + operations_pb2.Operation(name="operations/spam") ) - response = await client.get_rag_corpus() + response = await client.update_rag_corpus() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.GetRagCorpusRequest() + assert args[0] == vertex_rag_data_service.UpdateRagCorpusRequest() @pytest.mark.asyncio -async def test_get_rag_corpus_async_use_cached_wrapped_rpc( +async def test_update_rag_corpus_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", ): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, @@ -1771,7 +1771,7 @@ async def test_get_rag_corpus_async_use_cached_wrapped_rpc( # Ensure method has been cached assert ( - client._client._transport.get_rag_corpus + client._client._transport.update_rag_corpus in client._client._transport._wrapped_methods ) @@ -1779,16 +1779,21 @@ async def test_get_rag_corpus_async_use_cached_wrapped_rpc( mock_rpc = mock.AsyncMock() mock_rpc.return_value = mock.Mock() client._client._transport._wrapped_methods[ - client._client._transport.get_rag_corpus + client._client._transport.update_rag_corpus ] = mock_rpc request = {} - await client.get_rag_corpus(request) + await client.update_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - await client.get_rag_corpus(request) + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_rag_corpus(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -1796,9 +1801,9 @@ async def test_get_rag_corpus_async_use_cached_wrapped_rpc( @pytest.mark.asyncio -async def test_get_rag_corpus_async( +async def test_update_rag_corpus_async( transport: str = "grpc_asyncio", - request_type=vertex_rag_data_service.GetRagCorpusRequest, + request_type=vertex_rag_data_service.UpdateRagCorpusRequest, ): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -1810,50 +1815,47 @@ async def test_get_rag_corpus_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: + with mock.patch.object( + type(client.transport.update_rag_corpus), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data.RagCorpus( - name="name_value", - display_name="display_name_value", - description="description_value", - ) + operations_pb2.Operation(name="operations/spam") ) - response = await client.get_rag_corpus(request) + response = await client.update_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.GetRagCorpusRequest() + request = vertex_rag_data_service.UpdateRagCorpusRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, vertex_rag_data.RagCorpus) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" + assert isinstance(response, future.Future) @pytest.mark.asyncio -async def test_get_rag_corpus_async_from_dict(): - await test_get_rag_corpus_async(request_type=dict) +async def test_update_rag_corpus_async_from_dict(): + await test_update_rag_corpus_async(request_type=dict) -def test_get_rag_corpus_field_headers(): +def test_update_rag_corpus_field_headers(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.GetRagCorpusRequest() + request = vertex_rag_data_service.UpdateRagCorpusRequest() - request.name = "name_value" + request.rag_corpus.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: - call.return_value = vertex_rag_data.RagCorpus() - client.get_rag_corpus(request) + with mock.patch.object( + type(client.transport.update_rag_corpus), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.update_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -1864,28 +1866,30 @@ def test_get_rag_corpus_field_headers(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "name=name_value", + "rag_corpus.name=name_value", ) in kw["metadata"] @pytest.mark.asyncio -async def test_get_rag_corpus_field_headers_async(): +async def test_update_rag_corpus_field_headers_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.GetRagCorpusRequest() + request = vertex_rag_data_service.UpdateRagCorpusRequest() - request.name = "name_value" + request.rag_corpus.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: + with mock.patch.object( + type(client.transport.update_rag_corpus), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data.RagCorpus() + operations_pb2.Operation(name="operations/op") ) - await client.get_rag_corpus(request) + await client.update_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -1896,35 +1900,37 @@ async def test_get_rag_corpus_field_headers_async(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "name=name_value", + "rag_corpus.name=name_value", ) in kw["metadata"] -def test_get_rag_corpus_flattened(): +def test_update_rag_corpus_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: + with mock.patch.object( + type(client.transport.update_rag_corpus), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = vertex_rag_data.RagCorpus() + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_rag_corpus( - name="name_value", + client.update_rag_corpus( + rag_corpus=vertex_rag_data.RagCorpus(name="name_value"), ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" + arg = args[0].rag_corpus + mock_val = vertex_rag_data.RagCorpus(name="name_value") assert arg == mock_val -def test_get_rag_corpus_flattened_error(): +def test_update_rag_corpus_flattened_error(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -1932,43 +1938,45 @@ def test_get_rag_corpus_flattened_error(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.get_rag_corpus( - vertex_rag_data_service.GetRagCorpusRequest(), - name="name_value", + client.update_rag_corpus( + vertex_rag_data_service.UpdateRagCorpusRequest(), + rag_corpus=vertex_rag_data.RagCorpus(name="name_value"), ) @pytest.mark.asyncio -async def test_get_rag_corpus_flattened_async(): +async def test_update_rag_corpus_flattened_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: + with mock.patch.object( + type(client.transport.update_rag_corpus), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = vertex_rag_data.RagCorpus() + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data.RagCorpus() + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_rag_corpus( - name="name_value", + response = await client.update_rag_corpus( + rag_corpus=vertex_rag_data.RagCorpus(name="name_value"), ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" + arg = args[0].rag_corpus + mock_val = vertex_rag_data.RagCorpus(name="name_value") assert arg == mock_val @pytest.mark.asyncio -async def test_get_rag_corpus_flattened_error_async(): +async def test_update_rag_corpus_flattened_error_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -1976,20 +1984,20 @@ async def test_get_rag_corpus_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.get_rag_corpus( - vertex_rag_data_service.GetRagCorpusRequest(), - name="name_value", + await client.update_rag_corpus( + vertex_rag_data_service.UpdateRagCorpusRequest(), + rag_corpus=vertex_rag_data.RagCorpus(name="name_value"), ) @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.ListRagCorporaRequest, + vertex_rag_data_service.GetRagCorpusRequest, dict, ], ) -def test_list_rag_corpora(request_type, transport: str = "grpc"): +def test_get_rag_corpus(request_type, transport: str = "grpc"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2000,25 +2008,29 @@ def test_list_rag_corpora(request_type, transport: str = "grpc"): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = vertex_rag_data_service.ListRagCorporaResponse( - next_page_token="next_page_token_value", + call.return_value = vertex_rag_data.RagCorpus( + name="name_value", + display_name="display_name_value", + description="description_value", ) - response = client.list_rag_corpora(request) + response = client.get_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.ListRagCorporaRequest() + request = vertex_rag_data_service.GetRagCorpusRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListRagCorporaPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, vertex_rag_data.RagCorpus) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" -def test_list_rag_corpora_empty_call(): +def test_get_rag_corpus_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceClient( @@ -2027,17 +2039,17 @@ def test_list_rag_corpora_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: call.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client.list_rag_corpora() + client.get_rag_corpus() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.ListRagCorporaRequest() + assert args[0] == vertex_rag_data_service.GetRagCorpusRequest() -def test_list_rag_corpora_non_empty_request_with_auto_populated_field(): +def test_get_rag_corpus_non_empty_request_with_auto_populated_field(): # This test is a coverage failsafe to make sure that UUID4 fields are # automatically populated, according to AIP-4235, with non-empty requests. client = VertexRagDataServiceClient( @@ -2048,26 +2060,24 @@ def test_list_rag_corpora_non_empty_request_with_auto_populated_field(): # Populate all string fields in the request which are not UUID4 # since we want to check that UUID4 are populated automatically # if they meet the requirements of AIP 4235. - request = vertex_rag_data_service.ListRagCorporaRequest( - parent="parent_value", - page_token="page_token_value", + request = vertex_rag_data_service.GetRagCorpusRequest( + name="name_value", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: call.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client.list_rag_corpora(request=request) + client.get_rag_corpus(request=request) call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.ListRagCorporaRequest( - parent="parent_value", - page_token="page_token_value", + assert args[0] == vertex_rag_data_service.GetRagCorpusRequest( + name="name_value", ) -def test_list_rag_corpora_use_cached_wrapped_rpc(): +def test_get_rag_corpus_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -2081,23 +2091,21 @@ def test_list_rag_corpora_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.list_rag_corpora in client._transport._wrapped_methods + assert client._transport.get_rag_corpus in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[ - client._transport.list_rag_corpora - ] = mock_rpc + client._transport._wrapped_methods[client._transport.get_rag_corpus] = mock_rpc request = {} - client.list_rag_corpora(request) + client.get_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.list_rag_corpora(request) + client.get_rag_corpus(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -2105,7 +2113,7 @@ def test_list_rag_corpora_use_cached_wrapped_rpc(): @pytest.mark.asyncio -async def test_list_rag_corpora_empty_call_async(): +async def test_get_rag_corpus_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceAsyncClient( @@ -2114,21 +2122,23 @@ async def test_list_rag_corpora_empty_call_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data_service.ListRagCorporaResponse( - next_page_token="next_page_token_value", + vertex_rag_data.RagCorpus( + name="name_value", + display_name="display_name_value", + description="description_value", ) ) - response = await client.list_rag_corpora() + response = await client.get_rag_corpus() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.ListRagCorporaRequest() + assert args[0] == vertex_rag_data_service.GetRagCorpusRequest() @pytest.mark.asyncio -async def test_list_rag_corpora_async_use_cached_wrapped_rpc( +async def test_get_rag_corpus_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", ): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, @@ -2145,7 +2155,7 @@ async def test_list_rag_corpora_async_use_cached_wrapped_rpc( # Ensure method has been cached assert ( - client._client._transport.list_rag_corpora + client._client._transport.get_rag_corpus in client._client._transport._wrapped_methods ) @@ -2153,16 +2163,16 @@ async def test_list_rag_corpora_async_use_cached_wrapped_rpc( mock_rpc = mock.AsyncMock() mock_rpc.return_value = mock.Mock() client._client._transport._wrapped_methods[ - client._client._transport.list_rag_corpora + client._client._transport.get_rag_corpus ] = mock_rpc request = {} - await client.list_rag_corpora(request) + await client.get_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - await client.list_rag_corpora(request) + await client.get_rag_corpus(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -2170,9 +2180,9 @@ async def test_list_rag_corpora_async_use_cached_wrapped_rpc( @pytest.mark.asyncio -async def test_list_rag_corpora_async( +async def test_get_rag_corpus_async( transport: str = "grpc_asyncio", - request_type=vertex_rag_data_service.ListRagCorporaRequest, + request_type=vertex_rag_data_service.GetRagCorpusRequest, ): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -2184,46 +2194,50 @@ async def test_list_rag_corpora_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data_service.ListRagCorporaResponse( - next_page_token="next_page_token_value", + vertex_rag_data.RagCorpus( + name="name_value", + display_name="display_name_value", + description="description_value", ) ) - response = await client.list_rag_corpora(request) + response = await client.get_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.ListRagCorporaRequest() + request = vertex_rag_data_service.GetRagCorpusRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListRagCorporaAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, vertex_rag_data.RagCorpus) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" @pytest.mark.asyncio -async def test_list_rag_corpora_async_from_dict(): - await test_list_rag_corpora_async(request_type=dict) +async def test_get_rag_corpus_async_from_dict(): + await test_get_rag_corpus_async(request_type=dict) -def test_list_rag_corpora_field_headers(): +def test_get_rag_corpus_field_headers(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.ListRagCorporaRequest() + request = vertex_rag_data_service.GetRagCorpusRequest() - request.parent = "parent_value" + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: - call.return_value = vertex_rag_data_service.ListRagCorporaResponse() - client.list_rag_corpora(request) + with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: + call.return_value = vertex_rag_data.RagCorpus() + client.get_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -2234,28 +2248,28 @@ def test_list_rag_corpora_field_headers(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "parent=parent_value", + "name=name_value", ) in kw["metadata"] @pytest.mark.asyncio -async def test_list_rag_corpora_field_headers_async(): +async def test_get_rag_corpus_field_headers_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.ListRagCorporaRequest() + request = vertex_rag_data_service.GetRagCorpusRequest() - request.parent = "parent_value" + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data_service.ListRagCorporaResponse() + vertex_rag_data.RagCorpus() ) - await client.list_rag_corpora(request) + await client.get_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -2266,35 +2280,35 @@ async def test_list_rag_corpora_field_headers_async(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "parent=parent_value", + "name=name_value", ) in kw["metadata"] -def test_list_rag_corpora_flattened(): +def test_get_rag_corpus_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = vertex_rag_data_service.ListRagCorporaResponse() + call.return_value = vertex_rag_data.RagCorpus() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_rag_corpora( - parent="parent_value", + client.get_rag_corpus( + name="name_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - arg = args[0].parent - mock_val = "parent_value" + arg = args[0].name + mock_val = "name_value" assert arg == mock_val -def test_list_rag_corpora_flattened_error(): +def test_get_rag_corpus_flattened_error(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -2302,43 +2316,43 @@ def test_list_rag_corpora_flattened_error(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.list_rag_corpora( - vertex_rag_data_service.ListRagCorporaRequest(), - parent="parent_value", + client.get_rag_corpus( + vertex_rag_data_service.GetRagCorpusRequest(), + name="name_value", ) @pytest.mark.asyncio -async def test_list_rag_corpora_flattened_async(): +async def test_get_rag_corpus_flattened_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = vertex_rag_data_service.ListRagCorporaResponse() + call.return_value = vertex_rag_data.RagCorpus() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data_service.ListRagCorporaResponse() + vertex_rag_data.RagCorpus() ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_rag_corpora( - parent="parent_value", + response = await client.get_rag_corpus( + name="name_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - arg = args[0].parent - mock_val = "parent_value" + arg = args[0].name + mock_val = "name_value" assert arg == mock_val @pytest.mark.asyncio -async def test_list_rag_corpora_flattened_error_async(): +async def test_get_rag_corpus_flattened_error_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -2346,307 +2360,112 @@ async def test_list_rag_corpora_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.list_rag_corpora( - vertex_rag_data_service.ListRagCorporaRequest(), - parent="parent_value", + await client.get_rag_corpus( + vertex_rag_data_service.GetRagCorpusRequest(), + name="name_value", ) -def test_list_rag_corpora_pager(transport_name: str = "grpc"): +@pytest.mark.parametrize( + "request_type", + [ + vertex_rag_data_service.ListRagCorporaRequest, + dict, + ], +) +def test_list_rag_corpora(request_type, transport: str = "grpc"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport_name, + transport=transport, ) + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[ - vertex_rag_data.RagCorpus(), - vertex_rag_data.RagCorpus(), - vertex_rag_data.RagCorpus(), - ], - next_page_token="abc", - ), - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[], - next_page_token="def", - ), - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[ - vertex_rag_data.RagCorpus(), - ], - next_page_token="ghi", - ), - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[ - vertex_rag_data.RagCorpus(), - vertex_rag_data.RagCorpus(), - ], - ), - RuntimeError, - ) - - expected_metadata = () - retry = retries.Retry() - timeout = 5 - expected_metadata = tuple(expected_metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + # Designate an appropriate return value for the call. + call.return_value = vertex_rag_data_service.ListRagCorporaResponse( + next_page_token="next_page_token_value", ) - pager = client.list_rag_corpora(request={}, retry=retry, timeout=timeout) + response = client.list_rag_corpora(request) - assert pager._metadata == expected_metadata - assert pager._retry == retry - assert pager._timeout == timeout + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = vertex_rag_data_service.ListRagCorporaRequest() + assert args[0] == request - results = list(pager) - assert len(results) == 6 - assert all(isinstance(i, vertex_rag_data.RagCorpus) for i in results) + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListRagCorporaPager) + assert response.next_page_token == "next_page_token_value" -def test_list_rag_corpora_pages(transport_name: str = "grpc"): +def test_list_rag_corpora_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport_name, + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[ - vertex_rag_data.RagCorpus(), - vertex_rag_data.RagCorpus(), - vertex_rag_data.RagCorpus(), - ], - next_page_token="abc", - ), - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[], - next_page_token="def", - ), - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[ - vertex_rag_data.RagCorpus(), - ], - next_page_token="ghi", - ), - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[ - vertex_rag_data.RagCorpus(), - vertex_rag_data.RagCorpus(), - ], - ), - RuntimeError, + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) - pages = list(client.list_rag_corpora(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token + client.list_rag_corpora() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == vertex_rag_data_service.ListRagCorporaRequest() -@pytest.mark.asyncio -async def test_list_rag_corpora_async_pager(): - client = VertexRagDataServiceAsyncClient( +def test_list_rag_corpora_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = vertex_rag_data_service.ListRagCorporaRequest( + parent="parent_value", + page_token="page_token_value", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_rag_corpora), "__call__", new_callable=mock.AsyncMock - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[ - vertex_rag_data.RagCorpus(), - vertex_rag_data.RagCorpus(), - vertex_rag_data.RagCorpus(), - ], - next_page_token="abc", - ), - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[], - next_page_token="def", - ), - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[ - vertex_rag_data.RagCorpus(), - ], - next_page_token="ghi", - ), - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[ - vertex_rag_data.RagCorpus(), - vertex_rag_data.RagCorpus(), - ], - ), - RuntimeError, + with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) - async_pager = await client.list_rag_corpora( - request={}, + client.list_rag_corpora(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == vertex_rag_data_service.ListRagCorporaRequest( + parent="parent_value", + page_token="page_token_value", ) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: # pragma: no branch - responses.append(response) - assert len(responses) == 6 - assert all(isinstance(i, vertex_rag_data.RagCorpus) for i in responses) - -@pytest.mark.asyncio -async def test_list_rag_corpora_async_pages(): - client = VertexRagDataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_rag_corpora), "__call__", new_callable=mock.AsyncMock - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[ - vertex_rag_data.RagCorpus(), - vertex_rag_data.RagCorpus(), - vertex_rag_data.RagCorpus(), - ], - next_page_token="abc", - ), - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[], - next_page_token="def", - ), - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[ - vertex_rag_data.RagCorpus(), - ], - next_page_token="ghi", - ), - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[ - vertex_rag_data.RagCorpus(), - vertex_rag_data.RagCorpus(), - ], - ), - RuntimeError, - ) - pages = [] - # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` - # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 - async for page_ in ( # pragma: no branch - await client.list_rag_corpora(request={}) - ).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.parametrize( - "request_type", - [ - vertex_rag_data_service.DeleteRagCorpusRequest, - dict, - ], -) -def test_delete_rag_corpus(request_type, transport: str = "grpc"): - client = VertexRagDataServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_rag_corpus), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - response = client.delete_rag_corpus(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.DeleteRagCorpusRequest() - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_rag_corpus_empty_call(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = VertexRagDataServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_rag_corpus), "__call__" - ) as call: - call.return_value.name = ( - "foo" # operation_request.operation in compute client(s) expect a string. - ) - client.delete_rag_corpus() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.DeleteRagCorpusRequest() - - -def test_delete_rag_corpus_non_empty_request_with_auto_populated_field(): - # This test is a coverage failsafe to make sure that UUID4 fields are - # automatically populated, according to AIP-4235, with non-empty requests. - client = VertexRagDataServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", - ) - - # Populate all string fields in the request which are not UUID4 - # since we want to check that UUID4 are populated automatically - # if they meet the requirements of AIP 4235. - request = vertex_rag_data_service.DeleteRagCorpusRequest( - name="name_value", - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_rag_corpus), "__call__" - ) as call: - call.return_value.name = ( - "foo" # operation_request.operation in compute client(s) expect a string. - ) - client.delete_rag_corpus(request=request) - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.DeleteRagCorpusRequest( - name="name_value", - ) - - -def test_delete_rag_corpus_use_cached_wrapped_rpc(): - # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, - # instead of constructing them on each call - with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: - client = VertexRagDataServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", - ) +def test_list_rag_corpora_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) # Should wrap all calls on client creation assert wrapper_fn.call_count > 0 wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.delete_rag_corpus in client._transport._wrapped_methods + assert client._transport.list_rag_corpora in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() @@ -2654,20 +2473,15 @@ def test_delete_rag_corpus_use_cached_wrapped_rpc(): "foo" # operation_request.operation in compute client(s) expect a string. ) client._transport._wrapped_methods[ - client._transport.delete_rag_corpus + client._transport.list_rag_corpora ] = mock_rpc request = {} - client.delete_rag_corpus(request) + client.list_rag_corpora(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - # Operation methods call wrapper_fn to build a cached - # client._transport.operations_client instance on first rpc call. - # Subsequent calls should use the cached wrapper - wrapper_fn.reset_mock() - - client.delete_rag_corpus(request) + client.list_rag_corpora(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -2675,7 +2489,7 @@ def test_delete_rag_corpus_use_cached_wrapped_rpc(): @pytest.mark.asyncio -async def test_delete_rag_corpus_empty_call_async(): +async def test_list_rag_corpora_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceAsyncClient( @@ -2684,21 +2498,21 @@ async def test_delete_rag_corpus_empty_call_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_rag_corpus), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + vertex_rag_data_service.ListRagCorporaResponse( + next_page_token="next_page_token_value", + ) ) - response = await client.delete_rag_corpus() + response = await client.list_rag_corpora() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.DeleteRagCorpusRequest() + assert args[0] == vertex_rag_data_service.ListRagCorporaRequest() @pytest.mark.asyncio -async def test_delete_rag_corpus_async_use_cached_wrapped_rpc( +async def test_list_rag_corpora_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", ): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, @@ -2715,7 +2529,7 @@ async def test_delete_rag_corpus_async_use_cached_wrapped_rpc( # Ensure method has been cached assert ( - client._client._transport.delete_rag_corpus + client._client._transport.list_rag_corpora in client._client._transport._wrapped_methods ) @@ -2723,21 +2537,16 @@ async def test_delete_rag_corpus_async_use_cached_wrapped_rpc( mock_rpc = mock.AsyncMock() mock_rpc.return_value = mock.Mock() client._client._transport._wrapped_methods[ - client._client._transport.delete_rag_corpus + client._client._transport.list_rag_corpora ] = mock_rpc request = {} - await client.delete_rag_corpus(request) + await client.list_rag_corpora(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - # Operation methods call wrapper_fn to build a cached - # client._transport.operations_client instance on first rpc call. - # Subsequent calls should use the cached wrapper - wrapper_fn.reset_mock() - - await client.delete_rag_corpus(request) + await client.list_rag_corpora(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -2745,9 +2554,9 @@ async def test_delete_rag_corpus_async_use_cached_wrapped_rpc( @pytest.mark.asyncio -async def test_delete_rag_corpus_async( +async def test_list_rag_corpora_async( transport: str = "grpc_asyncio", - request_type=vertex_rag_data_service.DeleteRagCorpusRequest, + request_type=vertex_rag_data_service.ListRagCorporaRequest, ): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -2759,47 +2568,46 @@ async def test_delete_rag_corpus_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_rag_corpus), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + vertex_rag_data_service.ListRagCorporaResponse( + next_page_token="next_page_token_value", + ) ) - response = await client.delete_rag_corpus(request) + response = await client.list_rag_corpora(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.DeleteRagCorpusRequest() + request = vertex_rag_data_service.ListRagCorporaRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) + assert isinstance(response, pagers.ListRagCorporaAsyncPager) + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio -async def test_delete_rag_corpus_async_from_dict(): - await test_delete_rag_corpus_async(request_type=dict) +async def test_list_rag_corpora_async_from_dict(): + await test_list_rag_corpora_async(request_type=dict) -def test_delete_rag_corpus_field_headers(): +def test_list_rag_corpora_field_headers(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.DeleteRagCorpusRequest() + request = vertex_rag_data_service.ListRagCorporaRequest() - request.name = "name_value" + request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_rag_corpus), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") - client.delete_rag_corpus(request) + with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: + call.return_value = vertex_rag_data_service.ListRagCorporaResponse() + client.list_rag_corpora(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -2810,30 +2618,28 @@ def test_delete_rag_corpus_field_headers(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "name=name_value", + "parent=parent_value", ) in kw["metadata"] @pytest.mark.asyncio -async def test_delete_rag_corpus_field_headers_async(): +async def test_list_rag_corpora_field_headers_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.DeleteRagCorpusRequest() + request = vertex_rag_data_service.ListRagCorporaRequest() - request.name = "name_value" + request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_rag_corpus), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") + vertex_rag_data_service.ListRagCorporaResponse() ) - await client.delete_rag_corpus(request) + await client.list_rag_corpora(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -2844,37 +2650,35 @@ async def test_delete_rag_corpus_field_headers_async(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "name=name_value", + "parent=parent_value", ) in kw["metadata"] -def test_delete_rag_corpus_flattened(): +def test_list_rag_corpora_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_rag_corpus), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = vertex_rag_data_service.ListRagCorporaResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_rag_corpus( - name="name_value", + client.list_rag_corpora( + parent="parent_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" + arg = args[0].parent + mock_val = "parent_value" assert arg == mock_val -def test_delete_rag_corpus_flattened_error(): +def test_list_rag_corpora_flattened_error(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -2882,45 +2686,43 @@ def test_delete_rag_corpus_flattened_error(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.delete_rag_corpus( - vertex_rag_data_service.DeleteRagCorpusRequest(), - name="name_value", + client.list_rag_corpora( + vertex_rag_data_service.ListRagCorporaRequest(), + parent="parent_value", ) @pytest.mark.asyncio -async def test_delete_rag_corpus_flattened_async(): +async def test_list_rag_corpora_flattened_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_rag_corpus), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = vertex_rag_data_service.ListRagCorporaResponse() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + vertex_rag_data_service.ListRagCorporaResponse() ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_rag_corpus( - name="name_value", + response = await client.list_rag_corpora( + parent="parent_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" + arg = args[0].parent + mock_val = "parent_value" assert arg == mock_val @pytest.mark.asyncio -async def test_delete_rag_corpus_flattened_error_async(): +async def test_list_rag_corpora_flattened_error_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -2928,20 +2730,214 @@ async def test_delete_rag_corpus_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.delete_rag_corpus( - vertex_rag_data_service.DeleteRagCorpusRequest(), - name="name_value", + await client.list_rag_corpora( + vertex_rag_data_service.ListRagCorporaRequest(), + parent="parent_value", + ) + + +def test_list_rag_corpora_pager(transport_name: str = "grpc"): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[ + vertex_rag_data.RagCorpus(), + vertex_rag_data.RagCorpus(), + vertex_rag_data.RagCorpus(), + ], + next_page_token="abc", + ), + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[], + next_page_token="def", + ), + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[ + vertex_rag_data.RagCorpus(), + ], + next_page_token="ghi", + ), + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[ + vertex_rag_data.RagCorpus(), + vertex_rag_data.RagCorpus(), + ], + ), + RuntimeError, + ) + + expected_metadata = () + retry = retries.Retry() + timeout = 5 + expected_metadata = tuple(expected_metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_rag_corpora(request={}, retry=retry, timeout=timeout) + + assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, vertex_rag_data.RagCorpus) for i in results) + + +def test_list_rag_corpora_pages(transport_name: str = "grpc"): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[ + vertex_rag_data.RagCorpus(), + vertex_rag_data.RagCorpus(), + vertex_rag_data.RagCorpus(), + ], + next_page_token="abc", + ), + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[], + next_page_token="def", + ), + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[ + vertex_rag_data.RagCorpus(), + ], + next_page_token="ghi", + ), + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[ + vertex_rag_data.RagCorpus(), + vertex_rag_data.RagCorpus(), + ], + ), + RuntimeError, + ) + pages = list(client.list_rag_corpora(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_rag_corpora_async_pager(): + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_rag_corpora), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[ + vertex_rag_data.RagCorpus(), + vertex_rag_data.RagCorpus(), + vertex_rag_data.RagCorpus(), + ], + next_page_token="abc", + ), + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[], + next_page_token="def", + ), + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[ + vertex_rag_data.RagCorpus(), + ], + next_page_token="ghi", + ), + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[ + vertex_rag_data.RagCorpus(), + vertex_rag_data.RagCorpus(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_rag_corpora( + request={}, + ) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, vertex_rag_data.RagCorpus) for i in responses) + + +@pytest.mark.asyncio +async def test_list_rag_corpora_async_pages(): + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_rag_corpora), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[ + vertex_rag_data.RagCorpus(), + vertex_rag_data.RagCorpus(), + vertex_rag_data.RagCorpus(), + ], + next_page_token="abc", + ), + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[], + next_page_token="def", + ), + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[ + vertex_rag_data.RagCorpus(), + ], + next_page_token="ghi", + ), + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[ + vertex_rag_data.RagCorpus(), + vertex_rag_data.RagCorpus(), + ], + ), + RuntimeError, ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_rag_corpora(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.UploadRagFileRequest, + vertex_rag_data_service.DeleteRagCorpusRequest, dict, ], ) -def test_upload_rag_file(request_type, transport: str = "grpc"): +def test_delete_rag_corpus(request_type, transport: str = "grpc"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2952,22 +2948,24 @@ def test_upload_rag_file(request_type, transport: str = "grpc"): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_rag_corpus), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = vertex_rag_data_service.UploadRagFileResponse() - response = client.upload_rag_file(request) + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.delete_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.UploadRagFileRequest() + request = vertex_rag_data_service.DeleteRagCorpusRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, vertex_rag_data_service.UploadRagFileResponse) + assert isinstance(response, future.Future) -def test_upload_rag_file_empty_call(): +def test_delete_rag_corpus_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceClient( @@ -2976,17 +2974,19 @@ def test_upload_rag_file_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_rag_corpus), "__call__" + ) as call: call.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client.upload_rag_file() + client.delete_rag_corpus() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.UploadRagFileRequest() + assert args[0] == vertex_rag_data_service.DeleteRagCorpusRequest() -def test_upload_rag_file_non_empty_request_with_auto_populated_field(): +def test_delete_rag_corpus_non_empty_request_with_auto_populated_field(): # This test is a coverage failsafe to make sure that UUID4 fields are # automatically populated, according to AIP-4235, with non-empty requests. client = VertexRagDataServiceClient( @@ -2997,24 +2997,26 @@ def test_upload_rag_file_non_empty_request_with_auto_populated_field(): # Populate all string fields in the request which are not UUID4 # since we want to check that UUID4 are populated automatically # if they meet the requirements of AIP 4235. - request = vertex_rag_data_service.UploadRagFileRequest( - parent="parent_value", + request = vertex_rag_data_service.DeleteRagCorpusRequest( + name="name_value", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_rag_corpus), "__call__" + ) as call: call.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client.upload_rag_file(request=request) + client.delete_rag_corpus(request=request) call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.UploadRagFileRequest( - parent="parent_value", + assert args[0] == vertex_rag_data_service.DeleteRagCorpusRequest( + name="name_value", ) -def test_upload_rag_file_use_cached_wrapped_rpc(): +def test_delete_rag_corpus_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -3028,21 +3030,28 @@ def test_upload_rag_file_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.upload_rag_file in client._transport._wrapped_methods + assert client._transport.delete_rag_corpus in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[client._transport.upload_rag_file] = mock_rpc + client._transport._wrapped_methods[ + client._transport.delete_rag_corpus + ] = mock_rpc request = {} - client.upload_rag_file(request) + client.delete_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.upload_rag_file(request) + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_rag_corpus(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -3050,7 +3059,7 @@ def test_upload_rag_file_use_cached_wrapped_rpc(): @pytest.mark.asyncio -async def test_upload_rag_file_empty_call_async(): +async def test_delete_rag_corpus_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceAsyncClient( @@ -3059,19 +3068,21 @@ async def test_upload_rag_file_empty_call_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_rag_corpus), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data_service.UploadRagFileResponse() + operations_pb2.Operation(name="operations/spam") ) - response = await client.upload_rag_file() + response = await client.delete_rag_corpus() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.UploadRagFileRequest() + assert args[0] == vertex_rag_data_service.DeleteRagCorpusRequest() @pytest.mark.asyncio -async def test_upload_rag_file_async_use_cached_wrapped_rpc( +async def test_delete_rag_corpus_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", ): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, @@ -3088,7 +3099,7 @@ async def test_upload_rag_file_async_use_cached_wrapped_rpc( # Ensure method has been cached assert ( - client._client._transport.upload_rag_file + client._client._transport.delete_rag_corpus in client._client._transport._wrapped_methods ) @@ -3096,16 +3107,21 @@ async def test_upload_rag_file_async_use_cached_wrapped_rpc( mock_rpc = mock.AsyncMock() mock_rpc.return_value = mock.Mock() client._client._transport._wrapped_methods[ - client._client._transport.upload_rag_file + client._client._transport.delete_rag_corpus ] = mock_rpc request = {} - await client.upload_rag_file(request) + await client.delete_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - await client.upload_rag_file(request) + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_rag_corpus(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -3113,9 +3129,9 @@ async def test_upload_rag_file_async_use_cached_wrapped_rpc( @pytest.mark.asyncio -async def test_upload_rag_file_async( +async def test_delete_rag_corpus_async( transport: str = "grpc_asyncio", - request_type=vertex_rag_data_service.UploadRagFileRequest, + request_type=vertex_rag_data_service.DeleteRagCorpusRequest, ): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -3127,43 +3143,47 @@ async def test_upload_rag_file_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_rag_corpus), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data_service.UploadRagFileResponse() + operations_pb2.Operation(name="operations/spam") ) - response = await client.upload_rag_file(request) + response = await client.delete_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.UploadRagFileRequest() + request = vertex_rag_data_service.DeleteRagCorpusRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, vertex_rag_data_service.UploadRagFileResponse) + assert isinstance(response, future.Future) @pytest.mark.asyncio -async def test_upload_rag_file_async_from_dict(): - await test_upload_rag_file_async(request_type=dict) +async def test_delete_rag_corpus_async_from_dict(): + await test_delete_rag_corpus_async(request_type=dict) -def test_upload_rag_file_field_headers(): +def test_delete_rag_corpus_field_headers(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.UploadRagFileRequest() + request = vertex_rag_data_service.DeleteRagCorpusRequest() - request.parent = "parent_value" + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: - call.return_value = vertex_rag_data_service.UploadRagFileResponse() - client.upload_rag_file(request) + with mock.patch.object( + type(client.transport.delete_rag_corpus), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.delete_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -3174,28 +3194,30 @@ def test_upload_rag_file_field_headers(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "parent=parent_value", + "name=name_value", ) in kw["metadata"] @pytest.mark.asyncio -async def test_upload_rag_file_field_headers_async(): +async def test_delete_rag_corpus_field_headers_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.UploadRagFileRequest() + request = vertex_rag_data_service.DeleteRagCorpusRequest() - request.parent = "parent_value" + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_rag_corpus), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data_service.UploadRagFileResponse() + operations_pb2.Operation(name="operations/op") ) - await client.upload_rag_file(request) + await client.delete_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -3206,53 +3228,37 @@ async def test_upload_rag_file_field_headers_async(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "parent=parent_value", + "name=name_value", ) in kw["metadata"] -def test_upload_rag_file_flattened(): +def test_delete_rag_corpus_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_rag_corpus), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = vertex_rag_data_service.UploadRagFileResponse() + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.upload_rag_file( - parent="parent_value", - rag_file=vertex_rag_data.RagFile( - gcs_source=io.GcsSource(uris=["uris_value"]) - ), - upload_rag_file_config=vertex_rag_data.UploadRagFileConfig( - rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( - chunk_size=1075 - ) - ), + client.delete_rag_corpus( + name="name_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - arg = args[0].parent - mock_val = "parent_value" - assert arg == mock_val - arg = args[0].rag_file - mock_val = vertex_rag_data.RagFile(gcs_source=io.GcsSource(uris=["uris_value"])) - assert arg == mock_val - arg = args[0].upload_rag_file_config - mock_val = vertex_rag_data.UploadRagFileConfig( - rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( - chunk_size=1075 - ) - ) + arg = args[0].name + mock_val = "name_value" assert arg == mock_val -def test_upload_rag_file_flattened_error(): +def test_delete_rag_corpus_flattened_error(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -3260,69 +3266,45 @@ def test_upload_rag_file_flattened_error(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.upload_rag_file( - vertex_rag_data_service.UploadRagFileRequest(), - parent="parent_value", - rag_file=vertex_rag_data.RagFile( - gcs_source=io.GcsSource(uris=["uris_value"]) - ), - upload_rag_file_config=vertex_rag_data.UploadRagFileConfig( - rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( - chunk_size=1075 - ) - ), + client.delete_rag_corpus( + vertex_rag_data_service.DeleteRagCorpusRequest(), + name="name_value", ) @pytest.mark.asyncio -async def test_upload_rag_file_flattened_async(): +async def test_delete_rag_corpus_flattened_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_rag_corpus), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = vertex_rag_data_service.UploadRagFileResponse() + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data_service.UploadRagFileResponse() + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.upload_rag_file( - parent="parent_value", - rag_file=vertex_rag_data.RagFile( - gcs_source=io.GcsSource(uris=["uris_value"]) - ), - upload_rag_file_config=vertex_rag_data.UploadRagFileConfig( - rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( - chunk_size=1075 - ) - ), + response = await client.delete_rag_corpus( + name="name_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - arg = args[0].parent - mock_val = "parent_value" - assert arg == mock_val - arg = args[0].rag_file - mock_val = vertex_rag_data.RagFile(gcs_source=io.GcsSource(uris=["uris_value"])) - assert arg == mock_val - arg = args[0].upload_rag_file_config - mock_val = vertex_rag_data.UploadRagFileConfig( - rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( - chunk_size=1075 - ) - ) + arg = args[0].name + mock_val = "name_value" assert arg == mock_val @pytest.mark.asyncio -async def test_upload_rag_file_flattened_error_async(): +async def test_delete_rag_corpus_flattened_error_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -3330,28 +3312,20 @@ async def test_upload_rag_file_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.upload_rag_file( - vertex_rag_data_service.UploadRagFileRequest(), - parent="parent_value", - rag_file=vertex_rag_data.RagFile( - gcs_source=io.GcsSource(uris=["uris_value"]) - ), - upload_rag_file_config=vertex_rag_data.UploadRagFileConfig( - rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( - chunk_size=1075 - ) - ), + await client.delete_rag_corpus( + vertex_rag_data_service.DeleteRagCorpusRequest(), + name="name_value", ) @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.ImportRagFilesRequest, + vertex_rag_data_service.UploadRagFileRequest, dict, ], ) -def test_import_rag_files(request_type, transport: str = "grpc"): +def test_upload_rag_file(request_type, transport: str = "grpc"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3362,22 +3336,22 @@ def test_import_rag_files(request_type, transport: str = "grpc"): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - response = client.import_rag_files(request) + call.return_value = vertex_rag_data_service.UploadRagFileResponse() + response = client.upload_rag_file(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.ImportRagFilesRequest() + request = vertex_rag_data_service.UploadRagFileRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) + assert isinstance(response, vertex_rag_data_service.UploadRagFileResponse) -def test_import_rag_files_empty_call(): +def test_upload_rag_file_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceClient( @@ -3386,17 +3360,17 @@ def test_import_rag_files_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: call.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client.import_rag_files() + client.upload_rag_file() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.ImportRagFilesRequest() + assert args[0] == vertex_rag_data_service.UploadRagFileRequest() -def test_import_rag_files_non_empty_request_with_auto_populated_field(): +def test_upload_rag_file_non_empty_request_with_auto_populated_field(): # This test is a coverage failsafe to make sure that UUID4 fields are # automatically populated, according to AIP-4235, with non-empty requests. client = VertexRagDataServiceClient( @@ -3407,24 +3381,24 @@ def test_import_rag_files_non_empty_request_with_auto_populated_field(): # Populate all string fields in the request which are not UUID4 # since we want to check that UUID4 are populated automatically # if they meet the requirements of AIP 4235. - request = vertex_rag_data_service.ImportRagFilesRequest( + request = vertex_rag_data_service.UploadRagFileRequest( parent="parent_value", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: call.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client.import_rag_files(request=request) + client.upload_rag_file(request=request) call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.ImportRagFilesRequest( + assert args[0] == vertex_rag_data_service.UploadRagFileRequest( parent="parent_value", ) -def test_import_rag_files_use_cached_wrapped_rpc(): +def test_upload_rag_file_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -3438,28 +3412,21 @@ def test_import_rag_files_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.import_rag_files in client._transport._wrapped_methods + assert client._transport.upload_rag_file in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[ - client._transport.import_rag_files - ] = mock_rpc + client._transport._wrapped_methods[client._transport.upload_rag_file] = mock_rpc request = {} - client.import_rag_files(request) + client.upload_rag_file(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - # Operation methods call wrapper_fn to build a cached - # client._transport.operations_client instance on first rpc call. - # Subsequent calls should use the cached wrapper - wrapper_fn.reset_mock() - - client.import_rag_files(request) + client.upload_rag_file(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -3467,7 +3434,7 @@ def test_import_rag_files_use_cached_wrapped_rpc(): @pytest.mark.asyncio -async def test_import_rag_files_empty_call_async(): +async def test_upload_rag_file_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceAsyncClient( @@ -3476,19 +3443,19 @@ async def test_import_rag_files_empty_call_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + vertex_rag_data_service.UploadRagFileResponse() ) - response = await client.import_rag_files() + response = await client.upload_rag_file() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.ImportRagFilesRequest() + assert args[0] == vertex_rag_data_service.UploadRagFileRequest() @pytest.mark.asyncio -async def test_import_rag_files_async_use_cached_wrapped_rpc( +async def test_upload_rag_file_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", ): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, @@ -3505,7 +3472,7 @@ async def test_import_rag_files_async_use_cached_wrapped_rpc( # Ensure method has been cached assert ( - client._client._transport.import_rag_files + client._client._transport.upload_rag_file in client._client._transport._wrapped_methods ) @@ -3513,21 +3480,16 @@ async def test_import_rag_files_async_use_cached_wrapped_rpc( mock_rpc = mock.AsyncMock() mock_rpc.return_value = mock.Mock() client._client._transport._wrapped_methods[ - client._client._transport.import_rag_files + client._client._transport.upload_rag_file ] = mock_rpc request = {} - await client.import_rag_files(request) + await client.upload_rag_file(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - # Operation methods call wrapper_fn to build a cached - # client._transport.operations_client instance on first rpc call. - # Subsequent calls should use the cached wrapper - wrapper_fn.reset_mock() - - await client.import_rag_files(request) + await client.upload_rag_file(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -3535,9 +3497,9 @@ async def test_import_rag_files_async_use_cached_wrapped_rpc( @pytest.mark.asyncio -async def test_import_rag_files_async( +async def test_upload_rag_file_async( transport: str = "grpc_asyncio", - request_type=vertex_rag_data_service.ImportRagFilesRequest, + request_type=vertex_rag_data_service.UploadRagFileRequest, ): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -3549,43 +3511,43 @@ async def test_import_rag_files_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + vertex_rag_data_service.UploadRagFileResponse() ) - response = await client.import_rag_files(request) + response = await client.upload_rag_file(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.ImportRagFilesRequest() + request = vertex_rag_data_service.UploadRagFileRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) + assert isinstance(response, vertex_rag_data_service.UploadRagFileResponse) @pytest.mark.asyncio -async def test_import_rag_files_async_from_dict(): - await test_import_rag_files_async(request_type=dict) +async def test_upload_rag_file_async_from_dict(): + await test_upload_rag_file_async(request_type=dict) -def test_import_rag_files_field_headers(): +def test_upload_rag_file_field_headers(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.ImportRagFilesRequest() + request = vertex_rag_data_service.UploadRagFileRequest() request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") - client.import_rag_files(request) + with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: + call.return_value = vertex_rag_data_service.UploadRagFileResponse() + client.upload_rag_file(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -3601,23 +3563,23 @@ def test_import_rag_files_field_headers(): @pytest.mark.asyncio -async def test_import_rag_files_field_headers_async(): +async def test_upload_rag_file_field_headers_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.ImportRagFilesRequest() + request = vertex_rag_data_service.UploadRagFileRequest() request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") + vertex_rag_data_service.UploadRagFileResponse() ) - await client.import_rag_files(request) + await client.upload_rag_file(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -3632,22 +3594,27 @@ async def test_import_rag_files_field_headers_async(): ) in kw["metadata"] -def test_import_rag_files_flattened(): +def test_upload_rag_file_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = vertex_rag_data_service.UploadRagFileResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.import_rag_files( + client.upload_rag_file( parent="parent_value", - import_rag_files_config=vertex_rag_data.ImportRagFilesConfig( + rag_file=vertex_rag_data.RagFile( gcs_source=io.GcsSource(uris=["uris_value"]) ), + upload_rag_file_config=vertex_rag_data.UploadRagFileConfig( + rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( + chunk_size=1075 + ) + ), ) # Establish that the underlying call was made with the expected @@ -3657,14 +3624,19 @@ def test_import_rag_files_flattened(): arg = args[0].parent mock_val = "parent_value" assert arg == mock_val - arg = args[0].import_rag_files_config - mock_val = vertex_rag_data.ImportRagFilesConfig( - gcs_source=io.GcsSource(uris=["uris_value"]) + arg = args[0].rag_file + mock_val = vertex_rag_data.RagFile(gcs_source=io.GcsSource(uris=["uris_value"])) + assert arg == mock_val + arg = args[0].upload_rag_file_config + mock_val = vertex_rag_data.UploadRagFileConfig( + rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( + chunk_size=1075 + ) ) assert arg == mock_val -def test_import_rag_files_flattened_error(): +def test_upload_rag_file_flattened_error(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -3672,36 +3644,46 @@ def test_import_rag_files_flattened_error(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.import_rag_files( - vertex_rag_data_service.ImportRagFilesRequest(), + client.upload_rag_file( + vertex_rag_data_service.UploadRagFileRequest(), parent="parent_value", - import_rag_files_config=vertex_rag_data.ImportRagFilesConfig( + rag_file=vertex_rag_data.RagFile( gcs_source=io.GcsSource(uris=["uris_value"]) ), + upload_rag_file_config=vertex_rag_data.UploadRagFileConfig( + rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( + chunk_size=1075 + ) + ), ) @pytest.mark.asyncio -async def test_import_rag_files_flattened_async(): +async def test_upload_rag_file_flattened_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = vertex_rag_data_service.UploadRagFileResponse() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + vertex_rag_data_service.UploadRagFileResponse() ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.import_rag_files( + response = await client.upload_rag_file( parent="parent_value", - import_rag_files_config=vertex_rag_data.ImportRagFilesConfig( + rag_file=vertex_rag_data.RagFile( gcs_source=io.GcsSource(uris=["uris_value"]) ), + upload_rag_file_config=vertex_rag_data.UploadRagFileConfig( + rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( + chunk_size=1075 + ) + ), ) # Establish that the underlying call was made with the expected @@ -3711,15 +3693,20 @@ async def test_import_rag_files_flattened_async(): arg = args[0].parent mock_val = "parent_value" assert arg == mock_val - arg = args[0].import_rag_files_config - mock_val = vertex_rag_data.ImportRagFilesConfig( - gcs_source=io.GcsSource(uris=["uris_value"]) + arg = args[0].rag_file + mock_val = vertex_rag_data.RagFile(gcs_source=io.GcsSource(uris=["uris_value"])) + assert arg == mock_val + arg = args[0].upload_rag_file_config + mock_val = vertex_rag_data.UploadRagFileConfig( + rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( + chunk_size=1075 + ) ) assert arg == mock_val @pytest.mark.asyncio -async def test_import_rag_files_flattened_error_async(): +async def test_upload_rag_file_flattened_error_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -3727,23 +3714,28 @@ async def test_import_rag_files_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.import_rag_files( - vertex_rag_data_service.ImportRagFilesRequest(), + await client.upload_rag_file( + vertex_rag_data_service.UploadRagFileRequest(), parent="parent_value", - import_rag_files_config=vertex_rag_data.ImportRagFilesConfig( + rag_file=vertex_rag_data.RagFile( gcs_source=io.GcsSource(uris=["uris_value"]) ), + upload_rag_file_config=vertex_rag_data.UploadRagFileConfig( + rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( + chunk_size=1075 + ) + ), ) @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.GetRagFileRequest, + vertex_rag_data_service.ImportRagFilesRequest, dict, ], ) -def test_get_rag_file(request_type, transport: str = "grpc"): +def test_import_rag_files(request_type, transport: str = "grpc"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3754,35 +3746,22 @@ def test_get_rag_file(request_type, transport: str = "grpc"): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: + with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = vertex_rag_data.RagFile( - name="name_value", - display_name="display_name_value", - description="description_value", - size_bytes=1089, - rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT, - ) - response = client.get_rag_file(request) + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.import_rag_files(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.GetRagFileRequest() + request = vertex_rag_data_service.ImportRagFilesRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, vertex_rag_data.RagFile) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" - assert response.size_bytes == 1089 - assert ( - response.rag_file_type == vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT - ) + assert isinstance(response, future.Future) -def test_get_rag_file_empty_call(): +def test_import_rag_files_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceClient( @@ -3791,17 +3770,17 @@ def test_get_rag_file_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: + with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: call.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client.get_rag_file() + client.import_rag_files() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.GetRagFileRequest() + assert args[0] == vertex_rag_data_service.ImportRagFilesRequest() -def test_get_rag_file_non_empty_request_with_auto_populated_field(): +def test_import_rag_files_non_empty_request_with_auto_populated_field(): # This test is a coverage failsafe to make sure that UUID4 fields are # automatically populated, according to AIP-4235, with non-empty requests. client = VertexRagDataServiceClient( @@ -3812,24 +3791,24 @@ def test_get_rag_file_non_empty_request_with_auto_populated_field(): # Populate all string fields in the request which are not UUID4 # since we want to check that UUID4 are populated automatically # if they meet the requirements of AIP 4235. - request = vertex_rag_data_service.GetRagFileRequest( - name="name_value", + request = vertex_rag_data_service.ImportRagFilesRequest( + parent="parent_value", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: + with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: call.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client.get_rag_file(request=request) + client.import_rag_files(request=request) call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.GetRagFileRequest( - name="name_value", + assert args[0] == vertex_rag_data_service.ImportRagFilesRequest( + parent="parent_value", ) -def test_get_rag_file_use_cached_wrapped_rpc(): +def test_import_rag_files_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -3843,21 +3822,28 @@ def test_get_rag_file_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.get_rag_file in client._transport._wrapped_methods + assert client._transport.import_rag_files in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[client._transport.get_rag_file] = mock_rpc + client._transport._wrapped_methods[ + client._transport.import_rag_files + ] = mock_rpc request = {} - client.get_rag_file(request) + client.import_rag_files(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.get_rag_file(request) + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.import_rag_files(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -3865,7 +3851,7 @@ def test_get_rag_file_use_cached_wrapped_rpc(): @pytest.mark.asyncio -async def test_get_rag_file_empty_call_async(): +async def test_import_rag_files_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceAsyncClient( @@ -3874,25 +3860,19 @@ async def test_get_rag_file_empty_call_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: + with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data.RagFile( - name="name_value", - display_name="display_name_value", - description="description_value", - size_bytes=1089, - rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT, - ) + operations_pb2.Operation(name="operations/spam") ) - response = await client.get_rag_file() + response = await client.import_rag_files() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.GetRagFileRequest() + assert args[0] == vertex_rag_data_service.ImportRagFilesRequest() @pytest.mark.asyncio -async def test_get_rag_file_async_use_cached_wrapped_rpc( +async def test_import_rag_files_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", ): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, @@ -3909,7 +3889,7 @@ async def test_get_rag_file_async_use_cached_wrapped_rpc( # Ensure method has been cached assert ( - client._client._transport.get_rag_file + client._client._transport.import_rag_files in client._client._transport._wrapped_methods ) @@ -3917,16 +3897,21 @@ async def test_get_rag_file_async_use_cached_wrapped_rpc( mock_rpc = mock.AsyncMock() mock_rpc.return_value = mock.Mock() client._client._transport._wrapped_methods[ - client._client._transport.get_rag_file + client._client._transport.import_rag_files ] = mock_rpc request = {} - await client.get_rag_file(request) + await client.import_rag_files(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - await client.get_rag_file(request) + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.import_rag_files(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -3934,9 +3919,9 @@ async def test_get_rag_file_async_use_cached_wrapped_rpc( @pytest.mark.asyncio -async def test_get_rag_file_async( +async def test_import_rag_files_async( transport: str = "grpc_asyncio", - request_type=vertex_rag_data_service.GetRagFileRequest, + request_type=vertex_rag_data_service.ImportRagFilesRequest, ): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -3948,56 +3933,43 @@ async def test_get_rag_file_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: + with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data.RagFile( - name="name_value", - display_name="display_name_value", - description="description_value", - size_bytes=1089, - rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT, - ) + operations_pb2.Operation(name="operations/spam") ) - response = await client.get_rag_file(request) + response = await client.import_rag_files(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.GetRagFileRequest() + request = vertex_rag_data_service.ImportRagFilesRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, vertex_rag_data.RagFile) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" - assert response.size_bytes == 1089 - assert ( - response.rag_file_type == vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT - ) + assert isinstance(response, future.Future) @pytest.mark.asyncio -async def test_get_rag_file_async_from_dict(): - await test_get_rag_file_async(request_type=dict) +async def test_import_rag_files_async_from_dict(): + await test_import_rag_files_async(request_type=dict) -def test_get_rag_file_field_headers(): +def test_import_rag_files_field_headers(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.GetRagFileRequest() + request = vertex_rag_data_service.ImportRagFilesRequest() - request.name = "name_value" + request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: - call.return_value = vertex_rag_data.RagFile() - client.get_rag_file(request) + with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.import_rag_files(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -4008,28 +3980,28 @@ def test_get_rag_file_field_headers(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "name=name_value", + "parent=parent_value", ) in kw["metadata"] @pytest.mark.asyncio -async def test_get_rag_file_field_headers_async(): +async def test_import_rag_files_field_headers_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.GetRagFileRequest() + request = vertex_rag_data_service.ImportRagFilesRequest() - request.name = "name_value" + request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: + with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data.RagFile() + operations_pb2.Operation(name="operations/op") ) - await client.get_rag_file(request) + await client.import_rag_files(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -4040,35 +4012,43 @@ async def test_get_rag_file_field_headers_async(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "name=name_value", + "parent=parent_value", ) in kw["metadata"] -def test_get_rag_file_flattened(): +def test_import_rag_files_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: + with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = vertex_rag_data.RagFile() + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_rag_file( - name="name_value", + client.import_rag_files( + parent="parent_value", + import_rag_files_config=vertex_rag_data.ImportRagFilesConfig( + gcs_source=io.GcsSource(uris=["uris_value"]) + ), ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + arg = args[0].import_rag_files_config + mock_val = vertex_rag_data.ImportRagFilesConfig( + gcs_source=io.GcsSource(uris=["uris_value"]) + ) assert arg == mock_val -def test_get_rag_file_flattened_error(): +def test_import_rag_files_flattened_error(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -4076,43 +4056,54 @@ def test_get_rag_file_flattened_error(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.get_rag_file( - vertex_rag_data_service.GetRagFileRequest(), - name="name_value", + client.import_rag_files( + vertex_rag_data_service.ImportRagFilesRequest(), + parent="parent_value", + import_rag_files_config=vertex_rag_data.ImportRagFilesConfig( + gcs_source=io.GcsSource(uris=["uris_value"]) + ), ) @pytest.mark.asyncio -async def test_get_rag_file_flattened_async(): +async def test_import_rag_files_flattened_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: + with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = vertex_rag_data.RagFile() + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data.RagFile() + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_rag_file( - name="name_value", + response = await client.import_rag_files( + parent="parent_value", + import_rag_files_config=vertex_rag_data.ImportRagFilesConfig( + gcs_source=io.GcsSource(uris=["uris_value"]) + ), ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + arg = args[0].import_rag_files_config + mock_val = vertex_rag_data.ImportRagFilesConfig( + gcs_source=io.GcsSource(uris=["uris_value"]) + ) assert arg == mock_val @pytest.mark.asyncio -async def test_get_rag_file_flattened_error_async(): +async def test_import_rag_files_flattened_error_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -4120,20 +4111,23 @@ async def test_get_rag_file_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.get_rag_file( - vertex_rag_data_service.GetRagFileRequest(), - name="name_value", + await client.import_rag_files( + vertex_rag_data_service.ImportRagFilesRequest(), + parent="parent_value", + import_rag_files_config=vertex_rag_data.ImportRagFilesConfig( + gcs_source=io.GcsSource(uris=["uris_value"]) + ), ) @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.ListRagFilesRequest, + vertex_rag_data_service.GetRagFileRequest, dict, ], ) -def test_list_rag_files(request_type, transport: str = "grpc"): +def test_get_rag_file(request_type, transport: str = "grpc"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -4144,25 +4138,35 @@ def test_list_rag_files(request_type, transport: str = "grpc"): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = vertex_rag_data_service.ListRagFilesResponse( - next_page_token="next_page_token_value", + call.return_value = vertex_rag_data.RagFile( + name="name_value", + display_name="display_name_value", + description="description_value", + size_bytes=1089, + rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT, ) - response = client.list_rag_files(request) + response = client.get_rag_file(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.ListRagFilesRequest() + request = vertex_rag_data_service.GetRagFileRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListRagFilesPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, vertex_rag_data.RagFile) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert response.size_bytes == 1089 + assert ( + response.rag_file_type == vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT + ) -def test_list_rag_files_empty_call(): +def test_get_rag_file_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceClient( @@ -4171,17 +4175,17 @@ def test_list_rag_files_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: call.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client.list_rag_files() + client.get_rag_file() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.ListRagFilesRequest() + assert args[0] == vertex_rag_data_service.GetRagFileRequest() -def test_list_rag_files_non_empty_request_with_auto_populated_field(): +def test_get_rag_file_non_empty_request_with_auto_populated_field(): # This test is a coverage failsafe to make sure that UUID4 fields are # automatically populated, according to AIP-4235, with non-empty requests. client = VertexRagDataServiceClient( @@ -4192,26 +4196,24 @@ def test_list_rag_files_non_empty_request_with_auto_populated_field(): # Populate all string fields in the request which are not UUID4 # since we want to check that UUID4 are populated automatically # if they meet the requirements of AIP 4235. - request = vertex_rag_data_service.ListRagFilesRequest( - parent="parent_value", - page_token="page_token_value", + request = vertex_rag_data_service.GetRagFileRequest( + name="name_value", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: call.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client.list_rag_files(request=request) + client.get_rag_file(request=request) call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.ListRagFilesRequest( - parent="parent_value", - page_token="page_token_value", + assert args[0] == vertex_rag_data_service.GetRagFileRequest( + name="name_value", ) -def test_list_rag_files_use_cached_wrapped_rpc(): +def test_get_rag_file_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -4225,21 +4227,21 @@ def test_list_rag_files_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.list_rag_files in client._transport._wrapped_methods + assert client._transport.get_rag_file in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[client._transport.list_rag_files] = mock_rpc + client._transport._wrapped_methods[client._transport.get_rag_file] = mock_rpc request = {} - client.list_rag_files(request) + client.get_rag_file(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.list_rag_files(request) + client.get_rag_file(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -4247,7 +4249,7 @@ def test_list_rag_files_use_cached_wrapped_rpc(): @pytest.mark.asyncio -async def test_list_rag_files_empty_call_async(): +async def test_get_rag_file_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceAsyncClient( @@ -4256,21 +4258,25 @@ async def test_list_rag_files_empty_call_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data_service.ListRagFilesResponse( - next_page_token="next_page_token_value", + vertex_rag_data.RagFile( + name="name_value", + display_name="display_name_value", + description="description_value", + size_bytes=1089, + rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT, ) ) - response = await client.list_rag_files() + response = await client.get_rag_file() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.ListRagFilesRequest() + assert args[0] == vertex_rag_data_service.GetRagFileRequest() @pytest.mark.asyncio -async def test_list_rag_files_async_use_cached_wrapped_rpc( +async def test_get_rag_file_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", ): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, @@ -4287,7 +4293,7 @@ async def test_list_rag_files_async_use_cached_wrapped_rpc( # Ensure method has been cached assert ( - client._client._transport.list_rag_files + client._client._transport.get_rag_file in client._client._transport._wrapped_methods ) @@ -4295,16 +4301,16 @@ async def test_list_rag_files_async_use_cached_wrapped_rpc( mock_rpc = mock.AsyncMock() mock_rpc.return_value = mock.Mock() client._client._transport._wrapped_methods[ - client._client._transport.list_rag_files + client._client._transport.get_rag_file ] = mock_rpc request = {} - await client.list_rag_files(request) + await client.get_rag_file(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - await client.list_rag_files(request) + await client.get_rag_file(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -4312,9 +4318,9 @@ async def test_list_rag_files_async_use_cached_wrapped_rpc( @pytest.mark.asyncio -async def test_list_rag_files_async( +async def test_get_rag_file_async( transport: str = "grpc_asyncio", - request_type=vertex_rag_data_service.ListRagFilesRequest, + request_type=vertex_rag_data_service.GetRagFileRequest, ): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -4326,46 +4332,56 @@ async def test_list_rag_files_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data_service.ListRagFilesResponse( - next_page_token="next_page_token_value", + vertex_rag_data.RagFile( + name="name_value", + display_name="display_name_value", + description="description_value", + size_bytes=1089, + rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT, ) ) - response = await client.list_rag_files(request) + response = await client.get_rag_file(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.ListRagFilesRequest() + request = vertex_rag_data_service.GetRagFileRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListRagFilesAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, vertex_rag_data.RagFile) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert response.size_bytes == 1089 + assert ( + response.rag_file_type == vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT + ) @pytest.mark.asyncio -async def test_list_rag_files_async_from_dict(): - await test_list_rag_files_async(request_type=dict) +async def test_get_rag_file_async_from_dict(): + await test_get_rag_file_async(request_type=dict) -def test_list_rag_files_field_headers(): +def test_get_rag_file_field_headers(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.ListRagFilesRequest() + request = vertex_rag_data_service.GetRagFileRequest() - request.parent = "parent_value" + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: - call.return_value = vertex_rag_data_service.ListRagFilesResponse() - client.list_rag_files(request) + with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: + call.return_value = vertex_rag_data.RagFile() + client.get_rag_file(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -4376,28 +4392,28 @@ def test_list_rag_files_field_headers(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "parent=parent_value", + "name=name_value", ) in kw["metadata"] @pytest.mark.asyncio -async def test_list_rag_files_field_headers_async(): +async def test_get_rag_file_field_headers_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.ListRagFilesRequest() + request = vertex_rag_data_service.GetRagFileRequest() - request.parent = "parent_value" + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data_service.ListRagFilesResponse() + vertex_rag_data.RagFile() ) - await client.list_rag_files(request) + await client.get_rag_file(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -4408,35 +4424,35 @@ async def test_list_rag_files_field_headers_async(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "parent=parent_value", + "name=name_value", ) in kw["metadata"] -def test_list_rag_files_flattened(): +def test_get_rag_file_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = vertex_rag_data_service.ListRagFilesResponse() + call.return_value = vertex_rag_data.RagFile() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_rag_files( - parent="parent_value", + client.get_rag_file( + name="name_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - arg = args[0].parent - mock_val = "parent_value" + arg = args[0].name + mock_val = "name_value" assert arg == mock_val -def test_list_rag_files_flattened_error(): +def test_get_rag_file_flattened_error(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -4444,43 +4460,43 @@ def test_list_rag_files_flattened_error(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.list_rag_files( - vertex_rag_data_service.ListRagFilesRequest(), - parent="parent_value", + client.get_rag_file( + vertex_rag_data_service.GetRagFileRequest(), + name="name_value", ) @pytest.mark.asyncio -async def test_list_rag_files_flattened_async(): +async def test_get_rag_file_flattened_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = vertex_rag_data_service.ListRagFilesResponse() + call.return_value = vertex_rag_data.RagFile() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vertex_rag_data_service.ListRagFilesResponse() + vertex_rag_data.RagFile() ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_rag_files( - parent="parent_value", + response = await client.get_rag_file( + name="name_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - arg = args[0].parent - mock_val = "parent_value" + arg = args[0].name + mock_val = "name_value" assert arg == mock_val @pytest.mark.asyncio -async def test_list_rag_files_flattened_error_async(): +async def test_get_rag_file_flattened_error_async(): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -4488,320 +4504,126 @@ async def test_list_rag_files_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.list_rag_files( - vertex_rag_data_service.ListRagFilesRequest(), - parent="parent_value", + await client.get_rag_file( + vertex_rag_data_service.GetRagFileRequest(), + name="name_value", ) -def test_list_rag_files_pager(transport_name: str = "grpc"): +@pytest.mark.parametrize( + "request_type", + [ + vertex_rag_data_service.ListRagFilesRequest, + dict, + ], +) +def test_list_rag_files(request_type, transport: str = "grpc"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport_name, + transport=transport, ) + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[ - vertex_rag_data.RagFile(), - vertex_rag_data.RagFile(), - vertex_rag_data.RagFile(), - ], - next_page_token="abc", - ), - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[], - next_page_token="def", - ), - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[ - vertex_rag_data.RagFile(), - ], - next_page_token="ghi", - ), - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[ - vertex_rag_data.RagFile(), - vertex_rag_data.RagFile(), - ], - ), - RuntimeError, - ) - - expected_metadata = () - retry = retries.Retry() - timeout = 5 - expected_metadata = tuple(expected_metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + # Designate an appropriate return value for the call. + call.return_value = vertex_rag_data_service.ListRagFilesResponse( + next_page_token="next_page_token_value", ) - pager = client.list_rag_files(request={}, retry=retry, timeout=timeout) + response = client.list_rag_files(request) - assert pager._metadata == expected_metadata - assert pager._retry == retry - assert pager._timeout == timeout + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = vertex_rag_data_service.ListRagFilesRequest() + assert args[0] == request - results = list(pager) - assert len(results) == 6 - assert all(isinstance(i, vertex_rag_data.RagFile) for i in results) + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListRagFilesPager) + assert response.next_page_token == "next_page_token_value" -def test_list_rag_files_pages(transport_name: str = "grpc"): +def test_list_rag_files_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport_name, + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[ - vertex_rag_data.RagFile(), - vertex_rag_data.RagFile(), - vertex_rag_data.RagFile(), - ], - next_page_token="abc", - ), - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[], - next_page_token="def", - ), - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[ - vertex_rag_data.RagFile(), - ], - next_page_token="ghi", - ), - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[ - vertex_rag_data.RagFile(), - vertex_rag_data.RagFile(), - ], - ), - RuntimeError, + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) - pages = list(client.list_rag_files(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token + client.list_rag_files() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == vertex_rag_data_service.ListRagFilesRequest() -@pytest.mark.asyncio -async def test_list_rag_files_async_pager(): - client = VertexRagDataServiceAsyncClient( +def test_list_rag_files_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = vertex_rag_data_service.ListRagFilesRequest( + parent="parent_value", + page_token="page_token_value", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_rag_files), "__call__", new_callable=mock.AsyncMock - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[ - vertex_rag_data.RagFile(), - vertex_rag_data.RagFile(), - vertex_rag_data.RagFile(), - ], - next_page_token="abc", - ), - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[], - next_page_token="def", - ), - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[ - vertex_rag_data.RagFile(), - ], - next_page_token="ghi", - ), - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[ - vertex_rag_data.RagFile(), - vertex_rag_data.RagFile(), - ], - ), - RuntimeError, + with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) - async_pager = await client.list_rag_files( - request={}, + client.list_rag_files(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == vertex_rag_data_service.ListRagFilesRequest( + parent="parent_value", + page_token="page_token_value", ) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: # pragma: no branch - responses.append(response) - assert len(responses) == 6 - assert all(isinstance(i, vertex_rag_data.RagFile) for i in responses) +def test_list_rag_files_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) -@pytest.mark.asyncio -async def test_list_rag_files_async_pages(): - client = VertexRagDataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_rag_files), "__call__", new_callable=mock.AsyncMock - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[ - vertex_rag_data.RagFile(), - vertex_rag_data.RagFile(), - vertex_rag_data.RagFile(), - ], - next_page_token="abc", - ), - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[], - next_page_token="def", - ), - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[ - vertex_rag_data.RagFile(), - ], - next_page_token="ghi", - ), - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[ - vertex_rag_data.RagFile(), - vertex_rag_data.RagFile(), - ], - ), - RuntimeError, - ) - pages = [] - # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` - # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 - async for page_ in ( # pragma: no branch - await client.list_rag_files(request={}) - ).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.parametrize( - "request_type", - [ - vertex_rag_data_service.DeleteRagFileRequest, - dict, - ], -) -def test_delete_rag_file(request_type, transport: str = "grpc"): - client = VertexRagDataServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - response = client.delete_rag_file(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.DeleteRagFileRequest() - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_rag_file_empty_call(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = VertexRagDataServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: - call.return_value.name = ( - "foo" # operation_request.operation in compute client(s) expect a string. - ) - client.delete_rag_file() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.DeleteRagFileRequest() - - -def test_delete_rag_file_non_empty_request_with_auto_populated_field(): - # This test is a coverage failsafe to make sure that UUID4 fields are - # automatically populated, according to AIP-4235, with non-empty requests. - client = VertexRagDataServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", - ) - - # Populate all string fields in the request which are not UUID4 - # since we want to check that UUID4 are populated automatically - # if they meet the requirements of AIP 4235. - request = vertex_rag_data_service.DeleteRagFileRequest( - name="name_value", - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: - call.return_value.name = ( - "foo" # operation_request.operation in compute client(s) expect a string. - ) - client.delete_rag_file(request=request) - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.DeleteRagFileRequest( - name="name_value", - ) - - -def test_delete_rag_file_use_cached_wrapped_rpc(): - # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, - # instead of constructing them on each call - with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: - client = VertexRagDataServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", - ) - - # Should wrap all calls on client creation - assert wrapper_fn.call_count > 0 - wrapper_fn.reset_mock() + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.delete_rag_file in client._transport._wrapped_methods + assert client._transport.list_rag_files in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[client._transport.delete_rag_file] = mock_rpc + client._transport._wrapped_methods[client._transport.list_rag_files] = mock_rpc request = {} - client.delete_rag_file(request) + client.list_rag_files(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - # Operation methods call wrapper_fn to build a cached - # client._transport.operations_client instance on first rpc call. - # Subsequent calls should use the cached wrapper - wrapper_fn.reset_mock() - - client.delete_rag_file(request) + client.list_rag_files(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -4809,7 +4631,7 @@ def test_delete_rag_file_use_cached_wrapped_rpc(): @pytest.mark.asyncio -async def test_delete_rag_file_empty_call_async(): +async def test_list_rag_files_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VertexRagDataServiceAsyncClient( @@ -4818,19 +4640,21 @@ async def test_delete_rag_file_empty_call_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: + with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + vertex_rag_data_service.ListRagFilesResponse( + next_page_token="next_page_token_value", + ) ) - response = await client.delete_rag_file() + response = await client.list_rag_files() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == vertex_rag_data_service.DeleteRagFileRequest() + assert args[0] == vertex_rag_data_service.ListRagFilesRequest() @pytest.mark.asyncio -async def test_delete_rag_file_async_use_cached_wrapped_rpc( +async def test_list_rag_files_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", ): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, @@ -4847,7 +4671,7 @@ async def test_delete_rag_file_async_use_cached_wrapped_rpc( # Ensure method has been cached assert ( - client._client._transport.delete_rag_file + client._client._transport.list_rag_files in client._client._transport._wrapped_methods ) @@ -4855,21 +4679,16 @@ async def test_delete_rag_file_async_use_cached_wrapped_rpc( mock_rpc = mock.AsyncMock() mock_rpc.return_value = mock.Mock() client._client._transport._wrapped_methods[ - client._client._transport.delete_rag_file + client._client._transport.list_rag_files ] = mock_rpc request = {} - await client.delete_rag_file(request) + await client.list_rag_files(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - # Operation methods call wrapper_fn to build a cached - # client._transport.operations_client instance on first rpc call. - # Subsequent calls should use the cached wrapper - wrapper_fn.reset_mock() - - await client.delete_rag_file(request) + await client.list_rag_files(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 @@ -4877,9 +4696,9 @@ async def test_delete_rag_file_async_use_cached_wrapped_rpc( @pytest.mark.asyncio -async def test_delete_rag_file_async( +async def test_list_rag_files_async( transport: str = "grpc_asyncio", - request_type=vertex_rag_data_service.DeleteRagFileRequest, + request_type=vertex_rag_data_service.ListRagFilesRequest, ): client = VertexRagDataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -4891,43 +4710,46 @@ async def test_delete_rag_file_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: + with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + vertex_rag_data_service.ListRagFilesResponse( + next_page_token="next_page_token_value", + ) ) - response = await client.delete_rag_file(request) + response = await client.list_rag_files(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = vertex_rag_data_service.DeleteRagFileRequest() + request = vertex_rag_data_service.ListRagFilesRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) + assert isinstance(response, pagers.ListRagFilesAsyncPager) + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio -async def test_delete_rag_file_async_from_dict(): - await test_delete_rag_file_async(request_type=dict) +async def test_list_rag_files_async_from_dict(): + await test_list_rag_files_async(request_type=dict) -def test_delete_rag_file_field_headers(): +def test_list_rag_files_field_headers(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.DeleteRagFileRequest() + request = vertex_rag_data_service.ListRagFilesRequest() - request.name = "name_value" + request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") - client.delete_rag_file(request) + with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + call.return_value = vertex_rag_data_service.ListRagFilesResponse() + client.list_rag_files(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -4938,141 +4760,1122 @@ def test_delete_rag_file_field_headers(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "name=name_value", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_rag_files_field_headers_async(): + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vertex_rag_data_service.ListRagFilesRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vertex_rag_data_service.ListRagFilesResponse() + ) + await client.list_rag_files(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", ) in kw["metadata"] -@pytest.mark.asyncio -async def test_delete_rag_file_field_headers_async(): - client = VertexRagDataServiceAsyncClient( +def test_list_rag_files_flattened(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = vertex_rag_data_service.ListRagFilesResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_rag_files( + parent="parent_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + + +def test_list_rag_files_flattened_error(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_rag_files( + vertex_rag_data_service.ListRagFilesRequest(), + parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_rag_files_flattened_async(): + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = vertex_rag_data_service.ListRagFilesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vertex_rag_data_service.ListRagFilesResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_rag_files( + parent="parent_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_list_rag_files_flattened_error_async(): + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_rag_files( + vertex_rag_data_service.ListRagFilesRequest(), + parent="parent_value", + ) + + +def test_list_rag_files_pager(transport_name: str = "grpc"): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[ + vertex_rag_data.RagFile(), + vertex_rag_data.RagFile(), + vertex_rag_data.RagFile(), + ], + next_page_token="abc", + ), + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[], + next_page_token="def", + ), + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[ + vertex_rag_data.RagFile(), + ], + next_page_token="ghi", + ), + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[ + vertex_rag_data.RagFile(), + vertex_rag_data.RagFile(), + ], + ), + RuntimeError, + ) + + expected_metadata = () + retry = retries.Retry() + timeout = 5 + expected_metadata = tuple(expected_metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_rag_files(request={}, retry=retry, timeout=timeout) + + assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, vertex_rag_data.RagFile) for i in results) + + +def test_list_rag_files_pages(transport_name: str = "grpc"): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[ + vertex_rag_data.RagFile(), + vertex_rag_data.RagFile(), + vertex_rag_data.RagFile(), + ], + next_page_token="abc", + ), + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[], + next_page_token="def", + ), + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[ + vertex_rag_data.RagFile(), + ], + next_page_token="ghi", + ), + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[ + vertex_rag_data.RagFile(), + vertex_rag_data.RagFile(), + ], + ), + RuntimeError, + ) + pages = list(client.list_rag_files(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_rag_files_async_pager(): + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_rag_files), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[ + vertex_rag_data.RagFile(), + vertex_rag_data.RagFile(), + vertex_rag_data.RagFile(), + ], + next_page_token="abc", + ), + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[], + next_page_token="def", + ), + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[ + vertex_rag_data.RagFile(), + ], + next_page_token="ghi", + ), + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[ + vertex_rag_data.RagFile(), + vertex_rag_data.RagFile(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_rag_files( + request={}, + ) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, vertex_rag_data.RagFile) for i in responses) + + +@pytest.mark.asyncio +async def test_list_rag_files_async_pages(): + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_rag_files), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[ + vertex_rag_data.RagFile(), + vertex_rag_data.RagFile(), + vertex_rag_data.RagFile(), + ], + next_page_token="abc", + ), + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[], + next_page_token="def", + ), + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[ + vertex_rag_data.RagFile(), + ], + next_page_token="ghi", + ), + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[ + vertex_rag_data.RagFile(), + vertex_rag_data.RagFile(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_rag_files(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + vertex_rag_data_service.DeleteRagFileRequest, + dict, + ], +) +def test_delete_rag_file(request_type, transport: str = "grpc"): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.delete_rag_file(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = vertex_rag_data_service.DeleteRagFileRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_rag_file_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.delete_rag_file() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == vertex_rag_data_service.DeleteRagFileRequest() + + +def test_delete_rag_file_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = vertex_rag_data_service.DeleteRagFileRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.delete_rag_file(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == vertex_rag_data_service.DeleteRagFileRequest( + name="name_value", + ) + + +def test_delete_rag_file_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_rag_file in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_rag_file] = mock_rpc + request = {} + client.delete_rag_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_rag_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_rag_file_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.delete_rag_file() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == vertex_rag_data_service.DeleteRagFileRequest() + + +@pytest.mark.asyncio +async def test_delete_rag_file_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_rag_file + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_rag_file + ] = mock_rpc + + request = {} + await client.delete_rag_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_rag_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_rag_file_async( + transport: str = "grpc_asyncio", + request_type=vertex_rag_data_service.DeleteRagFileRequest, +): + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.delete_rag_file(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = vertex_rag_data_service.DeleteRagFileRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_rag_file_async_from_dict(): + await test_delete_rag_file_async(request_type=dict) + + +def test_delete_rag_file_field_headers(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vertex_rag_data_service.DeleteRagFileRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.delete_rag_file(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_rag_file_field_headers_async(): + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vertex_rag_data_service.DeleteRagFileRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.delete_rag_file(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_delete_rag_file_flattened(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_rag_file( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_delete_rag_file_flattened_error(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_rag_file( + vertex_rag_data_service.DeleteRagFileRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_rag_file_flattened_async(): + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_rag_file( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_delete_rag_file_flattened_error_async(): + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_rag_file( + vertex_rag_data_service.DeleteRagFileRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + vertex_rag_data_service.CreateRagCorpusRequest, + dict, + ], +) +def test_create_rag_corpus_rest(request_type): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request_init["rag_corpus"] = { + "name": "name_value", + "display_name": "display_name_value", + "description": "description_value", + "rag_embedding_model_config": { + "vertex_prediction_endpoint": { + "endpoint": "endpoint_value", + "model": "model_value", + "model_version_id": "model_version_id_value", + }, + "hybrid_search_config": { + "sparse_embedding_config": { + "bm25": {"multilingual": True, "k1": 0.156, "b": 0.98} + }, + "dense_embedding_model_prediction_endpoint": {}, + }, + }, + "rag_vector_db_config": { + "rag_managed_db": {}, + "weaviate": { + "http_endpoint": "http_endpoint_value", + "collection_name": "collection_name_value", + }, + "vertex_feature_store": { + "feature_view_resource_name": "feature_view_resource_name_value" + }, + "api_auth": { + "api_key_config": { + "api_key_secret_version": "api_key_secret_version_value" + } + }, + }, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "corpus_status": {"state": 1, "error_status": "error_status_value"}, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = vertex_rag_data_service.CreateRagCorpusRequest.meta.fields[ + "rag_corpus" + ] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["rag_corpus"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["rag_corpus"][field])): + del request_init["rag_corpus"][field][i][subfield] + else: + del request_init["rag_corpus"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.create_rag_corpus(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_create_rag_corpus_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_rag_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_rag_corpus + ] = mock_rpc + + request = {} + client.create_rag_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_rag_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_create_rag_corpus_rest_required_fields( + request_type=vertex_rag_data_service.CreateRagCorpusRequest, +): + transport_class = transports.VertexRagDataServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_rag_corpus._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_rag_corpus._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.create_rag_corpus(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_rag_corpus_rest_unset_required_fields(): + transport = transports.VertexRagDataServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.create_rag_corpus._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "ragCorpus", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_rag_corpus_rest_interceptors(null_interceptor): + transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.VertexRagDataServiceRestInterceptor(), ) + client = VertexRagDataServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.VertexRagDataServiceRestInterceptor, "post_create_rag_corpus" + ) as post, mock.patch.object( + transports.VertexRagDataServiceRestInterceptor, "pre_create_rag_corpus" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = vertex_rag_data_service.CreateRagCorpusRequest.pb( + vertex_rag_data_service.CreateRagCorpusRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = vertex_rag_data_service.DeleteRagFileRequest() + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) - request.name = "name_value" + request = vertex_rag_data_service.CreateRagCorpusRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") + client.create_rag_corpus( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], ) - await client.delete_rag_file(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - "x-goog-request-params", - "name=name_value", - ) in kw["metadata"] + pre.assert_called_once() + post.assert_called_once() -def test_delete_rag_file_flattened(): +def test_create_rag_corpus_rest_bad_request( + transport: str = "rest", request_type=vertex_rag_data_service.CreateRagCorpusRequest +): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.delete_rag_file( - name="name_value", - ) + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request = request_type(**request_init) - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" - assert arg == mock_val + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.create_rag_corpus(request) -def test_delete_rag_file_flattened_error(): +def test_create_rag_corpus_rest_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_rag_file( - vertex_rag_data_service.DeleteRagFileRequest(), - name="name_value", - ) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/locations/sample2"} -@pytest.mark.asyncio -async def test_delete_rag_file_flattened_async(): - client = VertexRagDataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + rag_corpus=vertex_rag_data.RagCorpus(name="name_value"), + ) + mock_args.update(sample_request) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.delete_rag_file( - name="name_value", - ) + client.create_rag_corpus(**mock_args) # Establish that the underlying call was made with the expected # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" - assert arg == mock_val + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1beta1/{parent=projects/*/locations/*}/ragCorpora" + % client.transport._host, + args[1], + ) -@pytest.mark.asyncio -async def test_delete_rag_file_flattened_error_async(): - client = VertexRagDataServiceAsyncClient( +def test_create_rag_corpus_rest_flattened_error(transport: str = "rest"): + client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.delete_rag_file( - vertex_rag_data_service.DeleteRagFileRequest(), - name="name_value", + client.create_rag_corpus( + vertex_rag_data_service.CreateRagCorpusRequest(), + parent="parent_value", + rag_corpus=vertex_rag_data.RagCorpus(name="name_value"), ) +def test_create_rag_corpus_rest_error(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.CreateRagCorpusRequest, + vertex_rag_data_service.UpdateRagCorpusRequest, dict, ], ) -def test_create_rag_corpus_rest(request_type): +def test_update_rag_corpus_rest(request_type): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "rag_corpus": {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} + } request_init["rag_corpus"] = { - "name": "name_value", + "name": "projects/sample1/locations/sample2/ragCorpora/sample3", "display_name": "display_name_value", "description": "description_value", "rag_embedding_model_config": { @@ -5080,17 +5883,39 @@ def test_create_rag_corpus_rest(request_type): "endpoint": "endpoint_value", "model": "model_value", "model_version_id": "model_version_id_value", - } + }, + "hybrid_search_config": { + "sparse_embedding_config": { + "bm25": {"multilingual": True, "k1": 0.156, "b": 0.98} + }, + "dense_embedding_model_prediction_endpoint": {}, + }, + }, + "rag_vector_db_config": { + "rag_managed_db": {}, + "weaviate": { + "http_endpoint": "http_endpoint_value", + "collection_name": "collection_name_value", + }, + "vertex_feature_store": { + "feature_view_resource_name": "feature_view_resource_name_value" + }, + "api_auth": { + "api_key_config": { + "api_key_secret_version": "api_key_secret_version_value" + } + }, }, "create_time": {"seconds": 751, "nanos": 543}, "update_time": {}, + "corpus_status": {"state": 1, "error_status": "error_status_value"}, } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency # See https://github.com/googleapis/gapic-generator-python/issues/1748 # Determine if the message type is proto-plus or protobuf - test_field = vertex_rag_data_service.CreateRagCorpusRequest.meta.fields[ + test_field = vertex_rag_data_service.UpdateRagCorpusRequest.meta.fields[ "rag_corpus" ] @@ -5168,13 +5993,13 @@ def get_message_fields(field): response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.create_rag_corpus(request) + response = client.update_rag_corpus(request) # Establish that the response is the type that we expect. assert response.operation.name == "operations/spam" -def test_create_rag_corpus_rest_use_cached_wrapped_rpc(): +def test_update_rag_corpus_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -5188,7 +6013,7 @@ def test_create_rag_corpus_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.create_rag_corpus in client._transport._wrapped_methods + assert client._transport.update_rag_corpus in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() @@ -5196,11 +6021,11 @@ def test_create_rag_corpus_rest_use_cached_wrapped_rpc(): "foo" # operation_request.operation in compute client(s) expect a string. ) client._transport._wrapped_methods[ - client._transport.create_rag_corpus + client._transport.update_rag_corpus ] = mock_rpc request = {} - client.create_rag_corpus(request) + client.update_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 @@ -5209,20 +6034,19 @@ def test_create_rag_corpus_rest_use_cached_wrapped_rpc(): # subsequent calls should use the cached wrapper wrapper_fn.reset_mock() - client.create_rag_corpus(request) + client.update_rag_corpus(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_create_rag_corpus_rest_required_fields( - request_type=vertex_rag_data_service.CreateRagCorpusRequest, +def test_update_rag_corpus_rest_required_fields( + request_type=vertex_rag_data_service.UpdateRagCorpusRequest, ): transport_class = transports.VertexRagDataServiceRestTransport request_init = {} - request_init["parent"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -5233,21 +6057,17 @@ def test_create_rag_corpus_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).create_rag_corpus._get_unset_required_fields(jsonified_request) + ).update_rag_corpus._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = "parent_value" - unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).create_rag_corpus._get_unset_required_fields(jsonified_request) + ).update_rag_corpus._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "parent" in jsonified_request - assert jsonified_request["parent"] == "parent_value" client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5268,7 +6088,7 @@ def test_create_rag_corpus_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "post", + "method": "patch", "query_params": pb_request, } transcode_result["body"] = pb_request @@ -5281,32 +6101,24 @@ def test_create_rag_corpus_rest_required_fields( response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.create_rag_corpus(request) + response = client.update_rag_corpus(request) expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_create_rag_corpus_rest_unset_required_fields(): +def test_update_rag_corpus_rest_unset_required_fields(): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.create_rag_corpus._get_unset_required_fields({}) - assert set(unset_fields) == ( - set(()) - & set( - ( - "parent", - "ragCorpus", - ) - ) - ) + unset_fields = transport.update_rag_corpus._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("ragCorpus",))) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_create_rag_corpus_rest_interceptors(null_interceptor): +def test_update_rag_corpus_rest_interceptors(null_interceptor): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -5321,14 +6133,14 @@ def test_create_rag_corpus_rest_interceptors(null_interceptor): ) as transcode, mock.patch.object( operation.Operation, "_set_result_from_operation" ), mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "post_create_rag_corpus" + transports.VertexRagDataServiceRestInterceptor, "post_update_rag_corpus" ) as post, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "pre_create_rag_corpus" + transports.VertexRagDataServiceRestInterceptor, "pre_update_rag_corpus" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = vertex_rag_data_service.CreateRagCorpusRequest.pb( - vertex_rag_data_service.CreateRagCorpusRequest() + pb_message = vertex_rag_data_service.UpdateRagCorpusRequest.pb( + vertex_rag_data_service.UpdateRagCorpusRequest() ) transcode.return_value = { "method": "post", @@ -5344,7 +6156,7 @@ def test_create_rag_corpus_rest_interceptors(null_interceptor): operations_pb2.Operation() ) - request = vertex_rag_data_service.CreateRagCorpusRequest() + request = vertex_rag_data_service.UpdateRagCorpusRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), @@ -5352,7 +6164,7 @@ def test_create_rag_corpus_rest_interceptors(null_interceptor): pre.return_value = request, metadata post.return_value = operations_pb2.Operation() - client.create_rag_corpus( + client.update_rag_corpus( request, metadata=[ ("key", "val"), @@ -5364,8 +6176,8 @@ def test_create_rag_corpus_rest_interceptors(null_interceptor): post.assert_called_once() -def test_create_rag_corpus_rest_bad_request( - transport: str = "rest", request_type=vertex_rag_data_service.CreateRagCorpusRequest +def test_update_rag_corpus_rest_bad_request( + transport: str = "rest", request_type=vertex_rag_data_service.UpdateRagCorpusRequest ): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5373,7 +6185,9 @@ def test_create_rag_corpus_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "rag_corpus": {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -5385,10 +6199,10 @@ def test_create_rag_corpus_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.create_rag_corpus(request) + client.update_rag_corpus(request) -def test_create_rag_corpus_rest_flattened(): +def test_update_rag_corpus_rest_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -5400,11 +6214,14 @@ def test_create_rag_corpus_rest_flattened(): return_value = operations_pb2.Operation(name="operations/spam") # get arguments that satisfy an http rule for this method - sample_request = {"parent": "projects/sample1/locations/sample2"} + sample_request = { + "rag_corpus": { + "name": "projects/sample1/locations/sample2/ragCorpora/sample3" + } + } # get truthy value for each flattened field mock_args = dict( - parent="parent_value", rag_corpus=vertex_rag_data.RagCorpus(name="name_value"), ) mock_args.update(sample_request) @@ -5416,20 +6233,20 @@ def test_create_rag_corpus_rest_flattened(): response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.create_rag_corpus(**mock_args) + client.update_rag_corpus(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{parent=projects/*/locations/*}/ragCorpora" + "%s/v1beta1/{rag_corpus.name=projects/*/locations/*/ragCorpora/*}" % client.transport._host, args[1], ) -def test_create_rag_corpus_rest_flattened_error(transport: str = "rest"): +def test_update_rag_corpus_rest_flattened_error(transport: str = "rest"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -5438,14 +6255,13 @@ def test_create_rag_corpus_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.create_rag_corpus( - vertex_rag_data_service.CreateRagCorpusRequest(), - parent="parent_value", + client.update_rag_corpus( + vertex_rag_data_service.UpdateRagCorpusRequest(), rag_corpus=vertex_rag_data.RagCorpus(name="name_value"), ) -def test_create_rag_corpus_rest_error(): +def test_update_rag_corpus_rest_error(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -8241,6 +9057,7 @@ def test_vertex_rag_data_service_base_transport(): # raise NotImplementedError. methods = ( "create_rag_corpus", + "update_rag_corpus", "get_rag_corpus", "list_rag_corpora", "delete_rag_corpus", @@ -8546,6 +9363,9 @@ def test_vertex_rag_data_service_client_transport_session_collision(transport_na session1 = client1.transport.create_rag_corpus._session session2 = client2.transport.create_rag_corpus._session assert session1 != session2 + session1 = client1.transport.update_rag_corpus._session + session2 = client2.transport.update_rag_corpus._session + assert session1 != session2 session1 = client1.transport.get_rag_corpus._session session2 = client2.transport.get_rag_corpus._session assert session1 != session2 diff --git a/tests/unit/vertex_rag/test_rag_constants.py b/tests/unit/vertex_rag/test_rag_constants.py index a525d2784f..cd2a74e30c 100644 --- a/tests/unit/vertex_rag/test_rag_constants.py +++ b/tests/unit/vertex_rag/test_rag_constants.py @@ -18,7 +18,17 @@ from google.cloud import aiplatform -from vertexai.preview import rag +from vertexai.preview.rag import ( + EmbeddingModelConfig, + RagCorpus, + RagFile, + RagResource, + SlackChannelsSource, + SlackChannel, + JiraSource, + JiraQuery, + Weaviate, +) from google.cloud.aiplatform_v1beta1 import ( GoogleDriveSource, RagFileChunkingConfig, @@ -32,6 +42,7 @@ SlackSource as GapicSlackSource, RagContexts, RetrieveContextsResponse, + RagVectorDbConfig, ) from google.cloud.aiplatform_v1beta1.types import api_auth from google.protobuf import timestamp_pb2 @@ -47,6 +58,16 @@ TEST_RAG_CORPUS_RESOURCE_NAME = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragCorpora/{TEST_RAG_CORPUS_ID}" # RagCorpus +TEST_WEAVIATE_HTTP_ENDPOINT = "test.weaviate.com" +TEST_WEAVIATE_COLLECTION_NAME = "test-collection" +TEST_WEAVIATE_API_KEY_SECRET_VERSION = ( + "projects/test-project/secrets/test-secret/versions/1" +) +TEST_WEAVIATE_CONFIG = Weaviate( + weaviate_http_endpoint=TEST_WEAVIATE_HTTP_ENDPOINT, + collection_name=TEST_WEAVIATE_COLLECTION_NAME, + api_key=TEST_WEAVIATE_API_KEY_SECRET_VERSION, +) TEST_GAPIC_RAG_CORPUS = GapicRagCorpus( name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, @@ -57,15 +78,37 @@ TEST_PROJECT, TEST_REGION ) ) -TEST_EMBEDDING_MODEL_CONFIG = rag.EmbeddingModelConfig( +TEST_GAPIC_RAG_CORPUS_WEAVIATE = GapicRagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + rag_vector_db_config=RagVectorDbConfig( + weaviate=RagVectorDbConfig.Weaviate( + http_endpoint=TEST_WEAVIATE_HTTP_ENDPOINT, + collection_name=TEST_WEAVIATE_COLLECTION_NAME, + ), + api_auth=api_auth.ApiAuth( + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=TEST_WEAVIATE_API_KEY_SECRET_VERSION + ), + ), + ), +) +TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig( publisher_model="publishers/google/models/textembedding-gecko", ) -TEST_RAG_CORPUS = rag.RagCorpus( +TEST_RAG_CORPUS = RagCorpus( name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, description=TEST_CORPUS_DISCRIPTION, embedding_model_config=TEST_EMBEDDING_MODEL_CONFIG, ) +TEST_RAG_CORPUS_WEAVIATE = RagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + vector_db=TEST_WEAVIATE_CONFIG, +) TEST_PAGE_TOKEN = "test-page-token" # RagFiles @@ -165,7 +208,7 @@ display_name=TEST_FILE_DISPLAY_NAME, description=TEST_FILE_DESCRIPTION, ) -TEST_RAG_FILE = rag.RagFile( +TEST_RAG_FILE = RagFile( name=TEST_RAG_FILE_RESOURCE_NAME, display_name=TEST_FILE_DISPLAY_NAME, description=TEST_FILE_DESCRIPTION, @@ -183,15 +226,15 @@ TEST_SLACK_API_KEY_SECRET_VERSION_2 = ( "projects/test-project/secrets/test-secret/versions/2" ) -TEST_SLACK_SOURCE = rag.SlackChannelsSource( +TEST_SLACK_SOURCE = SlackChannelsSource( channels=[ - rag.SlackChannel( + SlackChannel( channel_id=TEST_SLACK_CHANNEL_ID, api_key=TEST_SLACK_API_KEY_SECRET_VERSION, start_time=TEST_SLACK_START_TIME, end_time=TEST_SLACK_END_TIME, ), - rag.SlackChannel( + SlackChannel( channel_id=TEST_SLACK_CHANNEL_ID_2, api_key=TEST_SLACK_API_KEY_SECRET_VERSION_2, ), @@ -241,9 +284,9 @@ TEST_JIRA_API_KEY_SECRET_VERSION = ( "projects/test-project/secrets/test-secret/versions/1" ) -TEST_JIRA_SOURCE = rag.JiraSource( +TEST_JIRA_SOURCE = JiraSource( queries=[ - rag.JiraQuery( + JiraQuery( email=TEST_JIRA_EMAIL, jira_projects=[TEST_JIRA_PROJECT], custom_queries=[TEST_JIRA_CUSTOM_QUERY], @@ -286,11 +329,11 @@ ] ) TEST_RETRIEVAL_RESPONSE = RetrieveContextsResponse(contexts=TEST_CONTEXTS) -TEST_RAG_RESOURCE = rag.RagResource( +TEST_RAG_RESOURCE = RagResource( rag_corpus=TEST_RAG_CORPUS_RESOURCE_NAME, rag_file_ids=[TEST_RAG_FILE_ID], ) -TEST_RAG_RESOURCE_INVALID_NAME = rag.RagResource( +TEST_RAG_RESOURCE_INVALID_NAME = RagResource( rag_corpus="213lkj-1/23jkl/", rag_file_ids=[TEST_RAG_FILE_ID], ) diff --git a/tests/unit/vertex_rag/test_rag_data.py b/tests/unit/vertex_rag/test_rag_data.py index 2b789d6513..fd243920b1 100644 --- a/tests/unit/vertex_rag/test_rag_data.py +++ b/tests/unit/vertex_rag/test_rag_data.py @@ -47,6 +47,21 @@ def create_rag_corpus_mock(): yield create_rag_corpus_mock +@pytest.fixture +def create_rag_corpus_mock_weaviate(): + with mock.patch.object( + VertexRagDataServiceClient, + "create_rag_corpus", + ) as create_rag_corpus_mock_weaviate: + create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) + create_rag_corpus_lro_mock.done.return_value = True + create_rag_corpus_lro_mock.result.return_value = ( + tc.TEST_GAPIC_RAG_CORPUS_WEAVIATE + ) + create_rag_corpus_mock_weaviate.return_value = create_rag_corpus_lro_mock + yield create_rag_corpus_mock_weaviate + + @pytest.fixture def list_rag_corpora_pager_mock(): with mock.patch.object( @@ -141,6 +156,7 @@ def list_rag_files_pager_mock(): def rag_corpus_eq(returned_corpus, expected_corpus): assert returned_corpus.name == expected_corpus.name assert returned_corpus.display_name == expected_corpus.display_name + assert returned_corpus.vector_db.__eq__(expected_corpus.vector_db) def rag_file_eq(returned_file, expected_file): @@ -191,6 +207,15 @@ def test_create_corpus_success(self): rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS) + @pytest.mark.usefixtures("create_rag_corpus_mock_weaviate") + def test_create_corpus_weaviate_success(self): + rag_corpus = rag.create_corpus( + display_name=tc.TEST_CORPUS_DISPLAY_NAME, + vector_db=tc.TEST_WEAVIATE_CONFIG, + ) + + rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_WEAVIATE) + @pytest.mark.usefixtures("rag_data_client_mock_exception") def test_create_corpus_failure(self): with pytest.raises(RuntimeError) as e: diff --git a/tests/unit/vertex_rag/test_rag_retrieval.py b/tests/unit/vertex_rag/test_rag_retrieval.py index 5aec0bd72a..07b2180422 100644 --- a/tests/unit/vertex_rag/test_rag_retrieval.py +++ b/tests/unit/vertex_rag/test_rag_retrieval.py @@ -76,6 +76,7 @@ def test_retrieval_query_rag_resources_success(self): text=tc.TEST_QUERY_TEXT, similarity_top_k=2, vector_distance_threshold=0.5, + vector_search_alpha=0.5, ) retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE) diff --git a/tests/unit/vertex_ray/test_cluster_init.py b/tests/unit/vertex_ray/test_cluster_init.py index a4ddd5f818..864b6f6f18 100644 --- a/tests/unit/vertex_ray/test_cluster_init.py +++ b/tests/unit/vertex_ray/test_cluster_init.py @@ -384,6 +384,7 @@ def test_create_ray_cluster_2_pools_custom_images_success( head_node_type=tc.ClusterConstants.TEST_HEAD_NODE_TYPE_2_POOLS_CUSTOM_IMAGE, worker_node_types=tc.ClusterConstants.TEST_WORKER_NODE_TYPES_2_POOLS_CUSTOM_IMAGE, network=tc.ProjectConstants.TEST_VPC_NETWORK, + reserved_ip_ranges=["vertex-dedicated-range"], cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID, ) diff --git a/tests/unit/vertex_ray/test_constants.py b/tests/unit/vertex_ray/test_constants.py index 64018e930d..26537391bf 100644 --- a/tests/unit/vertex_ray/test_constants.py +++ b/tests/unit/vertex_ray/test_constants.py @@ -51,12 +51,17 @@ from google.cloud.aiplatform_v1beta1.types.service_networking import ( PscInterfaceConfig, ) +import ray import pytest rovminversion = pytest.mark.skipif( sys.version_info > (3, 10), reason="Requires python3.10 or lower" ) +# TODO(b/363340317) +xgbversion = pytest.mark.skipif( + ray.__version__ != "2.9.3", reason="Requires xgboost 1.7 or higher" +) @dataclasses.dataclass(frozen=True) @@ -347,6 +352,7 @@ class ClusterConstants: ), psc_interface_config=None, network=ProjectConstants.TEST_VPC_NETWORK, + reserved_ip_ranges=["vertex-dedicated-range"], ) # Responses TEST_RESOURCE_POOL_2.replica_count = 1 @@ -366,6 +372,7 @@ class ClusterConstants: network_attachment=TEST_PSC_NETWORK_ATTACHMENT ), network=None, + reserved_ip_ranges=None, resource_runtime=ResourceRuntime( access_uris={ "RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS, @@ -386,6 +393,7 @@ class ClusterConstants: ), ), network=ProjectConstants.TEST_VPC_NETWORK, + reserved_ip_ranges=["vertex-dedicated-range"], resource_runtime=ResourceRuntime( access_uris={ "RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS, @@ -399,6 +407,7 @@ class ClusterConstants: python_version="3.10", ray_version="2.9", network=ProjectConstants.TEST_VPC_NETWORK, + reserved_ip_ranges=None, service_account=None, state="RUNNING", head_node_type=TEST_HEAD_NODE_TYPE_1_POOL, @@ -412,6 +421,7 @@ class ClusterConstants: python_version="3.10", ray_version="2.9", network="", + reserved_ip_ranges="", service_account=None, state="RUNNING", head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS, @@ -424,6 +434,7 @@ class ClusterConstants: TEST_CLUSTER_CUSTOM_IMAGE = Cluster( cluster_resource_name=TEST_VERTEX_RAY_PR_ADDRESS, network=ProjectConstants.TEST_VPC_NETWORK, + reserved_ip_ranges=["vertex-dedicated-range"], service_account=None, state="RUNNING", head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS_CUSTOM_IMAGE, @@ -438,6 +449,7 @@ class ClusterConstants: python_version="3.10", ray_version="2.9", network="", + reserved_ip_ranges="", service_account=ProjectConstants.TEST_SERVICE_ACCOUNT, state="RUNNING", head_node_type=TEST_HEAD_NODE_TYPE_1_POOL, diff --git a/tests/unit/vertex_ray/test_ray_prediction.py b/tests/unit/vertex_ray/test_ray_prediction.py index fdab42c425..c5d41c3725 100644 --- a/tests/unit/vertex_ray/test_ray_prediction.py +++ b/tests/unit/vertex_ray/test_ray_prediction.py @@ -41,7 +41,6 @@ import numpy as np import pytest import ray -from ray.train import xgboost as ray_xgboost import tensorflow as tf import torch import xgboost @@ -90,9 +89,14 @@ def ray_sklearn_checkpoint(): @pytest.fixture() def ray_xgboost_checkpoint(): - model = test_prediction_utils.get_xgboost_model() - checkpoint = ray_xgboost.XGBoostCheckpoint.from_model(model.get_booster()) - return checkpoint + if ray.__version__ == "2.9.3": + from ray.train import xgboost as ray_xgboost + + model = test_prediction_utils.get_xgboost_model() + checkpoint = ray_xgboost.XGBoostCheckpoint.from_model(model.get_booster()) + return checkpoint + else: + return None @pytest.fixture() @@ -374,6 +378,7 @@ def test_register_sklearnartifact_uri_not_gcs_uri_raise_error( assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*") # XGBoost Tests + @tc.xgbversion @tc.rovminversion def test_convert_checkpoint_to_xgboost_raise_exception( self, ray_checkpoint_from_dict @@ -392,6 +397,7 @@ def test_convert_checkpoint_to_xgboost_raise_exception( "ray.train.xgboost.XGBoostCheckpoint .*" ) + @tc.xgbversion def test_convert_checkpoint_to_xgboost_model_succeed( self, ray_xgboost_checkpoint ) -> None: @@ -406,6 +412,7 @@ def test_convert_checkpoint_to_xgboost_model_succeed( y_pred = model.predict(xgboost.DMatrix(np.array([[1, 2]]))) assert y_pred[0] is not None + @tc.xgbversion def test_register_xgboost_succeed( self, ray_xgboost_checkpoint, @@ -429,6 +436,7 @@ def test_register_xgboost_succeed( pickle_dump.assert_called_once() gcs_utils_upload_to_gcs.assert_called_once() + @tc.xgbversion def test_register_xgboost_initialized_succeed( self, ray_xgboost_checkpoint, @@ -455,6 +463,7 @@ def test_register_xgboost_initialized_succeed( pickle_dump.assert_called_once() gcs_utils_upload_to_gcs.assert_called_once() + @tc.xgbversion def test_register_xgboostartifact_uri_is_none_raise_error( self, ray_xgboost_checkpoint ) -> None: @@ -467,6 +476,7 @@ def test_register_xgboostartifact_uri_is_none_raise_error( ) assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*") + @tc.xgbversion def test_register_xgboostartifact_uri_not_gcs_uri_raise_error( self, ray_xgboost_checkpoint ) -> None: diff --git a/tests/unit/vertexai/conftest.py b/tests/unit/vertexai/conftest.py index ecc50c2df7..73084cb412 100644 --- a/tests/unit/vertexai/conftest.py +++ b/tests/unit/vertexai/conftest.py @@ -65,10 +65,7 @@ _TEST_OPTIMIZED_FV2, _TEST_PSC_OPTIMIZED_FOS, ) -from google.cloud.logging import Logger -from pyfakefs import fake_filesystem_unittest import pytest -import tensorflow.saved_model as tf_saved_model _TEST_PROJECT = "test-project" _TEST_PROJECT_NUMBER = "12345678" @@ -142,14 +139,6 @@ def google_auth_mock(): yield auth_mock -@pytest.fixture -def mock_filesystem(): - with fake_filesystem_unittest.Patcher() as patcher: - patcher.setUp() - yield patcher.fs - patcher.tearDown() - - @pytest.fixture def mock_storage_blob(mock_filesystem): """Mocks the storage Blob API. @@ -292,19 +281,6 @@ def mock_uuid(): yield uuid_mock -@pytest.fixture -def mock_tf_saved_model_load(): - with mock.patch.object(tf_saved_model, "load") as load_mock: - yield load_mock - - -@pytest.fixture -def mock_cloud_logging_list_entries(): - with mock.patch.object(Logger, "list_entries") as list_entries_mock: - list_entries_mock.return_value = [] - yield list_entries_mock - - @pytest.fixture def base_logger_mock(): with patch.object( diff --git a/tests/unit/vertexai/test_batch_prediction.py b/tests/unit/vertexai/test_batch_prediction.py index 485da92ae6..2970e44177 100644 --- a/tests/unit/vertexai/test_batch_prediction.py +++ b/tests/unit/vertexai/test_batch_prediction.py @@ -50,6 +50,8 @@ _TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME = "projects/123/locations/us-central1/models/456" _TEST_PALM_MODEL_NAME = "text-bison" _TEST_PALM_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_PALM_MODEL_NAME}" +_TEST_LLAMA_MODEL_NAME = "llama3-405b-instruct-maas" +_TEST_LLAMA_MODEL_RESOURCE_NAME = f"publishers/meta/models/{_TEST_LLAMA_MODEL_NAME}" _TEST_GCS_INPUT_URI = "gs://test-bucket/test-input.jsonl" _TEST_GCS_INPUT_URI_2 = "gs://test-bucket/test-input-2.jsonl" @@ -127,6 +129,23 @@ def get_batch_prediction_job_with_gcs_output_mock(): yield get_job_mock +@pytest.fixture +def get_batch_prediction_job_with_llama_model_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_job_mock: + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_LLAMA_MODEL_RESOURCE_NAME, + state=_TEST_JOB_STATE_SUCCESS, + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( + gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX + ), + ) + yield get_job_mock + + @pytest.fixture def get_batch_prediction_job_with_tuned_gemini_model_mock(): with mock.patch.object( @@ -252,6 +271,16 @@ def test_init_batch_prediction_job( name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY ) + def test_init_batch_prediction_job_with_llama_model( + self, + get_batch_prediction_job_with_llama_model_mock, + ): + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) + + get_batch_prediction_job_with_llama_model_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY + ) + def test_init_batch_prediction_job_with_tuned_gemini_model( self, get_batch_prediction_job_with_tuned_gemini_model_mock, @@ -447,6 +476,39 @@ def test_submit_batch_prediction_job_with_bq_input_without_output_uri_prefix( timeout=None, ) + def test_submit_batch_prediction_job_with_llama_model( + self, + create_batch_prediction_job_mock, + ): + job = batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_LLAMA_MODEL_RESOURCE_NAME, + input_dataset=_TEST_BQ_INPUT_URI, + ) + + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB + + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( + display_name=_TEST_DISPLAY_NAME, + model=_TEST_LLAMA_MODEL_RESOURCE_NAME, + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( + instances_format="bigquery", + bigquery_source=gca_io_compat.BigQuerySource( + input_uri=_TEST_BQ_INPUT_URI + ), + ), + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( + bigquery_destination=gca_io_compat.BigQueryDestination( + output_uri=_TEST_BQ_OUTPUT_PREFIX + ), + predictions_format="bigquery", + ), + ) + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + timeout=None, + ) + @pytest.mark.usefixtures("create_batch_prediction_job_mock") def test_submit_batch_prediction_job_with_tuned_model( self, @@ -467,7 +529,8 @@ def test_submit_batch_prediction_job_with_invalid_source_model(self): with pytest.raises( ValueError, match=( - f"Model '{_TEST_PALM_MODEL_RESOURCE_NAME}' is not a Generative AI model." + "Abbreviated model names are only supported for Gemini models. " + "Please provide the full publisher model name." ), ): batch_prediction.BatchPredictionJob.submit( @@ -475,6 +538,19 @@ def test_submit_batch_prediction_job_with_invalid_source_model(self): input_dataset=_TEST_GCS_INPUT_URI, ) + def test_submit_batch_prediction_job_with_invalid_abbreviated_model_name(self): + with pytest.raises( + ValueError, + match=( + "Abbreviated model names are only supported for Gemini models. " + "Please provide the full publisher model name." + ), + ): + batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_LLAMA_MODEL_NAME, + input_dataset=_TEST_GCS_INPUT_URI, + ) + @pytest.mark.usefixtures("get_non_gemini_model_mock") def test_submit_batch_prediction_job_with_non_gemini_tuned_model(self): with pytest.raises( diff --git a/tests/unit/vertexai/test_evaluation.py b/tests/unit/vertexai/test_evaluation.py index c386419305..f5e0879599 100644 --- a/tests/unit/vertexai/test_evaluation.py +++ b/tests/unit/vertexai/test_evaluation.py @@ -1176,15 +1176,14 @@ def test_evaluate_response_column_and_model_provided(self): ValueError, match=re.escape( ( - "The `model` parameter or `baseline_model` in pairwise" - " metric is specified, but the evaluation `dataset` contains" - " model response column or baseline model response column" - " `response` to perform bring-your-own-response(BYOR)" - " evaluation. If you would like to perform evaluation using the" - " dataset with the existing model response column or or" - " baseline model response column `response`, please remove" - " `model` parameter in `EvalTask.evaluate` function or keep" - " `baseline_model` as None in pairwise metric." + "The `model` parameter or `baseline_model` in pairwise metric is" + " specified, but the evaluation `dataset` contains model response" + " column or baseline model response column `response`" + " to perform bring-your-own-response(BYOR) evaluation. If you would" + " like to perform evaluation using the dataset with the" + " existing model response column or or baseline model response column" + " `response`, please remove `model` parameter in `EvalTask.evaluate()`" + " function or `baseline_model` in `PairwiseMetric`." ) ), ): @@ -1202,16 +1201,15 @@ def test_evaluate_baseline_response_column_and_baseline_model_provided(self): ValueError, match=re.escape( ( - "The `model` parameter or `baseline_model` in pairwise" - " metric is specified, but the evaluation `dataset` contains" - " model response column or baseline model response column" - " `baseline_model_response` to perform" - " bring-your-own-response(BYOR) evaluation. If you would" + "The `model` parameter or `baseline_model` in pairwise metric is" + " specified, but the evaluation `dataset` contains model response" + " column or baseline model response column `baseline_model_response`" + " to perform bring-your-own-response(BYOR) evaluation. If you would" " like to perform evaluation using the dataset with the" - " existing model response column or or baseline model" - " response column `baseline_model_response`, please remove" - " `model` parameter in `EvalTask.evaluate` function or keep" - " `baseline_model` as None in pairwise metric." + " existing model response column or or baseline model response column" + " `baseline_model_response`, please remove `model` parameter in" + " `EvalTask.evaluate()` function or `baseline_model` in" + " `PairwiseMetric`." ) ), ): @@ -1227,8 +1225,8 @@ def test_evaluate_response_column_and_model_not_provided(self): KeyError, match=re.escape( ( - "Required column `response` not found in the eval dataset." - " The columns in the provided dataset are ['prompt']." + "Required column `response` not found in the evaluation dataset." + " The columns in the evaluation dataset are ['prompt']." ) ), ): @@ -1248,7 +1246,7 @@ def test_evaluate_baseline_response_column_and_baseline_model_not_provided( match=re.escape( ( "Required column `baseline_model_response` not found in the" - " eval dataset. The columns in the provided dataset are" + " evaluation dataset. The columns in the evaluation dataset are" " ['prompt', 'response']." ) ), @@ -1301,7 +1299,7 @@ def test_evaluate_pairwise_metrics_with_multiple_baseline_models(self): ) with pytest.raises( ValueError, - match="Not all PairwiseMetric instances have the same baseline_model", + match="Not all `PairwiseMetric` instances have the same `baseline_model`", ): test_eval_task.evaluate() @@ -1314,15 +1312,14 @@ def test_evaluate_invalid_model_and_dataset_input(self): ValueError, match=re.escape( ( - "The `model` parameter or `baseline_model` in pairwise metric" - " is specified, but the evaluation `dataset` contains model" - " response column or baseline model response column `response`" - " to perform bring-your-own-response(BYOR) evaluation. If you" - " would like to perform evaluation using the dataset with the" - " existing model response column or or baseline model response" - " column `response`, please remove `model` parameter in" - " `EvalTask.evaluate` function or keep `baseline_model` as None" - " in pairwise metric." + "The `model` parameter or `baseline_model` in pairwise metric is" + " specified, but the evaluation `dataset` contains model response" + " column or baseline model response column `response`" + " to perform bring-your-own-response(BYOR) evaluation. If you would" + " like to perform evaluation using the dataset with the" + " existing model response column or or baseline model response column" + " `response`, please remove `model` parameter in `EvalTask.evaluate()`" + " function or `baseline_model` in `PairwiseMetric`." ) ), ): @@ -1339,15 +1336,14 @@ def test_unmatched_metric_column_mapping(self): ValueError, match=re.escape( ( - "The `model` parameter or `baseline_model` in pairwise metric" - " is specified, but the evaluation `dataset` contains model" - " response column or baseline model response column `response`" - " to perform bring-your-own-response(BYOR) evaluation. If you" - " would like to perform evaluation using the dataset with the" - " existing model response column or or baseline model response" - " column `response`, please remove `model` parameter in" - " `EvalTask.evaluate` function or keep `baseline_model` as None" - " in pairwise metric." + "The `model` parameter or `baseline_model` in pairwise metric is" + " specified, but the evaluation `dataset` contains model response" + " column or baseline model response column `response`" + " to perform bring-your-own-response(BYOR) evaluation. If you would" + " like to perform evaluation using the dataset with the" + " existing model response column or or baseline model response column" + " `response`, please remove `model` parameter in `EvalTask.evaluate()`" + " function or `baseline_model` in `PairwiseMetric`." ) ), ): diff --git a/tests/unit/vertexai/test_tokenization.py b/tests/unit/vertexai/test_tokenization.py index 861c225211..f51be60360 100644 --- a/tests/unit/vertexai/test_tokenization.py +++ b/tests/unit/vertexai/test_tokenization.py @@ -20,11 +20,19 @@ import tempfile from typing import List from unittest import mock -from vertexai.generative_models import Content, Image, Part +from vertexai.generative_models import ( + Content, + Image, + Part, + FunctionDeclaration, + Tool, +) + from vertexai.tokenization import _tokenizer_loading from vertexai.tokenization._tokenizers import ( CountTokensResult, TokensInfo, + _TextsAccumulator, get_tokenizer_for_model, ) import pytest @@ -32,7 +40,11 @@ from sentencepiece import sentencepiece_model_pb2, sentencepiece_pb2 from google.cloud.aiplatform_v1beta1.types import ( content as gapic_content_types, + tool as gapic_tool_types, + openapi, ) +from google.protobuf import struct_pb2 + _TOKENIZER_NAME = "google/gemma" _MODEL_NAME = "gemini-1.5-pro" @@ -170,14 +182,8 @@ _LIST_OF_UNSUPPORTED_CONTENTS = [ - gapic_content_types.Part( - video_metadata=gapic_content_types.VideoMetadata(start_offset="10s") - ), Part.from_uri("gs://bucket/object", mime_type="mime_type"), Part.from_data(b"inline_data_bytes", mime_type="mime_type"), - Part.from_dict({"function_call": {"name": "test_function_call"}}), - Part.from_dict({"function_response": {"name": "test_function_response"}}), - Part.from_dict({"video_metadata": {"start_offset": "10s"}}), Content( role="user", parts=[Part.from_uri("gs://bucket/object", mime_type="mime_type")], @@ -186,20 +192,57 @@ role="user", parts=[Part.from_data(b"inline_data_bytes", mime_type="mime_type")], ), - Content( - role="user", - parts=[Part.from_dict({"function_call": {"name": "test_function_call"}})], - ), - Content( - role="user", - parts=[ - Part.from_dict({"function_response": {"name": "test_function_response"}}) - ], - ), - Content( - role="user", - parts=[Part.from_dict({"video_metadata": {"start_offset": "10s"}})], - ), +] + +_NESTED_STRUCT_1 = struct_pb2.Struct( + fields={"string_key": struct_pb2.Value(string_value="value1")} +) +_NESTED_STRUCT_2 = struct_pb2.Struct( + fields={ + "list_key": struct_pb2.Value( + list_value=struct_pb2.ListValue( + values=[struct_pb2.Value(string_value="value2")] + ) + ) + } +) +_NESTED_STRUCT_3 = struct_pb2.Struct( + fields={ + "struct_key": struct_pb2.Value( + struct_value=struct_pb2.Struct( + fields={"string_key": struct_pb2.Value(string_value="value3")} + ) + ) + } +) +_STRUCT = struct_pb2.Struct( + fields={ + "string_key": struct_pb2.Value(string_value="value4"), + "list_key": struct_pb2.Value( + list_value=struct_pb2.ListValue( + values=[struct_pb2.Value(string_value="value5")] + ) + ), + "struct_key1": struct_pb2.Value(struct_value=_NESTED_STRUCT_1), + "struct_key2": struct_pb2.Value(struct_value=_NESTED_STRUCT_2), + "struct_key3": struct_pb2.Value(struct_value=_NESTED_STRUCT_3), + } +) +_STRUCT_TEXTS = [ + "struct_key3", + "struct_key1", + "list_key", + "string_key", + "struct_key2", + "struct_key", + "string_key", + "value3", + "string_key", + "value1", + "value5", + "value4", + "list_key", + "value2", ] @@ -218,7 +261,15 @@ def mock_sp_processor(): def _encode_as_ids(contents: List[str]): - return [_TOKENS_MAP[content]["ids"] for content in contents] + return [ + ( + _TOKENS_MAP[content]["ids"] + if content in _TOKENS_MAP + # Returns stable ids arrary when content is not predefined. + else [0] * len(content.split(" ")) + ) + for content in contents + ] def _build_sentencepiece_text(content: str): @@ -254,6 +305,21 @@ def mock_hashlib_sha256(): yield sha256_mock +def get_current_weather(location: str, unit: str = "centigrade"): + """Gets weather in the specified location. + Args: + location: The location for which to get the weather. + unit: Optional. Temperature unit. Can be Centigrade or Fahrenheit. Defaults to Centigrade. + Returns: + The weather information as a dict. + """ + return dict( + location=location, + unit=unit, + weather="Super nice, but maybe a bit hot.", + ) + + @pytest.mark.usefixtures("mock_requests_get", "mock_hashlib_sha256") class TestTokenizers: """Unit tests for the tokenizers.""" @@ -269,7 +335,7 @@ def test_valid_contents_type_for_bytes_token_type(self, mock_sp_processor): _tokenizer_loading.get_sentencepiece.cache_clear() assert get_tokenizer_for_model(_MODEL_NAME).compute_tokens( [_SENTENCE_4] - ).token_info_list == ( + ).tokens_info == ( [TokensInfo(token_ids=[0, 1], tokens=[b"A", b"B"], role="user")] ) assert get_tokenizer_for_model(_MODEL_NAME).count_tokens( @@ -304,10 +370,9 @@ def testcompute_tokens_valid_contents_type( self, mock_sp_processor, contents, encode_input, encode_output, roles ): _tokenizer_loading.get_sentencepiece.cache_clear() - assert ( get_tokenizer_for_model(_MODEL_NAME).compute_tokens(contents) - ).token_info_list == ( + ).tokens_info == ( [ TokensInfo(token_ids=output["ids"], tokens=output["tokens"], role=role) for role, output in zip(roles, encode_output) @@ -334,6 +399,54 @@ def test_count_tokens_unsupported_contents_type( get_tokenizer_for_model(_MODEL_NAME).count_tokens(contents) e.match("Tokenizers do not support non-text content types.") + def test_system_instruction_count_tokens(self, mock_sp_processor): + _tokenizer_loading.get_sentencepiece.cache_clear() + tokenizer = get_tokenizer_for_model(_MODEL_NAME) + result = tokenizer.count_tokens( + ["hello world"], system_instruction=["You are a chatbot."] + ) + assert result.total_tokens == 6 + + def test_function_call_count_tokens(self, mock_sp_processor): + tokenizer = get_tokenizer_for_model(_MODEL_NAME) + part = Part._from_gapic( + gapic_content_types.Part( + function_call=gapic_tool_types.FunctionCall( + name="test_function_call", + args=_STRUCT, + ), + ) + ) + + result = tokenizer.count_tokens(part) + + assert result.total_tokens + + def test_function_response_count_tokens(self, mock_sp_processor): + tokenizer = get_tokenizer_for_model(_MODEL_NAME) + part = Part._from_gapic( + gapic_content_types.Part( + function_response=gapic_tool_types.FunctionResponse( + name="test_function_response", response=_STRUCT + ), + ) + ) + + result = tokenizer.count_tokens(part) + + assert result.total_tokens + + def test_tools_count_tokens(self, mock_sp_processor): + tokenizer = get_tokenizer_for_model(_MODEL_NAME) + get_current_weather_func = FunctionDeclaration.from_func(get_current_weather) + weather_tool = Tool( + function_declarations=[get_current_weather_func], + ) + + result = tokenizer.count_tokens(contents=[], tools=[weather_tool]) + + assert result.total_tokens + def test_image_mime_types(self, mock_sp_processor): # Importing external library lazily to reduce the scope of import errors. from PIL import Image as PIL_Image # pylint: disable=g-import-not-at-top @@ -417,3 +530,107 @@ def test_load_model_proto_from_corrupted_cache(self, hash_mock, mock_requests_ge mock_requests_get.assert_called_once() with open(cache_path, "rb") as f: assert f.read() == _TOKENIZER_MODEL.SerializeToString() + + +class TestTextsAccumulator: + def setup_method(self): + self.texts_accumulator = _TextsAccumulator() + + def test_function_declaration_unsupported_field(self): + function_declaration = gapic_tool_types.FunctionDeclaration( + parameters=openapi.Schema(nullable=True) + ) + with pytest.raises(ValueError): + self.texts_accumulator.add_tool( + gapic_tool_types.Tool(function_declarations=[function_declaration]) + ) + + def test_function_call_unsupported_field(self): + function_call = gapic_tool_types.FunctionCall( + name="test_function_call", + args=struct_pb2.Struct( + fields={ + "bool_key": struct_pb2.Value(bool_value=True), + } + ), + ) + with pytest.raises(ValueError): + self.texts_accumulator.add_function_call(function_call) + + def test_function_response_unsupported_field(self): + function_call = gapic_tool_types.FunctionResponse( + name="test_function_response", + response=struct_pb2.Struct( + fields={ + "bool_key": struct_pb2.Value(bool_value=True), + } + ), + ) + with pytest.raises(ValueError): + self.texts_accumulator.add_function_response(function_call) + + def test_function_declaration(self): + schema1 = openapi.Schema( + format="schema1_format", description="schema1_description" + ) + schema2 = openapi.Schema( + format="schema2_format", description="schema2_description" + ) + example = struct_pb2.Value(string_value="value1") + function_declaration = gapic_tool_types.FunctionDeclaration( + name="function_declaration_name", + description="function_declaration_description", + parameters=openapi.Schema( + format="schema_format", + description="schema_description", + enum=["schema_enum1", "schema_enum2"], + required=["schema_required1", "schema_required2"], + items=schema1, + properties={"property_key": schema2}, + example=example, + ), + ) + + self.texts_accumulator.add_tool( + gapic_tool_types.Tool(function_declarations=[function_declaration]) + ) + assert self.texts_accumulator.get_texts() == [ + "function_declaration_name", + "function_declaration_description", + "schema_format", + "schema_description", + "schema_enum1", + "schema_enum2", + "schema_required1", + "schema_required2", + "schema1_format", + "schema1_description", + "property_key", + "schema2_format", + "schema2_description", + "value1", + ] + + def test_function_call(self): + function_call = gapic_tool_types.FunctionCall( + name="test_function_call", + args=_STRUCT, + ) + + self.texts_accumulator.add_function_call(function_call) + + assert ( + self.texts_accumulator.get_texts() == ["test_function_call"] + _STRUCT_TEXTS + ) + + def test_function_response(self): + function_response = gapic_tool_types.FunctionResponse( + name="test_function_response", response=_STRUCT + ) + + self.texts_accumulator.add_function_response(function_response) + + assert ( + self.texts_accumulator.get_texts() + == ["test_function_response"] + _STRUCT_TEXTS + ) diff --git a/vertexai/batch_prediction/_batch_prediction.py b/vertexai/batch_prediction/_batch_prediction.py index c7cafc7543..2ca7d7b85c 100644 --- a/vertexai/batch_prediction/_batch_prediction.py +++ b/vertexai/batch_prediction/_batch_prediction.py @@ -33,6 +33,7 @@ _LOGGER = aiplatform_base.Logger(__name__) _GEMINI_MODEL_PATTERN = r"publishers/google/models/gemini" +_LLAMA_MODEL_PATTERN = r"publishers/meta/models/llama" _GEMINI_TUNED_MODEL_PATTERN = r"^projects/[0-9]+?/locations/[0-9a-z-]+?/models/[0-9]+?$" @@ -272,13 +273,20 @@ def _reconcile_model_name(cls, model_name: str) -> str: if "/" not in model_name: # model name (e.g., gemini-1.0-pro) - model_name = "publishers/google/models/" + model_name + if model_name.startswith("gemini"): + model_name = "publishers/google/models/" + model_name + else: + raise ValueError( + "Abbreviated model names are only supported for Gemini models. " + "Please provide the full publisher model name." + ) elif model_name.startswith("models/"): # publisher model name (e.g., models/gemini-1.0-pro) model_name = "publishers/google/" + model_name elif ( # publisher model full name not model_name.startswith("publishers/google/models/") + and not model_name.startswith("publishers/meta/models/") # tuned model full resource name and not re.search(_GEMINI_TUNED_MODEL_PATTERN, model_name) ): @@ -302,6 +310,10 @@ def _is_genai_model(cls, model_name: str) -> bool: # Model is a tuned Gemini model. return True + if re.search(_LLAMA_MODEL_PATTERN, model_name): + # Model is a Llama3 model. + return True + return False @classmethod diff --git a/vertexai/evaluation/_base.py b/vertexai/evaluation/_base.py index 60696f8ed2..863ab0b6a1 100644 --- a/vertexai/evaluation/_base.py +++ b/vertexai/evaluation/_base.py @@ -76,9 +76,10 @@ def validate_dataset_column(self, column_name: str) -> None: not in self.dataset.columns ): raise KeyError( - f"Required column `{self.metric_column_mapping.get(column_name, column_name)}`" - " not found in the eval dataset. The columns in the provided dataset" - f" are {list(self.dataset.columns)}." + "Required column" + f" `{self.metric_column_mapping.get(column_name, column_name)}`" + " not found in the evaluation dataset. The columns in the" + f" evaluation dataset are {list(self.dataset.columns)}." ) diff --git a/vertexai/evaluation/_evaluation.py b/vertexai/evaluation/_evaluation.py index cfce0f3c07..ab878b1822 100644 --- a/vertexai/evaluation/_evaluation.py +++ b/vertexai/evaluation/_evaluation.py @@ -113,12 +113,13 @@ def _validate_metric_column_map( not in evaluation_run_config.dataset.columns ): raise ValueError( - "Cannot find appropriate column name mapping to help" - " fill metric prompt template with evaluation dataset. " - "Please provide key-value pair in `metric_column_mapping`" - " parameter from 'EvalTask.evaluate' function. Key is" - f" '{variable}', value should be found in evaluation" - f" dataset columns" + f"Cannot find the `{variable}` column in the evaluation" + " dataset to fill the metric prompt template for" + f" `{str(metric)}` metric. Please check if the column is" + " present in the evaluation dataset, or provide a" + " key-value pair in `metric_column_mapping` parameter" + " of `EvalTask` to map it to a different column name." + " The evaluation dataset columns are" f" {list(evaluation_run_config.dataset.columns)}." ) @@ -158,10 +159,11 @@ def _compute_custom_metrics( ] = metric_output[custom_metric.name] except KeyError: raise KeyError( - f"Custom metric score `{custom_metric.name}` not found in the metric" - f" output {metric_output}. Please make sure the custom metric" - " function is valid, and the output dictionary uses" - f" `{custom_metric.name}` as the key for metric value." + f"Custom metric score `{custom_metric.name}` not found in" + f" the metric output {metric_output}. Please make sure the" + " custom metric function is valid, and the output" + f" dictionary uses `{custom_metric.name}` as the key for" + " metric score." ) # Include additional metric results like explanation. for key, value in metric_output.items(): @@ -224,8 +226,7 @@ def _aggregate_summary_metrics( ].std() except (ValueError, KeyError) as e: _LOGGER.warning( - f"Failed to compute metric statistics for {metric}. This metric" - " output contains error from the Autorater.\n" + f"Failed to compute metric statistics for `{metric}` metric." f"{type(e).__name__}: {e}" ) continue @@ -414,12 +415,13 @@ def _run_model_inference( ] = response_column_name else: raise ValueError( - "Missing required input column to start model inference." + "Missing required input `prompt` column to start model inference." " Please provide a `prompt_template` parameter in" " `EvalTask.evaluate()` function if you want to assemble a" - " `prompt` with variables from the dataset, or provide a" + " `prompt` column with variables from the dataset, or provide a" " `prompt` column in dataset to directly use as input to" - " the model." + " the model. Mappings in `metric_column_mapping` do not" + " apply for model inference and are used for evaluation only." ) else: if model: @@ -432,8 +434,8 @@ def _run_model_inference( " like to perform evaluation using the dataset with the" " existing model response column or or baseline model response column" f" `{evaluation_run_config.metric_column_mapping[response_column_name]}`," - " please remove `model` parameter in `EvalTask.evaluate`" - " function or keep `baseline_model` as None in pairwise metric." + " please remove `model` parameter in `EvalTask.evaluate()`" + " function or `baseline_model` in `PairwiseMetric`." ) @@ -486,8 +488,8 @@ def _assemble_prompt_for_dataset( raise ValueError("Prompt template cannot be an empty string.") _LOGGER.info( - "Assembling prompts from the prompt_template. The `prompt` column in" - " the EvalResult.metrics_table has the assembled prompts used for model" + "Assembling prompts from the `prompt_template`. The `prompt` column in" + " the `EvalResult.metrics_table` has the assembled prompts used for model" " response generation." ) if isinstance(prompt_template, str): @@ -507,13 +509,28 @@ def _assemble_prompt_for_dataset( ), axis=1, ) + if ( + constants.Dataset.PROMPT_COLUMN + in evaluation_run_config.metric_column_mapping + and evaluation_run_config.metric_column_mapping[ + constants.Dataset.PROMPT_COLUMN + ] + != constants.Dataset.PROMPT_COLUMN + ): + _LOGGER.warning( + "The `prompt` column mapping provided in" + " `metric_column_mapping` parameter is overwritten by the" + " assembled `prompt` column because the `prompt_template`" + " parameter is provided. Please verify that you want to use" + " the assembled `prompt` column for evaluation." + ) evaluation_run_config.metric_column_mapping[ constants.Dataset.PROMPT_COLUMN ] = constants.Dataset.PROMPT_COLUMN except Exception as e: raise ValueError( f"Failed to assemble prompt template: {e}. Please make sure all" - " variables in prompt_template are present in the evaluation" + " variables in `prompt_template` are present in the evaluation" f" dataset columns: `{list(evaluation_run_config.dataset.columns)}`." ) from e @@ -634,8 +651,8 @@ def _compute_metrics( total_request_count = api_request_count + custom_metric_request_count _LOGGER.info( - f"Computing metrics with a total of {total_request_count} Vertex online" - " evaluation service requests." + f"Computing metrics with a total of {total_request_count} Vertex Gen AI" + " Evaluation Service API requests." ) instance_list = [] @@ -718,10 +735,10 @@ def _get_baseline_model(evaluation_run_config: evaluation_base.EvaluationRunConf } if len(set(baseline_models.values())) > 1: raise ValueError( - "Not all PairwiseMetric instances have the same baseline_model " + "Not all `PairwiseMetric` instances have the same `baseline_model`. " f"Here are the detected baseline models: `{baseline_models}`. " "Please separate pairwise metrics with different baseline models " - "in different EvalTask or use the same baseline model for " + "in different `EvalTask` or use the same baseline model for " "all pairwise metrics." ) return pairwise_metric_instances[0].baseline_model @@ -747,10 +764,11 @@ def _convert_metric_prompt_template_example(metrics): ) _LOGGER.info( f"Pairwise metric `{metric.metric_name}` loaded from" - " MetricPromptTemplateExamples does not have baseline_model." - " If you would like to run inference on the baseline model," - " please instantiate a PairwiseMetric and provide the" - " `baseline_model` input." + " `MetricPromptTemplateExamples` does not have `baseline_model`" + " specified and only supports Bring-Your-Own-Response(BYOR)" + " evaluation. If you would like to run inference on the baseline model," + " please instantiate a `PairwiseMetric` and provide the" + " `baseline_model` parameter." ) updated_metrics.append(metric) return updated_metrics @@ -836,7 +854,6 @@ def evaluate( constants.Dataset.REFERENCE_COLUMN ) - # Model inference if prompt_template: _assemble_prompt_for_dataset(evaluation_run_config, prompt_template) @@ -852,22 +869,12 @@ def evaluate( ) ) - # Baseline model inference pairwise_metric_exists = any( isinstance(metric, pairwise_metric.PairwiseMetric) for metric in evaluation_run_config.metrics ) if pairwise_metric_exists: baseline_model = _get_baseline_model(evaluation_run_config) - if baseline_model is None: - _LOGGER.warning( - "Note: pointwise metric with metric prompt template example and " - "pairwise metric with metric prompt template example do not " - "have baseline models. If evaluation run fails by " - "`baseline_model_response` column validation, Please sepecify" - " `baseline_model_response_column_name` in `EvalTask.evaluate` to" - " perform bring-your-own-response(BYOR) evaluation." - ) _run_model_inference( model=baseline_model, evaluation_run_config=evaluation_run_config, diff --git a/vertexai/evaluation/eval_task.py b/vertexai/evaluation/eval_task.py index 973c3d3483..3ff24e5f22 100644 --- a/vertexai/evaluation/eval_task.py +++ b/vertexai/evaluation/eval_task.py @@ -396,7 +396,7 @@ def evaluate( global_experiment_name = metadata._experiment_tracker.experiment_name if experiment_run_name and not self._experiment and not global_experiment_name: raise ValueError( - "Experiment is not set. Please initialize EvalTask with an" + "Experiment is not set. Please initialize `EvalTask` with an" " experiment, or initialize a global experiment with " "`vertexai.init(experiment='experiment_name')`for logging this" " evaluation run." diff --git a/vertexai/evaluation/metrics/metric_prompt_template.py b/vertexai/evaluation/metrics/metric_prompt_template.py index bae45d2a76..75256d3ac3 100644 --- a/vertexai/evaluation/metrics/metric_prompt_template.py +++ b/vertexai/evaluation/metrics/metric_prompt_template.py @@ -51,13 +51,7 @@ def __init__( ): """Initializes a metric prompt template.""" - if not input_variables: - _LOGGER.info( - "input_variables is empty. The evaluation dataset will not be used" - " for computing this model-based metric. Only the responses are used." - ) self._input_variables = input_variables - self._instruction = instruction self._metric_definition = metric_definition self._criteria = criteria @@ -117,6 +111,10 @@ def __init__( """ if not input_variables: input_variables = [] + _LOGGER.info( + "The `input_variables` parameter is empty. Only the `response`" + " column is used for computing this model-based metric." + ) input_variables = list(set(input_variables + ["response"])) instruction = instruction or self.get_default_pointwise_instruction() @@ -273,6 +271,11 @@ def __init__( """ if not input_variables: input_variables = [] + _LOGGER.info( + "The `input_variables` parameter is empty. Only the `response`" + " and `baseline_model_response` columns are used for computing" + " this model-based metric." + ) input_variables = list( set(input_variables + ["response", "baseline_model_response"]) ) diff --git a/vertexai/evaluation/metrics/metric_prompt_template_examples.py b/vertexai/evaluation/metrics/metric_prompt_template_examples.py index 659453e9a8..d067bf0824 100644 --- a/vertexai/evaluation/metrics/metric_prompt_template_examples.py +++ b/vertexai/evaluation/metrics/metric_prompt_template_examples.py @@ -24,74 +24,71 @@ ) from vertexai.evaluation.metrics import pairwise_metric from vertexai.evaluation.metrics import pointwise_metric -import immutabledict class MetricPromptTemplateExamples: """Examples of metric prompt templates for model-based evaluation.""" - _PROMPT_TEMPLATE_MAP = immutabledict.immutabledict( - { - constants.Metric.COHERENCE: _default_templates.COHERENCE_PROMPT_TEMPLATE, - constants.Metric.FLUENCY: _default_templates.FLUENCY_PROMPT_TEMPLATE, - constants.Metric.SAFETY: _default_templates.SAFETY_PROMPT_TEMPLATE, - constants.Metric.GROUNDEDNESS: ( - _default_templates.GROUNDEDNESS_PROMPT_TEMPLATE - ), - constants.Metric.INSTRUCTION_FOLLOWING: ( - _default_templates.INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE - ), - constants.Metric.VERBOSITY: _default_templates.VERBOSITY_PROMPT_TEMPLATE, - constants.Metric.TEXT_QUALITY: ( - _default_templates.TEXT_QUALITY_PROMPT_TEMPLATE - ), - constants.Metric.SUMMARIZATION_QUALITY: ( - _default_templates.SUMMARIZATION_QUALITY_PROMPT_TEMPLATE - ), - constants.Metric.QUESTION_ANSWERING_QUALITY: ( - _default_templates.QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE - ), - constants.Metric.MULTI_TURN_CHAT_QUALITY: ( - _default_templates.MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE - ), - constants.Metric.MULTI_TURN_SAFETY: ( - _default_templates.MULTI_TURN_SAFETY_PROMPT_TEMPLATE - ), - constants.Metric.PAIRWISE_COHERENCE: ( - _default_templates.PAIRWISE_COHERENCE_PROMPT_TEMPLATE - ), - constants.Metric.PAIRWISE_FLUENCY: ( - _default_templates.PAIRWISE_FLUENCY_PROMPT_TEMPLATE - ), - constants.Metric.PAIRWISE_SAFETY: ( - _default_templates.PAIRWISE_SAFETY_PROMPT_TEMPLATE - ), - constants.Metric.PAIRWISE_GROUNDEDNESS: ( - _default_templates.PAIRWISE_GROUNDEDNESS_PROMPT_TEMPLATE - ), - constants.Metric.PAIRWISE_INSTRUCTION_FOLLOWING: ( - _default_templates.PAIRWISE_INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE - ), - constants.Metric.PAIRWISE_VERBOSITY: ( - _default_templates.PAIRWISE_VERBOSITY_PROMPT_TEMPLATE - ), - constants.Metric.PAIRWISE_TEXT_QUALITY: ( - _default_templates.PAIRWISE_TEXT_QUALITY_PROMPT_TEMPLATE - ), - constants.Metric.PAIRWISE_SUMMARIZATION_QUALITY: ( - _default_templates.PAIRWISE_SUMMARIZATION_QUALITY_PROMPT_TEMPLATE - ), - constants.Metric.PAIRWISE_QUESTION_ANSWERING_QUALITY: ( - _default_templates.PAIRWISE_QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE - ), - constants.Metric.PAIRWISE_MULTI_TURN_CHAT_QUALITY: ( - _default_templates.PAIRWISE_MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE - ), - constants.Metric.PAIRWISE_MULTI_TURN_SAFETY: ( - _default_templates.PAIRWISE_MULTI_TURN_SAFETY_PROMPT_TEMPLATE - ), - } - ) + _PROMPT_TEMPLATE_MAP = { + constants.Metric.COHERENCE: _default_templates.COHERENCE_PROMPT_TEMPLATE, + constants.Metric.FLUENCY: _default_templates.FLUENCY_PROMPT_TEMPLATE, + constants.Metric.SAFETY: _default_templates.SAFETY_PROMPT_TEMPLATE, + constants.Metric.GROUNDEDNESS: ( + _default_templates.GROUNDEDNESS_PROMPT_TEMPLATE + ), + constants.Metric.INSTRUCTION_FOLLOWING: ( + _default_templates.INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE + ), + constants.Metric.VERBOSITY: _default_templates.VERBOSITY_PROMPT_TEMPLATE, + constants.Metric.TEXT_QUALITY: ( + _default_templates.TEXT_QUALITY_PROMPT_TEMPLATE + ), + constants.Metric.SUMMARIZATION_QUALITY: ( + _default_templates.SUMMARIZATION_QUALITY_PROMPT_TEMPLATE + ), + constants.Metric.QUESTION_ANSWERING_QUALITY: ( + _default_templates.QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE + ), + constants.Metric.MULTI_TURN_CHAT_QUALITY: ( + _default_templates.MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE + ), + constants.Metric.MULTI_TURN_SAFETY: ( + _default_templates.MULTI_TURN_SAFETY_PROMPT_TEMPLATE + ), + constants.Metric.PAIRWISE_COHERENCE: ( + _default_templates.PAIRWISE_COHERENCE_PROMPT_TEMPLATE + ), + constants.Metric.PAIRWISE_FLUENCY: ( + _default_templates.PAIRWISE_FLUENCY_PROMPT_TEMPLATE + ), + constants.Metric.PAIRWISE_SAFETY: ( + _default_templates.PAIRWISE_SAFETY_PROMPT_TEMPLATE + ), + constants.Metric.PAIRWISE_GROUNDEDNESS: ( + _default_templates.PAIRWISE_GROUNDEDNESS_PROMPT_TEMPLATE + ), + constants.Metric.PAIRWISE_INSTRUCTION_FOLLOWING: ( + _default_templates.PAIRWISE_INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE + ), + constants.Metric.PAIRWISE_VERBOSITY: ( + _default_templates.PAIRWISE_VERBOSITY_PROMPT_TEMPLATE + ), + constants.Metric.PAIRWISE_TEXT_QUALITY: ( + _default_templates.PAIRWISE_TEXT_QUALITY_PROMPT_TEMPLATE + ), + constants.Metric.PAIRWISE_SUMMARIZATION_QUALITY: ( + _default_templates.PAIRWISE_SUMMARIZATION_QUALITY_PROMPT_TEMPLATE + ), + constants.Metric.PAIRWISE_QUESTION_ANSWERING_QUALITY: ( + _default_templates.PAIRWISE_QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE + ), + constants.Metric.PAIRWISE_MULTI_TURN_CHAT_QUALITY: ( + _default_templates.PAIRWISE_MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE + ), + constants.Metric.PAIRWISE_MULTI_TURN_SAFETY: ( + _default_templates.PAIRWISE_MULTI_TURN_SAFETY_PROMPT_TEMPLATE + ), + } @classmethod def get_prompt_template(cls, metric_name: str) -> str: diff --git a/vertexai/evaluation/utils.py b/vertexai/evaluation/utils.py index 5b3d9b4498..a7240eed59 100644 --- a/vertexai/evaluation/utils.py +++ b/vertexai/evaluation/utils.py @@ -31,7 +31,6 @@ from google.cloud.aiplatform_v1.services import ( evaluation_service as gapic_evaluation_services, ) -from vertexai.evaluation import constants if TYPE_CHECKING: @@ -180,13 +179,13 @@ def load_dataset( else: raise ValueError( f"Unsupported file type: {file_type} from {source}. Please" - " provide valid GCS path with jsonl or csv suffix or valid" + " provide a valid GCS path with `jsonl` or `csv` suffix or a valid" " BigQuery table URI." ) else: raise TypeError( - "Unsupported dataset type. Must be DataFrame, dictionary, or" - " valid GCS path with jsonl or csv suffix or BigQuery table URI." + "Unsupported dataset type. Must be a `pd.DataFrame`, Python dictionary," + " valid GCS path with `jsonl` or `csv` suffix or a valid BigQuery table URI." ) @@ -254,25 +253,27 @@ def _read_gcs_file_contents(filepath: str) -> str: def upload_evaluation_results( dataset: "pd.DataFrame", destination_uri_prefix: str, file_name: str ): - """Uploads eval results to GCS csv destination.""" + """Uploads eval results to GCS CSV destination.""" + supported_file_types = ["csv"] if not destination_uri_prefix: return if destination_uri_prefix.startswith(_GCS_PREFIX): _, extension = os.path.splitext(file_name) file_type = extension.lower()[1:] - if file_type in ["csv"]: + if file_type in supported_file_types: output_path = destination_uri_prefix + "/" + file_name utils.gcs_utils._upload_pandas_df_to_gcs(dataset, output_path) else: raise ValueError( - "Unsupported file type of GCS destination uri:" - f" {file_name}, please provide valid GCS" - " path with csv suffix." + "Unsupported file type in the GCS destination URI:" + f" {file_name}, please provide a valid GCS" + f" file name with a file type in {supported_file_types}." ) else: raise ValueError( - f"Unsupported destination uri: {destination_uri_prefix}." - " Please provide valid GCS bucket path." + f"Unsupported destination URI: {destination_uri_prefix}." + f" Please provide a valid GCS bucket URI prefix starting with" + f" {_GCS_PREFIX}." ) @@ -285,23 +286,12 @@ def initialize_metric_column_mapping( initialized_metric_column_mapping[column] = column if metric_column_mapping: for key, value in metric_column_mapping.items(): - if ( - key == constants.Dataset.PROMPT_COLUMN - and value != constants.Dataset.PROMPT_COLUMN - ): - _LOGGER.warning( - f"`{key}:{value}` will be overwritten to" - f" `prompt:prompt`. Please do not set `prompt` as" - " key in metric_column_mapping." - ) - continue if key in initialized_metric_column_mapping: _LOGGER.warning( - f"`{key}:{key}` is already in metric_column_mapping. Cannot" - f" override it with `{key}:{value}` because `{key}` is an" - " evaluation dataset column. Metric_column_mapping cannot override" - " keys that are already in evaluation dataset" - " column.metric_column_mapping: {metric_column_mapping}." + f"Cannot override `{key}` column with `{key}:{value}` mapping" + f" because `{key}` column is present in the evaluation" + " dataset. `metric_column_mapping` cannot override keys" + " that are already in evaluation dataset columns." ) else: initialized_metric_column_mapping[key] = value diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index f6db0d115a..80b21b9591 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -29,9 +29,11 @@ Dict, Iterable, List, + Literal, Optional, Sequence, Union, + overload, TYPE_CHECKING, ) @@ -514,6 +516,32 @@ def _parse_response( ) -> "GenerationResponse": return GenerationResponse._from_gapic(response) + @overload + def generate_content( + self, + contents: ContentsType, + *, + generation_config: Optional[GenerationConfigType] = None, + safety_settings: Optional[SafetySettingsType] = None, + tools: Optional[List["Tool"]] = None, + tool_config: Optional["ToolConfig"] = None, + stream: Literal[False] = False, + ) -> "GenerationResponse": + ... + + @overload + def generate_content( + self, + contents: ContentsType, + *, + generation_config: Optional[GenerationConfigType] = None, + safety_settings: Optional[SafetySettingsType] = None, + tools: Optional[List["Tool"]] = None, + tool_config: Optional["ToolConfig"] = None, + stream: Literal[True], + ) -> Iterable["GenerationResponse"]: + ... + def generate_content( self, contents: ContentsType, @@ -523,7 +551,7 @@ def generate_content( tools: Optional[List["Tool"]] = None, tool_config: Optional["ToolConfig"] = None, stream: bool = False, - ) -> Union["GenerationResponse", Iterable["GenerationResponse"],]: + ) -> Union["GenerationResponse", Iterable["GenerationResponse"]]: """Generates content. Args: @@ -562,6 +590,32 @@ def generate_content( tool_config=tool_config, ) + @overload + async def generate_content_async( + self, + contents: ContentsType, + *, + generation_config: Optional[GenerationConfigType] = None, + safety_settings: Optional[SafetySettingsType] = None, + tools: Optional[List["Tool"]] = None, + tool_config: Optional["ToolConfig"] = None, + stream: Literal[False] = False, + ) -> "GenerationResponse": + ... + + @overload + async def generate_content_async( + self, + contents: ContentsType, + *, + generation_config: Optional[GenerationConfigType] = None, + safety_settings: Optional[SafetySettingsType] = None, + tools: Optional[List["Tool"]] = None, + tool_config: Optional["ToolConfig"] = None, + stream: Literal[True] = True, + ) -> AsyncIterable["GenerationResponse"]: + ... + async def generate_content_async( self, contents: ContentsType, @@ -571,7 +625,7 @@ async def generate_content_async( tools: Optional[List["Tool"]] = None, tool_config: Optional["ToolConfig"] = None, stream: bool = False, - ) -> Union["GenerationResponse", AsyncIterable["GenerationResponse"],]: + ) -> Union["GenerationResponse", AsyncIterable["GenerationResponse"]]: """Generates content asynchronously. Args: @@ -770,7 +824,7 @@ async def async_generator(): return async_generator() def count_tokens( - self, contents: ContentsType + self, contents: ContentsType, *, tools: Optional[List["Tool"]] = None ) -> gapic_prediction_service_types.CountTokensResponse: """Counts tokens. @@ -782,22 +836,32 @@ def count_tokens( * str, Image, Part, * List[Union[str, Image, Part]], * List[Content] + tools: A list of tools (functions) that the model can try calling. Returns: A CountTokensResponse object that has the following attributes: total_tokens: The total number of tokens counted across all instances from the request. total_billable_characters: The total number of billable characters counted across all instances from the request. """ + request = self._prepare_request( + contents=contents, + tools=tools, + ) return self._prediction_client.count_tokens( request=gapic_prediction_service_types.CountTokensRequest( endpoint=self._prediction_resource_name, model=self._prediction_resource_name, - contents=self._prepare_request(contents=contents).contents, + contents=request.contents, + system_instruction=request.system_instruction, + tools=request.tools, ) ) async def count_tokens_async( - self, contents: ContentsType + self, + contents: ContentsType, + *, + tools: Optional[List["Tool"]] = None, ) -> gapic_prediction_service_types.CountTokensResponse: """Counts tokens asynchronously. @@ -809,17 +873,24 @@ async def count_tokens_async( * str, Image, Part, * List[Union[str, Image, Part]], * List[Content] + tools: A list of tools (functions) that the model can try calling. Returns: And awaitable for a CountTokensResponse object that has the following attributes: total_tokens: The total number of tokens counted across all instances from the request. total_billable_characters: The total number of billable characters counted across all instances from the request. """ + request = self._prepare_request( + contents=contents, + tools=tools, + ) return await self._prediction_async_client.count_tokens( request=gapic_prediction_service_types.CountTokensRequest( endpoint=self._prediction_resource_name, model=self._prediction_resource_name, - contents=self._prepare_request(contents=contents).contents, + contents=request.contents, + system_instruction=request.system_instruction, + tools=request.tools, ) ) @@ -981,6 +1052,30 @@ def __init__( def history(self) -> List["Content"]: return self._history + @overload + def send_message( + self, + content: PartsType, + *, + generation_config: Optional[GenerationConfigType] = None, + safety_settings: Optional[SafetySettingsType] = None, + tools: Optional[List["Tool"]] = None, + stream: Literal[False] = False, + ) -> "GenerationResponse": + ... + + @overload + def send_message( + self, + content: PartsType, + *, + generation_config: Optional[GenerationConfigType] = None, + safety_settings: Optional[SafetySettingsType] = None, + tools: Optional[List["Tool"]] = None, + stream: Literal[True] = True, + ) -> Iterable["GenerationResponse"]: + ... + def send_message( self, content: PartsType, @@ -1025,6 +1120,30 @@ def send_message( tools=tools, ) + @overload + def send_message_async( + self, + content: PartsType, + *, + generation_config: Optional[GenerationConfigType] = None, + safety_settings: Optional[SafetySettingsType] = None, + tools: Optional[List["Tool"]] = None, + stream: Literal[False] = False, + ) -> Awaitable["GenerationResponse"]: + ... + + @overload + def send_message_async( + self, + content: PartsType, + *, + generation_config: Optional[GenerationConfigType] = None, + safety_settings: Optional[SafetySettingsType] = None, + tools: Optional[List["Tool"]] = None, + stream: Literal[True] = True, + ) -> Awaitable[AsyncIterable["GenerationResponse"]]: + ... + def send_message_async( self, content: PartsType, diff --git a/vertexai/preview/rag/__init__.py b/vertexai/preview/rag/__init__.py index fff380b359..56590fed0e 100644 --- a/vertexai/preview/rag/__init__.py +++ b/vertexai/preview/rag/__init__.py @@ -38,36 +38,38 @@ ) from vertexai.preview.rag.utils.resources import ( EmbeddingModelConfig, - JiraSource, JiraQuery, + JiraSource, RagCorpus, RagFile, RagResource, SlackChannel, SlackChannelsSource, + Weaviate, ) __all__ = ( + "EmbeddingModelConfig", + "JiraQuery", + "JiraSource", + "RagCorpus", + "RagFile", + "RagResource", + "Retrieval", + "SlackChannel", + "SlackChannelsSource", + "VertexRagStore", + "Weaviate", "create_corpus", - "list_corpora", - "get_corpus", "delete_corpus", - "upload_file", + "delete_file", + "get_corpus", + "get_file", "import_files", "import_files_async", - "get_file", + "list_corpora", "list_files", - "delete_file", "retrieval_query", - "EmbeddingModelConfig", - "Retrieval", - "VertexRagStore", - "RagResource", - "RagFile", - "RagCorpus", - "JiraSource", - "JiraQuery", - "SlackChannel", - "SlackChannelsSource", + "upload_file", ) diff --git a/vertexai/preview/rag/rag_data.py b/vertexai/preview/rag/rag_data.py index 32da89c0a3..0983037173 100644 --- a/vertexai/preview/rag/rag_data.py +++ b/vertexai/preview/rag/rag_data.py @@ -48,6 +48,7 @@ RagCorpus, RagFile, SlackChannelsSource, + Weaviate, ) @@ -55,6 +56,7 @@ def create_corpus( display_name: Optional[str] = None, description: Optional[str] = None, embedding_model_config: Optional[EmbeddingModelConfig] = None, + vector_db: Optional[Weaviate] = None, ) -> RagCorpus: """Creates a new RagCorpus resource. @@ -76,6 +78,8 @@ def create_corpus( consist of any UTF-8 characters. description: The description of the RagCorpus. embedding_model_config: The embedding model config. + vector_db: The vector db config of the RagCorpus. If unspecified, the + default database Spanner is used. Returns: RagCorpus. Raises: @@ -88,9 +92,14 @@ def create_corpus( rag_corpus = GapicRagCorpus(display_name=display_name, description=description) if embedding_model_config: - rag_corpus = _gapic_utils.set_embedding_model_config( - embedding_model_config, - rag_corpus, + _gapic_utils.set_embedding_model_config( + embedding_model_config=embedding_model_config, + rag_corpus=rag_corpus, + ) + if vector_db is not None: + _gapic_utils.set_vector_db( + vector_db=vector_db, + rag_corpus=rag_corpus, ) request = CreateRagCorpusRequest( diff --git a/vertexai/preview/rag/rag_retrieval.py b/vertexai/preview/rag/rag_retrieval.py index 519b15a822..1d8bbb5612 100644 --- a/vertexai/preview/rag/rag_retrieval.py +++ b/vertexai/preview/rag/rag_retrieval.py @@ -37,6 +37,7 @@ def retrieval_query( rag_corpora: Optional[List[str]] = None, similarity_top_k: Optional[int] = 10, vector_distance_threshold: Optional[float] = 0.3, + vector_search_alpha: Optional[float] = 0.5, ) -> RetrieveContextsResponse: """Retrieve top k relevant docs/chunks. @@ -54,6 +55,7 @@ def retrieval_query( )], similarity_top_k=2, vector_distance_threshold=0.5, + vector_search_alpha=0.5, ) ``` @@ -67,6 +69,10 @@ def retrieval_query( similarity_top_k: The number of contexts to retrieve. vector_distance_threshold: Optional. Only return contexts with vector distance smaller than the threshold. + vector_search_alpha: Optional. Controls the weight between dense and + sparse vector search results. The range is [0, 1], where 0 means + sparse vector search only and 1 means dense vector search only. + The default value is 0.5. Returns: RetrieveContextsResonse. @@ -111,7 +117,13 @@ def retrieval_query( ) vertex_rag_store.vector_distance_threshold = vector_distance_threshold - query = RagQuery(text=text, similarity_top_k=similarity_top_k) + query = RagQuery( + text=text, + similarity_top_k=similarity_top_k, + ranking=RagQuery.Ranking( + alpha=vector_search_alpha, + ), + ) request = RetrieveContextsRequest( vertex_rag_store=vertex_rag_store, parent=parent, diff --git a/vertexai/preview/rag/utils/_gapic_utils.py b/vertexai/preview/rag/utils/_gapic_utils.py index 6cc1c1d316..640fd8c5f0 100644 --- a/vertexai/preview/rag/utils/_gapic_utils.py +++ b/vertexai/preview/rag/utils/_gapic_utils.py @@ -28,6 +28,7 @@ RagFile as GapicRagFile, SlackSource as GapicSlackSource, JiraSource as GapicJiraSource, + RagVectorDbConfig, ) from google.cloud.aiplatform import initializer from google.cloud.aiplatform.utils import ( @@ -41,6 +42,7 @@ RagFile, SlackChannelsSource, JiraSource, + Weaviate, ) @@ -93,8 +95,22 @@ def convert_gapic_to_embedding_model_config( return embedding_model_config +def convert_gapic_to_vector_db( + gapic_vector_db: RagVectorDbConfig, +) -> Weaviate: + """Convert Gapic RagVectorDbConfig to Weaviate.""" + if gapic_vector_db.__contains__("weaviate"): + return Weaviate( + weaviate_http_endpoint=gapic_vector_db.weaviate.http_endpoint, + collection_name=gapic_vector_db.weaviate.collection_name, + api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version, + ) + else: + return None + + def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus: - """ "Convert GapicRagCorpus to RagCorpus.""" + """Convert GapicRagCorpus to RagCorpus.""" rag_corpus = RagCorpus( name=gapic_rag_corpus.name, display_name=gapic_rag_corpus.display_name, @@ -102,12 +118,13 @@ def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus: embedding_model_config=convert_gapic_to_embedding_model_config( gapic_rag_corpus.rag_embedding_model_config ), + vector_db=convert_gapic_to_vector_db(gapic_rag_corpus.rag_vector_db_config), ) return rag_corpus def convert_gapic_to_rag_file(gapic_rag_file: GapicRagFile) -> RagFile: - """ "Convert GapicRagFile to RagFile.""" + """Convert GapicRagFile to RagFile.""" rag_file = RagFile( name=gapic_rag_file.name, display_name=gapic_rag_file.display_name, @@ -315,7 +332,7 @@ def get_file_name( def set_embedding_model_config( embedding_model_config: EmbeddingModelConfig, rag_corpus: GapicRagCorpus, -) -> GapicRagCorpus: +) -> None: if embedding_model_config.publisher_model and embedding_model_config.endpoint: raise ValueError("publisher_model and endpoint cannot be set at the same time.") if ( @@ -371,4 +388,27 @@ def set_embedding_model_config( "endpoint must be of the format `projects/{project}/locations/{location}/endpoints/{endpoint}` or `endpoints/{endpoint}`" ) - return rag_corpus + +def set_vector_db( + vector_db: Weaviate, + rag_corpus: GapicRagCorpus, +) -> None: + """Sets the vector db configuration for the rag corpus.""" + if isinstance(vector_db, Weaviate): + http_endpoint = vector_db.weaviate_http_endpoint + collection_name = vector_db.collection_name + api_key = vector_db.api_key + + rag_corpus.rag_vector_db_config = RagVectorDbConfig( + weaviate=RagVectorDbConfig.Weaviate( + http_endpoint=http_endpoint, + collection_name=collection_name, + ), + api_auth=api_auth.ApiAuth( + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=api_key + ), + ), + ) + else: + raise TypeError("vector_db must be a Weaviate.") diff --git a/vertexai/preview/rag/utils/resources.py b/vertexai/preview/rag/utils/resources.py index 1b5af451f6..aad7bad35d 100644 --- a/vertexai/preview/rag/utils/resources.py +++ b/vertexai/preview/rag/utils/resources.py @@ -69,6 +69,22 @@ class EmbeddingModelConfig: model_version_id: Optional[str] = None +@dataclasses.dataclass +class Weaviate: + """Weaviate. + + Attributes: + weaviate_http_endpoint: The Weaviate DB instance HTTP endpoint + collection_name: The corresponding Weaviate collection this corpus maps to + api_key: The SecretManager resource name for the Weaviate DB API token. Format: + ``projects/{project}/secrets/{secret}/versions/{version}`` + """ + + weaviate_http_endpoint: str + collection_name: str + api_key: str + + @dataclasses.dataclass class RagCorpus: """RAG corpus(output only). @@ -78,12 +94,15 @@ class RagCorpus: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}`` display_name: Display name that was configured at client side. description: The description of the RagCorpus. + embedding_model_config: The embedding model config of the RagCorpus. + vector_db: The Vector DB of the RagCorpus. """ name: Optional[str] = None display_name: Optional[str] = None description: Optional[str] = None embedding_model_config: Optional[EmbeddingModelConfig] = None + vector_db: Optional[Weaviate] = None @dataclasses.dataclass diff --git a/vertexai/preview/tokenization.py b/vertexai/preview/tokenization.py index add6cf67cc..ad7e06203b 100644 --- a/vertexai/preview/tokenization.py +++ b/vertexai/preview/tokenization.py @@ -16,7 +16,7 @@ # We just want to re-export certain classes # pylint: disable=g-multiple-import,g-importing-member from vertexai.tokenization._tokenizers import ( - get_tokenizer_for_model, + _get_tokenizer_for_model_preview as get_tokenizer_for_model, ) diff --git a/vertexai/tokenization/_tokenizers.py b/vertexai/tokenization/_tokenizers.py index 1bad192db6..17e2b5b823 100644 --- a/vertexai/tokenization/_tokenizers.py +++ b/vertexai/tokenization/_tokenizers.py @@ -18,13 +18,18 @@ Iterable, List, Sequence, + Optional, + Union, ) from vertexai.generative_models._generative_models import ( ContentsType, Image, + Tool, + PartsType, _validate_contents_type_as_valid_sequence, _content_types_to_gapic_contents, + _to_content, ) from vertexai.tokenization._tokenizer_loading import ( @@ -35,8 +40,10 @@ from google.cloud.aiplatform_v1beta1.types import ( content as gapic_content_types, tool as gapic_tool_types, + openapi, ) from sentencepiece import sentencepiece_model_pb2 +from google.protobuf import struct_pb2 @dataclasses.dataclass(frozen=True) @@ -48,11 +55,37 @@ class TokensInfo: @dataclasses.dataclass(frozen=True) class ComputeTokensResult: - token_info_list: Sequence[TokensInfo] + """Represents token string pieces and ids output in compute_tokens function. + + Attributes: + tokens_info: Lists of tokens_info from the input. + The input `contents: ContentsType` could have + multiple string instances and each tokens_info + item represents each string instance. Each token + info consists tokens list, token_ids list and + a role. + """ + + tokens_info: Sequence[TokensInfo] + + +class PreviewComputeTokensResult(ComputeTokensResult): + def token_info_list(self) -> Sequence[TokensInfo]: + import warnings + + message = "PreviewComputeTokensResult.token_info_list is deprecated. Use ComputeTokensResult.tokens_info instead." + warnings.warn(message, DeprecationWarning, stacklevel=2) + return self.tokens_info @dataclasses.dataclass(frozen=True) class CountTokensResult: + """Represents an token numbers output in count_tokens function. + + Attributes: + total_tokens: number of total tokens. + """ + total_tokens: int @@ -129,27 +162,20 @@ def compute_tokens( role=role, ) ) - return ComputeTokensResult(token_info_list=token_infos) + return ComputeTokensResult(tokens_info=token_infos) def _to_gapic_contents( contents: ContentsType, -) -> gapic_content_types.Content: +) -> List[gapic_content_types.Content]: """Converts a GenerativeModel compatible contents type to a gapic content.""" _validate_contents_type_as_valid_sequence(contents) _assert_no_image_contents_type(contents) gapic_contents = _content_types_to_gapic_contents(contents) - _assert_text_only_content_types_sequence(gapic_contents) + # _assert_text_only_content_types_sequence(gapic_contents) return gapic_contents -def _content_types_to_string_iterator(contents: ContentsType) -> Iterable[str]: - """Converts a GenerativeModel compatible contents type to a list of strings.""" - gapic_contents = _to_gapic_contents(contents) - for content in gapic_contents: - yield from _to_string_array(content) - - def _content_types_to_role_iterator(contents: ContentsType) -> Iterable[str]: gapic_contents = _to_gapic_contents(contents) # Flattening role by content's multi parts @@ -158,13 +184,6 @@ def _content_types_to_role_iterator(contents: ContentsType) -> Iterable[str]: yield content.role -def _to_string_array(content: gapic_content_types.Content) -> Iterable[str]: - """Converts a gapic content type to a list of strings.""" - if not content: - raise TypeError("content must not be empty.") - return [part.text for part in content.parts] - - def _assert_no_image_contents_type(contents: ContentsType): """Asserts that the contents type does not contain any image content.""" if isinstance(contents, Image) or ( @@ -174,37 +193,12 @@ def _assert_no_image_contents_type(contents: ContentsType): raise ValueError("Tokenizers do not support Image content type.") -def _assert_text_only_content_types_sequence( - contents: List[gapic_content_types.Content], -): - """Asserts that the contents type does not contain any non-text content.""" - for value in contents: - for part in value.parts: - _assert_text_only_gapic_part(part) - - -def _assert_text_only_gapic_part(value: gapic_content_types.Part): - """Asserts that the gapic content part is a text content type.""" - if ( - gapic_content_types.FileData() != value.file_data - or gapic_content_types.Blob() != value.inline_data - or gapic_tool_types.FunctionCall() != value.function_call - or gapic_tool_types.FunctionResponse() != value.function_response - or gapic_content_types.VideoMetadata() != value.video_metadata - ): - raise ValueError("Tokenizers do not support non-text content types.") - - -def _to_canonical_contents_texts(contents: ContentsType) -> Iterable[str]: - """Gets the canonical contents.""" - if isinstance(contents, str): - yield contents - elif isinstance(contents, Sequence) and all( - isinstance(content, str) for content in contents - ): - yield from contents - else: - yield from _content_types_to_string_iterator(contents) +def _is_string_inputs(contents: ContentsType) -> bool: + return ( + isinstance(contents, str) + or isinstance(contents, Sequence) + and all(isinstance(content, str) for content in contents) + ) def _to_canonical_roles(contents: ContentsType) -> Iterable[str]: @@ -218,6 +212,224 @@ def _to_canonical_roles(contents: ContentsType) -> Iterable[str]: yield from _content_types_to_role_iterator(contents) +class _TextsAccumulator: + """Accumulates texts from contents and tools. + + This class is used to accumulate countable texts from contents and tools. + When user passes a unsupported fields that are added in the future, the new + fields might be only counted in remote tokenizer. In this case, the local + tokenizer should know that an unsupported new field exist in the content or + tool instances and raise error to avoid returning incorrect result to users. + + The machanism to detect unsupported fields introduced in the future: when + local tokenizer traversing the input instances, it is allowlist based text + accumulation. When a field is traversed and evaluated to be countable, the + value of this + field is copied to two places: (1) self._texts for inputs to sentencepiece + token count function. (2) a counted instance object in the recursive + function's return value. That's to say, after done current recurssion, + the instance(of same type as the input) only keeps the counted values. + If user sets unsupported future proto fields, they can be detected by + comparing the input instances equal to counted instances or not. + """ + + def __init__(self): + self._texts = [] + + def get_texts(self) -> Iterable[str]: + return self._texts + + def add_texts(self, texts: Union[Iterable[str], str]) -> None: + if isinstance(texts, str): + self._texts.append(texts) + else: + self._texts.extend(texts) + + def add_content(self, content: gapic_content_types.Content) -> None: + counted_content = gapic_content_types.Content() + for part in content.parts: + counted_part = gapic_content_types.Part() + if "file_data" in part or "inline_data" in part: + raise ValueError("Tokenizers do not support non-text content types.") + if "video_metadata" in part: + counted_part.video_metadata = part.video_metadata + if "function_call" in part: + self.add_function_call(part.function_call) + counted_part.function_call = part.function_call + if "function_response" in part: + self.add_function_response(part.function_response) + counted_part.function_response = part.function_response + if "text" in part: + counted_part.text = part.text + self._texts.append(part.text) + counted_content.parts.append(counted_part) + counted_content.role = content.role + if content._pb != counted_content._pb: + raise ValueError( + f"Content contains unsupported types for token counting. Supported fields {counted_content}. Got {content}." + ) + + def add_function_call(self, function_call: gapic_tool_types.FunctionCall) -> None: + """Processes a function call and adds relevant text to the accumulator. + + Args: + function_call: The function call to process. + """ + self._texts.append(function_call.name) + counted_function_call = gapic_tool_types.FunctionCall(name=function_call.name) + counted_struct = self._struct_traverse(function_call._pb.args) + counted_function_call.args = counted_struct + if counted_function_call._pb != function_call._pb: + raise ValueError( + f"Function call argument contains unsupported types for token counting. Supported fields {counted_function_call}. Got {function_call}." + ) + + def add_function_calls( + self, function_calls: Iterable[gapic_tool_types.FunctionCall] + ) -> None: + for function_call in function_calls: + self.add_function_call(function_call) + + def add_tool(self, tool: gapic_tool_types.Tool) -> gapic_tool_types.Tool: + counted_tool = gapic_tool_types.Tool() + for function_declaration in tool.function_declarations: + counted_function_declaration = self._function_declaration_traverse( + function_declaration + ) + counted_tool.function_declarations.append(counted_function_declaration) + if counted_tool._pb != tool._pb: + raise ValueError( + f"Tool argument contains unsupported types for token counting. Supported fields {counted_tool}. Got {tool}." + ) + + def add_tools(self, tools: Iterable[gapic_tool_types.Tool]) -> None: + for tool in tools: + self.add_tool(tool) + + def add_function_responses( + self, function_responses: Iterable[gapic_tool_types.FunctionResponse] + ) -> None: + for function_response in function_responses: + self.add_function_response(function_response) + + def add_function_response( + self, function_response: gapic_tool_types.FunctionResponse + ) -> None: + counted_function_response = gapic_tool_types.FunctionResponse() + self._texts.append(function_response.name) + counted_struct = self._struct_traverse(function_response._pb.response) + counted_function_response.name = function_response.name + counted_function_response.response = counted_struct + if counted_function_response._pb != function_response._pb: + raise ValueError( + f"Function response argument contains unsupported types for token counting. Supported fields {counted_function_response}. Got {function_response}." + ) + + def _function_declaration_traverse( + self, function_declaration: gapic_tool_types.FunctionDeclaration + ) -> gapic_tool_types.FunctionDeclaration: + counted_function_declaration = gapic_tool_types.FunctionDeclaration() + self._texts.append(function_declaration.name) + counted_function_declaration.name = function_declaration.name + if function_declaration.description: + self._texts.append(function_declaration.description) + counted_function_declaration.description = function_declaration.description + if function_declaration.parameters: + counted_parameters = self._schema_traverse(function_declaration.parameters) + counted_function_declaration.parameters = counted_parameters + if function_declaration.response: + counted_response = self._schema_traverse(function_declaration.response) + counted_function_declaration.response = counted_response + return counted_function_declaration + + def _schema_traverse(self, schema: openapi.Schema) -> openapi.Schema: + """Processes a schema and adds relevant text to the accumulator. + + Args: + schema: The schema to process. + + Returns: + The new schema object with only countable fields. + """ + counted_schema = openapi.Schema() + if "type_" in schema: + counted_schema.type = schema.type + if "title" in schema: + counted_schema.title = schema.title + if "default" in schema: + counted_schema.default = schema.default + if "format_" in schema: + self._texts.append(schema.format_) + counted_schema.format_ = schema.format_ + if "description" in schema: + self._texts.append(schema.description) + counted_schema.description = schema.description + if "enum" in schema: + self._texts.extend(schema.enum) + counted_schema.enum = schema.enum + if "required" in schema: + self._texts.extend(schema.required) + counted_schema.required = schema.required + if "items" in schema: + counted_schema_items = self._schema_traverse(schema.items) + counted_schema.items = counted_schema_items + if "properties" in schema: + d = {} + for key, value in schema.properties.items(): + self._texts.append(key) + counted_value = self._schema_traverse(value) + d[key] = counted_value + counted_schema.properties.update(d) + if "example" in schema: + counted_schema_example = self._value_traverse(schema._pb.example) + counted_schema.example = counted_schema_example + return counted_schema + + def _struct_traverse(self, struct: struct_pb2.Struct) -> struct_pb2.Struct: + """Processes a struct and adds relevant text to the accumulator. + + Args: + struct: The struct to process. + + Returns: + The new struct object with only countable fields. + """ + counted_struct = struct_pb2.Struct() + self._texts.extend(list(struct.fields.keys())) + for key, val in struct.fields.items(): + counted_struct_fields = self._value_traverse(val) + if isinstance(counted_struct_fields, str): + counted_struct.fields[key] = counted_struct_fields + else: + counted_struct.fields[key].MergeFrom(counted_struct_fields) + return counted_struct + + def _value_traverse(self, value: struct_pb2.Value) -> struct_pb2.Value: + """Processes a struct field and adds relevant text to the accumulator. + + Args: + struct: The struct field to process. + + Returns: + The new struct field object with only countable fields. + """ + kind = value.WhichOneof("kind") + counted_value = struct_pb2.Value() + if kind == "string_value": + self._texts.append(value.string_value) + counted_value.string_value = value.string_value + elif kind == "struct_value": + counted_struct = self._struct_traverse(value.struct_value) + counted_value.struct_value.MergeFrom(counted_struct) + elif kind == "list_value": + counted_list_value = struct_pb2.ListValue() + for item in value.list_value.values: + counted_value = self._value_traverse(item) + counted_list_value.values.append(counted_value) + counted_value.list_value.MergeFrom(counted_list_value) + return counted_value + + class Tokenizer: """A tokenizer that can parse text into tokens.""" @@ -228,10 +440,17 @@ def __init__(self, tokenizer_name: str): Args: name: The name of the tokenizer. + """ self._sentencepiece_adapter = _SentencePieceAdaptor(tokenizer_name) - def count_tokens(self, contents: ContentsType) -> CountTokensResult: + def count_tokens( + self, + contents: ContentsType, + *, + tools: Optional[List["Tool"]] = None, + system_instruction: Optional[PartsType] = None, + ) -> CountTokensResult: r"""Counts the number of tokens in the text-only contents. Args: @@ -244,16 +463,34 @@ def count_tokens(self, contents: ContentsType) -> CountTokensResult: * List[Union[str, Part]], * List[Content] Throws an error if the contents contain non-text content. - + tools: A list of tools (functions) that the model can try calling. + system_instruction: The provided system instructions for the model. + Note: only text should be used in parts and content in each part + will be in a separate paragraph. Returns: A CountTokensResult object containing the total number of tokens in the contents. """ - return self._sentencepiece_adapter.count_tokens( - _to_canonical_contents_texts(contents) - ) + text_accumulator = _TextsAccumulator() + if _is_string_inputs(contents): + text_accumulator.add_texts(contents) + else: + gapic_contents = _to_gapic_contents(contents) + for content in gapic_contents: + text_accumulator.add_content(content) + + if tools: + text_accumulator.add_tools((tool._raw_tool for tool in tools)) + + if system_instruction: + if _is_string_inputs(system_instruction): + text_accumulator.add_texts(system_instruction) + else: + text_accumulator.add_content(_to_content(system_instruction)) + + return self._sentencepiece_adapter.count_tokens(text_accumulator.get_texts()) def compute_tokens(self, contents: ContentsType) -> ComputeTokensResult: r"""Computes the tokens ids and string pieces in the text-only contents. @@ -272,13 +509,55 @@ def compute_tokens(self, contents: ContentsType) -> ComputeTokensResult: Returns: A ComputeTokensResult object containing the tokens ids and string pieces in the contents. + + Examples: + compute_tokens(["hello world", "what's the weather today"]) + outputs: + ComputeTokensResult(tokens_info=[TokensInfo(token_ids=[17534, 2134], tokens=[b'hello', b' world'], role='user'), TokensInfo(token_ids=[84264, 235341], tokens=[b'Goodbye', b'!'], role='user')], token_info_list=...The same as tokens_info) + """ + text_accumulator = _TextsAccumulator() + if _is_string_inputs(contents): + text_accumulator.add_texts(contents) + else: + gapic_contents = _to_gapic_contents(contents) + for content in gapic_contents: + text_accumulator.add_content(content) + return self._sentencepiece_adapter.compute_tokens( - contents=_to_canonical_contents_texts(contents), + contents=text_accumulator.get_texts(), roles=_to_canonical_roles(contents), ) +class PreviewTokenizer(Tokenizer): + def compute_tokens(self, contents: ContentsType) -> PreviewComputeTokensResult: + return PreviewComputeTokensResult( + tokens_info=super().compute_tokens(contents).tokens_info + ) + + +def _get_tokenizer_for_model_preview(model_name: str) -> PreviewTokenizer: + """Returns a tokenizer for the given tokenizer name. + + Usage: + ``` + tokenizer = get_tokenizer_for_model("gemini-1.5-pro-001") + print(tokenizer.count_tokens("Hello world!")) + ``` + + Supported models can be found at + https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. + + Args: + model_name: Specify the tokenizer is from which model. + """ + if not model_name: + raise ValueError("model_name must not be empty.") + + return PreviewTokenizer(get_tokenizer_name(model_name)) + + def get_tokenizer_for_model(model_name: str) -> Tokenizer: """Returns a tokenizer for the given tokenizer name. @@ -288,6 +567,9 @@ def get_tokenizer_for_model(model_name: str) -> Tokenizer: print(tokenizer.count_tokens("Hello world!")) ``` + Supported models can be found at + https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. + Args: model_name: Specify the tokenizer is from which model. """ diff --git a/vertexai/vision_models/_vision_models.py b/vertexai/vision_models/_vision_models.py index f9de7ef3ff..a80f0b2dc1 100644 --- a/vertexai/vision_models/_vision_models.py +++ b/vertexai/vision_models/_vision_models.py @@ -792,12 +792,15 @@ def upscale_image( self, image: Union["Image", "GeneratedImage"], new_size: Optional[int] = 2048, + upscale_factor: Optional[Literal["x2", "x4"]] = None, + output_mime_type: Optional[Literal["image/png", "image/jpeg"]] = "image/png", + output_compression_quality: Optional[int] = None, output_gcs_uri: Optional[str] = None, ) -> "Image": """Upscales an image. - This supports upscaling images generated through the `generate_images()` method, - or upscaling a new image that is 1024x1024. + This supports upscaling images generated through the `generate_images()` + method, or upscaling a new image. Examples:: @@ -812,27 +815,63 @@ def upscale_image( my_image = Image.load_from_file("my-image.png") model.upscale_image(image=my_image) + # Upscale a new arbitrary sized image using a x2 or x4 upscaling factor + my_image = Image.load_from_file("my-image.png") + model.upscale_image(image=my_image, upscale_factor="x2") + + # Upscale an image and get the result in JPEG format + my_image = Image.load_from_file("my-image.png") + model.upscale_image(image=my_image, output_mime_type="image/jpeg", + output_compression_quality=90) + Args: - image (Union[GeneratedImage, Image]): - Required. The generated image to upscale. - new_size (int): - The size of the biggest dimension of the upscaled image. Only 2048 and 4096 are currently - supported. Results in a 2048x2048 or 4096x4096 image. Defaults to 2048 if not provided. - output_gcs_uri: Google Cloud Storage uri to store the upscaled images. + image (Union[GeneratedImage, Image]): Required. The generated image + to upscale. + new_size (int): The size of the biggest dimension of the upscaled + image. + Only 2048 and 4096 are currently supported. Results in a + 2048x2048 or 4096x4096 image. Defaults to 2048 if not provided. + upscale_factor: The upscaling factor. Supported values are "x2" and + "x4". Defaults to None. + output_mime_type: The mime type of the output image. Supported values + are "image/png" and "image/jpeg". Defaults to "image/png". + output_compression_quality: The compression quality of the output + image + as an int (0-100). Only applicable if the output mime type is + "image/jpeg". Defaults to None. + output_gcs_uri: Google Cloud Storage uri to store the upscaled + images. Returns: An `Image` object. """ + target_image_size = new_size if new_size else None + longest_dim = max(image._size[0], image._size[1]) - # Currently this method only supports 1024x1024 images - if image._size[0] != 1024 and image._size[1] != 1024: - raise ValueError( - "Upscaling is currently only supported on images that are 1024x1024." - ) + if not new_size and not upscale_factor: + raise ValueError("Either new_size or upscale_factor must be provided.") + + if not upscale_factor: + x2_factor = 2.0 + x4_factor = 4.0 + epsilon = 0.1 + is_upscaling_x2_request = abs(new_size / longest_dim - x2_factor) < epsilon + is_upscaling_x4_request = abs(new_size / longest_dim - x4_factor) < epsilon + if not is_upscaling_x2_request and not is_upscaling_x4_request: + raise ValueError( + "Only x2 and x4 upscaling are currently supported. Requested" + f" upscaling factor: {new_size / longest_dim}" + ) + else: + if upscale_factor == "x2": + target_image_size = longest_dim * 2 + else: + target_image_size = longest_dim * 4 if new_size not in _SUPPORTED_UPSCALING_SIZES: raise ValueError( - f"Only the folowing square upscaling sizes are currently supported: {_SUPPORTED_UPSCALING_SIZES}." + "Only the folowing square upscaling sizes are currently supported:" + f" {_SUPPORTED_UPSCALING_SIZES}." ) instance = {"prompt": ""} @@ -847,14 +886,25 @@ def upscale_image( } parameters = { - "sampleImageSize": str(new_size), "sampleCount": 1, "mode": "upscale", } + if upscale_factor: + parameters["upscaleConfig"] = {"upscaleFactor": upscale_factor} + + else: + parameters["sampleImageSize"] = str(new_size) + if output_gcs_uri is not None: parameters["storageUri"] = output_gcs_uri + parameters["outputOptions"] = {"mimeType": output_mime_type} + if output_mime_type == "image/jpeg" and output_compression_quality is not None: + parameters["outputOptions"][ + "compressionQuality" + ] = output_compression_quality + response = self._endpoint.predict( instances=[instance], parameters=parameters, @@ -868,7 +918,7 @@ def upscale_image( else: generation_parameters = {} - generation_parameters["upscaled_image_size"] = new_size + generation_parameters["upscaled_image_size"] = target_image_size encoded_bytes = upscaled_image.get("bytesBase64Encoded") return GeneratedImage(