Skip to content

Commit

Permalink
Ensemble fixed for new SVM's
Browse files Browse the repository at this point in the history
  • Loading branch information
orionw committed Jan 8, 2019
1 parent 7c4fa4c commit 876ad80
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 17 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export(GetModelComparisons)
export(GetModelWeights)
export(GetPredType)
export(GetPredictionsForStacking)
export(GetSVMScale)
export(GetTrainingInfo)
export(GetWeightsFromTestingSet)
export(MajorityVote)
Expand Down
23 changes: 16 additions & 7 deletions R/Ensembling.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#' @keywords internal
#' @export
StripPredictions <- function(pred) {
if (class(pred) == "list") {
if (class(pred)[[1]] == "list") {
realPredictions <- list(length = length(pred))
i = 0
for (model.pred in pred) {
Expand Down Expand Up @@ -84,7 +84,7 @@ predict.Ensemble <- function(object, newdata, voting.type="default", ...) {
i = 0
for (model in object$models) {
i = i + 1
preds[[i]] <- StripPredictions(predict(model, newdata = newdata, type="prob"))
preds[[i]] <- StripPredictions(GetPredType(model, newdata))
}

# input overrides set voting type
Expand Down Expand Up @@ -288,16 +288,25 @@ GetFactorEqual <- function(pred) {
GetWeightsFromTestingSet <- function(ensemble, df.train, test.set, train.type) {
i = 0
preds = list(length = length(ensemble$models))
for (model in ensemble$models) {
i = i + 1
preds[[i]] <- predict(model, df.train, type="prob")
}
fakeModelComp <- list()
class(fakeModelComp) <- "ModelComparison"
fakeModelComp$model.list <- ensemble$models
fakeModelComp$.multi.class <- FALSE
preds <- predict.ModelComparison(fakeModelComp, df.train, type="prob")
# names(preds) <- names(ensemble$models)

weights <- list(length = length(ensemble$models))
test.set <- GetFactorEqual(test.set)
i = 0
for (ind.pred in preds) {
i = i + 1
conf.matrix = caret::confusionMatrix(test.set, as.factor(round(ind.pred[, 1])))
# those tricky SVM's
if (names(preds)[[i]] == "svm.formula") {
pred <- GetSVMScale(ind.pred)
conf.matrix = caret::confusionMatrix(test.set, as.factor(pred))
} else {
conf.matrix = caret::confusionMatrix(test.set, as.factor(round(ind.pred[, 1])))
}
weights[[i]] = conf.matrix$byClass[train.type][[1]]
}
return(weights)
Expand Down
2 changes: 2 additions & 0 deletions R/ModelGeneration.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ predict.ModelComparison <- function(object, newdata, ...) {
i = i + 1
pred.basic[[i]] = GetPredType(model, newdata)
}
.getClass <- function(model) { return(class(model)[[1]]) }
names(pred.basic) <- sapply(object$model.list, .getClass)
return(pred.basic)
}
}
Expand Down
2 changes: 0 additions & 2 deletions R/Visualization.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ plot.ModelComparison <- function(object, labels, training.data = "none", predict
if (predictions == "empty") {
# Predictions not given - create them here from training data
pred.basic <- predict(object, newdata=training.data, type="prob")
.getClass <- function(model) { return(class(model)[[1]]) }
names(pred.basic) <- sapply(object$model.list, .getClass)
} else {
# use the given predictions
pred.basic = predictions
Expand Down
24 changes: 24 additions & 0 deletions man/GetSVMScale.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion man/plot.ModelComparison.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 1 addition & 5 deletions tests/testthat/test_ensemble_integration.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@ context("Ensemble - Model Comparison Integration")
# load the libraries
library(BestModel)

# test_that("Ensemble plot", {
# # TODO: decide what to do with this
# })

test_that("Ensemble used in a ModelComparison", {
# prepare the dataset
iris <- PrepareIris()
Expand All @@ -21,5 +17,5 @@ test_that("Ensemble used in a ModelComparison", {
comp <- ModelComparison(mlist, F)
expect_equal(class(comp), "ModelComparison")
# make sure plot works
print(plot(comp, iris[, 5], iris[,1:4]))
plot(comp, iris[, 5], iris[,1:4])
})
7 changes: 5 additions & 2 deletions tests/testthat/test_package_integration.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ test_that("Other packages used for ModelComparison", {
# multiple metrics, two word metrics, uncapitalized
plot(comp, titanic[, 1], titanic[, -1], plot.type=list("precision", "accuracy",
"recall", "detection rate"))
# verify we can make an Ensemble
ensem <- Ensemble(comp$model.list, "majorityWeight", titanic[, -1], titanic[, 1])
expect_equal(length(ensem$weight.list), length(models))


# ensem1 <- Ensemble(comp$model.list, "majorityWeight", iris[,1:4], iris[,5])
pred.ensem <- predict(ensem, titanic[, -1])
expect_equal(length(titanic[, 1]), length(pred.ensem))

})

0 comments on commit 876ad80

Please sign in to comment.