From b8110f8d1ec6349bee77efb5022621fdf50bd4a5 Mon Sep 17 00:00:00 2001 From: Jan Weidner Date: Tue, 30 Jun 2020 09:02:08 +0200 Subject: [PATCH] add BLAS.get_num_threads (#36360) Co-authored-by: Okon Samuel <39421418+OkonSamuel@users.noreply.github.com> Co-authored-by: Takafumi Arakaki --- NEWS.md | 1 + stdlib/LinearAlgebra/src/blas.jl | 100 ++++++++++++++++++++++++++---- stdlib/LinearAlgebra/test/blas.jl | 27 ++++++++ 3 files changed, 115 insertions(+), 13 deletions(-) diff --git a/NEWS.md b/NEWS.md index 2c77e9e470921..acab9f17137ab 100644 --- a/NEWS.md +++ b/NEWS.md @@ -61,6 +61,7 @@ Standard library changes #### LinearAlgebra * New method `LinearAlgebra.issuccess(::CholeskyPivoted)` for checking whether pivoted Cholesky factorization was successful ([#36002]). * `UniformScaling` can now be indexed into using ranges to return dense matrices and vectors ([#24359]). +* New function `LinearAlgebra.BLAS.get_num_threads()` for getting the number of BLAS threads. ([#36360]) #### Markdown diff --git a/stdlib/LinearAlgebra/src/blas.jl b/stdlib/LinearAlgebra/src/blas.jl index a2367d7ee4b90..fca99ceee3f78 100644 --- a/stdlib/LinearAlgebra/src/blas.jl +++ b/stdlib/LinearAlgebra/src/blas.jl @@ -106,27 +106,101 @@ end openblas_get_config() = strip(unsafe_string(ccall((@blasfunc(openblas_get_config), libblas), Ptr{UInt8}, () ))) +function guess_vendor() + # like determine_vendor, but guesses blas in some cases + # where determine_vendor returns :unknown + ret = vendor() + if Sys.isapple() && (ret == :unknown) + ret = :osxblas + end + ret +end + + """ - set_num_threads(n) + set_num_threads(n::Integer) + set_num_threads(::Nothing) -Set the number of threads the BLAS library should use. +Set the number of threads the BLAS library should use equal to `n::Integer`. + +Also accepts `nothing`, in which case julia tries to guess the default number of threads. +Passing `nothing` is discouraged and mainly exists for the following reason: + +On exotic variants of BLAS, `nothing` may be returned by `get_num_threads()`. +Thus on exotic variants of BLAS, the following pattern may fail to set the number of threads: + +```julia +old = get_num_threads() +set_num_threads(1) +@threads for i in 1:10 + # single-threaded BLAS calls +end +set_num_threads(old) +``` +Because `set_num_threads` accepts `nothing`, this code can still run +on exotic variants of BLAS without error. Warnings will be raised instead. + +!!! compat "Julia 1.6" + `set_num_threads(::Nothing)` requires at least Julia 1.6. """ -function set_num_threads(n::Integer) - blas = vendor() - if blas === :openblas - return ccall((:openblas_set_num_threads, libblas), Cvoid, (Int32,), n) - elseif blas === :openblas64 - return ccall((:openblas_set_num_threads64_, libblas), Cvoid, (Int32,), n) - elseif blas === :mkl +set_num_threads(n)::Nothing = _set_num_threads(n) + +function _set_num_threads(n::Integer; _blas = guess_vendor()) + if _blas === :openblas || _blas == :openblas64 + return ccall((@blasfunc(openblas_set_num_threads), libblas), Cvoid, (Cint,), n) + elseif _blas === :mkl # MKL may let us set the number of threads in several ways return ccall((:MKL_Set_Num_Threads, libblas), Cvoid, (Cint,), n) - end - - # OSX BLAS looks at an environment variable - @static if Sys.isapple() + elseif _blas === :osxblas + # OSX BLAS looks at an environment variable ENV["VECLIB_MAXIMUM_THREADS"] = n + else + @assert _blas === :unknown + @warn "Failed to set number of BLAS threads." maxlog=1 end + return nothing +end + +_tryparse_env_int(key) = tryparse(Int, get(ENV, key, "")) + +function _set_num_threads(::Nothing; _blas = guess_vendor()) + n = something( + _tryparse_env_int("OPENBLAS_NUM_THREADS"), + _tryparse_env_int("OMP_NUM_THREADS"), + max(1, Sys.CPU_THREADS รท 2), + ) + _set_num_threads(n; _blas) +end + +""" + get_num_threads() +Get the number of threads the BLAS library is using. + +On exotic variants of `BLAS` this function can fail, which is indicated by returning `nothing`. + +!!! compat "Julia 1.6" + `get_num_threads` requires at least Julia 1.6. +""" +get_num_threads(;_blas=guess_vendor())::Union{Int, Nothing} = _get_num_threads() + +function _get_num_threads(; _blas = guess_vendor())::Union{Int, Nothing} + if _blas === :openblas || _blas === :openblas64 + return Int(ccall((@blasfunc(openblas_get_num_threads), libblas), Cint, ())) + elseif _blas === :mkl + return Int(ccall((:mkl_get_max_threads, libblas), Cint, ())) + elseif _blas === :osxblas + key = "VECLIB_MAXIMUM_THREADS" + nt = _tryparse_env_int(key) + if nt === nothing + @warn "Failed to read environment variable $key" maxlog=1 + else + return nt + end + else + @assert _blas === :unknown + end + @warn "Could not get number of BLAS threads. Returning `nothing` instead." maxlog=1 return nothing end diff --git a/stdlib/LinearAlgebra/test/blas.jl b/stdlib/LinearAlgebra/test/blas.jl index 23c6d68cdc997..cefb0625d00d5 100644 --- a/stdlib/LinearAlgebra/test/blas.jl +++ b/stdlib/LinearAlgebra/test/blas.jl @@ -553,4 +553,31 @@ Base.stride(A::WrappedArray, i::Int) = stride(A.A, i) end end +@testset "get_set_num_threads" begin + default = BLAS.get_num_threads() + @test default isa Int + @test default > 0 + BLAS.set_num_threads(1) + @test BLAS.get_num_threads() === 1 + BLAS.set_num_threads(default) + @test BLAS.get_num_threads() === default + + @test_logs (:warn,) match_mode=:any BLAS._set_num_threads(1, _blas=:unknown) + if BLAS.guess_vendor() !== :osxblas + # test osxblas which is not covered by CI + withenv("VECLIB_MAXIMUM_THREADS" => nothing) do + @test @test_logs( + (:warn,), + (:warn,), + match_mode=:any, + BLAS._get_num_threads(_blas=:osxblas), + ) === nothing + @test_logs BLAS._set_num_threads(1, _blas=:osxblas) + @test @test_logs(BLAS._get_num_threads(_blas=:osxblas)) === 1 + @test_logs BLAS._set_num_threads(2, _blas=:osxblas) + @test @test_logs(BLAS._get_num_threads(_blas=:osxblas)) === 2 + end + end +end + end # module TestBLAS