Skip to content

Commit

Permalink
Generic script to convert kaldi datasets to list file format
Browse files Browse the repository at this point in the history
Summary: As per title

Reviewed By: avidov

Differential Revision: D27444664

fbshipit-source-id: ff964962cec567fa26cf072175b5734a1adc7be4
  • Loading branch information
vineelpratap authored and facebook-github-bot committed Mar 31, 2021
1 parent 668ec5e commit 37aa855
Showing 1 changed file with 139 additions and 0 deletions.
139 changes: 139 additions & 0 deletions data/utils/kaldi_to_listfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""
Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
----------
Script to package kaldi data directory into a form readable in
wav2letter++ pipelines
Command : python3 prepare.py --src [...] --dst [...]
Replace [...] with appropriate path
`src` directory is the path to kaldi data directory typically
prepared with `prepare_data.sh` script.
`dst` directory is the path to store (segmented) audio files and the
list file that is used by wav2letter++ pipelines to load data.
"""

from __future__ import (
absolute_import,
division,
print_function,
unicode_literals,
)

import argparse
import os
import re
from multiprocessing import Pool

import sox
from tqdm import tqdm


def run_segment(item):
uid, val = item
infile, start_sec, end_sec, outfile = val
sox_tfm = sox.Transformer()
sox_tfm.set_output_format(
file_type="flac", encoding="signed-integer", bits=16
)
sox_tfm.trim(start_sec, end_sec)
sox_tfm.build(infile, outfile)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Script to generate list file from Kaldi data dir"
)
parser.add_argument(
"--src",
help="input kaldi data directory. Must contain "
"'text', 'segments' and 'wav.scp' files",
)
parser.add_argument(
"--dst", help="destination directory where to store data",
)
parser.add_argument(
"--name", help="name of the output list file", default="data.lst"
)
parser.add_argument(
"-p",
"--process",
help="number of process for multiprocessing",
default=8,
type=int,
)

args = parser.parse_args()

wav_files = {}
cache = {}
cmds = []
with open(f"{args.src}/wav.scp") as f:
for line in f:
# handles two possible cases
# Case 1: ID followed by wav file
# Ex: S03_U01.CH1 /path/S03_U01.CH1.wav
# Case 2: ID followed by sox script
# Ex: P09_S03.L sox /path/S03_P09.wav -t wav - remix 1 |
wid, wav_handle = line.strip().split(" ", 1)
if wav_handle in cache:
wav_file = cache[wav_handle]
elif wav_handle.startswith("sox"):
hsh = re.sub("[^0-9a-zA-Z]+", "", wav_handle)
wav_file = "/tmp/{}.wav".format(hsh)
cmds.append(
wav_handle.replace(" - ", " " + wav_file + " ").replace(
"|", ""
)
)
else:
wav_file = wav_handle
wav_files[wid] = wav_file
print("Found {} wav files".format(len(wav_files)))

print("Running {} wav commands ...".format(len(cmds)))

def run_command(cmd):
os.system(cmd)

p = Pool(args.process)
list(tqdm(p.imap(run_command, cmds), total=len(cmds),))

transcripts = {}
with open(f"{args.src}/text") as f:
for line in f:
line_split = line.strip().split()
transcripts[line_split[0]] = " ".join(line_split[1:])
print("Found {} transcripts".format(len(transcripts)))

segments = {}
with open(f"{args.src}/segments") as f:
for line in f:
uid, wid, start_sec, end_sec = line.strip().split(" ", 3)
start_sec = float(start_sec)
end_sec = float(end_sec)
outfile = f"{args.dst}/audio/{uid}.flac"
segments[uid] = (wav_files[wid], start_sec, end_sec, outfile)
print("Found {} segments".format(len(segments)))

os.makedirs(f"{args.dst}", exist_ok=True)
os.makedirs(f"{args.dst}/audio", exist_ok=True)

print("Creating segmented audio files ...")
list(tqdm(p.imap(run_segment, segments.items()), total=len(segments),))

print("Writing to list file ...")
with open(f"{args.dst}/{args.name}", "w") as fo:
for uid, val in segments.items():
_, start_sec, end_sec, outfile = val
duration = "{:.2f}".format((end_sec - start_sec) * 1000)
fo.write(
"\t".join([uid, outfile, duration, transcripts[uid]]) + "\n"
)

print("Done!")

0 comments on commit 37aa855

Please sign in to comment.