Skip to content

Commit

Permalink
Change H5Py dataset implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
wilson1yan committed Jun 24, 2021
1 parent 426009a commit a693f20
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions videogpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import glob
import h5py
import numpy as np

import torch
import torch.utils.data as data
Expand Down Expand Up @@ -132,24 +133,12 @@ def __init__(self, data_file, sequence_length, train=True, resolution=64):
self.prefix = 'train' if train else 'test'
self._images = self.data[f'{self.prefix}_data']
self._idx = self.data[f'{self.prefix}_idx']

# compute splits for all possible sequences
self._splits = self._compute_seq_splits()
self.size = len(self._idx)

@property
def n_classes(self):
raise Exception('class conditioning not support for HDF5Dataset')

def _compute_seq_splits(self):
splits = []
n_videos = len(self._idx)
for i in range(n_videos):
start = self._idx[i]
end = self._idx[i + 1] if i < n_videos - 1 else n_videos
splits.extend([(start + i, start + i + self.sequence_length)
for i in range(end - start - self.sequence_length)])
return splits

def __getstate__(self):
state = self.__dict__
state['data'] = None
Expand All @@ -164,11 +153,16 @@ def __setstate__(self, state):
self._idx = self.data[f'{self.prefix}_idx']

def __len__(self):
return len(self._splits)
return self.size

def __getitem__(self, idx):
start_idx, end_idx = self._splits[idx]
video = torch.tensor(self._images[start_idx:end_idx])
start = self._idx[idx]
end = self._idx[idx + 1] if idx < len(self._idx) - 1 else len(self._images)
assert end - start >= 0

start = start + np.random.randint(low=0, high=end - start - self.sequence_length)
assert start < start + self.sequence_length <= end
video = torch.tensor(self._images[start:start + self.sequence_length])
return dict(video=preprocess(video, self.resolution))


Expand Down

0 comments on commit a693f20

Please sign in to comment.