Skip to content

Weights of Restored model after checkpointing do not give same loss as model before saving #1402

Open
@delara38

Description

Hello,

I have a class that contains an nnx.Module and trains it. I try to save and restore by accessing this attribute but as the title says I find that when I restore the model, it's loss is as bad as a randomly initialized model.

I have no way to describe the problem as anything or than the title says because I will train a model, halve the loss from it's initialization, save the model using the instructions in the tutorial on saving and loading models (or the instructions given here google/flax#4383, or the instructions on the orbax website) and then restore them in another file and re-run the training loop. However at the final step my loss is the same as the loss I got at initialization. Note, that the parameters are not the ones I had at initialization but completely different ones that are equally poor when evaluated on my objective function.

I have attatched the code for my model, my training file, and my loading function.

Model file:

@nnx.jit
def training_step(model, optimizer, key, X, Y, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, maxT):
    batch_size = X.shape[0] 
    t_key, key = random.split(key)
    t = random.randint(t_key, (batch_size//2+1), 0, maxT)

    t = jnp.concatenate([t, maxT-t-1], axis=-1)[:batch_size]

    ab_s = jnp.take(alphas_bar_sqrt, t)[..., None]
    am1 = jnp.take(one_minus_alphas_bar_sqrt, t)[..., None]

    noise_key, key = random.split(key)
    noise = random.normal(noise_key, X.shape)

    x_ts = ab_s*X + am1*noise


    
    def conditional_model_estimation_loss(model):

        preds = model(x_ts, Y, t[..., None])

        loss = jnp.mean(jnp.square(preds-noise))
        return loss 

    loss, grads = nnx.value_and_grad(conditional_model_estimation_loss)(model)

    optimizer.update(grads)
    return loss

def make_beta_schedule(schedule = 'linear', T=100, start=1e-5, end = 0.5e-2):
    if schedule == 'linear':
        betas = jnp.linspace(start, end, T)
    elif schedule == 'cosine':
        fn = lambda x: jnp.cos(x/T + np.pi/2).pow(2)
        betas = fn(jnp.linspace(0,1, T))
    elif schedule == "sigmoid":
        betas = jnp.linspace(-6, 6, T)
        
        betas = jax.nn.sigmoid(betas) * (end - start) + start
    
    else:
        raise ValueError("schedule not implemented yet")
    return betas

class ConditionalLinear(nnx.Module):
    def __init__(self, num_in, num_out, n_steps, rngs):
        self.lin = nnx.Linear(num_in, num_out, rngs = rngs)
        self.embed = nnx.Embed(n_steps, num_out, rngs = rngs)

    def __call__(self, x, t):
        xout = self.lin(x)
        em = jnp.reshape(self.embed(t), xout.shape)
        #em = jnp.squeeze(self.embed(t), -2)
        out = xout*em
        #print(xout.shape, out.shape, em.shape, x.shape)
        return out 


class ConditionalDiffusionModel(nnx.Module):
    def __init__(self, dim, conditioning_dim, hidden_dim, T, rngs):
        self.dim = dim 

        self.cond_emb1 = ConditionalLinear(conditioning_dim, hidden_dim, T, rngs)
        self.cond_emb2 = ConditionalLinear(conditioning_dim, hidden_dim, T, rngs)
        self.cond_emb3 = ConditionalLinear(conditioning_dim, hidden_dim, T, rngs) 
        
        self.l1 = ConditionalLinear(dim+hidden_dim, hidden_dim, T, rngs)
        self.l2 = ConditionalLinear(hidden_dim, hidden_dim, T, rngs)
        self.l3 = ConditionalLinear(hidden_dim, hidden_dim, T, rngs)
        self.last_layer = nnx.Linear(hidden_dim, dim, rngs=rngs)
        



    def __call__(self, x, y, ts):
        yemb = nnx.softplus(self.cond_emb1(y, ts))
        yemb = nnx.softplus(self.cond_emb2(y, ts))
        yemb = nnx.softplus(self.cond_emb3(y, ts))

        xus = jnp.concatenate([x, yemb], axis=-1)
        xus = nnx.softplus(self.l1(xus, ts))
        xus = nnx.softplus(self.l2(xus, ts))
        xus = nnx.softplus(self.l3(xus, ts))

        preds = self.last_layer(xus)

        return preds
    
class ConditionalDiffuser:
    def __init__(self, dim, conditioning_dim, hidden_dim, beta_schedule, T, sigma, rngs):
        self.model = ConditionalDiffusionModel(dim, conditioning_dim, hidden_dim, T, rngs)
        self.betas = jnp.array(make_beta_schedule(beta_schedule, T))
        self.alphas = 1-self.betas 
        self.alpha_bars = jnp.cumprod(self.alphas)
        self.alphas_bar_sqrt = jnp.sqrt(self.alpha_bars)
        self.one_minus_alphas_bar_sqrt = jnp.sqrt(1-self.alpha_bars)
        self.alpha_bars_p = jnp.concatenate([jnp.array([1]), self.alpha_bars[:-1]])
        self.T = T
        self.sigma = sigma 

    def __call__(self, x):
        return self.model(x)

    def get_model(self):
        return self.model
    def train(self, key, dataset, opt, batch_size, epochs):
        epoch_loss = [0]
        ema_loss = None 
        ema = 0.9

        for e in range(epochs):
            if e % 10 == 0:
                print(f"Epoch: {e}\\t: epoch loss {epoch_loss[-1]}, ema loss {ema_loss}")
            
            el = []
            key, perm_key = random.split(key)
            permutation = random.permutation(perm_key, dataset['samples'].shape[0])

            with tqdm(range(0, dataset['samples'].shape[0], batch_size)) as tp:
                for i in tp:

                    indices = permutation[i:i+batch_size]
                    X = dataset['samples'][indices]
                    Y = dataset['conditioners'][indices]

                    step_key, key = random.split(key)
                    loss = training_step(self.model, opt, step_key, X, Y, self.alphas_bar_sqrt, self.one_minus_alphas_bar_sqrt, self.T)

                    el.append(loss)
                    if ema_loss is None:
                        ema_loss = loss 
                    else:
                        ema_loss = ema*ema_loss + (1-ema)*loss 
                    
                    tp.set_postfix(loss = np.mean(el[-40:]))
            el = np.mean(el)
            epoch_loss.append(el)
    

    def backward(self, x, y, ts, key):


        if type(self.sigma) == list:
            sigmas = np.choose(self.sigma, ts)
        else:
            sigmas = self.sigma
    
        
        a_roots = jnp.sqrt(jnp.take(self.alphas,(ts-1).astype(jnp.int32))[..., None])
        betas = jnp.take(self.betas, ts.astype(jnp.int32))[..., None]
        one_minus_abar_roots = jnp.sqrt(1 - jnp.take(self.alpha_bars, ts.astype(jnp.int32))[..., None])

        sigma_ts = jnp.sqrt( betas )


        pred_noise = self.model(x, y, ts.astype(jnp.int32))

        n_key, key = random.split(key)
        x_t = (1/a_roots) * (x - ( betas / one_minus_abar_roots )*pred_noise) + sigma_ts * random.normal(n_key, x.shape)
        


        return x_t
    
    
    def complete_backward(self, x,y, T, key):
        for t in range(1,T):
            
            #xus = torch.concatenate([x, torch.ones(x.shape[0], 1)*(T-t)], dim=-1)
            diff_key, key = random.split(key)
            ts = (jnp.ones((x.shape[0], ))*(T-t))
            x = self.backward(x,y, ts, diff_key )
        return x

Training file

def main(args):


    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    
    env = gym.make(args.env)
    action_dim = env.action_space.shape[0]
    state_dim = env.observation_space.shape[0]


    dataset = np.load(args.dataset_path)

    obs = dataset['observations']
    actions = dataset['actions']
    rew_to_go = dataset['reward_to_go']
    print(rew_to_go.shape)
    Y = jnp.array(np.concatenate([obs, rew_to_go], axis=-1))

    X = jnp.array(actions)
    data = {'samples': X, 'conditioners': Y}

    if args.restore:
            print("Loading from Checkpoint")
            ckpt_path = Path('saved_models/my_checkpoints/').resolve()
            diff_model = load_model(ckpt_path)
            model = ConditionalDiffuser(dim=action_dim, conditioning_dim=state_dim+1, hidden_dim= 300, beta_schedule='linear', T=252, sigma=0.1, rngs=nnx.Rngs(0, noise=1))
            model.model = diff_model

    else:
            print("Building Fresh Model")
            model = ConditionalDiffuser(dim=action_dim, conditioning_dim=state_dim+1, hidden_dim= 300, beta_schedule='linear', T=252, sigma=0.1, rngs=nnx.Rngs(0, noise=1))
        #model = ConditionalDiffusionModel(action_dim, state_dim+1, 1000, 8, 'linear', 52, 0.05).to(device)
        #model.load_state_dict(torch.load(args.model_out))

    opt = nnx.Optimizer(model.model, optax.adam(args.lr))

    Y = jnp.array(np.concatenate([obs, rew_to_go], axis=-1))

    X = jnp.array(actions)
    data = {'samples': X, 'conditioners': Y}

     train_key = jax.random.key(1)
     model.train(train_key, data, opt, args.batch_size, args.epochs)
        
        
        # Prepare state to save
     _,  state = nnx.split(model.model)

     print("Checkpointing...")
        # Save using checkpoint manager
        
        # Checkpointing
     ckpt_path = Path('saved_models/my_checkpoints/').resolve()
        #ckpt_path.mkdir(parents=True, exist_ok=True)  # Ensure directory exists
        
     checkpointer = ocp.StandardCheckpointer()
     checkpointer.save(ckpt_path/'attempt_8', state)
        

     sus_model = load_model(ckpt_path)
     _, restored_state = nnx.split(sus_model)
     assert(jax.tree.map(np.testing.assert_array_equal, restored_state, state))
        #print(other_state)

     print("Done checkpointing")
    return

Load function

def load_model(checkpoint_dir, env_name='Pendulum-v1'):
    # Initialize environment to get dimensions
    env = gym.make(env_name)
    action_dim = env.action_space.shape[0]
    state_dim = env.observation_space.shape[0]
    
    # Initialize model with same parameters as during training
    model = ConditionalDiffuser(
        dim=action_dim,
        conditioning_dim=state_dim+1,
        hidden_dim=1000,
        beta_schedule='linear',
        T=52,
        sigma=0.05,
        rngs=nnx.Rngs(0, noise=1)
    )
    
    abstract_model = nnx.eval_shape(lambda: ConditionalDiffuser(dim=action_dim, conditioning_dim=state_dim+1, hidden_dim= 300, beta_schedule='linear', T=252, sigma=0.1, rngs=nnx.Rngs(0, noise=1)))
    graphdef, abstract_state = nnx.split(abstract_model)

    
    checkpointer = ocp.StandardCheckpointer()
    loaded_state = checkpointer.restore(checkpoint_dir/'attempt_8', abstract_state)
    
    #jax.tree.map(np.testing.assert_array_equal, abstract_state, loaded_state)
    #print(loaded_state)
    #print(nnx.display(loaded_state))
    model = nnx.merge(graphdef, loaded_state)


    return model

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions