import json import argparse import numpy as np def argsparser(): parser = argparse.ArgumentParser(prog=__file__) parser.add_argument('--input', '-i',type=str, default='./prediction_result.json', help='path of prediction json') args = parser.parse_args() return args def calculate_metrics(y_true, y_pred): """ 使用numpy手动计算分类评估指标 - 修正版 """ y_pred = y_pred.astype(np.int32) y_true = y_true.astype(np.int32) # 获取唯一类别并排序 classes = np.unique(np.concatenate([y_true, y_pred])).astype(np.int32) n_classes = len(classes) # 创建混淆矩阵 # TP FN # FP TN confusion_matrix = np.zeros((n_classes, n_classes), dtype=np.float32) # 手动填充混淆矩阵 for true_val, pred_val in zip(y_true, y_pred): if true_val == pred_val: # 预测正确的样本 if true_val == 1: # 真实为正例 confusion_matrix[0, 0] += 1 else: confusion_matrix[1, 1] += 1 else: # 预测错误的样本 if true_val == 0: # 真实为负例 confusion_matrix[1, 0] += 1 else: # 真实为正例 confusion_matrix[0, 1] += 1 tp = confusion_matrix[0, 0] tn = confusion_matrix[1, 1] fn = confusion_matrix[0, 1] fp = confusion_matrix[1, 0] # 计算精度 accuracy = (tp + tn) / (tp + fp + fn + tn + 1e-8) # 计算准确率 precision = tp / (tp + fp + 1e-8) # 计算召回率 recall = tp / (tp + fn + 1e-8) # 计算F1分数 f1 = 2 * precision * recall / (precision + recall + 1e-8) return { 'confusion_matrix': confusion_matrix, 'classes': classes, 'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1 } def main(args): # 读取json文件 data = None with open(args.input, 'r') as f: data = json.load(f) if data is None: return 0 # 开始计算精度 格式:真值,预测值 data = [[i['y'], i['x']] for i in data.values()] data = np.array(data, dtype=np.float32) # 分离真值和预测值 y_true = data[:, 0] # 第一列是真值 y_pred = data[:, 1] # 第二列是预测值 # # 计算评估指标 # metrics = calculate_metrics(y_true, y_pred) # # # 打印结果 # print("=== 分类评估结果 ===") # print(f"总体精度: {metrics['accuracy']:.4f}") # print(f"混淆矩阵:TP FN") # print(f" FP TN") # print(f"{metrics['confusion_matrix']}") # print(f"查准率: {metrics['precision']:.4f}") # print(f"查全率: {metrics['recall']:.4f}") # print(f"F1分数: {metrics['f1']:.4f}") from sklearn.metrics import classification_report, confusion_matrix, accuracy_score # 计算各项指标 report = classification_report(y_true, y_pred) cm = confusion_matrix(y_true, y_pred) accuracy = accuracy_score(y_true, y_pred) print(f"=== sklearn 分类评估结果 ===") print(f"分类报告:\n{report}") print(f"混淆矩阵:\n{cm}") print(f"准确率: {accuracy:.4f}") return 1 if __name__ == '__main__': args = argsparser() main(args)