-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix MLE/MAP with Zygote and ReverseDiff #1408
Conversation
For anyone stumbling on this issue, any MLE/MAP estimates run with Zygote or ReverseDiff as the AD backends in Turing 14.0, 14.1, or 14.2 are incorrect and should be re-run. |
Codecov Report
@@ Coverage Diff @@
## master #1408 +/- ##
=======================================
Coverage 67.01% 67.01%
=======================================
Files 25 25
Lines 1616 1616
=======================================
Hits 1083 1083
Misses 533 533
Continue to review full report at Codecov.
|
test/modes/ModeEstimation.jl
Outdated
@testset "AD backends" begin | ||
Random.seed!(222) | ||
true_value = [0.0625, 1.75] | ||
|
||
Turing.setadbackend(:forwarddiff) | ||
m1 = optimize(gdemo_default, MLE()) | ||
|
||
Turing.setadbackend(:reversediff) | ||
m2 = optimize(gdemo_default, MLE()) | ||
|
||
Turing.setadbackend(:tracker) | ||
m3 = optimize(gdemo_default, MLE()) | ||
|
||
Turing.setadbackend(:zygote) | ||
m4 = optimize(gdemo_default, MLE()) | ||
|
||
# Go back to normal forwarddiff for the rest of the tests | ||
Turing.setadbackend(:forwarddiff) | ||
|
||
@test all(isapprox.(m1.values.array - true_value, 0.0, atol=0.01)) | ||
@test all(isapprox.(m2.values.array - true_value, 0.0, atol=0.01)) | ||
@test all(isapprox.(m3.values.array - true_value, 0.0, atol=0.01)) | ||
@test all(isapprox.(m4.values.array - true_value, 0.0, atol=0.01)) | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it would be better to move the mode tests to the AD block in runtests.jl, to more easily comment out tests of different AD backends?
(In general, IMO in a different PR we might want to restructure the tests similar to DistributionsAD or Bijectors such that we can test specific AD backends more easily using environment variables. This would also allow to split the tests in CI tests for each backend.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer to keep it here because it falls under the optional dependency Optim -- there's really no great place to put these, so I think it's better to keep all of these "feature" tests under one banner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh maybe I didn't explain clearly what I mean. I just thought that instead of switching between AD backends here, you could move the include(...)
line for the mode tests in runtests.jl inside of the block that cycles through the AD backends.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then you would not have to add any new tests at all, the mode tests would run with every AD backend automatically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, yeah, that's way better. Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Cool, I'll merge this and flag a release once the tests pass. |
The loop of AD backends only includes Forward, Tracker and Reverse. There's no Zygote I think. |
Oh good catch. I think we should split the CI tests and use environment variables to determine which AD backend to run in runtests.jl (a bit similar to DistributionsAD). |
@devmotion Yeah, I think maybe you guys want to do a similar exercise as in DistributionsAD... But I'm convinced that the previous issue on MLE/MAP with Zygote is resolved. Thanks! |
MLE/MAP was not being tested with ReverseDiff or Zygote, so I missed the fact that the sampling context was not being passed to
gradient_logp
for those two backends. Thanks to @wupeifan for flagging this.I added tests for this as well to make sure this doesn't happen in the future.