R에서 train() 함수 구현

Sheeraz Gul 2023년6월21일
R에서 train() 함수 구현

train() 방법(캐럿 라이브러리에서)은 분류 및 회귀 훈련에 사용됩니다. 또한 복잡성 매개 변수를 선택하여 모델을 조정하는 데 사용됩니다.

이 튜토리얼은 R의 캐럿 패키지에서 train() 메서드를 사용하는 방법을 보여줍니다.

R에서 train() 함수 구현

train() 메서드를 사용하려면 먼저 캐럿 및 기타 필수 패키지를 설치해야 합니다. 이 경우에는 캐럿 및 mlbench 패키지를 사용합니다.

install.packages("caret", dependencies = c("Depends", "Suggests"))
install.packages('mlbench')

패키지가 설치된 후 로드하여 사용할 수 있습니다. 우리는 train() 메서드에 내장 데이터 세트 iris를 사용합니다.

# Load Packages
library(caret)
library(mlbench)

# Load IRIS dataset
data(iris)

# Show the first Six lines of the iris data set
head(iris)

위의 코드는 패키지와 데이터 세트를 로드하고 iris 데이터 세트의 처음 6개 라인을 표시합니다.

  Sepal.Length Sepal.Width Petal.Length Petal.Width Species
1          5.1         3.5          1.4         0.2  setosa
2          4.9         3.0          1.4         0.2  setosa
3          4.7         3.2          1.3         0.2  setosa
4          4.6         3.1          1.5         0.2  setosa
5          5.0         3.6          1.4         0.2  setosa
6          5.4         3.9          1.7         0.4  setosa

다음으로 해야 할 일은 데이터 세트에서 개체를 만드는 것입니다. 이 개체는 나중에 train() 메서드에서 사용됩니다.

# create binary object
iris$binary  <- ifelse(iris$Species=="setosa",1,0)
iris$Species <- NULL

그런 다음 trainControl() 메서드를 사용하여 리샘플링 방법을 수정합니다. 이 메서드는 여러 매개변수를 사용합니다. 출력은 train() 메서드에서 사용됩니다.

ctrl <- trainControl(method = "repeatedcv",
                        number = 4,
                        savePredictions = TRUE,
                        verboseIter = T,
                        returnResamp = "all")

마지막으로 train() 메서드를 사용하여 데이터 모델을 조정할 수 있습니다. 위에서 만든 이진 개체를 사용합니다.

매개변수로 data, method, familytrControl이 있습니다.

# the train method
iris_train <- train(binary ~.,
                 data=iris,
                 method = "glm",
                 family="binomial",
                 trControl = ctrl)

출력:

+ Fold1.Rep1: parameter=none
- Fold1.Rep1: parameter=none
+ Fold2.Rep1: parameter=none
- Fold2.Rep1: parameter=none
+ Fold3.Rep1: parameter=none
- Fold3.Rep1: parameter=none
+ Fold4.Rep1: parameter=none
- Fold4.Rep1: parameter=none
Aggregating results
Fitting final model on full training set

이제 모델이 훈련되었습니다. 폴드 레벨 성능을 볼 수 있습니다.

# fold level performance
iris_train$resample

리샘플링 출력:

          RMSE  Rsquared          MAE parameter   Resample
1 3.929660e-06 1.0000000 6.420415e-07      none Fold1.Rep1
2 2.735382e-13 1.0000000 6.115382e-14      none Fold2.Rep1
3 8.221919e-12 1.0000000 1.397945e-12      none Fold3.Rep1
4 9.119130e-04 0.9999968 1.479318e-04      none Fold4.Rep1

모델이 훈련되면 데이터 세트의 다른 옵션을 사용하여 중간 모델을 만들 수 있습니다.

완전한 예제 코드

편의를 위한 전체 예제 코드는 다음과 같습니다.

install.packages("caret", dependencies = c("Depends", "Suggests"))
install.packages('mlbench')

# Load Packages
library(caret)
library(mlbench)

# Load IRIS dataset
data(iris)

# Show the first Six lines of the iris data set
head(iris)

# create binary object
iris$binary  <- ifelse(iris$Species=="setosa",1,0)
iris$Species <- NULL

# use trainControl() method to modify the resampling method
ctrl <- trainControl(method = "repeatedcv",
                        number = 4,
                        savePredictions = TRUE,
                        verboseIter = T,
                        returnResamp = "all")
# the train method
iris_train <- train(binary ~.,
                 data=iris,
                 method = "glm",
                 family="binomial",
                 trControl = ctrl)

# fold level performance
iris_train$resample
작가: Sheeraz Gul
Sheeraz Gul avatar Sheeraz Gul avatar

Sheeraz is a Doctorate fellow in Computer Science at Northwestern Polytechnical University, Xian, China. He has 7 years of Software Development experience in AI, Web, Database, and Desktop technologies. He writes tutorials in Java, PHP, Python, GoLang, R, etc., to help beginners learn the field of Computer Science.

LinkedIn Facebook

관련 문장 - R Function