Skip to content

Commit

Permalink
Better performance for sampling with replacement (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar authored Oct 14, 2024
1 parent 8e3b897 commit 34404f2
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 41 deletions.
45 changes: 23 additions & 22 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,22 +107,13 @@ end
end
elseif s.skip_k < s.seen_k
p = 1/s.seen_k
z = (1-p)^(n-3)
q = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p),1.0))
k = choose(n, p, q, z)
@inbounds begin
if k == 1
r = rand(s.rng, 1:n)
s.value[r] = el
update_order_single!(s, r)
else
for j in 1:k
r = rand(s.rng, j:n)
s.value[r] = el
s.value[r], s.value[j] = s.value[j], s.value[r]
update_order_multi!(s, r, j)
end
end
z = exp((n-4)*log1p(-p))
q = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p),1.0))
k = @inline choose(n, p, q, z)
@inbounds for j in 1:k
r = rand(s.rng, j:n)
s.value[r], s.value[j] = s.value[j], el
update_order_multi!(s, r, j)
end
s = recompute_skip!(s, n)
end
Expand Down Expand Up @@ -164,20 +155,22 @@ function recompute_skip!(s::SampleMultiAlgL, n)
return s
end
function recompute_skip!(s::SampleMultiAlgRSWRSKIP, n)
q = rand(s.rng)^(1/n)
q = exp(-randexp(s.rng)/n)
@update s.skip_k = ceil(Int, s.seen_k/q)-1
return s
end

function choose(n, p, q, z)
m = 1-p
s = z
z = s*m*m*(m + n*p)
z = s*m*m*m*(m + n*p)
z > q && return 1
z += n*p*(n-1)*p*s*m/2
z += n*p*(n-1)*p*s*m*m/2
z > q && return 2
z += n*p*(n-1)*p*(n-2)*p*s/6
z += n*p*(n-1)*p*(n-2)*p*s*m/6
z > q && return 3
z += n*p*(n-1)*p*(n-2)*p*(n-3)*p*s/24
z > q && return 4
b = Binomial(n, p)
return quantile(b, q)
end
Expand Down Expand Up @@ -226,7 +219,11 @@ function OnlineStatsBase.value(s::Union{SampleMultiAlgR, SampleMultiAlgL})
end
function OnlineStatsBase.value(s::SampleMultiAlgRSWRSKIP)
if nobs(s) < length(s.value)
return nobs(s) == 0 ? s.value[1:0] : sample(s.rng, s.value[1:nobs(s)], length(s.value))
if nobs(s) == 0
return s.value[1:0]
else
return sample(s.rng, s.value[1:nobs(s)], length(s.value))
end
else
return s.value
end
Expand All @@ -241,7 +238,11 @@ function ordvalue(s::Union{SampleMultiOrdAlgR, SampleMultiOrdAlgL})
end
function ordvalue(s::SampleMultiOrdAlgRSWRSKIP)
if nobs(s) < length(s.value)
return sample(s.rng, s.value[1:nobs(s)], length(s.value); ordered=true)
if nobs(s) == 0
return s.value[1:0]
else
return sample(s.rng, s.value[1:nobs(s)], length(s.value); ordered=true)
end
else
return s.value[sortperm(s.ord)]
end
Expand Down
27 changes: 9 additions & 18 deletions src/WeightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ end
@inbounds s.value[s.seen_k] = el
@inbounds s.weights[s.seen_k] = w
if s.seen_k == n
new_values = sample(s.rng, s.value, weights(s.weights), n; ordered = is_ordered(s))
new_values = sample(s.rng, s.value, Weights(s.weights, s.state), n; ordered = is_ordered(s))
@inbounds for i in 1:n
s.value[i] = new_values[i]
end
Expand All @@ -133,22 +133,13 @@ end
end
elseif s.skip_w <= s.state
p = w/s.state
z = (1-p)^(n-3)
q = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p),1.0))
k = choose(n, p, q, z)
@inbounds begin
if k == 1
r = rand(s.rng, 1:n)
s.value[r] = el
update_order_single!(s, r)
else
for j in 1:k
r = rand(s.rng, j:n)
s.value[r] = el
s.value[r], s.value[j] = s.value[j], s.value[r]
update_order_multi!(s, r, j)
end
end
z = exp((n-4)*log1p(-p))
q = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p),1.0))
k = @inline choose(n, p, q, z)
@inbounds for j in 1:k
r = rand(s.rng, j:n)
s.value[r], s.value[j] = s.value[j], el
update_order_multi!(s, r, j)
end
s = @inline recompute_skip!(s, n)
end
Expand Down Expand Up @@ -233,7 +224,7 @@ function recompute_skip!(s::SampleMultiAlgAExpJ)
return s
end
function recompute_skip!(s::SampleMultiAlgWRSWRSKIP, n)
q = rand(s.rng)^(1/n)
q = exp(-randexp(s.rng)/n)
@update s.skip_w = s.state/q
return s
end
Expand Down
2 changes: 1 addition & 1 deletion test/unweighted_sampling_multi_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
@test all(x -> a <= x <= b, value(rs))
@test nobs(rs) == 10

rngs = (StableRNG(42), StableRNG(43))
rngs = (StableRNG(46), StableRNG(47))
iters = (a:b, Iterators.filter(x -> x != b + 1, a:b+1), (a:floor(Int, b/2), (floor(Int, b/2)+1):b))
sizes = (2, 3)
for it in iters
Expand Down

0 comments on commit 34404f2

Please sign in to comment.