Introduction

The regression (and later, classification) tree model

What should we do when we have no idea of the function that generates the data \(y=f(x)\)?

We can use polynomial regression (Monday’s lecture)

Or we can use another idea which is based on the following ‘decision’ model.

Here is the magic formula. We will construct an estimate of God’s function in the following manner:

\[ \hat{y} = \sum_i B_i(x) v_i \] where each parameter \(B_i\) is a boolean expression and only one boolean is true for any given value \(x\) of the predictor variables.

We will construct the boolean value by combining intervals for each predictor of the form \(x < c\) or \(x > c\).

For example, if I have two predictors \(x,z\), I may construct the boolean expression

(x < 1.5) and (z > 3.)

this is a rectangle in the space \((x,z)\) of predictors.

The sweet spot

As explained in the last lecture, we have to find the right model complexity

For regression trees, we can decimate (prune) trees to decrease complexity

We will use the idea of regularization to find the sweetest complexity

It will be given by an optimal value of a hyperparameter (penalizing complexity)

We will use cross-validation for finding the sweet spot

Fitting a Regression Tree

Figure 8.1 of “Introduction to Statistical Learning” shows the initial part of a regression tree which is used to model the logarithm of a baseball players salary, using number of years played, and number of hits in the previous year as predictors.

The data Hitters is part of the ISLR library. The function tree used to fit a regression tree is part of the library tree. After loading the data, records with missing salaries are removed, and only the variables Salary, Years and Hits are retained. Salary is then transformed to log(Salary).

library(ISLR)  #contains the Hitters data set
library(tree)  #contains functions for tree construction
baseball=Hitters[!is.na(Hitters$Salary),] #remove records with missing salary
baseb=baseball[,c("Salary","Years","Hits")] #keep only 3 variables
baseb$Salary=log(baseb$Salary) #model log(salary)

The function tree is used to construct the regression tree, which is then printed, plotted, and text is added to the plot.

basebtree=tree(baseb) #construct the tree
basebtree #print the tree
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 263 207.200 5.927  
##    2) Years < 4.5 90  42.350 5.107  
##      4) Years < 3.5 62  23.010 4.892  
##        8) Hits < 114 43  17.150 4.727  
##         16) Hits < 40.5 5  10.400 5.511 *
##         17) Hits > 40.5 38   3.280 4.624 *
##        9) Hits > 114 19   2.069 5.264 *
##      5) Years > 3.5 28  10.130 5.583 *
##    3) Years > 4.5 173  72.710 6.354  
##      6) Hits < 117.5 90  28.090 5.998  
##       12) Years < 6.5 26   7.238 5.689 *
##       13) Years > 6.5 64  17.350 6.124  
##         26) Hits < 50.5 12   2.689 5.730 *
##         27) Hits > 50.5 52  12.370 6.215 *
##      7) Hits > 117.5 83  20.880 6.740 *
plot(basebtree) #plot the tree
text(basebtree) #label the tree

Following are a number of functions which can be used to construct bits of the tree.

n=dim(baseb)[1] #number of cases
M=dim(baseb)[2] #variables

SSE=function(data,pred,cut){
#split data into two sets, data1 and data2
#returns sums of squares and means for each of data1,data2
 data1=data[pred<=cut]
 data2=data[pred>cut]
#calculate the error sum of squares for each set
 M1=mean(data1)
 M2=mean(data2)
 SSE1=sum((data1-M1)^2)
 SSE2=sum((data2-M2)^2)
#calculate the total error sum of squares
 return(c(SSE1,SSE2,M1,M2))
 }

gencuts=function(x){
 #calculates all possible cutpoints for x
xuo=sort(unique(x))  #sorted unique values of x
cuts=rep(0,(length(xuo)-1)) 
nc=length(cuts)
for (i in 1:nc)cuts[i]=mean(xuo[i:(i+1)]) #midpoints 
return(cuts)}

cut1=function(y,x){
 cuts=gencuts(x) #get cutpoints for variable x
 ncuts=length(cuts) #number of cut points
 SSEs=matrix(rep(0,5*ncuts),byrow=T,ncol=5)
 for (i in 1:length(cuts)){
   SSEs[i,2:5]=SSE(y,x,cuts[i])
   SSEs[i,1]=sum(SSEs[i,2:3])}
   minp=(1:length(cuts))[SSEs[,1]==min(SSEs[,1])]
#returns best cut point, sums of squares for 2 branches, means for 2 branches
 return(c(cuts[minp],SSEs[minp,]))      
}

Here’s an example of what gencuts does.

x <- baseb$Hits
xuo=sort(unique(x))  #sorted unique values of x
plot(xuo)

table(baseb$Hits)
## 
##   1   4  27  32  37  39  40  41  42  43  44  46  47  49  51  52  53  54  55  56 
##   1   1   1   2   1   3   1   3   1   3   1   2   2   2   1   3   4   3   1   4 
##  57  58  60  61  63  64  65  66  68  69  70  71  72  73  74  75  76  77  78  80 
##   2   2   3   1   1   1   1   2   5   1   6   1   1   3   1   1   5   4   3   1 
##  81  82  83  84  85  86  87  90  91  92  93  94  95  96  97  99 101 102 103 104 
##   4   2   3   2   2   2   1   1   2   4   2   3   1   4   1   1   5   2   5   1 
## 106 108 109 110 112 113 114 115 116 117 118 119 120 122 123 124 125 126 127 128 
##   1   2   1   3   2   4   1   1   1   2   2   3   3   2   2   1   1   2   3   2 
## 129 130 131 132 133 135 136 137 138 139 140 141 142 144 145 146 147 148 149 150 
##   1   1   3   2   1   2   3   3   2   3   1   3   2   3   2   1   3   2   2   1 
## 151 152 154 157 158 159 160 161 163 167 168 169 170 171 172 174 177 178 179 183 
##   1   4   2   2   2   3   1   1   5   1   3   3   2   2   1   2   1   1   1   1 
## 184 186 198 200 207 210 211 213 223 238 
##   1   1   1   2   1   1   1   1   1   1
gencuts(baseb$Hits)
##   [1]   2.5  15.5  29.5  34.5  38.0  39.5  40.5  41.5  42.5  43.5  45.0  46.5
##  [13]  48.0  50.0  51.5  52.5  53.5  54.5  55.5  56.5  57.5  59.0  60.5  62.0
##  [25]  63.5  64.5  65.5  67.0  68.5  69.5  70.5  71.5  72.5  73.5  74.5  75.5
##  [37]  76.5  77.5  79.0  80.5  81.5  82.5  83.5  84.5  85.5  86.5  88.5  90.5
##  [49]  91.5  92.5  93.5  94.5  95.5  96.5  98.0 100.0 101.5 102.5 103.5 105.0
##  [61] 107.0 108.5 109.5 111.0 112.5 113.5 114.5 115.5 116.5 117.5 118.5 119.5
##  [73] 121.0 122.5 123.5 124.5 125.5 126.5 127.5 128.5 129.5 130.5 131.5 132.5
##  [85] 134.0 135.5 136.5 137.5 138.5 139.5 140.5 141.5 143.0 144.5 145.5 146.5
##  [97] 147.5 148.5 149.5 150.5 151.5 153.0 155.5 157.5 158.5 159.5 160.5 162.0
## [109] 165.0 167.5 168.5 169.5 170.5 171.5 173.0 175.5 177.5 178.5 181.0 183.5
## [121] 185.0 192.0 199.0 203.5 208.5 210.5 212.0 218.0 230.5

where SSE1 and SSE2 are the error sums of squares for the first and second subsets.

# find the best variable for the first partition
 for (j in 2:3)print(cut1(baseb[,1],baseb[,j]))
## [1]   4.500000 115.058475  42.353165  72.705310   5.106790   6.354036
## [1] 117.500000 160.971530  96.510476  64.461054   5.566327   6.413784

From the output, we see that the best choice is to partition using \(Years<4.5\) and \(Years>4.5\).

  baseb1=baseb[baseb[,2]<4.5,]    
  baseb2=baseb[baseb[,2]>4.5,]
for (j in 2:3)print(cut1(baseb2[,1],baseb2[,j]))  
## [1]  6.500000 69.877019 27.582434 42.294585  6.164227  6.440167
## [1] 117.500000  48.976782  28.093708  20.883074   5.998380   6.739687

Example: use all predictors to grow the tree, work with Salary on original scale.

baseball=Hitters[!is.na(Hitters$Salary),] 
basebtree=tree(Salary~.,data=baseball) #construct the tree
basebtree #print the tree
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 263 53320000  535.9  
##    2) CHits < 450 117  5931000  227.9  
##      4) AtBat < 147 5  2940000  709.5 *
##      5) AtBat > 147 112  1779000  206.4  
##       10) CRBI < 114.5 74   302100  141.8 *
##       11) CRBI > 114.5 38   567200  332.1 *
##    3) CHits > 450 146 27390000  782.8  
##      6) Walks < 61 104  9470000  649.6  
##       12) AtBat < 395.5 53  2859000  510.0 *
##       13) AtBat > 395.5 51  4504000  794.7  
##         26) PutOuts < 771 45  2358000  746.4 *
##         27) PutOuts > 771 6  1255000 1157.0 *
##      7) Walks > 61 42 11500000 1113.0  
##       14) RBI < 73.5 22  3148000  885.3  
##         28) PutOuts < 239.5 7  1739000 1156.0 *
##         29) PutOuts > 239.5 15   656300  758.9 *
##       15) RBI > 73.5 20  5967000 1363.0  
##         30) Years < 13.5 14  3767000 1521.0  
##           60) CAtBat < 3814.5 8   529600 1141.0 *
##           61) CAtBat > 3814.5 6   541500 2028.0 *
##         31) Years > 13.5 6  1026000  992.5 *
plot(basebtree) #plot the tree
text(basebtree) #label the tree

pruning

Remember polynomial regression? Here the number of leaves is the complexity measure of the model!

Let us vary the complexity of our model!

mysize = 2
bp2=prune.tree(basebtree,best=mysize)
plot(bp2)
text(bp2)

mysize = 5
bp5=prune.tree(basebtree,best=mysize)
plot(bp5)
text(bp5)

mysize = 8
bp8=prune.tree(basebtree,best=mysize)
plot(bp8)
text(bp8)

Cross-validated choice of tree size

  • Algorithm 8.1 in ISLR

  • For each \(\alpha\), evaluate the penalized sum of squares

\[PSSE = \sum_{m=1}^{|T|} \sum_{i: x_i \in R_m} (y_i - \hat y_{R_m})^2 + \alpha |T|\]

  • where \(|T|\) is the number of terminal nodes in the tree

  • \(R_m\) is the subset of the predictor space corresponding to the \(m\)’th terminal node.

  • for each \(\alpha\) there is a tree \(T\) such that PSSE is as small as possible.

  • use \(k-fold\) cross validation to choose the value of \(\alpha\) which minimizes PSSE. Equivalently, choose the size of tree which minimizes PSSE. The default value of \(k\) is 10.

Carry out the cross validation, and plot the penalized sum of squares vs tree size.

basebcv=cv.tree(basebtree)
plot(basebcv$size,basebcv$dev,type='b')

bestsize=basebcv$size[basebcv$dev==min(basebcv$dev)]
  • The cross validated choice of tree size is a tree with 7 terminal nodes.

  • Use prune.tree to find the best tree of this size.

baseball.prune=prune.tree(basebtree,best=bestsize)
plot(baseball.prune)
text(baseball.prune)

  • Note: there is sampling variability in the procedure due to the random partition into \(k\) folds. Each time you run the procedure, you are likely to get a different “best” tree. Hopefully they are not too different.

+** Random forests** and/or bagging tend to give gives good predictions, with reduced variability.

Fitting a random forest

Example: fitting random forest using 3 predictors at each split.

library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
n=nrow(baseball)
attach(baseball)
n2=ifelse(n%%2==0,n/2,(n+1)/2)
index=sample(1:n,n2,replace=F)
bbtrain=baseball[index,]
bbtest=baseball[-index,]
saltrain=Salary[index]
saltest=Salary[-index]
salary.rf=randomForest(Salary~.,data=bbtrain,mtry=3,importance=T)
importance(salary.rf)
##              %IncMSE IncNodePurity
## AtBat      6.9069895    1120062.25
## Hits       8.4886165    1609603.87
## HmRun      1.8832088     555266.63
## Runs       4.6433060    1129279.29
## RBI        5.0863375    1097901.53
## Walks      4.2615930    1390003.92
## Years      7.1000876     783558.78
## CAtBat    10.7754552    2132581.63
## CHits     11.4841486    2223612.77
## CHmRun     6.5103673    1435517.01
## CRuns     11.0300561    2225031.99
## CRBI       7.7508920    1707745.82
## CWalks     5.4700093    1606160.86
## League     1.2934688      57653.97
## Division   0.3003970      42201.65
## PutOuts    0.6949388     502469.37
## Assists    1.6269734     379456.18
## Errors     0.9425467     351362.91
## NewLeague -0.8328295      92767.02
varImpPlot(salary.rf)

shattrain=predict(salary.rf,newdata=bbtrain)
shattest=predict(salary.rf,newdata=bbtest)
paste("training MSE = ", mean((saltrain-shattrain)^2))
## [1] "training MSE =  16096.392440201"
paste("test MSE = ", mean((saltest-shattest)^2))
## [1] "test MSE =  91014.6007220454"

Example: bagging, fits a random forest with mtry=19.

salary.rf=randomForest(Salary~.,data=bbtrain,mtry=19,importance=T)
importance(salary.rf)
##              %IncMSE IncNodePurity
## AtBat      7.1476444    1262635.25
## Hits      11.4508217    1906264.77
## HmRun      0.2306228     401727.64
## Runs       1.8570363    1315685.82
## RBI        2.4153136    1289695.19
## Walks      0.8076082    1817572.99
## Years      4.5232623     214635.30
## CAtBat     8.5666766     585003.54
## CHits     15.5288606    3561683.94
## CHmRun     3.6901168    1185423.19
## CRuns     19.0464540    4553862.28
## CRBI       2.9481110     750290.57
## CWalks     3.1612868     554650.19
## League    -1.3322827      25947.45
## Division  -2.3439124      21054.09
## PutOuts    0.4539780     646626.09
## Assists    4.8420270     395780.12
## Errors    -2.8856907     179432.51
## NewLeague -1.0584685      46851.10
varImpPlot(salary.rf)

shattrain=predict(salary.rf,newdata=bbtrain)
shattest=predict(salary.rf,newdata=bbtest)
paste("training MSE = ", mean((saltrain-shattrain)^2))
## [1] "training MSE =  13443.5089903326"
paste("test MSE = ", mean((saltest-shattest)^2))
## [1] "test MSE =  88464.6705531581"

Note on cross-validation

The general procedure is as follows:

  1. Shuffle the dataset randomly.
  2. Split the dataset into k groups
  3. For each unique group:
    1. Take the group as a hold out or test data set
    2. Take the remaining groups as a training data set
    3. Fit a model on the training set and evaluate it on the test set
    4. Retain the evaluation score and discard the model
  4. Summarize the skill of the model using the sample of model evaluation scores

Importantly, each observation in the data sample is assigned to an individual group and stays in that group for the duration of the procedure.

This means that each sample is given the opportunity to be used in the hold out set 1 time and used to train the model k-1 times.

# example 3-fold cv
d <- c(0.1, 0.2, 0.3, 0.4, 0.5, 0.6) # data

# partitiom data into 3 folds
d1 <- c(0.2, 0.5)
d2 <- c(0.1, 0.3)
d3 <- c(0.4, 0.6)

# we have 3 models
# model 1 is fitted on d-d1 = d2+d3 and we use d1 for predictions
# etc