Training set
Testing set
Validation set
Cross-validation: used to estimate the error rate of classification
Validation: used to build the ‘right level of complexity’ model
Are these fundamentally different?
You will see sites that say that CV has nothing to do with model building:
https://stackoverflow.com/questions/2314850/help-understanding-cross-validation-and-decision-trees
But you will also see lectures (Stanford) that say the opposite:
e.g. Find the optimal subtree by cross validation
https://web.stanford.edu/class/stats202/content/lec19.pdf
This is also the approach we used in class.
(Figure: Dec. 28, 1987 - Landing of the Soyuz-TM3 craft carrying the Soyuz-TM2 crew Y. Romanenko, Aleksandrov & Levchenko)
TT, L10, LpO, VS, stratified-CV, kF-CV, nested cross-VALIDCEPTION
https://www.cs.utoronto.ca/~fidler/teaching/2015/slides/CSC411/tutorial3_CrossVal-DTs.pdf
RECURSIVE PARTITIONING (library RPART in R)
https://www.statmethods.net/advstats/cart.html
BSP:
BSP for collision detection and polygon visibility in videogames:
BSP for creating games:
inter and intra-object representation:
Recall the CV procedure:
heart=read.csv("http://faculty.marshall.usc.edu/gareth-james/ISL/Heart.csv")
head(heart)
## X Age Sex ChestPain RestBP Chol Fbs RestECG MaxHR ExAng Oldpeak Slope Ca
## 1 1 63 1 typical 145 233 1 2 150 0 2.3 3 0
## 2 2 67 1 asymptomatic 160 286 0 2 108 1 1.5 2 3
## 3 3 67 1 asymptomatic 120 229 0 2 129 1 2.6 2 2
## 4 4 37 1 nonanginal 130 250 0 0 187 0 3.5 3 0
## 5 5 41 0 nontypical 130 204 0 2 172 0 1.4 1 0
## 6 6 56 1 nontypical 120 236 0 0 178 0 0.8 1 0
## Thal AHD
## 1 fixed No
## 2 normal Yes
## 3 reversable Yes
## 4 normal No
## 5 normal No
## 6 normal No
summary(heart)
## X Age Sex ChestPain
## Min. : 1.0 Min. :29.00 Min. :0.0000 asymptomatic:144
## 1st Qu.: 76.5 1st Qu.:48.00 1st Qu.:0.0000 nonanginal : 86
## Median :152.0 Median :56.00 Median :1.0000 nontypical : 50
## Mean :152.0 Mean :54.44 Mean :0.6799 typical : 23
## 3rd Qu.:227.5 3rd Qu.:61.00 3rd Qu.:1.0000
## Max. :303.0 Max. :77.00 Max. :1.0000
##
## RestBP Chol Fbs RestECG
## Min. : 94.0 Min. :126.0 Min. :0.0000 Min. :0.0000
## 1st Qu.:120.0 1st Qu.:211.0 1st Qu.:0.0000 1st Qu.:0.0000
## Median :130.0 Median :241.0 Median :0.0000 Median :1.0000
## Mean :131.7 Mean :246.7 Mean :0.1485 Mean :0.9901
## 3rd Qu.:140.0 3rd Qu.:275.0 3rd Qu.:0.0000 3rd Qu.:2.0000
## Max. :200.0 Max. :564.0 Max. :1.0000 Max. :2.0000
##
## MaxHR ExAng Oldpeak Slope
## Min. : 71.0 Min. :0.0000 Min. :0.00 Min. :1.000
## 1st Qu.:133.5 1st Qu.:0.0000 1st Qu.:0.00 1st Qu.:1.000
## Median :153.0 Median :0.0000 Median :0.80 Median :2.000
## Mean :149.6 Mean :0.3267 Mean :1.04 Mean :1.601
## 3rd Qu.:166.0 3rd Qu.:1.0000 3rd Qu.:1.60 3rd Qu.:2.000
## Max. :202.0 Max. :1.0000 Max. :6.20 Max. :3.000
##
## Ca Thal AHD
## Min. :0.0000 fixed : 18 No :164
## 1st Qu.:0.0000 normal :166 Yes:139
## Median :0.0000 reversable:117
## Mean :0.6722 NA's : 2
## 3rd Qu.:1.0000
## Max. :3.0000
## NA's :4
nmiss=apply(is.na(heart),1,sum) #number of missing values, by row
heartm=heart[nmiss!=0,] #cases with missing values
heart=heart[nmiss==0,] #remove cases with missing values
head(heartm)
## X Age Sex ChestPain RestBP Chol Fbs RestECG MaxHR ExAng Oldpeak Slope
## 88 88 53 0 nonanginal 128 216 0 2 115 0 0.0 1
## 167 167 52 1 nonanginal 138 223 0 0 169 0 0.0 1
## 193 193 43 1 asymptomatic 132 247 1 2 143 1 0.1 2
## 267 267 52 1 asymptomatic 128 204 1 0 156 1 1.0 2
## 288 288 58 1 nontypical 125 220 0 0 144 0 0.4 2
## 303 303 38 1 nonanginal 138 175 0 0 173 0 0.0 1
## Ca Thal AHD
## 88 0 <NA> No
## 167 NA normal No
## 193 NA reversable Yes
## 267 0 <NA> Yes
## 288 NA reversable No
## 303 NA normal No
library(tree)
heart.tree=tree(AHD~., data=heart)
plot(heart.tree)
text(heart.tree)
print(heart.tree)
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 297 409.900 No ( 0.53872 0.46128 )
## 2) Thal: normal 164 175.100 No ( 0.77439 0.22561 )
## 4) Ca < 0.5 115 81.150 No ( 0.88696 0.11304 )
## 8) Age < 57.5 80 25.590 No ( 0.96250 0.03750 )
## 16) MaxHR < 152.5 17 15.840 No ( 0.82353 0.17647 )
## 32) Chol < 226.5 8 0.000 No ( 1.00000 0.00000 ) *
## 33) Chol > 226.5 9 11.460 No ( 0.66667 0.33333 ) *
## 17) MaxHR > 152.5 63 0.000 No ( 1.00000 0.00000 ) *
## 9) Age > 57.5 35 41.880 No ( 0.71429 0.28571 )
## 18) Fbs < 0.5 29 37.360 No ( 0.65517 0.34483 ) *
## 19) Fbs > 0.5 6 0.000 No ( 1.00000 0.00000 ) *
## 5) Ca > 0.5 49 67.910 No ( 0.51020 0.48980 )
## 10) ChestPain: nonanginal,nontypical,typical 29 32.050 No ( 0.75862 0.24138 )
## 20) X < 230.5 19 7.835 No ( 0.94737 0.05263 ) *
## 21) X > 230.5 10 13.460 Yes ( 0.40000 0.60000 ) *
## 11) ChestPain: asymptomatic 20 16.910 Yes ( 0.15000 0.85000 )
## 22) Sex < 0.5 6 8.318 No ( 0.50000 0.50000 ) *
## 23) Sex > 0.5 14 0.000 Yes ( 0.00000 1.00000 ) *
## 3) Thal: fixed,reversable 133 149.000 Yes ( 0.24812 0.75188 )
## 6) Ca < 0.5 59 81.370 Yes ( 0.45763 0.54237 )
## 12) ExAng < 0.5 33 42.010 No ( 0.66667 0.33333 )
## 24) Age < 51 13 17.320 Yes ( 0.38462 0.61538 )
## 48) ChestPain: nonanginal,nontypical 5 5.004 No ( 0.80000 0.20000 ) *
## 49) ChestPain: asymptomatic,typical 8 6.028 Yes ( 0.12500 0.87500 ) *
## 25) Age > 51 20 16.910 No ( 0.85000 0.15000 ) *
## 13) ExAng > 0.5 26 25.460 Yes ( 0.19231 0.80769 )
## 26) Oldpeak < 1.55 11 15.160 Yes ( 0.45455 0.54545 )
## 52) Chol < 240.5 6 5.407 No ( 0.83333 0.16667 ) *
## 53) Chol > 240.5 5 0.000 Yes ( 0.00000 1.00000 ) *
## 27) Oldpeak > 1.55 15 0.000 Yes ( 0.00000 1.00000 ) *
## 7) Ca > 0.5 74 41.650 Yes ( 0.08108 0.91892 )
## 14) RestECG < 0.5 34 31.690 Yes ( 0.17647 0.82353 )
## 28) MaxHR < 145 20 7.941 Yes ( 0.05000 0.95000 ) *
## 29) MaxHR > 145 14 18.250 Yes ( 0.35714 0.64286 )
## 58) MaxHR < 158 5 5.004 No ( 0.80000 0.20000 ) *
## 59) MaxHR > 158 9 6.279 Yes ( 0.11111 0.88889 ) *
## 15) RestECG > 0.5 40 0.000 Yes ( 0.00000 1.00000 ) *
heartcv=cv.tree(heart.tree, FUN=prune.misclass)
print(heartcv)
## $size
## [1] 19 14 11 8 6 4 2 1
##
## $dev
## [1] 79 79 75 69 65 79 90 142
##
## $k
## [1] -Inf 0.0 1.0 2.0 3.0 5.5 7.0 67.0
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(heartcv$size,heartcv$dev,type='b')
bestsize=heartcv$size[heartcv$dev==min(heartcv$dev)]
297 > rand <- sample(K, length(m[[1L]]), replace = TRUE) > rand [1] 10 9 2 2 3 8 7 4 6 5 3 10 4 9 9 4 3 3 6 9 5 9 1 9 [25] 8 3 1 8 2 1 10 9 7 2 4 6 4 7 5 8 9 8 3 8 3 2 8 1 [49] 2 4 2 6 3 6 7 3 2 8 10 9 3 6 1 6 10 3 10 1 5 7 9 6 [73] 7 10 4 6 2 3 1 2 9 2 7 2 2 5 7 8 6 9 10 2 8 1 2 2 [97] 5 4 10 7 10 1 2 6 2 2 2 1 8 1 3 9 8 9 4 5 3 1 10 1 [121] 5 3 2 9 10 1 2 2 1 1 7 3 4 3 9 4 7 5 2 4 7 3 9 8 [145] 10 8 1 6 10 9 8 10 4 6 8 8 3 8 6 5 5 10 10 5 1 6 2 6 [169] 10 9 3 10 3 10 9 6 2 6 10 4 10 1 7 10 10 4 6 9 4 4 2 3 [193] 1 8 2 5 6 5 8 8 3 8 4 4 6 4 3 9 4 7 10 4 5 10 7 6 [217] 7 6 9 3 3 8 7 5 10 1 2 8 5 10 7 3 1 2 9 8 6 3 5 9 [241] 9 3 1 8 5 9 8 4 7 4 2 4 7 5 5 9 2 9 9 2 4 9 10 7 [265] 3 9 7 5 6 10 7 9 10 2 8 1 1 2 3 4 8 7 10 6 9 9 2 10 [289] 1 8 6 4 10 7 9 7 10 init <- do.call(FUN=prune.misclass, c(list(object))) > init <- do.call(prune.misclass, c(list(object))) > init$k [1] -Inf 0.0 1.0 2.0 3.0 5.5 7.0 67.0 k cost-complexity parameter defining either a specific subtree of tree (k a scalar) or the (optional) sequence of subtrees minimizing the cost-complexity measure (k a vector). If missing, k is determined algorithmically. > init <- do.call(prune.misclass, c(list(object))) > init$k [1] -Inf 0.0 1.0 2.0 3.0 5.5 7.0 67.0 $size [1] 19 14 11 8 6 4 2 1 $dev [1] 66 66 56 60 59 71 79 138 $k [1] -Inf 0.0 1.0 2.0 3.0 5.5 7.0 67.0 $method [1] "misclass" attr(,"class")
HERE WE ARE with the Tree Homotopy Process of CSP
#$size
#[1] 19 14 11 8 6 4 2 1
#$dev
#[1] 66 66 56 60 59 71 79 138
#$k
#[1] -Inf 0.0 1.0 2.0 3.0 5.5 7.0 67.0
MY MAIN POINT HERE (correction of a wrong statement I made during my previous lecture on classification trees)
Do you see where the grid of alpha values is?
The grid of SIZES is used instead for plotting !
Are there any jumps in sizes? (tree homotopy process - THP- of Breiman 1984)
José R. Almirall, Tatiana Trejos Advances in technology provide forensic scientists with better tools to detect, to identify, and to individualize small amounts of trace evidence that have been left at a crime scene. The analysis of glass fragments can be useful in solving cases such as hit and run, burglaries, kidnappings, and bombings. The value of glass as "evidentiary material" lies in its inherent characteristics such as: (a) it is a fragile material that is often broken and hence commonly found in various types of crime scenes, (b) it can be easily transferred from the broken source to the scene, suspect, and/or victim, (c) it is relatively persistent, (d) it is chemically stable, and (e) it has measurable physical and chemical properties that can provide significant evidence of an association between the recovered glass fragments and the source of the broken glass. Forensic scientists have dedicated considerable effort to study and improve the detection and discrimination capabilities of analytical techniques in order to enhance the quality of information obtained from glass fragments. This article serves as a review of the developments in the application of both traditional and novel methods of glass analysis. The greatest progress has been made with respect to the incorporation of automated refractive index measurements and elemental analysis to the analytical scheme. Glass examiners have applied state-of-the-art technology including elemental analysis by sensitive methods such as ICPMS and LA-ICP-MS. A review of the literature regarding transfer, persistence, and interpretation of glass is also presented. LESS
library(tree)
data(fgl, package="MASS")
fgl.tr <- tree(type ~ ., fgl)
#plot(print(fgl.tr))
fgl.cv <- cv.tree(fgl.tr,, prune.tree)
for(i in 2:5) fgl.cv$dev <- fgl.cv$dev +
cv.tree(fgl.tr,, prune.tree)$dev
fgl.cv$dev <- fgl.cv$dev/5
plot(fgl.cv)
fgl.cv <- cv.tree(fgl.tr,, prune.tree)
plot(fgl.cv)
fgl.cv <- cv.tree(fgl.tr,, prune.tree)
plot(fgl.cv)
fgl.cv <- cv.tree(fgl.tr,, prune.tree)
plot(fgl.cv)
fgl.cv <- cv.tree(fgl.tr,, prune.tree)
fgl.cv
## $size
## [1] 20 19 18 17 16 15 14 13 12 11 10 9 8 5 4 3 2 1
##
## $dev
## [1] 549.3466 533.6819 533.6819 523.6940 517.9623 516.3673 515.6067 515.9651
## [9] 504.7484 502.9228 491.7872 491.8353 477.6573 481.5056 492.7666 522.9014
## [17] 501.2550 653.5372
##
## $k
## [1] -Inf 6.765927 6.771674 8.099535 8.940479 9.751469
## [7] 9.873409 9.994950 10.356555 13.077082 16.041350 16.132081
## [13] 18.672227 22.627954 38.229167 50.117941 55.081551 166.958846
##
## $method
## [1] "deviance"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
MY MAIN POINT HERE (correction of a wrong statement I made during my previous lecture on classification trees)
Do you see where the grid of alpha values is?
The grid of SIZES is used instead for plotting !
Are there any jumps in sizes? (tree homotopy process - THP- of Breiman 1984)
(src:daviddalpiaz.github.io/r4sl/trees.html)
library(tree)
library(ISLR)
data(Carseats)
#?Carseats
str(Carseats)
## 'data.frame': 400 obs. of 11 variables:
## $ Sales : num 9.5 11.22 10.06 7.4 4.15 ...
## $ CompPrice : num 138 111 113 117 141 124 115 136 132 132 ...
## $ Income : num 73 48 35 100 64 113 105 81 110 113 ...
## $ Advertising: num 11 16 10 4 3 13 0 15 0 0 ...
## $ Population : num 276 260 269 466 340 501 45 425 108 131 ...
## $ Price : num 120 83 80 97 128 72 108 120 124 124 ...
## $ ShelveLoc : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
## $ Age : num 42 65 59 55 38 78 71 67 76 76 ...
## $ Education : num 17 10 12 14 13 16 15 10 10 17 ...
## $ Urban : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
## $ US : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
Carseats$Sales = as.factor(ifelse(Carseats$Sales <= 8, "Low", "High"))
str(Carseats)
## 'data.frame': 400 obs. of 11 variables:
## $ Sales : Factor w/ 2 levels "High","Low": 1 1 1 2 2 1 2 1 2 2 ...
## $ CompPrice : num 138 111 113 117 141 124 115 136 132 132 ...
## $ Income : num 73 48 35 100 64 113 105 81 110 113 ...
## $ Advertising: num 11 16 10 4 3 13 0 15 0 0 ...
## $ Population : num 276 260 269 466 340 501 45 425 108 131 ...
## $ Price : num 120 83 80 97 128 72 108 120 124 124 ...
## $ ShelveLoc : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
## $ Age : num 42 65 59 55 38 78 71 67 76 76 ...
## $ Education : num 17 10 12 14 13 16 15 10 10 17 ...
## $ Urban : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
## $ US : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
# Fit an unpruned classification tree using all of the predictors.
seat_tree = tree(Sales ~ ., data = Carseats)
# seat_tree = tree(Sales ~ ., data = Carseats,
# control = tree.control(nobs = nrow(Carseats), minsize = 10))
summary(seat_tree)
##
## Classification tree:
## tree(formula = Sales ~ ., data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "US" "Income" "CompPrice"
## [6] "Population" "Advertising" "Age"
## Number of terminal nodes: 27
## Residual mean deviance: 0.4575 = 170.7 / 373
## Misclassification error rate: 0.09 = 36 / 400
plot(seat_tree)
text(seat_tree, pretty = 0)
title(main = "Sales of carseats: full classification tree: no pruning!")
# details of the splits.
seat_tree
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 400 541.500 Low ( 0.41000 0.59000 )
## 2) ShelveLoc: Good 85 90.330 High ( 0.77647 0.22353 )
## 4) Price < 135 68 49.260 High ( 0.88235 0.11765 )
## 8) US: No 17 22.070 High ( 0.64706 0.35294 )
## 16) Price < 109 8 0.000 High ( 1.00000 0.00000 ) *
## 17) Price > 109 9 11.460 Low ( 0.33333 0.66667 ) *
## 9) US: Yes 51 16.880 High ( 0.96078 0.03922 ) *
## 5) Price > 135 17 22.070 Low ( 0.35294 0.64706 )
## 10) Income < 46 6 0.000 Low ( 0.00000 1.00000 ) *
## 11) Income > 46 11 15.160 High ( 0.54545 0.45455 ) *
## 3) ShelveLoc: Bad,Medium 315 390.600 Low ( 0.31111 0.68889 )
## 6) Price < 92.5 46 56.530 High ( 0.69565 0.30435 )
## 12) Income < 57 10 12.220 Low ( 0.30000 0.70000 )
## 24) CompPrice < 110.5 5 0.000 Low ( 0.00000 1.00000 ) *
## 25) CompPrice > 110.5 5 6.730 High ( 0.60000 0.40000 ) *
## 13) Income > 57 36 35.470 High ( 0.80556 0.19444 )
## 26) Population < 207.5 16 21.170 High ( 0.62500 0.37500 ) *
## 27) Population > 207.5 20 7.941 High ( 0.95000 0.05000 ) *
## 7) Price > 92.5 269 299.800 Low ( 0.24535 0.75465 )
## 14) Advertising < 13.5 224 213.200 Low ( 0.18304 0.81696 )
## 28) CompPrice < 124.5 96 44.890 Low ( 0.06250 0.93750 )
## 56) Price < 106.5 38 33.150 Low ( 0.15789 0.84211 )
## 112) Population < 177 12 16.300 Low ( 0.41667 0.58333 )
## 224) Income < 60.5 6 0.000 Low ( 0.00000 1.00000 ) *
## 225) Income > 60.5 6 5.407 High ( 0.83333 0.16667 ) *
## 113) Population > 177 26 8.477 Low ( 0.03846 0.96154 ) *
## 57) Price > 106.5 58 0.000 Low ( 0.00000 1.00000 ) *
## 29) CompPrice > 124.5 128 150.200 Low ( 0.27344 0.72656 )
## 58) Price < 122.5 51 70.680 High ( 0.50980 0.49020 )
## 116) ShelveLoc: Bad 11 6.702 Low ( 0.09091 0.90909 ) *
## 117) ShelveLoc: Medium 40 52.930 High ( 0.62500 0.37500 )
## 234) Price < 109.5 16 7.481 High ( 0.93750 0.06250 ) *
## 235) Price > 109.5 24 32.600 Low ( 0.41667 0.58333 )
## 470) Age < 49.5 13 16.050 High ( 0.69231 0.30769 ) *
## 471) Age > 49.5 11 6.702 Low ( 0.09091 0.90909 ) *
## 59) Price > 122.5 77 55.540 Low ( 0.11688 0.88312 )
## 118) CompPrice < 147.5 58 17.400 Low ( 0.03448 0.96552 ) *
## 119) CompPrice > 147.5 19 25.010 Low ( 0.36842 0.63158 )
## 238) Price < 147 12 16.300 High ( 0.58333 0.41667 )
## 476) CompPrice < 152.5 7 5.742 High ( 0.85714 0.14286 ) *
## 477) CompPrice > 152.5 5 5.004 Low ( 0.20000 0.80000 ) *
## 239) Price > 147 7 0.000 Low ( 0.00000 1.00000 ) *
## 15) Advertising > 13.5 45 61.830 High ( 0.55556 0.44444 )
## 30) Age < 54.5 25 25.020 High ( 0.80000 0.20000 )
## 60) CompPrice < 130.5 14 18.250 High ( 0.64286 0.35714 )
## 120) Income < 100 9 12.370 Low ( 0.44444 0.55556 ) *
## 121) Income > 100 5 0.000 High ( 1.00000 0.00000 ) *
## 61) CompPrice > 130.5 11 0.000 High ( 1.00000 0.00000 ) *
## 31) Age > 54.5 20 22.490 Low ( 0.25000 0.75000 )
## 62) CompPrice < 122.5 10 0.000 Low ( 0.00000 1.00000 ) *
## 63) CompPrice > 122.5 10 13.860 Low ( 0.50000 0.50000 )
## 126) Price < 125 5 0.000 High ( 1.00000 0.00000 ) *
## 127) Price > 125 5 0.000 Low ( 0.00000 1.00000 ) *
PLAN: evaluate the overfit of a given test-train partition
simple: 1. predict answers of test and train and 2. compare each to truth (confusion table) and 3. compare tables
# try a test-train split
dim(Carseats)
## [1] 400 11
## [1] 400 11
set.seed(2)
seat_idx = sample(1:nrow(Carseats), 200)
seat_trn = Carseats[seat_idx,]
seat_tst = Carseats[-seat_idx,]
seat_tree = tree(Sales ~ ., data = seat_trn)
summary(seat_tree)
##
## Classification tree:
## tree(formula = Sales ~ ., data = seat_trn)
## Variables actually used in tree construction:
## [1] "Price" "Population" "ShelveLoc" "Age" "Education"
## [6] "Income" "US" "CompPrice" "Advertising"
## Number of terminal nodes: 21
## Residual mean deviance: 0.5543 = 99.22 / 179
## Misclassification error rate: 0.115 = 23 / 200
# display the predictors used in the tree. Not all columns are used !
summary(seat_tree)$used
## [1] Price Population ShelveLoc Age Education Income
## [7] US CompPrice Advertising
## 11 Levels: <leaf> CompPrice Income Advertising Population Price ... US
# this is clearer
names(Carseats)[which(!(names(Carseats) %in% summary(seat_tree)$used))]
## [1] "Sales" "Urban"
# is the tree built on this training set the SAME as the tree built on all data?
# visual inspection tells you this is not the case
plot(seat_tree)
text(seat_tree, pretty = 0)
title(main = "Unpruned Classification Tree")
Confusion tables
seat_trn_pred = predict(seat_tree, seat_trn, type = "class")
seat_tst_pred = predict(seat_tree, seat_tst, type = "class")
#predict(seat_tree, seat_trn, type = "vector") this would not return the classes predicted but the probabilities of both claaes
#predict(seat_tree, seat_tst, type = "vector")
# confusion table of predictions of data in train
table(predicted = seat_trn_pred, actual = seat_trn$Sales)
## actual
## predicted High Low
## High 67 8
## Low 14 111
# same for data in test
table(predicted = seat_tst_pred, actual = seat_tst$Sales)
## actual
## predicted High Low
## High 51 12
## Low 32 105
What was training, what was testing predictions?
## actual
## predicted High Low
## High 57 29
## Low 27 87
## actual
## predicted High Low
## High 66 10
## Low 14 110
Now calculate accuracies
# define a function
accuracy = function(actual, predicted) {
mean(actual == predicted)
}
# run it on train
accuracy(predicted = seat_trn_pred, actual = seat_trn$Sales)
## [1] 0.89
# run it on test
accuracy(predicted = seat_tst_pred, actual = seat_tst$Sales)
## [1] 0.78
Conclusion : overfit
Cure? Pruning to avoid overfit
How? By penalizing large trees (too many leaves) (cost-complexity pruning CCP)
More precisely how? Use cross-validation to find the cost of each leaf
(Figure: Dec. 28, 1987 - Landing of the Soyuz-TM3 craft carrying the Soyuz-TM2 crew Y. Romanenko, Aleksandrov & Levchenko)
Apply the cross-validation receipe to find the optimal penalty hyperparameter
set.seed(3)
seat_tree_cv = cv.tree(seat_tree, FUN = prune.misclass)
# index of tree with minimum error
min_idx = which.min(seat_tree_cv$dev)
min_idx
## [1] 1
# number of terminal nodes in that tree
seat_tree_cv$size[min_idx]
## [1] 21
# misclassification rate of each tree
seat_tree_cv$dev / length(seat_idx)
## [1] 0.375 0.380 0.405 0.405 0.375 0.385 0.390 0.425 0.405
par(mfrow = c(1, 2))
# default plot
plot(seat_tree_cv)
# better plot
plot(seat_tree_cv$size, seat_tree_cv$dev / nrow(seat_trn), type = "b",
xlab = "Tree Size", ylab = "CV Misclassification Rate")
Apply the estimated hyperparameter to finalize our model (prune to target size)
seat_tree_prune = prune.misclass(seat_tree, best = 9)
summary(seat_tree_prune)
##
## Classification tree:
## snip.tree(tree = seat_tree, nodes = c(13L, 15L, 29L, 2L))
## Variables actually used in tree construction:
## [1] "Price" "ShelveLoc" "Income" "Age" "CompPrice"
## [6] "Population"
## Number of terminal nodes: 9
## Residual mean deviance: 0.9135 = 174.5 / 191
## Misclassification error rate: 0.175 = 35 / 200
Plot the optimized model
plot(seat_tree_prune)
text(seat_tree_prune, pretty = 0)
title(main = "Pruned Classification Tree")
Apply the Test-train business to evaluate overfitting of the NEW supposedly improved model
seat_prune_trn_pred = predict(seat_tree_prune, seat_trn, type = "class")
table(predicted = seat_prune_trn_pred, actual = seat_trn$Sales)
## actual
## predicted High Low
## High 62 16
## Low 19 103
accuracy(predicted = seat_prune_trn_pred, actual = seat_trn$Sales)
## [1] 0.825
seat_prune_tst_pred = predict(seat_tree_prune, seat_tst, type = "class")
table(predicted = seat_prune_tst_pred, actual = seat_tst$Sales)
## actual
## predicted High Low
## High 58 20
## Low 25 97
accuracy(predicted = seat_prune_tst_pred, actual = seat_tst$Sales)
## [1] 0.775
Tree models are easily prone to overfit.
OVERALL CONCLUSION: CCP worked SOMEWHAT but we need other better methods!
IDEAS: Will be introduced in next lectures :
Bagging
Boosting
Random Forests
Why is it that cross-validation does not entrain us in an infinite loop hell?
(src: Wikipedia)
Post-pruning algorithm for Decision Trees
Subtrees:
Pruning a subtree \(T_{t}\)
Pruning Algorithm:
Output:
The algorithm outputs $^1 ^2 … ^k … $
How to choose \(\alpha\)
Suppose we have the following tree:
Some formulas:
Iteration 1:
\(t\) | \(R_(t)\) | \(R(T_t)\) | \(g(t)\) |
---|---|---|---|
\(t_1\) | \(\cfrac{8}{16} \cdot \cfrac{16}{16}\) |
\(T_{t_1}\) - the entire tree all leaves are pure \(R(T_{t_1}) = 0\) |
\(\cfrac{8/16 - 0}{4 - 1} = \cfrac{1}{6}\) |
\(t_2\) |
\(\cfrac{4}{12} \cdot \cfrac{12}{16} = \cfrac{4}{16}\) (there are 12 records, 4 \(\blacksquare\) + 8 \(\bigcirc\) ) |
\(R(T_{t_2}) = 0\) | \(\cfrac{4/16 - 0}{3 - 1} = \cfrac{1}{8}\) |
\(t_3\) | \(\cfrac{2}{6} \cdot \cfrac{6}{16} = \cfrac{2}{16}\) | \(R(T_{t_3}) = 0\) | \(\cfrac{2/16 - 0}{3 - 1} = \cfrac{1}{8}\) |
We want to find the minimal \(g(t)\)
Iteration 2:
\(t\) | \(R_(t)\) | \(R(T_t)\) | \(g(t)\) |
---|---|---|---|
\(t_1\) | \(\cfrac{8}{16} \cdot \cfrac{16}{16}\) | \(\cfrac{2}{16}\) | \(\cfrac{8/16 - 2/16}{3 - 1} = \cfrac{6}{32}\) |
\(t_2\) | \(\cfrac{4}{12} \cdot \cfrac{12}{16}\) | \(\cfrac{2}{16}\) | \(\cfrac{4/16 - 2/16}{2 - 1} = \cfrac{1}{8}\) |
Find minimal \(g(t)\):
Iteration 3:
Selecting the best:
From IT4BI 2013 year exam: