def __len__(self):
return self.data.size(0)
def __len__(self):
return self.data.size(0) // self.seq_len
data_train, data_val = torch.Tensor(train_data1), torch.Tensor(train_data1[:(1024 * 128)+1])
class MusicSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
idx = index * self.seq_len
x = self.data[idx: idx + self.seq_len].long()
trg = self.data[(idx+1): (idx+1) + self.seq_len].long()
return x, trg
def __len__(self):
return (self.data.size(0) // self.seq_len) - 1
train_dataset = MusicSamplerDataset(data_train, SEQ_LEN)
val_dataset = MusicSamplerDataset(data_val, SEQ_LEN)
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE)
val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE)