Terminology

Are these fundamentally different?

You will see sites that say that CV has nothing to do with model building:

https://stackoverflow.com/questions/2314850/help-understanding-cross-validation-and-decision-trees

But you will also see lectures (Stanford) that say the opposite:

e.g. Find the optimal subtree by cross validation

https://web.stanford.edu/class/stats202/content/lec19.pdf

This is also the approach we used in class.

(Figure: Dec. 28, 1987 - Landing of the Soyuz-TM3 craft carrying the Soyuz-TM2 crew Y. Romanenko, Aleksandrov & Levchenko)

TT, L10, LpO, VS, stratified-CV, kF-CV, nested cross-VALIDCEPTION

https://www.cs.utoronto.ca/~fidler/teaching/2015/slides/CSC411/tutorial3_CrossVal-DTs.pdf

RECURSIVE PARTITIONING (library RPART in R)

https://www.statmethods.net/advstats/cart.html

BSP:

BSP for collision detection and polygon visibility in videogames:

BSP for creating games:

https://gamedevelopment.tutsplus.com/tutorials/how-to-use-bsp-trees-to-generate-game-maps--gamedev-12268

inter and intra-object representation:

DATA SET 1: Heart failure

Cross-validation

Recall the CV procedure:

heart=read.csv("http://faculty.marshall.usc.edu/gareth-james/ISL/Heart.csv")
head(heart)
##   X Age Sex    ChestPain RestBP Chol Fbs RestECG MaxHR ExAng Oldpeak Slope Ca
## 1 1  63   1      typical    145  233   1       2   150     0     2.3     3  0
## 2 2  67   1 asymptomatic    160  286   0       2   108     1     1.5     2  3
## 3 3  67   1 asymptomatic    120  229   0       2   129     1     2.6     2  2
## 4 4  37   1   nonanginal    130  250   0       0   187     0     3.5     3  0
## 5 5  41   0   nontypical    130  204   0       2   172     0     1.4     1  0
## 6 6  56   1   nontypical    120  236   0       0   178     0     0.8     1  0
##         Thal AHD
## 1      fixed  No
## 2     normal Yes
## 3 reversable Yes
## 4     normal  No
## 5     normal  No
## 6     normal  No
summary(heart)
##        X              Age             Sex                ChestPain  
##  Min.   :  1.0   Min.   :29.00   Min.   :0.0000   asymptomatic:144  
##  1st Qu.: 76.5   1st Qu.:48.00   1st Qu.:0.0000   nonanginal  : 86  
##  Median :152.0   Median :56.00   Median :1.0000   nontypical  : 50  
##  Mean   :152.0   Mean   :54.44   Mean   :0.6799   typical     : 23  
##  3rd Qu.:227.5   3rd Qu.:61.00   3rd Qu.:1.0000                     
##  Max.   :303.0   Max.   :77.00   Max.   :1.0000                     
##                                                                     
##      RestBP           Chol            Fbs            RestECG      
##  Min.   : 94.0   Min.   :126.0   Min.   :0.0000   Min.   :0.0000  
##  1st Qu.:120.0   1st Qu.:211.0   1st Qu.:0.0000   1st Qu.:0.0000  
##  Median :130.0   Median :241.0   Median :0.0000   Median :1.0000  
##  Mean   :131.7   Mean   :246.7   Mean   :0.1485   Mean   :0.9901  
##  3rd Qu.:140.0   3rd Qu.:275.0   3rd Qu.:0.0000   3rd Qu.:2.0000  
##  Max.   :200.0   Max.   :564.0   Max.   :1.0000   Max.   :2.0000  
##                                                                   
##      MaxHR           ExAng           Oldpeak         Slope      
##  Min.   : 71.0   Min.   :0.0000   Min.   :0.00   Min.   :1.000  
##  1st Qu.:133.5   1st Qu.:0.0000   1st Qu.:0.00   1st Qu.:1.000  
##  Median :153.0   Median :0.0000   Median :0.80   Median :2.000  
##  Mean   :149.6   Mean   :0.3267   Mean   :1.04   Mean   :1.601  
##  3rd Qu.:166.0   3rd Qu.:1.0000   3rd Qu.:1.60   3rd Qu.:2.000  
##  Max.   :202.0   Max.   :1.0000   Max.   :6.20   Max.   :3.000  
##                                                                 
##        Ca                 Thal      AHD     
##  Min.   :0.0000   fixed     : 18   No :164  
##  1st Qu.:0.0000   normal    :166   Yes:139  
##  Median :0.0000   reversable:117            
##  Mean   :0.6722   NA's      :  2            
##  3rd Qu.:1.0000                             
##  Max.   :3.0000                             
##  NA's   :4
nmiss=apply(is.na(heart),1,sum) #number of missing values, by row

heartm=heart[nmiss!=0,]  #cases with missing values
heart=heart[nmiss==0,]   #remove cases with missing values
head(heartm)
##       X Age Sex    ChestPain RestBP Chol Fbs RestECG MaxHR ExAng Oldpeak Slope
## 88   88  53   0   nonanginal    128  216   0       2   115     0     0.0     1
## 167 167  52   1   nonanginal    138  223   0       0   169     0     0.0     1
## 193 193  43   1 asymptomatic    132  247   1       2   143     1     0.1     2
## 267 267  52   1 asymptomatic    128  204   1       0   156     1     1.0     2
## 288 288  58   1   nontypical    125  220   0       0   144     0     0.4     2
## 303 303  38   1   nonanginal    138  175   0       0   173     0     0.0     1
##     Ca       Thal AHD
## 88   0       <NA>  No
## 167 NA     normal  No
## 193 NA reversable Yes
## 267  0       <NA> Yes
## 288 NA reversable  No
## 303 NA     normal  No
library(tree)
heart.tree=tree(AHD~., data=heart)
plot(heart.tree)
text(heart.tree)

print(heart.tree)
## node), split, n, deviance, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 297 409.900 No ( 0.53872 0.46128 )  
##    2) Thal: normal 164 175.100 No ( 0.77439 0.22561 )  
##      4) Ca < 0.5 115  81.150 No ( 0.88696 0.11304 )  
##        8) Age < 57.5 80  25.590 No ( 0.96250 0.03750 )  
##         16) MaxHR < 152.5 17  15.840 No ( 0.82353 0.17647 )  
##           32) Chol < 226.5 8   0.000 No ( 1.00000 0.00000 ) *
##           33) Chol > 226.5 9  11.460 No ( 0.66667 0.33333 ) *
##         17) MaxHR > 152.5 63   0.000 No ( 1.00000 0.00000 ) *
##        9) Age > 57.5 35  41.880 No ( 0.71429 0.28571 )  
##         18) Fbs < 0.5 29  37.360 No ( 0.65517 0.34483 ) *
##         19) Fbs > 0.5 6   0.000 No ( 1.00000 0.00000 ) *
##      5) Ca > 0.5 49  67.910 No ( 0.51020 0.48980 )  
##       10) ChestPain: nonanginal,nontypical,typical 29  32.050 No ( 0.75862 0.24138 )  
##         20) X < 230.5 19   7.835 No ( 0.94737 0.05263 ) *
##         21) X > 230.5 10  13.460 Yes ( 0.40000 0.60000 ) *
##       11) ChestPain: asymptomatic 20  16.910 Yes ( 0.15000 0.85000 )  
##         22) Sex < 0.5 6   8.318 No ( 0.50000 0.50000 ) *
##         23) Sex > 0.5 14   0.000 Yes ( 0.00000 1.00000 ) *
##    3) Thal: fixed,reversable 133 149.000 Yes ( 0.24812 0.75188 )  
##      6) Ca < 0.5 59  81.370 Yes ( 0.45763 0.54237 )  
##       12) ExAng < 0.5 33  42.010 No ( 0.66667 0.33333 )  
##         24) Age < 51 13  17.320 Yes ( 0.38462 0.61538 )  
##           48) ChestPain: nonanginal,nontypical 5   5.004 No ( 0.80000 0.20000 ) *
##           49) ChestPain: asymptomatic,typical 8   6.028 Yes ( 0.12500 0.87500 ) *
##         25) Age > 51 20  16.910 No ( 0.85000 0.15000 ) *
##       13) ExAng > 0.5 26  25.460 Yes ( 0.19231 0.80769 )  
##         26) Oldpeak < 1.55 11  15.160 Yes ( 0.45455 0.54545 )  
##           52) Chol < 240.5 6   5.407 No ( 0.83333 0.16667 ) *
##           53) Chol > 240.5 5   0.000 Yes ( 0.00000 1.00000 ) *
##         27) Oldpeak > 1.55 15   0.000 Yes ( 0.00000 1.00000 ) *
##      7) Ca > 0.5 74  41.650 Yes ( 0.08108 0.91892 )  
##       14) RestECG < 0.5 34  31.690 Yes ( 0.17647 0.82353 )  
##         28) MaxHR < 145 20   7.941 Yes ( 0.05000 0.95000 ) *
##         29) MaxHR > 145 14  18.250 Yes ( 0.35714 0.64286 )  
##           58) MaxHR < 158 5   5.004 No ( 0.80000 0.20000 ) *
##           59) MaxHR > 158 9   6.279 Yes ( 0.11111 0.88889 ) *
##       15) RestECG > 0.5 40   0.000 Yes ( 0.00000 1.00000 ) *
heartcv=cv.tree(heart.tree, FUN=prune.misclass)
print(heartcv)
## $size
## [1] 19 14 11  8  6  4  2  1
## 
## $dev
## [1]  79  79  75  69  65  79  90 142
## 
## $k
## [1] -Inf  0.0  1.0  2.0  3.0  5.5  7.0 67.0
## 
## $method
## [1] "misclass"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"
plot(heartcv$size,heartcv$dev,type='b')

bestsize=heartcv$size[heartcv$dev==min(heartcv$dev)]

Under the hood of CV

297

> rand <- sample(K, length(m[[1L]]), replace = TRUE)
> rand
  [1] 10  9  2  2  3  8  7  4  6  5  3 10  4  9  9  4  3  3  6  9  5  9  1  9
 [25]  8  3  1  8  2  1 10  9  7  2  4  6  4  7  5  8  9  8  3  8  3  2  8  1
 [49]  2  4  2  6  3  6  7  3  2  8 10  9  3  6  1  6 10  3 10  1  5  7  9  6
 [73]  7 10  4  6  2  3  1  2  9  2  7  2  2  5  7  8  6  9 10  2  8  1  2  2
 [97]  5  4 10  7 10  1  2  6  2  2  2  1  8  1  3  9  8  9  4  5  3  1 10  1
[121]  5  3  2  9 10  1  2  2  1  1  7  3  4  3  9  4  7  5  2  4  7  3  9  8
[145] 10  8  1  6 10  9  8 10  4  6  8  8  3  8  6  5  5 10 10  5  1  6  2  6
[169] 10  9  3 10  3 10  9  6  2  6 10  4 10  1  7 10 10  4  6  9  4  4  2  3
[193]  1  8  2  5  6  5  8  8  3  8  4  4  6  4  3  9  4  7 10  4  5 10  7  6
[217]  7  6  9  3  3  8  7  5 10  1  2  8  5 10  7  3  1  2  9  8  6  3  5  9
[241]  9  3  1  8  5  9  8  4  7  4  2  4  7  5  5  9  2  9  9  2  4  9 10  7
[265]  3  9  7  5  6 10  7  9 10  2  8  1  1  2  3  4  8  7 10  6  9  9  2 10
[289]  1  8  6  4 10  7  9  7 10



init <- do.call(FUN=prune.misclass, c(list(object)))


> init <- do.call(prune.misclass, c(list(object)))
> init$k
[1] -Inf  0.0  1.0  2.0  3.0  5.5  7.0 67.0

k   
cost-complexity parameter defining either a specific subtree of tree (k a scalar) or the (optional) sequence of subtrees minimizing the cost-complexity measure (k a vector).
 If missing, k is determined algorithmically.
 
 > init <- do.call(prune.misclass, c(list(object)))
> init$k
[1] -Inf  0.0  1.0  2.0  3.0  5.5  7.0 67.0


$size
[1] 19 14 11  8  6  4  2  1

$dev
[1]  66  66  56  60  59  71  79 138

$k
[1] -Inf  0.0  1.0  2.0  3.0  5.5  7.0 67.0

$method
[1] "misclass"

attr(,"class")

HERE WE ARE with the Tree Homotopy Process of CSP

#$size
#[1] 19 14 11  8  6  4  2  1

#$dev
#[1]  66  66  56  60  59  71  79 138

#$k
#[1] -Inf  0.0  1.0  2.0  3.0  5.5  7.0 67.0

MY MAIN POINT HERE (correction of a wrong statement I made during my previous lecture on classification trees)

Do you see where the grid of alpha values is?

The grid of SIZES is used instead for plotting !

Are there any jumps in sizes? (tree homotopy process - THP- of Breiman 1984)

DATASET 2: Forensic Glass Fragments

(src:R ?prune.misclass)


José R. Almirall, Tatiana Trejos
Advances in technology provide forensic scientists with better tools to detect, to identify, and to individualize small amounts of trace 
evidence that have been left at a crime scene. 
The analysis of glass fragments can be useful in solving cases such as hit and run, burglaries, kidnappings, and bombings. 
The value of glass as "evidentiary material" lies in its inherent characteristics such as: 
(a) it is a fragile material that is often broken and hence commonly found in various types of crime scenes, 
(b) it can be easily transferred from the broken source to the scene, suspect, and/or victim, 
(c) it is relatively persistent, 
(d) it is chemically stable, and 
(e) it has measurable physical and chemical properties that can provide significant evidence of an association between the recovered glass fragments and the source of the broken glass.
 Forensic scientists have dedicated considerable effort to study and improve the detection and discrimination capabilities of analytical techniques 
 in order to enhance the quality of information obtained from glass fragments. 
 This article serves as a review of the developments in the application of both traditional 
 and novel methods of glass analysis. 
 The greatest progress has been made with respect to the incorporation of automated refractive 
 index measurements and elemental analysis to the analytical scheme. 
 Glass examiners have applied state-of-the-art technology including elemental analysis by sensitive methods such as ICPMS 
 and LA-ICP-MS. 
 A review of the literature regarding transfer, persistence, and interpretation of glass is also presented. LESS

library(tree)
data(fgl, package="MASS")
fgl.tr <- tree(type ~ ., fgl)
#plot(print(fgl.tr))
fgl.cv <- cv.tree(fgl.tr,, prune.tree)
for(i in 2:5)  fgl.cv$dev <- fgl.cv$dev +
   cv.tree(fgl.tr,, prune.tree)$dev
fgl.cv$dev <- fgl.cv$dev/5
plot(fgl.cv)

fgl.cv <- cv.tree(fgl.tr,, prune.tree)
plot(fgl.cv)

fgl.cv <- cv.tree(fgl.tr,, prune.tree)
plot(fgl.cv)

fgl.cv <- cv.tree(fgl.tr,, prune.tree)
plot(fgl.cv)

fgl.cv <- cv.tree(fgl.tr,, prune.tree)
fgl.cv
## $size
##  [1] 20 19 18 17 16 15 14 13 12 11 10  9  8  5  4  3  2  1
## 
## $dev
##  [1] 549.3466 533.6819 533.6819 523.6940 517.9623 516.3673 515.6067 515.9651
##  [9] 504.7484 502.9228 491.7872 491.8353 477.6573 481.5056 492.7666 522.9014
## [17] 501.2550 653.5372
## 
## $k
##  [1]       -Inf   6.765927   6.771674   8.099535   8.940479   9.751469
##  [7]   9.873409   9.994950  10.356555  13.077082  16.041350  16.132081
## [13]  18.672227  22.627954  38.229167  50.117941  55.081551 166.958846
## 
## $method
## [1] "deviance"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"

MY MAIN POINT HERE (correction of a wrong statement I made during my previous lecture on classification trees)

Do you see where the grid of alpha values is?

The grid of SIZES is used instead for plotting !

Are there any jumps in sizes? (tree homotopy process - THP- of Breiman 1984)

DATASET 3: Carseats sales

(src:daviddalpiaz.github.io/r4sl/trees.html)

Full tree

library(tree)
library(ISLR)

data(Carseats)
#?Carseats
str(Carseats)
## 'data.frame':    400 obs. of  11 variables:
##  $ Sales      : num  9.5 11.22 10.06 7.4 4.15 ...
##  $ CompPrice  : num  138 111 113 117 141 124 115 136 132 132 ...
##  $ Income     : num  73 48 35 100 64 113 105 81 110 113 ...
##  $ Advertising: num  11 16 10 4 3 13 0 15 0 0 ...
##  $ Population : num  276 260 269 466 340 501 45 425 108 131 ...
##  $ Price      : num  120 83 80 97 128 72 108 120 124 124 ...
##  $ ShelveLoc  : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
##  $ Age        : num  42 65 59 55 38 78 71 67 76 76 ...
##  $ Education  : num  17 10 12 14 13 16 15 10 10 17 ...
##  $ Urban      : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
##  $ US         : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
Carseats$Sales = as.factor(ifelse(Carseats$Sales <= 8, "Low", "High"))
str(Carseats)
## 'data.frame':    400 obs. of  11 variables:
##  $ Sales      : Factor w/ 2 levels "High","Low": 1 1 1 2 2 1 2 1 2 2 ...
##  $ CompPrice  : num  138 111 113 117 141 124 115 136 132 132 ...
##  $ Income     : num  73 48 35 100 64 113 105 81 110 113 ...
##  $ Advertising: num  11 16 10 4 3 13 0 15 0 0 ...
##  $ Population : num  276 260 269 466 340 501 45 425 108 131 ...
##  $ Price      : num  120 83 80 97 128 72 108 120 124 124 ...
##  $ ShelveLoc  : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
##  $ Age        : num  42 65 59 55 38 78 71 67 76 76 ...
##  $ Education  : num  17 10 12 14 13 16 15 10 10 17 ...
##  $ Urban      : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
##  $ US         : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
# Fit an unpruned classification tree using all of the predictors.

seat_tree = tree(Sales ~ ., data = Carseats)
# seat_tree = tree(Sales ~ ., data = Carseats, 
#                  control = tree.control(nobs = nrow(Carseats), minsize = 10))
summary(seat_tree)
## 
## Classification tree:
## tree(formula = Sales ~ ., data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc"   "Price"       "US"          "Income"      "CompPrice"  
## [6] "Population"  "Advertising" "Age"        
## Number of terminal nodes:  27 
## Residual mean deviance:  0.4575 = 170.7 / 373 
## Misclassification error rate: 0.09 = 36 / 400
plot(seat_tree)
text(seat_tree, pretty = 0)
title(main = "Sales of carseats: full classification tree: no pruning!")

#  details of the splits.
seat_tree
## node), split, n, deviance, yval, (yprob)
##       * denotes terminal node
## 
##   1) root 400 541.500 Low ( 0.41000 0.59000 )  
##     2) ShelveLoc: Good 85  90.330 High ( 0.77647 0.22353 )  
##       4) Price < 135 68  49.260 High ( 0.88235 0.11765 )  
##         8) US: No 17  22.070 High ( 0.64706 0.35294 )  
##          16) Price < 109 8   0.000 High ( 1.00000 0.00000 ) *
##          17) Price > 109 9  11.460 Low ( 0.33333 0.66667 ) *
##         9) US: Yes 51  16.880 High ( 0.96078 0.03922 ) *
##       5) Price > 135 17  22.070 Low ( 0.35294 0.64706 )  
##        10) Income < 46 6   0.000 Low ( 0.00000 1.00000 ) *
##        11) Income > 46 11  15.160 High ( 0.54545 0.45455 ) *
##     3) ShelveLoc: Bad,Medium 315 390.600 Low ( 0.31111 0.68889 )  
##       6) Price < 92.5 46  56.530 High ( 0.69565 0.30435 )  
##        12) Income < 57 10  12.220 Low ( 0.30000 0.70000 )  
##          24) CompPrice < 110.5 5   0.000 Low ( 0.00000 1.00000 ) *
##          25) CompPrice > 110.5 5   6.730 High ( 0.60000 0.40000 ) *
##        13) Income > 57 36  35.470 High ( 0.80556 0.19444 )  
##          26) Population < 207.5 16  21.170 High ( 0.62500 0.37500 ) *
##          27) Population > 207.5 20   7.941 High ( 0.95000 0.05000 ) *
##       7) Price > 92.5 269 299.800 Low ( 0.24535 0.75465 )  
##        14) Advertising < 13.5 224 213.200 Low ( 0.18304 0.81696 )  
##          28) CompPrice < 124.5 96  44.890 Low ( 0.06250 0.93750 )  
##            56) Price < 106.5 38  33.150 Low ( 0.15789 0.84211 )  
##             112) Population < 177 12  16.300 Low ( 0.41667 0.58333 )  
##               224) Income < 60.5 6   0.000 Low ( 0.00000 1.00000 ) *
##               225) Income > 60.5 6   5.407 High ( 0.83333 0.16667 ) *
##             113) Population > 177 26   8.477 Low ( 0.03846 0.96154 ) *
##            57) Price > 106.5 58   0.000 Low ( 0.00000 1.00000 ) *
##          29) CompPrice > 124.5 128 150.200 Low ( 0.27344 0.72656 )  
##            58) Price < 122.5 51  70.680 High ( 0.50980 0.49020 )  
##             116) ShelveLoc: Bad 11   6.702 Low ( 0.09091 0.90909 ) *
##             117) ShelveLoc: Medium 40  52.930 High ( 0.62500 0.37500 )  
##               234) Price < 109.5 16   7.481 High ( 0.93750 0.06250 ) *
##               235) Price > 109.5 24  32.600 Low ( 0.41667 0.58333 )  
##                 470) Age < 49.5 13  16.050 High ( 0.69231 0.30769 ) *
##                 471) Age > 49.5 11   6.702 Low ( 0.09091 0.90909 ) *
##            59) Price > 122.5 77  55.540 Low ( 0.11688 0.88312 )  
##             118) CompPrice < 147.5 58  17.400 Low ( 0.03448 0.96552 ) *
##             119) CompPrice > 147.5 19  25.010 Low ( 0.36842 0.63158 )  
##               238) Price < 147 12  16.300 High ( 0.58333 0.41667 )  
##                 476) CompPrice < 152.5 7   5.742 High ( 0.85714 0.14286 ) *
##                 477) CompPrice > 152.5 5   5.004 Low ( 0.20000 0.80000 ) *
##               239) Price > 147 7   0.000 Low ( 0.00000 1.00000 ) *
##        15) Advertising > 13.5 45  61.830 High ( 0.55556 0.44444 )  
##          30) Age < 54.5 25  25.020 High ( 0.80000 0.20000 )  
##            60) CompPrice < 130.5 14  18.250 High ( 0.64286 0.35714 )  
##             120) Income < 100 9  12.370 Low ( 0.44444 0.55556 ) *
##             121) Income > 100 5   0.000 High ( 1.00000 0.00000 ) *
##            61) CompPrice > 130.5 11   0.000 High ( 1.00000 0.00000 ) *
##          31) Age > 54.5 20  22.490 Low ( 0.25000 0.75000 )  
##            62) CompPrice < 122.5 10   0.000 Low ( 0.00000 1.00000 ) *
##            63) CompPrice > 122.5 10  13.860 Low ( 0.50000 0.50000 )  
##             126) Price < 125 5   0.000 High ( 1.00000 0.00000 ) *
##             127) Price > 125 5   0.000 Low ( 0.00000 1.00000 ) *

Test-Train

PLAN: evaluate the overfit of a given test-train partition

simple: 1. predict answers of test and train and 2. compare each to truth (confusion table) and 3. compare tables

# try a test-train split
dim(Carseats)
## [1] 400  11
## [1] 400  11
set.seed(2)
seat_idx = sample(1:nrow(Carseats), 200)
seat_trn = Carseats[seat_idx,]
seat_tst = Carseats[-seat_idx,]
seat_tree = tree(Sales ~ ., data = seat_trn)
summary(seat_tree)
## 
## Classification tree:
## tree(formula = Sales ~ ., data = seat_trn)
## Variables actually used in tree construction:
## [1] "Price"       "Population"  "ShelveLoc"   "Age"         "Education"  
## [6] "Income"      "US"          "CompPrice"   "Advertising"
## Number of terminal nodes:  21 
## Residual mean deviance:  0.5543 = 99.22 / 179 
## Misclassification error rate: 0.115 = 23 / 200
# display the predictors used in the tree. Not all columns are used !
summary(seat_tree)$used
## [1] Price       Population  ShelveLoc   Age         Education   Income     
## [7] US          CompPrice   Advertising
## 11 Levels: <leaf> CompPrice Income Advertising Population Price ... US
# this is clearer
names(Carseats)[which(!(names(Carseats) %in% summary(seat_tree)$used))]
## [1] "Sales" "Urban"
# is the tree built on this training set the SAME as the tree built on all data?
# visual inspection tells you this is not the case

plot(seat_tree)
text(seat_tree, pretty = 0)
title(main = "Unpruned Classification Tree")

Confusion tables

seat_trn_pred = predict(seat_tree, seat_trn, type = "class")
seat_tst_pred = predict(seat_tree, seat_tst, type = "class")
#predict(seat_tree, seat_trn, type = "vector")  this would not return the classes predicted but the probabilities of both claaes
#predict(seat_tree, seat_tst, type = "vector")

# confusion table of predictions of data in train
table(predicted = seat_trn_pred, actual = seat_trn$Sales)
##          actual
## predicted High Low
##      High   67   8
##      Low    14 111
# same for data in test
table(predicted = seat_tst_pred, actual = seat_tst$Sales)
##          actual
## predicted High Low
##      High   51  12
##      Low    32 105

What was training, what was testing predictions?

##          actual
## predicted High Low
##      High   57  29
##      Low    27  87


##          actual
## predicted High Low
##      High   66  10
##      Low    14 110

Now calculate accuracies

# define a function
accuracy = function(actual, predicted) {
  mean(actual == predicted)
}

# run it on train
accuracy(predicted = seat_trn_pred, actual = seat_trn$Sales)
## [1] 0.89
# run it on test
accuracy(predicted = seat_tst_pred, actual = seat_tst$Sales)
## [1] 0.78

Conclusion : overfit

Cure? Pruning to avoid overfit

How? By penalizing large trees (too many leaves) (cost-complexity pruning CCP)

More precisely how? Use cross-validation to find the cost of each leaf

CCP

(Figure: Dec. 28, 1987 - Landing of the Soyuz-TM3 craft carrying the Soyuz-TM2 crew Y. Romanenko, Aleksandrov & Levchenko)

Cross-validation tuning of leaf-cost

Apply the cross-validation receipe to find the optimal penalty hyperparameter

set.seed(3)
seat_tree_cv = cv.tree(seat_tree, FUN = prune.misclass)
# index of tree with minimum error
min_idx = which.min(seat_tree_cv$dev)
min_idx
## [1] 1
# number of terminal nodes in that tree
seat_tree_cv$size[min_idx]
## [1] 21
# misclassification rate of each tree
seat_tree_cv$dev / length(seat_idx)
## [1] 0.375 0.380 0.405 0.405 0.375 0.385 0.390 0.425 0.405
par(mfrow = c(1, 2))
# default plot
plot(seat_tree_cv)
# better plot
plot(seat_tree_cv$size, seat_tree_cv$dev / nrow(seat_trn), type = "b",
     xlab = "Tree Size", ylab = "CV Misclassification Rate")

Apply the estimated hyperparameter to finalize our model (prune to target size)

seat_tree_prune = prune.misclass(seat_tree, best = 9)
summary(seat_tree_prune)
## 
## Classification tree:
## snip.tree(tree = seat_tree, nodes = c(13L, 15L, 29L, 2L))
## Variables actually used in tree construction:
## [1] "Price"      "ShelveLoc"  "Income"     "Age"        "CompPrice" 
## [6] "Population"
## Number of terminal nodes:  9 
## Residual mean deviance:  0.9135 = 174.5 / 191 
## Misclassification error rate: 0.175 = 35 / 200

Plot the optimized model

plot(seat_tree_prune)
text(seat_tree_prune, pretty = 0)
title(main = "Pruned Classification Tree")

Apply the Test-train business to evaluate overfitting of the NEW supposedly improved model

seat_prune_trn_pred = predict(seat_tree_prune, seat_trn, type = "class")
table(predicted = seat_prune_trn_pred, actual = seat_trn$Sales)
##          actual
## predicted High Low
##      High   62  16
##      Low    19 103
accuracy(predicted = seat_prune_trn_pred, actual = seat_trn$Sales)
## [1] 0.825
seat_prune_tst_pred = predict(seat_tree_prune, seat_tst, type = "class")
table(predicted = seat_prune_tst_pred, actual = seat_tst$Sales)
##          actual
## predicted High Low
##      High   58  20
##      Low    25  97
accuracy(predicted = seat_prune_tst_pred, actual = seat_tst$Sales)
## [1] 0.775
  • Training set: prediction errors comparable to before (high)
  • Testing set: predictions were a little improved

Tree models are easily prone to overfit.

OVERALL CONCLUSION: CCP worked SOMEWHAT but we need other better methods!

IDEAS: Will be introduced in next lectures :

  • Bagging

  • Boosting

  • Random Forests

Question for the curious

Why is it that cross-validation does not entrain us in an infinite loop hell?

Using a test set to assess the misclassification rate

(src: Wikipedia)

Cost-Complexity Pruning

Post-pruning algorithm for Decision Trees


Cost-Complexity Function


Pruning Subtrees

Subtrees:

Pruning a subtree \(T_{t}\)



Algorithm

Pruning Algorithm:


Output:


Choosing \(\alpha\)

The algorithm outputs $^1 ^2  …  ^k  … $

How to choose \(\alpha\)


Example

Example 1

Suppose we have the following tree:

Some formulas:


Iteration 1:

\(t\) \(R_(t)\) \(R(T_t)\) \(g(t)\)
\(t_1\) \(\cfrac{8}{16} \cdot \cfrac{16}{16}\) \(T_{t_1}\) - the entire tree
all leaves are pure
\(R(T_{t_1}) = 0\)
\(\cfrac{8/16 - 0}{4 - 1} = \cfrac{1}{6}\)
\(t_2\) \(\cfrac{4}{12} \cdot \cfrac{12}{16} = \cfrac{4}{16}\)
(there are 12 records, 4 \(\blacksquare\) + 8 \(\bigcirc\) )
\(R(T_{t_2}) = 0\) \(\cfrac{4/16 - 0}{3 - 1} = \cfrac{1}{8}\)
\(t_3\) \(\cfrac{2}{6} \cdot \cfrac{6}{16} = \cfrac{2}{16}\) \(R(T_{t_3}) = 0\) \(\cfrac{2/16 - 0}{3 - 1} = \cfrac{1}{8}\)

We want to find the minimal \(g(t)\)

decision-tree-pruning-ex1-2.png


Iteration 2:


\(t\) \(R_(t)\) \(R(T_t)\) \(g(t)\)
\(t_1\) \(\cfrac{8}{16} \cdot \cfrac{16}{16}\) \(\cfrac{2}{16}\) \(\cfrac{8/16 - 2/16}{3 - 1} = \cfrac{6}{32}\)
\(t_2\) \(\cfrac{4}{12} \cdot \cfrac{12}{16}\) \(\cfrac{2}{16}\) \(\cfrac{4/16 - 2/16}{2 - 1} = \cfrac{1}{8}\)


Find minimal \(g(t)\):


Iteration 3:


Selecting the best:


Example 2

From IT4BI 2013 year exam: