Skip to content

Commit

Permalink
No commit message
Browse files Browse the repository at this point in the history
  • Loading branch information
paulxshen committed Jan 5, 2025
1 parent 1e00506 commit 454bca6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
3 changes: 1 addition & 2 deletions src/ArrayPadding.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
module ArrayPadding
include("main.jl")
export pad, pad!, Ramp, ReplicateRamp, PaddedArray, left, right
export place, place!, array
export pad, pad!, TanhRamp
end # module Pad
10 changes: 7 additions & 3 deletions src/alg.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
function lr(a::S, v, i, l, r, ol=0, or=0, dofill=false) where {S}
function lr(a::S, v0, i, l, r, ol=0, or=0, dofill=false) where {S}
N = ndims(a)
sz = size(a)
T = eltype(a)
ax = axes(a, i)
sel = (1:N) .== i
al = ar = nothing
if l > 0
v = _fv(v0, l)
if v == :periodic
I = ifelse.(sel, (ax[end-l+1:end],), :)
al = a[I...]
Expand All @@ -27,7 +28,8 @@ function lr(a::S, v, i, l, r, ol=0, or=0, dofill=false) where {S}
# al = 2a[I...] - a[I2...]
# end
elseif isa(v, Function)
al = repeat(v.(constructor(S)(1:l)), (sel .+ .!(sel) .* sz)...)
b = reshape(v.(constructor(S)(1:l)), (l .* sel .+ .!(sel))...)
al = repeat(b, (sel .+ .!(sel) .* sz)...)
else
if dofill
f = fillfunc(S)
Expand All @@ -39,6 +41,7 @@ function lr(a::S, v, i, l, r, ol=0, or=0, dofill=false) where {S}
end
end
if r > 0
v = _fv(v0, r)
if v == :periodic
I = ifelse.(sel, (1:r,), :)
ar = a[I...]
Expand All @@ -60,7 +63,8 @@ function lr(a::S, v, i, l, r, ol=0, or=0, dofill=false) where {S}
# ar = 2a[I...] - a[I2...]
# end
elseif isa(v, Function)
ar = repeat(v.(constructor(S)(1:r)), (sel .+ .!(sel) .* sz)...)
b = reshape(v.(constructor(S)(1:r)), (r .* sel .+ .!(sel))...)
ar = repeat(b, (sel .+ .!(sel) .* sz)...)
else
if dofill
f = fillfunc(S)
Expand Down
8 changes: 7 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,10 @@ constructor(::Type{<:Array}) = Array

# (f::Function)(::Type{Buffer{T,S}}) where {T,S} = f(S)
fillfunc(::Type{<:Buffer{T,S}}) where {T,S} = fillfunc(S)
constructor(::Type{<:Buffer{T,S}}) where {T,S} = constructor(S)
constructor(::Type{<:Buffer{T,S}}) where {T,S} = constructor(S)

struct TanhRamp
v
end
_fv(v::TanhRamp, n) = i -> v.v * tanh(2i / n)
_fv(v, n) = v

0 comments on commit 454bca6

Please sign in to comment.