ABOUT ME

notou10@yonsei.ac.kr

Today
Yesterday
Total
  • PR curve, threshold
    카테고리 없음 2022. 7. 23. 17:11

    https://hmkim312.github.io/posts/%EC%A0%95%EB%B0%80%EB%8F%84(Precision)%EC%99%80_%EC%9E%AC%ED%98%84%EC%9C%A8(recall)/ 

     

    정밀도(Precision)와 재현율(Recall)

    1. 정밀도와 재현율의 트레이드 오프

    hmkim312.github.io

    import matplotlib.pyplot as plt
    from sklearn.metrics import precision_recall_curve
    import numpy as np
    import os
    def PR_curve(target, pred, iter, cfg):
        out_dir = os.path.join("./PR", cfg.TRAIN.exp_name.split('/')[-1])
        if not os.path.isdir(out_dir):
            os.mkdir(out_dir)
        
        plt.close()
        #import pdb; pdb.set_trace()
        plt.figure(figsize = (8, 6))
        #pred = lr.predict_proba(X_test)[:,1]
        
        # count_big, count_small = [], []
        # for index, i in enumerate(pred[:,1]):
        #     if i >0.99:
        #         count_big.append(index)
                
        #     elif i < 0.01:
        #         count_small.append(index)
            
            # if i >=1.0:
            #     pred[index, 1] = 0.9995 
     
    
        plt.ylim(0,1.01)
        plt.xlabel('num of images')
        plt.ylabel('prediction score')
    
        plt.plot(np.sort(pred[:,1]))
        plt.savefig(f"{out_dir}/element_{iter}.png", dpi = 600)
        plt.close()
        
        plt.ylim(0,1.01)
        plt.xlim(0,1.01)
        
        precision, recalls, thresholds = precision_recall_curve(target, pred[:,1])
        
        plt.plot(thresholds, precision[:len(thresholds)], label = 'precision')
        plt.plot(thresholds, recalls[:len(thresholds)], label = 'recall')
        plt.xlabel('threshold')
        plt.ylabel('P/R score')
        
        plt.grid()
        plt.legend()
        plt.show()
        plt.savefig(f"{out_dir}/PR_cruve_{iter}.png", dpi = 600)
        plt.close()
Designed by Tistory.