library(tidyverse)
library(caret)
library(gridExtra)
library(grid)
We will use the twoClass dataset from Applied Predictive Modeling, the book of M. Kuhn and K. Johnson to illustrate the most classical supervised classification algorithms. We will use some advanced R packages: the ggplot2 package for the figures and the caret package for the learning part. caret that provides an unified interface to many other packages.
We read first the dataset and use ggplot2 to display it.
library(AppliedPredictiveModeling)
library(RColorBrewer)
data(twoClassData)
twoClass <- cbind(as.data.frame(predictors),classes)
twoClassColor <- brewer.pal(3,'Set1')[1:2]
names(twoClassColor) <- c('Class1','Class2')
nrow(twoClass)
## [1] 208
summary(twoClass)
## PredictorA PredictorB classes
## Min. :0.0236 Min. :0.0289 Class1:111
## 1st Qu.:0.1335 1st Qu.:0.1293 Class2: 97
## Median :0.2490 Median :0.2248
## Mean :0.2502 Mean :0.2360
## 3rd Qu.:0.3312 3rd Qu.:0.3016
## Max. :0.7060 Max. :0.7342
prop.table(table(twoClass$classes))
##
## Class1 Class2
## 0.5336538 0.4663462
ggplot(data = twoClass,aes(x = PredictorA, y = PredictorB)) +
geom_point(aes(color = classes), size = 6, alpha = .5) +
scale_colour_manual(name = 'classes', values = twoClassColor) +
scale_x_continuous(expand = c(0,0)) +
scale_y_continuous(expand = c(0,0))
We create a few functions that will be useful to display our classifiers.
nbp <- 250;
PredA <- seq(min(twoClass$PredictorA), max(twoClass$PredictorA), length = nbp)
PredB <- seq(min(twoClass$PredictorB), max(twoClass$PredictorB), length = nbp)
Grid <- expand.grid(PredictorA = PredA, PredictorB = PredB)
PlotGrid <- function(pred,title) {
surf <- (ggplot(data = twoClass, aes(x = PredictorA, y = PredictorB,
color = classes)) +
geom_tile(data = cbind(Grid, classes = pred), aes(fill = classes)) +
scale_fill_manual(name = 'classes', values = twoClassColor) +
ggtitle("Decision region") + theme(legend.text = element_text(size = 10)) +
scale_colour_manual(name = 'classes', values = twoClassColor)) +
scale_x_continuous(expand = c(0,0)) +
scale_y_continuous(expand = c(0,0))
pts <- (ggplot(data = twoClass, aes(x = PredictorA, y = PredictorB,
color = classes)) +
geom_contour(data = cbind(Grid, classes = pred), aes(z = as.numeric(classes)),
color = "red", breaks = c(1.5)) +
geom_point(size = 4, alpha = .5) +
ggtitle("Decision boundary") +
theme(legend.text = element_text(size = 10)) +
scale_colour_manual(name = 'classes', values = twoClassColor)) +
scale_x_continuous(expand = c(0,0)) +
scale_y_continuous(expand = c(0,0))
grid.arrange(surf, pts, top = textGrob(title, gp = gpar(fontsize = 20)), ncol = 2)
}
As explained in the introduction, we will use caret for the learning part. This package provides a unified interface to a huge number of classifier available in R. It is a very powerful tool when exploring the different models. In particular, it proposes to compute a resampling accuracy estimate and gives the user the choice of the specific methodology. We will use a repeated V-fold strategy with \(10\) folds and \(2\) repetitions. We will reuse the same seed for each model in order to be sure that the same folds are used.
library("caret")
V <- 10
T <- 4
TrControl <- trainControl(method = "repeatedcv",
number = V,
repeats = T)
Seed <- 345
Finally, we provide a function that will store the accuracies for every resample and every model. We will also compute an accuracy based on the same data than the one used to learn in order to show the over-fitting phenomenon.
ErrsCaret <- function(Model, Name) {
Errs <- data.frame(t(postResample(predict(Model, newdata = twoClass), twoClass[["classes"]])),
Resample = "None", model = Name)
rbind(Errs, data.frame(Model$resample, model = Name))
}
Errs <- data.frame()
We are now ready to define a function that take in input the current collection of accuracies, a name of a model, the corresponding formula and methods, as well as more parameters used to specify the model, and computes the trained model, displays its prediction in a figure and add the errors in the collection.
CaretLearnAndDisplay <- function (Errs, Name, Formula, Method, ...) {
set.seed(Seed)
Model <- train(as.formula(Formula), data = twoClass, method = Method, trControl = TrControl, ...)
Pred <- predict(Model, newdata = Grid)
PlotGrid(Pred, Name)
Errs <- rbind(Errs, ErrsCaret(Model, Name))
}
We can apply this function to any model available in caret. We will pick a few models and sort them depending on the heuristic used to define them. We will distinguish models coming from a statistical point of view in which one try to estimate the conditional law and plug it into the Bayes classifier and from an optimization point of view in which one try to enforce a small training error by minimizing a relaxed criterion.
In the nearest neighbor method, we need to supply a parameter: \(k\) the number of neighbors used to define the kernel. We will compare visually the solution obtained with a few \(k\) values.
ErrsKNN <- data.frame()
KNNKS <- c(seq(1, 33, by = 4),seq(37,85, by = 8), seq(101, 200, by = 8))
for (k in KNNKS) {
ErrsKNN <- CaretLearnAndDisplay(ErrsKNN, sprintf("k-NN with k=%i", k),
"classes ~ .","knn", tuneGrid = data.frame(k = c(k)))
}
Errs <- rbind(Errs, ErrsKNN)
We will now compare our model according to 4 criterion: - Accuracy: the naive empirical accuracy computed on the same dataset than the one used to learn the classifier, - AccuracyCV: the mean of the accuracies obtained by the resampling scheme, - AccuracyInf: a lowest bound on the mean accuracy obtained by substracting to the mean two times the standard deviation divided by the square root of the number of resamples. - AccuracyPAC: a highly probably bound on the accuracy obtained by substracting the standard deviation.
ErrCaretAccuracy <- function(Errs) {
Errs <- group_by(Errs, model)
cbind(dplyr::summarize(Errs, mAccuracy = mean(Accuracy, na.rm = TRUE), mKappa = mean(Kappa, na.rm = TRUE),
sdAccuracy = sd(Accuracy, na.rm = TRUE), sdKappa = sd(Kappa, na.rm = TRUE)))
}
ErrAndPlotErrs <- function(Errs, simple = FALSE) {
ErrCV <- ErrCaretAccuracy(dplyr::filter(Errs, !(Resample == "None")))
ErrCVtmp <- transmute(ErrCV, AccuracyCV = mAccuracy, AccuracyCVInf = mAccuracy - 2 * sdAccuracy/sqrt(T * V),
AccuracyCVPAC = mAccuracy - sdAccuracy,
model = model)
ErrEmp <- dplyr::select(dplyr::filter(Errs, (Resample == "None")), Accuracy, model)
Err <- dplyr::left_join(ErrCVtmp, ErrEmp)
if (simple) {
print(ggplot(data = reshape2::melt(dplyr::select(Err, model, Accuracy, AccuracyCV),
"model"),
aes(x = model, y = value, color = variable)) +
geom_point(size = 5) +
theme(axis.text.x = element_text(angle = 45, hjust = 1),
plot.margin = grid::unit(c(1, 1, .5, 1.5), "lines")))
} else {
print(ggplot(data = reshape2::melt(Err, "model"),
aes(x = model, y = value, color = variable)) +
geom_errorbar(data = ErrCV,
aes(x = model,
ymin = mAccuracy - 2 * sdAccuracy/sqrt(T * V),
ymax = mAccuracy + 2 * sdAccuracy/sqrt(T * V), y = mAccuracy), color = "black") +
geom_point(size = 5) +
theme(axis.text.x = element_text(angle = 45, hjust = 1),
plot.margin = grid::unit(c(1, 1, .5, 1.5), "lines")))
}
Err
}
Err <- ErrAndPlotErrs(Errs)
FindBestErr <- function(Err) {
for (name in names(Err)[!(names(Err) == "model")]) {
ind <- which.max(Err[, name])
writeLines(strwrap(paste("Best method according to", name, ":", Err[ind, "model"])))
}
}
FindBestErr(Err)
## Best method according to AccuracyCV : k-NN with k=77
## Best method according to AccuracyCVInf : k-NN with k=77
## Best method according to AccuracyCVPAC : k-NN with k=77
## Best method according to Accuracy : k-NN with k=1
We focus here on the choice of the neighborhood in the \(k\) Nearest Neighbor method.
ErrKNN <- ErrAndPlotErrs(ErrsKNN, simple = TRUE)
FindBestErr(ErrKNN)
## Best method according to AccuracyCV : k-NN with k=77
## Best method according to AccuracyCVInf : k-NN with k=77
## Best method according to AccuracyCVPAC : k-NN with k=77
## Best method according to Accuracy : k-NN with k=1
ErrKNN <- ErrAndPlotErrs(ErrsKNN)
FindBestErr(ErrKNN)
## Best method according to AccuracyCV : k-NN with k=77
## Best method according to AccuracyCVInf : k-NN with k=77
## Best method according to AccuracyCVPAC : k-NN with k=77
## Best method according to Accuracy : k-NN with k=1