Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

posterior_linpred() for ordinal families: argument for taking the intercept into account #1137

Merged
merged 34 commits into from
May 5, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a609e11
Introduce an argument to posterior_linpred() for taking the intercept…
fweber144 Apr 12, 2021
fcd2c21
Add a NEWS entry.
fweber144 Apr 12, 2021
0a68dae
Insert GitHub PR number.
fweber144 Apr 12, 2021
40e43ac
Merge branch 'master' into projpred_augdat
fweber144 Apr 13, 2021
ec4126c
Merge branch 'master' into projpred_augdat
fweber144 Apr 14, 2021
448e813
Merge branch 'master' into projpred_augdat
fweber144 Apr 14, 2021
7b696c2
Merge branch 'master' into projpred_augdat
fweber144 Apr 29, 2021
de47fdd
add 'slice' function
paul-buerkner Apr 29, 2021
837a8e5
refactor 'dcumulative'
paul-buerkner Apr 29, 2021
3a431d8
update implementation of 'incl_thres'
paul-buerkner Apr 29, 2021
63ba13e
fix typo
paul-buerkner Apr 29, 2021
827591f
Re-indent tests/local/tests.models_new.R
fweber144 Apr 30, 2021
8e7f4fe
Add (preliminary) tests for argument `incl_thres` of posterior_linpre…
fweber144 Apr 30, 2021
4adac19
Fix a test for argument `incl_thres` of posterior_linpred() (the one …
fweber144 Apr 30, 2021
089e506
Remove an unnecessary check.
fweber144 Apr 30, 2021
d689fd7
Fix a typo.
fweber144 Apr 30, 2021
b690b75
posterior_epred_ordinal() in case of grouped thresholds: Fill missing…
fweber144 Apr 30, 2021
71eb8eb
Merge branch 'master' into projpred_augdat
fweber144 May 1, 2021
8eb1157
posterior_epred_ordinal() in case of grouped thresholds: For the "ide…
fweber144 May 1, 2021
650b732
Replace remaining extract_col() occurrences by slice_col().
fweber144 May 1, 2021
e089a4d
minor cleaning
paul-buerkner May 2, 2021
5c1c2e3
Internally document dcumulative() and inv_link_cumulative().
fweber144 May 3, 2021
00148e1
In inv_link_cumulative(): Overwrite `x`.
fweber144 May 3, 2021
64080d3
Create and use inv_link_sratio().
fweber144 May 3, 2021
1da7f7e
Create and use inv_link_cratio().
fweber144 May 3, 2021
b0bde2f
Create and use inv_link_acat().
fweber144 May 3, 2021
5b4b3f9
Test that d<ordinal_family>() works correctly.
fweber144 May 4, 2021
5fa503e
Add argument `drop` to slice().
fweber144 May 4, 2021
aadc683
inv_link_sratio(), inv_link_cratio(), and inv_link_acat(): Allow for …
fweber144 May 4, 2021
d7e50bf
Test that inv_link_<ordinal_family>() works correctly for arrays.
fweber144 May 4, 2021
5c2cdb4
minor cleaning
paul-buerkner May 5, 2021
2438629
add frank as contributor
paul-buerkner May 5, 2021
68b4fe7
some more minor cleaning
paul-buerkner May 5, 2021
c841e36
more cleaning
paul-buerkner May 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update implementation of 'incl_thres'
  • Loading branch information
paul-buerkner committed Apr 29, 2021
commit 3a431d8a4f1dc9e942806a251191b76d17d88f14
68 changes: 31 additions & 37 deletions R/posterior_epred.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ posterior_epred.brmsfit <- function(object, newdata = NULL, re_formula = NULL,
nsamples = nsamples, subset = subset, check_response = FALSE, ...
)
posterior_epred(
prep, scale = "response", dpar = dpar,
nlpar = nlpar, sort = sort, summary = FALSE
prep, dpar = dpar, nlpar = nlpar, sort = sort,
scale = "response", summary = FALSE
)
}

Expand All @@ -72,18 +72,10 @@ posterior_epred.mvbrmsprep <- function(object, ...) {
}

#' @export
posterior_epred.brmsprep <- function(object, scale, dpar, nlpar, sort,
summary, robust, probs,
incl_thres = FALSE, ...) {
incl_thres <- as_one_logical(incl_thres)
if (incl_thres && !is_ordinal(object$family)) {
warning2("Argument 'incl_thres' is set to FALSE for non-ordinal families.")
incl_thres <- FALSE
}
if (incl_thres && scale == "response") {
warning2("Argument 'incl_thres' is set to FALSE if scale == \"response\".")
incl_thres <- FALSE
}
posterior_epred.brmsprep <- function(object, dpar, nlpar, sort,
scale = "response", incl_thres = NULL,
summary = FALSE, robust = FALSE,
probs = c(0.025, 0.975), ...) {
dpars <- names(object$dpars)
nlpars <- names(object$nlpars)
if (length(dpar)) {
Expand Down Expand Up @@ -129,8 +121,18 @@ posterior_epred.brmsprep <- function(object, scale, dpar, nlpar, sort,
}
out <- get_nlpar(object, nlpar = nlpar)
} else {
# predict the mean of the response distribution
# no dpar or nlpar specified
incl_thres <- as_one_logical(incl_thres %||% FALSE)
if (scale == "linear" && incl_thres && is_ordinal(object$family)) {
# extract linear predictor array with thresholds etc. included
if (is.mixfamily(object$family)) {
stop2("'incl_thres' is not supported for mixture models.")
}
object$family$link <- "identity"
scale <- "response"
}
if (scale == "response") {
# predict the mean of the response distribution
for (nlp in nlpars) {
object$nlpars[[nlp]] <- get_nlpar(object, nlpar = nlp)
}
Expand All @@ -146,16 +148,17 @@ posterior_epred.brmsprep <- function(object, scale, dpar, nlpar, sort,
}
} else {
# return results on the linear scale
# extract all 'mu' parameters
if (conv_cats_dpars(object$family)) {
mus <- dpars[grepl("^mu", dpars)]
out <- dpars[grepl("^mu", dpars)]
} else {
mus <- dpars[dpar_class(dpars) %in% "mu"]
out <- dpars[dpar_class(dpars) %in% "mu"]
}
if (length(mus) == 1L) {
out <- get_dpar(object, dpar = mus, ilink = FALSE)
if (length(out) == 1L) {
out <- get_dpar(object, dpar = out, ilink = FALSE)
} else {
# multiple mu parameters in categorical or mixture models
out <- lapply(mus, get_dpar, prep = object, ilink = FALSE)
out <- lapply(out, get_dpar, prep = object, ilink = FALSE)
out <- abind::abind(out, along = 3)
}
}
Expand All @@ -165,13 +168,6 @@ posterior_epred.brmsprep <- function(object, scale, dpar, nlpar, sort,
}
colnames(out) <- NULL
out <- reorder_obs(out, object$old_order, sort = sort)
if (incl_thres) {
out <- sapply(seq_len(ncol(out)), function(i) {
sweep(object$thres$thres, 1, as.array(out[, i])) # Shorter (but lacks sweep()'s recycling checks): `object$thres$thres - out[, i]`
}, simplify = "array")
out <- aperm(out, perm = c(1, 3, 2))
dimnames(out) <- NULL
}
if (summary) {
# only for compatibility with the 'fitted' method
out <- posterior_summary(out, probs = probs, robust = robust)
Expand Down Expand Up @@ -261,7 +257,7 @@ fitted.brmsfit <- function(object, newdata = NULL, re_formula = NULL,
nsamples = nsamples, subset = subset, check_response = FALSE, ...
)
posterior_epred(
prep, scale = scale, dpar = dpar, nlpar = nlpar, sort = sort,
prep, dpar = dpar, nlpar = nlpar, sort = sort, scale = scale,
summary = summary, robust = robust, probs = probs
)
}
Expand All @@ -281,12 +277,10 @@ fitted.brmsfit <- function(object, newdata = NULL, re_formula = NULL,
#' @param dpar Name of a predicted distributional parameter
#' for which samples are to be returned. By default, samples
#' of the main distributional parameter(s) \code{"mu"} are returned.
#' @param incl_thres For ordinal families only: A single logical value
#' indicating whether to take the thresholds (intercepts) into account
#' (\code{TRUE}) or not (\code{FALSE}). Thereby, "taking the thresholds into
#' account" means to substract the threshold-excluding linear predictor from
#' the thresholds instead of simply returning the threshold-excluding linear
#' predictor.
#' @param incl_thres Logical; only relevant for ordinal models when
#' \code{transform} is \code{FALSE}, and ignored otherwise. Shall the
#' thresholds and category-specific effects be included in the linear
#' predictor? For backwards compatibility, the default is to not include them.
#'
#' @seealso \code{\link{posterior_epred.brmsfit}}
#'
Expand All @@ -309,7 +303,7 @@ fitted.brmsfit <- function(object, newdata = NULL, re_formula = NULL,
posterior_linpred.brmsfit <- function(
object, transform = FALSE, newdata = NULL, re_formula = NULL,
re.form = NULL, resp = NULL, dpar = NULL, nlpar = NULL,
nsamples = NULL, subset = NULL, sort = FALSE, incl_thres = FALSE, ...
incl_thres = NULL, nsamples = NULL, subset = NULL, sort = FALSE, ...
) {
cl <- match.call()
if ("re.form" %in% names(cl)) {
Expand All @@ -333,8 +327,8 @@ posterior_linpred.brmsfit <- function(
nsamples = nsamples, subset = subset, check_response = FALSE, ...
)
posterior_epred(
prep, scale = scale, dpar = dpar,
nlpar = nlpar, sort = sort, summary = FALSE, incl_thres = incl_thres
prep, dpar = dpar, nlpar = nlpar, sort = sort,
scale = scale, incl_thres = incl_thres, summary = FALSE
)
}

Expand Down
14 changes: 6 additions & 8 deletions man/posterior_linpred.brmsfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.