데이터/Machine Learning
sklearn classification_report를 이용한 모델 검증
성장하기
2023. 5. 21. 11:50
Classification Report
sklearn에서는 분류 모델의 검증을 위한 classificrion_report api를 제공한다.
text로 된 classification metrics에 대한 기본적인 리포트를 만들어 주며 다음과 같이 사용한다.
sklearn.metrics.classification_report(y_true, y_pred, *, labels=None, target_names=None, sample_weight=None, digits=2, output_dict=False, zero_division='warn')
파라미터
- y_true : 1차원 배열, 레이블 배열, 희소행렬
- 레이블의 실제 값을 입력
- y_pred : 1차원 배열, 레이블 배열, 희소행렬
- classifier가 예측한 값을 입력
- labels : 배열형태 (default=None)
- classification report에 포함시킬 label index를 입력
- target_names : 리스트 or 문자열 형태
- report 내에서 보여줄 타겟의 이름을 입력
- sample_weight : 배열형태 (default=None)
- 다중 분류에서만 의미가 있으며, 클래스 별로 weight값을 다르게 주어 average score를 계산함
- output_dict : boolean
- True의 경우 dictionary 형태로 반환
사용예제
from sklearn.metrics import classification_report
##
y_true = [0, 1, 2, 2, 2]
y_pred = [0, 0, 2, 2, 1]
target_names = ['class 0', 'class 1', 'class 2']
print(classification_report(y_true, y_pred, target_names=target_names))
Confusion Matrix
- 지표들을 이해하기 위해서는 우선 Confusion Matrix를 이해해야한다.
- True Positive(TP) : 실제 True인 정답을 True라고 예측 (정답)
- False Positive(FP) : 실제 False인 정답을 True라고 예측 (오답)
- False Negative(FN) : 실제 True인 정답을 False라고 예측 (오답)
- True Negative(TN) : 실제 False인 정답을 False라고 예측 (정답)
precision
- 정밀도라고도 표현하며, 모델이 True라고 분류한 것 중에서 실제 True의 비율이다.
$$
Precision = \frac{TP}{TP+FP}
$$
recall
- 재현율, Sensitivity라도고 표현하며, 실제 True 중, 모델이 True로 제대로 분류한 비율이다
$$
Recall = \frac{TP}{TP+FN}
$$
- 실제 → Recall, Sensitivity
- 예측 → Precision 으로 외우자! (실Sen, 예Pre)
f1-score
- Recall과 Precision의 조화 평균으로 불균형한 레이블 분포의 데이터를 검증하기 위해 많이 사용한다
불균형한 데이터에서는 성능 검증이 어렵다.
가령 100개의 데이터 중 10개의 데이터만 레이블이 0이고, 90개가 1이라면, 모델이 전부 1이라고 예측해도 정확도는 90%가 되고, Precision 또한 100%가 된다. 반면 Recall은 0%가 된다.
Precision과 Recall은 한가지가 증가하면 다른 한가지가 감소하는 Trade-off현상을 가지고 있으며, 적절한 균형을 가지는 Threshold를 찾거나, 모델을 구현하기 위해 f1-score나 roc 커브 등을 이용한다.
자세한 내용을 알고싶은 분들은 https://sumniya.tistory.com/26 블로그를 참고하시길 바랍니다.
$$
F1 Score = 2 \times \frac{1}{Precision^-1 + Recall^-1} = \frac{2 \times Precision \times Recall}{Precision + Recall}
$$
support
- 각 레이블의 실제 개수를 의미한다.
- 위 예제에서 레이블 0은 1개, 1은 2개, 2는 3개이다.
macro avg
- 가중치를 주지 않은 단순 평균을 의미한다. 이는 레이블 불균형을 고려하지 않아, 불균형한 레이블을 가진 데이터에 사용하기에는 적절하지 않다.
weighted avg
- 각 레이블에 대한 실제 객체 수에 의해 가중치를 부여하고, 가중치가 부여된 평균을 의미한다. 레이블 불균형을 고려한 평균이다.
참고
- https://scikit-learn.org/stable/modules/model_evaluation.html#classification-report
- https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html#sklearn.metrics.classification_report
- https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html#sklearn.metrics.precision_recall_fscore_support
- https://sumniya.tistory.com/26