-
PR curve, threshold카테고리 없음 2022. 7. 23. 17:11
정밀도(Precision)와 재현율(Recall)
1. 정밀도와 재현율의 트레이드 오프
hmkim312.github.io
import matplotlib.pyplot as plt from sklearn.metrics import precision_recall_curve import numpy as np import os def PR_curve(target, pred, iter, cfg): out_dir = os.path.join("./PR", cfg.TRAIN.exp_name.split('/')[-1]) if not os.path.isdir(out_dir): os.mkdir(out_dir) plt.close() #import pdb; pdb.set_trace() plt.figure(figsize = (8, 6)) #pred = lr.predict_proba(X_test)[:,1] # count_big, count_small = [], [] # for index, i in enumerate(pred[:,1]): # if i >0.99: # count_big.append(index) # elif i < 0.01: # count_small.append(index) # if i >=1.0: # pred[index, 1] = 0.9995 plt.ylim(0,1.01) plt.xlabel('num of images') plt.ylabel('prediction score') plt.plot(np.sort(pred[:,1])) plt.savefig(f"{out_dir}/element_{iter}.png", dpi = 600) plt.close() plt.ylim(0,1.01) plt.xlim(0,1.01) precision, recalls, thresholds = precision_recall_curve(target, pred[:,1]) plt.plot(thresholds, precision[:len(thresholds)], label = 'precision') plt.plot(thresholds, recalls[:len(thresholds)], label = 'recall') plt.xlabel('threshold') plt.ylabel('P/R score') plt.grid() plt.legend() plt.show() plt.savefig(f"{out_dir}/PR_cruve_{iter}.png", dpi = 600) plt.close()