Open
Description
opened on Jul 26, 2021
Hi, when I try to implement two-layer coupling layer like below image. I got this error, do you have any insight to fix it?
The error looks like:
Some direction I had try
- Since ResNet can work, the problem must not be FanInSum
- I also set
is_gaussian
to be True - I think the problem may occur when I try to kernelize my architecture, so I try the
optimizers.sgd
to train my network. It works, but I still need to kernelize it
Here is some code can reproduce error:
from jax import random
from neural_tangents import stax
import jax.numpy as np
import neural_tangents as nt
def DenseBlock(neurons):
return stax.serial(
stax.Dense(neurons), stax.Relu()
)
def ReluNetwork(latent_dim, hidden_dim, num_layers):
"""Create the network which is embedd in flow_base model
Args:
latent_dim: input and output dim
hidden_dim: the width dim of the ReluNetwork
num_layers: depth of the ReluNetwork
Returns:
stax.serial(ReluNetwork)
"""
blocks = [DenseBlock(hidden_dim)]
for _ in range(num_layers):
blocks += [DenseBlock(hidden_dim)]
blocks += [stax.Dense(latent_dim)]
return stax.serial(*blocks)
def lower_path(input_dim):
helf_dim = input_dim//2
# pre_half's rhs
rhs1 = np.identity(helf_dim)
rhs1 = np.pad(rhs1, ((0, 0), (0, helf_dim)))
rhs1 = np.reshape(rhs1, (*rhs1.shape, 1))
# post_half's rhs
rhs2 = np.identity(helf_dim)
rhs2 = np.pad(rhs2, ((helf_dim, 0), (helf_dim, 0)))
rhs2 = np.reshape(rhs2, (*rhs2.shape, 1))
rhs4 = np.identity(helf_dim)
rhs4 = np.pad(rhs4, ((helf_dim, 0), (0, 0)))
rhs4 = np.reshape(rhs4, (*rhs4.shape, 1))
blocks = [
stax.DotGeneral(
rhs = rhs1,
dimension_numbers = (((2,), (1,)), ((), ())),
channel_axis = 1
),
stax.DotGeneral(
rhs = np.array([1]),
dimension_numbers = (((3,), (0,)), ((), ())),
channel_axis = 1
)]
blocks += [ReluNetwork(latent_dim=helf_dim, hidden_dim=helf_dim//2, num_layers=4)]
blocks += [
stax.DotGeneral(
rhs = rhs4,
dimension_numbers = (((2,), (1,)), ((), ())),
channel_axis = 1
),
stax.DotGeneral(
rhs = np.array([1]),
dimension_numbers = (((3,), (0,)), ((), ())),
channel_axis = 1
)]
pre_half = stax.serial(
*blocks
)
post_half = stax.serial(
stax.DotGeneral(
rhs = rhs2,
dimension_numbers = (((2,), (1,)), ((), ())),
channel_axis = 1
),
stax.DotGeneral(
rhs = np.array([1]),
dimension_numbers = (((3,), (0,)), ((), ())),
channel_axis = 1
)
)
return stax.serial(stax.FanOut(2),
stax.parallel(pre_half, post_half),
stax.FanInSum()
)
def AdditiveCouplingLayer(input_dim, order):
"""the additive couplinglayer in the paper
Args:
nonlinearity: the ReluNetwork
Returns:
stax.serial(AdditiveCouplingLayer)
"""
helf_dim = input_dim//2
rhs_matrix = np.identity(helf_dim)
rhs_matrix = np.pad(rhs_matrix, ((0, helf_dim), (0, helf_dim)))
rhs_matrix = np.reshape(rhs_matrix, (*rhs_matrix.shape, 1))
upper_path = stax.serial(
stax.DotGeneral(
rhs = rhs_matrix,
dimension_numbers = (((2,), (1,)), ((), ())),
channel_axis = 1
),
stax.DotGeneral(
rhs = np.array([1]),
dimension_numbers = (((3,), (0,)), ((), ())),
channel_axis = 1
)
)
return stax.serial(stax.FanOut(2),
stax.parallel(upper_path, lower_path(input_dim)),
stax.FanInSum()
)
def LogisticPriorLoss(fx, y):
return np.mean((0.5*np.sum(np.power(fx, 2), axis=1) + fx.shape[1]*0.5*np.log(2*np.pi)))
# test
x = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]])
x = np.reshape(x, (x.shape[0], 1, *x.shape[1:]))
input_dim = x.shape[2] # (B, 1, 4): B is batch size
helf_dim = input_dim//2
init_fn, apply_fn, kernel_fn = lower_path(input_dim=input_dim)
key = random.PRNGKey(1)
_, params = init_fn(key, input_shape=x.shape)
# z_train.dim = x_train.dim
z_train = random.normal(key, x.shape)
x_test = np.array([[1, 2, 3, 4, 5, 6]])
x_test = np.reshape(x_test, (x_test.shape[0], 1, *x_test.shape[1:]))
ntk_train_train = kernel_fn(x, x, 'ntk', channel_axis=1, is_gaussian=True)
ntk_test_train = kernel_fn(x_test, x, 'ntk')
predictor = nt.predict.gradient_descent(LogisticPriorLoss, ntk_train_train, z_train)
Many thanks for your kindly reply.
Activity