Theme of activity

MNIST digit prediction

Why is the MNIST digit prediction problem popular?

Answer: because of the simplicity and importance of this task, and of the role it played in the emergence of deep learning models.

Lecun 1989 (7000 citations): goal: decipher (handwritten) ZIP codes

Fundamental idea: CONSTRAIN THE LEARNING NETWORK to tune its complexity

Deep learning rests on the recognition of the role of symmetries in data (my opinion anyway).

We are NOT going to use neural networks here. Instead we will use (unconstrained) classical prediction methods (including Random Forest), and this will serve as a basis for a comparison between the methods learned in this course and more advanced methods (deep networks) that you will see or use in the future.

Our goal is to learn how to load the data, and fit and compare various models of prediction in R.

The data

Images (original 28x28, here downsampled to 16x16) of handwritten digits (0-9)

Load the data:

set.seed(1)
train <- read.csv("../Data/mnist_train.psv", sep="|", as.is=TRUE, header=FALSE)
test <- read.csv("../Data/mnist_test.psv", sep="|", as.is=TRUE, header=FALSE)

257 columns:

column 1: digit class (response)

column i+1 : pixel intensity in (-1,1) of pixel number i

dim(train)
## [1] 7291  257
train[1:10,1:10]
##    V1 V2 V3 V4     V5     V6     V7     V8     V9    V10
## 1   6 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 -0.631  0.862
## 2   5 -1 -1 -1 -0.813 -0.671 -0.809 -0.887 -0.671 -0.853
## 3   4 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 -1.000 -1.000
## 4   7 -1 -1 -1 -1.000 -1.000 -0.273  0.684  0.960  0.450
## 5   3 -1 -1 -1 -1.000 -1.000 -0.928 -0.204  0.751  0.466
## 6   6 -1 -1 -1 -1.000 -1.000 -0.397  0.983 -0.535 -1.000
## 7   3 -1 -1 -1 -0.830  0.442  1.000  1.000  0.479 -0.328
## 8   1 -1 -1 -1 -1.000 -1.000 -1.000 -1.000  0.510 -0.213
## 9   0 -1 -1 -1 -1.000 -1.000 -0.454  0.879 -0.745 -1.000
## 10  1 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 -0.909  0.801

Use rasterImage to plot images

y <- matrix(as.matrix(train[3400,-1]),16,16,byrow=TRUE)
y <- 1 - (y + 1)*0.5

plot(0,0)
rasterImage(y,-1,-1,1,1)

Plot a grid of 48 observations.

ngr=6
ngc=8
iset <- sample(1:nrow(train),ngr*ngc)
               #5*7)
par(mar=c(0,0,0,0))
par(mfrow=c(ngr,ngc))
for (j in iset) {
  y <- matrix(as.matrix(train[j,-1]),16,16,byrow=TRUE)
  y <- 1 - (y + 1)*0.5

  plot(0,0,xlab="",ylab="",axes=FALSE)
  rasterImage(y,-1,-1,1,1)
  box()
  text(-0.8,-0.7, train[j,1], cex=3, col="red")
}

Separate the predictors from the response:

Xtrain <- as.matrix(train[,-1])
Xtest <- as.matrix(test[,-1])
ytrain <- train[,1]
ytest <- test[,1]

We are ready to build prediction models.

Prediction by Random forest

We can also run a random forest model. It will run significantly faster if I restrict the maximum number of nodes somewhat.

library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
outRf1 <- randomForest(Xtrain,  factor(ytrain), maxnodes=10)
predRf1 <- predict(outRf1, Xtest)

Let us calculation the misclassification rate on the test set:

mean(predRf1 != ytest)
## [1] 0.2690583

Is this good or is this bad? Maybe we can do better by making a more appropriate choice of the hyperparameters. Here I will only play with ‘maxnodes’

outRf2 <- randomForest(Xtrain,  factor(ytrain), maxnodes=20)
predRf2 <- predict(outRf2, Xtest)
outRf3 <- randomForest(Xtrain,  factor(ytrain), maxnodes=30)
predRf3 <- predict(outRf3, Xtest)

How well did we do? 10: 0.26, 20: 0.14, 30: 0.128

The accuracy seems to saturate. Assuming that this model is the one we want to keep, the question remains: can we do better?

We need to run another model to answer this question.

Let us try a simple alternative: the k-nearest neighbors (K-nn) model

fit a K-nearest neighbors model (alternative model 1)

As a simple model, we can use k-nearest neighbors. I set k equal to three, which in a multi-class model says to use the closest point unless the next two closest points agree on the class label.

\(k=3\) means that we select the class of the closest point as the predicted response except when the next two closest points give the same class label. In the latter case, we vote for the common label, not for the label of the closest point

library(FNN)
predKnn <- knn(Xtrain,Xtest,ytrain,k=3)

Let’s calculation the misclassification rate of K-nn on the test set (same set as before!)

mean(predKnn != ytest)
## [1] 0.05530643

Oh boy!

Lesson learned?

  • Random forest will not necessarily to the best job?

  • To consolidate the comparison between the two models, we should of course to a more extensive validation check. I let you do this, e.g. by repeating this exercises on other test-train random splits.

Summary