Why is the MNIST digit prediction problem popular?
Answer: because of the simplicity and importance of this task, and of the role it played in the emergence of deep learning models.
Lecun 1989 (7000 citations): goal: decipher (handwritten) ZIP codes
Fundamental idea: CONSTRAIN THE LEARNING NETWORK to tune its complexity
Deep learning rests on the recognition of the role of symmetries in data (my opinion anyway).
We are NOT going to use neural networks here. Instead we will use (unconstrained) classical prediction methods (including Random Forest), and this will serve as a basis for a comparison between the methods learned in this course and more advanced methods (deep networks) that you will see or use in the future.
Our goal is to learn how to load the data, and fit and compare various models of prediction in R.
Images (original 28x28, here downsampled to 16x16) of handwritten digits (0-9)
Load the data:
set.seed(1)
train <- read.csv("../Data/mnist_train.psv", sep="|", as.is=TRUE, header=FALSE)
test <- read.csv("../Data/mnist_test.psv", sep="|", as.is=TRUE, header=FALSE)
257 columns:
column 1: digit class (response)
column i+1 : pixel intensity in (-1,1) of pixel number i
dim(train)
## [1] 7291 257
train[1:10,1:10]
## V1 V2 V3 V4 V5 V6 V7 V8 V9 V10
## 1 6 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 -0.631 0.862
## 2 5 -1 -1 -1 -0.813 -0.671 -0.809 -0.887 -0.671 -0.853
## 3 4 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 -1.000 -1.000
## 4 7 -1 -1 -1 -1.000 -1.000 -0.273 0.684 0.960 0.450
## 5 3 -1 -1 -1 -1.000 -1.000 -0.928 -0.204 0.751 0.466
## 6 6 -1 -1 -1 -1.000 -1.000 -0.397 0.983 -0.535 -1.000
## 7 3 -1 -1 -1 -0.830 0.442 1.000 1.000 0.479 -0.328
## 8 1 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 0.510 -0.213
## 9 0 -1 -1 -1 -1.000 -1.000 -0.454 0.879 -0.745 -1.000
## 10 1 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 -0.909 0.801
Use rasterImage to plot images
y <- matrix(as.matrix(train[3400,-1]),16,16,byrow=TRUE)
y <- 1 - (y + 1)*0.5
plot(0,0)
rasterImage(y,-1,-1,1,1)
Plot a grid of 48 observations.
ngr=6
ngc=8
iset <- sample(1:nrow(train),ngr*ngc)
#5*7)
par(mar=c(0,0,0,0))
par(mfrow=c(ngr,ngc))
for (j in iset) {
y <- matrix(as.matrix(train[j,-1]),16,16,byrow=TRUE)
y <- 1 - (y + 1)*0.5
plot(0,0,xlab="",ylab="",axes=FALSE)
rasterImage(y,-1,-1,1,1)
box()
text(-0.8,-0.7, train[j,1], cex=3, col="red")
}
Separate the predictors from the response:
Xtrain <- as.matrix(train[,-1])
Xtest <- as.matrix(test[,-1])
ytrain <- train[,1]
ytest <- test[,1]
We are ready to build prediction models.
We can also run a random forest model. It will run significantly faster if I restrict the maximum number of nodes somewhat.
library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
outRf1 <- randomForest(Xtrain, factor(ytrain), maxnodes=10)
predRf1 <- predict(outRf1, Xtest)
Let us calculation the misclassification rate on the test set:
mean(predRf1 != ytest)
## [1] 0.2690583
Is this good or is this bad? Maybe we can do better by making a more appropriate choice of the hyperparameters. Here I will only play with ‘maxnodes’
outRf2 <- randomForest(Xtrain, factor(ytrain), maxnodes=20)
predRf2 <- predict(outRf2, Xtest)
outRf3 <- randomForest(Xtrain, factor(ytrain), maxnodes=30)
predRf3 <- predict(outRf3, Xtest)
How well did we do? 10: 0.26, 20: 0.14, 30: 0.128
The accuracy seems to saturate. Assuming that this model is the one we want to keep, the question remains: can we do better?
We need to run another model to answer this question.
Let us try a simple alternative: the k-nearest neighbors (K-nn) model
As a simple model, we can use k-nearest neighbors. I set k equal to three, which in a multi-class model says to use the closest point unless the next two closest points agree on the class label.
\(k=3\) means that we select the class of the closest point as the predicted response except when the next two closest points give the same class label. In the latter case, we vote for the common label, not for the label of the closest point
library(FNN)
predKnn <- knn(Xtrain,Xtest,ytrain,k=3)
Let’s calculation the misclassification rate of K-nn on the test set (same set as before!)
mean(predKnn != ytest)
## [1] 0.05530643
Oh boy!
Lesson learned?
Random forest will not necessarily to the best job?
To consolidate the comparison between the two models, we should of course to a more extensive validation check. I let you do this, e.g. by repeating this exercises on other test-train random splits.
We have learned how to load, plot and classify images in R
Can you do the same with handwritten letters (search for the Emnist dataset)