Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-threaded sampling with reversediff backend and rdcache gives bad samples #1412

Closed
BlackWingedKing opened this issue Sep 23, 2020 · 17 comments · Fixed by #1414
Closed

Comments

@BlackWingedKing
Copy link

BlackWingedKing commented Sep 23, 2020

The following are the posterior plots for the coin flip turing example
image
As we can see the posterior from multi-threaded sampling with rdcache is not at all close to others
Code to reproduce this

# Load the modules
using Turing, MCMCChains, Distributions, StatsPlots, Random
using ReverseDiff, Memoization, Zygote

println("loaded modules")
# Set the true probability of heads in a coin.
p_true = 0.5

# Iterate from having seen 0 observations to 100 observations.
Ns = 0:100
Random.seed!(12)
data = rand(Bernoulli(p_true), last(Ns))
# define the model
@model coinflip(y) = begin
    p ~ Beta(1, 1)
    N = length(y)
    for n in 1:N
        y[n] ~ Bernoulli(p)
    end
end
# parameters
num_chains = 10
iterations = 100
model_coin = coinflip(data)

# sampling
chain_serial = mapreduce(c -> sample(model_coin, NUTS(0.65), iterations), chainscat, 1:num_chains)
chain_multithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)

# now enable Zygote backend
Turing.setadbackend(:zygote)
chain_zymultithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)

# now enable ReverseDiff backend without rdcache
Turing.setadbackend(:reversediff)
Turing.setrdcache(false)
chain_rdmultithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)

# with rdcache
Turing.setrdcache(true)
chain_rdcachemultithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)

# plot the combined density of all those chains
density(chain_serial[:p][:], label="serial", legend=:topleft)
density!(chain_multithread[:p][:], label="multi-thread-normal", legend=:topleft)
density!(chain_zymultithread[:p][:], label="multi-thread-zygote", legend=:topleft)
density!(chain_rdmultithread[:p][:], label="multi-thread-reversediff", legend=:topleft)
density!(chain_rdcachemultithread[:p][:], label="multi-thread-rdcache", legend=:topleft)

savefig("mwe.png")

My configuration

  JULIA_NUM_THREADS=4
  [31c24e10] Distributions v0.23.12
  [ced4e74d] DistributionsAD v0.6.9
  [c7f686f2] MCMCChains v4.2.1
  [6fafb56a] Memoization v0.1.4
  [37e2e3b7] ReverseDiff v1.4.3
  [f3b207a7] StatsPlots v0.14.13
  [fce5fe82] Turing v0.14.3
  [e88e6eb3] Zygote v0.5.7

Also posted in the slack channel

@BlackWingedKing
Copy link
Author

@mohamed82008

@devmotion
Copy link
Member

Maybe related to the other ReverseDiff/Memoization issues (see, e.g., #1393). According to its README, Memoization is also not thread-safe (see https://github.com/marius311/Memoization.jl), so I'm not surprised that multithreaded sampling is problematic. Does ReverseDiff work without Memoization (you have to restart the Julia process to make sure there's no memoized stuff floating around anymore)?

@BlackWingedKing
Copy link
Author

@devmotion yes it works fine. I will update the plot and mwe to include it

@devmotion
Copy link
Member

Then it seems the issue here is that Memoization is not threadsafe

@BlackWingedKing
Copy link
Author

Thank you @devmotion. I was wondering if we could have a warning/error message with multi-threaded sampling when rdcache is enabled. It could be something like this
Warning: Memoization isn't threadsafe. Please don't use rdcache with Multi-Threaded sampling

@devmotion
Copy link
Member

I think maybe it's possible to fix the problem on our side by using a Dict for memoization and adding the thread id to the list of keys. So I would like to check that first before adding a warning.

@BlackWingedKing
Copy link
Author

Thanks for the suggestion. Could you point me out to any example which does this?

@mohamed82008
Copy link
Member

I think this is fixable too. I am actually surprised this is not working at all. Because even a race condition shouldn't affect the result because all the threads are writing the same compiled tape.

@devmotion
Copy link
Member

Could you point me out to any example which does this?

It's something in Turing that has to be changed (if it fixes the issue).

@mohamed82008
Copy link
Member

Hmm ok I think I know what's going on. It's not the memoization, it's ReverseDiff. The compiled tape has cache fields re-used every time the tape is differentiated. Different compiled tapes for different threads should solve the problem here.

@devmotion
Copy link
Member

Yep, that's what I suggested above 🙂

@mohamed82008
Copy link
Member

Yes I was agreeing with you :)

@BlackWingedKing
Copy link
Author

Closing this as #1414 fixes it.

@devmotion
Copy link
Member

Turing 0.14.4 is available now which should contain the fix for this problem.

@BlackWingedKing
Copy link
Author

@devmotion The repository still shows 0.14.3 as the latest release. Does it take some time to update?

@devmotion
Copy link
Member

Tags for the git repo are created automatically at midnight every day. For the Julia package manager it is only relevant that JuliaRegistries/General#21909 was merged - as soon as the registry is updated users are able to update Turing with Pkg.

@BlackWingedKing
Copy link
Author

Thanks for the clarification!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants