test.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # BERT文本分类测试脚本
  2. import torch
  3. import pandas as pd
  4. import numpy as np
  5. from torch.utils.data import Dataset, DataLoader
  6. from transformers import AutoTokenizer, AutoModelForSequenceClassification
  7. from tqdm import tqdm
  8. from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
  9. import os
  10. import seaborn as sns
  11. import matplotlib.pyplot as plt
  12. import gc
  13. import time
  14. class TextClassificationTester:
  15. """BERT文本分类测试器"""
  16. def __init__(self, model_path, is_english=False, num_classes=2, max_length=256):
  17. """
  18. 初始化测试器
  19. Args:
  20. model_path: 模型权重文件路径
  21. is_english: 是否为英文模型
  22. num_classes: 分类类别数
  23. max_length: 最大序列长度
  24. """
  25. # 选择模型
  26. if is_english:
  27. self.model_name = "bert-base-uncased"
  28. else:
  29. self.model_name = "bert-base-chinese"
  30. self.is_english = is_english
  31. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  32. self.max_length = max_length
  33. self.num_classes = num_classes
  34. print(f"\n{'='*60}")
  35. print("初始化测试器")
  36. print(f"{'='*60}")
  37. print(f"使用设备: {self.device}")
  38. print(f"模型名称: {self.model_name}")
  39. print(f"模型路径: {model_path}")
  40. print(f"类别数: {num_classes}")
  41. print(f"最大序列长度: {max_length}")
  42. # 加载模型和分词器
  43. print("\n加载模型和分词器...")
  44. self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
  45. self.model = AutoModelForSequenceClassification.from_pretrained(
  46. self.model_name,
  47. num_labels=num_classes
  48. ).to(self.device)
  49. # 加载训练好的权重
  50. if os.path.exists(model_path):
  51. print(f"加载模型权重: {model_path}")
  52. self.model.load_state_dict(torch.load(model_path, map_location=self.device))
  53. print("✓ 模型权重加载成功")
  54. else:
  55. raise FileNotFoundError(f"模型文件不存在: {model_path}")
  56. # 设置为评估模式
  57. self.model.eval()
  58. def load_test_data(self, test_csv_file):
  59. """
  60. 加载测试数据
  61. Args:
  62. test_csv_file: 测试数据CSV文件路径
  63. """
  64. print(f"\n加载测试数据: {test_csv_file}")
  65. self.test_data = pd.read_csv(test_csv_file)
  66. self.test_data = self.test_data.dropna()
  67. print(f"测试样本数: {len(self.test_data)}")
  68. print(f"列名: {self.test_data.columns.tolist()}")
  69. # 检查必要的列
  70. if 'record' not in self.test_data.columns:
  71. raise ValueError("数据文件中缺少 'record' 列")
  72. if 'label' not in self.test_data.columns:
  73. raise ValueError("数据文件中缺少 'label' 列")
  74. # 统计标签分布
  75. label_counts = self.test_data['label'].value_counts().sort_index()
  76. print(f"\n标签分布:")
  77. for label, count in label_counts.items():
  78. print(f" 类别 {label}: {count} ({count/len(self.test_data)*100:.1f}%)")
  79. def predict(self, batch_size=32):
  80. """
  81. 对测试数据进行预测
  82. Args:
  83. batch_size: 批次大小
  84. """
  85. print(f"\n开始预测 (batch_size={batch_size})...")
  86. self.predictions = []
  87. self.labels = []
  88. self.texts = []
  89. with torch.no_grad():
  90. for i in tqdm(range(0, len(self.test_data), batch_size), desc="预测中"):
  91. batch_texts = self.test_data['record'].iloc[i:i+batch_size].tolist()
  92. batch_labels = self.test_data['label'].iloc[i:i+batch_size].tolist()
  93. # 分词
  94. inputs = self.tokenizer(
  95. batch_texts,
  96. max_length=self.max_length,
  97. padding='max_length',
  98. truncation=True,
  99. return_tensors='pt'
  100. )
  101. # 移动到设备
  102. inputs = {k: v.to(self.device) for k, v in inputs.items()}
  103. # 预测
  104. outputs = self.model(**inputs)
  105. preds = torch.argmax(outputs.logits, dim=-1).cpu().numpy()
  106. self.predictions.extend(preds)
  107. self.labels.extend(batch_labels)
  108. self.texts.extend(batch_texts)
  109. self.predictions = np.array(self.predictions)
  110. self.labels = np.array(self.labels)
  111. print(f"✓ 预测完成,共 {len(self.predictions)} 个样本")
  112. def evaluate(self):
  113. """
  114. 评估模型性能并输出详细指标
  115. """
  116. print(f"\n{'='*60}")
  117. print("评估结果")
  118. print(f"{'='*60}")
  119. # 计算各项指标
  120. accuracy = accuracy_score(self.labels, self.predictions)
  121. precision_macro = precision_score(self.labels, self.predictions, average='macro')
  122. recall_macro = recall_score(self.labels, self.predictions, average='macro')
  123. f1_macro = f1_score(self.labels, self.predictions, average='macro')
  124. precision_weighted = precision_score(self.labels, self.predictions, average='weighted')
  125. recall_weighted = recall_score(self.labels, self.predictions, average='weighted')
  126. f1_weighted = f1_score(self.labels, self.predictions, average='weighted')
  127. # 输出总体指标
  128. print("\n【总体指标】")
  129. print(f" 准确率 (Accuracy): {accuracy:.4f} ({accuracy*100:.2f}%)")
  130. print(f"\n Macro平均:")
  131. print(f" 精确率 (Precision): {precision_macro:.4f}")
  132. print(f" 召回率 (Recall): {recall_macro:.4f}")
  133. print(f" F1分数: {f1_macro:.4f}")
  134. print(f"\n Weighted平均:")
  135. print(f" 精确率 (Precision): {precision_weighted:.4f}")
  136. print(f" 召回率 (Recall): {recall_weighted:.4f}")
  137. print(f" F1分数: {f1_weighted:.4f}")
  138. # 输出每个类别的详细指标
  139. print("\n【各类别详细指标】")
  140. print(classification_report(self.labels, self.predictions, digits=4))
  141. # 混淆矩阵
  142. print("\n【混淆矩阵】")
  143. cm = confusion_matrix(self.labels, self.predictions)
  144. print(cm)
  145. # 绘制混淆矩阵热力图
  146. self._plot_confusion_matrix(cm)
  147. return {
  148. 'accuracy': accuracy,
  149. 'precision_macro': precision_macro,
  150. 'recall_macro': recall_macro,
  151. 'f1_macro': f1_macro,
  152. 'precision_weighted': precision_weighted,
  153. 'recall_weighted': recall_weighted,
  154. 'f1_weighted': f1_weighted
  155. }
  156. def _plot_confusion_matrix(self, cm):
  157. """
  158. 绘制混淆矩阵热力图
  159. """
  160. plt.figure(figsize=(8, 6))
  161. sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True)
  162. plt.title('Confusion Matrix')
  163. plt.ylabel('True Label')
  164. plt.xlabel('Predicted Label')
  165. # 保存图片
  166. if self.is_english:
  167. output_path = './confusion_matrix_en.png'
  168. else:
  169. output_path = './confusion_matrix_zh.png'
  170. plt.savefig(output_path, dpi=300, bbox_inches='tight')
  171. print(f"\n✓ 混淆矩阵已保存至: {output_path}")
  172. plt.close()
  173. def save_errors(self, output_csv_file='errors.csv'):
  174. """
  175. 保存预测错误的样本到CSV文件
  176. Args:
  177. output_csv_file: 输出CSV文件路径
  178. """
  179. print(f"\n保存错误预测样本...")
  180. # 找出预测错误的样本
  181. error_indices = np.where(self.predictions != self.labels)[0]
  182. if len(error_indices) == 0:
  183. print("✓ 没有预测错误的样本")
  184. return
  185. # 创建错误样本DataFrame
  186. errors_df = pd.DataFrame({
  187. 'record': [self.texts[i] for i in error_indices],
  188. 'predict': self.predictions[error_indices],
  189. 'label': self.labels[error_indices]
  190. })
  191. # 保存到CSV
  192. errors_df.to_csv(output_csv_file, index=False, encoding='utf-8-sig')
  193. print(f"✓ 错误样本已保存至: {output_csv_file}")
  194. print(f" 错误样本数: {len(errors_df)} / {len(self.labels)} ({len(errors_df)/len(self.labels)*100:.2f}%)")
  195. # 统计错误类型
  196. print(f"\n【错误类型统计】")
  197. error_types = errors_df.groupby(['label', 'predict']).size().reset_index(name='count')
  198. for _, row in error_types.iterrows():
  199. print(f" 真实标签 {row['label']} → 预测为 {row['predict']}: {row['count']} 个")
  200. def run_test(self, test_csv_file, output_csv_file='errors.csv', batch_size=32):
  201. """
  202. 运行完整的测试流程
  203. Args:
  204. test_csv_file: 测试数据文件路径
  205. output_csv_file: 错误样本输出文件路径
  206. batch_size: 批次大小
  207. """
  208. # 加载测试数据
  209. self.load_test_data(test_csv_file)
  210. # 预测
  211. self.predict(batch_size)
  212. # 评估
  213. metrics = self.evaluate()
  214. # 保存错误样本
  215. self.save_errors(output_csv_file)
  216. print(f"\n{'='*60}")
  217. print("测试完成!")
  218. print(f"{'='*60}")
  219. return metrics
  220. if __name__ == '__main__':
  221. # 示例用法
  222. tester = TextClassificationTester(
  223. model_path='./checkpoints/bert-base-chinese_best_acc.pth', # 模型路径
  224. is_english=False, # 中文模型
  225. num_classes=2, # 二分类
  226. max_length=256 # 最大序列长度
  227. )
  228. # 运行测试
  229. metrics = tester.run_test(
  230. test_csv_file='data_zh.csv', # 测试数据文件
  231. output_csv_file='errors_zh.csv', # 错误样本输出文件
  232. batch_size=32 # 批次大小
  233. )
  234. del tester, metrics
  235. gc.collect() # 强制 Python 垃圾回收
  236. torch.cuda.empty_cache() # 清空 CUDA 缓存
  237. time.sleep(2) # 短暂等待即可,不需要太久
  238. tester = TextClassificationTester(
  239. model_path='./checkpoints/bert-base-uncased_best_acc.pth', # 模型路径
  240. is_english=True, # 英文模型
  241. num_classes=2, # 二分类
  242. max_length=256 # 最大序列长度
  243. )
  244. # 运行测试
  245. metrics = tester.run_test(
  246. test_csv_file='data_en.csv', # 测试数据文件
  247. output_csv_file='errors_en.csv', # 错误样本输出文件
  248. batch_size=32 # 批次大小
  249. )