Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make labels optional in Datalab #730

Merged
merged 4 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 72 additions & 17 deletions cleanlab/datalab/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""Classes and methods for datasets that are loaded into Datalab."""

import os
from typing import Any, Callable, Dict, List, Mapping, Tuple, Union, cast, TYPE_CHECKING
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast, TYPE_CHECKING

try:
import datasets
Expand Down Expand Up @@ -126,13 +126,11 @@ class Data:
:py:class:`Datalab <cleanlab.datalab.datalab.Datalab>` to work.
"""

def __init__(self, data: "DatasetLike", label_name: str) -> None:
def __init__(self, data: "DatasetLike", label_name: Optional[str] = None) -> None:
self._validate_data(data)
self._label_name = label_name
self._data = self._load_data(data)
self._validate_data_and_labels(self._data, self._data[label_name])
self._data_hash = hash(self._data)
self._labels, self._label_map = _extract_labels(self._data, label_name)
self.labels = Label(data=self._data, label_name=label_name)

def _load_data(self, data: "DatasetLike") -> Dataset:
"""Checks the type of dataset and uses the correct loader method and
Expand All @@ -154,19 +152,22 @@ def __len__(self) -> int:
def __eq__(self, other) -> bool:
if isinstance(other, Data):
# Equality checks
hashes = self._data_hash == other._data_hash
labels = np.array_equal(self._labels, other._labels)
label_names = self._label_name == other._label_name
label_maps = self._label_map == other._label_map
return all([hashes, labels, label_names, label_maps])
hashes_are_equal = self._data_hash == other._data_hash
labels_are_equal = self.labels == other.labels
return all([hashes_are_equal, labels_are_equal])
return False

def __hash__(self) -> int:
return self._data_hash

@property
def class_names(self) -> list:
return list(self._label_map.values())
def class_names(self) -> List[str]:
return self.labels.class_names

@property
def has_labels(self) -> bool:
"""Check if labels are available."""
return self.labels.is_available

@staticmethod
def _validate_data(data) -> None:
Expand All @@ -175,11 +176,6 @@ def _validate_data(data) -> None:
if not isinstance(data, (Dataset, pd.DataFrame, dict, list, str)):
raise DataFormatError(data)

@staticmethod
def _validate_data_and_labels(data, labels) -> None:
assert isinstance(labels, (np.ndarray, list))
assert len(labels) == len(data)

@staticmethod
def _load_dataset_from_dict(data_dict: Dict[str, Any]) -> Dataset:
try:
Expand Down Expand Up @@ -218,6 +214,65 @@ def _load_dataset_from_string(data_string: str) -> Dataset:
return dataset_cast


class Label:
"""
Class to represent labels in a dataset.

Parameters
----------
"""

def __init__(self, *, data: Dataset, label_name: Optional[str] = None) -> None:
self._data = data
self.label_name = label_name
self.labels = labels_to_array([])
self.label_map: Mapping[str, Any] = {}
if label_name is not None:
self.labels, self.label_map = _extract_labels(data, label_name)
self._validate_labels()

def __len__(self) -> int:
if self.labels is None:
return 0
return len(self.labels)

def __eq__(self, __value: object) -> bool:
if isinstance(__value, Label):
labels_are_equal = np.array_equal(self.labels, __value.labels)
names_are_equal = self.label_name == __value.label_name
maps_are_equal = self.label_map == __value.label_map
return all([labels_are_equal, names_are_equal, maps_are_equal])
return False

def __getitem__(self, __index: Union[int, slice, np.ndarray]) -> np.ndarray:
return self.labels[__index]

def __bool__(self) -> bool:
return self.is_available

@property
def class_names(self) -> List[str]:
"""A list of class names that are present in the dataset.

Without labels, this will return an empty list.
"""
return list(self.label_map.values())

@property
def is_available(self) -> bool:
"""Check if labels are available."""
empty_labels = self.labels is None or len(self.labels) == 0
empty_label_map = self.label_map is None or len(self.label_map) == 0
return not (empty_labels or empty_label_map)

def _validate_labels(self) -> None:
if self.label_name not in self._data.column_names:
raise ValueError(f"Label column '{self.label_name}' not found in dataset.")
labels = self._data[self.label_name]
assert isinstance(labels, (np.ndarray, list))
assert len(labels) == len(self._data)


def _extract_labels(data: Dataset, label_name: str) -> Tuple[np.ndarray, Mapping]:
"""
Picks out labels from the dataset and formats them to be [0, 1, ..., K-1]
Expand Down
38 changes: 29 additions & 9 deletions cleanlab/datalab/data_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,10 @@ def __init__(self, data: Data) -> None:
self.issue_summary: pd.DataFrame = pd.DataFrame(
columns=["issue_type", "score", "num_issues"]
).astype({"score": np.float64, "num_issues": np.int64})
class_names = data.class_names
self.info: Dict[str, Dict[str, Any]] = {
"statistics": {
"num_examples": len(data),
"class_names": class_names,
"num_classes": len(class_names),
"multi_label": False,
"health_score": None,
},
"statistics": get_data_statistics(data),
}
self._label_map = data._label_map
self._label_map = data.labels.label_map

@property
def statistics(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -179,6 +172,11 @@ def get_info(self, issue_name: Optional[str] = None) -> Dict[str, Any]:
)
info = info.copy()
if issue_name == "label":
if self._label_map is None:
raise ValueError(
"The label map is not available. "
"Most likely, no label column was provided when creating the Data object."
)
# Labels that are stored as integers may need to be converted to strings.
for key in ["given_label", "predicted_label"]:
labels = info.get(key, None)
Expand Down Expand Up @@ -269,3 +267,25 @@ def set_health_score(self) -> None:
Currently, the health score is the mean of the scores for each issue type.
"""
self.info["statistics"]["health_score"] = self.issue_summary["score"].mean()


def get_data_statistics(data: Data) -> Dict[str, Any]:
"""Get statistics about a dataset.

This function is called to initialize the "statistics" info in all `Datalab` objects.

Parameters
----------
data : Data
Data object containing the dataset.
"""
statistics: Dict[str, Any] = {
"num_examples": len(data),
"multi_label": False,
"health_score": None,
}
if data.labels.is_available:
class_names = data.class_names
statistics["class_names"] = class_names
statistics["num_classes"] = len(class_names)
return statistics
21 changes: 15 additions & 6 deletions cleanlab/datalab/datalab.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@ class Datalab:
def __init__(
self,
data: "DatasetLike",
label_name: str,
label_name: Optional[str] = None,
verbosity: int = 1,
) -> None:
self._data = Data(data, label_name)
self.data = self._data._data
self._labels, self._label_map = self._data._labels, self._data._label_map
self._labels = self._data.labels
self._label_map = self._labels.label_map
self.label_name = self._labels.label_name
self._data_hash = self._data._data_hash
self.label_name = self._data._label_name
self.data_issues = DataIssues(self._data)
self.cleanlab_version = cleanlab.version.__version__
self.verbosity = verbosity
Expand All @@ -109,12 +110,20 @@ def __str__(self) -> str:
@property
def labels(self) -> np.ndarray:
"""Labels of the dataset, in a [0, 1, ..., K-1] format."""
return self._labels
return self._labels.labels

@property
def has_labels(self) -> bool:
"""Whether the dataset has labels."""
return self._labels.is_available

@property
def class_names(self) -> List[str]:
"""Names of the classes in the dataset."""
return self._data.class_names
"""Names of the classes in the dataset.

If the dataset has no labels, returns an empty list.
"""
return self._labels.class_names

def find_issues(
self,
Expand Down
18 changes: 18 additions & 0 deletions cleanlab/datalab/issue_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,4 +349,22 @@ def get_available_issue_types(self, **kwargs):
if issue in issue_types_copy
}

drop_label_check = "label" in issue_types_copy and not self.datalab.has_labels
if drop_label_check:
warnings.warn("No labels were provided. " "The 'label' issue type will not be run.")
issue_types_copy.pop("label")

outlier_check_needs_features = "outlier" in issue_types_copy and not self.datalab.has_labels
if outlier_check_needs_features:
no_features = features is None
no_knn_graph = knn_graph is None
pred_probs_given = issue_types_copy["outlier"].get("pred_probs", None) is not None

only_pred_probs_given = pred_probs_given and no_features and no_knn_graph
if only_pred_probs_given:
warnings.warn(
"No labels were provided. " "The 'outlier' issue type will not be run."
)
issue_types_copy.pop("outlier")

return issue_types_copy
8 changes: 4 additions & 4 deletions cleanlab/datalab/issue_manager/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _reset(self) -> None:
if not self.health_summary_parameters:
statistics_dict = self.datalab.get_info("statistics")
self.health_summary_parameters = {
"labels": self.datalab._labels,
"labels": self.datalab.labels,
"class_names": list(self.datalab._label_map.values()),
"num_examples": statistics_dict.get("num_examples"),
"joint": statistics_dict.get("joint", None),
Expand All @@ -124,7 +124,7 @@ def find_issues(
self.health_summary_parameters.update({"pred_probs": pred_probs})
# Find examples with label issues
self.issues = self.cl.find_label_issues(
labels=self.datalab._labels,
labels=self.datalab.labels,
pred_probs=pred_probs,
**self._process_find_label_issues_kwargs(kwargs),
)
Expand Down Expand Up @@ -180,7 +180,7 @@ def _get_summary_parameters(self, pred_probs) -> Dict["str", Any]:
else:
summary_parameters = {
"pred_probs": pred_probs,
"labels": self.datalab._labels,
"labels": self.datalab.labels,
}

summary_parameters["class_names"] = self.health_summary_parameters["class_names"]
Expand Down Expand Up @@ -223,4 +223,4 @@ def collect_info(self, issues: pd.DataFrame, summary_dict: dict) -> dict:
return info_dict

def _validate_pred_probs(self, pred_probs) -> None:
assert_valid_inputs(X=None, y=self.datalab._labels, pred_probs=pred_probs)
assert_valid_inputs(X=None, y=self.datalab.labels, pred_probs=pred_probs)
2 changes: 1 addition & 1 deletion cleanlab/datalab/issue_manager/outlier.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def _build_statistics_dictionary(
def _score_with_pred_probs(self, pred_probs: np.ndarray, **kwargs) -> np.ndarray:
# Remove "threshold" from kwargs if it exists
kwargs.pop("threshold", None)
scores = self.ood.fit_score(pred_probs=pred_probs, labels=self.datalab._labels, **kwargs)
scores = self.ood.fit_score(pred_probs=pred_probs, labels=self.datalab.labels, **kwargs)
return scores

def _score_with_features(self, features: npt.NDArray, **kwargs) -> npt.NDArray:
Expand Down
6 changes: 3 additions & 3 deletions tests/datalab/issue_manager/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ def test_get_summary_parameters(self, issue_manager, monkeypatch):
"confident_joint": [1 / 3, 1 / 3, 1 / 3],
"multi_label": False,
}
monkeypatch.setattr(
issue_manager.datalab, "_labels", mock_health_summary_parameters["labels"]
)
pred_probs = np.random.rand(3, 3)
monkeypatch.setattr(
issue_manager, "health_summary_parameters", mock_health_summary_parameters
Expand Down Expand Up @@ -89,6 +86,9 @@ def test_get_summary_parameters(self, issue_manager, monkeypatch):

# Test missing "joint" key
mock_health_summary_parameters.pop("joint")
monkeypatch.setattr(
issue_manager.datalab._labels, "labels", mock_health_summary_parameters["labels"]
)
monkeypatch.setattr(
issue_manager, "health_summary_parameters", mock_health_summary_parameters
)
Expand Down
10 changes: 5 additions & 5 deletions tests/datalab/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ def test_init_data_properties(self, dataset):
assert data._data == dataset

# All elements in the _labels attribute are integers in the range [0, num_classes - 1]
num_classes = len(set(data._label_map))
all_labels_are_ints = np.issubdtype(data._labels.dtype, np.integer)
assert all_labels_are_ints, f"{data._labels} should be a list of integers"
assert all(0 <= label < num_classes for label in data._labels)
num_classes = len(set(data.labels.label_map))
all_labels_are_ints = np.issubdtype(data.labels.labels.dtype, np.integer)
assert all_labels_are_ints, f"{data.labels.labels} should be a list of integers"
assert all(0 <= label < num_classes for label in data.labels.labels)

assert all(isinstance(label, int) for label in data._label_map.keys())
assert all(isinstance(label, int) for label in data.labels.label_map.keys())

def test_init_data(self, dataset_and_label_name):
dataset, label_name = dataset_and_label_name
Expand Down
Loading