diff --git a/model/modules.py b/model/modules.py index db86e2a..dae77df 100644 --- a/model/modules.py +++ b/model/modules.py @@ -186,7 +186,7 @@ def forward(self, encoder_output, duration, mask): range_param_prediction = self.range_param_proj(range_param_prediction) range_param_prediction = range_param_prediction.squeeze(-1) # [B, L] if mask is not None: - range_param_prediction = range_param_prediction.masked_fill(mask, 0.0) + range_param_prediction = range_param_prediction.masked_fill(mask, 1e-8) return range_param_prediction @@ -220,6 +220,6 @@ def forward(self, encoder_output, audio, seq_starts=None, full_len=False): return encoder_output, audio if encoder_output.shape[1] > self.segment_length: encoder_segment = self.get_hidden_segment(encoder_output, seq_starts) - encoder_segment = self.pad_seq(encoder_output, self.segment_length) + encoder_segment = self.pad_seq(encoder_segment, self.segment_length) audio_segment = self.pad_seq(audio, self.segment_length_up) return encoder_segment, audio_segment