Weights of Restored model after checkpointing do not give same loss as model before saving #1402
Description
Hello,
I have a class that contains an nnx.Module
and trains it. I try to save and restore by accessing this attribute but as the title says I find that when I restore the model, it's loss is as bad as a randomly initialized model.
I have no way to describe the problem as anything or than the title says because I will train a model, halve the loss from it's initialization, save the model using the instructions in the tutorial on saving and loading models (or the instructions given here google/flax#4383, or the instructions on the orbax website) and then restore them in another file and re-run the training loop. However at the final step my loss is the same as the loss I got at initialization. Note, that the parameters are not the ones I had at initialization but completely different ones that are equally poor when evaluated on my objective function.
I have attatched the code for my model, my training file, and my loading function.
Model file:
@nnx.jit
def training_step(model, optimizer, key, X, Y, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, maxT):
batch_size = X.shape[0]
t_key, key = random.split(key)
t = random.randint(t_key, (batch_size//2+1), 0, maxT)
t = jnp.concatenate([t, maxT-t-1], axis=-1)[:batch_size]
ab_s = jnp.take(alphas_bar_sqrt, t)[..., None]
am1 = jnp.take(one_minus_alphas_bar_sqrt, t)[..., None]
noise_key, key = random.split(key)
noise = random.normal(noise_key, X.shape)
x_ts = ab_s*X + am1*noise
def conditional_model_estimation_loss(model):
preds = model(x_ts, Y, t[..., None])
loss = jnp.mean(jnp.square(preds-noise))
return loss
loss, grads = nnx.value_and_grad(conditional_model_estimation_loss)(model)
optimizer.update(grads)
return loss
def make_beta_schedule(schedule = 'linear', T=100, start=1e-5, end = 0.5e-2):
if schedule == 'linear':
betas = jnp.linspace(start, end, T)
elif schedule == 'cosine':
fn = lambda x: jnp.cos(x/T + np.pi/2).pow(2)
betas = fn(jnp.linspace(0,1, T))
elif schedule == "sigmoid":
betas = jnp.linspace(-6, 6, T)
betas = jax.nn.sigmoid(betas) * (end - start) + start
else:
raise ValueError("schedule not implemented yet")
return betas
class ConditionalLinear(nnx.Module):
def __init__(self, num_in, num_out, n_steps, rngs):
self.lin = nnx.Linear(num_in, num_out, rngs = rngs)
self.embed = nnx.Embed(n_steps, num_out, rngs = rngs)
def __call__(self, x, t):
xout = self.lin(x)
em = jnp.reshape(self.embed(t), xout.shape)
#em = jnp.squeeze(self.embed(t), -2)
out = xout*em
#print(xout.shape, out.shape, em.shape, x.shape)
return out
class ConditionalDiffusionModel(nnx.Module):
def __init__(self, dim, conditioning_dim, hidden_dim, T, rngs):
self.dim = dim
self.cond_emb1 = ConditionalLinear(conditioning_dim, hidden_dim, T, rngs)
self.cond_emb2 = ConditionalLinear(conditioning_dim, hidden_dim, T, rngs)
self.cond_emb3 = ConditionalLinear(conditioning_dim, hidden_dim, T, rngs)
self.l1 = ConditionalLinear(dim+hidden_dim, hidden_dim, T, rngs)
self.l2 = ConditionalLinear(hidden_dim, hidden_dim, T, rngs)
self.l3 = ConditionalLinear(hidden_dim, hidden_dim, T, rngs)
self.last_layer = nnx.Linear(hidden_dim, dim, rngs=rngs)
def __call__(self, x, y, ts):
yemb = nnx.softplus(self.cond_emb1(y, ts))
yemb = nnx.softplus(self.cond_emb2(y, ts))
yemb = nnx.softplus(self.cond_emb3(y, ts))
xus = jnp.concatenate([x, yemb], axis=-1)
xus = nnx.softplus(self.l1(xus, ts))
xus = nnx.softplus(self.l2(xus, ts))
xus = nnx.softplus(self.l3(xus, ts))
preds = self.last_layer(xus)
return preds
class ConditionalDiffuser:
def __init__(self, dim, conditioning_dim, hidden_dim, beta_schedule, T, sigma, rngs):
self.model = ConditionalDiffusionModel(dim, conditioning_dim, hidden_dim, T, rngs)
self.betas = jnp.array(make_beta_schedule(beta_schedule, T))
self.alphas = 1-self.betas
self.alpha_bars = jnp.cumprod(self.alphas)
self.alphas_bar_sqrt = jnp.sqrt(self.alpha_bars)
self.one_minus_alphas_bar_sqrt = jnp.sqrt(1-self.alpha_bars)
self.alpha_bars_p = jnp.concatenate([jnp.array([1]), self.alpha_bars[:-1]])
self.T = T
self.sigma = sigma
def __call__(self, x):
return self.model(x)
def get_model(self):
return self.model
def train(self, key, dataset, opt, batch_size, epochs):
epoch_loss = [0]
ema_loss = None
ema = 0.9
for e in range(epochs):
if e % 10 == 0:
print(f"Epoch: {e}\\t: epoch loss {epoch_loss[-1]}, ema loss {ema_loss}")
el = []
key, perm_key = random.split(key)
permutation = random.permutation(perm_key, dataset['samples'].shape[0])
with tqdm(range(0, dataset['samples'].shape[0], batch_size)) as tp:
for i in tp:
indices = permutation[i:i+batch_size]
X = dataset['samples'][indices]
Y = dataset['conditioners'][indices]
step_key, key = random.split(key)
loss = training_step(self.model, opt, step_key, X, Y, self.alphas_bar_sqrt, self.one_minus_alphas_bar_sqrt, self.T)
el.append(loss)
if ema_loss is None:
ema_loss = loss
else:
ema_loss = ema*ema_loss + (1-ema)*loss
tp.set_postfix(loss = np.mean(el[-40:]))
el = np.mean(el)
epoch_loss.append(el)
def backward(self, x, y, ts, key):
if type(self.sigma) == list:
sigmas = np.choose(self.sigma, ts)
else:
sigmas = self.sigma
a_roots = jnp.sqrt(jnp.take(self.alphas,(ts-1).astype(jnp.int32))[..., None])
betas = jnp.take(self.betas, ts.astype(jnp.int32))[..., None]
one_minus_abar_roots = jnp.sqrt(1 - jnp.take(self.alpha_bars, ts.astype(jnp.int32))[..., None])
sigma_ts = jnp.sqrt( betas )
pred_noise = self.model(x, y, ts.astype(jnp.int32))
n_key, key = random.split(key)
x_t = (1/a_roots) * (x - ( betas / one_minus_abar_roots )*pred_noise) + sigma_ts * random.normal(n_key, x.shape)
return x_t
def complete_backward(self, x,y, T, key):
for t in range(1,T):
#xus = torch.concatenate([x, torch.ones(x.shape[0], 1)*(T-t)], dim=-1)
diff_key, key = random.split(key)
ts = (jnp.ones((x.shape[0], ))*(T-t))
x = self.backward(x,y, ts, diff_key )
return x
Training file
def main(args):
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
env = gym.make(args.env)
action_dim = env.action_space.shape[0]
state_dim = env.observation_space.shape[0]
dataset = np.load(args.dataset_path)
obs = dataset['observations']
actions = dataset['actions']
rew_to_go = dataset['reward_to_go']
print(rew_to_go.shape)
Y = jnp.array(np.concatenate([obs, rew_to_go], axis=-1))
X = jnp.array(actions)
data = {'samples': X, 'conditioners': Y}
if args.restore:
print("Loading from Checkpoint")
ckpt_path = Path('saved_models/my_checkpoints/').resolve()
diff_model = load_model(ckpt_path)
model = ConditionalDiffuser(dim=action_dim, conditioning_dim=state_dim+1, hidden_dim= 300, beta_schedule='linear', T=252, sigma=0.1, rngs=nnx.Rngs(0, noise=1))
model.model = diff_model
else:
print("Building Fresh Model")
model = ConditionalDiffuser(dim=action_dim, conditioning_dim=state_dim+1, hidden_dim= 300, beta_schedule='linear', T=252, sigma=0.1, rngs=nnx.Rngs(0, noise=1))
#model = ConditionalDiffusionModel(action_dim, state_dim+1, 1000, 8, 'linear', 52, 0.05).to(device)
#model.load_state_dict(torch.load(args.model_out))
opt = nnx.Optimizer(model.model, optax.adam(args.lr))
Y = jnp.array(np.concatenate([obs, rew_to_go], axis=-1))
X = jnp.array(actions)
data = {'samples': X, 'conditioners': Y}
train_key = jax.random.key(1)
model.train(train_key, data, opt, args.batch_size, args.epochs)
# Prepare state to save
_, state = nnx.split(model.model)
print("Checkpointing...")
# Save using checkpoint manager
# Checkpointing
ckpt_path = Path('saved_models/my_checkpoints/').resolve()
#ckpt_path.mkdir(parents=True, exist_ok=True) # Ensure directory exists
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(ckpt_path/'attempt_8', state)
sus_model = load_model(ckpt_path)
_, restored_state = nnx.split(sus_model)
assert(jax.tree.map(np.testing.assert_array_equal, restored_state, state))
#print(other_state)
print("Done checkpointing")
return
Load function
def load_model(checkpoint_dir, env_name='Pendulum-v1'):
# Initialize environment to get dimensions
env = gym.make(env_name)
action_dim = env.action_space.shape[0]
state_dim = env.observation_space.shape[0]
# Initialize model with same parameters as during training
model = ConditionalDiffuser(
dim=action_dim,
conditioning_dim=state_dim+1,
hidden_dim=1000,
beta_schedule='linear',
T=52,
sigma=0.05,
rngs=nnx.Rngs(0, noise=1)
)
abstract_model = nnx.eval_shape(lambda: ConditionalDiffuser(dim=action_dim, conditioning_dim=state_dim+1, hidden_dim= 300, beta_schedule='linear', T=252, sigma=0.1, rngs=nnx.Rngs(0, noise=1)))
graphdef, abstract_state = nnx.split(abstract_model)
checkpointer = ocp.StandardCheckpointer()
loaded_state = checkpointer.restore(checkpoint_dir/'attempt_8', abstract_state)
#jax.tree.map(np.testing.assert_array_equal, abstract_state, loaded_state)
#print(loaded_state)
#print(nnx.display(loaded_state))
model = nnx.merge(graphdef, loaded_state)
return model
Activity