-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
11 changed files
with
401 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
::: models.tts.styledtts2.diffusion.distributions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
::: models.tts.styledtts2.diffusion.utils |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from math import atan, pi | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
""" Distributions """ | ||
|
||
|
||
class Distribution: | ||
r"""Base class for all distributions.""" | ||
|
||
def __call__(self, num_samples: int) -> Tensor: | ||
r"""Generate a number of samples from the distribution. | ||
Args: | ||
num_samples (int): The number of samples to generate. | ||
Raises: | ||
NotImplementedError: This method should be overridden by subclasses. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class LogNormalDistribution(Distribution): | ||
r"""Log-normal distribution.""" | ||
|
||
def __init__(self, mean: float, std: float): | ||
r"""Initialize the distribution with a mean and standard deviation. | ||
Args: | ||
mean (float): The mean of the log-normal distribution. | ||
std (float): The standard deviation of the log-normal distribution. | ||
""" | ||
self.mean = mean | ||
self.std = std | ||
|
||
def __call__( | ||
self, num_samples: int, | ||
) -> Tensor: | ||
r"""Generate a number of samples from the log-normal distribution. | ||
Args: | ||
num_samples (int): The number of samples to generate. | ||
Returns: | ||
Tensor: A tensor of samples from the log-normal distribution. | ||
""" | ||
normal = self.mean + self.std * torch.randn((num_samples,)) | ||
return normal.exp() | ||
|
||
|
||
class UniformDistribution(Distribution): | ||
r"""Uniform distribution.""" | ||
|
||
def __call__(self, num_samples: int): | ||
r"""Generate a number of samples from the uniform distribution. | ||
Args: | ||
num_samples (int): The number of samples to generate. | ||
Returns: | ||
Tensor: A tensor of samples from the uniform distribution. | ||
""" | ||
return torch.rand(num_samples) | ||
|
||
|
||
class VKDistribution(Distribution): | ||
r"""VK distribution.""" | ||
|
||
def __init__( | ||
self, | ||
min_value: float = 0.0, | ||
max_value: float = float("inf"), | ||
sigma_data: float = 1.0, | ||
): | ||
r"""Initialize the distribution with a minimum value, maximum value, and sigma data. | ||
Args: | ||
min_value (float): The minimum value for the inverse CDF. Defaults to 0.0. | ||
max_value (float): The maximum value for the inverse CDF. Defaults to infinity. | ||
sigma_data (float): The sigma data of the VK distribution. Defaults to 1.0. | ||
""" | ||
self.min_value = min_value | ||
self.max_value = max_value | ||
self.sigma_data = sigma_data | ||
|
||
def __call__( | ||
self, num_samples: int, | ||
) -> Tensor: | ||
r"""Generate a number of samples from the VK distribution. | ||
Args: | ||
num_samples (int): The number of samples to generate. | ||
Returns: | ||
Tensor: A tensor of samples from the VK distribution. | ||
""" | ||
sigma_data = self.sigma_data | ||
min_cdf = atan(self.min_value / sigma_data) * 2 / pi | ||
max_cdf = atan(self.max_value / sigma_data) * 2 / pi | ||
u = (max_cdf - min_cdf) * torch.randn((num_samples,)) + min_cdf | ||
return torch.tan(u * pi / 2) * sigma_data |
Empty file.
38 changes: 38 additions & 0 deletions
38
models/tts/styledtts2/diffusion/tests/test_distributions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
from models.tts.styledtts2.diffusion.distributions import ( | ||
Distribution, | ||
LogNormalDistribution, | ||
UniformDistribution, | ||
VKDistribution, | ||
) | ||
|
||
|
||
class TestDistributions(unittest.TestCase): | ||
def test_distribution(self): | ||
with self.assertRaises(NotImplementedError): | ||
Distribution()(10) | ||
|
||
def test_log_normal_distribution(self): | ||
dist = LogNormalDistribution(mean=0.0, std=1.0) | ||
samples = dist(10) | ||
self.assertEqual(samples.shape, (10,)) | ||
self.assertTrue(torch.all(samples > 0)) | ||
|
||
def test_uniform_distribution(self): | ||
dist = UniformDistribution() | ||
samples = dist(10) | ||
self.assertEqual(samples.shape, (10,)) | ||
self.assertTrue(torch.all((samples >= 0) & (samples < 1))) | ||
|
||
def test_vk_distribution(self): | ||
dist = VKDistribution(min_value=0.0, max_value=1.0, sigma_data=1.0) | ||
samples = dist(10) | ||
self.assertEqual(samples.shape, (10,)) | ||
# No range check as the VKDistribution does not guarantee a specific range of output values | ||
# self.assertTrue(torch.all((samples >= 0) & (samples <= 1))) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
from models.tts.styledtts2.diffusion.utils import ( | ||
closest_power_2, | ||
default, | ||
exists, | ||
group_dict_by_prefix, | ||
groupby, | ||
iff, | ||
is_sequence, | ||
prefix_dict, | ||
prod, | ||
rand_bool, | ||
to_list, | ||
) | ||
|
||
|
||
class TestUtils(unittest.TestCase): | ||
def test_exists(self): | ||
self.assertTrue(exists(1)) | ||
self.assertFalse(exists(None)) | ||
|
||
def test_iff(self): | ||
self.assertEqual(iff(True, "value"), "value") | ||
self.assertEqual(iff(False, "value"), None) | ||
|
||
def test_is_sequence(self): | ||
self.assertTrue(is_sequence([1, 2, 3])) | ||
self.assertTrue(is_sequence((1, 2, 3))) | ||
self.assertFalse(is_sequence(123)) | ||
|
||
def test_default(self): | ||
self.assertEqual(default(None, "default"), "default") | ||
self.assertEqual(default("value", "default"), "value") | ||
self.assertEqual(default(None, lambda: "default"), "default") | ||
|
||
def test_to_list(self): | ||
self.assertEqual(to_list((1, 2, 3)), [1, 2, 3]) | ||
self.assertEqual(to_list([1, 2, 3]), [1, 2, 3]) | ||
self.assertEqual(to_list(1), [1]) | ||
|
||
def test_prod(self): | ||
self.assertEqual(prod([1, 2, 3, 4]), 24) | ||
|
||
def test_closest_power_2(self): | ||
self.assertEqual(closest_power_2(6), 4) | ||
self.assertEqual(closest_power_2(9), 8) | ||
|
||
def test_rand_bool(self): | ||
shape = (3, 3) | ||
tensor = rand_bool(shape, 0.5) | ||
self.assertEqual(tensor.shape, shape) | ||
self.assertTrue(tensor.dtype == torch.bool) | ||
|
||
def test_group_dict_by_prefix(self): | ||
d = {"prefix_key1": 1, "prefix_key2": 2, "key3": 3} | ||
with_prefix, without_prefix = group_dict_by_prefix("prefix_", d) | ||
self.assertEqual(with_prefix, {"prefix_key1": 1, "prefix_key2": 2}) | ||
self.assertEqual(without_prefix, {"key3": 3}) | ||
|
||
def test_groupby(self): | ||
d = {"prefix_key1": 1, "prefix_key2": 2, "key3": 3} | ||
with_prefix, without_prefix = groupby("prefix_", d) | ||
self.assertEqual(with_prefix, {"key1": 1, "key2": 2}) | ||
self.assertEqual(without_prefix, {"key3": 3}) | ||
|
||
def test_prefix_dict(self): | ||
d = {"key1": 1, "key2": 2, "key3": 3} | ||
prefixed = prefix_dict("prefix_", d) | ||
self.assertEqual(prefixed, {"prefix_key1": 1, "prefix_key2": 2, "prefix_key3": 3}) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.