Skip to content

Commit

Permalink
bair data loader and pretrained bair model.
Browse files Browse the repository at this point in the history
  • Loading branch information
edenton committed Mar 3, 2018
1 parent 4bb7869 commit 74c4be1
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.pyc
data/
logs
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,18 @@ To train the SVG-LP model on the 2 digit SM-MNIST dataset run:
python train_svg_lp.py --dataset smmnist --num_digits 2 --g_dim 128 --z_dim 10 --beta 0.0001 --data_root /path/to/data/ --log_dir /logs/will/be/saved/here/
```
If the MNIST dataset doesn't exist, it will be downloaded to the specified path.

## BAIR robot push dataset
To download the BAIR robot push dataset run:
```
sh data/download_bair.sh /path/to/data/
```
This will download the dataset in tfrecord format into the specified directory. To train the pytorch models, we need to first convert to tfrecord data into .png images by running:
```
python data/convert_bair.py --data_dir /path/to/data/
```
This may take some time. Images will be saved in ```/path/to/data/processeddata```.
Now we can train the SVG-LP model by running:
```
python train_svg_lp.py --dataset bair --g_dim 128 --z_dim 64 --beta 0.0001 --n_past 2 --n_future 10 --channels 3 --data_root /path/to/data/ --log_dir /logs/will/be/saved/here/
```
61 changes: 61 additions & 0 deletions data/bair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import io
from scipy.misc import imresize
import numpy as np
from PIL import Image
from scipy.misc import imresize
from scipy.misc import imread


class RobotPush(object):

"""Data Handler that loads robot pushing data."""

def __init__(self, data_root, train=True, seq_len=20, image_size=64):
self.root_dir = data_root
if train:
self.data_dir = '%s/processed_data/train' % self.root_dir
self.ordered = False
else:
self.data_dir = '%s/processed_data/test' % self.root_dir
self.ordered = True
self.dirs = []
for d1 in os.listdir(self.data_dir):
for d2 in os.listdir('%s/%s' % (self.data_dir, d1)):
self.dirs.append('%s/%s/%s' % (self.data_dir, d1, d2))
self.seq_len = seq_len
self.image_size = image_size
self.seed_is_set = False # multi threaded loading
self.d = 0

def set_seed(self, seed):
if not self.seed_is_set:
self.seed_is_set = True
np.random.seed(seed)

def __len__(self):
return 10000

def get_seq(self):
if self.ordered:
d = self.dirs[self.d]
if self.d == len(self.dirs) - 1:
self.d = 0
else:
self.d+=1
else:
d = self.dirs[np.random.randint(len(self.dirs))]
image_seq = []
for i in range(self.seq_len):
fname = '%s/%d.png' % (d, i)
im = imread(fname).reshape(1, 64, 64, 3)
image_seq.append(im/255.)
image_seq = np.concatenate(image_seq, axis=0)
return image_seq


def __getitem__(self, index):
self.set_seed(index)
return self.get_seq()


60 changes: 60 additions & 0 deletions data/convert_bair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
import io

import numpy as np
from PIL import Image
import tensorflow as tf

from tensorflow.python.platform import flags
from tensorflow.python.platform import gfile

from scipy.misc import imresize
from scipy.misc import imsave

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', default='', help='base directory to save processed data')
opt = parser.parse_args()

def get_seq(dname):
data_dir = '%s/softmotion30_44k/%s' % (opt.data_dir, dname)

filenames = gfile.Glob(os.path.join(data_dir, '*'))
if not filenames:
raise RuntimeError('No data files found.')

for f in filenames:
k=0
for serialized_example in tf.python_io.tf_record_iterator(f):
example = tf.train.Example()
example.ParseFromString(serialized_example)
image_seq = []
for i in range(30):
image_name = str(i) + '/image_aux1/encoded'
byte_str = example.features.feature[image_name].bytes_list.value[0]
#img = Image.open(io.BytesIO(byte_str))
img = Image.frombytes('RGB', (64, 64), byte_str)
arr = np.array(img.getdata()).reshape(img.size[1], img.size[0], 3)
image_seq.append(arr.reshape(1, 64, 64, 3)/255.)
image_seq = np.concatenate(image_seq, axis=0)
k=k+1
yield f, k, image_seq

def convert_data(dname):
seq_generator = get_seq(dname)
n = 0
while True:
n+=1
try:
f, k, seq = next(seq_generator)
except StopIteration:
break
f = f.split('/')[-1]
os.makedirs('%s/processed_data/%s/%s/%d/' % (opt.data_dir, dname, f[:-10], k), exist_ok=True)
for i in range(len(seq)):
imsave('/%s/processed_data/%s/%s/%d/%d.png' % (opt.data_dir, dname, f[:-10], k, i), seq[i])

print('%s data: %s (%d) (%d)' % (dname, f, k, n))

convert_data('test')
convert_data('train')
10 changes: 10 additions & 0 deletions data/download_bair.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
TARGET_DIR=$1
if [ -z $TARGET_DIR ]
then
echo "Must specify target directory"
else
mkdir $TARGET_DIR/
URL=http://rail.eecs.berkeley.edu/datasets/bair_robot_pushing_dataset_v0.tar
wget $URL -P $TARGET_DIR
tar -xvf $TARGET_DIR/bair_robot_pushing_dataset_v0.tar -C $TARGET_DIR
fi
93 changes: 93 additions & 0 deletions data/moving_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import socket
import numpy as np
from torchvision import datasets, transforms

class MovingMNIST(object):

"""Data Handler that creates Bouncing MNIST dataset on the fly."""

def __init__(self, train, data_root, seq_len=20, num_digits=2, image_size=64, deterministic=True):
path = data_root
self.seq_len = seq_len
self.num_digits = num_digits
self.image_size = image_size
self.step_length = 0.1
self.digit_size = 32
self.deterministic = deterministic
self.seed_is_set = False # multi threaded loading
self.channels = 1

self.data = datasets.MNIST(
path,
train=train,
download=True,
transform=transforms.Compose(
[transforms.Scale(self.digit_size),
transforms.ToTensor()]))

self.N = len(self.data)

def set_seed(self, seed):
if not self.seed_is_set:
self.seed_is_set = True
np.random.seed(seed)

def __len__(self):
return self.N

def __getitem__(self, index):
self.set_seed(index)
image_size = self.image_size
digit_size = self.digit_size
x = np.zeros((self.seq_len,
image_size,
image_size,
self.channels),
dtype=np.float32)
for n in range(self.num_digits):
idx = np.random.randint(self.N)
digit, _ = self.data[idx]

sx = np.random.randint(image_size-digit_size)
sy = np.random.randint(image_size-digit_size)
dx = np.random.randint(-4, 5)
dy = np.random.randint(-4, 5)
for t in range(self.seq_len):
if sy < 0:
sy = 0
if self.deterministic:
dy = -dy
else:
dy = np.random.randint(1, 5)
dx = np.random.randint(-4, 5)
elif sy >= image_size-32:
sy = image_size-32-1
if self.deterministic:
dy = -dy
else:
dy = np.random.randint(-4, 0)
dx = np.random.randint(-4, 5)

if sx < 0:
sx = 0
if self.deterministic:
dx = -dx
else:
dx = np.random.randint(1, 5)
dy = np.random.randint(-4, 5)
elif sx >= image_size-32:
sx = image_size-32-1
if self.deterministic:
dx = -dx
else:
dx = np.random.randint(-4, 0)
dy = np.random.randint(-4, 5)

x[t, sy:sy+32, sx:sx+32, 0] += digit.numpy().squeeze()
sy += dy
sx += dx

x[x>1] = 1.
return x


Binary file added pretrain_models/svglp_bair.pth
Binary file not shown.
2 changes: 2 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ def load_dataset(opt):
elif opt.dataset == 'bair':
from data.bair import RobotPush
train_data = RobotPush(
data_root=opt.data_root,
train=True,
seq_len=opt.max_step,
image_size=opt.image_width)
test_data = RobotPush(
data_root=opt.data_root,
train=False,
seq_len=opt.n_eval,
image_size=opt.image_width)
Expand Down

0 comments on commit 74c4be1

Please sign in to comment.