Nous disposons de données recueillies à Hewlett-Packards Labs, qui classe 4601 e-mails comme spam ou non spam. Ces données sont disponnibles sous la librarie kernlab du logiciel R. Il s’agit d’identifier les messages électroniques frauduleux à partir de leurs caractéristiques. Le donneur des e-mails c’est George Forman de HP Labs.

library(tidyverse)
library(kernlab)
library(caret)
data(spam)

La variable cible est la variabletype. Il y a 57 variables explicatives indiquant la fréquence de certains mots et de caractères dans l’e-mail. Regarder l’aide avec la commande help(spam) si vous avez besoin de plus d’information.

Calculons la proportion des classes :

#proportion estimée des classes 
prop.table(table(spam$type))
## 
##   nonspam      spam 
## 0.6059552 0.3940448

39.4% des messages sont frauduleux et 60.6% ne le sont pas.

Motivation : utilisation du package caret

Régression logistique avec glmt

fit_glm <- glm(type ~.,data=spam,family=binomial)
fit_glm
## 
## Call:  glm(formula = type ~ ., family = binomial, data = spam)
## 
## Coefficients:
##       (Intercept)               make            address  
##        -1.569e+00         -3.895e-01         -1.458e-01  
##               all              num3d                our  
##         1.141e-01          2.252e+00          5.624e-01  
##              over             remove           internet  
##         8.830e-01          2.279e+00          5.696e-01  
##             order               mail            receive  
##         7.343e-01          1.275e-01         -2.557e-01  
##              will             people             report  
##        -1.383e-01         -7.961e-02          1.447e-01  
##         addresses               free           business  
##         1.236e+00          1.039e+00          9.599e-01  
##             email                you             credit  
##         1.203e-01          8.131e-02          1.047e+00  
##              your               font             num000  
##         2.419e-01          2.013e-01          2.245e+00  
##             money                 hp                hpl  
##         4.264e-01         -1.920e+00         -1.040e+00  
##            george             num650                lab  
##        -1.177e+01          4.454e-01         -2.486e+00  
##              labs             telnet             num857  
##        -3.299e-01         -1.702e-01          2.549e+00  
##              data             num415              num85  
##        -7.383e-01          6.679e-01         -2.055e+00  
##        technology            num1999              parts  
##         9.237e-01          4.651e-02         -5.968e-01  
##                pm             direct                 cs  
##        -8.650e-01         -3.046e-01         -4.505e+01  
##           meeting           original            project  
##        -2.689e+00         -1.247e+00         -1.573e+00  
##                re                edu              table  
##        -7.923e-01         -1.459e+00         -2.326e+00  
##        conference      charSemicolon   charRoundbracket  
##        -4.016e+00         -1.291e+00         -1.881e-01  
## charSquarebracket    charExclamation         charDollar  
##        -6.574e-01          3.472e-01          5.336e+00  
##          charHash         capitalAve        capitalLong  
##         2.403e+00          1.199e-02          9.119e-03  
##      capitalTotal  
##         8.437e-04  
## 
## Degrees of Freedom: 4600 Total (i.e. Null);  4543 Residual
## Null Deviance:       6170 
## Residual Deviance: 1816  AIC: 1932

La fonction trainControl() de caret permet de fixer les paramètres du processus d’apprentissage. Commencons doucement !

param_train <- trainControl(method="none") 

Modèle logistique avec caret

tmp <- train(type~ ., data = spam, method="glm",trControl=param_train)
tmp
## Generalized Linear Model 
## 
## 4601 samples
##   57 predictor
##    2 classes: 'nonspam', 'spam' 
## 
## No pre-processing
## Resampling: None
tmp$finalModel
## 
## Call:  NULL
## 
## Coefficients:
##       (Intercept)               make            address  
##        -1.569e+00         -3.895e-01         -1.458e-01  
##               all              num3d                our  
##         1.141e-01          2.252e+00          5.624e-01  
##              over             remove           internet  
##         8.830e-01          2.279e+00          5.696e-01  
##             order               mail            receive  
##         7.343e-01          1.275e-01         -2.557e-01  
##              will             people             report  
##        -1.383e-01         -7.961e-02          1.447e-01  
##         addresses               free           business  
##         1.236e+00          1.039e+00          9.599e-01  
##             email                you             credit  
##         1.203e-01          8.131e-02          1.047e+00  
##              your               font             num000  
##         2.419e-01          2.013e-01          2.245e+00  
##             money                 hp                hpl  
##         4.264e-01         -1.920e+00         -1.040e+00  
##            george             num650                lab  
##        -1.177e+01          4.454e-01         -2.486e+00  
##              labs             telnet             num857  
##        -3.299e-01         -1.702e-01          2.549e+00  
##              data             num415              num85  
##        -7.383e-01          6.679e-01         -2.055e+00  
##        technology            num1999              parts  
##         9.237e-01          4.651e-02         -5.968e-01  
##                pm             direct                 cs  
##        -8.650e-01         -3.046e-01         -4.505e+01  
##           meeting           original            project  
##        -2.689e+00         -1.247e+00         -1.573e+00  
##                re                edu              table  
##        -7.923e-01         -1.459e+00         -2.326e+00  
##        conference      charSemicolon   charRoundbracket  
##        -4.016e+00         -1.291e+00         -1.881e-01  
## charSquarebracket    charExclamation         charDollar  
##        -6.574e-01          3.472e-01          5.336e+00  
##          charHash         capitalAve        capitalLong  
##         2.403e+00          1.199e-02          9.119e-03  
##      capitalTotal  
##         8.437e-04  
## 
## Degrees of Freedom: 4600 Total (i.e. Null);  4543 Residual
## Null Deviance:       6170 
## Residual Deviance: 1816  AIC: 1932

Dans la suite on va utiliser le package caret et plusieurs méthodes pour predire la variable type en fonction d’autres variables.

Hold-out (découpage en 2 sous-echantillons Dtrain et Dtest)

Imaginons que nous considereons 80% des lignes de notre tableau pour l’apprentissage.

set.seed(100) 
ind_train <- createDataPartition(spam$type,p = 0.8,list = FALSE) 

Nous utilisons ind_train pour partitionner les données.

#sous-echantillon d'apprentissage et de test
Dtrain <- spam[ind_train,]; dim(Dtrain)
## [1] 3682   58
Dtest <- spam[-ind_train,]; dim(Dtest)
## [1] 919  58
table(Dtrain$type); prop.table(table(Dtrain$type))
## 
## nonspam    spam 
##    2231    1451
## 
##   nonspam      spam 
## 0.6059207 0.3940793
table(Dtest$type); prop.table(table(Dtest$type))
## 
## nonspam    spam 
##     557     362
## 
##   nonspam      spam 
## 0.6060936 0.3939064

Avant de tester des méthodes plus compliquées, c’est toujours une bonne idée d’estimer les performances de la méthode prédisant toujours la classe majoritaire (nonspam ici). Quelle est l’erreur de cette prédiction sur les données de tests?

Revenons avec le modèle logistique avec le package caret. On va apprendre un modèle avec les données Dtrain

fit_glm <- train(type~ ., data = Dtrain, method="glm",trControl=param_train)

Dans la suite, on va estimer l’erreur de prediction sur les données de test. Pour ce faire utilisons l’échantillon test et la fonction predict(). Visualisons les commandes suivantes. Que peut-on-remarquer ?

head(predict(fit_glm, newdata = Dtest,type="prob"))
##         nonspam      spam
## 1  3.697804e-01 0.6302196
## 2  1.278400e-02 0.9872160
## 3  8.768356e-07 0.9999991
## 5  2.197250e-01 0.7802750
## 14 1.941361e-01 0.8058639
## 21 8.574695e-01 0.1425305
head(predict(fit_glm, newdata = Dtest))
## [1] spam    spam    spam    spam    spam    nonspam
## Levels: nonspam spam
score_glm <- predict(fit_glm, newdata = Dtest,
                     type="prob")
head(score_glm)
##         nonspam      spam
## 1  3.697804e-01 0.6302196
## 2  1.278400e-02 0.9872160
## 3  8.768356e-07 0.9999991
## 5  2.197250e-01 0.7802750
## 14 1.941361e-01 0.8058639
## 21 8.574695e-01 0.1425305
#valeurs predites sur les données Dtest (obtenues pour un seuil fixé à 0.5)
pred_glm <- predict(fit_glm, newdata = Dtest)
#Matrice de confusion
table(predite=pred_glm,observee=Dtest$type)
##          observee
## predite   nonspam spam
##   nonspam     530   46
##   spam         27  316
#erreur estimé sur les données Dtest
mean(pred_glm!=Dtest$type)
## [1] 0.07943417

Nous avons les classes prédites en ligne, les observées en colonne. L’estimation de l’erreur sur les données de test vaut 0.079 (7.9%).

Autre metrique : Courbe ROC

matrixConf <- confusionMatrix(data=pred_glm,
                              reference=Dtest$type,positive="spam")

matrixConf$overall["Accuracy"]
##  Accuracy 
## 0.9205658
matrixConf
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction nonspam spam
##    nonspam     530   46
##    spam         27  316
##                                           
##                Accuracy : 0.9206          
##                  95% CI : (0.9012, 0.9372)
##     No Information Rate : 0.6061          
##     P-Value [Acc > NIR] : < 2e-16         
##                                           
##                   Kappa : 0.8321          
##                                           
##  Mcnemar's Test P-Value : 0.03514         
##                                           
##             Sensitivity : 0.8729          
##             Specificity : 0.9515          
##          Pos Pred Value : 0.9213          
##          Neg Pred Value : 0.9201          
##              Prevalence : 0.3939          
##          Detection Rate : 0.3439          
##    Detection Prevalence : 0.3732          
##       Balanced Accuracy : 0.9122          
##                                           
##        'Positive' Class : spam            
## 
roc_obj <- pROC::roc(Dtest$type=="spam",score_glm[,"spam"])
plot(1-roc_obj$specificities,roc_obj$sensitivities,type="l") 
abline(0,1)

roc_obj$auc
## Area under the curve: 0.9712

On va dans la suite estimer l’erreur de prédiction à l’aide de la méthode V-Fold cross validation (V = numbre de blocs, voir les slides du cours).

Validation croisée avec V blocs (K-Fold Cross-validation)

Sous caret, il suffit de modifier le trainControl() puis de relancer le processus de modélisation. Nous demandons une validation croisée (method=cv) avec (number=5) blocs (folds).

param_train <- trainControl(method="cv",number=5) 

Dans la suite on va réutiliser les données complètes (data = spam) et utiliser 5-Fold validation croisée pour estimer l’erreur de prediction.

set.seed(100)
fit_glm <- train(type ~ ., data = spam,method="glm",trControl=param_train)

Nous disposons du détail des résultats, l’accuracy pour chaque fold, avec le champ $resample

fit_glm$resample
##    Accuracy     Kappa Resample
## 1 0.9304348 0.8532578    Fold1
## 2 0.9282609 0.8492612    Fold2
## 3 0.9293478 0.8514788    Fold3
## 4 0.9239130 0.8387331    Fold4
## 5 0.9315961 0.8560091    Fold5

Note que l’accuracy par VC (validation croisée) vaut

mean(fit_glm$resample$Accuracy)
## [1] 0.9287105

On peut avoir le résultat directement en tappant directement la commande suivante

fit_glm
## Generalized Linear Model 
## 
## 4601 samples
##   57 predictor
##    2 classes: 'nonspam', 'spam' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 3681, 3681, 3681, 3681, 3680 
## Resampling results:
## 
##   Accuracy   Kappa   
##   0.9287105  0.849748

Méthode k plus proches voisins avec caret (méthode knn)

#k plus proches voisins
set.seed(100)
k <- data.frame(k=c(7))
# methode knn (7 plus proches voisins)
fit_knn<- train(type ~ ., data = spam, method = "knn", trControl = param_train , tuneGrid = k)
fit_knn
## k-Nearest Neighbors 
## 
## 4601 samples
##   57 predictor
##    2 classes: 'nonspam', 'spam' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 3681, 3681, 3681, 3681, 3680 
## Resampling results:
## 
##   Accuracy   Kappa  
##   0.7954827  0.56915
## 
## Tuning parameter 'k' was held constant at a value of 7

Pourquoi la méthode de knn marche moins bien ici? Avez-vous une idée ?

#k plus proches voisins
set.seed(100)
k <- data.frame(k=c(3,5,7,9))
# methode knn 
fit_knn<- train(type ~ ., data = spam, method = "knn", trControl = param_train , tuneGrid = k)
fit_knn
## k-Nearest Neighbors 
## 
## 4601 samples
##   57 predictor
##    2 classes: 'nonspam', 'spam' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 3681, 3681, 3681, 3681, 3680 
## Resampling results across tuning parameters:
## 
##   k  Accuracy   Kappa    
##   3  0.8011292  0.5823256
##   5  0.7993964  0.5781790
##   7  0.7957006  0.5696367
##   9  0.7898310  0.5582916
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was k = 3.
plot(fit_knn)

N’hesitez pas à toujours appliquer une méthode logistique (avec glm) et d’autres méthodes sur vos données. Chaque experience est differente est glm marche bien pour beaucoup des données. Après tout ca depend de votre problème métier. De la question et complexité du problème.

Autres méthodes

La liste des méthodes utilisables est enorme. Voir la site https://topepo.github.io/caret/train-models-by-tag.html

A vous de jouer ! Utiliser d’autres méthodes…Bon courage pour la suite