Skip to content

NotImplementedError: When I use stax.DotGeneral #121

Open
@kkeevin123456

Description

Hi, when I try to implement two-layer coupling layer like below image. I got this error, do you have any insight to fix it?

image

The error looks like:

image

Some direction I had try

  • Since ResNet can work, the problem must not be FanInSum
  • I also set is_gaussian to be True
  • I think the problem may occur when I try to kernelize my architecture, so I try the optimizers.sgd to train my network. It works, but I still need to kernelize it

Here is some code can reproduce error:

    from jax import random
    from neural_tangents import stax
    import jax.numpy as np
    import neural_tangents as nt

    def DenseBlock(neurons):
        return stax.serial(
            stax.Dense(neurons), stax.Relu()
        )
    
    def ReluNetwork(latent_dim, hidden_dim, num_layers):
        """Create the network which is embedd in flow_base model
        
        Args:
            latent_dim: input and output dim
            hidden_dim: the width dim of the ReluNetwork
            num_layers: depth of the ReluNetwork
        
        Returns:
            stax.serial(ReluNetwork)
        """
        blocks = [DenseBlock(hidden_dim)]
        for _ in range(num_layers):
            blocks += [DenseBlock(hidden_dim)]
        blocks += [stax.Dense(latent_dim)]
        
        return stax.serial(*blocks)
    
    def lower_path(input_dim):
        helf_dim = input_dim//2
        # pre_half's rhs
        rhs1 = np.identity(helf_dim)
        rhs1 = np.pad(rhs1, ((0, 0), (0, helf_dim)))
        rhs1 = np.reshape(rhs1, (*rhs1.shape, 1))
        
        # post_half's rhs
        rhs2 = np.identity(helf_dim)
        rhs2 = np.pad(rhs2, ((helf_dim, 0), (helf_dim, 0)))
        rhs2 = np.reshape(rhs2, (*rhs2.shape, 1))
        
        
        rhs4 = np.identity(helf_dim)
        rhs4 = np.pad(rhs4, ((helf_dim, 0), (0, 0)))
        rhs4 = np.reshape(rhs4, (*rhs4.shape, 1))
        
        blocks = [
            stax.DotGeneral(
                    rhs = rhs1,
                    dimension_numbers = (((2,), (1,)), ((), ())),
                    channel_axis = 1
                ), 
            stax.DotGeneral(
                    rhs = np.array([1]),
                    dimension_numbers = (((3,), (0,)), ((), ())),
                    channel_axis = 1
                )]
        blocks += [ReluNetwork(latent_dim=helf_dim, hidden_dim=helf_dim//2, num_layers=4)]
        blocks += [
            stax.DotGeneral(
                    rhs = rhs4,
                    dimension_numbers = (((2,), (1,)), ((), ())),
                    channel_axis = 1
                ), 
            stax.DotGeneral(
                    rhs = np.array([1]),
                    dimension_numbers = (((3,), (0,)), ((), ())),
                    channel_axis = 1
                )]
        
        
        pre_half = stax.serial(
            *blocks
        )
        post_half = stax.serial(
            stax.DotGeneral(
                    rhs = rhs2,
                    dimension_numbers = (((2,), (1,)), ((), ())),
                    channel_axis = 1
                ), 
            stax.DotGeneral(
                    rhs = np.array([1]),
                    dimension_numbers = (((3,), (0,)), ((), ())),
                    channel_axis = 1
                )
        )
        return stax.serial(stax.FanOut(2),
                           stax.parallel(pre_half, post_half),
                           stax.FanInSum()
                          )
    
    def AdditiveCouplingLayer(input_dim, order):
        """the additive couplinglayer in the paper
        
        Args:
            nonlinearity: the ReluNetwork
        
        Returns:
            stax.serial(AdditiveCouplingLayer)
        """
        helf_dim = input_dim//2
        
        rhs_matrix = np.identity(helf_dim)
        rhs_matrix = np.pad(rhs_matrix, ((0, helf_dim), (0, helf_dim)))
        rhs_matrix = np.reshape(rhs_matrix, (*rhs_matrix.shape, 1))
    
        upper_path = stax.serial(
            stax.DotGeneral(
                    rhs = rhs_matrix,
                    dimension_numbers = (((2,), (1,)), ((), ())),
                    channel_axis = 1
                ), 
            stax.DotGeneral(
                    rhs = np.array([1]),
                    dimension_numbers = (((3,), (0,)), ((), ())),
                    channel_axis = 1
                )
        )
        
        return stax.serial(stax.FanOut(2),
                           stax.parallel(upper_path, lower_path(input_dim)),
                           stax.FanInSum()
                          )
    def LogisticPriorLoss(fx, y):
        return np.mean((0.5*np.sum(np.power(fx, 2), axis=1) + fx.shape[1]*0.5*np.log(2*np.pi)))

    # test
    x = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]])
    x = np.reshape(x, (x.shape[0], 1, *x.shape[1:]))
    input_dim = x.shape[2]  # (B, 1, 4): B is batch size
    helf_dim = input_dim//2
    
    init_fn, apply_fn, kernel_fn = lower_path(input_dim=input_dim)
    
    key = random.PRNGKey(1)
    _, params = init_fn(key, input_shape=x.shape)
    
    # z_train.dim = x_train.dim
    z_train = random.normal(key, x.shape)
    x_test = np.array([[1, 2, 3, 4, 5, 6]])
    x_test = np.reshape(x_test, (x_test.shape[0], 1, *x_test.shape[1:]))
    
    ntk_train_train = kernel_fn(x, x, 'ntk', channel_axis=1, is_gaussian=True)
    ntk_test_train = kernel_fn(x_test, x, 'ntk')
    predictor = nt.predict.gradient_descent(LogisticPriorLoss, ntk_train_train, z_train)

Many thanks for your kindly reply.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions