Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IO 重构,提升多机多卡训练效率 + 代码复用 #2152

Closed
robin1001 opened this issue Nov 18, 2023 · 22 comments
Closed

IO 重构,提升多机多卡训练效率 + 代码复用 #2152

robin1001 opened this issue Nov 18, 2023 · 22 comments
Assignees
Labels
enhancement New feature or request Stale

Comments

@robin1001
Copy link
Collaborator

No description provided.

@robin1001
Copy link
Collaborator Author

现在 IO 的主要瓶颈是在哪里?

@xingchensong
Copy link
Member

xingchensong commented Nov 18, 2023

  1. 一个shuffle size 数据 跑完,下一个 shuffle size 的时候需要等待挺长时间的  #2095
  2. gpu利用率吃不满,尖峰很多
  3. 未来引入多模态之后,io还要考虑其他模态

@Mddct
Copy link
Collaborator

Mddct commented Nov 20, 2023

星辰提到的3, 可以加个多任务, 下边举个例子
以下是来自chatgpt的回答

在多任务训练中,批次(batch)的组织和损失函数的定义需要考虑多个任务之间的关系和权衡。下面是一些常见的方法:

批次组织:在组织批次时,可以采取以下策略:

同时从每个任务的数据集中抽取一定数量的样本,形成一个批次。这种方法适用于任务之间的数据量相近或需要平衡处理。
为每个任务设置不同的批次大小,根据任务的重要性或数据分布进行调整。较重要的任务可以分配更大的批次大小,以便更充分地更新模型参数。
根据硬件资源和内存限制,将批次大小限制在可接受的范围内。
......

这个例子中,我们需要很方便初始化dataset1 dataset2 dataset3.

dataset = interleave(dataset1, dataset2, dataset3)

我们需要这要的功能,并且不止是多任务, 对于语音任务 在一个batch上我们也需要组合不同领域的数据进行tune

@Mddct
Copy link
Collaborator

Mddct commented Nov 25, 2023

另外,补充下 , 现在recognize.py 依赖的dataset 需要有label, 重构的dataset需要考虑单纯的infer
比如:

dataset(split='train') # speech and label
dataset(split='cv') # speech and label
dataset(split='infer') # only speech

对于和大模型的结合:

textdataset # only text
speechdataset # only speech

...

@xingchensong xingchensong changed the title IO 重构,提升多机多卡训练效率 @Mddct @xingchensong IO 重构,提升多机多卡训练效率 Nov 27, 2023
@Mddct
Copy link
Collaborator

Mddct commented Nov 28, 2023

当前瓶颈 在于:

1 同步的取数据 (及时有prefetch)
2 可以有并行的地方, 比如已经下载到内存的shard/wavs 可以parallel decode fbank等(现在极度依赖num_worker, num_worker 过大会有bus error问题 (并且num worker并行多的话 tar多了 , 我们也需要的是非tar级的并发), 如果并行起来 就可能是num_worker * parallel_num),

方案一:集成huggingface datasets:

1 优点:

  • 原生函数:
    • map (parallel_map)
    • filter
    • sort (此函数有坑, 没有buffer size, 还没看具体实现)
    • shuffle
    • interleave
    • etc

这些原生函数的实现,dtaset 会极度简化 和多模态结合 也会很方便
预计会减少gpu尖峰, 但是不会像tfdata那样(没有解藕生产消费):

2 缺点:

  • prefetch 依赖pytorch 的 DataLoader prefetch
  • 不支持tar , 原生支持arrow, 如果要支持tar, 需要自己写 _generate_examples, 但是yield是python 这里慢的话 上述能够并行的地方也会被限制,(原生arrow底层是c++实现)

方案二: 造轮子

抽象现有的process函数补&&齐&添加新函数

  • parallel_map (multi_processing)
  • filter
  • sort
  • interleave
  • etc

缺点:

  • 新功能需要一直添加

优点:

  • 简洁
  • 支持现有的所有io形式
  • 其他模态也可以支持

ref: https://huggingface.co/docs/datasets/stream

补充下:大模型时代,好多吧动态batch这种功能给去掉了,hg也不支持(tfdata 支持), 比如whisper 直接encoder 30s , LLM也有直接pad到预先设置的最大(非batch)

@Mddct Mddct added the enhancement New feature or request label Nov 28, 2023
@xingchensong
Copy link
Member

image

方案二能快速支持arrow格式的数据吗

@Mddct
Copy link
Collaborator

Mddct commented Nov 28, 2023

image

方案二能快速支持arrow格式的数据吗

应该可以的, hg也是用的arrow的py包

ref: https://arrow.apache.org/docs/python/

@xingchensong
Copy link
Member

可以,那支持方案二

@Mddct
Copy link
Collaborator

Mddct commented Nov 29, 2023

新IO需要考虑determinism(data 层面保证可复现)
shuffle 需要设置seed

@Mddct
Copy link
Collaborator

Mddct commented Nov 29, 2023

测试multiprocess 和multithread:

100条音频, 每条9s , ‘并发“计算fbank

Screenshot 2023-11-29 at 19 40 09
  • 单条计算0.02,
  • 100条顺序算为2左右,
  • 多线程(nthread=100)也为2左右, (cpu bound 变成串行 gil)
  • 多进程(nproc=100)为0.1

@Mddct
Copy link
Collaborator

Mddct commented Nov 30, 2023

关于代码复用性, 下边是代码片段

# 这里我们可以继承IterableDataset
class Dataset():
  def __init__(self, dataset, func=None, *args, **kwargs):
    self._dataset = dataset
    self.args = args
    self.kwargs = kwargs
    self.func = func

  @staticmethod
  def from_source(source):
    return Dataset(source)

  def __iter__(self):
    return self

  def __next__(self):
    if not self._dataset:
      raise StopIteration
    data = next(self._dataset)
    return self.func(data)

  def map(self, func, *args, **kwargs):
    return MapperDataset(self, func, *args, **kwargs)

  def filter(self, func, *args, **kwargs):
    return FilterDataset(self, func, *args, **kwargs)

class MapperDataset(Dataset):
    def __init__(self, dataset, func=None, *args, **kwargs):
      self._dataset = dataset
      self.args = args
      self.kwargs = kwargs
      self.func = func

    def __iter__(self):
      return self

    def __next__(self):
      if not self._dataset:
        raise StopIteration
      data = next(self._dataset)
      return self.func(data)

class FilterDataset(Dataset):
    def __init__(self, dataset, func=None, *args, **kwargs):
      self._dataset = dataset
      self.args = args
      self.kwargs = kwargs
      self.func = func

    def __iter__(self):
      return self

    def __next__(self):
      if not self._dataset:
        raise StopIteration
      data = next(self._dataset)
      while not self.func(data):
        data = next(self._dataset)
      return data

source = iter([1,2,3,4])
dataset = Dataset(source, lambda elem: elem)
dataset = dataset.map(lambda elem: {"speech": elem*2})
dataset = dataset.filter(lambda elem_dict: elem_dict['speech'] > 2)
for d in dataset:
  print(d)

# output:
{'speech': 4}
{'speech': 6}
{'speech': 8}

@Mddct Mddct changed the title IO 重构,提升多机多卡训练效率 IO 重构,提升多机多卡训练效率 + 代码复用 Nov 30, 2023
@echocatzh
Copy link

wenet的训练脚本在evaluate时只能单卡去过整个eval dataset,这部分感觉也是可以优化的,是否可以通过继承torch的sampler来实现利用DDP加速。比如:

from torch.utils.data.distributed import DistributedSampler
from catalyst.data.sampler import DistributedSamplerWrapper

dataset = ...
shuffle = ...
sampler = ...

# If DDP on
if torch.distributed.is_initialized():
    # If using a custom sampler make it distributed
    if sampler is not None:
        sampler = DistributedSamplerWrapper(sampler,
                                            shuffle=shuffle,
                                            num_replicas=communication.get_world_size(),
                                            rank=communication.get_rank())
    # If no custom sampler then just use the DistributedSampler
    else:
        sampler = DistributedSampler(dataset,
                                     shuffle=shuffle,
                                     num_replicas=communication.get_world_size(),
                                     rank=communication.get_rank())

# shuffle shouldn't be specified in DataLoader when using a sampler
shuffle = shuffle if sampler is None else None
dataloader = DataLoader(dataset, sampler=sampler, shuffle=shuffle, ...)

@robin1001
Copy link
Collaborator Author

是的,现在 eval 的时候是单卡 eval 所有,没有发挥多卡优势。这块是个优化点。

@echocatzh
Copy link

echocatzh commented Dec 3, 2023

是的,现在 eval 的时候是单卡 eval 所有,没有发挥多卡优势。这块是个优化点。

期待wenet的新版脚本哈哈

@Mddct
Copy link
Collaborator

Mddct commented Dec 6, 2023

modified code:

class Dataset:

    def __init__(self, source, f=lambda elem: elem, *args, **kw):
        assert callable(f)
        self._dataset = source
        self.f = f
        self.args = args
        self.kw = kw

    def set_epoch(self, epoch):
        self.source.set_epoch(epoch)

    def __iter__(self):
        """ Return an iterator over the source dataset processed by the
            given processor.
        """
        assert self._dataset is not None
        assert callable(self.f)
        for data in self._dataset:
          yield data

    def apply(self, f):
        assert callable(f)
        return Dataset(self, f, *self.args, **self.kw)

    def map(self, func, *args, **kwargs):
        return MapperDataset(self, func, *args, **kwargs)

    def filter(self, func, *args, **kwargs):
        return FilterDataset(self, func, *args, **kwargs)

    def sort(self, func, *args, **kwargs):
        return SortDataset(self, func, *args, **kwargs)

    def zip(self, *datasets):
      return ZipDataset(self, *datasets)



class MapperDataset(Dataset):
    def __init__(self, dataset, func=None, *args, **kwargs):
        self._dataset = dataset
        self.args = args
        self.kwargs = kwargs
        self.func = func

    def __iter__(self):
        return self._generator()

    def _generator(self):
        for data in self._dataset:
            yield self.func(data, *self.args, **self.kwargs)

class FilterDataset(Dataset):
    def __init__(self, dataset, func=None, *args, **kwargs):
        self._dataset = dataset
        self.args = args
        self.kwargs = kwargs
        self.func = func

    def __iter__(self):
        return self._generator()

    def _generator(self):
        for data in self._dataset:
            if self.func(data, *self.args, **self.kwargs):
                yield data

class SortDataset(Dataset):
    def __init__(self, dataset, key=None, reverse=False, buffer_size=None):
        self._dataset = dataset
        self.key = key
        self.reverse = reverse
        self.buffer_size = buffer_size

    def __iter__(self):
        return self._generator()

    def _generator(self):
        buffer = []
        for data in self._dataset:
            buffer.append(data)
            if self.buffer_size is not None and len(buffer) >= self.buffer_size:
                sorted_buffer = sorted(buffer, key=self.key, reverse=self.reverse)
                for sorted_data in sorted_buffer:
                    yield sorted_data
                buffer.clear()
        if buffer:
            sorted_buffer = sorted(buffer, key=self.key, reverse=self.reverse)
            for sorted_data in sorted_buffer:
                yield sorted_data

class ZipDataset(Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __iter__(self):
        return self._generator()

    def _generator(self):
        iterators = [iter(dataset) for dataset in self.datasets]
        while True:
            try:
                data = [next(iterator) for iterator in iterators]
                yield tuple(data)
            except StopIteration:
                return

class PaddingBatchDataset(Dataset):
    def __init__(self, dataset, batch_size, padding_fn, max_length_fn):
        self.dataset = dataset
        self.batch_size = batch_size
        self.padding_fn = padding_fn
        self.max_length_fn = max_length_fn

    def __iter__(self):
        return self._generator()

    def _generator(self):
        batch = []
        max_length = 0
        for data in self.dataset:
            batch.append(data)
            max_length = self.max_length_fn(data, max_length)
            if len(batch) == self.batch_size:
                padded_batch = self._pad_batch(batch, max_length)
                yield padded_batch
                batch = []
                max_length = 0
        if batch:
            padded_batch = self._pad_batch(batch, max_length)
            yield padded_batch

    def _pad_batch(self, batch, max_length):
        padded_batch = []
        for data in batch:
            padding_length = max_length - self.max_length_fn(data)
            padded_data = self.padding_fn(data, padding_length)
            padded_batch.append(padded_data)
        return padded_batch
  
# 创建数据源
def generator(data):
  for d in data:
    yield d

source = generator([1,2,3,4,1])

# 创建 Dataset 实例
speech_dataset = Dataset(source)

# preprocess
speech_dataset = speech_dataset.map(lambda elem: {"speech": elem * 2})
speech_dataset = speech_dataset.filter(lambda elem_dict: elem_dict['speech'] >= 2)
speech_dataset = speech_dataset.sort(lambda elem_dict: elem_dict['speech'], buffer_size=2)
# fbank
speech_dataset = speech_dataset.map(lambda elem_dict: {'fbank': elem_dict['speech'] + 1, 'speech': elem_dict['speech']})

llm_dataset = Dataset(generator([10,20,30,40,50,60]))
# eg tokenize
llm_dataset = llm_dataset.map(lambda elem: {"tokens": elem + 1 , "text": elem})

task_dataset = speech_dataset.zip(llm_dataset)
task_dataset = task_dataset.sort(lambda elem: elem[1]['tokens'])

# # 迭代并打印结果
for data in task_dataset:
    print(data)
# output:
({'fbank': 3, 'speech': 2}, {'tokens': 11, 'text': 10})
({'fbank': 5, 'speech': 4}, {'tokens': 21, 'text': 20})
({'fbank': 7, 'speech': 6}, {'tokens': 31, 'text': 30})
({'fbank': 9, 'speech': 8}, {'tokens': 41, 'text': 40})
({'fbank': 3, 'speech': 2}, {'tokens': 51, 'text': 50})

@kobenaxie
Copy link
Contributor

是不是可以直接用lhotse呢?nemo中也集成了lhotse-shard模式 .

@kobenaxie
Copy link
Contributor

是不是可以直接用lhotse呢?nemo中也集成了lhotse-shard模式 .

from lhotse.serialization import open_best

def iterate_tarfile_pairwise(
     tar_file: tarfile.TarFile,
 ):
     result = []
     for tarinfo in tar_file:
         if len(result) == 2:
             yield tuple(result)
             result = []
         result.append(parse_tarinfo(tarinfo, tar_file))

     if len(result) == 2:
         yield tuple(result)

     if len(result) == 1:
         raise RuntimeError(
             "Uneven number of files in the tarfile (expected to iterate pairs of text and binary data)."
         )

def parse_tarinfo(
     tarinfo: tarfile.TarInfo,
     tar_file: tarfile.TarFile,
 ):
     """
     Parse a tarinfo object and return the data it points to as well as the internal path.
     """
     path = Path(tarinfo.path)
     suffix = path.suffix.strip(".")

     raw_data = tar_file.extractfile(tarinfo)
     if suffix == "txt":
         txt = raw_data.read().decode("utf-8").strip()
         return (path.name, txt)
     elif suffix in AUDIO_FORMAT_SETS:
         waveform, sample_rate = torchaudio.load(raw_data)
         return (waveform, sample_rate)
     else:
         raise RuntimeError(
             f"Not support file format: {suffix}"
         )

def parse_tar(data):
     for sample in data:
         assert "src" in sample, sample.keys()
         url = sample["src"]
         try:
             with tarfile.open(fileobj=open_best(url, mode="rb"), mode="r|*") as tar:
                 for (key, txt), (waveform, sample_rate) in iterate_tarfile_pairwise(tar):
                     yield {
                         "key": key,
                         "wav": waveform,
                         "sample_rate": sample_rate,
                         "txt": txt,
                     }
         except Exception as ex:
             logging.warning(f"Failed to open {url}")

第一步,先简化解析tar包的逻辑,这样url_openertar_file_and_group可以用parse_tar替换了

@Mddct
Copy link
Collaborator

Mddct commented Jan 17, 2024

lthoste 和wenet现在的本质上没有区别,

  • 要加速cpu bound 的transform 多个item, multithread 不行 需要multipricessing, (比如一个tar已经在mem里了, 现在的并发度在shard级别不在shard内部) 这时候可以多进程提多个wav特征) 二者皆有这个问题

  • 要wrapper 适配各种任务, 在现有的wenet上简单抽象下就行

@Mddct
Copy link
Collaborator

Mddct commented Jan 20, 2024

目前torch官方已经提供了data chain 调用方式: https://github.com/pytorch/pytorch/tree/main/torch/utils/data/datapipes (非torchdata), 并且wenet已经升级到2.xx了 所以倾向于使用torch官方的, 下边是初步代码 @xingchensong @robin1001

import io
import json
import tarfile
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, IterDataPipe
from torch.utils.data import datapipes
import torchaudio
from torchaudio._extension import logging
from torchaudio.compliance.kaldi import fbank

from wenet.dataset.processor import compute_fbank

AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])


class WenetSourceDataPipe(IterDataPipe):

    def __init__(self, dp, data_type='raw', **fmtparams):
        self.dp = datapipes.iter.ShardingFilter(
            datapipes.iter.FileOpener(dp, mode='r'))
        datapipes.iter.StreamReader
        self.data_type = data_type
        self.fmtparams = fmtparams

    def __iter__(self):
        for _, stream in self.dp:
            for line in stream:
                line = line.strip('\n')
                if self.data_type == 'raw':
                    json_obj = json.loads(line)
                    with open(json_obj['wav'], 'rb') as f:
                        json_obj['wav'] = f.read()
                    yield json_obj
                else:
                    yield {'stream': open(line, 'rb')}


class TarFileGroupSourceDataPipe(IterDataPipe):

    def __init__(self, dp) -> None:
        super().__init__()
        self.dp = dp

    def __iter__(self):
        for sample in self.dp:
            try:
                # stream = tarfile.open(fileobj=sample['stream'], mode="r:*")
                with tarfile.open(fileobj=sample['stream'],
                                  mode="r:*") as stream:
                    prev_prefix = None
                    example = {}
                    valid = True
                    for tarinfo in stream:
                        name = tarinfo.name
                        pos = name.rfind('.')
                        assert pos > 0
                        prefix, postfix = name[:pos], name[pos + 1:]
                        if prev_prefix is not None and prefix != prev_prefix:
                            example['key'] = prev_prefix
                            if valid:
                                yield example
                            example = {}
                            valid = True
                        with stream.extractfile(tarinfo) as file_obj:
                            try:
                                if postfix == 'txt':
                                    example['txt'] = file_obj.read().decode(
                                        'utf8').strip()
                                elif postfix in AUDIO_FORMAT_SETS:
                                    example['wav'] = file_obj.read()
                                else:
                                    example[postfix] = file_obj.read()
                            except Exception as ex:
                                valid = False
                                logging.warning(
                                    'error to parse {}'.format(name))
                            prev_prefix = prefix
                    if prev_prefix is not None:
                        example['key'] = prev_prefix
                        yield example
            except Exception as ex:
                logging.warning(
                    'In tar_file_and_group: {} when processing '.format(
                        ex))  #, sample['src']))
            finally:
                stream.close()
                if 'process' in sample:
                    sample['process'].communicate()
                sample['stream'].close()


def decode_wav(elem):
    wav = elem['wav']
    key = elem['key']
    txt = elem['txt']
    with io.BytesIO(wav) as file_obj:
        waveform, sr = torchaudio.load(file_obj)
    return {"key": key, "txt": txt, 'waveform': waveform, "sample_rate": sr}


def compute_fbank(data,
                  num_mel_bins=23,
                  frame_length=25,
                  frame_shift=10,
                  dither=0.0):

    sample_rate = data['sample_rate']
    waveform = data['waveform']
    waveform = waveform * (1 << 15)
    mat = fbank(waveform,
                num_mel_bins=num_mel_bins,
                frame_length=frame_length,
                frame_shift=frame_shift,
                dither=dither,
                energy_floor=0.0,
                sample_frequency=sample_rate)
    data['feat'] = mat
    return data


def padding(data):
    assert isinstance(data, list)
    sample = data
    feats_length = torch.tensor([x['feat'].size(0) for x in sample],
                                dtype=torch.int32)
    order = torch.argsort(feats_length, descending=True)
    feats_lengths = torch.tensor([sample[i]['feat'].size(0) for i in order],
                                 dtype=torch.int32)
    sorted_feats = [sample[i]['feat'] for i in order]
    sorted_keys = [sample[i]['key'] for i in order]
    padded_feats = pad_sequence(sorted_feats,
                                batch_first=True,
                                padding_value=0)
    batch = {
        "keys": sorted_keys,
        "feats": padded_feats,
        "feats_lengths": feats_lengths,
    }
    return batch


def get_dataloader(data_type, files):
    dataset = WenetSourceDataPipe(files, data_type)
    # shard by files
    if data_type == 'shard':
        dataset = WenetSourceDataPipe(files, data_type=data_type)

        dataset = TarFileGroupSourceDataPipe(dataset)

    dataset = dataset.map(decode_wav)
    dataset = dataset.map(compute_fbank)
    dataset = dataset.batch(wrapper_class=padding, batch_size=2)
    dataloader = DataLoader(dataset,
                            batch_size=None,
                            num_workers=4,
                            persistent_workers=True)

    return dataloader


if __name__ == '__main__':
    raw_dataloader = get_dataloader('raw',
                                    ['test/resources/dataset/data.list'])
    tar_dataloader = get_dataloader(
        'shard', ['test/resources/dataset/data.shards.list'])

    print("--------" + "wenet raw data type" + '---------\n')
    for raw_batch in raw_dataloader:
        print(raw_batch)

    print("\n--------" + "wenet shard data type" + '---------\n')
    for shard_batch in tar_dataloader:
        print(shard_batch)
Screenshot 2024-01-21 at 01 40 31

之后重构思路:
1 datasetsource (支持auto shard, shard by line/files)
2 processor 保持原有for-loop 内部的“elem” 处理逻辑, 使用map filter 等方式调用

优势:

  • 比如whisper hybrid tokenizer 可以自己 构造自己任务:
    1 datasetsource
    2 feats
    3 hybrid tokenizer (只需要自己写’elem‘的函数, 然后map下)
    4 batch
    同时复用现在yaml里边
dataset: whisper_dataset. # 以前是ASRDataset
  • 对于tts/llm, 同上

@kobenaxie
Copy link
Contributor

目前torch官方已经提供了data chain 调用方式: https://github.com/pytorch/pytorch/tree/main/torch/utils/data/datapipes (非torchdata), 并且wenet已经升级到2.xx了 所以倾向于使用torch官方的, 下边是初步代码 @xingchensong @robin1001

import io
import json
import tarfile
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, IterDataPipe
from torch.utils.data import datapipes
import torchaudio
from torchaudio._extension import logging
from torchaudio.compliance.kaldi import fbank

from wenet.dataset.processor import compute_fbank

AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])


class WenetSourceDataPipe(IterDataPipe):

    def __init__(self, dp, data_type='raw', **fmtparams):
        self.dp = datapipes.iter.ShardingFilter(
            datapipes.iter.FileOpener(dp, mode='r'))
        datapipes.iter.StreamReader
        self.data_type = data_type
        self.fmtparams = fmtparams

    def __iter__(self):
        for _, stream in self.dp:
            for line in stream:
                line = line.strip('\n')
                if self.data_type == 'raw':
                    json_obj = json.loads(line)
                    with open(json_obj['wav'], 'rb') as f:
                        json_obj['wav'] = f.read()
                    yield json_obj
                else:
                    yield {'stream': open(line, 'rb')}


class TarFileGroupSourceDataPipe(IterDataPipe):

    def __init__(self, dp) -> None:
        super().__init__()
        self.dp = dp

    def __iter__(self):
        for sample in self.dp:
            try:
                # stream = tarfile.open(fileobj=sample['stream'], mode="r:*")
                with tarfile.open(fileobj=sample['stream'],
                                  mode="r:*") as stream:
                    prev_prefix = None
                    example = {}
                    valid = True
                    for tarinfo in stream:
                        name = tarinfo.name
                        pos = name.rfind('.')
                        assert pos > 0
                        prefix, postfix = name[:pos], name[pos + 1:]
                        if prev_prefix is not None and prefix != prev_prefix:
                            example['key'] = prev_prefix
                            if valid:
                                yield example
                            example = {}
                            valid = True
                        with stream.extractfile(tarinfo) as file_obj:
                            try:
                                if postfix == 'txt':
                                    example['txt'] = file_obj.read().decode(
                                        'utf8').strip()
                                elif postfix in AUDIO_FORMAT_SETS:
                                    example['wav'] = file_obj.read()
                                else:
                                    example[postfix] = file_obj.read()
                            except Exception as ex:
                                valid = False
                                logging.warning(
                                    'error to parse {}'.format(name))
                            prev_prefix = prefix
                    if prev_prefix is not None:
                        example['key'] = prev_prefix
                        yield example
            except Exception as ex:
                logging.warning(
                    'In tar_file_and_group: {} when processing '.format(
                        ex))  #, sample['src']))
            finally:
                stream.close()
                if 'process' in sample:
                    sample['process'].communicate()
                sample['stream'].close()


def decode_wav(elem):
    wav = elem['wav']
    key = elem['key']
    txt = elem['txt']
    with io.BytesIO(wav) as file_obj:
        waveform, sr = torchaudio.load(file_obj)
    return {"key": key, "txt": txt, 'waveform': waveform, "sample_rate": sr}


def compute_fbank(data,
                  num_mel_bins=23,
                  frame_length=25,
                  frame_shift=10,
                  dither=0.0):

    sample_rate = data['sample_rate']
    waveform = data['waveform']
    waveform = waveform * (1 << 15)
    mat = fbank(waveform,
                num_mel_bins=num_mel_bins,
                frame_length=frame_length,
                frame_shift=frame_shift,
                dither=dither,
                energy_floor=0.0,
                sample_frequency=sample_rate)
    data['feat'] = mat
    return data


def padding(data):
    assert isinstance(data, list)
    sample = data
    feats_length = torch.tensor([x['feat'].size(0) for x in sample],
                                dtype=torch.int32)
    order = torch.argsort(feats_length, descending=True)
    feats_lengths = torch.tensor([sample[i]['feat'].size(0) for i in order],
                                 dtype=torch.int32)
    sorted_feats = [sample[i]['feat'] for i in order]
    sorted_keys = [sample[i]['key'] for i in order]
    padded_feats = pad_sequence(sorted_feats,
                                batch_first=True,
                                padding_value=0)
    batch = {
        "keys": sorted_keys,
        "feats": padded_feats,
        "feats_lengths": feats_lengths,
    }
    return batch


def get_dataloader(data_type, files):
    dataset = WenetSourceDataPipe(files, data_type)
    # shard by files
    if data_type == 'shard':
        dataset = WenetSourceDataPipe(files, data_type=data_type)

        dataset = TarFileGroupSourceDataPipe(dataset)

    dataset = dataset.map(decode_wav)
    dataset = dataset.map(compute_fbank)
    dataset = dataset.batch(wrapper_class=padding, batch_size=2)
    dataloader = DataLoader(dataset,
                            batch_size=None,
                            num_workers=4,
                            persistent_workers=True)

    return dataloader


if __name__ == '__main__':
    raw_dataloader = get_dataloader('raw',
                                    ['test/resources/dataset/data.list'])
    tar_dataloader = get_dataloader(
        'shard', ['test/resources/dataset/data.shards.list'])

    print("--------" + "wenet raw data type" + '---------\n')
    for raw_batch in raw_dataloader:
        print(raw_batch)

    print("\n--------" + "wenet shard data type" + '---------\n')
    for shard_batch in tar_dataloader:
        print(shard_batch)
Screenshot 2024-01-21 at 01 40 31 之后重构思路: 1 datasetsource (支持auto shard, shard by line/files) 2 processor 保持原有for-loop 内部的“elem” 处理逻辑, 使用map filter 等方式调用

优势:

  • 比如whisper hybrid tokenizer 可以自己 构造自己任务:
    1 datasetsource
    2 feats
    3 hybrid tokenizer (只需要自己写’elem‘的函数, 然后map下)
    4 batch
    同时复用现在yaml里边
dataset: whisper_dataset. # 以前是ASRDataset
  • 对于tts/llm, 同上

WenetSourceDataPipe中的文件句柄stream可以用torch中的StreamWrapper封装,避免在其他位置手动stream.close();

from torch.utils.data.datapipes.utils.common import StreamWrapper
stream = StreamWrapper(open(line, 'rb'))

@bqzhu922
Copy link

@xingchensong 打扰想请教下现在wenet这套IO对比重构之前多机多卡的效率大概提升了多少啊,有一些定量的比较吗,感谢

@Mddct
Copy link
Collaborator

Mddct commented Dec 18, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request Stale
Projects
None yet
Development

No branches or pull requests

6 participants