Skip to content

Commit

Permalink
Adjust ZenBytes to zenml version 0.13.0.
Browse files Browse the repository at this point in the history
  • Loading branch information
fa9r committed Aug 23, 2022
1 parent 984899d commit bcb60c9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
25 changes: 18 additions & 7 deletions 1-1_Pipelines.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,17 @@
"import numpy as np\n",
"from sklearn.base import ClassifierMixin\n",
"from sklearn.svm import SVC\n",
"from zenml.integrations.sklearn.helpers.digits import get_digits\n",
"from sklearn.datasets import load_digits\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"\n",
"def train_test() -> None:\n",
" \"\"\"Train and test a Scikit-learn SVC classifier on digits\"\"\"\n",
" X_train, X_test, y_train, y_test = get_digits()\n",
" digits = load_digits()\n",
" data = digits.images.reshape((len(digits.images), -1))\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" data, digits.target, test_size=0.2, shuffle=False\n",
" )\n",
" model = SVC(gamma=0.001)\n",
" model.fit(X_train, y_train)\n",
" test_acc = model.score(X_test, y_test)\n",
Expand Down Expand Up @@ -136,7 +141,11 @@
" y_test=np.ndarray,\n",
"):\n",
" \"\"\"Load the digits dataset as numpy arrays.\"\"\"\n",
" X_train, X_test, y_train, y_test = get_digits()\n",
" digits = load_digits()\n",
" data = digits.images.reshape((len(digits.images), -1))\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" data, digits.target, test_size=0.2, shuffle=False\n",
" )\n",
" return X_train, X_test, y_train, y_test\n",
"\n",
"\n",
Expand Down Expand Up @@ -220,11 +229,8 @@
}
],
"metadata": {
"interpreter": {
"hash": "a35bb4b4bceaf970a493ff7351e9d97180ab3fe9951c21e9e29c55a687242182"
},
"kernelspec": {
"display_name": "Python 3.8.13 64-bit ('zenbytes-latest')",
"display_name": "Python 3.8.13 64-bit ('zenbytes')",
"language": "python",
"name": "python3"
},
Expand All @@ -239,6 +245,11 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"vscode": {
"interpreter": {
"hash": "9f70ec6e6bd16014ded89c8222361cbe53cd9507d51ebdcdf3ab6e494d45cf74"
}
}
},
"nbformat": 4,
Expand Down
9 changes: 7 additions & 2 deletions steps/importer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pandas as pd
from zenml.integrations.sklearn.helpers.digits import get_digits
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from zenml.steps import Output, step


Expand All @@ -9,7 +10,11 @@ def importer() -> Output(
X_train=np.ndarray, X_test=np.ndarray, y_train=np.ndarray, y_test=np.ndarray
):
"""Loads the digits array as normal numpy arrays."""
X_train, X_test, y_train, y_test = get_digits()
digits = load_digits()
data = digits.images.reshape((len(digits.images), -1))
X_train, X_test, y_train, y_test = train_test_split(
data, digits.target, test_size=0.2, shuffle=False
)
return X_train, X_test, y_train, y_test


Expand Down

0 comments on commit bcb60c9

Please sign in to comment.