Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] add a tree plotting function #6729

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Added review suggestions.
DiagrammeR in CI.
Error messages.
Default parameters.
Changed tests.
  • Loading branch information
fboudry committed Dec 24, 2024
commit 757dc847288886287a76eb7f2d076e01d3104ac7
2 changes: 1 addition & 1 deletion .ci/test-r-package-windows.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ Write-Output "Done installing CMake"

Write-Output "Installing dependencies"
$packages = -join @(
"c('data.table', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'processx', 'R6', 'RhpcBLASctl', 'testthat'), ",
"c('data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'processx', 'R6', 'RhpcBLASctl', 'testthat'), ",
"dependencies = c('Imports', 'Depends', 'LinkingTo')"
)
$params = -join @(
Expand Down
2 changes: 1 addition & 1 deletion R-package/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ NeedsCompilation: yes
Biarch: true
VignetteBuilder: knitr
Suggests:
DiagrammeR,
knitr,
markdown,
processx,
RhpcBLASctl,
testthat,
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
DiagrammeR
Depends:
R (>= 3.5)
Imports:
Expand Down
199 changes: 103 additions & 96 deletions R-package/R/lgb.plot.tree.R
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#' @name lgb.plot.tree
#' @title Plot a single LightGBM tree using DiagrammeR.
#' @title Plot a single LightGBM tree.
#' @description The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree.
#' @param model a \code{lgb.Booster} object.
#' @param tree an integer specifying the tree to plot.
#' @param tree an integer specifying the tree to plot. This is 1-based, so e.g. a value of '7' means 'the 7th tree' (tree_index=6 in LightGBM's underlying representation).
#' @param rules a list of rules to replace the split values with feature levels.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not totally convinced about this idea... it should be possible to recover the feature names from the model directly.

But before you remove this... can you please expand this doc and add examples and tests showing what this would look like? Right now, it's hard for me to understand what the content of rules is supposed to be.

#'
#' @return
#' The \code{lgb.plot.tree} function creates a DiagrammeR plot.
#'
#' @details
#' The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree. The tree is extracted from the model and displayed as a directed graph. The nodes are labelled with the feature, split value, gain, cover and value. The edges are labelled with the decision type and split value. The nodes are styled with a rectangle shape and filled with a beige colour. Leaf nodes are styled with an oval shape and filled with a khaki colour. The graph is rendered using the dot layout with a left-to-right rank direction. The nodes are coloured dim gray with a filled style and a Helvetica font. The edges are coloured dim gray with a solid style, a 1.5 arrow size, a vee arrowhead and a Helvetica font.
#' The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree. The tree is extracted from the model and displayed as a directed graph. The nodes are labelled with the feature, split value, gain, cover and value. The edges are labelled with the decision type and split value.
#'
#' @examples
#' \donttest{
Expand All @@ -23,9 +23,7 @@
#' # define model parameters and build a single tree
#' params <- list(
#' objective = "regression",
#' metric = "l2",
#' min_data = 1L,
#' learning_rate = 1.0
#' )
#' valids <- list(test = dtest)
#' model <- lgb.train(
Expand All @@ -43,142 +41,151 @@
#'
#' @export

# function to plot a single LightGBM tree using DiagrammeR
lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
lgb.plot.tree <- function(model, tree, rules = NULL) {

I can't think of any situation where it would be ok for model or tree to be NULL, can you?

If not, let's please require callers to provide values explicitly.

# check model is lgb.Booster
if (!inherits(model, "lgb.Booster")) {
stop("model: Has to be an object of class lgb.Booster")
if (!.is_Booster(x = model)) {
stop("lgb.plot.tree: model should be an ", sQuote("lgb.Booster"))
}
# check DiagrammeR is available
if (!requireNamespace("DiagrammeR", quietly = TRUE)) {
stop("DiagrammeR package is required for lgb.plot.tree",
stop("lgb.plot.tree: DiagrammeR package is required",
call. = FALSE
)
}
# tree must be numeric
if (!inherits(tree, "numeric")) {
stop("tree: Has to be an integer numeric")
stop("lgb.plot.tree: Has to be an integer numeric")
}
# tree must be integer
if (tree %% 1 != 0) {
stop("tree: Has to be an integer numeric")
stop("lgb.plot.tree: Has to be an integer numeric")
}
# extract data.table model structure
dt <- lgb.model.dt.tree(model)
modelDT <- lgb.model.dt.tree(model)
# check that tree is less than or equal to the maximum tree index in the model
if (tree > max(dt$tree_index)) {
stop("tree: has to be less than the number of trees in the model")
if (tree > max(modelDT$tree_index)) {
stop("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index, "). Got: ," tree, ".")
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please modify this error message so that it has enough information for someone to quickly debug the issue, like the provided value of tree and the number of trees in the model. And please combine it with the other check that the value is `>=01.

Something like this:

lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (125). Got: 181.

# filter dt to just the rows for the selected tree
dt <- dt[tree_index == tree, ]
# filter modelDT to just the rows for the selected tree
modelDT <- modelDT[tree_index == tree, ]
# change the column names to shorter more diagram friendly versions
data.table::setnames(dt, old = c("tree_index", "split_feature", "threshold", "split_gain"), new = c("Tree", "Feature", "Split", "Gain"))
dt[, Value := 0.0]
dt[, Value := leaf_value]
dt[is.na(Value), Value := internal_value]
dt[is.na(Gain), Gain := leaf_value]
dt[is.na(Feature), Feature := "Leaf"]
dt[, Cover := internal_count][Feature == "Leaf", Cover := leaf_count]
dt[, c("leaf_count", "internal_count", "leaf_value", "internal_value") := NULL]
dt[, Node := split_index]
max_node <- max(dt[["Node"]], na.rm = TRUE)
dt[is.na(Node), Node := max_node + leaf_index + 1]
dt[, ID := paste(Tree, Node, sep = "-")]
dt[, c("depth", "leaf_index") := NULL]
dt[, parent := node_parent][is.na(parent), parent := leaf_parent]
dt[, c("node_parent", "leaf_parent", "split_index") := NULL]
dt[, Yes := dt$ID[match(dt$Node, dt$parent)]]
dt <- dt[nrow(dt):1, ]
dt[, No := dt$ID[match(dt$Node, dt$parent)]]
data.table::setnames(modelDT, old = c("tree_index", "split_feature", "threshold", "split_gain"), new = c("Tree", "Feature", "Split", "Gain"))
modelDT[, Value := 0.0]
modelDT[, Value := leaf_value]
modelDT[is.na(Value), Value := internal_value]
modelDT[is.na(Gain), Gain := leaf_value]
modelDT[is.na(Feature), Feature := "Leaf"]
modelDT[, Cover := internal_count][Feature == "Leaf", Cover := leaf_count]
modelDT[, c("leaf_count", "internal_count", "leaf_value", "internal_value") := NULL]
modelDT[, Node := split_index]
max_node <- max(modelDT[["Node"]], na.rm = TRUE)
modelDT[is.na(Node), Node := max_node + leaf_index + 1]
modelDT[, ID := paste(Tree, Node, sep = "-")]
modelDT[, c("depth", "leaf_index") := NULL]
modelDT[, parent := node_parent][is.na(parent), parent := leaf_parent]
modelDT[, c("node_parent", "leaf_parent", "split_index") := NULL]
modelDT[, Yes := modelDT$ID[match(modelDT$Node, modelDT$parent)]]
modelDT <- modelDT[nrow(modelDT):1, ]
modelDT[, No := modelDT$ID[match(modelDT$Node, modelDT$parent)]]
# which way do the NA's go (this path will get a thicker arrow)
# for categorical features, NA gets put into the zero group
dt[default_left == TRUE, Missing := Yes]
dt[default_left == FALSE, Missing := No]
zero_present <- function(x) {
sapply(strsplit(as.character(x), "||", fixed = TRUE), function(el) {
any(el == "0")
})
}
dt[zero_present(Split), Missing := Yes]
# dt[, c('parent', 'default_left') := NULL]
# data.table::setcolorder(dt, c('Tree','Node','ID','Feature','decision_type','Split','Yes','No','Missing','Gain','Cover','Value'))
modelDT[default_left == TRUE, Missing := Yes]
modelDT[default_left == FALSE, Missing := No]
modelDT[.zero_present(Split), Missing := Yes]
# modelDT[, c('parent', 'default_left') := NULL]
# data.table::setcolorder(modelDT, c('Tree','Node','ID','Feature','decision_type','Split','Yes','No','Missing','Gain','Cover','Value'))
# create the label text
dt[, label := paste0(
modelDT[, label := paste0(
Feature,
"\nCover: ", Cover,
ifelse(Feature == "Leaf", "", "\nGain: "), ifelse(Feature == "Leaf", "", round(Gain, 4)),
"\nValue: ", round(Value, 4)
)]
# style the nodes - same format as xgboost
dt[Node == 0, label := paste0("Tree ", Tree, "\n", label)]
dt[, shape := "rectangle"][Feature == "Leaf", shape := "oval"]
dt[, filledcolor := "Beige"][Feature == "Leaf", filledcolor := "Khaki"]
modelDT[Node == 0, label := paste0("Tree ", Tree, "\n", label)]
modelDT[, shape := "rectangle"][Feature == "Leaf", shape := "oval"]
modelDT[, filledcolor := "Beige"][Feature == "Leaf", filledcolor := "Khaki"]
# in order to draw the first tree on top:
dt <- dt[order(-Tree)]
modelDT <- modelDT[order(-Tree)]
nodes <- DiagrammeR::create_node_df(
n = nrow(dt),
ID = dt$ID,
label = dt$label,
fillcolor = dt$filledcolor,
shape = dt$shape,
data = dt$Feature,
fontcolor = "black"
n = nrow(modelDT)
, ID = modelDT$ID
, label = modelDT$label
, fillcolor = modelDT$filledcolor
, shape = modelDT$shape
, data = modelDT$Feature
, fontcolor = "black"
)
# round the edge labels to 4 s.f. if they are numeric
# as otherwise get too many decimal places and the diagram looks bad
# would rather not use suppressWarnings
numeric_idx <- suppressWarnings(!is.na(as.numeric(dt[["Split"]])))
dt[numeric_idx, Split := round(as.numeric(Split), 4)]
numeric_idx <- suppressWarnings(!is.na(as.numeric(modelDT[["Split"]])))
modelDT[numeric_idx, Split := round(as.numeric(Split), 4)]
# replace indices with feature levels if rules supplied
levels.to.names <- function(x, feature_name, rules) {
lvls <- sort(rules[[feature_name]])
result <- strsplit(x, "||", fixed = TRUE)
result <- lapply(result, as.numeric)
levels_to_names <- function(x) {
names(lvls)[as.numeric(x)]
}
result <- lapply(result, levels_to_names)
result <- lapply(result, paste, collapse = "\n")
result <- as.character(result)
}

if (!is.null(rules)) {
for (f in names(rules)) {
dt[Feature == f & decision_type == "==", Split := levels.to.names(Split, f, rules)]
modelDT[Feature == f & decision_type == "==", Split := .levels.to.names(Split, f, rules)]
}
}
# replace long split names with a message
dt[nchar(Split) > 500, Split := "Split too long to render"]
modelDT[nchar(Split) > 500, Split := "Split too long to render"]
# create the edge labels
edges <- DiagrammeR::create_edge_df(
from = match(dt[Feature != "Leaf", c(ID)] %>% rep(2), dt$ID),
to = match(dt[Feature != "Leaf", c(Yes, No)], dt$ID),
label = dt[Feature != "Leaf", paste(decision_type, Split)] %>%
c(rep("", nrow(dt[Feature != "Leaf"]))),
style = dt[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")] %>%
c(dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]),
from = match(modelDT[Feature != "Leaf", c(ID)] %>% rep(2), modelDT$ID),
to = match(modelDT[Feature != "Leaf", c(Yes, No)], modelDT$ID),
label = modelDT[Feature != "Leaf", paste(decision_type, Split)] %>%
c(rep("", nrow(modelDT[Feature != "Leaf"]))),
style = modelDT[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")] %>%
c(modelDT[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]),
rel = "leading_to"
)
# create the graph
graph <- DiagrammeR::create_graph(
nodes_df = nodes,
edges_df = edges,
attr_theme = NULL
) %>%
DiagrammeR::add_global_graph_attrs(
attr_type = "graph",
attr = c("layout", "rankdir"),
value = c("dot", "LR")
) %>%
DiagrammeR::add_global_graph_attrs(
attr_type = "node",
attr = c("color", "style", "fontname"),
value = c("DimGray", "filled", "Helvetica")
) %>%
DiagrammeR::add_global_graph_attrs(
attr_type = "edge",
attr = c("color", "arrowsize", "arrowhead", "fontname"),
value = c("DimGray", "1.5", "vee", "Helvetica")
nodes_df = nodes
, edges_df = edges
, attr_theme = NULL
)
graph <- DiagrammeR::add_global_graph_attrs(
graph = graph
, attr_type = "graph"
, attr = c("layout", "rankdir")
, value = c("dot", "LR")
)
graph <- DiagrammeR::add_global_graph_attrs(
graph = graph
, attr_type = "node"
, attr = c("color", "style", "fontname")
, value = c("DimGray", "filled", "Helvetica")
)
graph <- DiagrammeR::add_global_graph_attrs(
graph = graph
, attr_type = "edge"
, attr = c("color", "arrowsize", "arrowhead", "fontname")
, value = c("DimGray", "1.5", "vee", "Helvetica")
)
# render the graph
DiagrammeR::render_graph(graph)
return(invisible(NULL))
}

.zero_present <- function(x) {
sapply(strsplit(as.character(x), "||", fixed = TRUE), function(el) {
any(el == "0")
})
return(invisible(NULL))
}

.levels.to.names <- function(x, feature_name, rules) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.levels.to.names <- function(x, feature_name, rules) {
.levels_to_names <- function(x, feature_name, rules) {

Please avoid using . in any of these private functions' names.

lvls <- sort(rules[[feature_name]])
result <- strsplit(x, "||", fixed = TRUE)
result <- lapply(result, as.numeric)
result <- lapply(result, .levels_to_names)
result <- lapply(result, paste, collapse = "\n")
result <- as.character(result)
return(invisible(NULL))
}

.levels_to_names <- function(x) {
names(lvls)[as.numeric(x)]
return(invisible(NULL))
}
44 changes: 14 additions & 30 deletions R-package/tests/testthat/test_lgb.plot.tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,42 @@ test_that("lgb.plot.tree works as expected"){
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
data(agaricus.test, package = "lightgbm")
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
# define model parameters and build a single tree
params <- list(
objective = "regression"
, metric = "l2"
, min_data = 1L
, learning_rate = 1.0
)
valids <- list(test = dtest)
model <- lgb.train(
params = params
params = list(
objective = "regression"
, num_threads = .LGB_MAX_THREADS
)
, data = dtrain
, nrounds = 1L
, valids = valids
, early_stopping_rounds = 1L
, verbose = .LGB_VERBOSITY
)
# plot the tree and compare to the tree table
# trees start from 0 in lgb.model.dt.tree
tree_table <- lgb.model.dt.tree(model)
expect_true({
lgb.plot.tree(model, 0)TRUE
})
lgb.plot.tree(model, 0)
}, regexp = "lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model")
}

test_that("lgb.plot.tree fails when a non existing tree is selected"){
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
data(agaricus.test, package = "lightgbm")
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
# define model parameters and build a single tree
params <- list(
objective = "regression"
, metric = "l2"
, min_data = 1L
, learning_rate = 1.0
)
valids <- list(test = dtest)
model <- lgb.train(
params = params
params = list(
objective = "regression"
, num_threads = .LGB_MAX_THREADS
)
, data = dtrain
, nrounds = 1L
, valids = valids
, early_stopping_rounds = 1L
, verbose = .LGB_VERBOSITY
)
# plot the tree and compare to the tree table
# trees start from 0 in lgb.model.dt.tree
tree_table <- lgb.model.dt.tree(model)
expect_error({
lgb.plot.tree(model, 999)TRUE
})
lgb.plot.tree(model, 999)
}, regexp = "lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model")
}