Skip to content

Commit

Permalink
FIX Update pairwise distance function argument names (#26351)
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 authored Jul 25, 2023
1 parent 07f6586 commit 59048f9
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}):
{{name_suffix}}bit implementation of ArgKminClassMode.
"""
cdef:
const intp_t[:] class_membership,
const intp_t[:] unique_labels
const intp_t[:] Y_labels,
const intp_t[:] unique_Y_labels
float64_t[:, :] class_scores
cpp_map[intp_t, intp_t] labels_to_index
WeightingStrategy weight_type
Expand All @@ -38,14 +38,14 @@ cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}):
Y,
intp_t k,
weights,
class_membership,
unique_labels,
Y_labels,
unique_Y_labels,
str metric="euclidean",
chunk_size=None,
dict metric_kwargs=None,
str strategy=None,
):
"""Compute the argkmin reduction with class_membership.
"""Compute the argkmin reduction with Y_labels.

This classmethod is responsible for introspecting the arguments
values to dispatch to the most appropriate implementation of
Expand All @@ -66,8 +66,8 @@ cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}):
chunk_size=chunk_size,
strategy=strategy,
weights=weights,
class_membership=class_membership,
unique_labels=unique_labels,
Y_labels=Y_labels,
unique_Y_labels=unique_Y_labels,
)

# Limit the number of threads in second level of nested parallelism for BLAS
Expand All @@ -83,8 +83,8 @@ cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}):
def __init__(
self,
DatasetsPair{{name_suffix}} datasets_pair,
const intp_t[:] class_membership,
const intp_t[:] unique_labels,
const intp_t[:] Y_labels,
const intp_t[:] unique_Y_labels,
chunk_size=None,
strategy=None,
intp_t k=1,
Expand All @@ -103,15 +103,15 @@ cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}):
self.weight_type = WeightingStrategy.distance
else:
self.weight_type = WeightingStrategy.callable
self.class_membership = class_membership
self.Y_labels = Y_labels

self.unique_labels = unique_labels
self.unique_Y_labels = unique_Y_labels

cdef intp_t idx, neighbor_class_idx
# Map from set of unique labels to their indices in `class_scores`
# Buffer used in building a histogram for one-pass weighted mode
self.class_scores = np.zeros(
(self.n_samples_X, unique_labels.shape[0]), dtype=np.float64,
(self.n_samples_X, unique_Y_labels.shape[0]), dtype=np.float64,
)

def _finalize_results(self):
Expand Down Expand Up @@ -142,7 +142,7 @@ cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}):
if use_distance_weighting:
score_incr = 1 / distances[neighbor_rank]
neighbor_idx = indices[neighbor_rank]
neighbor_class_idx = self.class_membership[neighbor_idx]
neighbor_class_idx = self.Y_labels[neighbor_idx]
self.class_scores[sample_index][neighbor_class_idx] += score_incr
return

Expand Down
26 changes: 13 additions & 13 deletions sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def is_usable_for(cls, X, Y, metric) -> bool:
The input array to be labelled.
Y : ndarray of shape (n_samples_Y, n_features)
The input array whose labels are provided through the `labels`
The input array whose labels are provided through the `Y_labels`
parameter.
metric : str, default='euclidean'
Expand Down Expand Up @@ -484,8 +484,8 @@ def compute(
Y,
k,
weights,
labels,
unique_labels,
Y_labels,
unique_Y_labels,
metric="euclidean",
chunk_size=None,
metric_kwargs=None,
Expand All @@ -499,23 +499,23 @@ def compute(
The input array to be labelled.
Y : ndarray of shape (n_samples_Y, n_features)
The input array whose labels are provided through the `labels`
parameter.
The input array whose class membership are provided through the
`Y_labels` parameter.
k : int
The number of nearest neighbors to consider.
weights : ndarray
The weights applied over the `labels` of `Y` when computing the
The weights applied over the `Y_labels` of `Y` when computing the
weighted mode of the labels.
class_membership : ndarray
Y_labels : ndarray
An array containing the index of the class membership of the
associated samples in `Y`. This is used in labeling `X`.
unique_classes : ndarray
unique_Y_labels : ndarray
An array containing all unique indices contained in the
corresponding `class_membership` array.
corresponding `Y_labels` array.
metric : str, default='euclidean'
The distance metric to use. For a list of available metrics, see
Expand Down Expand Up @@ -587,8 +587,8 @@ def compute(
Y=Y,
k=k,
weights=weights,
class_membership=np.array(labels, dtype=np.intp),
unique_labels=np.array(unique_labels, dtype=np.intp),
Y_labels=np.array(Y_labels, dtype=np.intp),
unique_Y_labels=np.array(unique_Y_labels, dtype=np.intp),
metric=metric,
chunk_size=chunk_size,
metric_kwargs=metric_kwargs,
Expand All @@ -601,8 +601,8 @@ def compute(
Y=Y,
k=k,
weights=weights,
class_membership=np.array(labels, dtype=np.intp),
unique_labels=np.array(unique_labels, dtype=np.intp),
Y_labels=np.array(Y_labels, dtype=np.intp),
unique_Y_labels=np.array(unique_Y_labels, dtype=np.intp),
metric=metric,
chunk_size=chunk_size,
metric_kwargs=metric_kwargs,
Expand Down
48 changes: 24 additions & 24 deletions sklearn/metrics/tests/test_pairwise_distances_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,8 @@ def test_argkmin_classmode_factory_method_wrong_usages():
metric = "manhattan"

weights = "uniform"
labels = rng.randint(low=0, high=10, size=100)
unique_labels = np.unique(labels)
Y_labels = rng.randint(low=0, high=10, size=100)
unique_Y_labels = np.unique(Y_labels)

msg = (
"Only float64 or float32 datasets pairs are supported at this time, "
Expand All @@ -663,8 +663,8 @@ def test_argkmin_classmode_factory_method_wrong_usages():
k=k,
metric=metric,
weights=weights,
labels=labels,
unique_labels=unique_labels,
Y_labels=Y_labels,
unique_Y_labels=unique_Y_labels,
)

msg = (
Expand All @@ -678,8 +678,8 @@ def test_argkmin_classmode_factory_method_wrong_usages():
k=k,
metric=metric,
weights=weights,
labels=labels,
unique_labels=unique_labels,
Y_labels=Y_labels,
unique_Y_labels=unique_Y_labels,
)

with pytest.raises(ValueError, match="k == -1, must be >= 1."):
Expand All @@ -689,8 +689,8 @@ def test_argkmin_classmode_factory_method_wrong_usages():
k=-1,
metric=metric,
weights=weights,
labels=labels,
unique_labels=unique_labels,
Y_labels=Y_labels,
unique_Y_labels=unique_Y_labels,
)

with pytest.raises(ValueError, match="k == 0, must be >= 1."):
Expand All @@ -700,8 +700,8 @@ def test_argkmin_classmode_factory_method_wrong_usages():
k=0,
metric=metric,
weights=weights,
labels=labels,
unique_labels=unique_labels,
Y_labels=Y_labels,
unique_Y_labels=unique_Y_labels,
)

with pytest.raises(ValueError, match="Unrecognized metric"):
Expand All @@ -711,8 +711,8 @@ def test_argkmin_classmode_factory_method_wrong_usages():
k=k,
metric="wrong metric",
weights=weights,
labels=labels,
unique_labels=unique_labels,
Y_labels=Y_labels,
unique_Y_labels=unique_Y_labels,
)

with pytest.raises(
Expand All @@ -724,8 +724,8 @@ def test_argkmin_classmode_factory_method_wrong_usages():
k=k,
metric=metric,
weights=weights,
labels=labels,
unique_labels=unique_labels,
Y_labels=Y_labels,
unique_Y_labels=unique_Y_labels,
)

with pytest.raises(ValueError, match="ndarray is not C-contiguous"):
Expand All @@ -735,8 +735,8 @@ def test_argkmin_classmode_factory_method_wrong_usages():
k=k,
metric=metric,
weights=weights,
labels=labels,
unique_labels=unique_labels,
Y_labels=Y_labels,
unique_Y_labels=unique_Y_labels,
)

non_existent_weights_strategy = "non_existent_weights_strategy"
Expand All @@ -751,8 +751,8 @@ def test_argkmin_classmode_factory_method_wrong_usages():
k=k,
metric=metric,
weights=non_existent_weights_strategy,
labels=labels,
unique_labels=unique_labels,
Y_labels=Y_labels,
unique_Y_labels=unique_Y_labels,
)

# TODO: introduce assertions on UserWarnings once the Euclidean specialisation
Expand Down Expand Up @@ -1332,16 +1332,16 @@ def test_argkmin_classmode_strategy_consistent():
metric = "manhattan"

weights = "uniform"
labels = rng.randint(low=0, high=10, size=100)
unique_labels = np.unique(labels)
Y_labels = rng.randint(low=0, high=10, size=100)
unique_Y_labels = np.unique(Y_labels)
results_X = ArgKminClassMode.compute(
X=X,
Y=Y,
k=k,
metric=metric,
weights=weights,
labels=labels,
unique_labels=unique_labels,
Y_labels=Y_labels,
unique_Y_labels=unique_Y_labels,
strategy="parallel_on_X",
)
results_Y = ArgKminClassMode.compute(
Expand All @@ -1350,8 +1350,8 @@ def test_argkmin_classmode_strategy_consistent():
k=k,
metric=metric,
weights=weights,
labels=labels,
unique_labels=unique_labels,
Y_labels=Y_labels,
unique_Y_labels=unique_Y_labels,
strategy="parallel_on_Y",
)
assert_array_equal(results_X, results_Y)
4 changes: 2 additions & 2 deletions sklearn/neighbors/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ def predict_proba(self, X):
self._fit_X,
k=self.n_neighbors,
weights=self.weights,
labels=self._y,
unique_labels=self.classes_,
Y_labels=self._y,
unique_Y_labels=self.classes_,
metric=metric,
metric_kwargs=metric_kwargs,
# `strategy="parallel_on_X"` has in practice be shown
Expand Down

0 comments on commit 59048f9

Please sign in to comment.