Skip to content

Commit

Permalink
Vectorise random vectors of Float16
Browse files Browse the repository at this point in the history
  • Loading branch information
giordano committed Oct 5, 2024
1 parent 99cc59c commit f045462
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions stdlib/Random/src/XoshiroSimd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ simdThreshold(::Type{Bool}) = 640
l = Float32(li >>> 8) * Float32(0x1.0p-24)
(UInt64(reinterpret(UInt32, u)) << 32) | UInt64(reinterpret(UInt32, l))
end
@inline function _bits2float(x::UInt64, ::Type{Float16})
ui = (x>>>16) % UInt16
li = x % UInt16
u = Float16(ui >>> 4) * Float16(0x1.0p-11)
l = Float16(li >>> 4) * Float16(0x1.0p-11)
(UInt64(reinterpret(UInt16, u)) << 16) | UInt64(reinterpret(UInt16, l))
end

# required operations. These could be written more concisely with `ntuple`, but the compiler
# sometimes refuses to properly vectorize.
Expand Down Expand Up @@ -118,6 +125,18 @@ for N in [4,8,16]
ret <$N x i64> %i
"""
@eval @inline _bits2float(x::$VT, ::Type{Float32}) = llvmcall($code, $VT, Tuple{$VT}, x)

code = """
%as16 = bitcast <$N x i64> %0 to <$(4N) x i16>
%shiftamt = shufflevector <1 x i16> <i16 4>, <1 x i16> undef, <$(4N) x i32> zeroinitializer
%sh = lshr <$(4N) x i16> %as16, %shiftamt
%f = uitofp <$(4N) x i16> %sh to <$(4N) x half>
%scale = shufflevector <1 x half> <half 0x3f40000000000000>, <1 x half> undef, <$(4N) x i32> zeroinitializer
%m = fmul <$(4N) x half> %f, %scale
%i = bitcast <$(4N) x half> %m to <$N x i64>
ret <$N x i64> %i
"""
@eval @inline _bits2float(x::$VT, ::Type{Float16}) = llvmcall($code, $VT, Tuple{$VT}, x)
end
end

Expand All @@ -137,7 +156,7 @@ end

_id(x, T) = x

@inline function xoshiro_bulk(rng::Union{TaskLocalRNG, Xoshiro}, dst::Ptr{UInt8}, len::Int, T::Union{Type{UInt8}, Type{Bool}, Type{Float32}, Type{Float64}}, ::Val{N}, f::F = _id) where {N, F}
@inline function xoshiro_bulk(rng::Union{TaskLocalRNG, Xoshiro}, dst::Ptr{UInt8}, len::Int, T::Union{Type{UInt8}, Type{Bool}, Type{Float16}, Type{Float32}, Type{Float64}}, ::Val{N}, f::F = _id) where {N, F}
if len >= simdThreshold(T)
written = xoshiro_bulk_simd(rng, dst, len, T, Val(N), f)
len -= written
Expand Down Expand Up @@ -265,13 +284,8 @@ end
end


function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::Array{Float32}, ::SamplerTrivial{CloseOpen01{Float32}})
GC.@preserve dst xoshiro_bulk(rng, convert(Ptr{UInt8}, pointer(dst)), length(dst)*4, Float32, xoshiroWidth(), _bits2float)
dst
end

function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::Array{Float64}, ::SamplerTrivial{CloseOpen01{Float64}})
GC.@preserve dst xoshiro_bulk(rng, convert(Ptr{UInt8}, pointer(dst)), length(dst)*8, Float64, xoshiroWidth(), _bits2float)
function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::Array{T}, ::SamplerTrivial{CloseOpen01{T}}) where {T<:Union{Float16,Float32,Float64}}
GC.@preserve dst xoshiro_bulk(rng, convert(Ptr{UInt8}, pointer(dst)), length(dst)*sizeof(), T, xoshiroWidth(), _bits2float)
dst
end

Expand Down

0 comments on commit f045462

Please sign in to comment.