Skip to content

Commit

Permalink
update to Lux 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed Oct 3, 2024
1 parent 0a11fb5 commit 4a9184a
Show file tree
Hide file tree
Showing 21 changed files with 52 additions and 49 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/misc_scripts/hess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ CUDA.allowscalar(false)

#==========================#
function testhessian(
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
data::Tuple;
device = cpu_device(),
)
Expand Down
2 changes: 1 addition & 1 deletion experiments_SDF/SDF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ end
#===========================================================#

function train_SDF(
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
casename::String,
modeldir::String,
E::Int; # num epochs
Expand Down
2 changes: 1 addition & 1 deletion experiments_SNFROM/convAE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ end
function train_CAE(
datafile::String,
modeldir::String,
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
E::Int; # num epochs
rng::Random.AbstractRNG = Random.default_rng(),
warmup::Bool = false,
Expand Down
2 changes: 1 addition & 1 deletion experiments_SNFROM/convINR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ end
function train_CINR(
datafile::String,
modeldir::String,
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
E::Int; # num epochs
rng::Random.AbstractRNG = Random.default_rng(),
warmup::Bool = true,
Expand Down
2 changes: 1 addition & 1 deletion experiments_SNFROM/smoothNF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ function compute_residual(
prob, model, timealg, scheme, Δt,
adaptive, autodiff_xyz, ϵ_xyz, learn_ic;
verbose::Bool = false,
device = Lux.gpu_deivce(),
device = Lux.gpu_device(),
)
res = zeros(Float32, size(Xdata, 2))

Expand Down
6 changes: 3 additions & 3 deletions misc/train_demo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ y = @. sin(1x)

data = (x, y)
NN = Chain(Dense(1, W, tanh), Dense(W, W, tanh), Dense(W, 1))
device = Lux.cpu_device()
device = Lux.gpu_device()
device = cpu_device()
device = gpu_device()

# # MIXED TEST
# @time (NN, p, st), ST = train_model(
Expand All @@ -23,7 +23,7 @@ device = Lux.gpu_device()
# )

# @time (NN, p, st), ST = train_model(NN, data)
trainer = Trainer(NN, data; device, verbose = true, patience_frac = .8)
trainer = Trainer(NN, data; device, verbose = true, patience_frac = .1)
@time model, ST = train!(trainer)

# @time train_model(NN, data; opts = (Optim.LBFGS(),), device)
Expand Down
6 changes: 4 additions & 2 deletions src/NeuralROMs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using CalculustCore

# ML Stack
using Lux
import LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer
using MLUtils
using Optimisers
using Optimization
Expand Down Expand Up @@ -41,7 +42,8 @@ using Adapt
using CUDA
using CUDA: AbstractGPUArray
using KernelAbstractions
import LuxDeviceUtils
using MLDataDevices
using MLDataDevices: AbstractGPUDevice

# numerical
using FFTW
Expand All @@ -53,7 +55,7 @@ using ComponentArrays
using Setfield: @set, @set!
using UnPack
using ConcreteStructs
using IterTools
using IterTools # train.jl BFGS

# linear/nonlinear solvers
using LinearSolve
Expand Down
10 changes: 5 additions & 5 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
# Hyper Network
#======================================================#

struct HyperNet{W <: Lux.AbstractExplicitLayer, C <: Lux.AbstractExplicitLayer, A} <:
Lux.AbstractExplicitContainerLayer{(:weight_generator, :evaluator)}
struct HyperNet{W <: AbstractLuxLayer, C <: AbstractLuxLayer, A} <:
Lux.AbstractLuxContainerLayer{(:weight_generator, :evaluator)}
weight_generator::W
evaluator::C
ca_axes::A
end

function HyperNet(
weight_gen::Lux.AbstractExplicitLayer,
evaluator::Lux.AbstractExplicitLayer
weight_gen::AbstractLuxLayer,
evaluator::AbstractLuxLayer
)
rng = Random.default_rng()
ca_axes = Lux.initialparameters(rng, evaluator) |> ComponentArray |> getaxes
Expand Down Expand Up @@ -77,7 +77,7 @@ SplitRows
Split rows of ND array, into `Tuple` of ND arrays.
"""
struct SplitRows{T} <: Lux.AbstractExplicitLayer
struct SplitRows{T} <: AbstractLuxLayer
splits::T
channel_dim::Int
end
Expand Down
24 changes: 12 additions & 12 deletions src/layers/encoder_decoder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ trajectory.
"""
function ImplicitEncoderDecoder(
encoder::Lux.AbstractExplicitLayer,
decoder::Lux.AbstractExplicitLayer,
encoder::AbstractLuxLayer,
decoder::AbstractLuxLayer,
Npoints::NTuple{D, Integer},
out_dim::Integer,
) where{D}
Expand Down Expand Up @@ -67,7 +67,7 @@ function ImplicitEncoderDecoder(
)
end

function get_encoder_decoder(NN::Lux.AbstractExplicitLayer, p, st)
function get_encoder_decoder(NN::AbstractLuxLayer, p, st)
encoder = (NN.layers.encode.layers.encoder, p.encode.encoder, st.encode.encoder)
decoder = (NN.layers.decoder, p.decoder, st.decoder)

Expand All @@ -83,7 +83,7 @@ end
Assumes input is `(xyz, idx)` of sizes `[in_dim, K]`, `[1, K]` respectively
"""
function AutoDecoder(
decoder::Lux.AbstractExplicitLayer,
decoder::AbstractLuxLayer,
num_batches::Int,
code_len::Int;
init_weight = randn32, # scale_init(randn32, 1f-1, 0f0) # N(μ = 0, σ2 = 0.1^2)
Expand All @@ -104,7 +104,7 @@ function AutoDecoder(
end

function get_autodecoder(
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
p::Union{NamedTuple, AbstractArray},
st::NamedTuple,
)
Expand All @@ -124,8 +124,8 @@ Input: `(x, param)` of sizes `[x_dim, K]`, and `[p_dim, K]` respectively.
Output: solution field `u` of size `[out_dim, K]`.
"""
function FlatDecoder(
hyper::Lux.AbstractExplicitLayer,
decoder::Lux.AbstractExplicitLayer,
hyper::AbstractLuxLayer,
decoder::AbstractLuxLayer,
)
noop = NoOpLayer()

Expand All @@ -136,7 +136,7 @@ function FlatDecoder(
end

function get_flatdecoder(
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
p::Union{NamedTuple, AbstractArray},
st::NamedTuple,
)
Expand All @@ -150,7 +150,7 @@ end
# OneEmbedding, freeze_decoder
#======================================================#

struct OneEmbedding{F} <: Lux.AbstractExplicitLayer
struct OneEmbedding{F} <: AbstractLuxLayer
len::Int
init::F
end
Expand Down Expand Up @@ -206,8 +206,8 @@ end
Assumes input is `(xyz, idx)` of sizes `[D, K]`, `[1, K]` respectively
"""
function HyperDecoder(
weight_gen::Lux.AbstractExplicitLayer,
evaluator::Lux.AbstractExplicitLayer,
weight_gen::AbstractLuxLayer,
evaluator::AbstractLuxLayer,
num_batches::Int,
code_len::Int;
init_weight = randn32,
Expand All @@ -228,7 +228,7 @@ function HyperDecoder(
HyperNet(code_gen, evaluator)
end

function get_hyperdecoder(NN::Lux.AbstractExplicitLayer, p, st)
function get_hyperdecoder(NN::AbstractLuxLayer, p, st)
NN.weight_generator.code, p.weight_generator.code, st.weight_generator.code
end
#======================================================#
Expand Down
4 changes: 2 additions & 2 deletions src/layers/experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ x -> sin(π⋅x/L)
Works when input is symmetric around 0, i.e., x ∈ [-1, 1).
If working with something like [0, 1], use cosines instead.
"""
@concrete struct PeriodicLayer <: Lux.AbstractExplicitLayer
@concrete struct PeriodicLayer <: AbstractLuxLayer
idxs
periods
end
Expand Down Expand Up @@ -49,7 +49,7 @@ end

export TanhKernel1D

@concrete struct TanhKernel1D{I<:Integer} <: Lux.AbstractExplicitLayer
@concrete struct TanhKernel1D{I<:Integer} <: AbstractLuxLayer
in_dim::I
out_dim::I
num_kernels::I
Expand Down
2 changes: 1 addition & 1 deletion src/layers/gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export Gaussian1D
# of the Gaussian. E.g. https://en.wikipedia.org/wiki/Gabor_filter
# [Gabor Splatting for High-Quality Gigapixel Image Representations]

@concrete struct Gaussian1D{I<:Integer} <: Lux.AbstractExplicitLayer
@concrete struct Gaussian1D{I<:Integer} <: AbstractLuxLayer
in_dim::I
out_dim::I
num_gauss::I
Expand Down
6 changes: 3 additions & 3 deletions src/layers/mfn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

export FourierMFN, GaborMFN

@concrete struct MFN{I} <: Lux.AbstractExplicitContainerLayer{(:filters, :linears)}
@concrete struct MFN{I} <: AbstractLuxContainerLayer{(:filters, :linears)}
in_dim::I
hd_dim::I
out_dim::I
Expand Down Expand Up @@ -117,8 +117,8 @@ end

export GaborLayer

# @concrete struct GaborLayer{I} <: Lux.AbstractExplicitContainerLayer{(:dense,)}
@concrete struct GaborLayer{I} <: Lux.AbstractExplicitLayer
# @concrete struct GaborLayer{I} <: AbstractLuxContainerLayer{(:dense,)}
@concrete struct GaborLayer{I} <: AbstractLuxLayer
in_dim::I
out_dim::I

Expand Down
10 changes: 5 additions & 5 deletions src/layers/sdf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ clamp_vanilla(x::AbstractArray , δ) = @. clamp(x, -δ, δ)
clamp_sigmoid(x::AbstractArray , δ) = @. δ * (2 * sigmoid_fast(x) - 1)
clamp_softsign(x::AbstractArray, δ) = @. δ * softsign(x)

struct ClampTanh{T <: Real} <: Lux.AbstractExplicitLayer; δ::T; end
struct ClampVanilla{T <: Real} <: Lux.AbstractExplicitLayer; δ::T; end
struct ClampSigmoid{T <: Real} <: Lux.AbstractExplicitLayer; δ::T; end
struct ClampSoftsign{T <: Real} <: Lux.AbstractExplicitLayer; δ::T; end
struct ClampTanh{T <: Real} <: AbstractLuxLayer; δ::T; end
struct ClampVanilla{T <: Real} <: AbstractLuxLayer; δ::T; end
struct ClampSigmoid{T <: Real} <: AbstractLuxLayer; δ::T; end
struct ClampSoftsign{T <: Real} <: AbstractLuxLayer; δ::T; end

Lux.initialstates(::ClampTanh ) = (;)
Lux.initialstates(::ClampVanilla ) = (;)
Expand Down Expand Up @@ -173,7 +173,7 @@ end

export SpatialHash, SpatialGrid

@concrete struct FeatureGrid{D} <: Lux.AbstractExplicitContainerLayer{(:embedding,)}
@concrete struct FeatureGrid{D} <: AbstractLuxContainerLayer{(:embedding,)}
shape
indexfun
interpfun
Expand Down
4 changes: 2 additions & 2 deletions src/neuralgridmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ export INRModel
end

function INRModel(
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
st::NamedTuple,
x::AbstractArray{T},
grid::NTuple{D, Integer}, # (Nx, Ny)
Expand Down Expand Up @@ -290,7 +290,7 @@ export CAEModel
end

function CAEModel(
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
p::Union{NamedTuple, ComponentVector},
st::NamedTuple,
x::AbstractArray{T},
Expand Down
2 changes: 1 addition & 1 deletion src/neuralmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
end

function NeuralModel(
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
st::NamedTuple,
metadata::NamedTuple,
)
Expand Down
4 changes: 2 additions & 2 deletions src/nonlinleastsq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function nonlinleastsq(
end
#======================================================#
function nonlinleastsq(
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
p0::Union{NamedTuple, AbstractVector},
st::NamedTuple,
data::NTuple{2, Any},
Expand Down Expand Up @@ -171,7 +171,7 @@ end
#======================================================#

function nonlinleastsq(
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
p0::Union{NamedTuple, AbstractVector},
st::NamedTuple,
data::NTuple{2, Any},
Expand Down
4 changes: 2 additions & 2 deletions src/operator/oplayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ Neural Operator convolution layer
so that eltype(params) is always real
"""
struct OpConv{D, F, I} <: Lux.AbstractExplicitLayer
struct OpConv{D, F, I} <: AbstractLuxLayer
ch_in::Int
ch_out::Int
modes::NTuple{D, Int}
Expand Down Expand Up @@ -163,7 +163,7 @@ end
Neural Operator bilinear convolution layer
"""
struct OpConvBilinear{D, F, I} <: Lux.AbstractExplicitLayer
struct OpConvBilinear{D, F, I} <: AbstractLuxLayer
ch_in1::Int
ch_in2::Int
ch_out::Int
Expand Down
2 changes: 1 addition & 1 deletion src/operator/transform.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# TODO: subtype AbstractTransform <: Lux.AbstractExplicitLayer and make
# TODO: subtype AbstractTransform <: AbstractLuxLayer and make
# TODO: OpConv(Bilinear) a hyper network/ Lux Container layer.
# TODO: then we can think about trainable transform types
abstract type AbstractTransform{D} end
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ end
#===========================================================#

function remake_ca_in_model(
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
p::Union{NamedTuple, AbstractArray},
st::NamedTuple,
)
Expand Down
2 changes: 1 addition & 1 deletion src/vis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ $SIGNATURES
function plot_1D_surrogate_steady(V::Spaces.AbstractSpace{<:Any, 1},
_data::NTuple{2, Any},
data_::NTuple{2, Any},
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
p,
st;
nsamples = 5,
Expand Down

0 comments on commit 4a9184a

Please sign in to comment.