Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change find() to return the same index type as pairs() #24774

Merged
merged 4 commits into from
Jan 10, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Change find() to return the same index type as pairs()
This does not change anything for AbstractVectors and general iterables,
which continue to use linear indices. For other AbstractArrays, return
CartesianIndexes (rather than linear indices). For Dicts, return keys
(previously not supported at all).

Relying on collect() to choose the return element type allows supporting
any definition of pairs(), including that for Dict, which creates a standard
Generator for which eltype() returns Any.
  • Loading branch information
nalimilan committed Jan 7, 2018
commit a16323ec9258ddc54f6965ac640bee9f9ee4860c
85 changes: 42 additions & 43 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1720,48 +1720,59 @@ findlast(testf::Function, A) = findprev(testf, A, endof(A))
"""
find(f::Function, A)

Return a vector `I` of the linear indices of `A` where `f(A[I])` returns `true`.
Return a vector `I` of the indices or keys of `A` where `f(A[I])` returns `true`.
If there are no such elements of `A`, return an empty array.

Indices or keys are of the same type as those returned by [`keys(A)`](@ref)
and [`pairs(A)`](@ref) for `AbstractArray` and `Associative` objects,
and are linear indices of type `Int` for other iterables.

# Examples
```jldoctest
julia> x = [1, 3, 4]
3-element Array{Int64,1}:
1
3
4

julia> find(isodd, x)
2-element Array{Int64,1}:
1
2

julia> A = [1 2 0; 3 4 0]
2×3 Array{Int64,2}:
1 2 0
3 4 0

julia> find(isodd, A)
2-element Array{Int64,1}:
1
2
2-element Array{CartesianIndex{2},1}:
CartesianIndex(1, 1)
CartesianIndex(2, 1)

julia> find(!iszero, A)
4-element Array{Int64,1}:
1
2
3
4
4-element Array{CartesianIndex{2},1}:
CartesianIndex(1, 1)
CartesianIndex(2, 1)
CartesianIndex(1, 2)
CartesianIndex(2, 2)

julia> d = Dict(:A => 10, :B => -1, :C => 0)
Dict{Symbol,Int64} with 3 entries:
:A => 10
:B => -1
:C => 0

julia> find(x -> x >= 0, d)
2-element Array{Symbol,1}:
:A
:C

julia> find(isodd, [2, 4])
0-element Array{Int64,1}
```
"""
function find(testf::Function, A)
# use a dynamic-length array to store the indices, then copy to a non-padded
# array for the return
tmpI = Vector{Int}()
inds = _index_remapper(A)
for (i,a) = enumerate(A)
if testf(a)
push!(tmpI, inds[i])
end
end
I = Vector{Int}(uninitialized, length(tmpI))
copyto!(I, tmpI)
return I
end
_index_remapper(A::AbstractArray) = linearindices(A)
_index_remapper(iter) = OneTo(typemax(Int)) # safe for objects that don't implement length
find(testf::Function, A) = collect(first(p) for p in _pairs(A) if testf(last(p)))

_pairs(A::Union{AbstractArray, AbstractDict}) = pairs(A)
_pairs(iter) = zip(OneTo(typemax(Int)), iter) # safe for objects that don't implement length
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you check the iterator trait here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? We don't guarantee that iterators implement pairs currently.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably use countfrom(1).


"""
find(A)
Expand All @@ -1786,22 +1797,10 @@ julia> find(falses(3))
```
"""
function find(A)
nnzA = count(t -> t != 0, A)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to issues with the inference of the returned eltype (see commit message), I've taken a radical approach simply using collect instead of the custom loops. This means we no longer compute the length of the result before filling it. Benchmarks will be needed to check what's the best approach, but at first sight it doesn't sound obvious to me that doing two passes over the data is a good tradeoff, does it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Historically it was worth it, but might be worth benchmarking again. See also https://discourse.julialang.org/t/push-and-interfacing-to-the-runtime-library/7461, which suggests that two passes might still be faster (growing an array with push! is 3x slower than growing it with setindex!).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also https://discourse.julialang.org/t/half-vectorization/7399/3, which benchmarks some conditional comprehensions with somewhat alarming results. I think you really need to benchmark this change before we can decide.

I = Vector{Int}(uninitialized, nnzA)
cnt = 1
inds = _index_remapper(A)
warned = false
for (i,a) in enumerate(A)
if !warned && !(a isa Bool)
depwarn("In the future `find(A)` will only work on boolean collections. Use `find(x->x!=0, A)` instead.", :find)
warned = true
end
if a != 0
I[cnt] = inds[i]
cnt += 1
end
if !(eltype(A) === Bool) && !all(x -> x isa Bool, A)
depwarn("In the future `find(A)` will only work on boolean collections. Use `find(x->x!=0, A)` instead.", :find)
end
return I
collect(first(p) for p in _pairs(A) if last(p) != 0)
end

find(x::Bool) = x ? [1] : Vector{Int}()
Expand Down
2 changes: 1 addition & 1 deletion base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,7 @@ function find(p::Function, S::SparseMatrixCSC)
end
sz = size(S)
I, J = _findn(p, S)
return Base._sub2ind(sz, I, J)
return CartesianIndex.(I, J)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the new behavior of find, findn could probably be removed as it gives almost the same information (struct of arrays vs. array of structs).

This method could also be optimized a bit by not allocating I and J first (but it already does less work than before).

end
find(p::Base.OccursIn, x::SparseMatrixCSC) =
invoke(find, Tuple{Base.OccursIn, AbstractArray}, p, x)
Expand Down
10 changes: 10 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,16 @@ end
@test findnext(equalto(0x00), [0x00, 0x01, 0x00], 2) == 3
@test findprev(equalto(0x00), [0x00, 0x01, 0x00], 2) == 1
end
@testset "find with Matrix" begin
A = [1 2 0; 3 4 0]
@test find(isodd, A) == [CartesianIndex(1, 1), CartesianIndex(2, 1)]
@test find(!iszero, A) == [CartesianIndex(1, 1), CartesianIndex(2, 1),
CartesianIndex(1, 2), CartesianIndex(2, 2)]
end
@testset "find with Dict" begin
d = Dict(:A => 10, :B => -1, :C => 0)
@test sort(find(x -> x >= 0, d)) == [:A, :C]
end
@testset "find with general iterables" begin
s = "julia"
@test find(c -> c == 'l', s) == [3]
Expand Down