主要思路是计算每个类别的ap,再对所有的类别取平均得到map。
ap 是p-r 曲线下的面积

def getAplist(pred, label, aplistSavePath):
    ''' pred is output of sigmoid, Calculate the AP for each category and get map, Number of categories in this example is 5000 input : pred [ batch, C ] tensor label[batch, C] tensor output: aplist : [ 5000 ] numpy.array '''
    map_list = []
    for cls_index in range(5000):
        cls_pred = pred[:,cls_index]
        cls_label = label[:,cls_index]
        cls_ap = certainClassAP(cls_pred, cls_label, 5000, 0.5)
        map_list.append(cls_ap)
    map_list = np.array(map_list)
    print ('map is',map_list.mean())
    np.save(aplistSavePath, map_list)

def certainClassAP(model_pred, labels, N, accuracy_th):
    ''' get ap of certain class model_pred: [batch] tensor labels: [batch] tensor N: (e.g. 5000) int accuracy_th: (e.g. 0.5) float '''
    p_list = [0 for i in range(N)]
    r_list = [0 for i in range(N)]
    for i in range (N):
        temp_pred = model_pred[:i+1]
        temp_label = labels[:i+1]
        pred_result = temp_pred > accuracy_th
        pred_result = pred_result.float()
        pred_one_num = torch.sum(pred_result)
        if pred_one_num == 0:
            p_list[i] = 0
            r_list[i] = 0
            continue
        target_one_num = torch.sum(temp_label)
        true_predict_num = torch.sum(pred_result * temp_label)
        precision = true_predict_num / pred_one_num
        recall = true_predict_num / target_one_num
        p_list[i] = precision
        r_list[i] = recall
    precisions = np.array(p_list)
    recalls = np.array(r_list)
    average_precision = 0
    for threshold in np.arange(0, 1.1, 0.1):
        precisions_at_recall_threshold = precisions[recalls >= threshold]
        if precisions_at_recall_threshold.size > 0:
            max_precision = np.max(precisions_at_recall_threshold)
        else:
           max_precision = 0
        average_precision = average_precision + max_precision / 11
    print ('cur class ap:',average_precision)
    return average_precision
getAplist(pred, label, './aplist@0.5.npy')

本文地址:https://blog.csdn.net/weixin_42544131/article/details/110925975