Skip to content

Commit

Permalink
update compat
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Dec 6, 2024
1 parent befa507 commit 6901257
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
25 changes: 16 additions & 9 deletions mlforecast/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,26 @@ def __init__(self, *args, **kwargs): # noqa: ARG002
raise ImportError("Please install lightgbm to use this model.")


try:
from window_ops.shift import shift_array
except ImportError:
import numpy as np

def shift_array(x, offset): # noqa: ARG002
return np.hstack([np.full(offset, np.nan), x[:-offset]])


try:
from xgboost import XGBRegressor
except ImportError:

class XGBRegressor:
def __init__(self, *args, **kwargs): # noqa: ARG002
raise ImportError("Please install xgboost to use this model.")


try:
from window_ops.shift import shift_array
except ImportError:
import numpy as np
from utilsforecast.compat import njit

@njit
def shift_array(x, offset):
if offset >= x.size or offset < 0:
return np.full_like(x, np.nan)
out = np.empty_like(x)
out[:offset] = np.nan
out[offset:] = x[:-offset]
return out
25 changes: 16 additions & 9 deletions nbs/compat.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,28 @@
" )\n",
"\n",
"try:\n",
" from window_ops.shift import shift_array\n",
"except ImportError:\n",
" import numpy as np\n",
"\n",
" def shift_array(x, offset): # noqa: ARG002\n",
" return np.hstack([np.full(offset, np.nan), x[:-offset]])\n",
"\n",
"try:\n",
" from xgboost import XGBRegressor\n",
"except ImportError:\n",
" class XGBRegressor:\n",
" def __init__(self, *args, **kwargs): # noqa: ARG002\n",
" raise ImportError(\n",
" \"Please install xgboost to use this model.\"\n",
" )"
" )\n",
"\n",
"try:\n",
" from window_ops.shift import shift_array\n",
"except ImportError:\n",
" import numpy as np\n",
" from utilsforecast.compat import njit\n",
"\n",
" @njit\n",
" def shift_array(x, offset):\n",
" if offset >= x.size or offset < 0:\n",
" return np.full_like(x, np.nan)\n",
" out = np.empty_like(x)\n",
" out[:offset] = np.nan\n",
" out[offset:] = x[:-offset]\n",
" return out"
]
}
],
Expand Down

0 comments on commit 6901257

Please sign in to comment.