| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304 |
- # BERT文本分类测试脚本
- import torch
- import pandas as pd
- import numpy as np
- from torch.utils.data import Dataset, DataLoader
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
- from tqdm import tqdm
- from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
- import os
- import seaborn as sns
- import matplotlib.pyplot as plt
- import gc
- import time
- class TextClassificationTester:
- """BERT文本分类测试器"""
-
- def __init__(self, model_path, is_english=False, num_classes=2, max_length=256):
- """
- 初始化测试器
-
- Args:
- model_path: 模型权重文件路径
- is_english: 是否为英文模型
- num_classes: 分类类别数
- max_length: 最大序列长度
- """
- # 选择模型
- if is_english:
- self.model_name = "bert-base-uncased"
- else:
- self.model_name = "bert-base-chinese"
- self.is_english = is_english
-
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- self.max_length = max_length
- self.num_classes = num_classes
-
- print(f"\n{'='*60}")
- print("初始化测试器")
- print(f"{'='*60}")
- print(f"使用设备: {self.device}")
- print(f"模型名称: {self.model_name}")
- print(f"模型路径: {model_path}")
- print(f"类别数: {num_classes}")
- print(f"最大序列长度: {max_length}")
-
- # 加载模型和分词器
- print("\n加载模型和分词器...")
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
- self.model = AutoModelForSequenceClassification.from_pretrained(
- self.model_name,
- num_labels=num_classes
- ).to(self.device)
-
- # 加载训练好的权重
- if os.path.exists(model_path):
- print(f"加载模型权重: {model_path}")
- self.model.load_state_dict(torch.load(model_path, map_location=self.device))
- print("✓ 模型权重加载成功")
- else:
- raise FileNotFoundError(f"模型文件不存在: {model_path}")
-
- # 设置为评估模式
- self.model.eval()
-
- def load_test_data(self, test_csv_file):
- """
- 加载测试数据
-
- Args:
- test_csv_file: 测试数据CSV文件路径
- """
- print(f"\n加载测试数据: {test_csv_file}")
- self.test_data = pd.read_csv(test_csv_file)
- self.test_data = self.test_data.dropna()
-
- print(f"测试样本数: {len(self.test_data)}")
- print(f"列名: {self.test_data.columns.tolist()}")
-
- # 检查必要的列
- if 'record' not in self.test_data.columns:
- raise ValueError("数据文件中缺少 'record' 列")
- if 'label' not in self.test_data.columns:
- raise ValueError("数据文件中缺少 'label' 列")
-
- # 统计标签分布
- label_counts = self.test_data['label'].value_counts().sort_index()
- print(f"\n标签分布:")
- for label, count in label_counts.items():
- print(f" 类别 {label}: {count} ({count/len(self.test_data)*100:.1f}%)")
-
- def predict(self, batch_size=32):
- """
- 对测试数据进行预测
-
- Args:
- batch_size: 批次大小
- """
- print(f"\n开始预测 (batch_size={batch_size})...")
-
- self.predictions = []
- self.labels = []
- self.texts = []
-
- with torch.no_grad():
- for i in tqdm(range(0, len(self.test_data), batch_size), desc="预测中"):
- batch_texts = self.test_data['record'].iloc[i:i+batch_size].tolist()
- batch_labels = self.test_data['label'].iloc[i:i+batch_size].tolist()
-
- # 分词
- inputs = self.tokenizer(
- batch_texts,
- max_length=self.max_length,
- padding='max_length',
- truncation=True,
- return_tensors='pt'
- )
-
- # 移动到设备
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
-
- # 预测
- outputs = self.model(**inputs)
- preds = torch.argmax(outputs.logits, dim=-1).cpu().numpy()
-
- self.predictions.extend(preds)
- self.labels.extend(batch_labels)
- self.texts.extend(batch_texts)
-
- self.predictions = np.array(self.predictions)
- self.labels = np.array(self.labels)
-
- print(f"✓ 预测完成,共 {len(self.predictions)} 个样本")
-
- def evaluate(self):
- """
- 评估模型性能并输出详细指标
- """
- print(f"\n{'='*60}")
- print("评估结果")
- print(f"{'='*60}")
-
- # 计算各项指标
- accuracy = accuracy_score(self.labels, self.predictions)
- precision_macro = precision_score(self.labels, self.predictions, average='macro')
- recall_macro = recall_score(self.labels, self.predictions, average='macro')
- f1_macro = f1_score(self.labels, self.predictions, average='macro')
-
- precision_weighted = precision_score(self.labels, self.predictions, average='weighted')
- recall_weighted = recall_score(self.labels, self.predictions, average='weighted')
- f1_weighted = f1_score(self.labels, self.predictions, average='weighted')
-
- # 输出总体指标
- print("\n【总体指标】")
- print(f" 准确率 (Accuracy): {accuracy:.4f} ({accuracy*100:.2f}%)")
- print(f"\n Macro平均:")
- print(f" 精确率 (Precision): {precision_macro:.4f}")
- print(f" 召回率 (Recall): {recall_macro:.4f}")
- print(f" F1分数: {f1_macro:.4f}")
- print(f"\n Weighted平均:")
- print(f" 精确率 (Precision): {precision_weighted:.4f}")
- print(f" 召回率 (Recall): {recall_weighted:.4f}")
- print(f" F1分数: {f1_weighted:.4f}")
-
- # 输出每个类别的详细指标
- print("\n【各类别详细指标】")
- print(classification_report(self.labels, self.predictions, digits=4))
-
- # 混淆矩阵
- print("\n【混淆矩阵】")
- cm = confusion_matrix(self.labels, self.predictions)
- print(cm)
-
- # 绘制混淆矩阵热力图
- self._plot_confusion_matrix(cm)
-
- return {
- 'accuracy': accuracy,
- 'precision_macro': precision_macro,
- 'recall_macro': recall_macro,
- 'f1_macro': f1_macro,
- 'precision_weighted': precision_weighted,
- 'recall_weighted': recall_weighted,
- 'f1_weighted': f1_weighted
- }
-
- def _plot_confusion_matrix(self, cm):
- """
- 绘制混淆矩阵热力图
- """
- plt.figure(figsize=(8, 6))
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True)
- plt.title('Confusion Matrix')
- plt.ylabel('True Label')
- plt.xlabel('Predicted Label')
-
- # 保存图片
- if self.is_english:
- output_path = './confusion_matrix_en.png'
- else:
- output_path = './confusion_matrix_zh.png'
- plt.savefig(output_path, dpi=300, bbox_inches='tight')
- print(f"\n✓ 混淆矩阵已保存至: {output_path}")
- plt.close()
-
- def save_errors(self, output_csv_file='errors.csv'):
- """
- 保存预测错误的样本到CSV文件
-
- Args:
- output_csv_file: 输出CSV文件路径
- """
- print(f"\n保存错误预测样本...")
-
- # 找出预测错误的样本
- error_indices = np.where(self.predictions != self.labels)[0]
-
- if len(error_indices) == 0:
- print("✓ 没有预测错误的样本")
- return
-
- # 创建错误样本DataFrame
- errors_df = pd.DataFrame({
- 'record': [self.texts[i] for i in error_indices],
- 'predict': self.predictions[error_indices],
- 'label': self.labels[error_indices]
- })
-
- # 保存到CSV
- errors_df.to_csv(output_csv_file, index=False, encoding='utf-8-sig')
-
- print(f"✓ 错误样本已保存至: {output_csv_file}")
- print(f" 错误样本数: {len(errors_df)} / {len(self.labels)} ({len(errors_df)/len(self.labels)*100:.2f}%)")
-
- # 统计错误类型
- print(f"\n【错误类型统计】")
- error_types = errors_df.groupby(['label', 'predict']).size().reset_index(name='count')
- for _, row in error_types.iterrows():
- print(f" 真实标签 {row['label']} → 预测为 {row['predict']}: {row['count']} 个")
-
- def run_test(self, test_csv_file, output_csv_file='errors.csv', batch_size=32):
- """
- 运行完整的测试流程
-
- Args:
- test_csv_file: 测试数据文件路径
- output_csv_file: 错误样本输出文件路径
- batch_size: 批次大小
- """
- # 加载测试数据
- self.load_test_data(test_csv_file)
-
- # 预测
- self.predict(batch_size)
-
- # 评估
- metrics = self.evaluate()
-
- # 保存错误样本
- self.save_errors(output_csv_file)
-
- print(f"\n{'='*60}")
- print("测试完成!")
- print(f"{'='*60}")
-
- return metrics
- if __name__ == '__main__':
- # 示例用法
- tester = TextClassificationTester(
- model_path='./checkpoints/bert-base-chinese_best_acc.pth', # 模型路径
- is_english=False, # 中文模型
- num_classes=2, # 二分类
- max_length=256 # 最大序列长度
- )
-
- # 运行测试
- metrics = tester.run_test(
- test_csv_file='data_zh.csv', # 测试数据文件
- output_csv_file='errors_zh.csv', # 错误样本输出文件
- batch_size=32 # 批次大小
- )
- del tester, metrics
- gc.collect() # 强制 Python 垃圾回收
- torch.cuda.empty_cache() # 清空 CUDA 缓存
- time.sleep(2) # 短暂等待即可,不需要太久
- tester = TextClassificationTester(
- model_path='./checkpoints/bert-base-uncased_best_acc.pth', # 模型路径
- is_english=True, # 英文模型
- num_classes=2, # 二分类
- max_length=256 # 最大序列长度
- )
- # 运行测试
- metrics = tester.run_test(
- test_csv_file='data_en.csv', # 测试数据文件
- output_csv_file='errors_en.csv', # 错误样本输出文件
- batch_size=32 # 批次大小
- )
|