Skip to content

Commit

Permalink
Loading Mnist and cifar
Browse files Browse the repository at this point in the history
  • Loading branch information
snknitin committed Apr 16, 2018
1 parent 12ddeb8 commit 0c1fbea
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions Code/data.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,33 @@
import os
from time import time
import pandas as pd
import numpy as np
from keras.datasets import mnist, cifar10

class Solution(object):
def __init__(self):

self.X_train,self.X_test,self.y_train,self.y_test,self.dataframe_all=None,None,None,None
def load_dataset(hps):
if hps.module == 'mnist':
width, height, channels=hps.train_mnist_dimensions

def visualize(self):
pass
# load MNIST data
(X_train, y_train), (X_test, y_test) = mnist.load_data()

def train_generator(self):
pass
# rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)

def train_discriminator(self):
pass
if hps.module == 'cifar10':
# load CIFAR10 data
width, height, channels = hps.train_mnist_dimensions
(X_train, y_train), (X_test, y_test) = cifar10.load_data()



if __name__=="__main__":
s=Solution()
# rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5

# defining input dims
img_rows = width
img_cols = height
channels = channels
img_shape = [img_rows, img_cols, channels]

return X_train, img_shape

0 comments on commit 0c1fbea

Please sign in to comment.