accuracy.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import json
  2. import argparse
  3. import numpy as np
  4. def argsparser():
  5. parser = argparse.ArgumentParser(prog=__file__)
  6. parser.add_argument('--input', '-i',type=str, default='./prediction_result.json', help='path of prediction json')
  7. args = parser.parse_args()
  8. return args
  9. def calculate_metrics(y_true, y_pred):
  10. """
  11. 使用numpy手动计算分类评估指标 - 修正版
  12. """
  13. y_pred = y_pred.astype(np.int32)
  14. y_true = y_true.astype(np.int32)
  15. # 获取唯一类别并排序
  16. classes = np.unique(np.concatenate([y_true, y_pred])).astype(np.int32)
  17. n_classes = len(classes)
  18. # 创建混淆矩阵
  19. # TP FN
  20. # FP TN
  21. confusion_matrix = np.zeros((n_classes, n_classes), dtype=np.float32)
  22. # 手动填充混淆矩阵
  23. for true_val, pred_val in zip(y_true, y_pred):
  24. if true_val == pred_val: # 预测正确的样本
  25. if true_val == 1: # 真实为正例
  26. confusion_matrix[0, 0] += 1
  27. else:
  28. confusion_matrix[1, 1] += 1
  29. else: # 预测错误的样本
  30. if true_val == 0: # 真实为负例
  31. confusion_matrix[1, 0] += 1
  32. else: # 真实为正例
  33. confusion_matrix[0, 1] += 1
  34. tp = confusion_matrix[0, 0]
  35. tn = confusion_matrix[1, 1]
  36. fn = confusion_matrix[0, 1]
  37. fp = confusion_matrix[1, 0]
  38. # 计算精度
  39. accuracy = (tp + tn) / (tp + fp + fn + tn + 1e-8)
  40. # 计算准确率
  41. precision = tp / (tp + fp + 1e-8)
  42. # 计算召回率
  43. recall = tp / (tp + fn + 1e-8)
  44. # 计算F1分数
  45. f1 = 2 * precision * recall / (precision + recall + 1e-8)
  46. return {
  47. 'confusion_matrix': confusion_matrix,
  48. 'classes': classes,
  49. 'tp': tp,
  50. 'tn': tn,
  51. 'fp': fp,
  52. 'fn': fn,
  53. 'accuracy': accuracy,
  54. 'precision': precision,
  55. 'recall': recall,
  56. 'f1': f1
  57. }
  58. def main(args):
  59. # 读取json文件
  60. data = None
  61. with open(args.input, 'r') as f:
  62. data = json.load(f)
  63. if data is None:
  64. return 0
  65. # 开始计算精度 格式:真值,预测值
  66. data = [[i['y'], i['x']] for i in data.values()]
  67. data = np.array(data, dtype=np.float32)
  68. # 分离真值和预测值
  69. y_true = data[:, 0] # 第一列是真值
  70. y_pred = data[:, 1] # 第二列是预测值
  71. # # 计算评估指标
  72. # metrics = calculate_metrics(y_true, y_pred)
  73. #
  74. # # 打印结果
  75. # print("=== 分类评估结果 ===")
  76. # print(f"总体精度: {metrics['accuracy']:.4f}")
  77. # print(f"混淆矩阵:TP FN")
  78. # print(f" FP TN")
  79. # print(f"{metrics['confusion_matrix']}")
  80. # print(f"查准率: {metrics['precision']:.4f}")
  81. # print(f"查全率: {metrics['recall']:.4f}")
  82. # print(f"F1分数: {metrics['f1']:.4f}")
  83. from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
  84. # 计算各项指标
  85. report = classification_report(y_true, y_pred)
  86. cm = confusion_matrix(y_true, y_pred)
  87. accuracy = accuracy_score(y_true, y_pred)
  88. print(f"=== sklearn 分类评估结果 ===")
  89. print(f"分类报告:\n{report}")
  90. print(f"混淆矩阵:\n{cm}")
  91. print(f"准确率: {accuracy:.4f}")
  92. return 1
  93. if __name__ == '__main__':
  94. args = argsparser()
  95. main(args)