Skip to content

Commit

Permalink
Add utils
Browse files Browse the repository at this point in the history
Add distributions
  • Loading branch information
nickovchinnikov committed Feb 12, 2024
1 parent 0e634e7 commit 4e5daf8
Show file tree
Hide file tree
Showing 11 changed files with 401 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/models/tts/styledtts2/diffusion/distributions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: models.tts.styledtts2.diffusion.distributions
1 change: 1 addition & 0 deletions docs/models/tts/styledtts2/diffusion/utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: models.tts.styledtts2.diffusion.utils
4 changes: 4 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ nav:
- Reference Encoder: models/tts/delightful_tts/reference_encoder/reference_encoder.md
- Utterance Level Prosody Encoder: models/tts/delightful_tts/reference_encoder/utterance_level_prosody_encoder.md
- Phoneme Level Prosody Encoder: models/tts/delightful_tts/reference_encoder/phoneme_level_prosody_encoder.md
- StyledTTS 2:
- Diffusion:
- Distributions: models/tts/styledtts2/diffusion/distributions.md
- Utils: models/tts/styledtts2/diffusion/utils.md
- Vocoder:
- Univnet:
- References: models/vocoder/univnet/readme.md
Expand Down
Empty file.
Empty file.
102 changes: 102 additions & 0 deletions models/tts/styledtts2/diffusion/distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from math import atan, pi

import torch
from torch import Tensor

""" Distributions """


class Distribution:
r"""Base class for all distributions."""

def __call__(self, num_samples: int) -> Tensor:
r"""Generate a number of samples from the distribution.
Args:
num_samples (int): The number of samples to generate.
Raises:
NotImplementedError: This method should be overridden by subclasses.
"""
raise NotImplementedError


class LogNormalDistribution(Distribution):
r"""Log-normal distribution."""

def __init__(self, mean: float, std: float):
r"""Initialize the distribution with a mean and standard deviation.
Args:
mean (float): The mean of the log-normal distribution.
std (float): The standard deviation of the log-normal distribution.
"""
self.mean = mean
self.std = std

def __call__(
self, num_samples: int,
) -> Tensor:
r"""Generate a number of samples from the log-normal distribution.
Args:
num_samples (int): The number of samples to generate.
Returns:
Tensor: A tensor of samples from the log-normal distribution.
"""
normal = self.mean + self.std * torch.randn((num_samples,))
return normal.exp()


class UniformDistribution(Distribution):
r"""Uniform distribution."""

def __call__(self, num_samples: int):
r"""Generate a number of samples from the uniform distribution.
Args:
num_samples (int): The number of samples to generate.
Returns:
Tensor: A tensor of samples from the uniform distribution.
"""
return torch.rand(num_samples)


class VKDistribution(Distribution):
r"""VK distribution."""

def __init__(
self,
min_value: float = 0.0,
max_value: float = float("inf"),
sigma_data: float = 1.0,
):
r"""Initialize the distribution with a minimum value, maximum value, and sigma data.
Args:
min_value (float): The minimum value for the inverse CDF. Defaults to 0.0.
max_value (float): The maximum value for the inverse CDF. Defaults to infinity.
sigma_data (float): The sigma data of the VK distribution. Defaults to 1.0.
"""
self.min_value = min_value
self.max_value = max_value
self.sigma_data = sigma_data

def __call__(
self, num_samples: int,
) -> Tensor:
r"""Generate a number of samples from the VK distribution.
Args:
num_samples (int): The number of samples to generate.
Returns:
Tensor: A tensor of samples from the VK distribution.
"""
sigma_data = self.sigma_data
min_cdf = atan(self.min_value / sigma_data) * 2 / pi
max_cdf = atan(self.max_value / sigma_data) * 2 / pi
u = (max_cdf - min_cdf) * torch.randn((num_samples,)) + min_cdf
return torch.tan(u * pi / 2) * sigma_data
Empty file.
38 changes: 38 additions & 0 deletions models/tts/styledtts2/diffusion/tests/test_distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import unittest

import torch

from models.tts.styledtts2.diffusion.distributions import (
Distribution,
LogNormalDistribution,
UniformDistribution,
VKDistribution,
)


class TestDistributions(unittest.TestCase):
def test_distribution(self):
with self.assertRaises(NotImplementedError):
Distribution()(10)

def test_log_normal_distribution(self):
dist = LogNormalDistribution(mean=0.0, std=1.0)
samples = dist(10)
self.assertEqual(samples.shape, (10,))
self.assertTrue(torch.all(samples > 0))

def test_uniform_distribution(self):
dist = UniformDistribution()
samples = dist(10)
self.assertEqual(samples.shape, (10,))
self.assertTrue(torch.all((samples >= 0) & (samples < 1)))

def test_vk_distribution(self):
dist = VKDistribution(min_value=0.0, max_value=1.0, sigma_data=1.0)
samples = dist(10)
self.assertEqual(samples.shape, (10,))
# No range check as the VKDistribution does not guarantee a specific range of output values
# self.assertTrue(torch.all((samples >= 0) & (samples <= 1)))

if __name__ == "__main__":
unittest.main()
75 changes: 75 additions & 0 deletions models/tts/styledtts2/diffusion/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import unittest

import torch

from models.tts.styledtts2.diffusion.utils import (
closest_power_2,
default,
exists,
group_dict_by_prefix,
groupby,
iff,
is_sequence,
prefix_dict,
prod,
rand_bool,
to_list,
)


class TestUtils(unittest.TestCase):
def test_exists(self):
self.assertTrue(exists(1))
self.assertFalse(exists(None))

def test_iff(self):
self.assertEqual(iff(True, "value"), "value")
self.assertEqual(iff(False, "value"), None)

def test_is_sequence(self):
self.assertTrue(is_sequence([1, 2, 3]))
self.assertTrue(is_sequence((1, 2, 3)))
self.assertFalse(is_sequence(123))

def test_default(self):
self.assertEqual(default(None, "default"), "default")
self.assertEqual(default("value", "default"), "value")
self.assertEqual(default(None, lambda: "default"), "default")

def test_to_list(self):
self.assertEqual(to_list((1, 2, 3)), [1, 2, 3])
self.assertEqual(to_list([1, 2, 3]), [1, 2, 3])
self.assertEqual(to_list(1), [1])

def test_prod(self):
self.assertEqual(prod([1, 2, 3, 4]), 24)

def test_closest_power_2(self):
self.assertEqual(closest_power_2(6), 4)
self.assertEqual(closest_power_2(9), 8)

def test_rand_bool(self):
shape = (3, 3)
tensor = rand_bool(shape, 0.5)
self.assertEqual(tensor.shape, shape)
self.assertTrue(tensor.dtype == torch.bool)

def test_group_dict_by_prefix(self):
d = {"prefix_key1": 1, "prefix_key2": 2, "key3": 3}
with_prefix, without_prefix = group_dict_by_prefix("prefix_", d)
self.assertEqual(with_prefix, {"prefix_key1": 1, "prefix_key2": 2})
self.assertEqual(without_prefix, {"key3": 3})

def test_groupby(self):
d = {"prefix_key1": 1, "prefix_key2": 2, "key3": 3}
with_prefix, without_prefix = groupby("prefix_", d)
self.assertEqual(with_prefix, {"key1": 1, "key2": 2})
self.assertEqual(without_prefix, {"key3": 3})

def test_prefix_dict(self):
d = {"key1": 1, "key2": 2, "key3": 3}
prefixed = prefix_dict("prefix_", d)
self.assertEqual(prefixed, {"prefix_key1": 1, "prefix_key2": 2, "prefix_key3": 3})

if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 4e5daf8

Please sign in to comment.