Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve efficiency of the inv_link_<ordinal_family>() functions #1155

Merged
merged 9 commits into from
May 6, 2021

Conversation

fweber144
Copy link
Contributor

@fweber144 fweber144 commented May 6, 2021

This PR is based on PR #1154, so it's better to merge #1154 first.

As suspected here (enumeration point 2), the current unit tests for the inv_link_<ordinal_family>() functions do indeed provide a more efficient implementation of these inv_link_<ordinal_family>() functions. Here is my requested speed comparison:

library(brms)
options(mc.cores = parallel::detectCores(logical = FALSE))
data(inhaler, package = "brms")
bfit_cumul <- brm(
  formula = rating ~ period + carry + treat + (1 | subject),
  data = inhaler,
  family = cumulative(),
  seed = 475064792
)
library(microbenchmark)
microbenchmark(epred_cumul <- posterior_epred(bfit_cumul),
               times = 25)
### Old:
# Unit: seconds
#                                       expr      min       lq     mean   median      uq      max neval
# epred_cumul <- posterior_epred(bfit_cumul) 1.206015 1.415697 1.525062 1.429557 1.68592 2.037138    25
### 
### New:
# Unit: seconds
#                                       expr      min       lq    mean   median       uq      max neval
# epred_cumul <- posterior_epred(bfit_cumul) 1.322469 1.329213 1.35928 1.336847 1.343451 1.615867    25
### 

bfit_sratio <- update(bfit_cumul, family = sratio())
microbenchmark(epred_sratio <- posterior_epred(bfit_sratio),
               times = 25)
### Old:
# Unit: seconds
#                                         expr      min       lq     mean   median       uq      max neval
# epred_sratio <- posterior_epred(bfit_sratio) 13.98803 14.29517 14.40808 14.38126 14.58076 14.65054    25
### 
### New:
# Unit: seconds
#                                         expr      min       lq     mean   median       uq      max neval
# epred_sratio <- posterior_epred(bfit_sratio) 8.475701 8.635784 8.723748 8.690884 8.795831 9.093068    25
### 

bfit_cratio <- update(bfit_cumul, family = cratio())
microbenchmark(epred_cratio <- posterior_epred(bfit_cratio),
               times = 25)
### Old:
# Unit: seconds
#                                         expr      min       lq     mean   median       uq      max neval
# epred_cratio <- posterior_epred(bfit_cratio) 13.97217 14.32164 14.39963 14.34807 14.47658 14.70686    25
### 
### New:
# Unit: seconds
#                                         expr      min       lq    mean   median       uq      max neval
# epred_cratio <- posterior_epred(bfit_cratio) 8.169052 8.572888 8.62483 8.597578 8.634792 8.878164    25
### 

bfit_acat <- update(bfit_cumul, family = acat())
microbenchmark(epred_acat <- posterior_epred(bfit_acat),
               times = 25)
### Old:
# Unit: seconds
#                                     expr      min       lq     mean   median       uq      max neval
# epred_acat <- posterior_epred(bfit_acat) 13.73694 13.88184 13.94854 13.92991 13.95867 14.17641    25
### 
### New:
# Unit: seconds
#                                     expr      min       lq    mean   median       uq      max neval
# epred_acat <- posterior_epred(bfit_acat) 12.32044 12.65225 12.7692 12.72448 12.90858 13.42917    25
### 

bfit_acat_probit <- update(bfit_cumul, family = acat(link = "probit"))
microbenchmark(epred_acat_probit <- posterior_epred(bfit_acat_probit),
               times = 25)
### Old:
# Unit: seconds
#                                                   expr      min      lq     mean   median       uq      max neval
# epred_acat_probit <- posterior_epred(bfit_acat_probit) 31.17294 31.3034 31.45372 31.49983 31.55278 31.70364    25
### 
### New:
# Unit: seconds
#                                                   expr      min       lq     mean   median       uq      max neval
# epred_acat_probit <- posterior_epred(bfit_acat_probit) 21.54279 22.21362 22.39333 22.48148 22.65841 23.15681    25
### 

Because of this speed improvement (which is sometimes smaller, sometimes larger, but always present), this PR swaps the two implementations of the inv_link_<ordinal_family>() functions (the original one and the one from the unit tests).

Of course, one could go one step further and achieve another speed improvement by using arrays when calling d<ordinal_family>() in posterior_epred_ordinal(), i.e. by not iterating over the observations, but instead including them as an additional array margin. But that probably requires larger changes.

…ctions: In `distributions.R`, use the more efficient implementation from the unit tests and in the unit tests, use the original implementation.
@paul-buerkner
Copy link
Owner

Very elegant implementations. Thank you! Will merge once the checks pass.

@paul-buerkner
Copy link
Owner

When I change the number of categories in inv_link_ordinal_sim from 3 to 2 some tests fail. Can you take a look at fix the implementations to work with 2 categories as well?

@fweber144
Copy link
Contributor Author

Yes, I'll take a look at it. Thanks for the hint.

@fweber144
Copy link
Contributor Author

Should be fixed now. And I hope I have included all special cases in the unit tests now. Thanks again for pointing this out and sorry for not being aware of this.

@paul-buerkner
Copy link
Owner

Thanks for fixing this! And no worries, 1/3 of all brms bugs are caused by R dropping dimensions somewhere in edge cases :-D

@paul-buerkner paul-buerkner merged commit 98d0cc8 into paul-buerkner:master May 6, 2021
@fweber144
Copy link
Contributor Author

fweber144 commented May 7, 2021

Yeah, that dropping of margins is not really developer-friendly, especially when there's no way to turn it off, as for apply(). Thanks for merging!

@fweber144 fweber144 deleted the ordinal_speed branch May 7, 2021 06:17
@jgabry
Copy link
Contributor

jgabry commented May 7, 2021

And no worries, 1/3 of all brms bugs are caused by R dropping dimensions somewhere in edge cases :-D

1/3 of all nightmares I have are caused by R dropping dimensions ;)

@wds15
Copy link
Contributor

wds15 commented May 8, 2021

And then filling up containers by repeating things is a real nightmare.

@fweber144
Copy link
Contributor Author

You mean because I repeatedly added those array(..., dim = c(dim_thres, dim_noncat)) calls in commit 8b1704a? That's true, that could perhaps have been solved more elegantly by adding a custom wrapper around apply() which does not drop margins.

@paul-buerkner
Copy link
Owner

paul-buerkner commented May 8, 2021 via email

@fweber144
Copy link
Contributor Author

Ah I see :D

Concerning the apply() wrapper: I'm currently lacking the time to implement this, but I'll try to keep it in mind.

@paul-buerkner
Copy link
Owner

paul-buerkner commented May 8, 2021 via email

@fweber144
Copy link
Contributor Author

As if the R Core Team had heard us: For R 4.1.0, the NEWS file says:

apply() gains a simplify argument to allow disabling of simplification of results.

This new argument doesn't offer exactly what I would have desired, but it should be a good starting point for writing a custom apply() wrapper. @paul-buerkner, do you want me to write such a wrapper? But it would make brms depend on R >= 4.1.0.

@paul-buerkner
Copy link
Owner

paul-buerkner commented May 20, 2021 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants