You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I am trying to train your VAE for a project of my own and I noticed there is no validation part in the training. Is there an easy way to add validation to that training? Could you help me with how I might do that using your DataLoader? Thanks in advance.
The text was updated successfully, but these errors were encountered:
If you import train_test_split from sklearn.model_selection you can do something like this
# CHANGED JPG TO PNG
images = [im for im in os.listdir(args.folder) if im.endswith('.png')]
images = np.array(images)
n_samples = len(images)
if args.n_samples > 0:
n_samples = min(n_samples, args.n_samples)
# indices for all time steps where the episode continues
indices = np.arange(n_samples, dtype='int64')
np.random.shuffle(indices)
# NEW SECTION THAT SPLITS INDICES INTO A TRAIN AND VAL SET FIRST BEFORE BATCHING
indices_df = pd.DataFrame(indices, columns = ['indices'])
train_series, val_series = train_test_split(indices_df['indices'], train_size = 0.8)
train = train_series.to_numpy()
val = val_series.to_numpy()
print("{} images in total".format(n_samples))
print("{} images in training set".format(len(train)))
print("{} images in validation set".format(len(val)))
# split indices into minibatches. minibatchlist is a list of lists; each
# list is the id of the observation preserved through the training
train_minibatchlist = [np.array(sorted(train[start_idx:start_idx + args.batch_size]))
for start_idx in range(0, len(train) - args.batch_size + 1, args.batch_size)]
val_minibatchlist = [np.array(sorted(val[start_idx:start_idx + args.batch_size]))
for start_idx in range(0, len(val) - args.batch_size + 1, args.batch_size)]
train_data_loader = DataLoader(train_minibatchlist, images, n_workers=2, folder=args.folder)
val_data_loader = DataLoader(val_minibatchlist, images, n_workers=2, folder=args.folder)
This is within train.py, no need to edit DataLoader class
Hello, I am trying to train your VAE for a project of my own and I noticed there is no validation part in the training. Is there an easy way to add validation to that training? Could you help me with how I might do that using your DataLoader? Thanks in advance.
The text was updated successfully, but these errors were encountered: