多label实现准确率和召回率

孤人 提交于 2020-01-13 01:56:24
#!/usr/bin/env python# -*- coding: utf-8 -*-"""    @File : recall.py    @Author : 郭凯锋    @Time : 2020/1/12 17:57    @Software : PyCharm    @Git-Hub : daguonice    @博客园: https://www.cnblogs.com/daguonice/    """import pandas as pdimport numpy as npfrom sklearn.metrics import recall_scoredef right_or_wrong(ypred_ele, ytrue_ele):    if ypred_ele in ytrue_ele:        return True    else:        return Falsedef get_single_score(y_true, y_pred):    TP = np.sum(np.logical_and(np.equal(y_true, 1), np.equal(y_pred, 1)))    FP = np.sum(np.logical_and(np.equal(y_true, 0), np.equal(y_pred, 1)))    TN = np.sum(np.logical_and(np.equal(y_true, 1), np.equal(y_pred, 0)))    FN = np.sum(np.logical_and(np.equal(y_true, 0), np.equal(y_pred, 0)))    recall = [TP, (TP + FN)]    accuracy = [(TP + TN), len(y_pred)]    return accuracy, recalldef get_multi_score(y_true, y_pred):    list_score = []    for ele in ['相互宝', '健康险', '保险线', '花呗']:        y_true_temp = np.zeros(len(y_true))        y_pred_temp = np.zeros(len(y_pred))        for idx in range(len(y_true)):            if y_pred[idx] == ele:                y_pred_temp[idx] = 1.            if right_or_wrong(ele, y_true[idx]):                y_true_temp[idx] = 1.        accuracy, recall = get_single_score(y_true_temp, y_pred_temp)  # 变成百分制        list_score.append([accuracy, recall])    return list_score  # 返回一个三维列表, 分别是相互宝、健康险、保险线的准确率和召回率的分数def fraction2decimal(result):    res = []    for ele in result:        list_ele = []        for lab in ele:            list_lab = []            for score in lab:                if score[1] == 0:                    list_lab.append(0.0)                else:                    list_lab.append(round(score[0] / score[1], 2))            list_ele.append(list_lab)        res.append(list_ele)    return resdef func(y_true, y_pred):    # y_true是一个一维列表, y_pred是一个二维列表    if len(y_true) != len(y_pred):        raise Exception("The two input lengths are inconsistent")    length = len(y_true)    list_score = get_multi_score(y_true, y_pred)    result = []    temp_list = list_score    for idx in range(length):        list_score = temp_list        if len(y_true[idx]) == 1 and y_pred[idx] not in y_true[idx]:            list_score[1][1][1] += 1            list_score[0][0][1] += 1            result.append(list_score)        elif len(y_true[idx]) == 2 and y_pred[idx] in y_true[idx]:            list_score[0][0][0] += 1            list_score[0][0][1] += 1            list_score[0][1][0] += 1            list_score[0][1][1] += 1            result.append(list_score)        elif len(y_true[idx]) == 2 and y_pred[idx] not in y_true[idx]:            list_score[1][1][1] += 1            list_score[2][1][1] += 1            list_score[0][0][1] += 1            result.append(list_score)        else:            result.append(list_score)    result = fraction2decimal(result)    return resultif __name__ == '__main__':    alist = ['相互宝', '相互宝', '相互宝', '健康险', '保险线', '花呗']    blist = ['健康险', ['健康险', '相互宝'], ['健康险', '保险线'], ['相互宝'], ['保险线'], ['花呗']]    res = func(blist, alist)    # res = pd.DataFrame(res, columns=['相互宝', '健康险', '保险线', '花呗'])    # pd.set_option('display.max_columns', None)    # pd.set_option('display.max_rows', None)    # pd.set_option('display.width', 100000)    # pd.set_option('display.unicode.east_asian_width', True)    print()    print(res)
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!