Skip to content

Commit

Permalink
Additional tests, improve consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
brenhinkeller committed Jun 13, 2023
1 parent c568c8a commit c579b37
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NaNStatistics"
uuid = "b946abbf-3ea7-4610-9019-9858bfdeaf2d"
authors = ["C. Brenhin Keller"]
version = "0.6.29"
version = "0.6.30"

[deps]
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
Expand Down
10 changes: 5 additions & 5 deletions src/ArrayStats/ArrayStats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,19 @@
```
Fill a Boolean mask of dimensions `size(A)` that is false wherever `A` is `NaN`
"""
function nanmask!(mask, A)
function nanmask!(mask::StridedArray, A::StridedArray{T}) where T<:PrimitiveNumber
@turbo for i eachindex(A)
mask[i] = A[i]==A[i]
end
return mask
end
function nanmask!(mask::StridedArray, A::StridedArray{T}) where T<:PrimitiveNumber
function nanmask!(mask, A)
@inbounds for i eachindex(A)
mask[i] = A[i]==A[i]
end
return mask
end

# Special methods for arrays that cannot contain NaNs
nanmask!(mask, A::AbstractArray{<:Integer}) = fill!(mask, true)
nanmask!(mask, A::AbstractArray{<:Rational}) = fill!(mask, true)
Expand Down Expand Up @@ -404,10 +405,9 @@
end
return sqrt(s / w * n / (n-1))
end
function _nanstd(A::StridedArray, W::StridedArray, ::Colon)
function _nanstd(A::StridedArray{Ta}, W::StridedArray{Tw}, ::Colon) where {Ta<:PrimitiveNumber, Tw<:PrimitiveNumber}
n = 0
Tw = eltype(W)
Tm = promote_type(eltype(W), eltype(A))
Tm = promote_type(Tw, Ta)
w = zero(Tw)
m = zero(Tm)
@turbo for i eachindex(A)
Expand Down
8 changes: 4 additions & 4 deletions src/ArrayStats/nanmean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ function _nanmean(A::AbstractArray{T,N}, dims::Tuple) where {T,N}
end

# Reduce all the dims!
function _nanmean(A::StridedArray, ::Colon)
Tₒ = Base.promote_op(/, eltype(A), Int)
function _nanmean(A::StridedArray{T}, ::Colon) where T<:PrimitiveFloat
Tₒ = Base.promote_op(/, T, Int)
n = 0
Σ == zero(Tₒ)
@turbo check_empty=true for i eachindex(A)
Expand All @@ -62,8 +62,8 @@ function _nanmean(A::StridedArray, ::Colon)
end
return Σ / n
end
function _nanmean(A::StridedArray{<:Integer}, ::Colon)
Tₒ = Base.promote_op(/, eltype(A), Int)
function _nanmean(A::StridedArray{T}, ::Colon) where T<:PrimitiveInteger
Tₒ = Base.promote_op(/, T, Int)
Σ = zero(Tₒ)
@turbo check_empty=true for i eachindex(A)
Σ += A[i]
Expand Down
2 changes: 1 addition & 1 deletion src/ArrayStats/nanstd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ nanstd(A; dims=:, dim=:, mean=nothing, corrected=true) = sqrt!(__nanvar(mean, co
export nanstd

sqrt!(x::Number) = sqrt(x)
function sqrt!(A::StridedArray{T}) where T<:PrimitiveNumber
function sqrt!(A::StridedArray{<:PrimitiveNumber})
@turbo check_empty=true for i eachindex(A)
A[i] = sqrt(A[i])
end
Expand Down
8 changes: 4 additions & 4 deletions src/ArrayStats/nansum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ function _nansum(A::AbstractArray{T,N}, dims::Tuple) where {T,N}
end

# Reduce all the dims!
function _nansum(A::StridedArray, ::Colon)
Tₒ = Base.promote_op(+, eltype(A), Int)
function _nansum(A::StridedArray{T}, ::Colon) where T<:PrimitiveFloat
Tₒ = Base.promote_op(+, T, Int)
Σ == zero(Tₒ)
@turbo check_empty=true for i eachindex(A)
Aᵢ = A[i]
Expand All @@ -60,8 +60,8 @@ function _nansum(A::StridedArray, ::Colon)
end
return Σ
end
function _nansum(A::StridedArray{<:Integer}, ::Colon)
Tₒ = Base.promote_op(+, eltype(A), Int)
function _nansum(A::StridedArray{T}, ::Colon) where T<:PrimitiveInteger
Tₒ = Base.promote_op(+, T, Int)
Σ = zero(Tₒ)
@turbo check_empty=true for i eachindex(A)
Σ += A[i]
Expand Down
14 changes: 7 additions & 7 deletions src/ArrayStats/nanvar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ _nanvar(μ, corrected::Bool, A, dims::Int) = _nanvar(μ, corrected, A, (dims,))
# If the mean isn't known, compute it
_nanvar(::Nothing, corrected::Bool, A, dims::Tuple) = _nanvar!(_nanmean(A, dims), corrected, A, dims)
# Reduce all the dims!
function _nanvar(::Nothing, corrected::Bool, A::StridedArray, ::Colon)
Tₒ = Base.promote_op(/, eltype(A), Int)
function _nanvar(::Nothing, corrected::Bool, A::StridedArray{T}, ::Colon) where T<:PrimitiveFloat
Tₒ = Base.promote_op(/, T, Int)
n = 0
Σ == zero(Tₒ)
@turbo check_empty=true for i eachindex(A)
Expand All @@ -64,8 +64,8 @@ function _nanvar(::Nothing, corrected::Bool, A::StridedArray, ::Colon)
end
return σ² / max(n-corrected,0)
end
function _nanvar(::Nothing, corrected::Bool, A::StridedArray{<:Integer}, ::Colon)
Tₒ = Base.promote_op(/, eltype(A), Int)
function _nanvar(::Nothing, corrected::Bool, A::StridedArray{T}, ::Colon) where T<:PrimitiveInteger
Tₒ = Base.promote_op(/, T, Int)
n = length(A)
Σ = zero(Tₒ)
@turbo check_empty=true for i eachindex(A)
Expand Down Expand Up @@ -106,7 +106,7 @@ _nanvar(μ, corrected::Bool, A, dims::Tuple) = _nanvar!(collect(μ), corrected,
_nanvar::Array, corrected::Bool, A, dims::Tuple) = _nanvar!(copy(μ), corrected, A, dims)
_nanvar::Number, corrected::Bool, A, dims::Tuple) = _nanvar!([μ], corrected, A, dims)
# Reduce all the dims!
function _nanvar::Number, corrected::Bool, A::AbstractArray, ::Colon)
function _nanvar::Number, corrected::Bool, A::StridedArray{T}, ::Colon) where T<:PrimitiveFloat
n = 0
σ² == zero(typeof(μ))
@turbo check_empty=true for i eachindex(A)
Expand All @@ -117,7 +117,7 @@ function _nanvar(μ::Number, corrected::Bool, A::AbstractArray, ::Colon)
end
return σ² / max(n-corrected, 0)
end
function _nanvar::Number, corrected::Bool, A::AbstractArray{<:Integer}, ::Colon)
function _nanvar::Number, corrected::Bool, A::StridedArray{T}, ::Colon) where T<:PrimitiveInteger
σ² = zero(typeof(μ))
if μ==μ
@turbo check_empty=true for i eachindex(A)
Expand All @@ -130,7 +130,7 @@ function _nanvar(μ::Number, corrected::Bool, A::AbstractArray{<:Integer}, ::Col
end
return σ² / max(n-corrected, 0)
end
# Fallback method for non-arrays
# Fallback method for non-strided-arrays
function _nanvar::Number, corrected::Bool, A, ::Colon)
n = 0
σ² == zero(typeof(μ))
Expand Down
4 changes: 3 additions & 1 deletion src/NaNStatistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ module NaNStatistics

using Static
const IntOrStaticInt = Union{Integer, StaticInt}
const PrimitiveNumber = Union{Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128, Float16, Float32, Float64}
const PrimitiveFloat = Union{Float16, Float32, Float64}
const PrimitiveInteger = Union{Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128}
const PrimitiveNumber = Union{PrimitiveFloat, PrimitiveInteger}
_dim(::Type{StaticInt{N}}) where {N} = N::Int

include("ArrayStats/ArrayStats.jl")
Expand Down
49 changes: 48 additions & 1 deletion test/testArrayStats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@
@test nanaad(A, dims=1) == [25.0 25.0 25.0]
@test nanaad(A, dims=2) fill(200/3, 100, 1)

## --- Summary statistics: dimensional tests, Float64
## --- Summary statistics: dimensional tests, Float64, dim

A = collect(reshape(1:300.,100,3))
@test nansum(A, dim=1) == vec(sum(A, dims=1))
Expand Down Expand Up @@ -283,6 +283,53 @@
@test nanaad(A, dim=1) == [25.0, 25.0, 25.0]
@test nanaad(A, dim=2) fill(200/3, 100)

## --- Summary statistics: dimensional tests, Float64, nonstandard array type

A = reshape(1:300.,100,3)
@test nansum(A, dims=1) == sum(A, dims=1)
@test nansum(A, dims=2) == sum(A, dims=2)
@test nancumsum(A, dims=1) == cumsum(A, dims=1)
@test nancumsum(A, dims=2) == cumsum(A, dims=2)
@test nanminimum(A, dims=1) == minimum(A, dims=1)
@test nanminimum(A, dims=2) == minimum(A, dims=2)
@test nanmaximum(A, dims=1) == maximum(A, dims=1)
@test nanmaximum(A, dims=2) == maximum(A, dims=2)
@test nanextrema(A, dims=1) == [(1.0, 100.0) (101.0, 200.0) (201.0, 300.0)]
@test nanmean(A, dims=1) == mean(A, dims=1)
@test nanmean(A, dims=2) == mean(A, dims=2)
@test nanmean(A, ones(size(A)), dims=1) == mean(A, dims=1) # weighted
@test nanmean(A, ones(size(A)), dims=2) == mean(A, dims=2) # weighted
@test nanvar(A, dims=1) var(A, dims=1)
@test nanvar(A, dims=2) var(A, dims=2)
@test nanstd(A, dims=1) std(A, dims=1)
@test nanstd(A, dims=2) std(A, dims=2)
@test nanstd(A, dims=1, mean=nanmean(A,dims=1)) std(A, dims=1)
@test nanstd(A, dims=2, mean=nanmean(A,dims=2)) std(A, dims=2)
@test nanstd(A, ones(size(A)), dims=1) std(A, dims=1) # weighted
@test nanstd(A, ones(size(A)), dims=2) std(A, dims=2) # weighted
@test nanmedian(A, dims=1) == median(A, dims=1)
@test nanmedian(A, dims=2) == median(A, dims=2)
@test nanminimum(A, dims=1) == [1 101 201]
@test dropdims(nanminimum(A, dims=2), dims=2) == 1:100
@test nanmaximum(A, dims=1) == [100 200 300]
@test dropdims(nanmaximum(A, dims=2), dims=2) == 201:300
@test nanmean(A, dims=1) == [50.5 150.5 250.5]
@test dropdims(nanmean(A, dims=2), dims=2) == 101:200
@test nanmean(A, ones(size(A)), dims=1) == [50.5 150.5 250.5] # weighted
@test dropdims(nanmean(A, ones(size(A)), dims=2), dims=2) == 101:200 # weighted
@test nanstd(A, dims=1) fill(29.011491975882016, 1, 3)
@test nanstd(A, dims=2) fill(100, 100, 1)
@test nanstd(A, ones(size(A)), dims=1) fill(29.011491975882016, 1, 3) # weighted
@test nanstd(A, ones(size(A)), dims=2) fill(100, 100, 1) # weighted
@test nanmedian(A, dims=1) == [50.5 150.5 250.5]
@test dropdims(nanmedian(A, dims=2), dims=2) == 101:200
@test nanpctile(A, 10, dims=1) [10.9 110.9 210.9]
@test nanpctile(A, 10, dims=2) 21:120
@test nanmad(A, dims=1) == [25.0 25.0 25.0]
@test nanmad(A, dims=2) == fill(100.0, 100, 1)
@test nanaad(A, dims=1) == [25.0 25.0 25.0]
@test nanaad(A, dims=2) fill(200/3, 100, 1)

## --- Test fallbacks for complex reductions

A = randn((2 .+ (1:6))...);
Expand Down

2 comments on commit c579b37

@brenhinkeller
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/85519

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.30 -m "<description of version>" c579b37157bc362632857e4cee5e4c770976553f
git push origin v0.6.30

Please sign in to comment.