Skip to content

Commit

Permalink
[dataset ] support interleave dataset (wenet-e2e#2610)
Browse files Browse the repository at this point in the history
* addd interleave dataset

* add ut
  • Loading branch information
Mddct authored Aug 14, 2024
1 parent 0dac0b5 commit 8859bf2
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 3 deletions.
26 changes: 24 additions & 2 deletions test/wenet/dataset/test_datapipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from torch.utils.data.datapipes.iter import IterableWrapper
from functools import partial

from wenet.dataset.datapipes import (RepeatDatapipe, SortDataPipe,
WenetRawDatasetSource,
from wenet.dataset.datapipes import (InterlaveDataPipe, RepeatDatapipe,
SortDataPipe, WenetRawDatasetSource,
WenetTarShardDatasetSource)
from wenet.dataset.processor import (DynamicBatchWindow, decode_wav, padding,
parse_json, compute_fbank,
Expand Down Expand Up @@ -224,3 +224,25 @@ def test_repeat():

assert len(result) == len(expected)
all(h == r for h, r in zip(result, expected))


def test_interleave():
dataset_1 = IterableWrapper(range(10))
dataset_2 = IterableWrapper(range(20, 30, 2))

dataset = InterlaveDataPipe([dataset_1, dataset_2])
dataset = dataset.batch(4)
generator = torch.Generator()
generator.manual_seed(100)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=None,
num_workers=0,
generator=generator,
persistent_workers=False)
expected = [[0, 1, 2, 3], [4, 20, 5, 22], [24, 6, 7, 8], [26, 9, 28]]

result = []
for batch in dataloader:
result.append(batch)

assert expected == result
44 changes: 43 additions & 1 deletion wenet/dataset/datapipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import sys
import tarfile
import logging
from typing import List
from typing import List, Optional
import numpy as np
import torch
from torch.utils.data import IterDataPipe, functional_datapipe
from torch.utils.data import datapipes
Expand Down Expand Up @@ -302,6 +303,47 @@ def apply_sharding(self, num_of_instances: int, instance_id: int,
self.instance_id = info.id


@functional_datapipe("interleave")
class InterlaveDataPipe(IterDataPipe):

def __init__(
self,
source_datapipes: List[IterDataPipe],
weights: Optional[List[float]] = None,
seed=2027,
):
super().__init__()
self.rng = np.random.default_rng(seed)
self.source_datapipes = source_datapipes
self.weights = weights
if weights is None:
self.weights = [1 / len(self.source_datapipes)] * len(
self.source_datapipes)
else:
self.weights = [weight / sum(weights) for weight in weights]
self.iters = None

def __iter__(self):
weights = copy.deepcopy(self.weights)
exhausted = len(self.source_datapipes) * [False]
if self.iters is None:
self.iters = [(i, iter(d))
for i, d in enumerate(self.source_datapipes)]
while True:
# TODO(Mddct): rng
index_iter = self.rng.choice(self.iters, p=weights)
i, ite = index_iter
try:
elem = next(ite)
yield elem
except StopIteration:
weights[i] = 0.
exhausted[i] = True
if all(exhausted):
return
weights = [weight / sum(weights) for weight in weights]


class TextLineDataPipe(IterDataPipe):
""" Streamming Text line
"""
Expand Down

0 comments on commit 8859bf2

Please sign in to comment.