Learn how to fit a random forest model
We will hindcast who survived the Titanic disaster
Understand OOB and variable importance in practice
library(Hmisc)
## Loading required package: lattice
## Loading required package: survival
## Loading required package: Formula
## Loading required package: ggplot2
##
## Attaching package: 'Hmisc'
## The following objects are masked from 'package:base':
##
## format.pval, units
library("caret")
## Warning: package 'caret' was built under R version 3.6.3
##
## Attaching package: 'caret'
## The following object is masked from 'package:survival':
##
## cluster
library("rpart")
library("tree")
library("e1071")
##
## Attaching package: 'e1071'
## The following object is masked from 'package:Hmisc':
##
## impute
library(ggplot2) # Data visualization
library(readr) # CSV file I/O, e.g. the read_csv function
library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
##
## margin
set.seed(1)
This is a considerable shortcut.
This specific dataset has been prepared by Thilaksha Silva
Please see THIS LINK for an explanatory analysis of the Titanic data.
mytrain = read.csv("titanictrain.csv")
mytest = read.csv("titanictest.csv")
mytitanic = rbind(mytest,mytrain)
# Fitting Random Forest Classification to the Training set
set.seed(432)
myclassifier = randomForest(as.factor(Survived) ~ ., data = mytrain, importance=TRUE)
# Choosing the number of trees
plot(myclassifier,main="Prediction errors: OOB (black), Death (red), Survival (green) ")
The figure shows how the errors get to a statistical equilibrium for a large enough forest.
The OOB error rate converges to around 17%.
Prediction error of death are lower than prediction error of survival.
The default value of 500 trees in the randomForest
function seemd to be ok here.
And also the misclassification error.
# Predicting the Validation set results
y_pred = predict(myclassifier, newdata = mytest[,-which(names(mytest)=="Survived")])
# Checking the prediction accuracy
table(mytest$Survived, y_pred) # Confusion matrix
## y_pred
## 0 1
## 0 103 7
## 1 18 50
error <- mean(mytest$Survived != y_pred) # Misclassification error
paste('Accuracy',round(1-error,4))
## [1] "Accuracy 0.8596"
y_pred 0 1 0 103 7 1 18 50 [1] “Accuracy 0.8596”
Two possible measures of variable importance are returned by randomforest:
Computed from permuting OOB data: For each tree, the prediction error on the OOB portion of the data is recorded (error rate for classification, MSE for regression).
Then the same is done after permuting each predictor variable. The difference between the two are then averaged over all trees, and normalized by the standard deviation of the differences. If the standard deviation of the differences is equal to 0 for a variable, the division is not done (but the average is almost always equal to 0 in that case).
This measure is the total decrease in node impurities from splitting on the variable, averaged over all trees. For classification, the node impurity is measured by the Gini index. For regression, it is measured by residual sum of squares.
importance(myclassifier)
## 0 1 MeanDecreaseAccuracy MeanDecreaseGini
## X -1.497090 3.055049 0.7060918 44.417401
## Pclass 20.234098 33.052913 36.9043130 27.634134
## Sex 21.871535 20.195071 25.3327735 40.686634
## Age 11.190499 19.630063 22.6116289 39.278575
## SibSp 9.403132 2.029170 9.8102733 10.657857
## Parch 2.105338 4.159154 4.3428452 6.421137
## Fare 18.410177 22.180937 30.0149496 51.814841
## Embarked 3.778376 13.656844 12.6509713 9.133448
## Title 25.378065 23.487058 29.0391057 57.846637
## FamilySize 8.147329 12.278848 15.0706673 13.153078
(measure 1): MeanDecreaseAccuracy
(measure 2): MeanDecreaseGini
Plot of variable importance
varImpPlot(myclassifier)
Why is title important?
table(mytest$Title,mytest$Survived)
##
## 0 1
## Master 1 4
## Miss 5 23
## Mr 93 17
## Mrs 5 24
## Other 6 0
oh well…
table(mytest$Title,mytest$Sex)
##
## female male
## Master 0 5
## Miss 28 0
## Mr 0 110
## Mrs 29 0
## Other 0 6
table(mytest$Sex,mytest$Survived)
##
## 0 1
## female 10 47
## male 100 21
hist(mytest$Fare)
median(mytest$Fare)
## [1] 15
table(mytest$Fare < 15,mytest$Survived)
##
## 0 1
## FALSE 43 46
## TRUE 67 22
rm(mytrain,mytest)
#mytrain
dotitan <- function(ntree,mtry,iopt){
set.seed(432)
acc = NULL
for(i in 1:100){
rm(mytrain,mytest)
nrec=nrow(mytitanic);
ntrain=ifelse(nrec%%2==0,nrec/2,(nrec+1)/2);
#ntrain
train=sample(1:nrec,ntrain,replace=F)
mytrain = mytitanic[train,]
mytest = mytitanic[-train,]
#train
# Fitting Random Forest Classification to the Training set
#if(iopt == 0){
myclassifier = randomForest(as.factor(Survived) ~ ., data = mytrain, ntree = 500)
#}
#if(iopt == 1){
#myclassifier = randomForest(as.factor(Survived) ~ ., #ntree=ntree, data = mytrain)
#}
#if(iopt == 2){
#myclassifier = randomForest(as.factor(Survived) ~ ., ntree #=ntree, mtry = mtry, data = mytrain)
#}
# Predict response on test covariates
y_pred = predict(myclassifier, newdata = mytest[,-which(names(mytest)=="Survived")])
# Check the prediction accuracy
table(mytest$Survived, y_pred) # Confusion matrix
error <- mean(mytest$Survived != y_pred) # Misclassification error
acc = c(acc,round(1-error,4))
}
return(acc)
}
h <- dotitan(100,4,0)
## Warning in rm(mytrain, mytest): object 'mytrain' not found
## Warning in rm(mytrain, mytest): object 'mytest' not found
hist(h)
mean(h)
## [1] 0.827749
Now a different exercise. Freeze the validation set
iopt = 1
ntree=500
mtry=4
set.seed(432)
acc = NULL
rm(mytrain,mytest)
## Warning in rm(mytrain, mytest): object 'mytrain' not found
## Warning in rm(mytrain, mytest): object 'mytest' not found
nrec=nrow(mytitanic);
ntrain=ifelse(nrec%%2==0,nrec/2,(nrec+1)/2);
#ntrain
train=sample(1:nrec,ntrain,replace=F)
mytrain = mytitanic[train,]
mytest = mytitanic[-train,]
Run 50 replicates (for example)
for(i in 1:50){
myclassifier = randomForest(as.factor(Survived) ~ ., ntree=ntree, data = mytrain)
# Predict response on test covariates
y_pred = predict(myclassifier, newdata = mytest[,-which(names(mytest)=="Survived")])
# Check the prediction accuracy
table(mytest$Survived, y_pred) # Confusion matrix
error <- mean(mytest$Survived != y_pred) # Misclassification error
acc = c(acc,round(1-error,4))
}
hist(acc,main="Yap RF is not very variable")
mean(acc)
## [1] 0.8365