Skip to content

Commit

Permalink
fix issue #1716
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Dec 5, 2024
1 parent 2e5177a commit a0602ad
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 36 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ Package: brms
Encoding: UTF-8
Type: Package
Title: Bayesian Regression Models using 'Stan'
Version: 2.22.7
Date: 2024-11-26
Version: 2.22.8
Date: 2024-12-05
Authors@R:
c(person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.com",
role = c("aut", "cre")),
Expand Down
32 changes: 9 additions & 23 deletions R/stan-likelihood.R
Original file line number Diff line number Diff line change
Expand Up @@ -354,13 +354,8 @@ stan_log_lik_add_se <- function(sigma, bterms, reqn = stan_log_lik_adj(bterms),

# multiply 'dpar' by the 'rate' denominator within the Stan likelihood
# @param log add the rate denominator on the log scale if sensible?
# @param req_dot_multiply Censoring may turn non-vectorized into vectorized
# statements later on (see stan_log_lik_cens) which then makes the * operator
# invalid and requires .* instead. Accordingly, req_dot_multiply should be
# FALSE if [n] is required only because of censoring.
stan_log_lik_multiply_rate_denom <- function(
dpar, bterms, reqn = stan_log_lik_adj(bterms),
req_dot_multiply = stan_log_lik_adj(bterms, c("trunc", "weights")),
log = FALSE, transform = NULL, threads = NULL, ...) {

dpar_transform <- dpar
Expand All @@ -381,7 +376,7 @@ stan_log_lik_multiply_rate_denom <- function(
# dpar without resp name or index
dpar_clean <- sub("(_|\\[).*", "", dpar)
is_pred <- dpar_clean %in% c("mu", names(bterms$dpars))
operator <- str_if(req_dot_multiply || !is_pred, "*", ".*")
operator <- ".*"
}
glue("{dpar_transform} {operator} {denom}")
}
Expand Down Expand Up @@ -631,13 +626,11 @@ stan_log_lik_binomial <- function(bterms, ...) {
stan_log_lik_beta_binomial <- function(bterms, ...) {
p <- stan_log_lik_dpars(bterms)
p$trials <- stan_log_lik_advars(bterms, "trials", ...)$trials
req_dot_multiply <- !stan_log_lik_adj(bterms) && is_pred_dpar(bterms, "phi")
multiply <- str_if(req_dot_multiply, " .* ", " * ")
sdist(
"beta_binomial",
p$trials,
paste0(p$mu, multiply, p$phi),
paste0("(1 - ", p$mu, ")", multiply, p$phi)
paste0(p$mu, " .* ", p$phi),
paste0("(1 - ", p$mu, ") .* ", p$phi)
)
}

Expand Down Expand Up @@ -665,11 +658,10 @@ stan_log_lik_com_poisson <- function(bterms, ...) {
}

stan_log_lik_gamma <- function(bterms, ...) {
reqn <- stan_log_lik_adj(bterms) || is_pred_dpar(bterms, "shape")
reqn <- stan_log_lik_adj(bterms)
p <- stan_log_lik_dpars(bterms, reqn = reqn)
# Stan uses shape-rate parameterization with rate = shape / mean
div_op <- str_if(reqn, " / ", " ./ ")
sdist("gamma", p$shape, paste0(p$shape, div_op, p$mu))
sdist("gamma", p$shape, paste0(p$shape, " ./ ", p$mu))
}

stan_log_lik_exponential <- function(bterms, ...) {
Expand All @@ -681,18 +673,14 @@ stan_log_lik_exponential <- function(bterms, ...) {
stan_log_lik_weibull <- function(bterms, ...) {
p <- stan_log_lik_dpars(bterms)
# Stan uses shape-scale parameterization for weibull
need_dot_div <- !stan_log_lik_adj(bterms) && is_pred_dpar(bterms, "shape")
div_op <- str_if(need_dot_div, " ./ ", " / ")
p$scale <- paste0(p$mu, div_op, "tgamma(1 + 1", div_op, p$shape, ")")
p$scale <- paste0(p$mu, " ./ tgamma(1 + 1 ./ ", p$shape, ")")
sdist("weibull", p$shape, p$scale)
}

stan_log_lik_frechet <- function(bterms, ...) {
p <- stan_log_lik_dpars(bterms)
# Stan uses shape-scale parameterization for frechet
need_dot_div <- !stan_log_lik_adj(bterms) && is_pred_dpar(bterms, "nu")
div_op <- str_if(need_dot_div, " ./ ", " / ")
p$scale <- paste0(p$mu, div_op, "tgamma(1 - 1", div_op, p$nu, ")")
p$scale <- paste0(p$mu, " ./ tgamma(1 - 1 ./ ", p$nu, ")")
sdist("frechet", p$nu, p$scale)
}

Expand Down Expand Up @@ -724,11 +712,9 @@ stan_log_lik_wiener <- function(bterms, ...) {

stan_log_lik_beta <- function(bterms, ...) {
p <- stan_log_lik_dpars(bterms)
req_dot_multiply <- !stan_log_lik_adj(bterms) && is_pred_dpar(bterms, "phi")
multiply <- str_if(req_dot_multiply, " .* ", " * ")
sdist("beta",
paste0(p$mu, multiply, p$phi),
paste0("(1 - ", p$mu, ")", multiply, p$phi)
paste0(p$mu, " .* ", p$phi),
paste0("(1 - ", p$mu, ") .* ", p$phi)
)
}

Expand Down
25 changes: 14 additions & 11 deletions tests/testthat/tests.stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,8 @@ test_that("customized covariances appear in the Stan code", {
test_that("truncation appears in the Stan code", {
scode <- stancode(time | trunc(0) ~ age + sex + disease,
data = kidney, family = "gamma")
expect_match2(scode, "target += gamma_lpdf(Y[n] | shape, shape / mu[n]) -")
expect_match2(scode, "gamma_lccdf(lb[n] | shape, shape / mu[n]);")
expect_match2(scode, "target += gamma_lpdf(Y[n] | shape, shape ./ mu[n]) -")
expect_match2(scode, "gamma_lccdf(lb[n] | shape, shape ./ mu[n]);")

scode <- stancode(time | trunc(ub = 100) ~ age + sex + disease,
data = kidney, family = student("log"))
Expand Down Expand Up @@ -620,7 +620,7 @@ test_that("Stan code for multivariate models is correct", {
# multivariate weibull models
bform <- bform + weibull()
scode <- stancode(bform, dat)
expect_match2(scode, "weibull_lpdf(Y_g | shape_g, mu_g / tgamma(1 + 1 / shape_g));")
expect_match2(scode, "weibull_lpdf(Y_g | shape_g, mu_g ./ tgamma(1 + 1 ./ shape_g));")
})

test_that("Stan code for categorical models is correct", {
Expand Down Expand Up @@ -1419,6 +1419,9 @@ test_that("weighted, censored, and truncated likelihoods are correct", {
scode <- stancode(y | cens(x) ~ 1, dat, family = asym_laplace())
expect_match2(scode, "target += asym_laplace_lccdf(Y[n] | mu[n], sigma, quantile);")

scode <- stancode(bf(y | cens(x) ~ 1, shape ~ 1), dat, family = Gamma())
expect_match2(scode, "target += gamma_lpdf(Y[Jevent[1:Nevent]] | shape[Jevent[1:Nevent]], shape[Jevent[1:Nevent]] ./ mu[Jevent[1:Nevent]]);")

dat$x[1] <- 2
scode <- stancode(y | cens(x, y2) ~ 1, dat, family = asym_laplace())
expect_match2(scode, "target += log_diff_exp(\n")
Expand All @@ -1442,15 +1445,15 @@ test_that("weighted, censored, and truncated likelihoods are correct", {

expect_match2(
stancode(y | trials(y2) + weights(y2) ~ 1, dat, beta_binomial()),
"target += weights[n] * (beta_binomial_lpmf(Y[n] | trials[n], mu[n] * phi,"
"target += weights[n] * (beta_binomial_lpmf(Y[n] | trials[n], mu[n] .* phi,"
)
expect_match2(
stancode(y | trials(y2) + trunc(0, 30) ~ 1, dat, beta_binomial()),
"log_diff_exp(beta_binomial_lcdf(ub[n] | trials[n], mu[n] * phi,"
"log_diff_exp(beta_binomial_lcdf(ub[n] | trials[n], mu[n] .* phi,"
)
expect_match2(
stancode(y | trials(y2) + cens(x, y2) ~ 1, dat, beta_binomial()),
"beta_binomial_lcdf(rcens[n] | trials[n], mu[n] * phi,"
"beta_binomial_lcdf(rcens[n] | trials[n], mu[n] .* phi,"
)
})

Expand Down Expand Up @@ -1684,19 +1687,19 @@ test_that("Stan code of addition term 'rate' is correct", {
expect_match2(scode, "target += poisson_lpmf(Y | mu .* denom);")

scode <- stancode(y | rate(time) ~ x, data, negbinomial())
expect_match2(scode, "target += neg_binomial_2_log_lpmf(Y | mu + log_denom, shape * denom);")
expect_match2(scode, "target += neg_binomial_2_log_lpmf(Y | mu + log_denom, shape .* denom);")

bform <- bf(y | rate(time) ~ mi(x), shape ~ mi(x), family = negbinomial()) +
bf(x | mi() ~ 1, family = gaussian())
scode <- stancode(bform, data)
expect_match2(scode, "target += neg_binomial_2_log_lpmf(Y_y | mu_y + log_denom_y, shape_y .* denom_y);")

scode <- stancode(y | rate(time) ~ x, data, brmsfamily("negbinomial2"))
expect_match2(scode, "target += neg_binomial_2_log_lpmf(Y | mu + log_denom, inv(sigma) * denom);")
expect_match2(scode, "target += neg_binomial_2_log_lpmf(Y | mu + log_denom, inv(sigma) .* denom);")

scode <- stancode(y | rate(time) + cens(1) ~ x, data, geometric())
expect_match2(scode,
"target += neg_binomial_2_lpmf(Y[Jevent[1:Nevent]] | mu[Jevent[1:Nevent]] .* denom[Jevent[1:Nevent]], 1 * denom[Jevent[1:Nevent]]);"
"target += neg_binomial_2_lpmf(Y[Jevent[1:Nevent]] | mu[Jevent[1:Nevent]] .* denom[Jevent[1:Nevent]], 1 .* denom[Jevent[1:Nevent]]);"
)
})

Expand Down Expand Up @@ -1783,7 +1786,7 @@ test_that("Stan code of mixture model is correct", {
data = data, mixture(Gamma("log"), weibull))
expect_match(scode, "data \\{[^\\}]*real<lower=0,upper=1> theta1;")
expect_match(scode, "data \\{[^\\}]*real<lower=0,upper=1> theta2;")
expect_match2(scode, "ps[1] = log(theta1) + gamma_lpdf(Y[n] | shape1[n], shape1[n] / mu1[n]);")
expect_match2(scode, "ps[1] = log(theta1) + gamma_lpdf(Y[n] | shape1[n], shape1[n] ./ mu1[n]);")
expect_match2(scode, "target += weights[n] * log_sum_exp(ps);")

scode <- stancode(bf(abs(y) | se(c) ~ x), data = data,
Expand Down Expand Up @@ -2146,7 +2149,7 @@ test_that("Stan code for missing value terms works correctly", {
scode <- stancode(bform, dat)
expect_match2(scode, "vector<lower=0,upper=1>[Nmi_x] Ymi_x;")
expect_match2(scode,
"target += beta_lpdf(Y_x[Jevent_x[1:Nevent_x]] | mu_x[Jevent_x[1:Nevent_x]] * phi_x, (1 - mu_x[Jevent_x[1:Nevent_x]]) * phi_x);"
"target += beta_lpdf(Y_x[Jevent_x[1:Nevent_x]] | mu_x[Jevent_x[1:Nevent_x]] .* phi_x, (1 - mu_x[Jevent_x[1:Nevent_x]]) .* phi_x);"
)

# tests #1608
Expand Down

0 comments on commit a0602ad

Please sign in to comment.