Skip to content

Commit

Permalink
add BLAS.get_num_threads (#36360)
Browse files Browse the repository at this point in the history
Co-authored-by: Okon Samuel <39421418+OkonSamuel@users.noreply.github.com>
Co-authored-by: Takafumi Arakaki <aka.tkf@gmail.com>
  • Loading branch information
3 people authored Jun 30, 2020
1 parent 39c278b commit b8110f8
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 13 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Standard library changes
#### LinearAlgebra
* New method `LinearAlgebra.issuccess(::CholeskyPivoted)` for checking whether pivoted Cholesky factorization was successful ([#36002]).
* `UniformScaling` can now be indexed into using ranges to return dense matrices and vectors ([#24359]).
* New function `LinearAlgebra.BLAS.get_num_threads()` for getting the number of BLAS threads. ([#36360])

#### Markdown

Expand Down
100 changes: 87 additions & 13 deletions stdlib/LinearAlgebra/src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,27 +106,101 @@ end

openblas_get_config() = strip(unsafe_string(ccall((@blasfunc(openblas_get_config), libblas), Ptr{UInt8}, () )))

function guess_vendor()
# like determine_vendor, but guesses blas in some cases
# where determine_vendor returns :unknown
ret = vendor()
if Sys.isapple() && (ret == :unknown)
ret = :osxblas
end
ret
end


"""
set_num_threads(n)
set_num_threads(n::Integer)
set_num_threads(::Nothing)
Set the number of threads the BLAS library should use.
Set the number of threads the BLAS library should use equal to `n::Integer`.
Also accepts `nothing`, in which case julia tries to guess the default number of threads.
Passing `nothing` is discouraged and mainly exists for the following reason:
On exotic variants of BLAS, `nothing` may be returned by `get_num_threads()`.
Thus on exotic variants of BLAS, the following pattern may fail to set the number of threads:
```julia
old = get_num_threads()
set_num_threads(1)
@threads for i in 1:10
# single-threaded BLAS calls
end
set_num_threads(old)
```
Because `set_num_threads` accepts `nothing`, this code can still run
on exotic variants of BLAS without error. Warnings will be raised instead.
!!! compat "Julia 1.6"
`set_num_threads(::Nothing)` requires at least Julia 1.6.
"""
function set_num_threads(n::Integer)
blas = vendor()
if blas === :openblas
return ccall((:openblas_set_num_threads, libblas), Cvoid, (Int32,), n)
elseif blas === :openblas64
return ccall((:openblas_set_num_threads64_, libblas), Cvoid, (Int32,), n)
elseif blas === :mkl
set_num_threads(n)::Nothing = _set_num_threads(n)

function _set_num_threads(n::Integer; _blas = guess_vendor())
if _blas === :openblas || _blas == :openblas64
return ccall((@blasfunc(openblas_set_num_threads), libblas), Cvoid, (Cint,), n)
elseif _blas === :mkl
# MKL may let us set the number of threads in several ways
return ccall((:MKL_Set_Num_Threads, libblas), Cvoid, (Cint,), n)
end

# OSX BLAS looks at an environment variable
@static if Sys.isapple()
elseif _blas === :osxblas
# OSX BLAS looks at an environment variable
ENV["VECLIB_MAXIMUM_THREADS"] = n
else
@assert _blas === :unknown
@warn "Failed to set number of BLAS threads." maxlog=1
end
return nothing
end

_tryparse_env_int(key) = tryparse(Int, get(ENV, key, ""))

function _set_num_threads(::Nothing; _blas = guess_vendor())
n = something(
_tryparse_env_int("OPENBLAS_NUM_THREADS"),
_tryparse_env_int("OMP_NUM_THREADS"),
max(1, Sys.CPU_THREADS ÷ 2),
)
_set_num_threads(n; _blas)
end

"""
get_num_threads()
Get the number of threads the BLAS library is using.
On exotic variants of `BLAS` this function can fail, which is indicated by returning `nothing`.
!!! compat "Julia 1.6"
`get_num_threads` requires at least Julia 1.6.
"""
get_num_threads(;_blas=guess_vendor())::Union{Int, Nothing} = _get_num_threads()

function _get_num_threads(; _blas = guess_vendor())::Union{Int, Nothing}
if _blas === :openblas || _blas === :openblas64
return Int(ccall((@blasfunc(openblas_get_num_threads), libblas), Cint, ()))
elseif _blas === :mkl
return Int(ccall((:mkl_get_max_threads, libblas), Cint, ()))
elseif _blas === :osxblas
key = "VECLIB_MAXIMUM_THREADS"
nt = _tryparse_env_int(key)
if nt === nothing
@warn "Failed to read environment variable $key" maxlog=1
else
return nt
end
else
@assert _blas === :unknown
end
@warn "Could not get number of BLAS threads. Returning `nothing` instead." maxlog=1
return nothing
end

Expand Down
27 changes: 27 additions & 0 deletions stdlib/LinearAlgebra/test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -553,4 +553,31 @@ Base.stride(A::WrappedArray, i::Int) = stride(A.A, i)
end
end

@testset "get_set_num_threads" begin
default = BLAS.get_num_threads()
@test default isa Int
@test default > 0
BLAS.set_num_threads(1)
@test BLAS.get_num_threads() === 1
BLAS.set_num_threads(default)
@test BLAS.get_num_threads() === default

@test_logs (:warn,) match_mode=:any BLAS._set_num_threads(1, _blas=:unknown)
if BLAS.guess_vendor() !== :osxblas
# test osxblas which is not covered by CI
withenv("VECLIB_MAXIMUM_THREADS" => nothing) do
@test @test_logs(
(:warn,),
(:warn,),
match_mode=:any,
BLAS._get_num_threads(_blas=:osxblas),
) === nothing
@test_logs BLAS._set_num_threads(1, _blas=:osxblas)
@test @test_logs(BLAS._get_num_threads(_blas=:osxblas)) === 1
@test_logs BLAS._set_num_threads(2, _blas=:osxblas)
@test @test_logs(BLAS._get_num_threads(_blas=:osxblas)) === 2
end
end
end

end # module TestBLAS

0 comments on commit b8110f8

Please sign in to comment.