NTK/NNGP behavior in the infinite regime when weights are drawn from Gaussians with high standard deviationΒ #197
Open
Description
opened on Dec 29, 2023
Hi everyone, thank you so much for your exceptional work!
I'm encountering some numerical issues when weights are drawn from Gaussians with a high standard deviation. Please see the snippet below:
import numpy as np
from neural_tangents import stax
from jax import jit
W_stds = list(range(1, 17))
# W_stds.reverse()
layer_fn = []
for i in range(len(W_stds) - 1):
layer_fn.append(stax.Dense(1, W_std=W_stds[i]))
layer_fn.append(stax.Relu())
layer_fn.append(stax.Dense(1, 1.0, 0.0))
_, _, kernel_fn = stax.serial(*layer_fn)
kernel_fn = jit(kernel_fn, static_argnames="get")
x = np.random.rand(100, 100)
print(kernel_fn(x, x, "ntk"))
The result achieves:
[[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]
...
[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]]
By enabling float64 precision, the results indicate numerical values blowing up:
[[2.2293401e+18 9.3420067e+17 9.2034030e+17 ... 8.9008971e+17
9.6801663e+17 9.6436509e+17]
[9.3420067e+17 2.3730658e+18 9.4658846e+17 ... 9.6854199e+17
9.6182735e+17 9.9944418e+17]
[9.2034030e+17 9.4658846e+17 2.3106050e+18 ... 9.1702287e+17
9.5415269e+17 9.9692925e+17]
...
[8.9008971e+17 9.6854199e+17 9.1702300e+17 ... 2.2127619e+18
9.2056034e+17 1.0147568e+18]
[9.6801663e+17 9.6182728e+17 9.5415269e+17 ... 9.2056034e+17
2.3979914e+18 9.9505658e+17]
[9.6436488e+17 9.9944418e+17 9.9692925e+17 ... 1.0147568e+18
9.9505658e+17 2.4954969e+18]]
What's interesting is that the behavior appears to be more dependent on the depth than the high values in the weights' standard deviation. If the standard deviation of the weights were reversed (by uncommenting the code), so that in layer 1 we would have
Thank you in advance, and happy new year!
Activity