forked from magenta/ddsp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
894 lines (745 loc) · 33.2 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
# Copyright 2020 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Library of loss functions."""
import functools
import crepe
from ddsp import spectral_ops
from ddsp.core import hz_to_midi
from ddsp.core import safe_divide
from ddsp.core import tf_float32
import gin
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfkl = tf.keras.layers
# ---------------------- Losses ------------------------------------------------
def mean_difference(target, value, loss_type='L1', weights=None):
"""Common loss functions.
Args:
target: Target tensor.
value: Value tensor.
loss_type: One of 'L1', 'L2', or 'COSINE'.
weights: A weighting mask for the per-element differences.
Returns:
The average loss.
Raises:
ValueError: If loss_type is not an allowed value.
"""
difference = target - value
weights = 1.0 if weights is None else weights
loss_type = loss_type.upper()
if loss_type == 'L1':
return tf.reduce_mean(tf.abs(difference * weights))
elif loss_type == 'L2':
return tf.reduce_mean(difference**2 * weights)
elif loss_type == 'COSINE':
return tf.losses.cosine_distance(target, value, weights=weights, axis=-1)
else:
raise ValueError('Loss type ({}), must be '
'"L1", "L2", or "COSINE"'.format(loss_type))
@gin.register
class SpectralLoss(tfkl.Layer):
"""Multi-scale spectrogram loss.
This loss is the bread-and-butter of comparing two audio signals. It offers
a range of options to compare spectrograms, many of which are redunant, but
emphasize different aspects of the signal. By far, the most common comparisons
are magnitudes (mag_weight) and log magnitudes (logmag_weight).
"""
def __init__(self,
fft_sizes=(2048, 1024, 512, 256, 128, 64),
loss_type='L1',
mag_weight=1.0,
delta_time_weight=0.0,
delta_freq_weight=0.0,
cumsum_freq_weight=0.0,
logmag_weight=0.0,
loudness_weight=0.0,
name='spectral_loss'):
"""Constructor, set loss weights of various components.
Args:
fft_sizes: Compare spectrograms at each of this list of fft sizes. Each
spectrogram has a time-frequency resolution trade-off based on fft size,
so comparing multiple scales allows multiple resolutions.
loss_type: One of 'L1', 'L2', or 'COSINE'.
mag_weight: Weight to compare linear magnitudes of spectrograms. Core
audio similarity loss. More sensitive to peak magnitudes than log
magnitudes.
delta_time_weight: Weight to compare the first finite difference of
spectrograms in time. Emphasizes changes of magnitude in time, such as
at transients.
delta_freq_weight: Weight to compare the first finite difference of
spectrograms in frequency. Emphasizes changes of magnitude in frequency,
such as at the boundaries of a stack of harmonics.
cumsum_freq_weight: Weight to compare the cumulative sum of spectrograms
across frequency for each slice in time. Similar to a 1-D Wasserstein
loss, this hopefully provides a non-vanishing gradient to push two
non-overlapping sinusoids towards eachother.
logmag_weight: Weight to compare log magnitudes of spectrograms. Core
audio similarity loss. More sensitive to quiet magnitudes than linear
magnitudes.
loudness_weight: Weight to compare the overall perceptual loudness of two
signals. Very high-level loss signal that is a subset of mag and
logmag losses.
name: Name of the module.
"""
super().__init__(name=name)
self.fft_sizes = fft_sizes
self.loss_type = loss_type
self.mag_weight = mag_weight
self.delta_time_weight = delta_time_weight
self.delta_freq_weight = delta_freq_weight
self.cumsum_freq_weight = cumsum_freq_weight
self.logmag_weight = logmag_weight
self.loudness_weight = loudness_weight
self.spectrogram_ops = []
for size in self.fft_sizes:
spectrogram_op = functools.partial(spectral_ops.compute_mag, size=size)
self.spectrogram_ops.append(spectrogram_op)
def call(self, target_audio, audio, weights=None):
loss = 0.0
diff = spectral_ops.diff
cumsum = tf.math.cumsum
# Compute loss for each fft size.
for loss_op in self.spectrogram_ops:
target_mag = loss_op(target_audio)
value_mag = loss_op(audio)
# Add magnitude loss.
if self.mag_weight > 0:
loss += self.mag_weight * mean_difference(
target_mag, value_mag, self.loss_type, weights=weights)
if self.delta_time_weight > 0:
target = diff(target_mag, axis=1)
value = diff(value_mag, axis=1)
loss += self.delta_time_weight * mean_difference(
target, value, self.loss_type, weights=weights)
if self.delta_freq_weight > 0:
target = diff(target_mag, axis=2)
value = diff(value_mag, axis=2)
loss += self.delta_freq_weight * mean_difference(
target, value, self.loss_type, weights=weights)
# TODO(kyriacos) normalize cumulative spectrogram
if self.cumsum_freq_weight > 0:
target = cumsum(target_mag, axis=2)
value = cumsum(value_mag, axis=2)
loss += self.cumsum_freq_weight * mean_difference(
target, value, self.loss_type, weights=weights)
# Add logmagnitude loss, reusing spectrogram.
if self.logmag_weight > 0:
target = spectral_ops.safe_log(target_mag)
value = spectral_ops.safe_log(value_mag)
loss += self.logmag_weight * mean_difference(
target, value, self.loss_type, weights=weights)
if self.loudness_weight > 0:
target = spectral_ops.compute_loudness(target_audio, n_fft=2048,
use_tf=True)
value = spectral_ops.compute_loudness(audio, n_fft=2048, use_tf=True)
loss += self.loudness_weight * mean_difference(
target, value, self.loss_type, weights=weights)
return loss
# ------------------------------------------------------------------------------
# Peceptual Losses
# ------------------------------------------------------------------------------
@gin.register
class EmbeddingLoss(tfkl.Layer):
"""Embedding loss for a given pretrained model.
Using these "perceptual" loss functions will help encourage better matching
of higher-order structure than just a spectral loss. In image models, these
perceptual losses help overcome the tendancy of straightforward L1 and L2
losses to lead to blurry images. For ddsp, a "blurry" image is often a
filtered noise synthesizer or reverb smearing power density in the right areas
of a spectrogram but without the appropriate harmonic structure. This
perceptual loss encouages output from harmonic and sinusoidal models because
the pretrained model compares the two audio signals with features that are
trained to detect the harmonic structure of natural sounds.
"""
def __init__(self,
weight=1.0,
loss_type='L1',
pretrained_model=None,
name='embedding_loss'):
super().__init__(name=name)
self.weight = weight
self.loss_type = loss_type
self.pretrained_model = pretrained_model
def call(self, target_audio, audio):
loss = 0.0
if self.weight > 0.0:
audio, target_audio = tf_float32(audio), tf_float32(target_audio)
target_emb = self.pretrained_model(target_audio)
synth_emb = self.pretrained_model(audio)
loss = self.weight * mean_difference(
target_emb, synth_emb, self.loss_type)
return loss
@gin.register
class PretrainedCREPEEmbeddingLoss(EmbeddingLoss):
"""Embedding loss of a pretrained CREPE model."""
def __init__(self,
weight=1.0,
loss_type='L1',
model_capacity='tiny',
activation_layer='classifier',
name='pretrained_crepe_embedding_loss'):
# Scale each layer activation loss to comparable scales.
scale = {
'conv1-BN': 1.3,
'conv1-maxpool': 1.0,
'conv2-BN': 1.4,
'conv2-maxpool': 1.1,
'conv3-BN': 1.9,
'conv3-maxpool': 1.6,
'conv4-BN': 1.5,
'conv4-maxpool': 1.4,
'conv5-BN': 1.9,
'conv5-maxpool': 1.7,
'conv6-BN': 30,
'conv6-maxpool': 25,
'classifier': 130,
}[activation_layer]
super().__init__(
weight=20.0 * scale * weight,
loss_type=loss_type,
name=name,
pretrained_model=PretrainedCREPE(model_capacity=model_capacity,
activation_layer=activation_layer))
class PretrainedCREPE(tfkl.Layer):
"""Pretrained CREPE model with frozen weights."""
def __init__(self,
model_capacity='tiny',
activation_layer='conv5-maxpool',
name='pretrained_crepe',
trainable=False):
super().__init__(name=name, trainable=trainable)
self._model_capacity = model_capacity
self._activation_layer = activation_layer
spectral_ops.reset_crepe()
self._model = crepe.core.build_and_load_model(self._model_capacity)
self.frame_length = 1024
def build(self, unused_x_shape):
self.layer_names = [l.name for l in self._model.layers]
if self._activation_layer not in self.layer_names:
raise ValueError(
'activation layer {} not found, valid names are {}'.format(
self._activation_layer, self.layer_names))
self._activation_model = tf.keras.Model(
inputs=self._model.input,
outputs=self._model.get_layer(self._activation_layer).output)
# Variables are not to be trained.
self._model.trainable = self.trainable
self._activation_model.trainable = self.trainable
def frame_audio(self, audio, hop_length=1024, center=True):
"""Slice audio into frames for crepe."""
# Pad so that frames are centered around their timestamps.
# (i.e. first frame is zero centered).
pad = int(self.frame_length / 2)
audio = tf.pad(audio, ((0, 0), (pad, pad))) if center else audio
frames = tf.signal.frame(audio,
frame_length=self.frame_length,
frame_step=hop_length)
# Normalize each frame -- this is expected by the model.
mean, var = tf.nn.moments(frames, [-1], keepdims=True)
frames -= mean
frames /= (var**0.5 + 1e-5)
return frames
def call(self, audio):
"""Returns the embeddings.
Args:
audio: tensors of shape [batch, length]. Length must be divisible by 1024.
Returns:
activations of shape [batch, depth]
"""
frames = self.frame_audio(audio)
batch_size = int(frames.shape[0])
n_frames = int(frames.shape[1])
# Get model predictions.
frames = tf.reshape(frames, [-1, self.frame_length])
outputs = self._activation_model(frames)
outputs = tf.reshape(outputs, [batch_size, n_frames, -1])
return outputs
# ------------------------------------------------------------------------------
# Consistency Losses
# ------------------------------------------------------------------------------
def amp_loss(amp,
amp_target,
loss_type='L1',
weights=None,
log=False,
amin=1e-5):
"""Loss comparing two amplitudes (scale logarithmically)."""
if log:
# Put in a log scale (psychophysically appropriate for audio).
log10 = lambda x: tf.math.log(x) / tf.math.log(10.0)
amp = log10(tf.maximum(amin, amp))
amp_target = log10(tf.maximum(amin, amp_target))
# Take the difference.
return mean_difference(amp, amp_target, loss_type, weights)
def freq_loss(f_hz, f_hz_target, loss_type='L1', weights=None):
"""Loss comparing two frequencies."""
# Convert to MIDI.
f_midi = hz_to_midi(f_hz)
f_midi_target = hz_to_midi(f_hz_target)
# Take the difference.
return mean_difference(f_midi, f_midi_target, loss_type, weights)
@gin.register
class FilteredNoiseConsistencyLoss(tfkl.Layer):
"""Consistency loss for synthesizer controls.
EXPERIMENTAL
"""
def __init__(self, weight=1.0, **kwargs):
super().__init__(**kwargs)
self.weight = weight
def call(self, noise_magnitudes, noise_magnitudes_target):
"""Add losses to the model."""
self.built = True
return self.weight * amp_loss(noise_magnitudes, noise_magnitudes_target)
@gin.register
class HarmonicConsistencyLoss(tfkl.Layer):
"""Consistency loss for synthesizer controls.
EXPERIMENTAL
"""
def __init__(self,
amp_weight=1.0,
dist_weight=1.0,
f0_weight=1.0,
amp_threshold=1e-4,
**kwargs):
super().__init__(**kwargs)
self.amp_weight = amp_weight
self.dist_weight = dist_weight
self.f0_weight = f0_weight
self.amp_threshold = amp_threshold
def call(self,
harm_amp,
harm_amp_target,
harm_dist,
harm_dist_target,
f0_hz,
f0_hz_target):
"""Add losses to the model."""
self.built = True
losses_dict = {}
# Mask loss where target audio is below threshold amplitude.
weights = tf.cast(harm_amp_target >= self.amp_threshold, tf.float32)
# Harmonic amplitude.
harm_amp_loss = amp_loss(harm_amp, harm_amp_target)
losses_dict['harm_amp_loss'] = self.amp_weight * harm_amp_loss
# Harmonic distribution.
harm_dist_loss = amp_loss(harm_dist, harm_dist_target, weights=weights)
losses_dict['harm_dist_loss'] = self.dist_weight * harm_dist_loss
# Fundamental frequency.
f0_hz_loss = freq_loss(f0_hz, f0_hz_target, weights=weights)
losses_dict['f0_hz_loss'] = self.f0_weight * f0_hz_loss
return losses_dict
# ------------------------------------------------------------------------------
# Sinusoidal Consistency Losses
# ------------------------------------------------------------------------------
@gin.register
class WassersteinConsistencyLoss(tfkl.Layer):
"""Compare similarity of two traces of sinusoids using wasserstein distance.
EXPERIMENTAL
"""
def __init__(self,
weight=1.0,
midi=True,
**kwargs):
"""Constructor.
Args:
weight: Loss weight.
midi: Scale frequencies logarithmically (MIDI).
**kwargs: Extra args for tfkl.Layer.
"""
super().__init__(**kwargs)
self.weight = weight
self.midi = midi
def call(self, amps_a, freqs_a, amps_b, freqs_b):
"""Returns the sinusoidal consistency loss scalar.
Args:
amps_a: Amplitudes of first sinusoids, greater than 0.
Shape [batch, time, freq].
freqs_a: Frequencies of first sinusoids in hertz.
Shape [batch, time, feq].
amps_b: Amplitudes of second sinusoids, greater than 0.
Shape [batch, time, freq].
freqs_b: Frequencies of second sinusoids in hertz.
Shape [batch, time, feq].
Returns:
Scalar, weighted wasserstein distance.
"""
loss = 0.0
if self.weight > 0.0:
if self.midi:
freqs_a = hz_to_midi(freqs_a)
freqs_b = hz_to_midi(freqs_b)
loss = wasserstein_distance(freqs_a, freqs_b, amps_a, amps_b, p=1.0)
loss = tf.reduce_mean(self.weight * loss)
return loss
def wasserstein_distance(u_values, v_values, u_weights, v_weights, p=1.0):
"""Differentiable 1-D Wasserstein distance.
Adapted from the scipy.stats implementation.
Args:
u_values: Samples from distribution `u`. Shape [batch_shape, n_samples].
v_values: Samples from distribution `v`. Shape [batch_shape, n_samples].
u_weights: Sample weights. Shape [batch_shape, n_samples].
v_weights: Sample weights. Shape [batch_shape, n_samples].
p: Degree of the distance norm. Wasserstein=1, Energy=2.
Returns:
The Wasserstein distance between samples. Shape [batch_shape].
"""
u_sorter = tf.argsort(u_values, axis=-1)
v_sorter = tf.argsort(v_values, axis=-1)
all_values = tf.concat([u_values, v_values], axis=-1)
all_values = tf.sort(all_values, axis=-1)
# Compute the differences between pairs of successive values of u and v.
deltas = spectral_ops.diff(all_values, axis=-1)
# Get the respective positions of the values of u and v among the values of
# both distributions.
batch_dims = len(u_values.shape) - 1
gather = lambda x, i: tf.gather(x, i, axis=-1, batch_dims=batch_dims)
u_cdf_indices = tf.searchsorted(
gather(u_values, u_sorter), all_values[..., :-1], side='right')
v_cdf_indices = tf.searchsorted(
gather(v_values, v_sorter), all_values[..., :-1], side='right')
# Calculate the CDFs of u and v using their weights, if specified.
if u_weights is None:
u_cdf = u_cdf_indices / float(u_values.shape[-1])
else:
u_sorted_cumweights = tf.concat(
[tf.zeros_like(u_weights)[..., 0:1],
tf.cumsum(gather(u_weights, u_sorter), axis=-1)],
axis=-1)
u_cdf = gather(u_sorted_cumweights, u_cdf_indices)
safe_divide(u_cdf, u_sorted_cumweights[..., -1:])
if v_weights is None:
v_cdf = v_cdf_indices / float(v_values.shape[-1])
else:
v_sorted_cumweights = tf.concat(
[tf.zeros_like(v_weights)[..., 0:1],
tf.cumsum(gather(v_weights, v_sorter), axis=-1)],
axis=-1)
v_cdf = gather(v_sorted_cumweights, v_cdf_indices)
safe_divide(v_cdf, v_sorted_cumweights[..., -1:])
# Compute the value of the integral based on the CDFs.
return tf.reduce_sum(deltas * tf.abs(u_cdf - v_cdf)**p, axis=-1)**(1.0 / p)
@gin.register
class KDEConsistencyLoss(tfkl.Layer):
"""Compare similarity of two traces of sinusoids using kernels.
EXPERIMENTAL
Adapted from differentiable two-way mismatch loss. Very similar to the
Jefferys divergence. Use gaussian kernel density estimate in both directions
to compare likelihood of each set of sinusoids relative to the other.
Also enforces mean amplitudes to be the same, as JD by itself is insensitve to
absolute scale of the amplitudes.
"""
def __init__(self,
weight_a=1.0,
weight_b=1.0,
weight_mean_amp=1.0,
scale_a=0.1,
scale_b=0.1,
**kwargs):
"""Constructor.
Args:
weight_a: Weight for -log p(a|b) term.
weight_b: Weight for -log p(b|a) term.
weight_mean_amp: Weight to match the mean amplitudes between a and b.
scale_a: Scale of the gaussians around each sinusoid in MIDI.
scale_b: Scale of the gaussians around each sinusoid in MIDI.
**kwargs: Extra args for tfkl.Layer.
"""
super().__init__(**kwargs)
# Loss weights.
self.weight_a = weight_a
self.weight_b = weight_b
self.weight_mean_amp = weight_mean_amp
# Gaussian widths in MIDI.
self.scale_a = scale_a
self.scale_b = scale_b
def call(self, amps_a, freqs_a, amps_b, freqs_b):
"""Returns the sinusoidal consistency loss scalar.
Args:
amps_a: Amplitudes of first sinusoids, greater than 0.
Shape [batch, time, freq].
freqs_a: Frequencies of first sinusoids in hertz.
Shape [batch, time, feq].
amps_b: Amplitudes of second sinusoids, greater than 0.
Shape [batch, time, freq].
freqs_b: Frequencies of second sinusoids in hertz.
Shape [batch, time, feq].
Returns:
Scalar, weighted -log p(a|b) - log p(b|a).
"""
loss = 0.0
if self.weight_a > 0.0:
loss_a = self.nll(amps_a, freqs_a, amps_b, freqs_b, self.scale_b)
loss += tf.reduce_mean(self.weight_a * loss_a)
if self.weight_b > 0.0:
loss_b = self.nll(amps_b, freqs_b, amps_a, freqs_a, self.scale_a)
loss += tf.reduce_mean(self.weight_b * loss_b)
if self.weight_mean_amp > 0.0:
mean_amp_a = tf.reduce_mean(amps_a, axis=-1)
mean_amp_b = tf.reduce_mean(amps_b, axis=-1)
loss_mean_amp = tf.reduce_mean(tf.abs(mean_amp_a - mean_amp_b))
loss += self.weight_mean_amp * loss_mean_amp
return loss
def nll(self, amps, freqs, amps_target, freqs_target, scale_target):
"""Returns negative log-likelihood of source sins given target sins.
Args:
amps: Amplitudes of source sinusoids, greater than 0.
Shape [batch, time, freq].
freqs: Frequencies of source sinusoids in hertz.
Shape [batch, time, feq].
amps_target: Amplitudes of target sinusoids, greater than 0.
Shape [batch, time, freq].
freqs_target: Frequencies of target sinusoids in hertz.
Shape [batch, time, feq].
scale_target: Scale of gaussian kernel in MIDI.
Returns:
- log(p(source|target)). Shape [batch, time].
"""
p_source_given_target = self.kernel_density_estimate(
amps_target, freqs_target, scale_target)
# KDE is on a logarithmic scale (MIDI).
freqs_midi = hz_to_midi(freqs)
# Need to rearrage shape as tfp expects, [sample_sh, batch_sh, event_sh].
freqs_transpose = tf.transpose(freqs_midi, [2, 0, 1]) # [freq, batch, time]
nll_transpose = - p_source_given_target.log_prob(freqs_transpose)
nll = tf.transpose(nll_transpose, [1, 2, 0]) # [batch, time, freq]
# Weighted sum over sinusoids -> [batch, time]
amps_norm = safe_divide(amps, tf.reduce_sum(amps, axis=-1, keepdims=True))
return tf.reduce_mean(nll * amps_norm, axis=-1)
def kernel_density_estimate(self, amps, freqs, scale):
"""Gets distribution of harmonics from candidate f0s given sinusoids.
Performs a gaussian kernel density estimate on the sinusoid points, with the
height of each gaussian component given by the sinusoidal amplitude.
Args:
amps: Amplitudes of sinusoids, must be greater than 0.
freqs: Frequencies of sinusoids in hertz.
scale: Scale of gaussian kernel in MIDI.
Returns:
MixtureSameFamily, Gaussian distribution.
"""
# Gaussian KDE around each partial, height=amplitude, center=frequency.
freqs_midi = hz_to_midi(freqs)
# NLL can be a nan if sinusoid amps are all zero, add a small offset.
amps = tf.where(amps == 0.0, 1e-7 * tf.ones_like(amps), amps)
amps_norm = safe_divide(amps, tf.reduce_sum(amps, axis=-1, keepdims=True))
# P(candidate_harmonics | freqs)
return tfd.MixtureSameFamily(tfd.Categorical(probs=amps_norm),
tfd.Normal(loc=freqs_midi, scale=scale))
# ------------------------------------------------------------------------------
# Differentiable Two-way Mismatch Loss
# ------------------------------------------------------------------------------
@gin.register
class TWMLoss(tfkl.Layer):
"""Two-way Mismatch, encourages sinusoids to be harmonics best f0 candidate.
EXPERIMENTAL
This loss function is a differentiable / smooth probabalistic adaptation of
the heuristic Two-way Mismatch (TWM) algorithm used to extract a fundamental
frequency from a group of sinusoids.
Sketch of the algorithm:
* Given f0 candidates, produce a harmonic series off each candidate.
* Place a gaussian at each harmonic peak and evaluate the likelihood of the
sinusoids given that harmonic distribution p(sinusoids|harmonics).
* Place a gaussian at each sinusoid peak, and evaluate the likelihood of the
harmonics given that sinusoidal distribution p(harmonics|sinusoids).
* The total loss is the combined negative log-likelihood of each term,
Loss = - log p(sinusoids|harmonics) - log p(harmonics|sinusoids), with a
softmin over f0 candidates to only consider the best candidates.
The two terms work against eachother, with p(sinusoids|harmonics) rewarding
lower f0 candidates that have more densely packed coverage, and
p(harmonics|sinusoids) rewarding higher f0 candidates that have less chance
of falling outside the sinusoidal coverage. The global minima for most cases
is an intermediate f0 where the harmonics and sinusoids align.
This implementation can also be used (non-differentiably) for selecting the
best f0 candidate using an argmin over the loss.
Original paper:
Maher, Beaucamp, Fundamental Frequency Estimation of Musical Signals using a
two-way Mismatch Procedure, 1993.
http://www.montana.edu/rmaher/publications/maher_jasa_0494_2254-2263.pdf
Good tutorial video:
https://www.coursera.org/lecture/audio-signal-processing/f0-detection-c7o1c
"""
def __init__(self,
sinusoids_weight=1.0,
harmonics_weight=1.0,
sinusoids_scale=0.5,
harmonics_scale=0.2,
n_harmonic_points=10,
n_harmonic_gaussians=30,
softmin_temperature=1.0,
sample_rate=16000,
**kwargs):
"""Constructor.
Args:
sinusoids_weight: Weight for -log p(sinusoids|harmonics) term.
harmonics_weight: Weight for -log p(harmonics|sinusoids) term.
sinusoids_scale: Scale of the gaussians around each sinusoid.
harmonics_scale: Scale of the gaussians around each harmonic.
n_harmonic_points: Number of points to consider for each candidate f0 in
the p(harmonics|sinusoids) calculation.
n_harmonic_gaussians: Number of gaussians for each candidate f0 in the
p(sinusoids|harmonics) calculation.
softmin_temperature: Greater than 0, lower temperatures focus more on just
the best (loweset cost) f0 candidate for the total loss.
sample_rate: Audio sample rate in hertz. Used for ignoring harmonics above
nyquist.
**kwargs: Extra args for tfkl.Layer.
"""
super().__init__(**kwargs)
self.softmin_temperature = softmin_temperature
self.sample_rate = sample_rate
# Loss weights.
self.sinusoids_weight = sinusoids_weight
self.harmonics_weight = harmonics_weight
# p(sinusoids|harmonics).
self.sinusoids_scale = sinusoids_scale
self.n_harmonic_points = n_harmonic_points
# p(harmonics|sinusoids).
self.harmonics_scale = harmonics_scale
self.n_harmonic_gaussians = n_harmonic_gaussians
def call(self, f0_candidates, freqs, amps):
"""Returns the TWM loss scalar.
Args:
f0_candidates: Frequencies of candidates in hertz. [batch, time, freq].
freqs: Frequencies of sinusoids in hertz. [batch, time, feq].
amps: Amplitudes of sinusoids, must be greater than 0. [batch, time, feq].
Returns:
Scalar, weighted -log p(sinusoids|harmonics) - log p(harmonics|sinusoids),
with a softmin over f0 candidates to just consider the best candidates.
"""
sinusoids_loss, harmonics_loss = self.get_loss_tensors(
f0_candidates, freqs, amps)
# Loss is the minimum loss (loss for the best f0 candidate).
# Use softmin to make it differentiable.
combined_loss = (self.sinusoids_weight * sinusoids_loss +
self.harmonics_weight * harmonics_loss)
softmin_loss = combined_loss * tf.nn.softmax(
-combined_loss / self.softmin_temperature, axis=-1)
return tf.reduce_mean(softmin_loss)
def predict_f0(self, f0_candidates, freqs, amps):
"""Get the most likely f0 from the series of candidates.
Args:
f0_candidates: Frequencies of candidates in hertz. [batch, time, freq].
freqs: Frequencies of sinusoids in hertz. [batch, time, feq].
amps: Amplitudes of sinusoids, greater than 0. [batch, time, freq].
Returns:
f0_hz: Numpy array, most likely f0 among the candidates at each timestep.
Shape [batch, time, 1].
"""
sinusoids_loss, harmonics_loss = self.get_loss_tensors(
f0_candidates, freqs, amps)
loss = (self.sinusoids_weight * sinusoids_loss +
self.harmonics_weight * harmonics_loss)
# Argmin is not differentiable, switching to numpy, ignore nans for min.
f0_idx = np.nanargmin(np.array(loss), axis=-1)[..., np.newaxis]
f0_hz = np.take_along_axis(np.array(f0_candidates), f0_idx, axis=-1)
return f0_hz
def get_loss_tensors(self, f0_candidates, freqs, amps):
"""Get traces of loss to estimate fundamental frequency.
Args:
f0_candidates: Frequencies of candidates in hertz. [batch, time, freq].
freqs: Frequencies of sinusoids in hertz. [batch, time, feq].
amps: Amplitudes of sinusoids, greater than 0. [batch, time, freq].
Returns:
sinusoids_loss: -log p(sinusoids|harmonics), [batch, time, f0_candidate].
harmonics_loss: - log p(harmonics|sinusoids), [batch, time, f0_candidate].
"""
# ==========================================================================
# P(sinusoids | candidate_harmonics).
# ==========================================================================
p_sinusoids_given_harmonics = self.get_p_sinusoids_given_harmonics()
# Treat each partial as a candidate.
# Get the ratio of each partial to each candidate.
# -> [batch, time, candidate, partial]
freq_ratios = safe_divide(freqs[:, :, tf.newaxis, :],
f0_candidates[:, :, :, tf.newaxis])
nll_sinusoids = - p_sinusoids_given_harmonics.log_prob(freq_ratios)
a = tf.convert_to_tensor(amps[:, :, tf.newaxis, :])
# # Don't count sinusoids that are less than 1 std > mean.
# a_mean, a_var = tf.nn.moments(a, axes=-1, keepdims=True)
# a = tf.where(a > a_mean + 0.5 * a_var**0.5, a, tf.zeros_like(a))
# Weighted sum by sinusoid amplitude.
# -> [batch, time, candidate]
sinusoids_loss = safe_divide(tf.reduce_sum(nll_sinusoids * a, axis=-1),
tf.reduce_sum(a, axis=-1))
# ==========================================================================
# P(candidate_harmonics | sinusoids)
# ==========================================================================
p_harm_given_sin = self.get_p_harmonics_given_sinusoids(freqs, amps)
harmonics = self.get_candidate_harmonics(f0_candidates, as_midi=True)
# Need to rearrage shape as tfp expects, [sample_sh, batch_sh, event_sh].
# -> [candidate, harmonic, batch, time]
harmonics_transpose = tf.transpose(harmonics, [2, 3, 0, 1])
nll_harmonics_transpose = - p_harm_given_sin.log_prob(harmonics_transpose)
# -> [batch, time, candidate, harm]
nll_harmonics = tf.transpose(nll_harmonics_transpose, [2, 3, 0, 1])
# Prior decreasing importance of upper harmonics.
amps_prior = tf.linspace(
1.0, 1.0 / self.n_harmonic_points, self.n_harmonic_points)
harmonics_loss = (nll_harmonics *
amps_prior[tf.newaxis, tf.newaxis, tf.newaxis, :])
# Don't count loss for harmonics above nyquist.
# Reweight by the number of harmonics below nyquist,
# (so it doesn't just pick the highest frequency possible).
nyquist_midi = hz_to_midi(self.sample_rate / 2.0)
nyquist_mask = tf.where(harmonics < nyquist_midi,
tf.ones_like(harmonics_loss),
tf.zeros_like(harmonics_loss))
harmonics_loss *= safe_divide(
nyquist_mask, tf.reduce_mean(nyquist_mask, axis=-1, keepdims=True))
# Sum over harmonics.
harmonics_loss = tf.reduce_mean(harmonics_loss, axis=-1)
return sinusoids_loss, harmonics_loss
def get_p_sinusoids_given_harmonics(self):
"""Gets distribution of sinusoids given harmonics from candidate f0s.
Returns:
MixtureSameFamily, Gaussian distribution.
"""
# Normalized frequency (harmonic number), create equally spaced gaussians.
harmonics_probs = (tf.ones(self.n_harmonic_gaussians) /
self.n_harmonic_gaussians)
harmonics_loc = tf.range(1, self.n_harmonic_gaussians + 1, dtype=tf.float32)
# P(sinusoids | candidate_harmonics).
return tfd.MixtureSameFamily(
tfd.Categorical(harmonics_probs),
tfd.Normal(loc=harmonics_loc, scale=self.harmonics_scale))
def get_p_harmonics_given_sinusoids(self, freqs, amps):
"""Gets distribution of harmonics from candidate f0s given sinusoids.
Performs a gaussian kernel density estimate on the sinusoid points, with the
height of each gaussian component given by the sinusoidal amplitude.
Args:
freqs: Frequencies of sinusoids in hertz.
amps: Amplitudes of sinusoids, must be greater than 0.
Returns:
MixtureSameFamily, Gaussian distribution.
"""
# Gaussian KDE around each partial, height=amplitude, center=frequency.
sinusoids_midi = hz_to_midi(freqs)
# NLL can be a nan if sinusoid amps are all zero, add a small offset.
amps = tf.where(amps == 0.0, 1e-7 * tf.ones_like(amps), amps)
amps_norm = safe_divide(amps, tf.reduce_sum(amps, axis=-1, keepdims=True))
# P(candidate_harmonics | sinusoids)
return tfd.MixtureSameFamily(
tfd.Categorical(probs=amps_norm),
tfd.Normal(loc=sinusoids_midi, scale=self.sinusoids_scale))
def get_candidate_harmonics(self, f0_candidates, as_midi=True):
"""Build a harmonic series off of each candidate partial."""
n = tf.range(1, self.n_harmonic_points + 1, dtype=tf.float32)
# -> [batch, time, candidate, harmonic]
harmonics = (f0_candidates[:, :, :, tf.newaxis] *
n[tf.newaxis, tf.newaxis, tf.newaxis, :])
if as_midi:
harmonics = hz_to_midi(harmonics)
return harmonics