Skip to content

Commit

Permalink
added new theme (theme_drwhy) and new lookout
Browse files Browse the repository at this point in the history
  • Loading branch information
pbiecek committed Mar 9, 2019
1 parent b61bb67 commit a63d889
Show file tree
Hide file tree
Showing 25 changed files with 483 additions and 260 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ S3method(print,individual_variable_effect)
export(individual_variable_effect)
export(install_shap)
export(shap)
export(theme_drwhy)
export(theme_drwhy_colors)
export(theme_drwhy_vertical)
import(ggplot2)
importFrom(reticulate,import)
importFrom(reticulate,py_install)
Expand Down
84 changes: 46 additions & 38 deletions R/plot_individual_variable_effect.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#' @param cols A vector of characters defining faceting groups on columns dimension. Possible values: 'label', 'id', 'ylevel'.
#' @param rows A vector of characters defining faceting groups on rows dimension. Possible values: 'label', 'id', 'ylevel'.
#' @param selected A vector of characters. If specified, then only selected classes are presented
#' @param vcolors named vector with colors
#' @param bar_width width of bars. By default 8
#'
#' @import ggplot2
#'
Expand All @@ -27,13 +29,18 @@
#' Y_train <- HR$status
#' x_train <- HR[ , -6]
#' set.seed(123)
#' model_rf <- randomForest(x = x_train, y = Y_train, ntree= 50)
#' model_rf <- randomForest(x = x_train, y = Y_train, ntree = 50)
#' p_function <- function(model, data) predict(model, newdata = data, type = "prob")
#'
#' ive_rf <- individual_variable_effect(model_rf, data = x_train, predict_function = p_function,
#' new_observation = x_train[1:2,], nsamples = 50)
#' plot(ive_rf)
#' }else{
#' pl1 <- plot(ive_rf, bar_width = 4)
#' pl2 <- plot(ive_rf, bar_width = 4, show_predcited = FALSE)
#' pl3 <- plot(ive_rf, bar_width = 4, show_predcited = FALSE, cols = c("id","ylevel"), rows = "label")
#' print(pl1)
#' print(pl2)
#' print(pl3)
#' } else {
#' print('Python testing environment is required.')
#' }
#' @method plot individual_variable_effect
Expand All @@ -43,82 +50,83 @@
#' @export
plot.individual_variable_effect <- function(x, ..., id = 1, digits = 2, rounding_function = round,
show_predcited = TRUE, show_attributions = TRUE,
cols = c("label", "id"), rows = "ylevel", selected = NULL) {
cols = c("label", "id"), rows = "ylevel", selected = NULL, bar_width=8,
vcolors = c(`-` = "#f05a71", `0` = "#371ea3", `+` = "#8bdcbe", X = "#371ea3", pred = "#371ea3")) {

`_id_` <- `_attribution_` <- `_sign_` <- `_vname_` <- `_varvalue_` <- NULL
`_yhat_mean_` <- `_yhat_` <- `_ext_vname_`<- NULL

`_yhat_mean_` <- `_yhat_` <- `_ext_vname_`<- `pretty_text` <- NULL

dfl <- c(list(x), list(...))
x <- do.call(rbind, dfl)
class(x) <- "data.frame"

# if selected is specified then select only these classess
if (!is.null(selected)) {
x <- x[x$`_ylevel_` %in% selected, ]
x <- x[x$`_ylevel_` %in% selected,]
}

# if id is specified then select only these observations
x <- x[x$`_id_` %in% id, ]
x <- x[x$`_id_` %in% id,]
values <- as.vector(x[1 , x$`_vname_`[1:(length(unique(x$`_vname_`)) * length(id))]])
names(values) <- unique(paste(x$`_vname_`, x$`_id_`))

for (i in 1:length(values)){
for (i in 1:length(values)) {
variable_i <- sub(" .*", "", names(values)[i])
id_i <- sub(".* ", "", names(values)[i])
values[i] <- x[x$`_vname_` == variable_i & x$`_id_` == id_i, ][1, variable_i]
values[i] <- x[x$`_vname_` == variable_i & x$`_id_` == id_i,][1, variable_i]
}
variable_values <- values[paste(x$`_vname_`, x$`_id_`)]
numeric_values <- sapply(variable_values, is.numeric)
variable_values[numeric_values] <- rounding_function(variable_values[numeric_values], digits)
x$`_varvalue_` <- t(variable_values)
x$`_vname_` <- reorder(x$`_vname_`, x$`_attribution_`, function(z) sum(abs(z)))
x$`_vname_` <- reorder(x$`_vname_`, x$`_attribution_`, function(z) sum(abs(z)))
x$`_ext_vname_` <- paste(x$`_vname_`, "=", x$`_varvalue_`)
x$`_ext_vname_` <- reorder(x$`_ext_vname_`, as.numeric(x$`_vname_`) * 0.001 + x$`_id_`, function(z) sum(z))
x$`_vname_id_` <- paste(x$`_id_`, x$`_vname_`)


if(show_predcited == TRUE) {
x$pretty_text <- paste0(" ", rounding_function(x$`_attribution_`, digits), " ")
if (show_predcited == TRUE) {
levels(x$`_ext_vname_`) <- c(levels(x$`_ext_vname_`), "_predicted_")
for(i in 1:length(id)){
x_pred <- x[id == i, ]
for (i in 1:length(id)) {
x_pred <- x[id == i,]
x_pred$`_ext_vname_` <- factor("_predicted_", levels = levels(x$`_ext_vname_`))
x_pred$`_attribution_` <- x_pred$`_yhat_`
x_pred$`_attribution_` <- x_pred$`_yhat_` - x_pred$`_yhat_mean_`
x_pred$pretty_text <- paste0(" ", rounding_function(x_pred$`_yhat_`, digits), " ")
x_pred$`_sign_` <- "pred"
x <- rbind(x, x_pred)
}
}

maybe_attributions <- if(show_attributions == TRUE){
geom_text(aes(label = rounding_function(`_attribution_`, digits)), nudge_x = 0.45, hjust="inward")
} else {
NULL
}

rows <- paste0("`_", rows, "_`")
rows <- paste(paste0("`_", rows, "_`"), collapse = "+")
cols <- paste(paste0("`_", cols, "_`"), collapse = "+")
grid_formula <- as.formula(paste(rows, "~", cols))

id_labeller <- function(value) paste0("id = ", value)
label_labeller <- function(value) {
if(length(unique(x$`_label_`)) > 1) return(value)
if (length(unique(x$`_label_`)) > 1) return(value)
""
}
}

ggplot(x, aes(x= `_ext_vname_`, xend=`_ext_vname_`,
yend = `_yhat_mean_`, y = `_yhat_mean_` + `_attribution_`,
color=`_sign_`)) +
geom_segment(arrow = arrow(length=unit(0.20,"cm"), ends="first", type = "closed")) +
scale_y_discrete(drop=FALSE) +
ylim(c(min(x$`_yhat_mean_` + x$`_attribution_`), max(x$`_yhat_mean_` + x$`_attribution_`))) +
geom_hline(aes(yintercept = `_yhat_mean_`)) +
maybe_attributions +
pl <- ggplot(x, aes(x = `_ext_vname_`,
y = `_yhat_mean_` + pmax(`_attribution_`, 0),
ymin = `_yhat_mean_`,
ymax = `_yhat_mean_` + `_attribution_`,
color = `_sign_`) ) +
geom_linerange(size = bar_width) +
geom_hline(aes(yintercept = `_yhat_mean_`), color = "#371ea3") +
facet_grid(grid_formula,
labeller = labeller(`_id_` = as_labeller(id_labeller), `_label_` = as_labeller(label_labeller))) +
scale_color_manual(values = c(`-` = "#d8b365", `0` = "#f5f5f5", `+` = "#5ab4ac",
X = "darkgrey", pred = "black")) +
coord_flip() + theme_minimal() + theme(legend.position="none") +
xlab("") + ylab("Shapley values") + ggtitle("Shapley values")
labeller = labeller(
`_id_` = as_labeller(id_labeller),
`_label_` = as_labeller(label_labeller)
)) +
scale_color_manual(values = vcolors) +
coord_flip() + theme_drwhy_vertical() + theme(legend.position = "none") +
xlab("") + ylab("Shapley values") + ggtitle("")

if (show_attributions) {
pl <- pl + geom_text(aes(label = pretty_text), hjust = 0)
}

pl

}
57 changes: 57 additions & 0 deletions R/theme_drwhy.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#' DrWhy Theme for ggplot objects
#'
#' @param n number of colors for color palette
#'
#' @return theme for ggplot2 objects
#' @export
#' @rdname theme_drwhy
theme_drwhy <- function() {
theme_bw(base_line_size = 0) %+replace%
theme(axis.ticks = element_blank(), legend.background = element_blank(),
legend.key = element_blank(), panel.background = element_blank(),
panel.border = element_blank(), strip.background = element_blank(),
plot.background = element_blank(), complete = TRUE,
legend.direction = "horizontal", legend.position = "top",
axis.line.y = element_line(color = "white"),
axis.ticks.y = element_line(color = "white"),
#axis.line = element_line(color = "#371ea3", size = 0.5, linetype = 1),
axis.title = element_text(color = "#371ea3"),
axis.text = element_text(color = "#371ea3", size = 10),
strip.text = element_text(color = "#371ea3", size = 12, hjust = 0),
panel.grid.major.y = element_line(color = "grey90", size = 0.5, linetype = 1),
panel.grid.minor.y = element_line(color = "grey90", size = 0.5, linetype = 1))

}

#' @export
#' @rdname theme_drwhy
theme_drwhy_vertical <- function() {
theme_bw(base_line_size = 0) %+replace%
theme(axis.ticks = element_blank(), legend.background = element_blank(),
legend.key = element_blank(), panel.background = element_blank(),
panel.border = element_blank(), strip.background = element_blank(),
plot.background = element_blank(), complete = TRUE,
legend.direction = "horizontal", legend.position = "top",
axis.line.x = element_line(color = "white"),
axis.ticks.x = element_line(color = "white"),
#axis.line = element_line(color = "#371ea3", size = 0.5, linetype = 1),
axis.title = element_text(color = "#371ea3"),
axis.text = element_text(color = "#371ea3", size = 10),
strip.text = element_text(color = "#371ea3", size = 12, hjust = 0),
panel.grid.major.x = element_line(color = "grey90", size = 0.5, linetype = 1),
panel.grid.minor.x = element_line(color = "grey90", size = 0.5, linetype = 1))

}


#' @export
#' @rdname theme_drwhy
theme_drwhy_colors <- function(n = 2) {
if (n == 1) return("#4378bf")
if (n == 2) return(c( "#4378bf", "#8bdcbe"))
if (n == 3) return(c( "#4378bf", "#f05a71", "#8bdcbe"))
if (n == 4) return(c( "#4378bf", "#f05a71", "#8bdcbe", "#ffa58c"))
if (n == 5) return(c( "#4378bf", "#f05a71", "#8bdcbe", "#ae2c87", "#ffa58c"))
if (n == 6) return(c( "#4378bf", "#46bac2", "#8bdcbe", "#ae2c87", "#ffa58c", "#f05a71"))
c( "#4378bf", "#46bac2", "#371ea3", "#8bdcbe", "#ae2c87", "#ffa58c", "#f05a71")[((0:(n-1)) %% 7) + 1]
}
Loading

0 comments on commit a63d889

Please sign in to comment.