Skip to content

Commit

Permalink
[ENH] decouple registry from base modules, scitype specific data reco…
Browse files Browse the repository at this point in the history
…rds for documentation of estimator types (#6998)

Fixes #6970 by a redesign of the
object scitype register that decouples the `registry` module from the
rest of `sktime`, by removing all outside imports on module level
(except from `sktime.base` and `utils`).

As a side effect, this also adds one record class per object type, which
can be used as a tagged metadata record and later as a basis or
documenting the individual object types in `sktime`.

The refactor proceeds as follows:

* the base class register is replaced by data record classes similar to
the `_tags` module
* the imports of base classes are isolated in class methods of those
records, `get_base_class`, which returns the base class corresponding to
the scitype
* exports of objects involving tags, in particular the classes, are
replaced by functions that produce the object, further isolating the
import to places where it is needed, e.g., `get_base_class_list`

To make the changes deprecation safe, imports of coupled objects are
intercepted, and replaced by calls on demand, i.e., whenever an external
call carries out an import.

Further, imports from outside `registry` but inside `sktime` (all from
the test framework) are also replaced with the new functions. This lead
to some dead functions and objects, which were removed, which further
lead to unused imports, which also were removed.
  • Loading branch information
fkiraly authored Aug 30, 2024
1 parent 58615b6 commit f3ee417
Show file tree
Hide file tree
Showing 12 changed files with 638 additions and 236 deletions.
4 changes: 2 additions & 2 deletions examples/02_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1448,10 +1448,10 @@
}
],
"source": [
"from sktime.registry import BASE_CLASS_REGISTER\n",
"from sktime.registry import get_obj_scitype_list\n",
"\n",
"# get only fist table column, the list of types\n",
"list(zip(*BASE_CLASS_REGISTER))[0]"
"get_obj_scitype_list()"
]
},
{
Expand Down
36 changes: 28 additions & 8 deletions sktime/registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,10 @@
from sktime.registry._alias import resolve_alias
from sktime.registry._alias_str import ALIAS_DICT
from sktime.registry._base_classes import (
BASE_CLASS_LIST,
BASE_CLASS_LOOKUP,
BASE_CLASS_REGISTER,
BASE_CLASS_SCITYPE_LIST,
TRANSFORMER_MIXIN_LIST,
TRANSFORMER_MIXIN_LOOKUP,
TRANSFORMER_MIXIN_REGISTER,
TRANSFORMER_MIXIN_SCITYPE_LIST,
get_base_class_list,
get_base_class_lookup,
get_base_class_register,
get_obj_scitype_list,
)
from sktime.registry._craft import craft, deps, imports
from sktime.registry._lookup import all_estimators, all_tags
Expand Down Expand Up @@ -42,3 +38,27 @@
"TRANSFORMER_MIXIN_LOOKUP",
"TRANSFORMER_MIXIN_SCITYPE_LIST",
]


def __getattr__(name):
getter_dict = {
"BASE_CLASS_LOOKUP": get_base_class_lookup,
"BASE_CLASS_REGISTER": get_base_class_register,
"BASE_CLASS_LIST": get_base_class_list,
"BASE_CLASS_SCITYPE_LIST": get_obj_scitype_list,
}
if name in getter_dict:
return getter_dict[name]()

# legacy transformer mixins,
# handled for downward compatibility
legacy_trafo_mixin_dict = {
"TRANSFORMER_MIXIN_LOOKUP": get_base_class_lookup,
"TRANSFORMER_MIXIN_REGISTER": get_base_class_register,
"TRANSFORMER_MIXIN_LIST": get_base_class_list,
"TRANSFORMER_MIXIN_SCITYPE_LIST": get_obj_scitype_list,
}
if name in legacy_trafo_mixin_dict:
return legacy_trafo_mixin_dict[name](mixin=True)

raise AttributeError(f"module {__name__} has no attribute {name}")
Loading

0 comments on commit f3ee417

Please sign in to comment.