From 6c21e952d37a8fcf7102f095b85fc193255c04d3 Mon Sep 17 00:00:00 2001 From: Avik Basu <3485425+ab93@users.noreply.github.com> Date: Tue, 9 May 2023 09:25:48 -0700 Subject: [PATCH] fix: stale check; conf lazy imports (#178) Signed-off-by: Avik Basu --- .pre-commit-config.yaml | 2 -- numalogic/config/factory.py | 40 ++++++++++++++------------ numalogic/registry/artifact.py | 16 +++++++++-- numalogic/registry/mlflow_registry.py | 19 ++++++++++-- numalogic/registry/redis_registry.py | 33 +++++++++++++++------ numalogic/tools/types.py | 1 + poetry.lock | 32 +++++++++++++++++++-- pyproject.toml | 5 ++-- tests/registry/test_mlflow_registry.py | 32 +++++++++++++++++++++ tests/registry/test_redis_registry.py | 21 ++++++++++++++ 10 files changed, 165 insertions(+), 36 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8f7ec31a..457be12f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,11 +8,9 @@ repos: language_version: python3.9 args: [--config=pyproject.toml, --diff, --color ] - repo: https://github.com/charliermarsh/ruff-pre-commit - # Ruff version. rev: 'v0.0.264' hooks: - id: ruff - args: [ --fix, --exit-non-zero-on-fix ] - repo: https://github.com/adamchainz/blacken-docs rev: "1.13.0" hooks: diff --git a/numalogic/config/factory.py b/numalogic/config/factory.py index 69370ace..115c33d6 100644 --- a/numalogic/config/factory.py +++ b/numalogic/config/factory.py @@ -10,23 +10,7 @@ # limitations under the License. from typing import Union -from sklearn.preprocessing import StandardScaler, MinMaxScaler, MaxAbsScaler, RobustScaler - from numalogic.config._config import ModelInfo, RegistryInfo -from numalogic.models.autoencoder.variants import ( - VanillaAE, - SparseVanillaAE, - Conv1dAE, - SparseConv1dAE, - LSTMAE, - SparseLSTMAE, - TransformerAE, - SparseTransformerAE, -) -from numalogic.models.threshold import StdDevThreshold, StaticThreshold, SigmoidThreshold -from numalogic.postprocess import TanhNorm, ExpMovingAverage -from numalogic.preprocess import LogTransformer, StaticPowerTransformer, TanhScaler -from numalogic.registry import MLflowRegistry, RedisRegistry from numalogic.tools.exceptions import UnknownConfigArgsError @@ -52,6 +36,9 @@ def get_cls(self, object_info: Union[ModelInfo, RegistryInfo]): class PreprocessFactory(_ObjectFactory): + from sklearn.preprocessing import StandardScaler, MinMaxScaler, MaxAbsScaler, RobustScaler + from numalogic.preprocess import LogTransformer, StaticPowerTransformer, TanhScaler + _CLS_MAP = { "StandardScaler": StandardScaler, "MinMaxScaler": MinMaxScaler, @@ -64,10 +51,14 @@ class PreprocessFactory(_ObjectFactory): class PostprocessFactory(_ObjectFactory): + from numalogic.postprocess import TanhNorm, ExpMovingAverage + _CLS_MAP = {"TanhNorm": TanhNorm, "ExpMovingAverage": ExpMovingAverage} class ThresholdFactory(_ObjectFactory): + from numalogic.models.threshold import StdDevThreshold, StaticThreshold, SigmoidThreshold + _CLS_MAP = { "StdDevThreshold": StdDevThreshold, "StaticThreshold": StaticThreshold, @@ -76,6 +67,17 @@ class ThresholdFactory(_ObjectFactory): class ModelFactory(_ObjectFactory): + from numalogic.models.autoencoder.variants import ( + VanillaAE, + SparseVanillaAE, + Conv1dAE, + SparseConv1dAE, + LSTMAE, + SparseLSTMAE, + TransformerAE, + SparseTransformerAE, + ) + _CLS_MAP = { "VanillaAE": VanillaAE, "SparseVanillaAE": SparseVanillaAE, @@ -89,7 +91,9 @@ class ModelFactory(_ObjectFactory): class RegistryFactory(_ObjectFactory): + import numalogic.registry as reg + _CLS_MAP = { - "RedisRegistry": RedisRegistry, - "MLflowRegistry": MLflowRegistry, + "RedisRegistry": getattr(reg, "RedisRegistry"), + "MLflowRegistry": getattr(reg, "MLflowRegistry"), } diff --git a/numalogic/registry/artifact.py b/numalogic/registry/artifact.py index c07d8d49..05bbd6de 100644 --- a/numalogic/registry/artifact.py +++ b/numalogic/registry/artifact.py @@ -13,7 +13,7 @@ from dataclasses import dataclass from typing import Any, Generic, TypeVar -from numalogic.tools.types import artifact_t, KEYS, META_T, EXTRA_T +from numalogic.tools.types import artifact_t, KEYS, META_T, META_VT, EXTRA_T @dataclass @@ -54,7 +54,7 @@ def load( """ raise NotImplementedError("Please implement this method!") - def save(self, skeys: KEYS, dkeys: KEYS, artifact: artifact_t, **metadata: META_T) -> Any: + def save(self, skeys: KEYS, dkeys: KEYS, artifact: artifact_t, **metadata: META_VT) -> Any: r""" Saves the artifact into mlflow registry and updates version. Args: @@ -75,6 +75,18 @@ def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None: """ raise NotImplementedError("Please implement this method!") + @staticmethod + def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool: + """ + Returns whether the given artifact is stale or not, i.e. if + more time has elasped since it was last retrained. + Args: + artifact_data: ArtifactData object to look into + freq_hr: Frequency of retraining in hours + + """ + raise NotImplementedError("Please implement this method!") + @staticmethod def construct_key(skeys: KEYS, dkeys: KEYS) -> str: """ diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index 3b90f14c..12cc6e8c 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -11,6 +11,7 @@ import logging +from datetime import datetime, timedelta from enum import Enum from typing import Optional, Any @@ -24,7 +25,7 @@ from numalogic.registry import ArtifactManager, ArtifactData from numalogic.registry.artifact import ArtifactCache from numalogic.tools.exceptions import ModelVersionError -from numalogic.tools.types import artifact_t, KEYS, META_T +from numalogic.tools.types import artifact_t, KEYS, META_VT _LOGGER = logging.getLogger(__name__) @@ -183,7 +184,7 @@ def save( dkeys: KEYS, artifact: artifact_t, run_id: str = None, - **metadata: META_T, + **metadata: META_VT, ) -> Optional[ModelVersion]: """ Saves the artifact into mlflow registry and updates version. @@ -213,6 +214,20 @@ def save( finally: mlflow.end_run() + @staticmethod + def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool: + """ + Returns whether the given artifact is stale or not, i.e. if + more time has elasped since it was last retrained. + Args: + artifact_data: ArtifactData object to look into + freq_hr: Frequency of retraining in hours + + """ + date_updated = artifact_data.extras["last_updated_timestamp"] / 1000 + stale_date = (datetime.now() - timedelta(hours=freq_hr)).timestamp() + return date_updated < stale_date + def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None: """ Deletes the artifact with a specified version from mlflow registry. diff --git a/numalogic/registry/redis_registry.py b/numalogic/registry/redis_registry.py index 3c85c4ff..851807be 100644 --- a/numalogic/registry/redis_registry.py +++ b/numalogic/registry/redis_registry.py @@ -1,5 +1,6 @@ import logging import time +from datetime import datetime, timedelta from typing import Optional from redis.exceptions import RedisError @@ -7,7 +8,7 @@ from numalogic.registry import ArtifactManager, ArtifactData from numalogic.registry._serialize import loads, dumps from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError -from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T +from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT _LOGGER = logging.getLogger() @@ -24,14 +25,13 @@ class RedisRegistry(ArtifactManager): >>> import redis >>> from numalogic.models.autoencoder.variants import VanillaAE >>> from numalogic.registry.redis_registry import RedisRegistry + >>> ... >>> r = redis.StrictRedis(host='127.0.0.1', port=6379) - >>> cli = r.client() - >>> registry = RedisRegistry(client=cli) - >>> skeys = ['c', 'a'] - >>> dkeys = ['d', 'a'] - >>> model = VanillaAE(10) + >>> registry = RedisRegistry(client=r) + >>> skeys, dkeys = ("mymetric", "ae"), ("vanilla", "seq10") + >>> model = VanillaAE(seq_len=10) >>> registry.save(skeys, dkeys, artifact=model, **{'lr': 0.01}) - >>> registry.load(skeys, dkeys) + >>> loaded_artifact = registry.load(skeys, dkeys) """ __slots__ = ("client", "ttl") @@ -177,7 +177,7 @@ def save( skeys: KEYS, dkeys: KEYS, artifact: artifact_t, - **metadata: META_T, + **metadata: META_VT, ) -> Optional[str]: """ Saves the artifact into redis registry and updates version. @@ -229,3 +229,20 @@ def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None: ) except RedisError as err: raise RedisRegistryError(f"{err.__class__.__name__} raised") from err + + @staticmethod + def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool: + """ + Returns whether the given artifact is stale or not, i.e. if + more time has elasped since it was last retrained. + Args: + artifact_data: ArtifactData object to look into + freq_hr: Frequency of retraining in hours + + """ + try: + artifact_ts = float(artifact_data.extras["timestamp"]) + except (KeyError, TypeError) as err: + raise RedisRegistryError("Error fetching timestamp information") from err + stale_ts = (datetime.now() - timedelta(hours=freq_hr)).timestamp() + return stale_ts > artifact_ts diff --git a/numalogic/tools/types.py b/numalogic/tools/types.py index cdc3a3c1..8f801619 100644 --- a/numalogic/tools/types.py +++ b/numalogic/tools/types.py @@ -19,6 +19,7 @@ artifact_t = TypeVar("artifact_t", bound=Union[nn.Module, BaseEstimator]) META_T = TypeVar("META_T", bound=dict[str, Union[str, list, dict]]) +META_VT = TypeVar("META_VT", bound=Union[str, list, dict]) EXTRA_T = TypeVar("EXTRA_T", bound=dict[str, Union[str, list, dict]]) redis_client_t = TypeVar("redis_client_t", bound=AbstractRedis, covariant=True) KEYS = TypeVar("KEYS", bound=Sequence[str], covariant=True) diff --git a/poetry.lock b/poetry.lock index b6873961..6c6f7123 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1046,6 +1046,21 @@ files = [ {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, ] +[[package]] +name = "freezegun" +version = "1.2.2" +description = "Let your Python tests travel through time" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "freezegun-1.2.2-py3-none-any.whl", hash = "sha256:ea1b963b993cb9ea195adbd893a48d573fda951b0da64f60883d7e988b606c9f"}, + {file = "freezegun-1.2.2.tar.gz", hash = "sha256:cd22d1ba06941384410cd967d8a99d5ae2442f57dfafeff2fda5de8dc5c05446"}, +] + +[package.dependencies] +python-dateutil = ">=2.7" + [[package]] name = "frozenlist" version = "1.3.3" @@ -3942,6 +3957,10 @@ category = "dev" optional = false python-versions = ">=3.8.0" files = [ + {file = "torch-2.0.0-1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:c9090bda7d2eeeecd74f51b721420dbeb44f838d4536cc1b284e879417e3064a"}, + {file = "torch-2.0.0-1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bd42db2a48a20574d2c33489e120e9f32789c4dc13c514b0c44272972d14a2d7"}, + {file = "torch-2.0.0-1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:8969aa8375bcbc0c2993e7ede0a7f889df9515f18b9b548433f412affed478d9"}, + {file = "torch-2.0.0-1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:ab2da16567cb55b67ae39e32d520d68ec736191d88ac79526ca5874754c32203"}, {file = "torch-2.0.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:7a9319a67294ef02459a19738bbfa8727bb5307b822dadd708bc2ccf6c901aca"}, {file = "torch-2.0.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:9f01fe1f6263f31bd04e1757946fd63ad531ae37f28bb2dbf66f5c826ee089f4"}, {file = "torch-2.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:527f4ae68df7b8301ee6b1158ca56350282ea633686537b30dbb5d7b4a52622a"}, @@ -4090,6 +4109,15 @@ category = "dev" optional = false python-versions = "*" files = [ + {file = "triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:38806ee9663f4b0f7cd64790e96c579374089e58f49aac4a6608121aa55e2505"}, + {file = "triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:226941c7b8595219ddef59a1fdb821e8c744289a132415ddd584facedeb475b1"}, + {file = "triton-2.0.0-1-cp36-cp36m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4c9fc8c89874bc48eb7e7b2107a9b8d2c0bf139778637be5bfccb09191685cfd"}, + {file = "triton-2.0.0-1-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d2684b6a60b9f174f447f36f933e9a45f31db96cb723723ecd2dcfd1c57b778b"}, + {file = "triton-2.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9d4978298b74fcf59a75fe71e535c092b023088933b2f1df933ec32615e4beef"}, + {file = "triton-2.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:74f118c12b437fb2ca25e1a04759173b517582fcf4c7be11913316c764213656"}, + {file = "triton-2.0.0-1-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9618815a8da1d9157514f08f855d9e9ff92e329cd81c0305003eb9ec25cc5add"}, + {file = "triton-2.0.0-1-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1aca3303629cd3136375b82cb9921727f804e47ebee27b2677fef23005c3851a"}, + {file = "triton-2.0.0-1-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e3e13aa8b527c9b642e3a9defcc0fbd8ffbe1c80d8ac8c15a01692478dc64d8a"}, {file = "triton-2.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f05a7e64e4ca0565535e3d5d3405d7e49f9d308505bb7773d21fb26a4c008c2"}, {file = "triton-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4b99ca3c6844066e516658541d876c28a5f6e3a852286bbc97ad57134827fd"}, {file = "triton-2.0.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47b4d70dc92fb40af553b4460492c31dc7d3a114a979ffb7a5cdedb7eb546c08"}, @@ -4383,4 +4411,4 @@ redis = ["redis"] [metadata] lock-version = "2.0" python-versions = ">=3.9, <3.11" -content-hash = "0c290bd9f95ba0368172e2ea78ba600c333712176d8ed36c5d281fe8db1fd1d3" +content-hash = "80f11647af0a24dfe123591c8c673c0caf9b569de42d79eb2152ea060f6d13ca" diff --git a/pyproject.toml b/pyproject.toml index 68de0710..ae92fb66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "numalogic" -version = "0.4.dev3" +version = "0.4.dev4" description = "Collection of operational Machine Learning models and tools." authors = ["Numalogic Developers"] packages = [{ include = "numalogic" }] @@ -26,7 +26,7 @@ homepage = "https://numalogic.numaproj.io/" python = ">=3.9, <3.11" numpy = "^1.23" pandas = "^2.0" -scikit-learn = "^1.0" +scikit-learn = "^1.2" mlflow-skinny = { version = ">2.0, <2.3", optional = true } protobuf = "~3.20" # needed to support pytorch-lightning omegaconf = "^2.3.0" @@ -55,6 +55,7 @@ torchinfo = "^1.7.2" ruff = "^0.0.264" pre-commit = "^3.3.1" fakeredis = "^2.11.2" +freezegun = "^1.2.2" [tool.poetry.group.jupyter] optional = true diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index 08f07c8e..be286694 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -2,6 +2,7 @@ from contextlib import contextmanager from unittest.mock import patch, Mock +from freezegun import freeze_time from mlflow import ActiveRun from mlflow.exceptions import RestException from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST, ErrorCode, RESOURCE_LIMIT_EXCEEDED @@ -290,6 +291,37 @@ def test_load_other_mlflow_err(self): dkeys = self.dkeys self.assertIsNone(ml.load(skeys=skeys, dkeys=dkeys)) + @patch("mlflow.pytorch.log_model", mock_log_model_pytorch()) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) + @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) + @patch("mlflow.log_params", {"lr": 0.01}) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10))) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict())) + def test_is_model_stale_true(self): + model = self.model + ml = MLflowRegistry(TRACKING_URI, artifact_type="pytorch") + ml.save(skeys=self.skeys, dkeys=self.dkeys, artifact=model, **{"lr": 0.01}) + data = ml.load(skeys=self.skeys, dkeys=self.dkeys) + self.assertTrue(ml.is_artifact_stale(data, 12)) + + @patch("mlflow.pytorch.log_model", mock_log_model_pytorch()) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) + @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) + @patch("mlflow.log_params", {"lr": 0.01}) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10))) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict())) + def test_is_model_stale_false(self): + model = self.model + ml = MLflowRegistry(TRACKING_URI, artifact_type="pytorch") + ml.save(skeys=self.skeys, dkeys=self.dkeys, artifact=model, **{"lr": 0.01}) + data = ml.load(skeys=self.skeys, dkeys=self.dkeys) + with freeze_time("2022-05-24 10:30:00"): + self.assertFalse(ml.is_artifact_stale(data, 12)) + if __name__ == "__main__": unittest.main() diff --git a/tests/registry/test_redis_registry.py b/tests/registry/test_redis_registry.py index 25290e3d..61904bed 100644 --- a/tests/registry/test_redis_registry.py +++ b/tests/registry/test_redis_registry.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch import fakeredis +from freezegun import freeze_time from redis import ConnectionError, InvalidResponse, TimeoutError from sklearn.ensemble import RandomForestRegressor from sklearn.preprocessing import StandardScaler @@ -69,6 +70,26 @@ def test_load_model_with_version(self): self.assertIsNone(data.metadata) self.assertEqual(data.extras["version"], version) + def test_check_if_model_stale_true(self): + with freeze_time("2023-05-08 12:30:00"): + self.registry.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model) + data = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) + self.assertTrue(self.registry.is_artifact_stale(data, 12)) + + def test_check_if_model_stale_false(self): + with freeze_time("2023-05-08 12:30:00"): + self.registry.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model) + with freeze_time("2023-05-08 19:30:00"): + data = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) + self.assertFalse(self.registry.is_artifact_stale(data, 8)) + + def test_check_if_model_stale_err(self): + self.registry.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model) + data = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) + data.extras = None + with self.assertRaises(RedisRegistryError): + self.registry.is_artifact_stale(data, 8) + def test_both_version_latest_model_with_version(self): with self.assertRaises(ValueError): self.registry.load(skeys=self.skeys, dkeys=self.dkeys, latest=False)