Implementieren Sie die Funktion train() in R

Sheeraz Gul 21 Juni 2023
Implementieren Sie die Funktion train() in R

Für das Klassifikations- und Regressionstraining wird die Methode train() (aus der Caret-Bibliothek) verwendet. Es wird auch verwendet, um die Modelle abzustimmen, indem die Komplexitätsparameter ausgewählt werden.

Dieses Tutorial zeigt, wie Sie die Methode train() aus dem Caret-Paket in R verwenden.

Implementieren Sie die Funktion train() in R

Um die Methode train() zu verwenden, müssen wir zuerst das Caret- und andere erforderliche Pakete installieren. In unserem Fall verwenden wir die Pakete caret und mlbench.

install.packages("caret", dependencies = c("Depends", "Suggests"))
install.packages('mlbench')

Nachdem die Pakete installiert sind, können wir sie laden und verwenden. Wir verwenden den eingebauten Datensatz iris für die Methode train():

# Load Packages
library(caret)
library(mlbench)

# Load IRIS dataset
data(iris)

# Show the first Six lines of the iris data set
head(iris)

Der obige Code lädt die Pakete und den Datensatz und zeigt die ersten sechs Zeilen des Iris-Datensatzes:

  Sepal.Length Sepal.Width Petal.Length Petal.Width Species
1          5.1         3.5          1.4         0.2  setosa
2          4.9         3.0          1.4         0.2  setosa
3          4.7         3.2          1.3         0.2  setosa
4          4.6         3.1          1.5         0.2  setosa
5          5.0         3.6          1.4         0.2  setosa
6          5.4         3.9          1.7         0.4  setosa

Als nächstes muss aus dem Datensatz ein Objekt erstellt werden, das später in der Methode train() verwendet wird:

# create binary object
iris$binary  <- ifelse(iris$Species=="setosa",1,0)
iris$Species <- NULL

Eine trainControl()-Methode wird dann verwendet, um die Resampling-Methode zu modifizieren. Diese Methode benötigt mehrere Parameter; die Ausgabe wird in der Methode train() verwendet:

ctrl <- trainControl(method = "repeatedcv",
                        number = 4,
                        savePredictions = TRUE,
                        verboseIter = T,
                        returnResamp = "all")

Schließlich können wir die Methode train() verwenden, um das Datenmodell zu optimieren. Es wird das oben erstellte binäre Objekt verwenden.

Als Parameter haben wir die data, method, family und trControl:

# the train method
iris_train <- train(binary ~.,
                 data=iris,
                 method = "glm",
                 family="binomial",
                 trControl = ctrl)

Ausgang:

+ Fold1.Rep1: parameter=none
- Fold1.Rep1: parameter=none
+ Fold2.Rep1: parameter=none
- Fold2.Rep1: parameter=none
+ Fold3.Rep1: parameter=none
- Fold3.Rep1: parameter=none
+ Fold4.Rep1: parameter=none
- Fold4.Rep1: parameter=none
Aggregating results
Fitting final model on full training set

Das Modell ist jetzt trainiert. Wir können die Leistung auf Faltebene sehen:

# fold level performance
iris_train$resample

Resampling-Ausgang:

          RMSE  Rsquared          MAE parameter   Resample
1 3.929660e-06 1.0000000 6.420415e-07      none Fold1.Rep1
2 2.735382e-13 1.0000000 6.115382e-14      none Fold2.Rep1
3 8.221919e-12 1.0000000 1.397945e-12      none Fold3.Rep1
4 9.119130e-04 0.9999968 1.479318e-04      none Fold4.Rep1

Sobald das Modell trainiert ist, können wir mithilfe der anderen Optionen im Datensatz Zwischenmodelle erstellen.

Vollständiger Beispielcode

Hier ist der vollständige Beispielcode für Ihre Bequemlichkeit:

install.packages("caret", dependencies = c("Depends", "Suggests"))
install.packages('mlbench')

# Load Packages
library(caret)
library(mlbench)

# Load IRIS dataset
data(iris)

# Show the first Six lines of the iris data set
head(iris)

# create binary object
iris$binary  <- ifelse(iris$Species=="setosa",1,0)
iris$Species <- NULL

# use trainControl() method to modify the resampling method
ctrl <- trainControl(method = "repeatedcv",
                        number = 4,
                        savePredictions = TRUE,
                        verboseIter = T,
                        returnResamp = "all")
# the train method
iris_train <- train(binary ~.,
                 data=iris,
                 method = "glm",
                 family="binomial",
                 trControl = ctrl)

# fold level performance
iris_train$resample
Sheeraz Gul avatar Sheeraz Gul avatar

Sheeraz is a Doctorate fellow in Computer Science at Northwestern Polytechnical University, Xian, China. He has 7 years of Software Development experience in AI, Web, Database, and Desktop technologies. He writes tutorials in Java, PHP, Python, GoLang, R, etc., to help beginners learn the field of Computer Science.

LinkedIn Facebook

Verwandter Artikel - R Function