Skip to content
This repository was archived by the owner on Jun 15, 2023. It is now read-only.

Commit

Permalink
Minor updates in GEM code
Browse files Browse the repository at this point in the history
Adding cross validation script for the GEM
Arslan Chaudhry committed Jul 12, 2018
1 parent a6f5fdf commit 410d772
Showing 5 changed files with 62 additions and 17 deletions.
10 changes: 5 additions & 5 deletions conv_split_cifar.py
Original file line number Diff line number Diff line change
@@ -28,14 +28,14 @@
NUM_RUNS = 5 # Number of experiments to average over
TRAIN_ITERS = 2000 # Number of training iterations per task
BATCH_SIZE = 16
LEARNING_RATE = 1e-3
LEARNING_RATE = 0.1
RANDOM_SEED = 1234
VALID_OPTIMS = ['SGD', 'MOMENTUM', 'ADAM']
OPTIM = 'ADAM'
OPTIM = 'SGD'
OPT_MOMENTUM = 0.9
OPT_POWER = 0.9
VALID_ARCHS = ['CNN', 'RESNET']
ARCH = 'CNN'
ARCH = 'RESNET'

## Model options
MODELS = ['VAN', 'PI', 'EWC', 'MAS', 'GEM', 'RWALK'] #List of valid models
@@ -266,7 +266,7 @@ def train_task_sequence(model, sess, datasets, task_labels, cross_validate_mode,
# Note that the model.train_phase flag is false to avoid updating the batch norm params while doing forward pass on prev tasks
logit_mask[:] = 0
logit_mask[task_labels[prev_task]] = 1.0
sess.run([model.compute_task_gradients, model.store_task_gradients], feed_dict={model.x: task_based_memory[prev_task]['images'],
sess.run([model.task_grads, model.store_task_gradients], feed_dict={model.x: task_based_memory[prev_task]['images'],
model.y_: task_based_memory[prev_task]['labels'], model.task_id: prev_task, model.keep_prob: 1.0,
model.output_mask: logit_mask, model.train_phase: False})

@@ -275,7 +275,7 @@ def train_task_sequence(model, sess, datasets, task_labels, cross_validate_mode,
logit_mask[task_labels[task]] = 1.0
feed_dict[model.output_mask] = logit_mask
feed_dict[model.task_id] = task
_, _,loss = sess.run([model.compute_task_gradients, model.store_task_gradients, model.reg_loss], feed_dict=feed_dict)
_, _,loss = sess.run([model.task_grads, model.store_task_gradients, model.reg_loss], feed_dict=feed_dict)
# Store the gradients
sess.run([model.gem_gradient_update, model.store_grads], feed_dict={model.task_id: task})
# Apply the gradients
43 changes: 43 additions & 0 deletions cross_validation_scripts/cross_validate_lr_gem.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#! /bin/bash
# Usage ./cross_validate_lr.sh <EXPERIMENT_NAME>, where <EXPERIMENT_NAME> could be 'SPLIT_MNIST', 'PERMUTE_MNIST', 'SPLIT_CIFAR'
set -e

EXP=$1
IMP_METHOD='GEM'
SYNAP_STGTH=(0)
BATCH_SIZE=10
LOG_DIR='../cross_validation_results_20tasks'
OPTIM='SGD'
ARCH='RESNET'
if [ $EXP = "SPLIT_MNIST" ]; then
LR=(0.0001 0.001 0.003 0.01 0.03 0.1)
for lamda in ${SYNAP_STGTH[@]}
do
for lr in ${LR[@]}
do
python ../fc_split_mnist.py --cross-validate-mode --train-single-epoch --num-runs 5 --batch-size $BATCH_SIZE --learning-rate $lr --imp-method $IMP_METHOD --synap-stgth $lamda --log-dir $LOG_DIR
done
done
elif [ $EXP = "PERMUTE_MNIST" ]; then
LR=(0.0001 0.0003 0.001 0.003 0.01 0.03 0.1 0.3 1.0)
for lamda in ${SYNAP_STGTH[@]}
do
for lr in ${LR[@]}
do
python ../fc_permute_mnist.py --cross-validate-mode --train-single-epoch --num-runs 5 --batch-size $BATCH_SIZE --learning-rate $lr --imp-method $IMP_METHOD --synap-stgth $lamda --log-dir $LOG_DIR
done
done
elif [ $EXP = "SPLIT_CIFAR" ]; then
LR=(0.0001 0.0003 0.001 0.003 0.01 0.03 0.1 0.3 1.0)
for lamda in ${SYNAP_STGTH[@]}
do
for lr in ${LR[@]}
do
python ../conv_split_cifar.py --cross-validate-mode --train-single-epoch --arch $ARCH --num-runs 3 --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method $IMP_METHOD --synap-stgth $lamda --log-dir $LOG_DIR
done
done
else
echo "ERROR! Wrong Experiment Name!!"
exit 1
fi

9 changes: 5 additions & 4 deletions fc_permute_mnist.py
Original file line number Diff line number Diff line change
@@ -40,7 +40,7 @@
SYNAP_STGTH = 75000
FISHER_EMA_DECAY = 0.9 # Exponential moving average decay factor for Fisher computation (online Fisher)
FISHER_UPDATE_AFTER = 10 # Number of training iterations for which the F_{\theta}^t is computed (see Eq. 10 in RWalk paper)
MEMORY_SIZE_PER_TASK = 10 # Number of samples per task
MEMORY_SIZE_PER_TASK = 25 # Number of samples per task
INPUT_FEATURE_SIZE = 784
TOTAL_CLASSES = 10 # Total number of classes in the dataset

@@ -163,7 +163,8 @@ def train_task_sequence(model, sess, datasets, cross_validate_mode, train_single

# Train a task observing sequence of data
if train_single_epoch:
num_iters = num_train_examples // batch_size
num_iters = 20
#num_iters = num_train_examples // batch_size
else:
num_iters = train_iters

@@ -223,13 +224,13 @@ def train_task_sequence(model, sess, datasets, cross_validate_mode, train_single
for prev_task in range(task):
# T-th task gradients.
# Note that the model.train_phase flag is false to avoid updating the batch norm params while doing forward pass on prev tasks
sess.run([model.compute_task_gradients, model.store_task_gradients], feed_dict={model.x: task_based_memory[prev_task]['images'],
sess.run([model.task_grads, model.store_task_gradients], feed_dict={model.x: task_based_memory[prev_task]['images'],
model.y_: task_based_memory[prev_task]['labels'], model.task_id: prev_task, model.keep_prob: 1.0,
model.output_mask: logit_mask, model.train_phase: False})

# Compute the gradient on the mini-batch of the current task
feed_dict[model.task_id] = task
_, _,loss = sess.run([model.compute_task_gradients, model.store_task_gradients, model.reg_loss], feed_dict=feed_dict)
_, _,loss = sess.run([model.task_grads, model.store_task_gradients, model.reg_loss], feed_dict=feed_dict)
# Store the gradients
sess.run([model.gem_gradient_update, model.store_grads], feed_dict={model.task_id: task})
# Apply the gradients
4 changes: 2 additions & 2 deletions fc_split_mnist.py
Original file line number Diff line number Diff line change
@@ -252,7 +252,7 @@ def train_task_sequence(model, sess, datasets, task_labels, cross_validate_mode,
# Note that the model.train_phase flag is false to avoid updating the batch norm params while doing forward pass on prev tasks
logit_mask[:] = 0
logit_mask[task_labels[prev_task]] = 1.0
sess.run([model.compute_task_gradients, model.store_task_gradients], feed_dict={model.x: task_based_memory[prev_task]['images'],
sess.run([model.task_grads, model.store_task_gradients], feed_dict={model.x: task_based_memory[prev_task]['images'],
model.y_: task_based_memory[prev_task]['labels'], model.task_id: prev_task, model.keep_prob: 1.0,
model.output_mask: logit_mask, model.train_phase: False})

@@ -261,7 +261,7 @@ def train_task_sequence(model, sess, datasets, task_labels, cross_validate_mode,
logit_mask[task_labels[task]] = 1.0
feed_dict[model.output_mask] = logit_mask
feed_dict[model.task_id] = task
_, _,loss = sess.run([model.compute_task_gradients, model.store_task_gradients, model.reg_loss], feed_dict=feed_dict)
_, _,loss = sess.run([model.task_grads, model.store_task_gradients, model.reg_loss], feed_dict=feed_dict)
# Store the gradients
sess.run([model.gem_gradient_update, model.store_grads], feed_dict={model.task_id: task})
# Apply the gradients
13 changes: 7 additions & 6 deletions model/model.py
Original file line number Diff line number Diff line change
@@ -620,17 +620,18 @@ def create_gem_ops(self):
"""
Define operations for Gradients Episodic Memory (GEM) method
"""
self.compute_task_gradients = tf.group([tf.assign(v, grad) for v, grad in zip(self.gem_reg_grads, tf.gradients(self.reg_loss, self.trainable_vars))])
with tf.control_dependencies([self.compute_task_gradients]):
flattened_grads = tf.concat([tf.reshape(v, [-1]) for v in self.gem_reg_grads], 0)
self.store_task_gradients = tf.assign(self.G[self.task_id], flattened_grads)
# Compute the gradients for a given task id
self.task_grads = tf.gradients(self.reg_loss, self.trainable_vars)

with tf.control_dependencies(self.task_grads):
self.store_task_gradients = tf.assign(self.G[self.task_id], tf.concat([tf.reshape(grad, [-1]) for grad in self.task_grads], 0))

def projectgradients_tfn():
return tf.py_func(project2cone2, [self.G[self.task_id], self.G[:self.task_id]], [tf.float32])

# Check if any of the constraints in GEM is violated. If yes, then solve the QP
self.gem_gradient_update = tf.cond(tf.cast(tf.reduce_sum(tf.cast(tf.less(tf.matmul(tf.expand_dims(self.G[self.task_id], axis=0),
tf.transpose(self.G)), 0), tf.int32)) != 0, tf.bool), projectgradients_tfn, lambda: tf.identity(self.G[self.task_id]))
self.gem_gradient_update = tf.cond(tf.equal(tf.reduce_sum(tf.cast(tf.less(tf.matmul(tf.expand_dims(self.G[self.task_id], axis=0),
tf.transpose(self.G)), 0), tf.int32)), 0), lambda: tf.identity(self.G[self.task_id]), projectgradients_tfn)

# Define ops to store the gradients
with tf.control_dependencies([self.gem_gradient_update]):

0 comments on commit 410d772

Please sign in to comment.