Verwenden Sie die Kreuzvalidierung für ein KNN-Klassifizierungsmodell in R

Jesse John 21 Juni 2023
  1. Verschiedene Cross-Validation-Ansätze
  2. Wiederholte K-fache Kreuzvalidierung für ein K-Nächster-Nachbar-Klassifizierungsmodell
Verwenden Sie die Kreuzvalidierung für ein KNN-Klassifizierungsmodell in R

Die Kreuzvalidierung ermöglicht es uns, die Leistung eines Modells anhand neuer Daten zu bewerten, obwohl wir nur über den Trainingsdatensatz verfügen. Es ist eine allgemeine Technik, die auf Regressions- und Klassifizierungsmodelle angewendet werden kann.

In diesem Artikel wird erläutert, wie eine k-fach wiederholte Kreuzvalidierung für ein K-Nearest Neighbor (KNN)-Klassifizierungsmodell durchgeführt wird. Dazu verwenden wir das Paket caret.

Das K in KNN bezieht sich auf die Anzahl der Beobachtungsnachbarn. Andererseits ist k k-fach die Anzahl der Teilmengen der Trainingsdaten.

Verschiedene Cross-Validation-Ansätze

Es gibt verschiedene Ansätze zur Kreuzvalidierung.

Die einfachste Version verwendet eine Teilmenge der Trainingsdaten, um das Modell zu validieren, was als Validierungsset-Ansatz bezeichnet wird. Das Modell wird nur einmal angepasst und dann an der Teilmenge getestet.

Die andere besteht darin, so viele Modelle anzupassen, wie es Beobachtungen gibt, und die durchschnittliche Fehlerrate zu nehmen. Jedes Modell wird mit einer ausgelassenen Beobachtung angepasst; Das Modell wird dann anhand dieser einen Beobachtung getestet.

Dies wird als Leave One Out Cross Validation (LOOCV) bezeichnet.

Der hilfreichste Ansatz beinhaltet:

  1. Aufteilen des Trainingsdatensatzes in k Falten (Gruppen),
  2. Anpassen des Modells k Mal,
  3. Weglassen einer Falte, und
  4. Testen des Modells darauf.

Dies wird als k-fache Kreuzvalidierung bezeichnet. Normalerweise führt ein k-Wert von 5 oder 10 zu guten Ergebnissen.

Eine Verbesserung der k-fachen Kreuzvalidierung beinhaltet das mehrmalige Anpassen des k-fachen Kreuzvalidierungsmodells mit unterschiedlichen Aufteilungen der Falten. Dies wird als wiederholte k-fache Kreuzvalidierung bezeichnet, die wir verwenden werden.

Wiederholte K-fache Kreuzvalidierung für ein K-Nächster-Nachbar-Klassifizierungsmodell

Wir werden nun sehen, wie ein K-Nearest Neighbor (KNN)-Klassifikationsmodell unter Verwendung wiederholter k-facher Kreuzvalidierung angepasst wird. Wir werden das Caret-Paket verwenden.

Das Caret-Paket ist sehr vielseitig und kann zum Erstellen verschiedener Arten von Modellen verwendet werden. Weitere Einzelheiten finden Sie in der Dokumentation zu CRAN.

Normalerweise sagt uns die k-fache Kreuzvalidierung nur, wie genau unser Modell bei neuen Daten erwartet wird.

Nehmen wir zum Beispiel an, wir passen ein K = 5 KNN-Modell an, indem wir eine wiederholte k-fache Kreuzvalidierung mit 10 k-fachen, dreimal wiederholt, verwenden. Das Modell wird 10 Mal für jeden der 3 verschiedenen Datensplits angepasst, und wir erhalten Leistungsmetriken nur für ein Modell: das mit K = 5 Nachbarn.

Darüber hinaus ermöglicht uns das Caret-Paket, KNN-Modelle für unterschiedliche Werte von K anzupassen. Die Funktion meldet dann den Wert von K, der zum besten Modell führt, und erstellt dieses Modell für uns.

Die Funktion createDataPartition() erzeugt eine geschichtete zufällige Aufteilung eines Faktorvektors. Wir werden dies verwenden, um unsere Daten in Trainings- und Testteilmengen zu unterteilen, um die Genauigkeit des Modells zu überprüfen.

Die Funktion train() ist die Hauptfunktion zum Erstellen eines Modells, wobei:

  1. x ist der Datenrahmen mit den Prädiktoren.
  2. y ist der Ergebnisdatenrahmen oder -vektor.
  3. Das Argument Methode nimmt den Modelltyp an, den wir erstellen möchten. Wir geben knn an.
  4. Für preprocess geben wir scale und center an.
  5. Mit dem Argument trControl können wir die Besonderheiten des Kreuzvalidierungsverfahrens spezifizieren.
  6. Das Argument tuneGrid hilft beim Erstellen und Vergleichen mehrerer Modelle. Es benötigt einen Datenrahmen mit dem Namen des abzustimmenden Parameters.

Da wir ein KNN-Modell erstellen, geben wir k in Kleinbuchstaben als Tuning-Parameter für tuneGrid an. Wir stellen einen Vektor von K-Werten von 1 bis 12 bereit, für den die Funktion Modelle erstellen und testen soll.

Die Besonderheiten der Kreuzvalidierung werden mit der Funktion trainControl() an das Argument trControl übergeben.

  1. Für das Argument method geben wir repeatedcv an, weil wir eine wiederholte Kreuzvalidierung wünschen.
  2. Wenn die Methode cv oder repeatedcv ist, gibt das Argument number die Faltungen k an. Wir verwenden 10.
  3. Das Argument repeats gibt an, wie oft der k-fache Split wiederholt werden muss.

Beispielcode:

# Create the data vectors for the demonstration.
# We will create two numeric vectors as predictors.
# Each vector will have two distinct groups to suit our model.
# We will create a factor with two levels.
# The factor levels correspond to the groups in the predictor vectors.

set.seed(564)
vX1a = round(rnorm(100, 2,2))+4
set.seed(574)
vX2a = round(rnorm(100, 15,4))

set.seed(584)
vX1b = round(rnorm(100, 10,3))+5
set.seed(594)
vX2b = round(rnorm(100, 5,4))

vYa = rep("Blue", 100)
vYb = rep("Red", 100)

vX1 = c(vX1a, vX1b)
vX2 = c(vX2a, vX2b)
vY = c(vYa, vYb)

# Dummy column for ordering rows.
set.seed(528)
R = sample(1:200,200)

# Temporary data frame.
temp_df = data.frame(X1 = vX1, X2 = vX2, Y = as.factor(vY), R)

# Packages that we will use.
library(ggplot2)
library(dplyr)

# See the sample data.
temp_df %>% ggplot(aes(x=X1, y = X2, colour = Y)) + geom_point()

# Re-order the rows, just to see that the KNN model works with the rows jumbled up.
# Final data frame.
# Notice that the outputs are a factor vector.
fin_df = temp_df %>% arrange(R) %>% select(X1, X2, Y)
head(fin_df)
str(fin_df)

# Install the caret package if it is not already installed.
# To install, uncomment the next line and run it.
# install.packages("caret")

# Load the caret package.
library(caret)

# Split the data frame into a training set and test set.
# Create a list of row numbers in the training set.
# This function creates a stratified random sample of all the outcome classes.
set.seed(365)
training_row_index = createDataPartition(fin_df[,3], p=0.75, list=FALSE)

# Create training sets of the predictors and the corresponding outcomes.
trg_data = fin_df[training_row_index,1:2]
trg_class = fin_df[training_row_index,3]

# Create the test set of predictors and the outcomes that we will later use.
tst_data = fin_df[-training_row_index,1:2]
tst_class = fin_df[-training_row_index,3]

# Let us check if the sample is stratified:
table(tst_class)
# Obviously, the training sample will complement these numbers of the totals.

# We will build a K-Nearest neighbors model using repeated k-fold cross-validation.
# The arguments are described in the article.
mod_knn = train(x = trg_data,
                y = trg_class,
                method = "knn",
                preProcess = c("center", "scale"),
                tuneGrid = data.frame(k = c(1:12)),
                trControl = trainControl(method = "repeatedcv",
                                         number = 10,
                                         repeats = 3)
                )

# View the fitted model.
mod_knn

Ausgang:

> head(fin_df)
  X1 X2    Y
1 15  6  Red
2 15  2  Red
3 14  3  Red
4  4 22 Blue
5 20 -3  Red
6  4 22 Blue

> str(fin_df)
'data.frame':	200 obs. of  3 variables:
 $ X1: num  15 15 14 4 20 4 15 2 20 13 ...
 $ X2: num  6 2 3 22 -3 22 7 16 9 -6 ...
 $ Y : Factor w/ 2 levels "Blue","Red": 2 2 2 1 2 1 2 1 2 2 ...

> # Let us check if the sample is really stratified:
> table(tst_class)
tst_class
Blue  Red
  25   25

> # View the fitted model.
> mod_knn
k-Nearest Neighbors

150 samples
  2 predictor
  2 classes: 'Blue', 'Red'

Pre-processing: centered (2), scaled (2)
Resampling: Cross-Validated (10 fold, repeated 3 times)
Summary of sample sizes: 135, 136, 135, 135, 135, 134, ...
Resampling results across tuning parameters:

  k   Accuracy   Kappa
   1  0.9710317  0.9420024
   2  0.9736905  0.9473602
   3  0.9753373  0.9505141
   4  0.9842460  0.9683719
   5  0.9864683  0.9728764
   6  0.9843849  0.9687098
   7  0.9843849  0.9687098
   8  0.9800794  0.9600386
   9  0.9800794  0.9600386
  10  0.9800794  0.9600386
  11  0.9800794  0.9600386
  12  0.9800794  0.9600386

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was k = 5.

Wir finden, dass das beste Modell K = 5 verwendet.

Lassen Sie uns nun die Basis-R-Funktion predict() und eine Konfusionsmatrix verwenden, die mit der Funktion table() erstellt wurde, um die Genauigkeit des Modells anhand der Testdaten zu überprüfen.

Beispielcode:

# Use model to predict classes for the test set.
pred_cls = predict(mod_knn, tst_data)

# Check the accuracy of the predictions by computing the confusion matrix.
table(Actl = tst_class, Pred = pred_cls)

Ausgang:

> table(Actl = tst_class, Pred = pred_cls)
      Pred
Actl   Blue Red
  Blue   25   0
  Red     0  25

Wir stellen fest, dass das Modell die Testdatenklasse mit voller Genauigkeit vorhergesagt hat. Dies war möglich, weil die Daten im Beispieldatenrahmen gut getrennt waren.

In der Praxis wird die Genauigkeit geringer sein. Für jedes Modell gibt uns das wiederholte k-fache Kreuzvalidierungsverfahren jedoch eine gute Vorstellung von der Genauigkeit, die wir bei neuen Daten erwarten können, die den Trainingsdaten ähneln.

Autor: Jesse John
Jesse John avatar Jesse John avatar

Jesse is passionate about data analysis and visualization. He uses the R statistical programming language for all aspects of his work.