Skip to content

NTK/NNGP behavior in the infinite regime when weights are drawn from Gaussians with high standard deviationΒ #197

Open
@tengandreaxu

Description

Hi everyone, thank you so much for your exceptional work!

I'm encountering some numerical issues when weights are drawn from Gaussians with a high standard deviation. Please see the snippet below:

import numpy as np
from neural_tangents import stax
from jax import jit

W_stds = list(range(1, 17))
# W_stds.reverse()
layer_fn = []
for i in range(len(W_stds) - 1):
    layer_fn.append(stax.Dense(1, W_std=W_stds[i]))
    layer_fn.append(stax.Relu())

layer_fn.append(stax.Dense(1, 1.0, 0.0))
_, _, kernel_fn = stax.serial(*layer_fn)

kernel_fn = jit(kernel_fn, static_argnames="get")

x = np.random.rand(100, 100)

print(kernel_fn(x, x, "ntk"))

The result achieves:

[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]

By enabling float64 precision, the results indicate numerical values blowing up:

[[2.2293401e+18 9.3420067e+17 9.2034030e+17 ... 8.9008971e+17
  9.6801663e+17 9.6436509e+17]
 [9.3420067e+17 2.3730658e+18 9.4658846e+17 ... 9.6854199e+17
  9.6182735e+17 9.9944418e+17]
 [9.2034030e+17 9.4658846e+17 2.3106050e+18 ... 9.1702287e+17
  9.5415269e+17 9.9692925e+17]
 ...
 [8.9008971e+17 9.6854199e+17 9.1702300e+17 ... 2.2127619e+18
  9.2056034e+17 1.0147568e+18]
 [9.6801663e+17 9.6182728e+17 9.5415269e+17 ... 9.2056034e+17
  2.3979914e+18 9.9505658e+17]
 [9.6436488e+17 9.9944418e+17 9.9692925e+17 ... 1.0147568e+18
  9.9505658e+17 2.4954969e+18]]

What's interesting is that the behavior appears to be more dependent on the depth than the high values in the weights' standard deviation. If the standard deviation of the weights were reversed (by uncommenting the code), so that in layer 1 we would have $w_{ij} \sim \mathcal{N}(0,17)$, and so on so forth. The results would remain unchanged.

Thank you in advance, and happy new year!

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions