Skip to content

Commit

Permalink
Use where syntax in sparsematrix.jl (#22112)
Browse files Browse the repository at this point in the history
* Use where syntax in sparsematrix.jl

* Update sparsematrix.jl

* Update sparsematrix.jl
  • Loading branch information
musm authored and tkelman committed May 30, 2017
1 parent 526ed78 commit fdf97c1
Showing 1 changed file with 41 additions and 41 deletions.
82 changes: 41 additions & 41 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -830,9 +830,9 @@ information.
See also: `unchecked_aliasing_permute!`
"""
function unchecked_noalias_permute!{Tv,Ti}(X::SparseMatrixCSC{Tv,Ti},
function unchecked_noalias_permute!(X::SparseMatrixCSC{Tv,Ti},
A::SparseMatrixCSC{Tv,Ti}, p::AbstractVector{<:Integer},
q::AbstractVector{<:Integer}, C::SparseMatrixCSC{Tv,Ti})
q::AbstractVector{<:Integer}, C::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
halfperm!(C, A, q)
_computecolptrs_permute!(X, A, q, X.colptr)
_distributevals_halfperm!(X, C, p, identity)
Expand All @@ -850,9 +850,9 @@ for additional information; these methods are identical but for this method's re
the additional `workcolptr`, `length(workcolptr) >= A.n + 1`, which enables efficient
handling of the source-destination aliasing.
"""
function unchecked_aliasing_permute!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti},
function unchecked_aliasing_permute!(A::SparseMatrixCSC{Tv,Ti},
p::AbstractVector{<:Integer}, q::AbstractVector{<:Integer},
C::SparseMatrixCSC{Tv,Ti}, workcolptr::Vector{Ti})
C::SparseMatrixCSC{Tv,Ti}, workcolptr::Vector{Ti}) where {Tv,Ti}
halfperm!(C, A, q)
_computecolptrs_permute!(A, A, q, workcolptr)
_distributevals_halfperm!(A, C, p, identity)
Expand All @@ -864,8 +864,8 @@ Computes `PAQ`'s column pointers, storing them shifted one position forward in `
`_distributevals_halfperm!` fixes this shift. Saves some work relative to
`_computecolptrs_halfperm!` as described in `uncheckednoalias_permute!`'s documentation.
"""
function _computecolptrs_permute!{Tv,Ti}(X::SparseMatrixCSC{Tv,Ti},
A::SparseMatrixCSC{Tv,Ti}, q::AbstractVector{<:Integer}, workcolptr::Vector{Ti})
function _computecolptrs_permute!(X::SparseMatrixCSC{Tv,Ti},
A::SparseMatrixCSC{Tv,Ti}, q::AbstractVector{<:Integer}, workcolptr::Vector{Ti}) where {Tv,Ti}
# Compute `A[p,q]`'s column counts. Store shifted forward one position in workcolptr.
@inbounds for k in 1:A.n
workcolptr[k+1] = A.colptr[q[k] + 1] - A.colptr[q[k]]
Expand Down Expand Up @@ -901,9 +901,9 @@ end
Helper method for `permute` and `permute!` methods operating on `SparseMatrixCSC`s.
Checks whether row- and column- permutation arguments `p` and `q` are valid permutations.
"""
function _checkargs_permutationsvalid_permute!{Ti<:Integer}(
function _checkargs_permutationsvalid_permute!(
p::AbstractVector{<:Integer}, pcheckspace::Vector{Ti},
q::AbstractVector{<:Integer}, qcheckspace::Vector{Ti})
q::AbstractVector{<:Integer}, qcheckspace::Vector{Ti}) where Ti<:Integer
if !_ispermutationvalid_permute!(p, pcheckspace)
throw(ArgumentError("row-permutation argument `p` must be a valid permutation"))
elseif !_ispermutationvalid_permute!(q, qcheckspace)
Expand Down Expand Up @@ -1003,42 +1003,42 @@ and `unchecked_aliasing_permute!`.
See also: [`permute`](@ref)
"""
function permute!{Tv,Ti}(X::SparseMatrixCSC{Tv,Ti}, A::SparseMatrixCSC{Tv,Ti},
p::AbstractVector{<:Integer}, q::AbstractVector{<:Integer})
function permute!(X::SparseMatrixCSC{Tv,Ti}, A::SparseMatrixCSC{Tv,Ti},
p::AbstractVector{<:Integer}, q::AbstractVector{<:Integer}) where {Tv,Ti}
_checkargs_sourcecompatdest_permute!(A, X)
_checkargs_sourcecompatperms_permute!(A, p, q)
C = SparseMatrixCSC(A.n, A.m, Vector{Ti}(A.m + 1), Vector{Ti}(nnz(A)), Vector{Tv}(nnz(A)))
_checkargs_permutationsvalid_permute!(p, C.colptr, q, X.colptr)
unchecked_noalias_permute!(X, A, p, q, C)
end
function permute!{Tv,Ti}(X::SparseMatrixCSC{Tv,Ti}, A::SparseMatrixCSC{Tv,Ti},
function permute!(X::SparseMatrixCSC{Tv,Ti}, A::SparseMatrixCSC{Tv,Ti},
p::AbstractVector{<:Integer}, q::AbstractVector{<:Integer},
C::SparseMatrixCSC{Tv,Ti})
C::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
_checkargs_sourcecompatdest_permute!(A, X)
_checkargs_sourcecompatperms_permute!(A, p, q)
_checkargs_sourcecompatworkmat_permute!(A, C)
_checkargs_permutationsvalid_permute!(p, C.colptr, q, X.colptr)
unchecked_noalias_permute!(X, A, p, q, C)
end
function permute!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, p::AbstractVector{<:Integer},
q::AbstractVector{<:Integer})
function permute!(A::SparseMatrixCSC{Tv,Ti}, p::AbstractVector{<:Integer},
q::AbstractVector{<:Integer}) where {Tv,Ti}
_checkargs_sourcecompatperms_permute!(A, p, q)
C = SparseMatrixCSC(A.n, A.m, Vector{Ti}(A.m + 1), Vector{Ti}(nnz(A)), Vector{Tv}(nnz(A)))
workcolptr = Vector{Ti}(A.n + 1)
_checkargs_permutationsvalid_permute!(p, C.colptr, q, workcolptr)
unchecked_aliasing_permute!(A, p, q, C, workcolptr)
end
function permute!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, p::AbstractVector{<:Integer},
q::AbstractVector{<:Integer}, C::SparseMatrixCSC{Tv,Ti})
function permute!(A::SparseMatrixCSC{Tv,Ti}, p::AbstractVector{<:Integer},
q::AbstractVector{<:Integer}, C::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
_checkargs_sourcecompatperms_permute!(A, p, q)
_checkargs_sourcecompatworkmat_permute!(A, C)
workcolptr = Vector{Ti}(A.n + 1)
_checkargs_permutationsvalid_permute!(p, C.colptr, q, workcolptr)
unchecked_aliasing_permute!(A, p, q, C, workcolptr)
end
function permute!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, p::AbstractVector{<:Integer},
function permute!(A::SparseMatrixCSC{Tv,Ti}, p::AbstractVector{<:Integer},
q::AbstractVector{<:Integer}, C::SparseMatrixCSC{Tv,Ti},
workcolptr::Vector{Ti})
workcolptr::Vector{Ti}) where {Tv,Ti}
_checkargs_sourcecompatperms_permute!(A, p, q)
_checkargs_sourcecompatworkmat_permute!(A, C)
_checkargs_sourcecompatworkcolptr_permute!(A, workcolptr)
Expand All @@ -1055,8 +1055,8 @@ row count (`length(p) == A.m`).
For expert drivers and additional information, see [`permute!`](@ref).
"""
function permute{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, p::AbstractVector{<:Integer},
q::AbstractVector{<:Integer})
function permute(A::SparseMatrixCSC{Tv,Ti}, p::AbstractVector{<:Integer},
q::AbstractVector{<:Integer}) where {Tv,Ti}
_checkargs_sourcecompatperms_permute!(A, p, q)
X = SparseMatrixCSC(A.m, A.n, Vector{Ti}(A.n + 1), Vector{Ti}(nnz(A)), Vector{Tv}(nnz(A)))
C = SparseMatrixCSC(A.n, A.m, Vector{Ti}(A.m + 1), Vector{Ti}(nnz(A)), Vector{Tv}(nnz(A)))
Expand Down Expand Up @@ -1276,8 +1276,8 @@ julia> sprand(rng, Float64, 3, 0.75)
[3] = 0.298614
```
"""
function sprand{T}(r::AbstractRNG, m::Integer, n::Integer, density::AbstractFloat,
rfn::Function, ::Type{T}=eltype(rfn(r,1)))
function sprand(r::AbstractRNG, m::Integer, n::Integer, density::AbstractFloat,
rfn::Function, ::Type{T}=eltype(rfn(r,1))) where T
N = m*n
N == 0 && return spzeros(T,m,n)
N == 1 && return rand(r) <= density ? sparse([1], [1], rfn(r,1)) : spzeros(T,1,1)
Expand All @@ -1286,8 +1286,8 @@ function sprand{T}(r::AbstractRNG, m::Integer, n::Integer, density::AbstractFloa
sparse_IJ_sorted!(I, J, rfn(r,length(I)), m, n, +) # it will never need to combine
end

function sprand{T}(m::Integer, n::Integer, density::AbstractFloat,
rfn::Function, ::Type{T}=eltype(rfn(1)))
function sprand(m::Integer, n::Integer, density::AbstractFloat,
rfn::Function, ::Type{T}=eltype(rfn(1))) where T
N = m*n
N == 0 && return spzeros(T,m,n)
N == 1 && return rand() <= density ? sparse([1], [1], rfn(1)) : spzeros(T,1,1)
Expand Down Expand Up @@ -1468,15 +1468,15 @@ imag(A::SparseMatrixCSC{Tv,Ti}) where {Tv<:Real,Ti} = spzeros(Tv, Ti, A.m, A.n)

## full equality
function ==(A1::SparseMatrixCSC, A2::SparseMatrixCSC)
size(A1)!=size(A2) && return false
size(A1) != size(A2) && return false
vals1, vals2 = nonzeros(A1), nonzeros(A2)
rows1, rows2 = rowvals(A1), rowvals(A2)
m, n = size(A1)
@inbounds for i = 1:n
nz1,nz2 = nzrange(A1,i), nzrange(A2,i)
j1,j2 = first(nz1), first(nz2)
# step through the rows of both matrices at once:
while j1<=last(nz1) && j2<=last(nz2)
while j1 <= last(nz1) && j2 <= last(nz2)
r1,r2 = rows1[j1], rows2[j2]
if r1==r2
vals1[j1]!=vals2[j2] && return false
Expand Down Expand Up @@ -1508,9 +1508,9 @@ end
# In general, output of sparse matrix reductions will not be sparse,
# and computing reductions along columns into SparseMatrixCSC is
# non-trivial, so use Arrays for output
Base.reducedim_initarray{R}(A::SparseMatrixCSC, region, v0, ::Type{R}) =
Base.reducedim_initarray(A::SparseMatrixCSC, region, v0, ::Type{R}) where {R} =
fill!(similar(dims->Array{R}(dims), Base.reduced_indices(A,region)), v0)
Base.reducedim_initarray0{R}(A::SparseMatrixCSC, region, v0, ::Type{R}) =
Base.reducedim_initarray0(A::SparseMatrixCSC, region, v0, ::Type{R}) where {R} =
fill!(similar(dims->Array{R}(dims), Base.reduced_indices0(A,region)), v0)

# General mapreduce
Expand All @@ -1533,7 +1533,7 @@ function _mapreducezeros(f, op, ::Type{T}, nzeros::Int, v0) where T
v
end

function Base._mapreduce{T}(f, op, ::Base.IndexCartesian, A::SparseMatrixCSC{T})
function Base._mapreduce(f, op, ::Base.IndexCartesian, A::SparseMatrixCSC{T}) where T
z = nnz(A)
n = length(A)
if z == 0
Expand All @@ -1553,7 +1553,7 @@ _mapreducezeros(f, ::typeof(+), ::Type{T}, nzeros::Int, v0) where {T} =
_mapreducezeros(f, ::typeof(*), ::Type{T}, nzeros::Int, v0) where {T} =
nzeros == 0 ? v0 : f(zero(T))^nzeros * v0

function Base._mapreduce{T}(f, op::typeof(*), A::SparseMatrixCSC{T})
function Base._mapreduce(f, op::typeof(*), A::SparseMatrixCSC{T}) where T
nzeros = length(A)-nnz(A)
if nzeros == 0
# No zeros, so don't compute f(0) since it might throw
Expand All @@ -1566,7 +1566,7 @@ function Base._mapreduce{T}(f, op::typeof(*), A::SparseMatrixCSC{T})
end

# General mapreducedim
function _mapreducerows!{T}(f, op, R::AbstractArray, A::SparseMatrixCSC{T})
function _mapreducerows!(f, op, R::AbstractArray, A::SparseMatrixCSC{T}) where T
colptr = A.colptr
rowval = A.rowval
nzval = A.nzval
Expand All @@ -1581,7 +1581,7 @@ function _mapreducerows!{T}(f, op, R::AbstractArray, A::SparseMatrixCSC{T})
R
end

function _mapreducecols!{Tv,Ti}(f, op, R::AbstractArray, A::SparseMatrixCSC{Tv,Ti})
function _mapreducecols!(f, op, R::AbstractArray, A::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
colptr = A.colptr
rowval = A.rowval
nzval = A.nzval
Expand All @@ -1600,7 +1600,7 @@ function _mapreducecols!{Tv,Ti}(f, op, R::AbstractArray, A::SparseMatrixCSC{Tv,T
R
end

function Base._mapreducedim!{T}(f, op, R::AbstractArray, A::SparseMatrixCSC{T})
function Base._mapreducedim!(f, op, R::AbstractArray, A::SparseMatrixCSC{T}) where T
lsiz = Base.check_reducedims(R,A)
isempty(A) && return R

Expand Down Expand Up @@ -1650,7 +1650,7 @@ end

# Specialized mapreducedim for + cols to avoid allocating a
# temporary array when f(0) == 0
function _mapreducecols!{Tv,Ti}(f, op::typeof(+), R::AbstractArray, A::SparseMatrixCSC{Tv,Ti})
function _mapreducecols!(f, op::typeof(+), R::AbstractArray, A::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
nzval = A.nzval
m, n = size(A)
if length(nzval) == m*n
Expand Down Expand Up @@ -1690,7 +1690,7 @@ function _mapreducecols!{Tv,Ti}(f, op::typeof(+), R::AbstractArray, A::SparseMat
end

# findmax/min and indmax/min methods
function _findz{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, rows=1:A.m, cols=1:A.n)
function _findz(A::SparseMatrixCSC{Tv,Ti}, rows=1:A.m, cols=1:A.n) where {Tv,Ti}
colptr = A.colptr; rowval = A.rowval; nzval = A.nzval
zval = zero(Tv)
col = cols[1]; row = 0
Expand Down Expand Up @@ -1895,7 +1895,7 @@ function getindex(A::SparseMatrixCSC{Tv,Ti}, I::Range, J::AbstractVector) where
return SparseMatrixCSC(nI, nJ, colptrS, rowvalS, nzvalS)
end

function getindex_I_sorted{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::AbstractVector, J::AbstractVector)
function getindex_I_sorted(A::SparseMatrixCSC{Tv,Ti}, I::AbstractVector, J::AbstractVector) where {Tv,Ti}
# Sorted vectors for indexing rows.
# Similar to getindex_general but without the transpose trick.
(m, n) = size(A)
Expand All @@ -1915,7 +1915,7 @@ function getindex_I_sorted{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::AbstractVector,
getindex_I_sorted_linear(A, I, J)
end

function getindex_I_sorted_bsearch_A{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::AbstractVector, J::AbstractVector)
function getindex_I_sorted_bsearch_A(A::SparseMatrixCSC{Tv,Ti}, I::AbstractVector, J::AbstractVector) where {Tv,Ti}
const nI = length(I)
const nJ = length(J)

Expand Down Expand Up @@ -3258,7 +3258,7 @@ function next(d::SpDiagIterator{Tv}, j) where Tv
(((r1 > r2) || (A.rowval[r1] != j)) ? zero(Tv) : A.nzval[r1], j+1)
end

function trace{Tv}(A::SparseMatrixCSC{Tv})
function trace(A::SparseMatrixCSC{Tv}) where Tv
if size(A,1) != size(A,2)
throw(DimensionMismatch("expected square matrix"))
end
Expand All @@ -3269,10 +3269,10 @@ function trace{Tv}(A::SparseMatrixCSC{Tv})
s
end

diag{Tv}(A::SparseMatrixCSC{Tv}) = Tv[d for d in SpDiagIterator(A)]
diag(A::SparseMatrixCSC{Tv}) where {Tv} = Tv[d for d in SpDiagIterator(A)]

function diagm{Tv,Ti}(v::SparseMatrixCSC{Tv,Ti})
if (size(v,1) != 1 && size(v,2) != 1)
function diagm(v::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
if size(v,1) != 1 && size(v,2) != 1
throw(DimensionMismatch("input should be nx1 or 1xn"))
end

Expand Down

0 comments on commit fdf97c1

Please sign in to comment.