Created
July 24, 2023 10:57
-
-
Save OmarManzoor/84c8695cd6f5165251b42e3df7e1a5ca to your computer and use it in GitHub Desktop.
RadiusNeighborsClassMode Updated Benchmark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.spatial.distance import cdist | |
from scipy.sparse import rand as sparse_rand | |
from .common import Benchmark | |
from sklearn.neighbors import RadiusNeighborsClassifier | |
class BruteRadiusNeighborsClass(Benchmark): | |
param_names = [ | |
"n_train", | |
"n_test", | |
"n_features", | |
"metric", | |
"dtype", | |
"X_train", | |
"X_test", | |
] | |
params = [ | |
[1000, 10_000], | |
[1000, 10_000], | |
[100], | |
["manhattan"], | |
[np.float64], | |
["dense"], | |
["dense"], | |
] | |
def setup(self, *params): | |
n_train, n_test, n_features, metric, dtype, X_train, X_test = params | |
rng = np.random.RandomState(0) | |
self.X_train = ( | |
rng.rand(n_train, n_features).astype(dtype) | |
if X_train == "dense" | |
else sparse_rand( | |
n_train, | |
n_features, | |
density=0.05, | |
format="csr", | |
dtype=dtype, | |
random_state=rng, | |
) | |
) | |
self.X_test = ( | |
rng.rand(n_test, n_features).astype(dtype) | |
if X_test == "dense" | |
else sparse_rand( | |
n_test, | |
n_features, | |
density=0.05, | |
format="csr", | |
dtype=dtype, | |
random_state=rng, | |
) | |
) | |
self.y_train = rng.randint(low=-1, high=1, size=(n_train,)) | |
self.metric = metric | |
dist_mat = cdist( | |
(self.X_train if X_train == "dense" else self.X_train.toarray())[:1000], | |
(self.X_test if X_test == "dense" else self.X_test.toarray())[:10], | |
) | |
self.radius = np.quantile(a=dist_mat.ravel(), q=0.01) * 20 | |
self.rc = RadiusNeighborsClassifier( | |
radius=self.radius, algorithm="brute", metric=self.metric, | |
) | |
self.rc.fit(X=self.X_train, y=self.y_train) | |
def time_predict_proba(self, *params): | |
self.rc.predict_proba(self.X_test) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment