Skip to content

Commit

Permalink
fix: imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik committed Oct 29, 2022
1 parent 508852b commit 69d420d
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions scripts/run_tama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
from pathlib import Path
from typing import List, Union

from entrainment import calculate_sample_correlation, calculate_time_series
from frame import Frame, MissingFrame
from tama import get_frames
from utils import print_audio_description
from speechalyze import tama
from speechalyze.utils import print_audio_description

arg_parser = argparse.ArgumentParser(
description="Generate a times series for a speaker for a task"
Expand Down Expand Up @@ -47,34 +45,38 @@ def main() -> None:

wav_a_fname: Path = Path(args.audio_file_a)
words_a_fname: Path = Path(args.words_file_a)
frames_a: List[Union[Frame, MissingFrame]] = get_frames(wav_a_fname, words_a_fname)
frames_a: List[Union[tama.Frame, tama.MissingFrame]] = tama.get_frames(
wav_a_fname, words_a_fname
)
print(f"Amount of frames of speaker A: {len(frames_a)}")

print_audio_description("A", wav_a_fname)

wav_b_fname: Path = Path(args.audio_file_b)
words_b_fname: Path = Path(args.words_file_b)
frames_b: List[Union[Frame, MissingFrame]] = get_frames(wav_b_fname, words_b_fname)
frames_b: List[Union[tama.Frame, tama.MissingFrame]] = tama.get_frames(
wav_b_fname, words_b_fname
)
print_audio_description("B", wav_b_fname)
print(f"Amount of frames of speaker B: {len(frames_b)}")

if len(frames_a) != len(frames_b):
raise ValueError("The amount of frames of each speaker is different")

time_series_a: List[float] = calculate_time_series(
time_series_a: List[float] = tama.calculate_time_series(
args.feature, frames_a, wav_a_fname, args.extractor, args.pitch_gender_a
)
print("----------------------------------------")
print(f"Time series of A: {time_series_a}")

time_series_b: List[float] = calculate_time_series(
time_series_b: List[float] = tama.calculate_time_series(
args.feature, frames_b, wav_b_fname, args.extractor, args.pitch_gender_b
)
print(f"Time series of B: {time_series_b}")
print("----------------------------------------")

print("Sample cross-correlation")
sample_cross_correlations: List[float] = calculate_sample_correlation(
sample_cross_correlations: List[float] = tama.calculate_sample_correlation(
time_series_a, time_series_b, int(args.lags)
)
print(f"Correlations with lag from 0 to {args.lags}: {sample_cross_correlations}")
Expand Down

0 comments on commit 69d420d

Please sign in to comment.