Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor KNN graph handling and outlier detection in issue managers #1155

Merged
merged 14 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address mypy and flake8 issues
  • Loading branch information
elisno committed Jun 20, 2024
commit fb775e3e4374ec29f151d296683498e6d193303f
9 changes: 5 additions & 4 deletions cleanlab/datalab/internal/issue_manager/knn_graph_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from scipy.sparse import csr_matrix


from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, cast

from cleanlab.internal.neighbor.knn_graph import create_knn_graph_and_index
from cleanlab.typing import Metric


def num_neighbors_in_knn_graph(knn_graph: csr_matrix) -> int:
Expand Down Expand Up @@ -44,10 +45,10 @@ def knn_exists(kwargs: Dict[str, Any], statistics: Dict[str, Any], k_needed: int
def set_knn_graph(
features: Optional[npt.NDArray],
find_issues_kwargs: Dict[str, Any],
metric: Optional[str],
metric: Optional[Metric],
k: int,
statistics: Dict[str, Any],
) -> Tuple[csr_matrix, str]:
) -> Tuple[csr_matrix, Metric]:
# This only fetches graph (optionally)
knn_graph = _process_knn_graph_from_inputs(
find_issues_kwargs, statistics, k_for_recomputation=k
Expand All @@ -60,4 +61,4 @@ def set_knn_graph(
assert features is not None, "Features must be provided to compute the knn graph."
knn_graph, knn = create_knn_graph_and_index(features, n_neighbors=k, metric=metric)
metric = knn.metric
return knn_graph, metric
return cast(csr_matrix, knn_graph), cast(Metric, metric)
1 change: 0 additions & 1 deletion tests/datalab/datalab/test_datalab.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from cleanlab.datalab.datalab import Datalab
from cleanlab.datalab.internal.report import Reporter
from cleanlab.datalab.internal.task import Task
from cleanlab.internal.neighbor.knn_graph import create_knn_graph_and_index


SEED = 42
Expand Down
Loading