How to Use Cross-Validation for a KNN Classification Model in R

Jesse John Feb 02, 2024
  1. Different Cross-Validation Approaches
  2. Repeated K-Fold Cross Validation for a K-Nearest Neighbor Classification Model
How to Use Cross-Validation for a KNN Classification Model in R

Cross-validation allows us to assess a model’s performance on new data even though we only have the training data set. It is a general technique that can apply to regression and classification models.

This article will discuss how to perform k-fold repeated cross-validation for a K-Nearest Neighbor (KNN) classification model. We will employ the caret package for this purpose.

The K in KNN refers to the number of neighbors of observation. On the other hand, k in k-fold is the number of subsets of the training data.

Different Cross-Validation Approaches

There are different approaches to cross-validation.

The most basic version uses one subset of the training data to validate the model, called the validation set approach. The model is fit only once and then tested on the subset.

The other involves fitting as many models as there are observations and taking the average error rate. Each model is fitted with one observation left out; the model is then tested on that one observation.

This is called Leave One Out Cross Validation (LOOCV).

The most helpful approach involves:

  1. Splitting the training data set into k folds (groups),
  2. Fitting the model k times,
  3. Leaving out one fold, and
  4. Testing the model on that.

This is called the k-fold cross-validation. Usually, a k value of 5 or 10 gives good results.

An enhancement to the k-fold cross-validation involves fitting the k-fold cross-validation model several times with different splits of the folds. This is called the repeated k-fold cross-validation, which we will use.

Repeated K-Fold Cross Validation for a K-Nearest Neighbor Classification Model

We will now see how to fit a K-Nearest Neighbor (KNN) classification model using repeated k-fold cross-validation. We will be using the caret package.

The caret package is highly versatile and can be used to build several types of models. Refer to its documentation on CRAN for more details.

Usually, k-fold cross-validation only tells us how accurate our model is expected to be on new data.

For example, suppose we fit a K = 5 KNN model using repeated k-fold cross-validation using 10 k-folds, repeated 3 times. The model will be fit 10 times for each of the 3 different data splits, and we will get performance metrics for only one model: the one with K = 5 neighbors.

In addition to the above, the caret package allows us to fit KNN models for different values of K. The function then reports the value of K that results in the best model and creates that model for us.

The createDataPartition() function creates a stratified random split of a factor vector. We will use this to separate our data into training and testing subsets to verify the model’s accuracy.

The train() function is the main function to create a model, where:

  1. x is the data frame with the predictors.
  2. y is the outcomes data frame or vector.
  3. The method argument takes the type of model we want to build. We will specify knn.
  4. For preprocess, we will specify scale and center.
  5. The trControl argument allows us to specify the specifics of the cross-validation procedure.
  6. The tuneGrid argument will help create and compare multiple models. It takes a data frame with the name of the parameter to tune.

Because we are building a KNN model, we will give k in lowercase as the tuning parameter to tuneGrid. We will provide a vector of K values from 1 to 12, for which we want the function to create and test models.

The specifics of the cross-validation are passed to the trControl argument using the trainControl() function.

  1. For the method argument, we will specify repeatedcv because we want repeated cross-validation.
  2. When the method is cv or repeatedcv, the number argument specifies the folds, k. We will use 10.
  3. The repeats argument specifies how many times the k-fold split must be repeated.

Example Code:

# Create the data vectors for the demonstration.
# We will create two numeric vectors as predictors.
# Each vector will have two distinct groups to suit our model.
# We will create a factor with two levels.
# The factor levels correspond to the groups in the predictor vectors.

set.seed(564)
vX1a = round(rnorm(100, 2,2))+4
set.seed(574)
vX2a = round(rnorm(100, 15,4))

set.seed(584)
vX1b = round(rnorm(100, 10,3))+5
set.seed(594)
vX2b = round(rnorm(100, 5,4))

vYa = rep("Blue", 100)
vYb = rep("Red", 100)

vX1 = c(vX1a, vX1b)
vX2 = c(vX2a, vX2b)
vY = c(vYa, vYb)

# Dummy column for ordering rows.
set.seed(528)
R = sample(1:200,200)

# Temporary data frame.
temp_df = data.frame(X1 = vX1, X2 = vX2, Y = as.factor(vY), R)

# Packages that we will use.
library(ggplot2)
library(dplyr)

# See the sample data.
temp_df %>% ggplot(aes(x=X1, y = X2, colour = Y)) + geom_point()

# Re-order the rows, just to see that the KNN model works with the rows jumbled up.
# Final data frame.
# Notice that the outputs are a factor vector.
fin_df = temp_df %>% arrange(R) %>% select(X1, X2, Y)
head(fin_df)
str(fin_df)

# Install the caret package if it is not already installed.
# To install, uncomment the next line and run it.
# install.packages("caret")

# Load the caret package.
library(caret)

# Split the data frame into a training set and test set.
# Create a list of row numbers in the training set.
# This function creates a stratified random sample of all the outcome classes.
set.seed(365)
training_row_index = createDataPartition(fin_df[,3], p=0.75, list=FALSE)

# Create training sets of the predictors and the corresponding outcomes.
trg_data = fin_df[training_row_index,1:2]
trg_class = fin_df[training_row_index,3]

# Create the test set of predictors and the outcomes that we will later use.
tst_data = fin_df[-training_row_index,1:2]
tst_class = fin_df[-training_row_index,3]

# Let us check if the sample is stratified:
table(tst_class)
# Obviously, the training sample will complement these numbers of the totals.

# We will build a K-Nearest neighbors model using repeated k-fold cross-validation.
# The arguments are described in the article.
mod_knn = train(x = trg_data,
                y = trg_class,
                method = "knn",
                preProcess = c("center", "scale"),
                tuneGrid = data.frame(k = c(1:12)),
                trControl = trainControl(method = "repeatedcv",
                                         number = 10,
                                         repeats = 3)
                )

# View the fitted model.
mod_knn

Output:

> head(fin_df)
  X1 X2    Y
1 15  6  Red
2 15  2  Red
3 14  3  Red
4  4 22 Blue
5 20 -3  Red
6  4 22 Blue

> str(fin_df)
'data.frame':	200 obs. of  3 variables:
 $ X1: num  15 15 14 4 20 4 15 2 20 13 ...
 $ X2: num  6 2 3 22 -3 22 7 16 9 -6 ...
 $ Y : Factor w/ 2 levels "Blue","Red": 2 2 2 1 2 1 2 1 2 2 ...

> # Let us check if the sample is really stratified:
> table(tst_class)
tst_class
Blue  Red
  25   25

> # View the fitted model.
> mod_knn
k-Nearest Neighbors

150 samples
  2 predictor
  2 classes: 'Blue', 'Red'

Pre-processing: centered (2), scaled (2)
Resampling: Cross-Validated (10 fold, repeated 3 times)
Summary of sample sizes: 135, 136, 135, 135, 135, 134, ...
Resampling results across tuning parameters:

  k   Accuracy   Kappa
   1  0.9710317  0.9420024
   2  0.9736905  0.9473602
   3  0.9753373  0.9505141
   4  0.9842460  0.9683719
   5  0.9864683  0.9728764
   6  0.9843849  0.9687098
   7  0.9843849  0.9687098
   8  0.9800794  0.9600386
   9  0.9800794  0.9600386
  10  0.9800794  0.9600386
  11  0.9800794  0.9600386
  12  0.9800794  0.9600386

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was k = 5.

We find that the best model uses K = 5.

Let us now use the base R predict() function and a confusion matrix created using the table() function to check the model’s accuracy on the test data.

Example Code:

# Use model to predict classes for the test set.
pred_cls = predict(mod_knn, tst_data)

# Check the accuracy of the predictions by computing the confusion matrix.
table(Actl = tst_class, Pred = pred_cls)

Output:

> table(Actl = tst_class, Pred = pred_cls)
      Pred
Actl   Blue Red
  Blue   25   0
  Red     0  25

We find that the model predicted the test data class with full accuracy. This was possible because the data was well-segregated in the sample data frame.

In practice, the accuracy will be lower. However, for each model, the repeated k-fold cross-validation procedure gives us a good idea of the accuracy we can expect on new data similar to the training data.

Author: Jesse John
Jesse John avatar Jesse John avatar

Jesse is passionate about data analysis and visualization. He uses the R statistical programming language for all aspects of his work.