This repository has been archived by the owner on Feb 15, 2023. It is now read-only.
forked from facebookresearch/vissl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Dataset label visualisation (facebookresearch#183)
Summary: Utilities to use in notebooks to show the statistics on the labels of a dataset + histogram visualisation Pull Request resolved: fairinternal/ssl_scaling#183 Reviewed By: prigoyal Differential Revision: D30433880 Pulled By: QuentinDuval fbshipit-source-id: e899f485687a6b0ef4797778334636896fdb4212
- Loading branch information
1 parent
a3aa3e4
commit 90204a5
Showing
2 changed files
with
67 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
|
||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
|
||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Dict, List, Tuple | ||
|
||
import numpy as np | ||
|
||
|
||
class LabelStatistics: | ||
""" | ||
Useful statistics and visualisations to explore a dataset | ||
""" | ||
|
||
@classmethod | ||
def label_statistics(cls, labels: List[int]) -> Dict[str, int]: | ||
counter = {} | ||
for label in labels: | ||
counter[label] = counter.get(label, 0) + 1 | ||
counts = list(counter.values()) | ||
return { | ||
"min": int(np.min(counts)), | ||
"max": int(np.max(counts)), | ||
"mean": int(np.mean(counts)), | ||
"median": int(np.median(counts)), | ||
"std": int(np.std(counts)), | ||
"percentile_5": int(np.percentile(counts, 5)), | ||
"percentile_95": int(np.percentile(counts, 95)), | ||
} | ||
|
||
@classmethod | ||
def label_histogram( | ||
cls, labels: List[int], figsize: Tuple[int, int] = (20, 8) | ||
) -> None: | ||
""" | ||
Compute and display some statistics about labels: | ||
- number of samples associated to each label | ||
- histogram of the number of samples by label | ||
""" | ||
import matplotlib.pyplot as plt | ||
|
||
histogram = cls.compute_histogram(labels) | ||
histogram = sorted(histogram.items()) | ||
xs = [x for x, _ in histogram] | ||
ys = [y for _, y in histogram] | ||
plt.figure(figsize=figsize) | ||
plt.bar(xs, ys) | ||
plt.show() | ||
|
||
@staticmethod | ||
def compute_histogram(labels: List[int]) -> Dict[int, int]: | ||
# How many samples assigned to each label | ||
counter = {} | ||
for label in labels: | ||
counter[label] = counter.get(label, 0) + 1 | ||
counts = list(counter.values()) | ||
|
||
# Histogram of number of samples by centroids | ||
histogram = {} | ||
for count in counts: | ||
histogram[count] = histogram.get(count, 0) + 1 | ||
return histogram |