Skip to content

Commit

Permalink
Add Unit test for SDS demo model
Browse files Browse the repository at this point in the history
  • Loading branch information
Siddhant committed Jan 2, 2025
1 parent c82d743 commit 1d21ffd
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
2 changes: 2 additions & 0 deletions espnet2/sds/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def handle_type_selection(
)
if option == "Cascaded":
self.client = None
self.type_option = "Cascaded"
for _ in self.handle_TTS_selection(TTS_radio):
continue
for _ in self.handle_ASR_selection(ASR_radio):
Expand All @@ -257,6 +258,7 @@ def handle_type_selection(
gr.Radio(visible=False),
)
else:
self.type_option = "E2E"
self.text2speech = None
self.s2t = None
self.LM_pipe = None
Expand Down
88 changes: 88 additions & 0 deletions test/espnet2/sds/test_espnet_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest
import soundfile
import torch

from espnet2.sds.espnet_model import ESPnetSDSModelInterface

pytest.importorskip("gradio")


def test_forward():
pytest.importorskip("webrtcvad")
if not torch.cuda.is_available():
return # Only GPU supported
dialogue_model = ESPnetSDSModelInterface(
ASR_option="librispeech_asr",
LLM_option="HuggingFaceTB/SmolLM2-1.7B-Instruct",
TTS_option="kan-bayashi/ljspeech_vits",
type_option="Cascaded",
access_token="",
)
x, rate = soundfile.read("test_utils/ctc_align_test.wav", dtype="int16")
gen = dialogue_model.handle_type_selection(
option="Cascaded",
TTS_radio="kan-bayashi/ljspeech_vits",
ASR_radio="librispeech_asr",
LLM_radio="HuggingFaceTB/SmolLM2-1.7B-Instruct",
)
for _ in gen:
continue
dialogue_model.forward(
x,
rate,
x,
asr_output_str=None,
text_str=None,
audio_output=None,
audio_output1=None,
latency_ASR=0.0,
latency_LM=0.0,
latency_TTS=0.0,
)


def test_handle_E2E_selection():
pytest.importorskip("pydub")
pytest.importorskip("espnet2.sds.end_to_end.mini_omni.inference")
pytest.importorskip("huggingface_hub")
if not torch.cuda.is_available():
return # Only GPU supported
dialogue_model = ESPnetSDSModelInterface(
ASR_option="librispeech_asr",
LLM_option="HuggingFaceTB/SmolLM2-1.7B-Instruct",
TTS_option="kan-bayashi/ljspeech_vits",
type_option="Cascaded",
access_token="",
)
x, rate = soundfile.read("test_utils/ctc_align_test.wav", dtype="int16")
dialogue_model.handle_type_selection(
option="E2E",
TTS_radio="kan-bayashi/ljspeech_vits",
ASR_radio="librispeech_asr",
LLM_radio="HuggingFaceTB/SmolLM2-1.7B-Instruct",
)
assert dialogue_model.text2speech is None
assert dialogue_model.s2t is None
assert dialogue_model.LM_pipe is None
assert dialogue_model.ASR_curr_name is None
assert dialogue_model.LLM_curr_name is None
assert dialogue_model.TTS_curr_name is None
dialogue_model.forward(
x,
rate,
x,
asr_output_str=None,
text_str=None,
audio_output=None,
audio_output1=None,
latency_ASR=0.0,
latency_LM=0.0,
latency_TTS=0.0,
)
dialogue_model.handle_type_selection(
option="Cascaded",
TTS_radio="kan-bayashi/ljspeech_vits",
ASR_radio="librispeech_asr",
LLM_radio="HuggingFaceTB/SmolLM2-1.7B-Instruct",
)
assert dialogue_model.client is None

0 comments on commit 1d21ffd

Please sign in to comment.