Skip to content

Commit

Permalink
fix: stale check; conf lazy imports (#178)
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <ab93@users.noreply.github.com>
  • Loading branch information
ab93 authored May 9, 2023
1 parent cffe380 commit 6c21e95
Show file tree
Hide file tree
Showing 10 changed files with 165 additions and 36 deletions.
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@ repos:
language_version: python3.9
args: [--config=pyproject.toml, --diff, --color ]
- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: 'v0.0.264'
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix ]
- repo: https://github.com/adamchainz/blacken-docs
rev: "1.13.0"
hooks:
Expand Down
40 changes: 22 additions & 18 deletions numalogic/config/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,7 @@
# limitations under the License.
from typing import Union

from sklearn.preprocessing import StandardScaler, MinMaxScaler, MaxAbsScaler, RobustScaler

from numalogic.config._config import ModelInfo, RegistryInfo
from numalogic.models.autoencoder.variants import (
VanillaAE,
SparseVanillaAE,
Conv1dAE,
SparseConv1dAE,
LSTMAE,
SparseLSTMAE,
TransformerAE,
SparseTransformerAE,
)
from numalogic.models.threshold import StdDevThreshold, StaticThreshold, SigmoidThreshold
from numalogic.postprocess import TanhNorm, ExpMovingAverage
from numalogic.preprocess import LogTransformer, StaticPowerTransformer, TanhScaler
from numalogic.registry import MLflowRegistry, RedisRegistry
from numalogic.tools.exceptions import UnknownConfigArgsError


Expand All @@ -52,6 +36,9 @@ def get_cls(self, object_info: Union[ModelInfo, RegistryInfo]):


class PreprocessFactory(_ObjectFactory):
from sklearn.preprocessing import StandardScaler, MinMaxScaler, MaxAbsScaler, RobustScaler
from numalogic.preprocess import LogTransformer, StaticPowerTransformer, TanhScaler

_CLS_MAP = {
"StandardScaler": StandardScaler,
"MinMaxScaler": MinMaxScaler,
Expand All @@ -64,10 +51,14 @@ class PreprocessFactory(_ObjectFactory):


class PostprocessFactory(_ObjectFactory):
from numalogic.postprocess import TanhNorm, ExpMovingAverage

_CLS_MAP = {"TanhNorm": TanhNorm, "ExpMovingAverage": ExpMovingAverage}


class ThresholdFactory(_ObjectFactory):
from numalogic.models.threshold import StdDevThreshold, StaticThreshold, SigmoidThreshold

_CLS_MAP = {
"StdDevThreshold": StdDevThreshold,
"StaticThreshold": StaticThreshold,
Expand All @@ -76,6 +67,17 @@ class ThresholdFactory(_ObjectFactory):


class ModelFactory(_ObjectFactory):
from numalogic.models.autoencoder.variants import (
VanillaAE,
SparseVanillaAE,
Conv1dAE,
SparseConv1dAE,
LSTMAE,
SparseLSTMAE,
TransformerAE,
SparseTransformerAE,
)

_CLS_MAP = {
"VanillaAE": VanillaAE,
"SparseVanillaAE": SparseVanillaAE,
Expand All @@ -89,7 +91,9 @@ class ModelFactory(_ObjectFactory):


class RegistryFactory(_ObjectFactory):
import numalogic.registry as reg

_CLS_MAP = {
"RedisRegistry": RedisRegistry,
"MLflowRegistry": MLflowRegistry,
"RedisRegistry": getattr(reg, "RedisRegistry"),
"MLflowRegistry": getattr(reg, "MLflowRegistry"),
}
16 changes: 14 additions & 2 deletions numalogic/registry/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dataclasses import dataclass
from typing import Any, Generic, TypeVar

from numalogic.tools.types import artifact_t, KEYS, META_T, EXTRA_T
from numalogic.tools.types import artifact_t, KEYS, META_T, META_VT, EXTRA_T


@dataclass
Expand Down Expand Up @@ -54,7 +54,7 @@ def load(
"""
raise NotImplementedError("Please implement this method!")

def save(self, skeys: KEYS, dkeys: KEYS, artifact: artifact_t, **metadata: META_T) -> Any:
def save(self, skeys: KEYS, dkeys: KEYS, artifact: artifact_t, **metadata: META_VT) -> Any:
r"""
Saves the artifact into mlflow registry and updates version.
Args:
Expand All @@ -75,6 +75,18 @@ def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None:
"""
raise NotImplementedError("Please implement this method!")

@staticmethod
def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool:
"""
Returns whether the given artifact is stale or not, i.e. if
more time has elasped since it was last retrained.
Args:
artifact_data: ArtifactData object to look into
freq_hr: Frequency of retraining in hours
"""
raise NotImplementedError("Please implement this method!")

@staticmethod
def construct_key(skeys: KEYS, dkeys: KEYS) -> str:
"""
Expand Down
19 changes: 17 additions & 2 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


import logging
from datetime import datetime, timedelta
from enum import Enum
from typing import Optional, Any

Expand All @@ -24,7 +25,7 @@
from numalogic.registry import ArtifactManager, ArtifactData
from numalogic.registry.artifact import ArtifactCache
from numalogic.tools.exceptions import ModelVersionError
from numalogic.tools.types import artifact_t, KEYS, META_T
from numalogic.tools.types import artifact_t, KEYS, META_VT

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -183,7 +184,7 @@ def save(
dkeys: KEYS,
artifact: artifact_t,
run_id: str = None,
**metadata: META_T,
**metadata: META_VT,
) -> Optional[ModelVersion]:
"""
Saves the artifact into mlflow registry and updates version.
Expand Down Expand Up @@ -213,6 +214,20 @@ def save(
finally:
mlflow.end_run()

@staticmethod
def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool:
"""
Returns whether the given artifact is stale or not, i.e. if
more time has elasped since it was last retrained.
Args:
artifact_data: ArtifactData object to look into
freq_hr: Frequency of retraining in hours
"""
date_updated = artifact_data.extras["last_updated_timestamp"] / 1000
stale_date = (datetime.now() - timedelta(hours=freq_hr)).timestamp()
return date_updated < stale_date

def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None:
"""
Deletes the artifact with a specified version from mlflow registry.
Expand Down
33 changes: 25 additions & 8 deletions numalogic/registry/redis_registry.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
import time
from datetime import datetime, timedelta
from typing import Optional

from redis.exceptions import RedisError

from numalogic.registry import ArtifactManager, ArtifactData
from numalogic.registry._serialize import loads, dumps
from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError
from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T
from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT

_LOGGER = logging.getLogger()

Expand All @@ -24,14 +25,13 @@ class RedisRegistry(ArtifactManager):
>>> import redis
>>> from numalogic.models.autoencoder.variants import VanillaAE
>>> from numalogic.registry.redis_registry import RedisRegistry
>>> ...
>>> r = redis.StrictRedis(host='127.0.0.1', port=6379)
>>> cli = r.client()
>>> registry = RedisRegistry(client=cli)
>>> skeys = ['c', 'a']
>>> dkeys = ['d', 'a']
>>> model = VanillaAE(10)
>>> registry = RedisRegistry(client=r)
>>> skeys, dkeys = ("mymetric", "ae"), ("vanilla", "seq10")
>>> model = VanillaAE(seq_len=10)
>>> registry.save(skeys, dkeys, artifact=model, **{'lr': 0.01})
>>> registry.load(skeys, dkeys)
>>> loaded_artifact = registry.load(skeys, dkeys)
"""

__slots__ = ("client", "ttl")
Expand Down Expand Up @@ -177,7 +177,7 @@ def save(
skeys: KEYS,
dkeys: KEYS,
artifact: artifact_t,
**metadata: META_T,
**metadata: META_VT,
) -> Optional[str]:
"""
Saves the artifact into redis registry and updates version.
Expand Down Expand Up @@ -229,3 +229,20 @@ def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None:
)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err

@staticmethod
def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool:
"""
Returns whether the given artifact is stale or not, i.e. if
more time has elasped since it was last retrained.
Args:
artifact_data: ArtifactData object to look into
freq_hr: Frequency of retraining in hours
"""
try:
artifact_ts = float(artifact_data.extras["timestamp"])
except (KeyError, TypeError) as err:
raise RedisRegistryError("Error fetching timestamp information") from err
stale_ts = (datetime.now() - timedelta(hours=freq_hr)).timestamp()
return stale_ts > artifact_ts
1 change: 1 addition & 0 deletions numalogic/tools/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

artifact_t = TypeVar("artifact_t", bound=Union[nn.Module, BaseEstimator])
META_T = TypeVar("META_T", bound=dict[str, Union[str, list, dict]])
META_VT = TypeVar("META_VT", bound=Union[str, list, dict])
EXTRA_T = TypeVar("EXTRA_T", bound=dict[str, Union[str, list, dict]])
redis_client_t = TypeVar("redis_client_t", bound=AbstractRedis, covariant=True)
KEYS = TypeVar("KEYS", bound=Sequence[str], covariant=True)
Expand Down
32 changes: 30 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.4.dev3"
version = "0.4.dev4"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand All @@ -26,7 +26,7 @@ homepage = "https://numalogic.numaproj.io/"
python = ">=3.9, <3.11"
numpy = "^1.23"
pandas = "^2.0"
scikit-learn = "^1.0"
scikit-learn = "^1.2"
mlflow-skinny = { version = ">2.0, <2.3", optional = true }
protobuf = "~3.20" # needed to support pytorch-lightning
omegaconf = "^2.3.0"
Expand Down Expand Up @@ -55,6 +55,7 @@ torchinfo = "^1.7.2"
ruff = "^0.0.264"
pre-commit = "^3.3.1"
fakeredis = "^2.11.2"
freezegun = "^1.2.2"

[tool.poetry.group.jupyter]
optional = true
Expand Down
32 changes: 32 additions & 0 deletions tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from contextlib import contextmanager
from unittest.mock import patch, Mock

from freezegun import freeze_time
from mlflow import ActiveRun
from mlflow.exceptions import RestException
from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST, ErrorCode, RESOURCE_LIMIT_EXCEEDED
Expand Down Expand Up @@ -290,6 +291,37 @@ def test_load_other_mlflow_err(self):
dkeys = self.dkeys
self.assertIsNone(ml.load(skeys=skeys, dkeys=dkeys))

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", {"lr": 0.01})
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_is_model_stale_true(self):
model = self.model
ml = MLflowRegistry(TRACKING_URI, artifact_type="pytorch")
ml.save(skeys=self.skeys, dkeys=self.dkeys, artifact=model, **{"lr": 0.01})
data = ml.load(skeys=self.skeys, dkeys=self.dkeys)
self.assertTrue(ml.is_artifact_stale(data, 12))

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", {"lr": 0.01})
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_is_model_stale_false(self):
model = self.model
ml = MLflowRegistry(TRACKING_URI, artifact_type="pytorch")
ml.save(skeys=self.skeys, dkeys=self.dkeys, artifact=model, **{"lr": 0.01})
data = ml.load(skeys=self.skeys, dkeys=self.dkeys)
with freeze_time("2022-05-24 10:30:00"):
self.assertFalse(ml.is_artifact_stale(data, 12))


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 6c21e95

Please sign in to comment.