[sklearn] 교차 검증
23 Mar 2020 | Scikit--Learn머신러닝은 데이터에 기반한다. 그리고 데이터는 이상치, 분포도, 다양한 속성값, 피처 중요도 등 여러 가지 머신러닝에 영향을 미치는 요소를 가지고 있다. 특정 머신러닝 알고리즘에서 최적으로 동작할 수 있도록 데이터를 선별해 학습한다면 실제 데이터 양식과는 많은 차이가 있을 것이고 결국 성능 저하로 이어질 수 있다.
교차 검증은 이런 데이터 편중을 막기 위해 별도의 여러 세트로 구성된 학습 데이터 세트와 검증 데이터 세트에서 학습과 평가를 수행하는 것이다. 데이터를 학습 데이터와 테스트 데이터 세트로 구분하고 학습 데이터는 다시 학습데이터 세트와 검증 데이터 세트로 나눈다. 학습 데이터를 학습한 모델의 성능을 1차적으로 평가하는 검증 데이터가 있고 모든 학습/검증 과정이 완료된 후 최종적으로 성능을 평가하기 위해 테스트 데이터 세트가 존재한다.
K-Fold Cross Validation
가장 보편적인 교차 검증 기법. K개의 데이터 폴드 세트를 만들어서 K번만큼 각 폴드 세트에 학습과 검증 평가를 반복적으로 수행하는 방법이다. K번의 학습과 검증 평가를 반복 수행하여 나온 K개의 예측 평가들의 평균을 구해서 K폴드 평가 결과로 반영하면 된다.
사이킷런에서는 K 폴드 교차 검증 프로세스를 구현하기 위해 KFold와 StratifiedKFold 클래스를 제공한다.
K Fold
Stratified K Fold
Stratified K 폴드는 불균형한(imbalanced) 분포도를 가진 레이블 데이터 집합을 위한 K 폴드 방식임. KFold로 분할된 레이블 데이터 세트가 전체 레이블 값의 분포도를 반영하지 못하는 문제를 해결해준다. 이를 위해 Stratified K 폴드는 원본 데이터의 레이블 분포를 먼저 고려한 뒤 이 분포와 동일하게 학습과 검증 데이터 세트를 분배한다.
StratifiedKFold 사용을 위해서는 레이블 데이터 분포도에 따라 학습/검증 데이터를 나누기 때문에 split( ) 메서드에 인자로 피처 데이터 세트뿐만 아니라 레이블 데이터 세트도 반드시 넣어줘야한다.
iris데이터를 3개의 폴드로 나누면 아래와 같은 결과를 보인다.
StratifiedKFold로 붓꽃 데이터 예측을 하면 아래와 같다.
분류에는 웬만하면 StratifiedKFold 쓰는게 좋다고 한다. 회귀에서는 Stratified K 폴드가 지원되지 않는대ㅔ 그 이유는 회귀의 값은 이산값 형태의 레이블이 아닌 연속형 숫자이기 때문에 결정값별로 분포를 정하는게 의미가 없어서 그렇다.
cross_val_score( )
사이킷런에서는 교차 검증을 보다 편리하게 할 수 있는 API를 제공한다. 대표적인 것이 cross_val_score( )이다. 기존 KFold는 순서는 다음과 같다.
- 폴드 세트를 설정하고
- for 루프에서 반복적으로 학습 및 테스트 데이터의 인덱스를 추출한 뒤
- 반복적으로 학습과 예측을 수행하고 예측 성능을 반환한다.
cross_val_score( )는 이 과정을 한번에 수행한다.
선언 형태는 위와 같이 생겼는데 이때 estimator, X, y, scoring, cv가 주요 파라미터이다.
- estimator : classifier 또는 regressor
- X, y : feature dataset, label dataset
- scoring : 예측 성능 평가 지표
- cv : 교차 검증 폴드 수
cross_val_scroe( ) 의 리턴값은 scoring 파라미터로 지정된 성능 지표 측정값을 배열 형태로 반환한다 (이 때, 배열 길이는 cv와 동일). classifierRㅏ 입력되면 Stratified K fold방식으로 레이블값의 분포에 따라 학습/테스트 세트를 분할하며 회귀인 경우는 Stratified가 불가능하므로 그냥 K Fold로 분할한다.
cross_val_score( )은 내부에서 학습(fit), 예측(fit), 평가(evaluation) 해주므로 간단하게 교차 검증을 수행할 수 있다. 비슷한 API로 cross_validate( )가 있다. 얘는 여러개의 평가 지표를 반환할 수 있다. 또한 학습 데이터에 대한 성능 평가 지표와 수행 시간도 같이 제공된다.
Comments