Skip to content

Commit

Permalink
[src & tests & egs] Refactor utils files into folder (asteroid-team#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
mpariente authored May 25, 2020
1 parent 238acca commit a0b1d5e
Show file tree
Hide file tree
Showing 12 changed files with 296 additions and 285 deletions.
2 changes: 2 additions & 0 deletions asteroid/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .utils import deprecation_utils, torch_utils

__version__ = '0.2.1'
2 changes: 1 addition & 1 deletion asteroid/losses/mse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ..deprecation_utils import DeprecationMixin
from ..utils.deprecation_utils import DeprecationMixin
from torch.nn.modules.loss import _Loss


Expand Down
2 changes: 1 addition & 1 deletion asteroid/losses/sdr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch.nn.modules.loss import _Loss
from ..deprecation_utils import DeprecationMixin
from ..utils.deprecation_utils import DeprecationMixin

EPS = 1e-8

Expand Down
275 changes: 0 additions & 275 deletions asteroid/utils.py

This file was deleted.

10 changes: 10 additions & 0 deletions asteroid/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .parser_utils import (
prepare_parser_from_dict, parse_args_as_dict, str_int_float, str2bool,
str2bool_arg, isfloat, isint
)
from .torch_utils import tensors_to_device, to_cuda
from .generic_utils import (
has_arg, flatten_dict, average_arrays_in_dic, get_wav_random_start_stop
)
# The functions above were all in asteroid/utils.py before refactoring into
# asteroid/utils/*_utils.py files. They are imported for backward compatibility.
File renamed without changes.
84 changes: 84 additions & 0 deletions asteroid/utils/generic_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import inspect
import collections
import numpy as np


def has_arg(fn, name):
""" Checks if a callable accepts a given keyword argument.
Args:
fn (callable): Callable to inspect.
name (str): Check if `fn` can be called with `name` as a keyword
argument.
Returns:
bool: whether `fn` accepts a `name` keyword argument.
"""
signature = inspect.signature(fn)
parameter = signature.parameters.get(name)
if parameter is None:
return False
return (parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY))


def flatten_dict(d, parent_key='', sep='_'):
""" Flattens a dictionary into a single-level dictionary while preserving
parent keys. Taken from https://stackoverflow.com/questions/6027558/
flatten-nested-dictionaries-compressing-keys?answertab=votes#tab-top
Args:
d (collections.MutableMapping): Dictionary to be flattened.
parent_key (str): String to use as a prefix to all subsequent keys.
sep (str): String to use as a separator between two key levels.
Returns:
dict: Single-level dictionary, flattened.
"""
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)


def average_arrays_in_dic(dic):
""" Take average of numpy arrays in a dictionary.
Args:
dic (dict): Input dictionary to take average from
Returns:
dict: New dictionary with array averaged.
"""
# Copy dic first
dic = dict(dic)
for k, v in dic.items():
if isinstance(v, np.ndarray):
dic[k] = float(v.mean())
return dic


def get_wav_random_start_stop(signal_len, desired_len=4*8000):
""" Get indexes for a chunk of signal of a given length.
Args:
signal_len (int): length of the signal to trim.
desired_len (int): the length of [start:stop]
Returns:
tuple: random start integer, stop integer.
"""
if signal_len == desired_len or desired_len is None:
rand_start = 0
else:
rand_start = np.random.randint(0, signal_len - desired_len)
if desired_len is None:
stop = None
else:
stop = rand_start + desired_len
return rand_start, stop
Loading

0 comments on commit a0b1d5e

Please sign in to comment.