accuracy.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import json
  2. import argparse
  3. import numpy as np
  4. def argsparser():
  5. parser = argparse.ArgumentParser(prog=__file__)
  6. parser.add_argument('--input', type=str, default='./prediction_result.json', help='path of prediction')
  7. args = parser.parse_args()
  8. return args
  9. def calculate_metrics(y_true, y_pred):
  10. """
  11. 使用numpy手动计算分类评估指标 - 修正版
  12. """
  13. y_pred = y_pred.astype(np.int32)
  14. y_true = y_true.astype(np.int32)
  15. # 获取唯一类别并排序
  16. classes = np.unique(np.concatenate([y_true, y_pred])).astype(np.int32)
  17. n_classes = len(classes)
  18. # 创建混淆矩阵
  19. # TP FN
  20. # FP TN
  21. confusion_matrix = np.zeros((n_classes, n_classes), dtype=np.float32)
  22. # 手动填充混淆矩阵
  23. for true_val, pred_val in zip(y_true, y_pred):
  24. if true_val == pred_val: # 预测正确的样本
  25. if true_val == 1: # 真实为正例
  26. confusion_matrix[0, 0] += 1
  27. else:
  28. confusion_matrix[1, 1] += 1
  29. else: # 预测错误的样本
  30. if true_val == 0: # 真实为负例
  31. confusion_matrix[1, 0] += 1
  32. else: # 真实为正例
  33. confusion_matrix[0, 1] += 1
  34. tp = confusion_matrix[0, 0]
  35. tn = confusion_matrix[1, 1]
  36. fn = confusion_matrix[0, 1]
  37. fp = confusion_matrix[1, 0]
  38. # 计算精度
  39. accuracy = (tp + tn) / (tp + fp + fn + tn + 1e-8)
  40. # 计算准确率
  41. precision = tp / (tp + fp + 1e-8)
  42. # 计算召回率
  43. recall = tp / (tp + fn + 1e-8)
  44. # 计算F1分数
  45. f1 = 2 * precision * recall / (precision + recall + 1e-8)
  46. return {
  47. 'confusion_matrix': confusion_matrix,
  48. 'classes': classes,
  49. 'tp': tp,
  50. 'tn': tn,
  51. 'fp': fp,
  52. 'fn': fn,
  53. 'accuracy': accuracy,
  54. 'precision': precision,
  55. 'recall': recall,
  56. 'f1': f1
  57. }
  58. def main(args):
  59. # 读取json文件
  60. data = None
  61. with open(args.input, 'r') as f:
  62. data = json.load(f)
  63. if data is None:
  64. return 0
  65. # 开始计算精度 格式:真值,预测值
  66. data = [[i['y'], i['x']] for i in data.values()]
  67. data = np.array(data, dtype=np.float32)
  68. # 分离真值和预测值
  69. y_true = data[:, 0] # 第一列是真值
  70. y_pred = data[:, 1] # 第二列是预测值
  71. # 计算评估指标
  72. metrics = calculate_metrics(y_true, y_pred)
  73. # 打印结果
  74. print("=== 分类评估结果 ===")
  75. print(f"总体精度: {metrics['accuracy']:.4f}")
  76. print(f"混淆矩阵:TP FN")
  77. print(f" FP TN")
  78. print(f"{metrics['confusion_matrix']}")
  79. print(f"查准率: {metrics['precision']:.4f}")
  80. print(f"查全率: {metrics['recall']:.4f}")
  81. print(f"F1分数: {metrics['f1']:.4f}")
  82. from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
  83. # 计算各项指标
  84. report = classification_report(y_true, y_pred)
  85. cm = confusion_matrix(y_true, y_pred)
  86. accuracy = accuracy_score(y_true, y_pred)
  87. print(f"=== sklearn 分类评估结果 ===")
  88. print(f"分类报告:\n{report}")
  89. print(f"混淆矩阵:\n{cm}")
  90. print(f"准确率: {accuracy:.4f}")
  91. return 1
  92. if __name__ == '__main__':
  93. args = argsparser()
  94. main(args)