Skip to content

Commit

Permalink
feat: GenAI - Tuning - Added support for tuned model rebasing. Added …
Browse files Browse the repository at this point in the history
…`rebase_tuned_model` to `vertexai.preview.tuning.sft`.

PiperOrigin-RevId: 688795387
  • Loading branch information
Ark-kun authored and copybara-github committed Oct 23, 2024
1 parent da76253 commit 2cef97f
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
4 changes: 4 additions & 0 deletions vertexai/preview/tuning/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
train,
SupervisedTuningJob,
)
from vertexai.tuning._tuning import (
rebase_tuned_model,
)

__all__ = [
"rebase_tuned_model",
"train",
"SupervisedTuningJob",
]
82 changes: 82 additions & 0 deletions vertexai/tuning/_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,85 @@ def _dashboard_url(self) -> str:
job = list(fields.values())[0]
url = f"https://console.cloud.google.com/vertex-ai/generative/language/locations/{location}/tuning/tuningJob/{job}?project={project}"
return url


def rebase_tuned_model(
tuned_model_ref: str,
*,
# TODO(b/372291558): Add support for overriding tuning job config
# tuning_job_config: Optional["TuningJob"] = None,
artifact_destination: Optional[str] = None,
deploy_to_same_endpoint: Optional[bool] = False,
):
"""Re-runs fine tuning on top of a new foundational model.
Takes a legacy Tuned GenAI model Reference and creates a TuningJob based
on a new model.
Args:
tuned_model_ref: Required. TunedModel reference to retrieve
the legacy model information.
tuning_job_config: The TuningJob to be updated. Users
can use this TuningJob field to overwrite tuning
configs.
artifact_destination: The Google Cloud Storage location to write the artifacts.
deploy_to_same_endpoint:
Optional. By default, bison to gemini
migration will always create new model/endpoint,
but for gemini-1.0 to gemini-1.5 migration, we
default deploy to the same endpoint. See details
in this Section.
Returns:
The new TuningJob.
"""
parent = aiplatform_initializer.global_config.common_location_path(
project=aiplatform_initializer.global_config.project,
location=aiplatform_initializer.global_config.location,
)

if "/tuningJobs/" in tuned_model_ref:
gapic_tuned_model_ref = gca_types.TunedModelRef(
tuning_job=tuned_model_ref,
)
elif "/pipelineJobs/" in tuned_model_ref:
gapic_tuned_model_ref = gca_types.TunedModelRef(
pipeline_job=tuned_model_ref,
)
elif "/models/" in tuned_model_ref:
gapic_tuned_model_ref = gca_types.TunedModelRef(
tuned_model=tuned_model_ref,
)
else:
raise ValueError(f"Unsupported tuned_model_ref: {tuned_model_ref}.")

# gapic_tuning_job_config = tuning_job._gca_resource if tuning_job else None
gapic_tuning_job_config = None

gapic_artifact_destination = (
gca_types.GcsDestination(output_uri_prefix=artifact_destination)
if artifact_destination
else None
)

api_client: gen_ai_tuning_service_v1beta1.GenAiTuningServiceClient = (
TuningJob._instantiate_client(
location=aiplatform_initializer.global_config.location,
credentials=aiplatform_initializer.global_config.credentials,
)
)
rebase_operation = api_client.rebase_tuned_model(
gca_types.RebaseTunedModelRequest(
parent=parent,
tuned_model_ref=gapic_tuned_model_ref,
tuning_job=gapic_tuning_job_config,
artifact_destination=gapic_artifact_destination,
deploy_to_same_endpoint=deploy_to_same_endpoint,
)
)
_LOGGER.log_create_with_lro(TuningJob, lro=rebase_operation)
gapic_rebase_tuning_job: gca_types.TuningJob = rebase_operation.result()
rebase_tuning_job = TuningJob._construct_sdk_resource_from_gapic(
gapic_resource=gapic_rebase_tuning_job,
)
return rebase_tuning_job

0 comments on commit 2cef97f

Please sign in to comment.