Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom operator serialize #120

Merged
merged 18 commits into from
Jun 28, 2023
20 changes: 14 additions & 6 deletions golem/serializers/coders/graph_serialization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Type, Sequence
from typing import Any, Dict, Type, Sequence, Union

from golem.core.dag.graph import Graph
from golem.core.dag.graph_delegate import GraphDelegate
Expand All @@ -20,13 +20,21 @@ def graph_from_json(cls: Type[Graph], json_obj: Dict[str, Any]) -> Graph:
return obj


def _reassign_edges_by_node_ids(nodes: Sequence[LinkedGraphNode]):
def _reassign_edges_by_node_ids(nodes: Sequence[Union[LinkedGraphNode, dict]]):
"""
Assigns each <inner_node> from "nodes_from" to equal <outer_node> from "nodes"
(cause each node from "nodes_from" in fact should point to the same node from "nodes")
"""
lookup_dict = {node.uid: node for node in nodes}
lookup_dict = {}
for node in nodes:
if node.nodes_from:
for parent_node_idx, parent_node_uid in enumerate(node.nodes_from):
node.nodes_from[parent_node_idx] = lookup_dict.get(parent_node_uid, None)
if isinstance(node, dict):
lookup_dict[node['uid']] = node
else:
lookup_dict[node.uid] = node

for node in nodes:
nodes_from = node['_nodes_from'] if isinstance(node, dict) else node.nodes_from
if not nodes_from:
continue
for parent_node_idx, parent_node_uid in enumerate(nodes_from):
nodes_from[parent_node_idx] = lookup_dict.get(parent_node_uid, None)
42 changes: 36 additions & 6 deletions golem/serializers/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from json import JSONDecoder, JSONEncoder
from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union

from golem.core.dag.linked_graph import LinkedGraph
from golem.core.dag.linked_graph_node import LinkedGraphNode
from golem.core.log import default_log

# NB: at the end of module init happens registration of default class coders

INSTANCE_OR_CALLABLE = TypeVar('INSTANCE_OR_CALLABLE', object, Callable)
EncodeCallable = Callable[[INSTANCE_OR_CALLABLE], Dict[str, Any]]
DecodeCallable = Callable[[Type[INSTANCE_OR_CALLABLE], Dict[str, Any]], INSTANCE_OR_CALLABLE]


MODULE_X_NAME_DELIMITER = '/'
CLASS_PATH_KEY = '_class_path'

Expand Down Expand Up @@ -56,7 +59,6 @@


class Serializer(JSONEncoder, JSONDecoder):

_to_json = 'to_json'
_from_json = 'from_json'

Expand Down Expand Up @@ -222,19 +224,33 @@ def default(self, obj: INSTANCE_OR_CALLABLE) -> Dict[str, Any]:
return JSONEncoder.default(self, obj)

@staticmethod
def _get_class(class_path: str) -> Type[INSTANCE_OR_CALLABLE]:
def _get_class(json_obj: dict) -> Optional[Type[INSTANCE_OR_CALLABLE]]:
"""
Gets the object type from the class_path

:param class_path: full path (module + name) of the class

:return: class, function or method type
"""
class_path = json_obj[CLASS_PATH_KEY]
class_path = LEGACY_CLASS_PATHS.get(class_path, class_path)
module_name, class_name = class_path.split(MODULE_X_NAME_DELIMITER)
module_name = Serializer._legacy_module_map(module_name)

obj_cls = import_module(module_name)
try:
obj_cls = import_module(module_name)
except ImportError as ex:
obj_cls = Serializer._import_as_base_class(json_obj)
if obj_cls:
default_log('Serializer').info(
f'Object was not decoded and will be stored as a dict '
f'because of an ImportError: {ex}.')
else:
default_log('Serializer').info(
f'Object was decoded as {obj_cls.__class__} and not as an original class '
f'because of an ImportError: {ex}.')
return obj_cls
MorrisNein marked this conversation as resolved.
Show resolved Hide resolved

for sub in class_name.split('.'):
obj_cls = getattr(obj_cls, sub)
return obj_cls
Expand All @@ -250,6 +266,18 @@ def _legacy_module_map(module_path: str) -> str:
def _is_bound_method(method: Callable) -> bool:
return hasattr(method, '__self__')

@staticmethod
def _import_as_base_class(json_obj: dict) \
-> Optional[Union[Type[LinkedGraph], Type[LinkedGraphNode]]]:
linked_graph_keys = {'_nodes', '_postprocess_nodes'}
linked_node_keys = {'content', '_nodes_from', 'uid'}
if linked_graph_keys.issubset(json_obj.keys()):
return LinkedGraph
elif linked_node_keys.issubset(json_obj.keys()):
return LinkedGraphNode
else:
return None

@staticmethod
def object_hook(json_obj: Dict[str, Any]) -> Union[INSTANCE_OR_CALLABLE, dict]:
"""
Expand All @@ -261,7 +289,7 @@ def object_hook(json_obj: Dict[str, Any]) -> Union[INSTANCE_OR_CALLABLE, dict]:
:return: Python class, function or method object OR input if it's just a regular dict
"""
if CLASS_PATH_KEY in json_obj:
obj_cls = Serializer._get_class(json_obj[CLASS_PATH_KEY])
obj_cls = Serializer._get_class(json_obj)
del json_obj[CLASS_PATH_KEY]
base_type = Serializer._get_base_type(obj_cls)
if isclass(obj_cls) and base_type is not None:
Expand All @@ -273,7 +301,8 @@ def object_hook(json_obj: Dict[str, Any]) -> Union[INSTANCE_OR_CALLABLE, dict]:
return coder(obj_cls, json_obj)
elif isfunction(obj_cls) or ismethod(obj_cls):
return obj_cls
raise TypeError(f'Parsed obj_cls={obj_cls} is not serializable, but should be')
else:
return json_obj
return json_obj


Expand All @@ -288,6 +317,7 @@ def default_save(obj: Any, json_file_path: Optional[Union[str, os.PathLike]] = N

def default_load(json_str_or_file_path: Union[str, os.PathLike]) -> Any:
""" Default load from json using Serializer """

def load_as_file_path():
with open(json_str_or_file_path, mode='r') as json_file:
return json.load(json_file, cls=Serializer)
Expand Down
Loading