What should we do when we have no idea of the function that generates the data \(y=f(x)\)?
We can use polynomial regression (Monday’s lecture)
Or we can use another idea which is based on the following ‘decision’ model.
Here is the magic formula. We will construct an estimate of God’s function in the following manner:
\[ \hat{y} = \sum_i B_i(x) v_i \] where each parameter \(B_i\) is a boolean expression and only one boolean is true for any given value \(x\) of the predictor variables.
We will construct the boolean value by combining intervals for each predictor of the form \(x < c\) or \(x > c\).
For example, if I have two predictors \(x,z\), I may construct the boolean expression
(x < 1.5) and (z > 3.)
this is a rectangle in the space \((x,z)\) of predictors.
As explained in the last lecture, we have to find the right model complexity
For regression trees, we can decimate (prune) trees to decrease complexity
We will use the idea of regularization to find the sweetest complexity
It will be given by an optimal value of a hyperparameter (penalizing complexity)
We will use cross-validation for finding the sweet spot
Figure 8.1 of “Introduction to Statistical Learning” shows the initial part of a regression tree which is used to model the logarithm of a baseball players salary, using number of years played, and number of hits in the previous year as predictors.
The data Hitters is part of the ISLR library. The function tree used to fit a regression tree is part of the library tree. After loading the data, records with missing salaries are removed, and only the variables Salary, Years and Hits are retained. Salary is then transformed to log(Salary).
library(ISLR) #contains the Hitters data set
library(tree) #contains functions for tree construction
baseball=Hitters[!is.na(Hitters$Salary),] #remove records with missing salary
baseb=baseball[,c("Salary","Years","Hits")] #keep only 3 variables
baseb$Salary=log(baseb$Salary) #model log(salary)
The function tree is used to construct the regression tree, which is then printed, plotted, and text is added to the plot.
basebtree=tree(baseb) #construct the tree
basebtree #print the tree
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 263 207.200 5.927
## 2) Years < 4.5 90 42.350 5.107
## 4) Years < 3.5 62 23.010 4.892
## 8) Hits < 114 43 17.150 4.727
## 16) Hits < 40.5 5 10.400 5.511 *
## 17) Hits > 40.5 38 3.280 4.624 *
## 9) Hits > 114 19 2.069 5.264 *
## 5) Years > 3.5 28 10.130 5.583 *
## 3) Years > 4.5 173 72.710 6.354
## 6) Hits < 117.5 90 28.090 5.998
## 12) Years < 6.5 26 7.238 5.689 *
## 13) Years > 6.5 64 17.350 6.124
## 26) Hits < 50.5 12 2.689 5.730 *
## 27) Hits > 50.5 52 12.370 6.215 *
## 7) Hits > 117.5 83 20.880 6.740 *
plot(basebtree) #plot the tree
text(basebtree) #label the tree
Following are a number of functions which can be used to construct bits of the tree.
n=dim(baseb)[1] #number of cases
M=dim(baseb)[2] #variables
SSE=function(data,pred,cut){
#split data into two sets, data1 and data2
#returns sums of squares and means for each of data1,data2
data1=data[pred<=cut]
data2=data[pred>cut]
#calculate the error sum of squares for each set
M1=mean(data1)
M2=mean(data2)
SSE1=sum((data1-M1)^2)
SSE2=sum((data2-M2)^2)
#calculate the total error sum of squares
return(c(SSE1,SSE2,M1,M2))
}
gencuts=function(x){
#calculates all possible cutpoints for x
xuo=sort(unique(x)) #sorted unique values of x
cuts=rep(0,(length(xuo)-1))
nc=length(cuts)
for (i in 1:nc)cuts[i]=mean(xuo[i:(i+1)]) #midpoints
return(cuts)}
cut1=function(y,x){
cuts=gencuts(x) #get cutpoints for variable x
ncuts=length(cuts) #number of cut points
SSEs=matrix(rep(0,5*ncuts),byrow=T,ncol=5)
for (i in 1:length(cuts)){
SSEs[i,2:5]=SSE(y,x,cuts[i])
SSEs[i,1]=sum(SSEs[i,2:3])}
minp=(1:length(cuts))[SSEs[,1]==min(SSEs[,1])]
#returns best cut point, sums of squares for 2 branches, means for 2 branches
return(c(cuts[minp],SSEs[minp,]))
}
Here’s an example of what gencuts does.
x <- baseb$Hits
xuo=sort(unique(x)) #sorted unique values of x
plot(xuo)
table(baseb$Hits)
##
## 1 4 27 32 37 39 40 41 42 43 44 46 47 49 51 52 53 54 55 56
## 1 1 1 2 1 3 1 3 1 3 1 2 2 2 1 3 4 3 1 4
## 57 58 60 61 63 64 65 66 68 69 70 71 72 73 74 75 76 77 78 80
## 2 2 3 1 1 1 1 2 5 1 6 1 1 3 1 1 5 4 3 1
## 81 82 83 84 85 86 87 90 91 92 93 94 95 96 97 99 101 102 103 104
## 4 2 3 2 2 2 1 1 2 4 2 3 1 4 1 1 5 2 5 1
## 106 108 109 110 112 113 114 115 116 117 118 119 120 122 123 124 125 126 127 128
## 1 2 1 3 2 4 1 1 1 2 2 3 3 2 2 1 1 2 3 2
## 129 130 131 132 133 135 136 137 138 139 140 141 142 144 145 146 147 148 149 150
## 1 1 3 2 1 2 3 3 2 3 1 3 2 3 2 1 3 2 2 1
## 151 152 154 157 158 159 160 161 163 167 168 169 170 171 172 174 177 178 179 183
## 1 4 2 2 2 3 1 1 5 1 3 3 2 2 1 2 1 1 1 1
## 184 186 198 200 207 210 211 213 223 238
## 1 1 1 2 1 1 1 1 1 1
gencuts(baseb$Hits)
## [1] 2.5 15.5 29.5 34.5 38.0 39.5 40.5 41.5 42.5 43.5 45.0 46.5
## [13] 48.0 50.0 51.5 52.5 53.5 54.5 55.5 56.5 57.5 59.0 60.5 62.0
## [25] 63.5 64.5 65.5 67.0 68.5 69.5 70.5 71.5 72.5 73.5 74.5 75.5
## [37] 76.5 77.5 79.0 80.5 81.5 82.5 83.5 84.5 85.5 86.5 88.5 90.5
## [49] 91.5 92.5 93.5 94.5 95.5 96.5 98.0 100.0 101.5 102.5 103.5 105.0
## [61] 107.0 108.5 109.5 111.0 112.5 113.5 114.5 115.5 116.5 117.5 118.5 119.5
## [73] 121.0 122.5 123.5 124.5 125.5 126.5 127.5 128.5 129.5 130.5 131.5 132.5
## [85] 134.0 135.5 136.5 137.5 138.5 139.5 140.5 141.5 143.0 144.5 145.5 146.5
## [97] 147.5 148.5 149.5 150.5 151.5 153.0 155.5 157.5 158.5 159.5 160.5 162.0
## [109] 165.0 167.5 168.5 169.5 170.5 171.5 173.0 175.5 177.5 178.5 181.0 183.5
## [121] 185.0 192.0 199.0 203.5 208.5 210.5 212.0 218.0 230.5
Try all possible cutpoints \(C\), and partition the dataset according to whether \(Hits<C\) or \(Hits>C\).
For each subset, calculate the mean and sum of squares of the dependent variable \(Y\). Here \(Y\) is the \(log(Salary)\).
Calculate the total sum of squares \(SSE = SSE_1 + SSE_2\)
where SSE1 and SSE2 are the error sums of squares for the first and second subsets.
pick the cutopoint \(C\) which minimizes \(SSE\). This gives us the minimum sum of squares when we partition the dataset using Hits.
Do the same thing, but using the variable Years
Choose the \(best\) predictor variable at this stage as that variable which minimizes SSE.
# find the best variable for the first partition
for (j in 2:3)print(cut1(baseb[,1],baseb[,j]))
## [1] 4.500000 115.058475 42.353165 72.705310 5.106790 6.354036
## [1] 117.500000 160.971530 96.510476 64.461054 5.566327 6.413784
From the output, we see that the best choice is to partition using \(Years<4.5\) and \(Years>4.5\).
baseb1=baseb[baseb[,2]<4.5,]
baseb2=baseb[baseb[,2]>4.5,]
for (j in 2:3)print(cut1(baseb2[,1],baseb2[,j]))
## [1] 6.500000 69.877019 27.582434 42.294585 6.164227 6.440167
## [1] 117.500000 48.976782 28.093708 20.883074 5.998380 6.739687
the next split on this branch uses \(Hits<117.5\) vs \(Hits>117.5\).
continue recursively down all branches, until some stopping criterion is satisfied.
baseball=Hitters[!is.na(Hitters$Salary),]
basebtree=tree(Salary~.,data=baseball) #construct the tree
basebtree #print the tree
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 263 53320000 535.9
## 2) CHits < 450 117 5931000 227.9
## 4) AtBat < 147 5 2940000 709.5 *
## 5) AtBat > 147 112 1779000 206.4
## 10) CRBI < 114.5 74 302100 141.8 *
## 11) CRBI > 114.5 38 567200 332.1 *
## 3) CHits > 450 146 27390000 782.8
## 6) Walks < 61 104 9470000 649.6
## 12) AtBat < 395.5 53 2859000 510.0 *
## 13) AtBat > 395.5 51 4504000 794.7
## 26) PutOuts < 771 45 2358000 746.4 *
## 27) PutOuts > 771 6 1255000 1157.0 *
## 7) Walks > 61 42 11500000 1113.0
## 14) RBI < 73.5 22 3148000 885.3
## 28) PutOuts < 239.5 7 1739000 1156.0 *
## 29) PutOuts > 239.5 15 656300 758.9 *
## 15) RBI > 73.5 20 5967000 1363.0
## 30) Years < 13.5 14 3767000 1521.0
## 60) CAtBat < 3814.5 8 529600 1141.0 *
## 61) CAtBat > 3814.5 6 541500 2028.0 *
## 31) Years > 13.5 6 1026000 992.5 *
plot(basebtree) #plot the tree
text(basebtree) #label the tree
Remember polynomial regression? Here the number of leaves is the complexity measure of the model!
Let us vary the complexity of our model!
mysize = 2
bp2=prune.tree(basebtree,best=mysize)
plot(bp2)
text(bp2)
mysize = 5
bp5=prune.tree(basebtree,best=mysize)
plot(bp5)
text(bp5)
mysize = 8
bp8=prune.tree(basebtree,best=mysize)
plot(bp8)
text(bp8)
Algorithm 8.1 in ISLR
For each \(\alpha\), evaluate the penalized sum of squares
\[PSSE = \sum_{m=1}^{|T|} \sum_{i: x_i \in R_m} (y_i - \hat y_{R_m})^2 + \alpha |T|\]
where \(|T|\) is the number of terminal nodes in the tree
\(R_m\) is the subset of the predictor space corresponding to the \(m\)’th terminal node.
for each \(\alpha\) there is a tree \(T\) such that PSSE is as small as possible.
use \(k-fold\) cross validation to choose the value of \(\alpha\) which minimizes PSSE. Equivalently, choose the size of tree which minimizes PSSE. The default value of \(k\) is 10.
basebcv=cv.tree(basebtree)
plot(basebcv$size,basebcv$dev,type='b')
bestsize=basebcv$size[basebcv$dev==min(basebcv$dev)]
The cross validated choice of tree size is a tree with 7 terminal nodes.
Use prune.tree to find the best tree of this size.
baseball.prune=prune.tree(basebtree,best=bestsize)
plot(baseball.prune)
text(baseball.prune)
+** Random forests** and/or bagging tend to give gives good predictions, with reduced variability.
library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
n=nrow(baseball)
attach(baseball)
n2=ifelse(n%%2==0,n/2,(n+1)/2)
index=sample(1:n,n2,replace=F)
bbtrain=baseball[index,]
bbtest=baseball[-index,]
saltrain=Salary[index]
saltest=Salary[-index]
salary.rf=randomForest(Salary~.,data=bbtrain,mtry=3,importance=T)
importance(salary.rf)
## %IncMSE IncNodePurity
## AtBat 6.9069895 1120062.25
## Hits 8.4886165 1609603.87
## HmRun 1.8832088 555266.63
## Runs 4.6433060 1129279.29
## RBI 5.0863375 1097901.53
## Walks 4.2615930 1390003.92
## Years 7.1000876 783558.78
## CAtBat 10.7754552 2132581.63
## CHits 11.4841486 2223612.77
## CHmRun 6.5103673 1435517.01
## CRuns 11.0300561 2225031.99
## CRBI 7.7508920 1707745.82
## CWalks 5.4700093 1606160.86
## League 1.2934688 57653.97
## Division 0.3003970 42201.65
## PutOuts 0.6949388 502469.37
## Assists 1.6269734 379456.18
## Errors 0.9425467 351362.91
## NewLeague -0.8328295 92767.02
varImpPlot(salary.rf)
shattrain=predict(salary.rf,newdata=bbtrain)
shattest=predict(salary.rf,newdata=bbtest)
paste("training MSE = ", mean((saltrain-shattrain)^2))
## [1] "training MSE = 16096.392440201"
paste("test MSE = ", mean((saltest-shattest)^2))
## [1] "test MSE = 91014.6007220454"
salary.rf=randomForest(Salary~.,data=bbtrain,mtry=19,importance=T)
importance(salary.rf)
## %IncMSE IncNodePurity
## AtBat 7.1476444 1262635.25
## Hits 11.4508217 1906264.77
## HmRun 0.2306228 401727.64
## Runs 1.8570363 1315685.82
## RBI 2.4153136 1289695.19
## Walks 0.8076082 1817572.99
## Years 4.5232623 214635.30
## CAtBat 8.5666766 585003.54
## CHits 15.5288606 3561683.94
## CHmRun 3.6901168 1185423.19
## CRuns 19.0464540 4553862.28
## CRBI 2.9481110 750290.57
## CWalks 3.1612868 554650.19
## League -1.3322827 25947.45
## Division -2.3439124 21054.09
## PutOuts 0.4539780 646626.09
## Assists 4.8420270 395780.12
## Errors -2.8856907 179432.51
## NewLeague -1.0584685 46851.10
varImpPlot(salary.rf)
shattrain=predict(salary.rf,newdata=bbtrain)
shattest=predict(salary.rf,newdata=bbtest)
paste("training MSE = ", mean((saltrain-shattrain)^2))
## [1] "training MSE = 13443.5089903326"
paste("test MSE = ", mean((saltest-shattest)^2))
## [1] "test MSE = 88464.6705531581"
The general procedure is as follows:
Importantly, each observation in the data sample is assigned to an individual group and stays in that group for the duration of the procedure.
This means that each sample is given the opportunity to be used in the hold out set 1 time and used to train the model k-1 times.
# example 3-fold cv
d <- c(0.1, 0.2, 0.3, 0.4, 0.5, 0.6) # data
# partitiom data into 3 folds
d1 <- c(0.2, 0.5)
d2 <- c(0.1, 0.3)
d3 <- c(0.4, 0.6)
# we have 3 models
# model 1 is fitted on d-d1 = d2+d3 and we use d1 for predictions
# etc