Skip to content

Commit

Permalink
add cifar to datasets.py (tinygrad#6210)
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot authored Aug 20, 2024
1 parent a5d7968 commit d9c62a3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
11 changes: 5 additions & 6 deletions examples/hlb_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import random, time
import numpy as np
from typing import Optional
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
from extra.lr_scheduler import OneCycleLR
from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
from tinygrad.nn.state import get_state_dict, get_parameters
from tinygrad.nn import optim
from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod
from tinygrad.multi import MultiLazyBuffer

cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]

BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000)
EVAL_BS = getenv("EVAL_BS", BS)
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
Expand Down Expand Up @@ -252,7 +254,7 @@ def fetch_batches(X_in:Tensor, Y_in:Tensor, BS:int, is_train:bool):
if not is_train: break

transform = [
lambda x: x / 255.0,
lambda x: x.float() / 255.0,
lambda x: x.reshape((-1,3,32,32)) - Tensor(cifar_mean, device=x.device, dtype=x.dtype).reshape((1,3,1,1)),
lambda x: x / Tensor(cifar_std, device=x.device, dtype=x.dtype).reshape((1,3,1,1)),
]
Expand All @@ -277,10 +279,7 @@ def update(self, net, decay):

set_seed(getenv('SEED', hyp['seed']))

X_train, Y_train, X_test, Y_test = fetch_cifar()
# load data and label into GPU and convert to dtype accordingly
X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float()
Y_train, Y_test = Y_train.to(device=Device.DEFAULT), Y_test.to(device=Device.DEFAULT)
X_train, Y_train, X_test, Y_test = nn.datasets.cifar()
# one-hot encode labels
Y_train, Y_test = Y_train.one_hot(10), Y_test.one_hot(10)
# preprocess data
Expand Down
16 changes: 11 additions & 5 deletions tinygrad/nn/datasets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import gzip
from tinygrad.tensor import Tensor
from tinygrad.helpers import fetch
from tinygrad.nn.state import tar_extract

def _fetch_mnist(file, offset): return Tensor(gzip.open(fetch("https://storage.googleapis.com/cvdf-datasets/mnist/"+file)).read()[offset:])
def mnist():
return _fetch_mnist("train-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("train-labels-idx1-ubyte.gz", 8), \
_fetch_mnist("t10k-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("t10k-labels-idx1-ubyte.gz", 8)
def _mnist(file): return Tensor(fetch("https://storage.googleapis.com/cvdf-datasets/mnist/"+file, gunzip=True))
def mnist(device=None):
return _mnist("train-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("train-labels-idx1-ubyte.gz")[8:].to(device), \
_mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)

def cifar(device=None):
tt = tar_extract(fetch('https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', gunzip=True))
train = Tensor.cat(*[tt[f"cifar-10-batches-bin/data_batch_{i}.bin"].reshape(-1, 3073).to(device) for i in range(1,6)])
test = tt["cifar-10-batches-bin/test_batch.bin"].reshape(-1, 3073).to(device)
return train[:, 1:].reshape(-1,3,32,32), train[:, 0], test[:, 1:].reshape(-1,3,32,32), test[:, 0]

0 comments on commit d9c62a3

Please sign in to comment.