# BERT文本分类测试脚本 import torch import pandas as pd import numpy as np from torch.utils.data import Dataset, DataLoader from transformers import AutoTokenizer, AutoModelForSequenceClassification from tqdm import tqdm from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score import os import seaborn as sns import matplotlib.pyplot as plt import gc import time class TextClassificationTester: """BERT文本分类测试器""" def __init__(self, model_path, is_english=False, num_classes=2, max_length=256): """ 初始化测试器 Args: model_path: 模型权重文件路径 is_english: 是否为英文模型 num_classes: 分类类别数 max_length: 最大序列长度 """ # 选择模型 if is_english: self.model_name = "bert-base-uncased" else: self.model_name = "bert-base-chinese" self.is_english = is_english self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.max_length = max_length self.num_classes = num_classes print(f"\n{'='*60}") print("初始化测试器") print(f"{'='*60}") print(f"使用设备: {self.device}") print(f"模型名称: {self.model_name}") print(f"模型路径: {model_path}") print(f"类别数: {num_classes}") print(f"最大序列长度: {max_length}") # 加载模型和分词器 print("\n加载模型和分词器...") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForSequenceClassification.from_pretrained( self.model_name, num_labels=num_classes ).to(self.device) # 加载训练好的权重 if os.path.exists(model_path): print(f"加载模型权重: {model_path}") self.model.load_state_dict(torch.load(model_path, map_location=self.device)) print("✓ 模型权重加载成功") else: raise FileNotFoundError(f"模型文件不存在: {model_path}") # 设置为评估模式 self.model.eval() def load_test_data(self, test_csv_file): """ 加载测试数据 Args: test_csv_file: 测试数据CSV文件路径 """ print(f"\n加载测试数据: {test_csv_file}") self.test_data = pd.read_csv(test_csv_file) self.test_data = self.test_data.dropna() print(f"测试样本数: {len(self.test_data)}") print(f"列名: {self.test_data.columns.tolist()}") # 检查必要的列 if 'record' not in self.test_data.columns: raise ValueError("数据文件中缺少 'record' 列") if 'label' not in self.test_data.columns: raise ValueError("数据文件中缺少 'label' 列") # 统计标签分布 label_counts = self.test_data['label'].value_counts().sort_index() print(f"\n标签分布:") for label, count in label_counts.items(): print(f" 类别 {label}: {count} ({count/len(self.test_data)*100:.1f}%)") def predict(self, batch_size=32): """ 对测试数据进行预测 Args: batch_size: 批次大小 """ print(f"\n开始预测 (batch_size={batch_size})...") self.predictions = [] self.labels = [] self.texts = [] with torch.no_grad(): for i in tqdm(range(0, len(self.test_data), batch_size), desc="预测中"): batch_texts = self.test_data['record'].iloc[i:i+batch_size].tolist() batch_labels = self.test_data['label'].iloc[i:i+batch_size].tolist() # 分词 inputs = self.tokenizer( batch_texts, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) # 移动到设备 inputs = {k: v.to(self.device) for k, v in inputs.items()} # 预测 outputs = self.model(**inputs) preds = torch.argmax(outputs.logits, dim=-1).cpu().numpy() self.predictions.extend(preds) self.labels.extend(batch_labels) self.texts.extend(batch_texts) self.predictions = np.array(self.predictions) self.labels = np.array(self.labels) print(f"✓ 预测完成,共 {len(self.predictions)} 个样本") def evaluate(self): """ 评估模型性能并输出详细指标 """ print(f"\n{'='*60}") print("评估结果") print(f"{'='*60}") # 计算各项指标 accuracy = accuracy_score(self.labels, self.predictions) precision_macro = precision_score(self.labels, self.predictions, average='macro') recall_macro = recall_score(self.labels, self.predictions, average='macro') f1_macro = f1_score(self.labels, self.predictions, average='macro') precision_weighted = precision_score(self.labels, self.predictions, average='weighted') recall_weighted = recall_score(self.labels, self.predictions, average='weighted') f1_weighted = f1_score(self.labels, self.predictions, average='weighted') # 输出总体指标 print("\n【总体指标】") print(f" 准确率 (Accuracy): {accuracy:.4f} ({accuracy*100:.2f}%)") print(f"\n Macro平均:") print(f" 精确率 (Precision): {precision_macro:.4f}") print(f" 召回率 (Recall): {recall_macro:.4f}") print(f" F1分数: {f1_macro:.4f}") print(f"\n Weighted平均:") print(f" 精确率 (Precision): {precision_weighted:.4f}") print(f" 召回率 (Recall): {recall_weighted:.4f}") print(f" F1分数: {f1_weighted:.4f}") # 输出每个类别的详细指标 print("\n【各类别详细指标】") print(classification_report(self.labels, self.predictions, digits=4)) # 混淆矩阵 print("\n【混淆矩阵】") cm = confusion_matrix(self.labels, self.predictions) print(cm) # 绘制混淆矩阵热力图 self._plot_confusion_matrix(cm) return { 'accuracy': accuracy, 'precision_macro': precision_macro, 'recall_macro': recall_macro, 'f1_macro': f1_macro, 'precision_weighted': precision_weighted, 'recall_weighted': recall_weighted, 'f1_weighted': f1_weighted } def _plot_confusion_matrix(self, cm): """ 绘制混淆矩阵热力图 """ plt.figure(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True) plt.title('Confusion Matrix') plt.ylabel('True Label') plt.xlabel('Predicted Label') # 保存图片 if self.is_english: output_path = './confusion_matrix_en.png' else: output_path = './confusion_matrix_zh.png' plt.savefig(output_path, dpi=300, bbox_inches='tight') print(f"\n✓ 混淆矩阵已保存至: {output_path}") plt.close() def save_errors(self, output_csv_file='errors.csv'): """ 保存预测错误的样本到CSV文件 Args: output_csv_file: 输出CSV文件路径 """ print(f"\n保存错误预测样本...") # 找出预测错误的样本 error_indices = np.where(self.predictions != self.labels)[0] if len(error_indices) == 0: print("✓ 没有预测错误的样本") return # 创建错误样本DataFrame errors_df = pd.DataFrame({ 'record': [self.texts[i] for i in error_indices], 'predict': self.predictions[error_indices], 'label': self.labels[error_indices] }) # 保存到CSV errors_df.to_csv(output_csv_file, index=False, encoding='utf-8-sig') print(f"✓ 错误样本已保存至: {output_csv_file}") print(f" 错误样本数: {len(errors_df)} / {len(self.labels)} ({len(errors_df)/len(self.labels)*100:.2f}%)") # 统计错误类型 print(f"\n【错误类型统计】") error_types = errors_df.groupby(['label', 'predict']).size().reset_index(name='count') for _, row in error_types.iterrows(): print(f" 真实标签 {row['label']} → 预测为 {row['predict']}: {row['count']} 个") def run_test(self, test_csv_file, output_csv_file='errors.csv', batch_size=32): """ 运行完整的测试流程 Args: test_csv_file: 测试数据文件路径 output_csv_file: 错误样本输出文件路径 batch_size: 批次大小 """ # 加载测试数据 self.load_test_data(test_csv_file) # 预测 self.predict(batch_size) # 评估 metrics = self.evaluate() # 保存错误样本 self.save_errors(output_csv_file) print(f"\n{'='*60}") print("测试完成!") print(f"{'='*60}") return metrics if __name__ == '__main__': # 示例用法 tester = TextClassificationTester( model_path='./checkpoints/bert-base-chinese_best_acc.pth', # 模型路径 is_english=False, # 中文模型 num_classes=2, # 二分类 max_length=256 # 最大序列长度 ) # 运行测试 metrics = tester.run_test( test_csv_file='data_zh.csv', # 测试数据文件 output_csv_file='errors_zh.csv', # 错误样本输出文件 batch_size=32 # 批次大小 ) del tester, metrics gc.collect() # 强制 Python 垃圾回收 torch.cuda.empty_cache() # 清空 CUDA 缓存 time.sleep(2) # 短暂等待即可,不需要太久 tester = TextClassificationTester( model_path='./checkpoints/bert-base-uncased_best_acc.pth', # 模型路径 is_english=True, # 英文模型 num_classes=2, # 二分类 max_length=256 # 最大序列长度 ) # 运行测试 metrics = tester.run_test( test_csv_file='data_en.csv', # 测试数据文件 output_csv_file='errors_en.csv', # 错误样本输出文件 batch_size=32 # 批次大小 )