콜백 함수, callbacks는 딥러닝의 최적화 과정에서 이용되는 함수입니다. 정확하게는 Tensorflow의 keras 라이브러리 내의 callbacks 함수입니다.
일정 이상의 에포크, epoch 가 진행되었음에도 학습이 원하는 대로 이루어지지 않는 경우, 혹은 학습은 되었지만, 과적합으로 인해 검증 데이터의 loss 값이 증가하는 경우 프로그램을 멈출 수 있는 함수입니다.
이번 포스트에서는 콜백 함수의 종류와 코드, 그리고 사용법에 대해서 알아보도록 하겠습니다.
콜백 함수의 종류와 코드를 보기 전에 미리 알아두실 점은 저는 python 3.9를 언어로 사용하며, jupyter notebook을 이용해 코드를 작성했습니다.
📌 콜백함수, callbacks
콜백 함수의 종류는 크게 두 가지로 나누어볼 수 있습니다.
일정 에포크, Epoch가 진행되어도 원하는 결과를 얻지 못할 때, 멈추는 함수 'EarlyStopping', 멈추는 것이 아니라 학습률을 조정하는 'ReduceLROnPlateau'가 있습니다.
'원하는 결과를 얻지 못한다' 는 것은 학습의 정확도가 향상되지 않는 경우, 학습 정확도는 향상되나 검증 데이터의 loss는 증가하는 과적합에 빠지는 경우 등 다양합니다.
콜백 함수는 이러한 기준점에 대해서도 사용자가 선택할 수 있는 매우 유용한 함수입니다.
📌 EarlyStopping
에포크의 진행을 멈추는 EalryStopping 함수의 코드는 다음과 같습니다.
from tensorflow.keras.callbacks import EarlyStopping
early_stopping = EarlyStopping(monitor='acc', patience=2)
저는 early_stopping 이라는 변수에 제가 원하는 EarlyStopping 조건을 저장했습니다.
monitor = 'acc' 는 제가 딥러닝 모델에서 설정한 정확도를 뜻합니다.
patience=2로 설정했으므로 만약 3회 연속 정확도가 향상되지 않는다면 즉시 에포크가 멈출 것입니다.
📌 ReduceLROnPlateau
학습률을 조정해주는 ReduceLROnPlateau 함수의 코드는 다음과 같습니다.
from tensorflow.keras.callbacks import ReduceLROnPlateau
reduce_learning_rate = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=20)
reduce_learning_rate라는 변수에 제가 원하는 ReduceLROnPlateau 조건을 저장하였습니다.
monitor='val_loss'는 딥러닝 모델에서 정한 검증 데이터의 손실 값을 의미합니다.
factor=0.5는 학습률을 절반 수준으로 줄이라고 값을 지정해준 것입니다.
patience=20으로 설정했으므로 만약 21회 연속 검증 데이터의 손실 값이 감소하지 않는다면 학습률을 절반으로 줄일 것입니다.
📌 콜백 함수 사용법
콜백 함수의 종류와 조건 저장방법에 대해서 알아보았으니, 이제 사용하는 방법에 대해서 알아볼 차례입니다.
콜백 함수는 딥러닝 모델을 컴파일한 이후 학습할 때, 사용합니다. 즉, model.compile()을 마친 이후 model.fit()을 할 때, 콜백 함수를 넣어주면 됩니다.
콜백 함수 사용 예제를 코드로 보시겠습니다.
model.fit( partial_x_train, partial_y_train epochs=300,
validation_data=( x_val, y_val ),
callbacks=[early_stopping])
저는 전체 train 데이터를 학습을 위한 partial 데이터와 검증을 위한 val 데이터로 나누었습니다. 따라서 순서대로 학습 데이터, 검증 데이터를 입력하고 마지막에 callbacks=[early_stopping] 을 입력함으로써 제가 미리 지정해둔 조건대로 EarlyStopping이 적용되도록 하였습니다.
만약 ReduceLORnPlateau를 사용하려면 마지막에 callbacks=[reduce_learning_rate]를 입력해주어야 합니다.
해당 변수들은 제가 앞의 코드에서 설정한대로 입력한 것으로 본인이 무엇으로 설정하느냐에 따라서 []안에 들어갈 변수가 달라집니다.
와인 품질 분류 딥러닝 1. 데이터 셋 다운로드 및 전처리
머신러닝과 딥러닝을 이제 막 배우기 시작한 초심자들이 반드시 거쳐가는 관문 중 하나인 와인 품질 분류 딥러닝 모델 만들기에 도전해보기로 했습니다. 언어는 Python 3.9, 사용하는 라이브러리
gamdonge.tistory.com
'Tensorflow 기초' 카테고리의 다른 글
AIFB Associate 자격증 취득 후기 (0) | 2022.08.02 |
---|