Skip to content

Commit

Permalink
feat: Add serializer.register_custom_command()
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570206347
  • Loading branch information
matthew29tang authored and copybara-github committed Oct 2, 2023
1 parent a36daa7 commit 639cf10
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 11 deletions.
24 changes: 16 additions & 8 deletions tests/unit/vertexai/test_remote_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,25 +492,29 @@ def mock_any_serializer_serialize_sklearn():
{
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
f"scikit-learn=={sklearn.__version__}"
]
],
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
},
{
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
f"numpy=={np.__version__}",
f"cloudpickle=={cloudpickle.__version__}",
]
],
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
},
{
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
f"numpy=={np.__version__}",
f"cloudpickle=={cloudpickle.__version__}",
]
],
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
},
{
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
f"numpy=={np.__version__}",
f"cloudpickle=={cloudpickle.__version__}",
]
],
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
},
],
) as mock_any_serializer_serialize:
Expand Down Expand Up @@ -575,25 +579,29 @@ def mock_any_serializer_serialize_keras():
{
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
f"tensorflow=={tf.__version__}"
]
],
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
},
{
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
f"numpy=={np.__version__}",
f"cloudpickle=={cloudpickle.__version__}",
]
],
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
},
{
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
f"numpy=={np.__version__}",
f"cloudpickle=={cloudpickle.__version__}",
]
],
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
},
{
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
f"numpy=={np.__version__}",
f"cloudpickle=={cloudpickle.__version__}",
]
],
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
},
],
) as mock_any_serializer_serialize:
Expand Down
13 changes: 11 additions & 2 deletions vertexai/preview/_workflow/executor/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ def remote_training(invokable: shared._Invokable, rewrapper: Any):
]

requirements = []
custom_commands = []

enable_cuda = config.enable_cuda

Expand Down Expand Up @@ -641,8 +642,16 @@ def remote_training(invokable: shared._Invokable, rewrapper: Any):

requirements = _add_indirect_dependency_versions(requirements)
command = ["export PIP_ROOT_USER_ACTION=ignore &&"]
if config.custom_commands:
custom_commands = [f"{command} &&" for command in config.custom_commands]

# Combine user custom_commands and serializer custom_commands
custom_commands += serialization_metadata[
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY
]
custom_commands += config.custom_commands
custom_commands = list(dict.fromkeys(custom_commands))

if custom_commands:
custom_commands = [f"{command} &&" for command in custom_commands]
command.extend(custom_commands)
if requirements:
command.append("pip install --upgrade pip &&")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,11 @@ def serialize(
) -> str:
# All bigframe serializers will be identical (bigframes.dataframe.DataFrame --> parquet)
# Record the framework in metadata for deserialization
BigframeSerializer._metadata.framework = kwargs.get("framework")
detected_framework = kwargs.get("framework")
BigframeSerializer._metadata.framework = detected_framework
if detected_framework == "torch":
self.register_custom_command("pip install torchdata")
self.register_custom_command("pip install torcharrow")
if not _is_valid_gcs_path(gcs_path):
raise ValueError(f"Invalid gcs path: {gcs_path}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
SERIALIZATION_METADATA_FILENAME = "serialization_metadata"
SERIALIZATION_METADATA_SERIALIZER_KEY = "serializer"
SERIALIZATION_METADATA_DEPENDENCIES_KEY = "dependencies"
SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY = "custom_commands"


@dataclasses.dataclass
Expand Down Expand Up @@ -133,11 +134,13 @@ def deserialize(self, gcs_path):

serializer: Optional[str] = None
dependencies: List[str] = dataclasses.field(default_factory=list)
custom_commands: List[str] = dataclasses.field(default_factory=list)

def to_dict(self):
return {
SERIALIZATION_METADATA_SERIALIZER_KEY: self.serializer,
SERIALIZATION_METADATA_DEPENDENCIES_KEY: self.dependencies,
SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: self.custom_commands,
}


Expand Down Expand Up @@ -322,6 +325,12 @@ def _dedupe_deps(cls):
# the version if version is not specified.
cls._metadata.dependencies = list(dict.fromkeys(cls._metadata.dependencies))

@classmethod
def _dedupe_custom_commands(cls):
cls._metadata.custom_commands = list(
dict.fromkeys(cls._metadata.custom_commands)
)

@classmethod
def register_requirement(cls, required_package: str):
# TODO(b/280648121) Consider allowing the user to register the
Expand All @@ -334,3 +343,8 @@ def register_requirement(cls, required_package: str):
def register_requirements(cls, requirements: List[str]):
cls._metadata.dependencies.extend(requirements)
cls._dedupe_deps()

@classmethod
def register_custom_command(cls, custom_command: str):
cls._metadata.custom_commands.append(custom_command)
cls._dedupe_custom_commands()

0 comments on commit 639cf10

Please sign in to comment.