Skip to content

Commit

Permalink
Completed the basic functionality of the iprobit package.
Browse files Browse the repository at this point in the history
  • Loading branch information
Haziq Jamil committed Dec 9, 2017
1 parent c678609 commit b8f0c14
Show file tree
Hide file tree
Showing 26 changed files with 1,465 additions and 813 deletions.
15 changes: 4 additions & 11 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ S3method(as.data.frame,iprobitData)
S3method(fitted,iprobitMod)
S3method(iprobit,default)
S3method(iprobit,formula)
S3method(iprobit,ipriorKernel)
S3method(iprobit,iprobitMod)
S3method(logLik,iprobitMod)
S3method(plot,iprobitData)
Expand All @@ -12,39 +13,31 @@ S3method(predict,iprobitMod)
S3method(print,iprobitLowerBound)
S3method(print,iprobitMod)
S3method(print,iprobitMod_summary)
S3method(print,iprobitPredict)
S3method(print,iprobitPredict_quant)
S3method(summary,iprobitMod)
S3method(update,iprobitMod)
export(convert_prob)
export(expit)
export(gen_circle)
export(gen_mixture)
export(gen_spiral)
export(get_brier_score)
export(get_brier_scores)
export(get_coef_se_mult)
export(get_error_rate)
export(get_error_rates)
export(get_kernel)
export(get_hyperparam)
export(get_intercept)
export(get_lbs)
export(get_one.lam)
export(iplot_dec_bound)
export(iplot_error)
export(iplot_fitted)
export(iplot_lb)
export(iplot_predict)
export(iprobit)
export(iprobit_bin)
export(iprobit_mult)
export(iprobit_parallel)
export(is.iprobitData)
export(is.iprobitMod)
export(is.iprobitMod_bin)
export(is.iprobitMod_mult)
export(logit)
export(predict_quant)
export(quantile_prob)
export(sample_prob_mult)
import(ggplot2)
import(iprior)
importFrom(stats,coef)
Expand Down
141 changes: 141 additions & 0 deletions R/Accessor_functions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
################################################################################
#
# iprobit: Binary and Multinomial Probit Regression with I-priors
# Copyright (C) 2017 Haziq Jamil
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################

#' Accessor functions for \code{ipriorMod} objects.
#'
#' @param object An \code{ipriorMod} object.
#'
#' @name Accessors
NULL

#' @describeIn Accessors Obtain all of the hyperparameters.
#' @export
get_hyperparam <- function(object) {
check_and_get_iprobitMod(object)
object$param.full
}

#' @describeIn Accessors Obtain the intercept.
#' @export
get_intercept <- function(object, by.class = FALSE) {
res <- get_hyperparam(object)
res <- res[grep("Intercept", rownames(res)), , drop = FALSE]
if (!isTRUE(by.class)) {
res <- c(res)
if (is.common.intercept(object)) {
res <- res[1]
names(res) <- "Intercept"
} else {
names(res) <- get_names(object, "intercept")
}
}
res
}

get_alpha <- get_intercept

get_lambda <- function(object, by.class = FALSE) {
res <- get_hyperparam(object)
res <- res[grep("lambda", rownames(res)), , drop = FALSE]
if (!isTRUE(by.class)) {
res <- c(res)
if (is.common.intercept(object)) {
res <- res[1]
names(res) <- get_names(object, "lambda", FALSE)
} else {
names(res) <- get_names(object, "lambda", TRUE)
}
}
res
}

get_sd <- function(object) {
setNames(object$param.summ$S.D., rownames(object$param.summ))
}

get_sd_alpha <- function(object) {
res <- get_sd(object)
res[grep("Intercept", names(res))]
}

get_sd_lambda <- function(object) {
res <- get_sd(object)
res[grep("lambda", names(res))]
}



#' @export
get_error_rate <- function(x) x$fitted.values$train.error

#' @export
get_error_rates <- function(x) {
res <- x$error
names(res) <- seq_along(res)
res
}

#' @export
get_brier_score <- function(x) x$fitted.values$brier.score

#' @export
get_brier_scores <- function(x) {
res <- x$brier
names(res) <- seq_along(res)
res
}

get_m <- function(object) {
if (is.iprobitMod(object)) object <- object$ipriorKernel
length(object$y.levels)
}


#' Extract the variational lower bound
#'
#' @param object An object of class \code{ipriorProbit}.
#' @param ... This is not used here.
#'
#' @return The variational lower bound.
#' @export
logLik.iprobitMod <- function(object, ...) {
lb <- object$lower.bound[!is.na(object$lower.bound)]
lb <- lb[length(lb)]
class(lb) <- "iprobitLowerBound"
lb
}

#' @export
get_lbs <- function(x) x$lower.bound

#' @export
print.iprobitLowerBound <- function(x, ...) {
cat("Lower bound =", x)
}

get_theta <- function(object) object$theta

get_w <- function(object) object$w

get_y <- function(object) {
res <- factor(object$ipriorKernel$y)
levels(res) <- object$ipriorKernel$y.levels
res
}
66 changes: 49 additions & 17 deletions R/Plots.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
#
################################################################################

# iplot_fitted()
# iplot_dec_bound()
# iplot_predict()
# iplot_lb()
# iplot_error()
# iplot_lb_and_error()

#' @export
plot.iprobitMod <- function(x, niter.plot = NULL, levels = NULL, ...) {
iplot_fitted(x)
Expand All @@ -27,10 +34,9 @@ plot.iprobitMod <- function(x, niter.plot = NULL, levels = NULL, ...) {
iplot_fitted <- function(object) {
list2env(object, environment())
list2env(ipriorKernel, environment())
list2env(model, environment())

probs <- fitted(object)$prob
if (isNystrom(ipriorKernel)) probs <- probs[order(Nystrom$Nys.samp), ]
# if (isNystrom(ipriorKernel)) probs <- probs[order(Nystrom$Nys.samp), ]
df.plot <- data.frame(probs, i = 1:n)
colnames(df.plot) <- c(y.levels, "i")
df.plot <- reshape2::melt(df.plot, id.vars = "i")
Expand All @@ -44,7 +50,7 @@ iplot_fitted <- function(object) {
}

#' @export
iplot_lb <- function(x, niter.plot = NULL, lab.pos = c("up", "down"), ...) {
iplot_lb <- function(x, niter.plot, lab.pos = c("up", "down"), ...) {
lae.check <- FALSE
extra.opt <- list(...)
if (isTRUE(extra.opt$lb.and.error)) {
Expand All @@ -61,7 +67,7 @@ iplot_lb <- function(x, niter.plot = NULL, lab.pos = c("up", "down"), ...) {

lb.original <- x$lower.bound
if (missing(niter.plot)) niter.plot <- seq_along(lb.original)
else if (length(niter.plot) == 1) niter.plot <- c(1, niter.plot)
else if (length(niter.plot) == 1) niter.plot <- seq_len(niter.plot)
lb <- lb.original[niter.plot]
plot.df <- data.frame(Iteration = niter.plot, lb = lb)
time.per.iter <- x$time$time / x$niter
Expand Down Expand Up @@ -191,7 +197,7 @@ iplot_dec_bound <- function(object, X.var = c(1, 2), col = "grey35", size = 0.8,
geom_point(data = points.df, aes(X1, X2, col = Class)) +
geom_contour(data = plot.df, aes(X1, X2, z = value, group = variable,
size = "Decision\nboundary"),
bins = 2, col = col, ...) +
binwidth = 0.5 + 1e-12, col = col, ...) +
coord_cartesian(xlim = mm[1, ], ylim = mm[2, ]) +
scale_colour_manual(values = c(iprior::gg_col_hue(m), "grey30")) +
scale_size_manual(values = size, name = NULL) +
Expand Down Expand Up @@ -230,18 +236,44 @@ iplot_predict <- function(object, X.var = c(1, 2), grid.len = 50,
}

iplot_predict_bin <- function(plot.df, points.df, x, y, m, dec.bound) {
ggplot() +
geom_raster(data = plot.df, aes(X1, X2, fill = class2), alpha = 0.5) +
p <- ggplot() +
geom_raster(data = plot.df, aes(X1, X2, fill = `2`), alpha = 0.5) +
scale_fill_gradient(low = "#F8766D", high = "#00BFC4", limits = c(0, 1)) +
# annotate(geom = "raster", x = plot.df[, 1], y = plot.df[, 2],
# alpha = 0.6 * plot.df[, 3], fill = "#F8766D") +
# annotate(geom = "raster", x = plot.df[, 1], y = plot.df[, 2],
# alpha = 0.6 * plot.df[, 4], fill = "#00BFC4") +
geom_point(data = points.df, aes(X1, X2, col = Class)) +
# col = "black", shape = 21, stroke = 0.8) +
coord_cartesian(xlim = x, ylim = y) +
guides(fill = FALSE) +
theme_bw()

# Add decision boundary ------------------------------------------------------
if (isTRUE(dec.bound)) {
p <- p +
geom_contour(data = plot.df, aes(X1, X2, z = `2`,
size = "Decision\nboundary"),
binwidth = 0.5 + 1e-12, col = "grey35",
linetype = "dashed") +
scale_size_manual(values = 0.8, name = NULL) +
scale_color_discrete(name = " Class") +
guides(fill = FALSE,
size = guide_legend(override.aes = list(linetype = 2, col = "grey35")),
col = guide_legend(order = 1, override.aes = list(linetype = 0))) +
theme(legend.key.width = unit(2, "line"))
} else {
p <- p + guides(fill = FALSE)
}

# Add points -----------------------------------------------------------------
p <- p +
ggrepel::geom_label_repel(data = points.df, segment.colour = "grey25",
box.padding = 0.9, show.legend = FALSE,
aes(X1, X2, col = Class, label = prob)) +
geom_point(data = points.df, aes(X1, X2, col = Class)) +
geom_point(data = subset(points.df, points.df$prob != ""), aes(X1, X2),
shape = 1, col = "grey25")

# Touch up plot and return ---------------------------------------------------
p + coord_cartesian(xlim = x, ylim = y)

}

iplot_predict_mult <- function(plot.df, points.df, x, y, m, dec.bound) {
Expand All @@ -262,15 +294,15 @@ iplot_predict_mult <- function(plot.df, points.df, x, y, m, dec.bound) {
alpha = alpha * plot.df[, class.ind[j]], fill = fill.col[j])
}

# Add points and decision boundary, and touch up remaining plot --------------
# Add decision boundary ------------------------------------------------------
if (isTRUE(dec.bound)) {
decbound.df <- reshape2::melt(plot.df, id.vars = c("X1", "X2"))
p <- p +
geom_contour(data = decbound.df, aes(X1, X2, z = value, group = variable,
col = variable,
size = "Decision\nboundary"),
# binwidth = 0.5 + 1e-12,
bins = 2,
binwidth = 0.5 + 1e-12,
# bins = 2,
linetype = "dashed") +
scale_size_manual(values = 0.8, name = NULL) +
scale_color_discrete(name = " Class") +
Expand Down Expand Up @@ -303,7 +335,7 @@ iplot_predict_mult <- function(plot.df, points.df, x, y, m, dec.bound) {
iplot_error <- function(x, niter.plot, plot.test = TRUE) {
if (x$niter < 2) stop("Nothing to plot.")
if (missing(niter.plot)) niter.plot <- seq_along(x$train.error)
else if (length(niter.plot) == 1) niter.plot <- c(1, niter.plot)
else if (length(niter.plot) == 1) niter.plot <- seq_len(niter.plot)

# Prepare plotting data frame ------------------------------------------------
plot.df <- data.frame(Iteration = niter.plot,
Expand Down Expand Up @@ -360,9 +392,9 @@ iplot_error <- function(x, niter.plot, plot.test = TRUE) {
theme(legend.position = "none")
}

iplot_lb_and_error <- function(x, niter.plot, lab.pos) {
iplot_lb_and_error <- function(x, niter.plot, lab.pos, plot.test = TRUE) {
suppressMessages(
p2 <- iplot_error(x, niter.plot) +
p2 <- iplot_error(x, niter.plot, plot.test) +
scale_x_continuous(
breaks = scales::pretty_breaks(n = min(5, ifelse(x$niter == 2, 1, x$niter)))
)
Expand Down
Loading

1 comment on commit b8f0c14

@haziqj
Copy link
Owner

@haziqj haziqj commented on b8f0c14 Apr 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue #3

Please sign in to comment.