Skip to content

Commit

Permalink
Support properties in command component (Azure#27000)
Browse files Browse the repository at this point in the history
* support properties

* remove properties if not specified

* remove properties in post dump
  • Loading branch information
D-W- authored Oct 25, 2022
1 parent 9bfdf2e commit e4ab9f3
Show file tree
Hide file tree
Showing 6 changed files with 486 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from copy import deepcopy

import yaml
from marshmallow import INCLUDE, fields, post_load
from marshmallow import INCLUDE, fields, post_load, post_dump

from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
from azure.ai.ml._schema.component.component import ComponentSchema
Expand Down Expand Up @@ -48,6 +48,14 @@ class Meta:
]
),
)
properties = fields.Dict(keys=fields.Str(), values=fields.Raw())

@post_dump
def remove_unnecessary_fields(self, component_schema_dict, **kwargs):
# remove empty properties to keep the component spec unchanged
if not component_schema_dict.get("properties"):
component_schema_dict.pop("properties", None)
return component_schema_dict


class RestCommandComponentSchema(CommandComponentSchema):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class CommandComponent(Component, ParameterizedCommand):
:type instance_count: int
:param is_deterministic: Whether the command component is deterministic.
:type is_deterministic: bool
:param properties: Properties of the component. Contents inside will pass through to backend as a dictionary.
:type properties: dict
:raises ~azure.ai.ml.exceptions.ValidationException: Raised if CommandComponent cannot be successfully validated.
Details will be provided in the error message.
"""
Expand All @@ -75,6 +78,7 @@ def __init__(
outputs: Dict = None,
instance_count: int = None, # promoted property from resources.instance_count
is_deterministic: bool = True,
properties: Dict = None,
**kwargs,
):
# validate init params are valid type
Expand All @@ -98,6 +102,7 @@ def __init__(
inputs=inputs,
outputs=outputs,
is_deterministic=is_deterministic,
properties=properties,
**kwargs,
)

Expand Down
41 changes: 40 additions & 1 deletion sdk/ml/azure-ai-ml/tests/component/e2etests/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import pydash
import pytest
from test_utilities.utils import _PYTEST_TIMEOUT_METHOD

from azure.ai.ml import MLClient, MpiDistribution, load_component, load_environment
from azure.ai.ml._restclient.v2022_05_01.models import ComponentContainerData, ListViewType
Expand Down Expand Up @@ -876,3 +875,43 @@ def test_component_with_default_label(

node = default_component()
assert node._to_rest_object()["componentId"] == default_component.id

def test_command_component_with_properties_e2e_flow(self, client: MLClient, randstr: Callable[[str], str]) -> None:
command_component = load_component(
source="./tests/test_configs/components/helloworld_component_with_properties.yml",
)
expected_dict = {
'$schema': 'https://azuremlschemas.azureedge.net/development/commandComponent.schema.json',
'_source': 'YAML.COMPONENT',
'command': 'echo Hello World & echo $[[${{inputs.component_in_number}}]] & '
'echo ${{inputs.component_in_path}} & echo '
'${{outputs.component_out_path}} > '
'${{outputs.component_out_path}}/component_in_number',
'description': 'This is the basic command component',
'display_name': 'CommandComponentBasic',
'inputs': {'component_in_number': {'default': '10.99',
'description': 'A number',
'optional': True,
'type': 'number'},
'component_in_path': {'description': 'A path',
'type': 'uri_folder'}},
'is_deterministic': True,
'outputs': {'component_out_path': {'type': 'uri_folder'}},
'properties': {'azureml.pipelines.dynamic': 'true'},
'tags': {'owner': 'sdkteam', 'tag': 'tagvalue'},
'type': 'command',
}
omit_fields = ["name", "creation_context", "id", "code", "environment", "version"]
rest_component = pydash.omit(
command_component._to_rest_object().as_dict()["properties"]["component_spec"],
omit_fields,
)

assert rest_component == expected_dict

from_rest_component = client.components.create_or_update(command_component, is_anonymous=True)

previous_dict = pydash.omit(command_component._to_dict(), omit_fields)
current_dict = pydash.omit(from_rest_component._to_dict(), omit_fields)
# TODO(2037030): verify when backend ready
# assert previous_dict == current_dict
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,14 @@ def test_anonymous_component_same_name(self, mock_machinelearning_client: MLClie
component_hash2 = component_entity2._get_anonymous_hash()
assert component_hash1 != component_hash2

def test_command_component_with_properties(self):
test_path = "./tests/test_configs/components/helloworld_component_with_properties.yml"
component_entity = load_component(source=test_path)
assert component_entity.properties == {"azureml.pipelines.dynamic": "true"}

validation_result = component_entity._validate()
assert validation_result.passed is True


@pytest.mark.timeout(_COMPONENT_TIMEOUT_SECOND)
@pytest.mark.unittest
Expand Down
Loading

0 comments on commit e4ab9f3

Please sign in to comment.