-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IO 重构,提升多机多卡训练效率 + 代码复用 #2152
Comments
现在 IO 的主要瓶颈是在哪里? |
|
星辰提到的3, 可以加个多任务, 下边举个例子
这个例子中,我们需要很方便初始化dataset1 dataset2 dataset3. dataset = interleave(dataset1, dataset2, dataset3) 我们需要这要的功能,并且不止是多任务, 对于语音任务 在一个batch上我们也需要组合不同领域的数据进行tune |
另外,补充下 , 现在recognize.py 依赖的dataset 需要有label, 重构的dataset需要考虑单纯的infer dataset(split='train') # speech and label
dataset(split='cv') # speech and label dataset(split='infer') # only speech 对于和大模型的结合: textdataset # only text
speechdataset # only speech ... |
当前瓶颈 在于:1 同步的取数据 (及时有prefetch) 方案一:集成huggingface datasets:1 优点:
这些原生函数的实现,dtaset 会极度简化 和多模态结合 也会很方便 2 缺点:
方案二: 造轮子抽象现有的process函数补&&齐&添加新函数
缺点:
优点:
ref: https://huggingface.co/docs/datasets/stream 补充下:大模型时代,好多吧动态batch这种功能给去掉了,hg也不支持(tfdata 支持), 比如whisper 直接encoder 30s , LLM也有直接pad到预先设置的最大(非batch) |
应该可以的, hg也是用的arrow的py包 |
可以,那支持方案二 |
新IO需要考虑determinism(data 层面保证可复现) |
关于代码复用性, 下边是代码片段 # 这里我们可以继承IterableDataset
class Dataset():
def __init__(self, dataset, func=None, *args, **kwargs):
self._dataset = dataset
self.args = args
self.kwargs = kwargs
self.func = func
@staticmethod
def from_source(source):
return Dataset(source)
def __iter__(self):
return self
def __next__(self):
if not self._dataset:
raise StopIteration
data = next(self._dataset)
return self.func(data)
def map(self, func, *args, **kwargs):
return MapperDataset(self, func, *args, **kwargs)
def filter(self, func, *args, **kwargs):
return FilterDataset(self, func, *args, **kwargs)
class MapperDataset(Dataset):
def __init__(self, dataset, func=None, *args, **kwargs):
self._dataset = dataset
self.args = args
self.kwargs = kwargs
self.func = func
def __iter__(self):
return self
def __next__(self):
if not self._dataset:
raise StopIteration
data = next(self._dataset)
return self.func(data)
class FilterDataset(Dataset):
def __init__(self, dataset, func=None, *args, **kwargs):
self._dataset = dataset
self.args = args
self.kwargs = kwargs
self.func = func
def __iter__(self):
return self
def __next__(self):
if not self._dataset:
raise StopIteration
data = next(self._dataset)
while not self.func(data):
data = next(self._dataset)
return data
source = iter([1,2,3,4])
dataset = Dataset(source, lambda elem: elem)
dataset = dataset.map(lambda elem: {"speech": elem*2})
dataset = dataset.filter(lambda elem_dict: elem_dict['speech'] > 2)
for d in dataset:
print(d)
# output:
{'speech': 4}
{'speech': 6}
{'speech': 8} |
wenet的训练脚本在evaluate时只能单卡去过整个eval dataset,这部分感觉也是可以优化的,是否可以通过继承torch的sampler来实现利用DDP加速。比如: from torch.utils.data.distributed import DistributedSampler
from catalyst.data.sampler import DistributedSamplerWrapper
dataset = ...
shuffle = ...
sampler = ...
# If DDP on
if torch.distributed.is_initialized():
# If using a custom sampler make it distributed
if sampler is not None:
sampler = DistributedSamplerWrapper(sampler,
shuffle=shuffle,
num_replicas=communication.get_world_size(),
rank=communication.get_rank())
# If no custom sampler then just use the DistributedSampler
else:
sampler = DistributedSampler(dataset,
shuffle=shuffle,
num_replicas=communication.get_world_size(),
rank=communication.get_rank())
# shuffle shouldn't be specified in DataLoader when using a sampler
shuffle = shuffle if sampler is None else None
dataloader = DataLoader(dataset, sampler=sampler, shuffle=shuffle, ...) |
是的,现在 eval 的时候是单卡 eval 所有,没有发挥多卡优势。这块是个优化点。 |
期待wenet的新版脚本哈哈 |
modified code: class Dataset:
def __init__(self, source, f=lambda elem: elem, *args, **kw):
assert callable(f)
self._dataset = source
self.f = f
self.args = args
self.kw = kw
def set_epoch(self, epoch):
self.source.set_epoch(epoch)
def __iter__(self):
""" Return an iterator over the source dataset processed by the
given processor.
"""
assert self._dataset is not None
assert callable(self.f)
for data in self._dataset:
yield data
def apply(self, f):
assert callable(f)
return Dataset(self, f, *self.args, **self.kw)
def map(self, func, *args, **kwargs):
return MapperDataset(self, func, *args, **kwargs)
def filter(self, func, *args, **kwargs):
return FilterDataset(self, func, *args, **kwargs)
def sort(self, func, *args, **kwargs):
return SortDataset(self, func, *args, **kwargs)
def zip(self, *datasets):
return ZipDataset(self, *datasets)
class MapperDataset(Dataset):
def __init__(self, dataset, func=None, *args, **kwargs):
self._dataset = dataset
self.args = args
self.kwargs = kwargs
self.func = func
def __iter__(self):
return self._generator()
def _generator(self):
for data in self._dataset:
yield self.func(data, *self.args, **self.kwargs)
class FilterDataset(Dataset):
def __init__(self, dataset, func=None, *args, **kwargs):
self._dataset = dataset
self.args = args
self.kwargs = kwargs
self.func = func
def __iter__(self):
return self._generator()
def _generator(self):
for data in self._dataset:
if self.func(data, *self.args, **self.kwargs):
yield data
class SortDataset(Dataset):
def __init__(self, dataset, key=None, reverse=False, buffer_size=None):
self._dataset = dataset
self.key = key
self.reverse = reverse
self.buffer_size = buffer_size
def __iter__(self):
return self._generator()
def _generator(self):
buffer = []
for data in self._dataset:
buffer.append(data)
if self.buffer_size is not None and len(buffer) >= self.buffer_size:
sorted_buffer = sorted(buffer, key=self.key, reverse=self.reverse)
for sorted_data in sorted_buffer:
yield sorted_data
buffer.clear()
if buffer:
sorted_buffer = sorted(buffer, key=self.key, reverse=self.reverse)
for sorted_data in sorted_buffer:
yield sorted_data
class ZipDataset(Dataset):
def __init__(self, *datasets):
self.datasets = datasets
def __iter__(self):
return self._generator()
def _generator(self):
iterators = [iter(dataset) for dataset in self.datasets]
while True:
try:
data = [next(iterator) for iterator in iterators]
yield tuple(data)
except StopIteration:
return
class PaddingBatchDataset(Dataset):
def __init__(self, dataset, batch_size, padding_fn, max_length_fn):
self.dataset = dataset
self.batch_size = batch_size
self.padding_fn = padding_fn
self.max_length_fn = max_length_fn
def __iter__(self):
return self._generator()
def _generator(self):
batch = []
max_length = 0
for data in self.dataset:
batch.append(data)
max_length = self.max_length_fn(data, max_length)
if len(batch) == self.batch_size:
padded_batch = self._pad_batch(batch, max_length)
yield padded_batch
batch = []
max_length = 0
if batch:
padded_batch = self._pad_batch(batch, max_length)
yield padded_batch
def _pad_batch(self, batch, max_length):
padded_batch = []
for data in batch:
padding_length = max_length - self.max_length_fn(data)
padded_data = self.padding_fn(data, padding_length)
padded_batch.append(padded_data)
return padded_batch
# 创建数据源
def generator(data):
for d in data:
yield d
source = generator([1,2,3,4,1])
# 创建 Dataset 实例
speech_dataset = Dataset(source)
# preprocess
speech_dataset = speech_dataset.map(lambda elem: {"speech": elem * 2})
speech_dataset = speech_dataset.filter(lambda elem_dict: elem_dict['speech'] >= 2)
speech_dataset = speech_dataset.sort(lambda elem_dict: elem_dict['speech'], buffer_size=2)
# fbank
speech_dataset = speech_dataset.map(lambda elem_dict: {'fbank': elem_dict['speech'] + 1, 'speech': elem_dict['speech']})
llm_dataset = Dataset(generator([10,20,30,40,50,60]))
# eg tokenize
llm_dataset = llm_dataset.map(lambda elem: {"tokens": elem + 1 , "text": elem})
task_dataset = speech_dataset.zip(llm_dataset)
task_dataset = task_dataset.sort(lambda elem: elem[1]['tokens'])
# # 迭代并打印结果
for data in task_dataset:
print(data)
# output:
({'fbank': 3, 'speech': 2}, {'tokens': 11, 'text': 10})
({'fbank': 5, 'speech': 4}, {'tokens': 21, 'text': 20})
({'fbank': 7, 'speech': 6}, {'tokens': 31, 'text': 30})
({'fbank': 9, 'speech': 8}, {'tokens': 41, 'text': 40})
({'fbank': 3, 'speech': 2}, {'tokens': 51, 'text': 50}) |
是不是可以直接用lhotse呢? |
第一步,先简化解析tar包的逻辑,这样 |
lthoste 和wenet现在的本质上没有区别,
|
目前torch官方已经提供了data chain 调用方式: https://github.com/pytorch/pytorch/tree/main/torch/utils/data/datapipes (非torchdata), 并且wenet已经升级到2.xx了 所以倾向于使用torch官方的, 下边是初步代码 @xingchensong @robin1001 import io
import json
import tarfile
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, IterDataPipe
from torch.utils.data import datapipes
import torchaudio
from torchaudio._extension import logging
from torchaudio.compliance.kaldi import fbank
from wenet.dataset.processor import compute_fbank
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
class WenetSourceDataPipe(IterDataPipe):
def __init__(self, dp, data_type='raw', **fmtparams):
self.dp = datapipes.iter.ShardingFilter(
datapipes.iter.FileOpener(dp, mode='r'))
datapipes.iter.StreamReader
self.data_type = data_type
self.fmtparams = fmtparams
def __iter__(self):
for _, stream in self.dp:
for line in stream:
line = line.strip('\n')
if self.data_type == 'raw':
json_obj = json.loads(line)
with open(json_obj['wav'], 'rb') as f:
json_obj['wav'] = f.read()
yield json_obj
else:
yield {'stream': open(line, 'rb')}
class TarFileGroupSourceDataPipe(IterDataPipe):
def __init__(self, dp) -> None:
super().__init__()
self.dp = dp
def __iter__(self):
for sample in self.dp:
try:
# stream = tarfile.open(fileobj=sample['stream'], mode="r:*")
with tarfile.open(fileobj=sample['stream'],
mode="r:*") as stream:
prev_prefix = None
example = {}
valid = True
for tarinfo in stream:
name = tarinfo.name
pos = name.rfind('.')
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if prev_prefix is not None and prefix != prev_prefix:
example['key'] = prev_prefix
if valid:
yield example
example = {}
valid = True
with stream.extractfile(tarinfo) as file_obj:
try:
if postfix == 'txt':
example['txt'] = file_obj.read().decode(
'utf8').strip()
elif postfix in AUDIO_FORMAT_SETS:
example['wav'] = file_obj.read()
else:
example[postfix] = file_obj.read()
except Exception as ex:
valid = False
logging.warning(
'error to parse {}'.format(name))
prev_prefix = prefix
if prev_prefix is not None:
example['key'] = prev_prefix
yield example
except Exception as ex:
logging.warning(
'In tar_file_and_group: {} when processing '.format(
ex)) #, sample['src']))
finally:
stream.close()
if 'process' in sample:
sample['process'].communicate()
sample['stream'].close()
def decode_wav(elem):
wav = elem['wav']
key = elem['key']
txt = elem['txt']
with io.BytesIO(wav) as file_obj:
waveform, sr = torchaudio.load(file_obj)
return {"key": key, "txt": txt, 'waveform': waveform, "sample_rate": sr}
def compute_fbank(data,
num_mel_bins=23,
frame_length=25,
frame_shift=10,
dither=0.0):
sample_rate = data['sample_rate']
waveform = data['waveform']
waveform = waveform * (1 << 15)
mat = fbank(waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
sample_frequency=sample_rate)
data['feat'] = mat
return data
def padding(data):
assert isinstance(data, list)
sample = data
feats_length = torch.tensor([x['feat'].size(0) for x in sample],
dtype=torch.int32)
order = torch.argsort(feats_length, descending=True)
feats_lengths = torch.tensor([sample[i]['feat'].size(0) for i in order],
dtype=torch.int32)
sorted_feats = [sample[i]['feat'] for i in order]
sorted_keys = [sample[i]['key'] for i in order]
padded_feats = pad_sequence(sorted_feats,
batch_first=True,
padding_value=0)
batch = {
"keys": sorted_keys,
"feats": padded_feats,
"feats_lengths": feats_lengths,
}
return batch
def get_dataloader(data_type, files):
dataset = WenetSourceDataPipe(files, data_type)
# shard by files
if data_type == 'shard':
dataset = WenetSourceDataPipe(files, data_type=data_type)
dataset = TarFileGroupSourceDataPipe(dataset)
dataset = dataset.map(decode_wav)
dataset = dataset.map(compute_fbank)
dataset = dataset.batch(wrapper_class=padding, batch_size=2)
dataloader = DataLoader(dataset,
batch_size=None,
num_workers=4,
persistent_workers=True)
return dataloader
if __name__ == '__main__':
raw_dataloader = get_dataloader('raw',
['test/resources/dataset/data.list'])
tar_dataloader = get_dataloader(
'shard', ['test/resources/dataset/data.shards.list'])
print("--------" + "wenet raw data type" + '---------\n')
for raw_batch in raw_dataloader:
print(raw_batch)
print("\n--------" + "wenet shard data type" + '---------\n')
for shard_batch in tar_dataloader:
print(shard_batch) 之后重构思路: 优势:
|
|
@xingchensong 打扰想请教下现在wenet这套IO对比重构之前多机多卡的效率大概提升了多少啊,有一些定量的比较吗,感谢 |
#2333 (comment)
bqzhu922 ***@***.***> 于2024年12月16日周一 11:38写道:
… @xingchensong <https://github.com/xingchensong>
打扰想请教下现在wenet这套IO对比重构之前多机多卡的效率大概提升了多少啊,有一些定量的比较吗,感谢
—
Reply to this email directly, view it on GitHub
<#2152 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABFN3Q74DMD7YHXL2FTXM6T2FZDMFAVCNFSM6AAAAABTVHNR3WVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKNBUGQ4TKMZYGM>
.
You are receiving this because you were assigned.Message ID:
***@***.***>
|
No description provided.
The text was updated successfully, but these errors were encountered: