classification.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. """
  2. 分类任务评估指标
  3. """
  4. import numpy as np
  5. from sklearn.metrics import (
  6. accuracy_score, precision_score, recall_score, f1_score,
  7. confusion_matrix, classification_report, roc_auc_score
  8. )
  9. from typing import List, Dict, Any, Optional, Union
  10. import logging
  11. logger = logging.getLogger(__name__)
  12. class ClassificationMetrics:
  13. """分类任务评估指标"""
  14. def __init__(self):
  15. """初始化分类指标"""
  16. pass
  17. def accuracy(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
  18. """准确率"""
  19. return accuracy_score(y_true, y_pred)
  20. def precision(self, y_true: np.ndarray, y_pred: np.ndarray,
  21. average: str = 'weighted') -> float:
  22. """精确率"""
  23. return precision_score(y_true, y_pred, average=average, zero_division=0)
  24. def recall(self, y_true: np.ndarray, y_pred: np.ndarray,
  25. average: str = 'weighted') -> float:
  26. """召回率"""
  27. return recall_score(y_true, y_pred, average=average, zero_division=0)
  28. def f1_score(self, y_true: np.ndarray, y_pred: np.ndarray,
  29. average: str = 'weighted') -> float:
  30. """F1分数"""
  31. return f1_score(y_true, y_pred, average=average, zero_division=0)
  32. def confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
  33. """混淆矩阵"""
  34. return confusion_matrix(y_true, y_pred)
  35. def classification_report(self, y_true: np.ndarray, y_pred: np.ndarray,
  36. target_names: Optional[List[str]] = None) -> str:
  37. """分类报告"""
  38. return classification_report(y_true, y_pred, target_names=target_names)
  39. def roc_auc(self, y_true: np.ndarray, y_pred_proba: np.ndarray,
  40. average: str = 'weighted') -> float:
  41. """ROC AUC分数"""
  42. try:
  43. return roc_auc_score(y_true, y_pred_proba, average=average)
  44. except ValueError as e:
  45. logger.warning(f"无法计算ROC AUC: {e}")
  46. return 0.0
  47. def compute_all_metrics(self, y_true: np.ndarray, y_pred: np.ndarray,
  48. y_pred_proba: Optional[np.ndarray] = None) -> Dict[str, float]:
  49. """计算所有指标"""
  50. metrics = {
  51. 'accuracy': self.accuracy(y_true, y_pred),
  52. 'precision': self.precision(y_true, y_pred),
  53. 'recall': self.recall(y_true, y_pred),
  54. 'f1_score': self.f1_score(y_true, y_pred)
  55. }
  56. if y_pred_proba is not None:
  57. metrics['roc_auc'] = self.roc_auc(y_true, y_pred_proba)
  58. return metrics
  59. class MultiClassMetrics:
  60. """多分类任务评估指标"""
  61. def __init__(self):
  62. """初始化多分类指标"""
  63. pass
  64. def macro_precision(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
  65. """宏平均精确率"""
  66. return precision_score(y_true, y_pred, average='macro', zero_division=0)
  67. def macro_recall(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
  68. """宏平均召回率"""
  69. return recall_score(y_true, y_pred, average='macro', zero_division=0)
  70. def macro_f1(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
  71. """宏平均F1分数"""
  72. return f1_score(y_true, y_pred, average='macro', zero_division=0)
  73. def micro_precision(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
  74. """微平均精确率"""
  75. return precision_score(y_true, y_pred, average='micro', zero_division=0)
  76. def micro_recall(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
  77. """微平均召回率"""
  78. return recall_score(y_true, y_pred, average='micro', zero_division=0)
  79. def micro_f1(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
  80. """微平均F1分数"""
  81. return f1_score(y_true, y_pred, average='micro', zero_division=0)
  82. def per_class_metrics(self, y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, Dict[str, float]]:
  83. """每个类别的指标"""
  84. unique_labels = np.unique(np.concatenate([y_true, y_pred]))
  85. per_class = {}
  86. for label in unique_labels:
  87. # 二分类指标
  88. y_true_binary = (y_true == label).astype(int)
  89. y_pred_binary = (y_pred == label).astype(int)
  90. tp = np.sum((y_true_binary == 1) & (y_pred_binary == 1))
  91. fp = np.sum((y_true_binary == 0) & (y_pred_binary == 1))
  92. fn = np.sum((y_true_binary == 1) & (y_pred_binary == 0))
  93. precision = tp / (tp + fp) if (tp + fp) > 0 else 0
  94. recall = tp / (tp + fn) if (tp + fn) > 0 else 0
  95. f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
  96. per_class[str(label)] = {
  97. 'precision': precision,
  98. 'recall': recall,
  99. 'f1_score': f1
  100. }
  101. return per_class
  102. class BinaryClassificationMetrics:
  103. """二分类任务评估指标"""
  104. def __init__(self):
  105. """初始化二分类指标"""
  106. pass
  107. def true_positive(self, y_true: np.ndarray, y_pred: np.ndarray) -> int:
  108. """真正例"""
  109. return np.sum((y_true == 1) & (y_pred == 1))
  110. def false_positive(self, y_true: np.ndarray, y_pred: np.ndarray) -> int:
  111. """假正例"""
  112. return np.sum((y_true == 0) & (y_pred == 1))
  113. def true_negative(self, y_true: np.ndarray, y_pred: np.ndarray) -> int:
  114. """真负例"""
  115. return np.sum((y_true == 0) & (y_pred == 0))
  116. def false_negative(self, y_true: np.ndarray, y_pred: np.ndarray) -> int:
  117. """假负例"""
  118. return np.sum((y_true == 1) & (y_pred == 0))
  119. def sensitivity(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
  120. """敏感性 (召回率)"""
  121. tp = self.true_positive(y_true, y_pred)
  122. fn = self.false_negative(y_true, y_pred)
  123. return tp / (tp + fn) if (tp + fn) > 0 else 0
  124. def specificity(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
  125. """特异性"""
  126. tn = self.true_negative(y_true, y_pred)
  127. fp = self.false_positive(y_true, y_pred)
  128. return tn / (tn + fp) if (tn + fp) > 0 else 0
  129. def positive_predictive_value(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
  130. """阳性预测值 (精确率)"""
  131. tp = self.true_positive(y_true, y_pred)
  132. fp = self.false_positive(y_true, y_pred)
  133. return tp / (tp + fp) if (tp + fp) > 0 else 0
  134. def negative_predictive_value(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
  135. """阴性预测值"""
  136. tn = self.true_negative(y_true, y_pred)
  137. fn = self.false_negative(y_true, y_pred)
  138. return tn / (tn + fn) if (tn + fn) > 0 else 0