Skip to content

Commit

Permalink
chore: GenAI - Improved the output of the to_dict methods
Browse files Browse the repository at this point in the history
Improvements:

* Enum values are not represented as string names
* Default proto values are no longer included.

PiperOrigin-RevId: 628458271
  • Loading branch information
Ark-kun authored and copybara-github committed Apr 26, 2024
1 parent 819a44e commit 9809a3a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
3 changes: 3 additions & 0 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,9 @@ def test_conversion_methods(self, generative_models: generative_models):
part_new = generative_models.Part.from_dict(part.to_dict())
assert repr(part_new) == repr(part)

# Checking that the enums are serialized as strings, not integers.
assert response.to_dict()["candidates"][0]["finish_reason"] == "STOP"

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
Expand Down
27 changes: 17 additions & 10 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,7 @@ def from_dict(cls, generation_config_dict: Dict[str, Any]) -> "GenerationConfig"
return cls._from_gapic(raw_generation_config=raw_generation_config)

def to_dict(self) -> Dict[str, Any]:
return type(self._raw_generation_config).to_dict(self._raw_generation_config)
return _proto_to_dict(self._raw_generation_config)

def __repr__(self) -> str:
return self._raw_generation_config.__repr__()
Expand Down Expand Up @@ -1373,7 +1373,7 @@ def from_dict(cls, tool_dict: Dict[str, Any]) -> "Tool":
return cls._from_gapic(raw_tool=raw_tool)

def to_dict(self) -> Dict[str, Any]:
return type(self._raw_tool).to_dict(self._raw_tool)
return _proto_to_dict(self._raw_tool)

def __repr__(self) -> str:
return self._raw_tool.__repr__()
Expand Down Expand Up @@ -1538,9 +1538,7 @@ def from_func(cls, func: Callable[..., Any]) -> "CallableFunctionDeclaration":
return CallableFunctionDeclaration.from_func(func)

def to_dict(self) -> Dict[str, Any]:
return type(self._raw_function_declaration).to_dict(
self._raw_function_declaration
)
return _proto_to_dict(self._raw_function_declaration)

def __repr__(self) -> str:
return self._raw_function_declaration.__repr__()
Expand Down Expand Up @@ -1638,7 +1636,7 @@ def from_dict(cls, response_dict: Dict[str, Any]) -> "GenerationResponse":
return cls._from_gapic(raw_response=raw_response)

def to_dict(self) -> Dict[str, Any]:
return type(self._raw_response).to_dict(self._raw_response)
return _proto_to_dict(self._raw_response)

def __repr__(self) -> str:
return self._raw_response.__repr__()
Expand Down Expand Up @@ -1710,7 +1708,7 @@ def from_dict(cls, candidate_dict: Dict[str, Any]) -> "Candidate":
return cls._from_gapic(raw_candidate=raw_candidate)

def to_dict(self) -> Dict[str, Any]:
return type(self._raw_candidate).to_dict(self._raw_candidate)
return _proto_to_dict(self._raw_candidate)

def __repr__(self) -> str:
return self._raw_candidate.__repr__()
Expand Down Expand Up @@ -1800,7 +1798,7 @@ def from_dict(cls, content_dict: Dict[str, Any]) -> "Content":
return cls._from_gapic(raw_content=raw_content)

def to_dict(self) -> Dict[str, Any]:
return type(self._raw_content).to_dict(self._raw_content)
return _proto_to_dict(self._raw_content)

def __repr__(self) -> str:
return self._raw_content.__repr__()
Expand Down Expand Up @@ -1912,7 +1910,7 @@ def from_function_response(name: str, response: Dict[str, Any]) -> "Part":
)

def to_dict(self) -> Dict[str, Any]:
return type(self._raw_part).to_dict(self._raw_part)
return _proto_to_dict(self._raw_part)

@property
def text(self) -> str:
Expand Down Expand Up @@ -2001,7 +1999,7 @@ def from_dict(cls, safety_setting_dict: Dict[str, Any]) -> "SafetySetting":
return cls._from_gapic(raw_safety_setting=raw_safety_setting)

def to_dict(self) -> Dict[str, Any]:
return type(self._raw_safety_setting).to_dict(self._raw_safety_setting)
return _proto_to_dict(self._raw_safety_setting)

def __repr__(self):
return self._raw_safety_setting.__repr__()
Expand Down Expand Up @@ -2228,6 +2226,15 @@ def _append_gapic_part(
base_part._pb = copy.deepcopy(new_part._pb)


def _proto_to_dict(message) -> Dict[str, Any]:
"""Converts a proto-plus protobuf message to a dictionary."""
return type(message).to_dict(
message,
including_default_value_fields=False,
use_integers_for_enums=False,
)


def _dict_to_pretty_string(d: dict) -> str:
"""Format dict as a pretty-printed JSON string."""
return json.dumps(d, indent=2)
Expand Down

0 comments on commit 9809a3a

Please sign in to comment.