Nous disposons de données recueillies à Hewlett-Packards Labs, qui classe 4601 e-mails comme spam ou non spam. Ces données sont disponnibles sous la librarie kernlab du logiciel R. Il s’agit d’identifier les messages électroniques frauduleux à partir de leurs caractéristiques. Le donneur des e-mails c’est George Forman de HP Labs.
library(tidyverse)
library(kernlab)
library(caret)
data(spam)
La variable cible est la variabletype
. Il y a 57 variables explicatives indiquant la fréquence de certains mots et de caractères dans l’e-mail. Regarder l’aide avec la commande help(spam)
si vous avez besoin de plus d’information.
Calculons la proportion des classes :
#proportion estimée des classes
prop.table(table(spam$type))
##
## nonspam spam
## 0.6059552 0.3940448
39.4% des messages sont frauduleux et 60.6% ne le sont pas.
Régression logistique avec glmt
fit_glm <- glm(type ~.,data=spam,family=binomial)
fit_glm
##
## Call: glm(formula = type ~ ., family = binomial, data = spam)
##
## Coefficients:
## (Intercept) make address
## -1.569e+00 -3.895e-01 -1.458e-01
## all num3d our
## 1.141e-01 2.252e+00 5.624e-01
## over remove internet
## 8.830e-01 2.279e+00 5.696e-01
## order mail receive
## 7.343e-01 1.275e-01 -2.557e-01
## will people report
## -1.383e-01 -7.961e-02 1.447e-01
## addresses free business
## 1.236e+00 1.039e+00 9.599e-01
## email you credit
## 1.203e-01 8.131e-02 1.047e+00
## your font num000
## 2.419e-01 2.013e-01 2.245e+00
## money hp hpl
## 4.264e-01 -1.920e+00 -1.040e+00
## george num650 lab
## -1.177e+01 4.454e-01 -2.486e+00
## labs telnet num857
## -3.299e-01 -1.702e-01 2.549e+00
## data num415 num85
## -7.383e-01 6.679e-01 -2.055e+00
## technology num1999 parts
## 9.237e-01 4.651e-02 -5.968e-01
## pm direct cs
## -8.650e-01 -3.046e-01 -4.505e+01
## meeting original project
## -2.689e+00 -1.247e+00 -1.573e+00
## re edu table
## -7.923e-01 -1.459e+00 -2.326e+00
## conference charSemicolon charRoundbracket
## -4.016e+00 -1.291e+00 -1.881e-01
## charSquarebracket charExclamation charDollar
## -6.574e-01 3.472e-01 5.336e+00
## charHash capitalAve capitalLong
## 2.403e+00 1.199e-02 9.119e-03
## capitalTotal
## 8.437e-04
##
## Degrees of Freedom: 4600 Total (i.e. Null); 4543 Residual
## Null Deviance: 6170
## Residual Deviance: 1816 AIC: 1932
La fonction trainControl()
de caret permet de fixer les paramètres du processus d’apprentissage. Commencons doucement !
param_train <- trainControl(method="none")
Modèle logistique avec caret
tmp <- train(type~ ., data = spam, method="glm",trControl=param_train)
tmp
## Generalized Linear Model
##
## 4601 samples
## 57 predictor
## 2 classes: 'nonspam', 'spam'
##
## No pre-processing
## Resampling: None
tmp$finalModel
##
## Call: NULL
##
## Coefficients:
## (Intercept) make address
## -1.569e+00 -3.895e-01 -1.458e-01
## all num3d our
## 1.141e-01 2.252e+00 5.624e-01
## over remove internet
## 8.830e-01 2.279e+00 5.696e-01
## order mail receive
## 7.343e-01 1.275e-01 -2.557e-01
## will people report
## -1.383e-01 -7.961e-02 1.447e-01
## addresses free business
## 1.236e+00 1.039e+00 9.599e-01
## email you credit
## 1.203e-01 8.131e-02 1.047e+00
## your font num000
## 2.419e-01 2.013e-01 2.245e+00
## money hp hpl
## 4.264e-01 -1.920e+00 -1.040e+00
## george num650 lab
## -1.177e+01 4.454e-01 -2.486e+00
## labs telnet num857
## -3.299e-01 -1.702e-01 2.549e+00
## data num415 num85
## -7.383e-01 6.679e-01 -2.055e+00
## technology num1999 parts
## 9.237e-01 4.651e-02 -5.968e-01
## pm direct cs
## -8.650e-01 -3.046e-01 -4.505e+01
## meeting original project
## -2.689e+00 -1.247e+00 -1.573e+00
## re edu table
## -7.923e-01 -1.459e+00 -2.326e+00
## conference charSemicolon charRoundbracket
## -4.016e+00 -1.291e+00 -1.881e-01
## charSquarebracket charExclamation charDollar
## -6.574e-01 3.472e-01 5.336e+00
## charHash capitalAve capitalLong
## 2.403e+00 1.199e-02 9.119e-03
## capitalTotal
## 8.437e-04
##
## Degrees of Freedom: 4600 Total (i.e. Null); 4543 Residual
## Null Deviance: 6170
## Residual Deviance: 1816 AIC: 1932
Dans la suite on va utiliser le package caret et plusieurs méthodes pour predire la variable type
en fonction d’autres variables.
Hold-out (découpage en 2 sous-echantillons Dtrain et Dtest)
Imaginons que nous considereons 80% des lignes de notre tableau pour l’apprentissage.
set.seed(100)
ind_train <- createDataPartition(spam$type,p = 0.8,list = FALSE)
Nous utilisons ind_train
pour partitionner les données.
#sous-echantillon d'apprentissage et de test
Dtrain <- spam[ind_train,]; dim(Dtrain)
## [1] 3682 58
Dtest <- spam[-ind_train,]; dim(Dtest)
## [1] 919 58
table(Dtrain$type); prop.table(table(Dtrain$type))
##
## nonspam spam
## 2231 1451
##
## nonspam spam
## 0.6059207 0.3940793
table(Dtest$type); prop.table(table(Dtest$type))
##
## nonspam spam
## 557 362
##
## nonspam spam
## 0.6060936 0.3939064
Avant de tester des méthodes plus compliquées, c’est toujours une bonne idée d’estimer les performances de la méthode prédisant toujours la classe majoritaire (nonspam ici). Quelle est l’erreur de cette prédiction sur les données de tests?
Revenons avec le modèle logistique avec le package caret. On va apprendre un modèle avec les données Dtrain
fit_glm <- train(type~ ., data = Dtrain, method="glm",trControl=param_train)
Dans la suite, on va estimer l’erreur de prediction sur les données de test. Pour ce faire utilisons l’échantillon test et la fonction predict()
. Visualisons les commandes suivantes. Que peut-on-remarquer ?
head(predict(fit_glm, newdata = Dtest,type="prob"))
## nonspam spam
## 1 3.697804e-01 0.6302196
## 2 1.278400e-02 0.9872160
## 3 8.768356e-07 0.9999991
## 5 2.197250e-01 0.7802750
## 14 1.941361e-01 0.8058639
## 21 8.574695e-01 0.1425305
head(predict(fit_glm, newdata = Dtest))
## [1] spam spam spam spam spam nonspam
## Levels: nonspam spam
score_glm <- predict(fit_glm, newdata = Dtest,
type="prob")
head(score_glm)
## nonspam spam
## 1 3.697804e-01 0.6302196
## 2 1.278400e-02 0.9872160
## 3 8.768356e-07 0.9999991
## 5 2.197250e-01 0.7802750
## 14 1.941361e-01 0.8058639
## 21 8.574695e-01 0.1425305
#valeurs predites sur les données Dtest (obtenues pour un seuil fixé à 0.5)
pred_glm <- predict(fit_glm, newdata = Dtest)
#Matrice de confusion
table(predite=pred_glm,observee=Dtest$type)
## observee
## predite nonspam spam
## nonspam 530 46
## spam 27 316
#erreur estimé sur les données Dtest
mean(pred_glm!=Dtest$type)
## [1] 0.07943417
Nous avons les classes prédites en ligne, les observées en colonne. L’estimation de l’erreur sur les données de test vaut 0.079 (7.9%).
Autre metrique : Courbe ROC
matrixConf <- confusionMatrix(data=pred_glm,
reference=Dtest$type,positive="spam")
matrixConf$overall["Accuracy"]
## Accuracy
## 0.9205658
matrixConf
## Confusion Matrix and Statistics
##
## Reference
## Prediction nonspam spam
## nonspam 530 46
## spam 27 316
##
## Accuracy : 0.9206
## 95% CI : (0.9012, 0.9372)
## No Information Rate : 0.6061
## P-Value [Acc > NIR] : < 2e-16
##
## Kappa : 0.8321
##
## Mcnemar's Test P-Value : 0.03514
##
## Sensitivity : 0.8729
## Specificity : 0.9515
## Pos Pred Value : 0.9213
## Neg Pred Value : 0.9201
## Prevalence : 0.3939
## Detection Rate : 0.3439
## Detection Prevalence : 0.3732
## Balanced Accuracy : 0.9122
##
## 'Positive' Class : spam
##
roc_obj <- pROC::roc(Dtest$type=="spam",score_glm[,"spam"])
plot(1-roc_obj$specificities,roc_obj$sensitivities,type="l")
abline(0,1)
roc_obj$auc
## Area under the curve: 0.9712
On va dans la suite estimer l’erreur de prédiction à l’aide de la méthode V-Fold cross validation (V = numbre de blocs, voir les slides du cours).
Validation croisée avec V blocs (K-Fold Cross-validation)
Sous caret, il suffit de modifier le trainControl()
puis de relancer le processus de modélisation. Nous demandons une validation croisée (method=cv) avec (number=5) blocs (folds).
param_train <- trainControl(method="cv",number=5)
Dans la suite on va réutiliser les données complètes (data = spam
) et utiliser 5-Fold validation croisée pour estimer l’erreur de prediction.
set.seed(100)
fit_glm <- train(type ~ ., data = spam,method="glm",trControl=param_train)
Nous disposons du détail des résultats, l’accuracy pour chaque fold, avec le champ $resample
fit_glm$resample
## Accuracy Kappa Resample
## 1 0.9304348 0.8532578 Fold1
## 2 0.9282609 0.8492612 Fold2
## 3 0.9293478 0.8514788 Fold3
## 4 0.9239130 0.8387331 Fold4
## 5 0.9315961 0.8560091 Fold5
Note que l’accuracy par VC (validation croisée) vaut
mean(fit_glm$resample$Accuracy)
## [1] 0.9287105
On peut avoir le résultat directement en tappant directement la commande suivante
fit_glm
## Generalized Linear Model
##
## 4601 samples
## 57 predictor
## 2 classes: 'nonspam', 'spam'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 3681, 3681, 3681, 3681, 3680
## Resampling results:
##
## Accuracy Kappa
## 0.9287105 0.849748
Méthode k plus proches voisins avec caret (méthode knn)
#k plus proches voisins
set.seed(100)
k <- data.frame(k=c(7))
# methode knn (7 plus proches voisins)
fit_knn<- train(type ~ ., data = spam, method = "knn", trControl = param_train , tuneGrid = k)
fit_knn
## k-Nearest Neighbors
##
## 4601 samples
## 57 predictor
## 2 classes: 'nonspam', 'spam'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 3681, 3681, 3681, 3681, 3680
## Resampling results:
##
## Accuracy Kappa
## 0.7954827 0.56915
##
## Tuning parameter 'k' was held constant at a value of 7
Pourquoi la méthode de knn marche moins bien ici? Avez-vous une idée ?
#k plus proches voisins
set.seed(100)
k <- data.frame(k=c(3,5,7,9))
# methode knn
fit_knn<- train(type ~ ., data = spam, method = "knn", trControl = param_train , tuneGrid = k)
fit_knn
## k-Nearest Neighbors
##
## 4601 samples
## 57 predictor
## 2 classes: 'nonspam', 'spam'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 3681, 3681, 3681, 3681, 3680
## Resampling results across tuning parameters:
##
## k Accuracy Kappa
## 3 0.8011292 0.5823256
## 5 0.7993964 0.5781790
## 7 0.7957006 0.5696367
## 9 0.7898310 0.5582916
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was k = 3.
plot(fit_knn)
N’hesitez pas à toujours appliquer une méthode logistique (avec glm) et d’autres méthodes sur vos données. Chaque experience est differente est glm marche bien pour beaucoup des données. Après tout ca depend de votre problème métier. De la question et complexité du problème.
Autres méthodes
La liste des méthodes utilisables est enorme. Voir la site https://topepo.github.io/caret/train-models-by-tag.html
A vous de jouer ! Utiliser d’autres méthodes…Bon courage pour la suite