Skip to content

Commit

Permalink
add conditional mnist example
Browse files Browse the repository at this point in the history
  • Loading branch information
Orozco committed May 10, 2024
1 parent 0433143 commit 1e5016f
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 40 deletions.
6 changes: 6 additions & 0 deletions CITATION.bib
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
@article{orozco2023invertiblenetworks,
title={InvertibleNetworks. jl: A Julia package for scalable normalizing flows},
author={Orozco, Rafael and Witte, Philipp and Louboutin, Mathias and Siahkoohi, Ali and Rizzuti, Gabrio and Peters, Bas and Herrmann, Felix J},
journal={arXiv preprint arXiv:2312.13480},
year={2023}
}
86 changes: 47 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,20 @@ Building blocks for invertible neural networks in the [Julia] programming langua

## Installation


InvertibleNetworks is registered and can be added like any standard Julia package with the command:

```
] add InvertibleNetworks
```


## Papers

The following publications use [InvertibleNetworks.jl]:

- **["Reliable amortized variational inference with physics-based latent distribution correction"]**
- paper: [https://arxiv.org/abs/2207.11640](https://arxiv.org/abs/2207.11640)
- [presentation](https://slim.gatech.edu/Publications/Public/Submitted/2022/siahkoohi2022ravi/slides.pdf)
- code: [ReliableAVI.jl]

- **["Learning by example: fast reliability-aware seismic imaging with normalizing flows"]**
- paper: [https://arxiv.org/abs/2104.06255](https://arxiv.org/abs/2104.06255)
- [presentation](https://slim.gatech.edu/Publications/Public/Conferences/KAUST/2021/siahkoohi2021EarthMLfar/siahkoohi2021EarthMLfar.pdf)
- code: [ReliabilityAwareImaging.jl]
## Uncertainty-aware image reconstruction

- **["Enabling uncertainty quantification for seismic data pre-processing using normalizing flows (NF)—an interpolation example"]**
- [paper](https://slim.gatech.edu/Publications/Public/Conferences/SEG/2021/kumar2021SEGeuq/kumar2021SEGeuq.pdf)
- code: [WavefieldRecoveryUQ.jl]

- **["Preconditioned training of normalizing flows for variational inference in inverse problems"]**
- paper: [https://arxiv.org/abs/2101.03709](https://arxiv.org/abs/2101.03709)
- [presentation](https://slim.gatech.edu/Publications/Public/Conferences/AABI/2021/siahkoohi2021AABIpto/siahkoohi2021AABIpto_pres.pdf)
- code: [FastApproximateInference.jl]
Due to its memory scaling InvertibleNetworks.jl has been particularily successful at Bayesian posterior sampling with simulation-based inference. To get started with this application please refer to a simple example ([Conditional sampling for MNSIT inpainting](https://github.com/slimgroup/InvertibleNetworks.jl/tree/master/examples/applications/application_conditional_mnist_inpainting.jl)) but please use this on your application and please reach out to us if you run into any trouble.

- **["Parameterizing uncertainty by deep invertible networks, an application to reservoir characterization"]**
- paper: [https://arxiv.org/abs/2004.07871](https://arxiv.org/abs/2004.07871)
- [presentation](https://slim.gatech.edu/Publications/Public/Conferences/SEG/2020/rizzuti2020SEGuqavp/rizzuti2020SEGuqavp_pres.pdf)
- code: [https://github.com/slimgroup/Software.SEG2020](https://github.com/slimgroup/Software.SEG2020)
![mnist_sampling_cond](docs/src/figures/mnist_sampling_cond.png)

- **["Generalized Minkowski sets for the regularization of inverse problems"]**
- paper: [http://arxiv.org/abs/1903.03942](http://arxiv.org/abs/1903.03942)
- code: [SetIntersectionProjection.jl]

## Building blocks

Expand Down Expand Up @@ -112,37 +87,70 @@ AN = ActNorm(k; logdet=true) |> gpu
Y_, logdet = AN.forward(X)
```

## Acknowledgments
## Reference

This package uses functions from [NNlib.jl](https://github.com/FluxML/NNlib.jl), [Flux.jl](https://github.com/FluxML/Flux.jl) and [Wavelets.jl](https://github.com/JuliaDSP/Wavelets.jl)
If you use InvertibleNetworks.jl in your research, we would be grateful if you cite us with the following bibtex:

```
@article{orozco2023invertiblenetworks,
title={InvertibleNetworks. jl: A Julia package for scalable normalizing flows},
author={Orozco, Rafael and Witte, Philipp and Louboutin, Mathias and Siahkoohi, Ali and Rizzuti, Gabrio and Peters, Bas and Herrmann, Felix J},
journal={arXiv preprint arXiv:2312.13480},
year={2023}
}
```


## References
## Papers

The following publications use [InvertibleNetworks.jl]:

- **["Reliable amortized variational inference with physics-based latent distribution correction"]**
- paper: [https://arxiv.org/abs/2207.11640](https://arxiv.org/abs/2207.11640)
- [presentation](https://slim.gatech.edu/Publications/Public/Submitted/2022/siahkoohi2022ravi/slides.pdf)
- code: [ReliableAVI.jl]

- Yann Dauphin, Angela Fan, Michael Auli and David Grangier, "Language modeling with gated convolutional networks", Proceedings of the 34th International Conference on Machine Learning, 2017. https://arxiv.org/pdf/1612.08083.pdf
- **["Learning by example: fast reliability-aware seismic imaging with normalizing flows"]**
- paper: [https://arxiv.org/abs/2104.06255](https://arxiv.org/abs/2104.06255)
- [presentation](https://slim.gatech.edu/Publications/Public/Conferences/KAUST/2021/siahkoohi2021EarthMLfar/siahkoohi2021EarthMLfar.pdf)
- code: [ReliabilityAwareImaging.jl]

- Laurent Dinh, Jascha Sohl-Dickstein and Samy Bengio, "Density estimation using Real NVP", International Conference on Learning Representations, 2017, https://arxiv.org/abs/1605.08803
- **["Enabling uncertainty quantification for seismic data pre-processing using normalizing flows (NF)—an interpolation example"]**
- [paper](https://slim.gatech.edu/Publications/Public/Conferences/SEG/2021/kumar2021SEGeuq/kumar2021SEGeuq.pdf)
- code: [WavefieldRecoveryUQ.jl]

- Diederik P. Kingma and Prafulla Dhariwal, "Glow: Generative Flow with Invertible 1x1 Convolutions", Conference on Neural Information Processing Systems, 2018. https://arxiv.org/abs/1807.03039
- **["Preconditioned training of normalizing flows for variational inference in inverse problems"]**
- paper: [https://arxiv.org/abs/2101.03709](https://arxiv.org/abs/2101.03709)
- [presentation](https://slim.gatech.edu/Publications/Public/Conferences/AABI/2021/siahkoohi2021AABIpto/siahkoohi2021AABIpto_pres.pdf)
- code: [FastApproximateInference.jl]

- Keegan Lensink, Eldad Haber and Bas Peters, "Fully Hyperbolic Convolutional Neural Networks", arXiv Computer Vision and Pattern Recognition, 2019. https://arxiv.org/abs/1905.10484
- **["Parameterizing uncertainty by deep invertible networks, an application to reservoir characterization"]**
- paper: [https://arxiv.org/abs/2004.07871](https://arxiv.org/abs/2004.07871)
- [presentation](https://slim.gatech.edu/Publications/Public/Conferences/SEG/2020/rizzuti2020SEGuqavp/rizzuti2020SEGuqavp_pres.pdf)
- code: [https://github.com/slimgroup/Software.SEG2020](https://github.com/slimgroup/Software.SEG2020)

- Patrick Putzky and Max Welling, "Invert to learn to invert", Advances in Neural Information Processing Systems, 2019. https://arxiv.org/abs/1911.10914
- **["Generalized Minkowski sets for the regularization of inverse problems"]**
- paper: [http://arxiv.org/abs/1903.03942](http://arxiv.org/abs/1903.03942)
- code: [SetIntersectionProjection.jl]

- Jakob Kruse, Gianluca Detommaso, Robert Scheichl and Ullrich Köthe, "HINT: Hierarchical Invertible Neural Transport for Density Estimation and Bayesian Inference", arXiv Statistics and Machine Learning, 2020. https://arxiv.org/abs/1905.10687

## Authors

- Rafael Orozco, Georgia Institute of Technology [rorozco@gatech.edu]

- Philipp Witte, Georgia Institute of Technology (now Microsoft)

- Gabrio Rizzuti, Utrecht University

- Mathias Louboutin, Georgia Institute of Technology

- Rafael Orozco, Georgia Institute of Technology

- Ali Siahkoohi, Georgia Institute of Technology


## Acknowledgments

This package uses functions from [NNlib.jl](https://github.com/FluxML/NNlib.jl), [Flux.jl](https://github.com/FluxML/Flux.jl) and [Wavelets.jl](https://github.com/JuliaDSP/Wavelets.jl)

[Flux]:https://fluxml.ai
[Julia]:https://julialang.org
[Zygote]:https://github.com/FluxML/Zygote.jl
Expand Down
162 changes: 162 additions & 0 deletions examples/applications/application_conditional_mnist_inpainting.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
using Pkg
Pkg.activate(".")

# Take around 6 minutes on CPU
using InvertibleNetworks
using Flux
using LinearAlgebra
using MLDatasets
using Statistics
using PyPlot
using ProgressMeter: Progress, next!
using Images
using MLUtils

function posterior_sampler(G, y, size_x; device=gpu, num_samples=1, batch_size=16)
# make samples from posterior for train sample
X_dummy = randn(Float32, size_x[1:end-1]...,batch_size) |> device
Y_repeat = repeat(y |>cpu, 1, 1, 1, batch_size) |> device
_, Zy_fixed, _ = G.forward(X_dummy, Y_repeat); #needs to set the proper sizes here

X_post = zeros(Float32, size_x[1:end-1]...,num_samples)
for i in 1:div(num_samples, batch_size)
Zx_noise_i = randn(Float32, size_x[1:end-1]...,batch_size)|> device
X_post[:,:,:, (i-1)*batch_size+1 : i*batch_size] = G.inverse(
Zx_noise_i,
Zy_fixed
) |> cpu;
end
X_post
end

# Training hyperparameters
device = cpu #GPU does not accelerate at this small size. quicker on cpu
lr = 2f-3
epochs = 30
batch_size = 128

# Load in training data
n_total = 2048
validation_perc = 0.9
X, _ = MNIST(split=:train)[1:(n_total)];

# Resize spatial size to a power of two to make Real-NVP multiscale easier.
nx = 16; ny = 16;
N = nx*ny
Xs = zeros(Float32, nx, ny, 1, n_total);
for i in 1:n_total
Xs[:,:,:,i] = imresize(X[:,:,i]', (nx, ny));
end

# Make Forward operator A
mask_size = 3 #number of pixels to zero out
mask_start = div((nx-mask_size),2)
A = ones(Float32,nx,ny)
A[mask_start:(end-mask_start),mask_start:(end-mask_start)] .= 0f0

# Make observations y
Ys = A .* Xs;

# To modify for your aplpication load in your own Ys and Xs here.
# julia> size(Ys)
# (16, 16, 1, 2048) (nx,ny,n_chan,n_batch)

# julia> size(Xs)
# (16, 16, 1, 2048) (nx,ny,n_chan,n_batch)

# Use MLutils to split into training and validation/test set
XY_train, XY_val = splitobs((Xs, Ys); at=validation_perc, shuffle=true);
train_loader = DataLoader(XY_train, batchsize=batch_size, shuffle=true, partial=false);

# Number of training batches
n_train = numobs(XY_train)
n_val = numobs(XY_val)
batches = cld(n_train, batch_size)
progress = Progress(epochs*batches);

# Architecture parametrs
chan_x = 1 # not RGB so chan=1
chan_y = 1 # not RGB so chan=1
L = 2 # Number of multiscale levels
K = 10 # Number of Real-NVP layers per multiscale level
n_hidden = 32 # Number of hidden channels in convolutional residual blocks

# Create network
G = NetworkConditionalGlow(chan_x, chan_y, n_hidden, L, K; split_scales=true ) |> device;

# Optimizer
opt = ADAM(lr)

# Training logs
loss_train = []; loss_val = [];

for e=1:epochs # epoch loop
for (X, Y) in train_loader #batch loop
ZX, ZY, logdet_i = G.forward(X|> device, Y|> device);
G.backward(ZX / batch_size, ZX, ZY)

for p in get_params(G)
Flux.update!(opt, p.data, p.grad)
end; clear_grad!(G) # clear gradients unless you need to accumulate

#Progress meter
append!(loss_train, norm(ZX)^2 / (N*batch_size) - logdet_i / N) # normalize by image size and batch size
next!(progress; showvalues=[(:objective, loss_train[end]),(:l2norm, norm(ZX)^2 / (N*batch_size))])
end
# Evaluate network on validation set
X = getobs(XY_val[1]) |> device;
Y = getobs(XY_val[2]) |> device;

ZX, ZY,logdet_i = G.forward(X, Y);
append!(loss_val, norm(ZX)^2 / (N*n_val) - logdet_i / N) # normalize by image size and batch size
end

# Training logs
final_obj_train = round(loss_train[end];digits=3)
final_obj_val = round(loss_val[end];digits=3)

fig = figure()
title("Objective value: train=$(final_obj_train) validation=$(final_obj_val)")
plot(loss_train;label="Train");
plot(batches:batches:batches*(epochs), loss_val;label="Validation");
xlabel("Parameter update"); ylabel("Negative log likelihood objective") ;
legend()
savefig("log.png",dpi=300)

# Make Figure of README
num_plot = 2
fig = figure(figsize=(11, 5));
for (i,ind) in enumerate([1,3])
x = XY_val[1][:,:,:,ind:ind]
y = XY_val[2][:,:,:,ind:ind]
X_post = posterior_sampler(G, y, size(x); device=device, num_samples=64) |> cpu

X_post_mean = mean(X_post; dims=ndims(X_post))
X_post_var = var(X_post; dims=ndims(X_post))

ssim_val = round(assess_ssim(X_post_mean[:,:,1,1], x[:,:,1,1]) ,digits=2)

subplot(num_plot,7,1+7*(i-1)); imshow(x[:,:,1,1], vmin=0, vmax=1, cmap="gray")
axis("off"); title(L"$x$");

subplot(num_plot,7,2+7*(i-1)); imshow(y[:,:,1,1] |> cpu, cmap="gray")
axis("off"); title(L"$y$");

subplot(num_plot,7,3+7*(i-1)); imshow(X_post_mean[:,:,1,1] , vmin=0, vmax=1, cmap="gray")
axis("off"); title("SSIM="*string(ssim_val)*" \n"*"Conditional Mean") ;

subplot(num_plot,7,4+7*(i-1)); imshow(X_post_var[:,:,1,1] , cmap="magma")
axis("off"); title(L"$UQ$") ;

subplot(num_plot,7,5+7*(i-1)); imshow(X_post[:,:,1,1] |> cpu, vmin=0, vmax=1,cmap="gray")
axis("off"); title("Posterior Sample") ;

subplot(num_plot,7,6+7*(i-1)); imshow(X_post[:,:,1,2] |> cpu, vmin=0, vmax=1,cmap="gray")
axis("off"); title("Posterior Sample") ;

subplot(num_plot,7,7+7*(i-1)); imshow(X_post[:,:,1,3] |> cpu, vmin=0, vmax=1, cmap="gray")
axis("off"); title("Posterior Sample") ;
end
tight_layout()
savefig("mnist_sampling_cond.png",dpi=300,bbox_inches="tight")

2 changes: 1 addition & 1 deletion src/utils/dimensionality_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ end

# Split and reshape 1D vector Y in latent space back to states Zi
# where Zi is the split tensor at each multiscale level.
function split_states(Y::AbstractVector{T}, Z_dims; L_net=2) where {T, N}
function split_states(Y::AbstractVector{T}, Z_dims; L_net=2) where {T}
L = length(Z_dims) + 1
inds = cumsum([1, [prod(Z_dims[j]) for j=1:L-1]...])
Z_save = [reshape(Y[inds[j]:inds[j+1]-1], xy_dims(Z_dims[j], Val(j==L), Val(length(Z_dims[j])))) for j=1:L-1]
Expand Down

0 comments on commit 1e5016f

Please sign in to comment.