K-Fold Cross-Validation in Python

Manav Narula Oct 10, 2023
  1. Need for Cross-Validation in Python
  2. K-Fold Cross-Validation in Python
  3. Use the sklearn.model_selection.KFold Class to Implement K-Fold in Python
K-Fold Cross-Validation in Python

In Python, we have a lot of machine learning algorithms. We can have supervised and unsupervised learning models trained and tested using a dataset before using the algorithm on real-time data.

Need for Cross-Validation in Python

The traditional approach to training and testing a model involves splitting the dataset into the train and test datasets. Usually, this ratio of the train to test is 70:30.

There are some drawbacks with directly splitting the dataset.

The main drawback is related to overfitting. To get the best performance from a model, we tweak the hyperparameters accordingly to improve the model’s performance on the test data.

However, while repeatedly changing the hyperparameters, we leak knowledge into the model, increasing the risk of overfitting the test data.

People started using training, test, and validation datasets to counter this.

We tune the hyperparameters using the training and validating dataset.
However, the number of samples for learning the model is significantly affected and reduced.

In recent times, people started using the K-Fold cross-validation technique to counter the earlier problems. This is an essential technique as it can help tune the model to choose the model with the best performance.

K-Fold Cross-Validation in Python

The data is split without shuffling into K consecutive folds. Now, every fold is used once for validation, while the remaining folds form the training set (K - 1).

In short, one part of the training set is for validation.

Let us discuss this in detail.

We divide the dataset into two - training and testing, and out of this, the training dataset further divides into K-Folds. One part from this is for validating, while the rest are for training.

The hyperparameters and performance of this model are noted. These steps are repeated until each split validates the dataset.

The performance of the model and the mean standard deviation are noted for each fold. This is repeated for different hyperparameter values, and the best-performing model is selected.

Use the sklearn.model_selection.KFold Class to Implement K-Fold in Python

We can use the sklearn module to implement different machine learning algorithms and techniques in Python. The model_selection.KFold class can implement the K-Fold cross-validation technique in Python.

In the KFold class, we specify the folds with the n_splits parameter, 5 by default.

We can also provide the shuffle parameter, determining whether to shuffle data before splitting. It is False by default.

The random_state parameter also controls each fold’s randomness. For this, shuffle needs to be set to True.

We’ll use an instance of this class with a simple numpy array.

We will provide the array with the split() function. This method will return the indices for the dataset.

Example:

from sklearn.model_selection import KFold
import numpy as np

x = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]
k_fold = KFold(n_splits=3)
for indices_train, indices_test in k_fold.split(x):
    print(indices_train, indices_test)

Output:

[ 4  5  6  7  8  9 10 11] [0 1 2 3]
[ 0  1  2  3  8  9 10 11] [4 5 6 7]
[0 1 2 3 4 5 6 7] [ 8  9 10 11]

In the above example, we set the number of folds to 3.

When working with large datasets, we usually set the value of K as 5. The value of K tends to increase as the dataset gets smaller.

One should note that before sklearn 0.20, the KFold class was a part of the sklearn.cross_validation module. There are other techniques also available for cross-validation of data in Python.

For small datasets, we tend to use the LOOCV technique. Other improved versions of K-Fold are the StratifiedKFold and GroupKFold.

ShuffleSplit is also a common technique used for cross-validation, and we have StratifiedShuffleSplit and GroupShuffleSplit methods.

Author: Manav Narula
Manav Narula avatar Manav Narula avatar

Manav is a IT Professional who has a lot of experience as a core developer in many live projects. He is an avid learner who enjoys learning new things and sharing his findings whenever possible.

LinkedIn

Related Article - Python KFold