Open
Description
I'm looking to use the library to compute the after kernel for a model trained with the FLAX library? I followed this Colab: https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_resnet.ipynb.
Instead of these lines:
params = model.init(random.PRNGKey(0), x1)
return params, (jacobian_contraction, ntvp, str_derivatives, auto)
params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
I used the params from the following TrainState of the FLAX model:
state = TrainState.create(
apply_fn = model.apply,
params = variables['params'],
batch_stats = variables['batch_stats'],
tx = tx)
I was wondering if this is the correct way to do this? Thanks!
Activity