Skip to content
This repository has been archived by the owner on Feb 15, 2023. It is now read-only.

Commit

Permalink
Dataset label visualisation (facebookresearch#183)
Browse files Browse the repository at this point in the history
Summary:
Utilities to use in notebooks to show the statistics on the labels of a dataset + histogram visualisation

Pull Request resolved: fairinternal/ssl_scaling#183

Reviewed By: prigoyal

Differential Revision: D30433880

Pulled By: QuentinDuval

fbshipit-source-id: e899f485687a6b0ef4797778334636896fdb4212
  • Loading branch information
QuentinDuval authored and facebook-github-bot committed Sep 30, 2021
1 parent a3aa3e4 commit 90204a5
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
4 changes: 4 additions & 0 deletions vissl/utils/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
63 changes: 63 additions & 0 deletions vissl/utils/visualization/dataset_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, List, Tuple

import numpy as np


class LabelStatistics:
"""
Useful statistics and visualisations to explore a dataset
"""

@classmethod
def label_statistics(cls, labels: List[int]) -> Dict[str, int]:
counter = {}
for label in labels:
counter[label] = counter.get(label, 0) + 1
counts = list(counter.values())
return {
"min": int(np.min(counts)),
"max": int(np.max(counts)),
"mean": int(np.mean(counts)),
"median": int(np.median(counts)),
"std": int(np.std(counts)),
"percentile_5": int(np.percentile(counts, 5)),
"percentile_95": int(np.percentile(counts, 95)),
}

@classmethod
def label_histogram(
cls, labels: List[int], figsize: Tuple[int, int] = (20, 8)
) -> None:
"""
Compute and display some statistics about labels:
- number of samples associated to each label
- histogram of the number of samples by label
"""
import matplotlib.pyplot as plt

histogram = cls.compute_histogram(labels)
histogram = sorted(histogram.items())
xs = [x for x, _ in histogram]
ys = [y for _, y in histogram]
plt.figure(figsize=figsize)
plt.bar(xs, ys)
plt.show()

@staticmethod
def compute_histogram(labels: List[int]) -> Dict[int, int]:
# How many samples assigned to each label
counter = {}
for label in labels:
counter[label] = counter.get(label, 0) + 1
counts = list(counter.values())

# Histogram of number of samples by centroids
histogram = {}
for count in counts:
histogram[count] = histogram.get(count, 0) + 1
return histogram

0 comments on commit 90204a5

Please sign in to comment.