-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
154 additions
and
304 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
.vscode/ | ||
|
||
temp/ | ||
# Files generated by invoking Julia with --code-coverage | ||
*.jl.cov | ||
*.jl.*.cov | ||
|
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
name = "ArrayPadding" | ||
uuid = "fe5ef43c-4912-423d-9170-4bd5cb6a013a" | ||
authors = ["Paul Shen <pxshen@alumni.stanford.edu>"] | ||
version = "1.0.1" | ||
version = "1.0.2" | ||
|
||
[deps] | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[compat] | ||
Statistics = "1" | ||
Zygote = "0.6,0.7,1" | ||
julia = "1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,175 +1,14 @@ | ||
# ArrayPadding.jl | ||
|
||
Pads arrays of any dimension with various border options including constants, periodic, symmetric, mirror and smooth. Can control amount of padding applied to the left and right side of each dimension. Fully differentiable (compatible with `Zygote.jl` `Flux.jl`) | ||
Pads arrays of any dimension with various border options including constants, periodic, symmetric, mirror and smooth. Can control amount of padding applied to the left and right side of each dimension. Fully differentiable and GPU compatible (plays well with `Zygote.jl` `Flux.jl` `CUDA.jl` ) | ||
|
||
`pad` eagerly allocates new bigger array while `pad!` mutates original array in place. For `pad!,` the effective border is set back by the padding amount when evaluating various border options. `pad` is AD (automatic differentiation) compatible while `pad!` requires `Zygote.Buffer` for AD (see usage ) | ||
`pad` eagerly allocates new bigger array while `pad!` mutates original array in place. `pad!` assumes argument to be the preallocated padded array so the effective border is set back by the padding amount when evaluating various border options. `pad` is AD (automatic differentiation) compatible while `pad!` requires `Zygote.Buffer` for AD (see usage ) | ||
|
||
```julia | ||
pad(array, border, pad_amount) | ||
pad(array, border, left_pad_amount, right_pad_amount) | ||
pad(array, border, pad_amounts_for_dims) | ||
pad(array, border, left_pad_amount_for_dims, right_pad_amount_for_dims) | ||
See `test/runtests.jl` for examples of usage. | ||
|
||
pad!(array, border, pad_amount) | ||
pad!(array, border, left_pad_amount, right_pad_amount) | ||
pad!(array, border, pad_amounts_for_dims) | ||
pad!(array, border, left_pad_amount_for_dims, right_pad_amount_for_dims) | ||
``` | ||
Consider starring or sponsoring us if you found this repo helpful. Feel free to request features or contribute PRs :) | ||
|
||
## Border options | ||
- any constant value `v`: a b c | v v | ||
- `:periodic`: a b c | a b | ||
- `:replicate`: a b c | c c | ||
- `:symmetric`: a b c | c b | ||
- `:mirror`: a b c | b a | ||
- `:smooth`: a b c | 2c-b (Maintains C1 continuity) | ||
|
||
## Usage | ||
```julia | ||
a = collect(reshape(1:16, 4, 4)) | ||
|
||
@test pad(a, -1, 1) == [ | ||
-1 -1 -1 -1 -1 -1 | ||
-1 1 5 9 13 -1 | ||
-1 2 6 10 14 -1 | ||
-1 3 7 11 15 -1 | ||
-1 4 8 12 16 -1 | ||
-1 -1 -1 -1 -1 -1 | ||
] | ||
|
||
@test pad!(copy(a), -1, 1) == [ | ||
-1 -1 -1 -1 | ||
-1 6 10 -1 | ||
-1 7 11 -1 | ||
-1 -1 -1 -1 | ||
] | ||
|
||
@test pad(a, -1, 1, 0) == [ | ||
-1 -1 -1 -1 -1 | ||
-1 1 5 9 13 | ||
-1 2 6 10 14 | ||
-1 3 7 11 15 | ||
-1 4 8 12 16 | ||
] | ||
|
||
@test pad!(copy(a), -1, 1, 0) == [ | ||
-1 -1 -1 -1 | ||
-1 6 10 14 | ||
-1 7 11 15 | ||
-1 8 12 16 | ||
] | ||
|
||
@test pad(a, -1, (0, 1), (1, 0)) == [ | ||
-1 1 5 9 13 | ||
-1 2 6 10 14 | ||
-1 3 7 11 15 | ||
-1 4 8 12 16 | ||
-1 -1 -1 -1 -1 | ||
] | ||
|
||
@test pad!(copy(a), -1, (0, 1), (1, 0)) == [ | ||
-1 5 9 13 | ||
-1 6 10 14 | ||
-1 7 11 15 | ||
-1 -1 -1 -1 | ||
] | ||
|
||
@test pad(a, :periodic, (1, 1), (0, 0)) == [ | ||
16 4 8 12 16 | ||
13 1 5 9 13 | ||
14 2 6 10 14 | ||
15 3 7 11 15 | ||
16 4 8 12 16 | ||
] | ||
|
||
@test pad!(copy(a), :periodic, (1, 1), (0, 0)) == [ | ||
16 8 12 16 | ||
14 6 10 14 | ||
15 7 11 15 | ||
16 8 12 16 | ||
] | ||
|
||
@test pad(a, :symmetric, 1) == [ | ||
1 1 5 9 13 13 | ||
1 1 5 9 13 13 | ||
2 2 6 10 14 14 | ||
3 3 7 11 15 15 | ||
4 4 8 12 16 16 | ||
4 4 8 12 16 16 | ||
] | ||
|
||
@test pad!(copy(a), :symmetric, 1) == [ | ||
6 6 10 10 | ||
6 6 10 10 | ||
7 7 11 11 | ||
7 7 11 11 | ||
] | ||
|
||
@test pad(a, :mirror, 1) == [ | ||
6 2 6 10 14 10 | ||
5 1 5 9 13 9 | ||
6 2 6 10 14 10 | ||
7 3 7 11 15 11 | ||
8 4 8 12 16 12 | ||
7 3 7 11 15 11 | ||
] | ||
|
||
@test pad!(copy(a), :mirror, 1) == [ | ||
11 7 11 7 | ||
10 6 10 6 | ||
11 7 11 7 | ||
10 6 10 6 | ||
] | ||
|
||
@test pad(a, :replicate, 1) == [ | ||
1 1 5 9 13 13 | ||
1 1 5 9 13 13 | ||
2 2 6 10 14 14 | ||
3 3 7 11 15 15 | ||
4 4 8 12 16 16 | ||
4 4 8 12 16 16 | ||
] | ||
|
||
@test pad!(copy(a), :replicate, 1) == [ | ||
6 6 10 10 | ||
6 6 10 10 | ||
7 7 11 11 | ||
7 7 11 11 | ||
] | ||
|
||
@test pad(a, :smooth, 1) == [ | ||
-4 0 4 8 12 16 | ||
-3 1 5 9 13 17 | ||
-2 2 6 10 14 18 | ||
-1 3 7 11 15 19 | ||
0 4 8 12 16 20 | ||
1 5 9 13 17 21 | ||
] | ||
|
||
@test pad!(copy(a), :smooth, 1) == [ | ||
1 5 9 13 | ||
2 6 10 14 | ||
3 7 11 15 | ||
4 8 12 16 | ||
] | ||
|
||
using Zygote | ||
using Zygote: Buffer | ||
|
||
@test withgradient(a) do a | ||
a = pad(a, 0, 1) | ||
sum(a) | ||
end == (val=136, grad=([1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0],)) | ||
|
||
@test withgradient(a) do a | ||
a_ = Buffer(a) | ||
a_ .= a | ||
pad!(a_, 0, 1) | ||
a = copy(a_) | ||
sum(a) | ||
end == (val=34, grad=([0.0 0.0 0.0 0.0; 0.0 1.0 1.0 0.0; 0.0 1.0 1.0 0.0; 0.0 0.0 0.0 0.0],)) | ||
``` | ||
## Contributing | ||
Consider sponsoring this on Github if you found this repo helpful. Feel free to request features or contribute PRs :) | ||
## Contributors | ||
Paul Shen <pxshen@alumni.stanford.edu> | ||
Paul Shen | ||
<pxshen@alumni.stanford.edu> | ||
Luminescent AI |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,80 +1,87 @@ | ||
function lr(a::T, b, i, l, r, out=true, ol=0, or=ol) where {T} | ||
d = ndims(a) | ||
sel = (1:d) .== i | ||
function lr(a::S, v, i, l, r) where {S} | ||
N = ndims(a) | ||
T = eltype(a) | ||
ax = axes(a, i) | ||
sel = (1:N) .== i | ||
al = ar = nothing | ||
if l > 0 | ||
o = out ? 0 : l + ol | ||
if b == :periodic | ||
I = [j == i ? ((size(a, i)-l+1):size(a, i)) : (:) for j = 1:d] | ||
o = l | ||
if v == :periodic | ||
I = ifelse.(sel, (ax[end-l+1:end],), :) | ||
al = a[I...] | ||
elseif b == :symmetric | ||
I = [j == i ? (l+o:-1:1+o) : (:) for j = 1:d] | ||
elseif v == :symmetric | ||
I = ifelse.(sel, (ax[begin+o+l-1:-1:begin+o],), :) | ||
al = a[I...] | ||
elseif b == :mirror | ||
I = [j == i ? (l+1+o:-1:2+o) : (:) for j = 1:d] | ||
elseif v == :mirror | ||
I = ifelse.(sel, (ax[begin+o+l:-1:begin+1+o],), :) | ||
al = a[I...] | ||
elseif b == :replicate | ||
I = [j == i ? ((1+o):(1+o)) : (:) for j = 1:d] | ||
# s = [j == i ? l : 1 for j = 1:d] | ||
# al = repeat(a[I...], s...) | ||
al = cat(fill(selectdim(a, i, o+1:o+1), l)..., dims=i) | ||
# al = repeat(selectdim(a, i, o+1:o+1), (l .* sel + (1 .- sel))...) | ||
elseif b == :smooth | ||
I = [j == i ? (1+o:1+o) : (:) for j = 1:d] | ||
if size(a, i) == 1 | ||
al = a[I...] | ||
elseif v == :replicate | ||
I = ifelse.(sel, (ax[begin+o:begin+o],), :) | ||
al = repeat(a[I...], (sel .* l .+ .!(sel))...) | ||
elseif v == :smooth | ||
# @ignore_derivatives I[i]=(begin+o:begin+o) | ||
# if size(a, i) == 1 | ||
# al = a[I...] | ||
# else | ||
# I2 = [j == i ? (2+o:2+o) | ||
# al = 2a[I...] - a[I2...] | ||
# end | ||
elseif isa(v, ReplicateRamp) | ||
# # @ignore_derivatives I[i]=((begin+o):(begin+o)) | ||
# @ignore_derivatives I[i]=((begin+o):(begin+o)) : (1:size(a, j)) for j = 1:N] | ||
# v = a[I...] | ||
# Δ = (v.v .- v) / l | ||
# al = v .+ cat([c * Δ for c = l:-1:1]..., dims=i) | ||
else | ||
if isa(a, Buffer) | ||
|
||
# al = fillfunc(S)(T(v), (sel .* l .+ .!(sel) .* size(a))...) | ||
al = fill(T(v), (sel .* l .+ .!(sel) .* size(a))...) | ||
else | ||
I2 = [j == i ? (2+o:2+o) : (:) for j = 1:d] | ||
al = 2a[I...] - a[I2...] | ||
al = T(v) | ||
end | ||
elseif isa(b, ReplicateRamp) | ||
# I = [j == i ? ((1+o):(1+o)) : (:) for j = 1:d] | ||
I = [j == i ? ((1+o):(1+o)) : (1:size(a, j)) for j = 1:d] | ||
v = a[I...] | ||
Δ = (b.v .- v) / l | ||
al = v .+ cat([c * Δ for c = l:-1:1]..., dims=i) | ||
else | ||
al = fill(b, Tuple(sel .* l .+ (1 .- sel) .* size(a))) | ||
end | ||
end | ||
if r > 0 | ||
o = out ? 0 : r + or | ||
if b == :periodic | ||
I = [j == i ? (1:r) : (:) for j = 1:d] | ||
o = r | ||
if v == :periodic | ||
I = ifelse.(sel, (1:r,), :) | ||
ar = a[I...] | ||
elseif b == :replicate | ||
# I = [j == i ? axes(a, i)[end-o:end-o] : (:) for j = 1:d] | ||
# s = [j == i ? r : 1 for j = 1:d] | ||
# ar = repeat(a[I...], s...) | ||
j = size(a, i) - o | ||
ar = cat(fill(selectdim(a, i, j:j), r)..., dims=i) | ||
# ar = repeat(selectdim(a, i, j:j), (r .* sel + (1 .- sel))...) | ||
elseif b == :symmetric | ||
I = [j == i ? axes(a, i)[end-o:-1:end-r+1-o] : (:) for j = 1:d] | ||
elseif v == :replicate | ||
I = ifelse.(sel, (ax[end-o:end-o],), :) | ||
ar = repeat(a[I...], (sel .* r .+ .!(sel))...) | ||
elseif v == :symmetric | ||
I = ifelse.(sel, (ax[end-o:-1:end-r+1-o],), :) | ||
ar = a[I...] | ||
elseif b == :mirror | ||
I = [j == i ? axes(a, i)[end-1-o:-1:end-r-o] : (:) for j = 1:d] | ||
elseif v == :mirror | ||
I = ifelse.(sel, (ax[end-1-o:-1:end-r-o],), :) | ||
ar = a[I...] | ||
elseif b == :smooth | ||
I = [j == i ? axes(a, i)[end-o:end-o] : (:) for j = 1:d] | ||
if size(a, i) == 1 | ||
ar = a[I...] | ||
elseif v == :smooth | ||
# @ignore_derivatives I[i]=ax[end-o:end-o] | ||
# if size(a, i) == 1 | ||
# ar = a[I...] | ||
# else | ||
# I2 = [j == i ? ax[end-o-1:end-o-1] | ||
# ar = 2a[I...] - a[I2...] | ||
# end | ||
elseif isa(v, ReplicateRamp) | ||
# @ignore_derivatives I[i]=ax[end-o:end-o] | ||
# v = a[I...] | ||
# Δ = (v.v .- v) / r | ||
# ar = v .+ cat([c * Δ for c = 1:r]..., dims=i) | ||
else | ||
if isa(a, Buffer) | ||
# ar = fillfunc(S)(T(v), (sel .* r .+ .!(sel) .* size(a))...) | ||
ar = fill(T(v), (sel .* r .+ .!(sel) .* size(a))...) | ||
else | ||
I2 = [j == i ? axes(a, i)[end-o-1:end-o-1] : (:) for j = 1:d] | ||
ar = 2a[I...] - a[I2...] | ||
ar = T(v) | ||
end | ||
elseif isa(b, ReplicateRamp) | ||
I = [j == i ? axes(a, i)[end-o:end-o] : (:) for j = 1:d] | ||
v = a[I...] | ||
Δ = (b.v .- v) / r | ||
ar = v .+ cat([c * Δ for c = 1:r]..., dims=i) | ||
else | ||
ar = fill(b, Tuple(sel .* r .+ (1 .- sel) .* size(a))) | ||
end | ||
end | ||
return (al, ar) | ||
end | ||
|
||
Base.vec(x::Number, d) = fill(x, d) | ||
Base.vec(v::Union{AbstractVector,Tuple}, d) = v | ||
Base.vec(x::Number, N) = fill(x, N) | ||
Base.vec(v::Union{AbstractVector,Tuple}, N) = v | ||
|
||
fillfunc(S) = (args...) -> Base.fill(args...) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
using Statistics | ||
# using Zygote: Buffer, bufferfrom | ||
using Statistics, Zygote | ||
using Zygote: Buffer, @ignore_derivatives | ||
|
||
include("types.jl") | ||
include("alg.jl") | ||
include("pad.jl") |
Oops, something went wrong.