jiyuhang 3 주 전
커밋
ada369d2b5
5개의 변경된 파일934개의 추가작업 그리고 0개의 파일을 삭제
  1. 6 0
      .gitignore
  2. 236 0
      export_records.py
  3. 68 0
      models_download.py
  4. 304 0
      test.py
  5. 320 0
      train.py

+ 6 - 0
.gitignore

@@ -0,0 +1,6 @@
+models/
+*.csv
+*.pth
+*.db
+*.png
+.idea/

+ 236 - 0
export_records.py

@@ -0,0 +1,236 @@
+"""
+从 chat_history.db 数据库读取 intent_recognition_records 表,
+按语言(中文/英文)分别导出为 CSV 文件
+支持翻译功能:将中文记录翻译为英文,英文记录翻译为中文
+"""
+
+import sqlite3
+import csv
+import re
+import translators as ts
+import pandas as pd
+
+def translate(text:str, from_language:str, to_language:str, translator: str = 'youdao', timeout: int = 10) -> str:
+    """
+    中文翻译为英文
+
+    参数:
+    text: 要翻译的中文文本
+    translator: 翻译引擎 ('google', 'bing', 'baidu', 'youdao'等)
+    timeout: 超时时间(秒)
+
+    返回:
+    翻译后的英文文本
+    """
+    if not text or not isinstance(text, str):
+        return text
+
+    try:
+        # 清理文本
+        text = text.strip()
+        if not text:
+            return ""
+
+        # 使用translators库进行翻译
+        translated_text = ts.translate_text(
+            text,
+            translator=translator,
+            from_language=from_language,
+            to_language=to_language,
+            timeout=timeout
+        )
+
+        # 如果返回None,返回原文本
+        if translated_text is None:
+            return text
+
+        # 处理可能的字典返回值(某些翻译引擎可能返回字典)
+        if isinstance(translated_text, dict):
+            # 尝试从字典中提取翻译结果
+            if 'text' in translated_text:
+                translated_text = translated_text['text']
+            elif 'translated' in translated_text:
+                translated_text = translated_text['translated']
+            elif 'result' in translated_text:
+                translated_text = translated_text['result']
+            else:
+                print(f"  [警告] 未知的字典格式: {translated_text}")
+                return text
+
+        # 确保返回字符串
+        if isinstance(translated_text, str):
+            return translated_text.strip()
+        else:
+            print(f"  [警告] 非字符串返回: type={type(translated_text)}")
+            return text
+
+    except Exception as e:
+        print(f"翻译错误: {text} -> {str(e)}")
+        # 出错时返回原文本
+        return text
+
+def is_chinese(text):
+    """判断文本是否包含中文字符"""
+    if not text:
+        return False
+    # 匹配中文字符的 Unicode 范围
+    chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
+    return bool(chinese_pattern.search(text))
+
+
+def translate_to_chinese(text):
+    """将文本翻译为中文"""
+    if not text:
+        return text
+    try:
+        result = translate(text, from_language='en', to_language='zh')
+        return result if result else text
+    except Exception as e:
+        print(f"翻译失败 (->中文): {text[:30]}... 错误: {e}")
+        return text
+
+
+def translate_to_english(text):
+    """将文本翻译为英文"""
+    if not text:
+        return text
+    try:
+        result = translate(text, from_language='zh', to_language='en')
+        return result if result else text
+    except Exception as e:
+        print(f"翻译失败 (->英文): {text[:30]}... 错误: {e}")
+        return text
+
+
+def export_records():
+    # 连接数据库
+    conn = sqlite3.connect('chat_history.db')
+    cursor = conn.cursor()
+    
+    # 读取所有数据
+    cursor.execute("SELECT id, record, time, label FROM intent_recognition_records")
+    all_records = cursor.fetchall()
+    
+    # 按语言分类
+    chinese_records = []
+    english_records = []
+    
+    for record in all_records:
+        record_text = record[1]  # record 字段
+        if is_chinese(record_text):
+            chinese_records.append(record)
+        else:
+            english_records.append(record)
+    
+    # 翻译中文记录为英文
+    print("正在翻译中文记录为英文...")
+    chinese_records_en = []
+    for i, record in enumerate(chinese_records):
+        id_, text, time, label = record
+        translated_text = translate_to_english(text)  # 中文翻译为英文
+        chinese_records_en.append((id_, translated_text, time, label))
+        if (i + 1) % 100 == 0:
+            print(f"  已翻译 {i + 1}/{len(chinese_records)} 条中文记录")
+    
+    # 翻译英文记录为中文
+    print("正在翻译英文记录为中文...")
+    english_records_zh = []
+    for i, record in enumerate(english_records):
+        id_, text, time, label = record
+        translated_text = translate_to_chinese(text)
+        english_records_zh.append((id_, translated_text, time, label))
+        if (i + 1) % 100 == 0:
+            print(f"  已翻译 {i + 1}/{len(english_records)} 条英文记录")
+    
+    # 合并记录
+    # 中文CSV:chinese_records + english_records_cn(全部为中文内容)
+    all_chinese_records = chinese_records + english_records_zh
+    print(f'中文记录为{len(all_chinese_records)}条')
+    # 英文CSV:english_records + chinese_records_en(全部为英文内容)
+    all_english_records = english_records + chinese_records_en
+    print(f'英文记录为{len(all_english_records)}条')
+    
+    # 导出中文CSV文件
+    with open('intent_records_chinese.csv', 'w', encoding='utf-8-sig', newline='') as f:
+        writer = csv.writer(f)
+        writer.writerow(['id', 'record', 'time', 'label'])
+        writer.writerows(all_chinese_records)
+    
+    # 导出英文CSV文件
+    with open('intent_records_english.csv', 'w', encoding='utf-8-sig', newline='') as f:
+        writer = csv.writer(f)
+        writer.writerow(['id', 'record', 'time', 'label'])
+        writer.writerows(all_english_records)
+    
+    conn.close()
+    
+    print(f"\n导出完成!")
+    print(f"中文CSV文件 (intent_records_chinese.csv):")
+    print(f"  - 原中文记录: {len(chinese_records)} 条")
+    print(f"  - 英文翻译为中文: {len(english_records_zh)} 条")
+    print(f"  - 总计: {len(all_chinese_records)} 条")
+    print(f"英文CSV文件 (intent_records_english.csv):")
+    print(f"  - 原英文记录: {len(english_records)} 条")
+    print(f"  - 中文翻译为英文: {len(chinese_records_en)} 条")
+    print(f"  - 总计: {len(all_english_records)} 条")
+
+def clean_csv(file_path, output_path):
+    """
+    将label列全部减去1,抛弃负数仅保留0和1的行,仅保留record列和label列
+    :param file_path: 输入的CSV文件路径
+    :return: 处理后的DataFrame
+    """
+    print(f"正在整理CSV文件中的数据: {file_path}")
+    # 使用pandas读取
+    df = pd.read_csv(file_path, encoding='utf-8')
+
+    # 将label列的数值全部减去1
+    df['label'] = df['label'] - 1  # 剔除label为负的行,无效标注
+
+    # 剔除所有label小于0的行,仅保留label为0和1的行
+    df = df[df['label'] >= 0]
+
+    # 仅保留record列和label列
+    df = df[['record', 'label']]
+
+    # 保存回原文件
+    df.to_csv(output_path, index=False, encoding='utf-8-sig')
+
+    print(f"处理完成,共保留 {len(df)} 条记录,保存至{output_path}")
+    return df
+
+def print_statistics(file_path):
+    """
+    打印CSV文件中label的统计信息
+    :param file_path: CSV文件路径
+    """
+    print(f"正在打印CSV文件统计数据: {file_path}")
+    df = pd.read_csv(file_path, encoding='utf-8')
+
+    # 总记录数
+    total_count = len(df)
+    print(f"总记录数: {total_count}")
+
+    # label分布统计
+    label_counts = df['label'].value_counts().sort_index()
+    print("\nLabel分布:")
+    for label, count in label_counts.items():
+        percentage = count / total_count * 100
+        print(f"  Label {label}: {count} 条 ({percentage:.2f}%)")
+
+    # 缺失值统计
+    missing_count = df['label'].isna().sum()
+    if missing_count > 0:
+        print(f"\n缺失值: {missing_count} 条")
+
+    print("-" * 40)
+if __name__ == '__main__':
+    # 第一步,从数据库导出所有记录
+    # export_records()
+    # 第二步,整理数据
+    # clean_csv('intent_records_chinese.csv','data_zh.csv')
+    # clean_csv('intent_records_english.csv','data_en.csv')
+    # 打印统计数据
+    print_statistics('data_en.csv')
+    print_statistics('data_zh.csv')
+

+ 68 - 0
models_download.py

@@ -0,0 +1,68 @@
+# download_models.py
+from huggingface_hub import snapshot_download
+import os
+import shutil  # [新增] 引入 shutil 模块用于删除非空目录
+import time
+
+
+def download_model(model_name, local_dir):
+    """
+    下载Hugging Face模型到本地目录
+
+    Args:
+        model_name: Hugging Face模型名称
+        local_dir: 本地保存目录
+    """
+    print(f"开始下载模型: {model_name}")
+    print(f"保存到: {local_dir}")
+
+    try:
+        # 创建目录
+        if os.path.exists(local_dir):
+            # [修改] 使用 shutil.rmtree 删除目录及其内部所有内容
+            shutil.rmtree(local_dir)
+            print(f"目录 {local_dir} 已删除")
+            time.sleep(5)
+        os.makedirs(local_dir, exist_ok=True)
+
+        # 下载模型
+        snapshot_download(
+            repo_id=model_name,
+            local_dir=local_dir,
+            local_dir_use_symlinks=False,  # 不使用符号链接
+            ignore_patterns=["*.safetensors"],  # 可选:不下载safetensors格式
+        )
+
+        print(f"✓ 模型 {model_name} 下载完成")
+        print(f"保存路径: {os.path.abspath(local_dir)}")
+
+    except Exception as e:
+        print(f"✗ 下载失败: {e}")
+
+
+if __name__ == "__main__":
+    # 定义要下载的模型
+    models_to_download = [
+        ("bert-base-uncased", "./models/bert-base-uncased"),
+        ("bert-base-chinese", "./models/bert-base-chinese"),
+    ]
+
+    print("开始下载BERT模型...")
+    print("=" * 50)
+
+    for model_name, local_dir in models_to_download:
+        download_model(model_name, local_dir)
+        print()
+
+    print("=" * 50)
+    print("所有模型下载完成!")
+    print("\n模型文件结构:")
+    for model_name, local_dir in models_to_download:
+        if os.path.exists(local_dir):
+            file_count = len([f for f in os.listdir(local_dir) if os.path.isfile(os.path.join(local_dir, f))])
+            print(f"  {local_dir}: {file_count} 个文件")
+
+    print("\n使用示例:")
+    print("本地模型路径:")
+    print("  - 英文: './models/bert-base-uncased'")
+    print("  - 中文: './models/bert-base-chinese'")

+ 304 - 0
test.py

@@ -0,0 +1,304 @@
+# BERT文本分类测试脚本
+import torch
+import pandas as pd
+import numpy as np
+from torch.utils.data import Dataset, DataLoader
+from transformers import AutoTokenizer, AutoModelForSequenceClassification
+from tqdm import tqdm
+from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
+import os
+import seaborn as sns
+import matplotlib.pyplot as plt
+import gc
+import time
+
+
+class TextClassificationTester:
+    """BERT文本分类测试器"""
+    
+    def __init__(self, model_path, is_english=False, num_classes=2, max_length=256):
+        """
+        初始化测试器
+        
+        Args:
+            model_path: 模型权重文件路径
+            is_english: 是否为英文模型
+            num_classes: 分类类别数
+            max_length: 最大序列长度
+        """
+        # 选择模型
+        if is_english:
+            self.model_name = "bert-base-uncased"
+        else:
+            self.model_name = "bert-base-chinese"
+        self.is_english = is_english
+        
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        self.max_length = max_length
+        self.num_classes = num_classes
+        
+        print(f"\n{'='*60}")
+        print("初始化测试器")
+        print(f"{'='*60}")
+        print(f"使用设备: {self.device}")
+        print(f"模型名称: {self.model_name}")
+        print(f"模型路径: {model_path}")
+        print(f"类别数: {num_classes}")
+        print(f"最大序列长度: {max_length}")
+        
+        # 加载模型和分词器
+        print("\n加载模型和分词器...")
+        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
+        self.model = AutoModelForSequenceClassification.from_pretrained(
+            self.model_name, 
+            num_labels=num_classes
+        ).to(self.device)
+        
+        # 加载训练好的权重
+        if os.path.exists(model_path):
+            print(f"加载模型权重: {model_path}")
+            self.model.load_state_dict(torch.load(model_path, map_location=self.device))
+            print("✓ 模型权重加载成功")
+        else:
+            raise FileNotFoundError(f"模型文件不存在: {model_path}")
+        
+        # 设置为评估模式
+        self.model.eval()
+        
+    def load_test_data(self, test_csv_file):
+        """
+        加载测试数据
+        
+        Args:
+            test_csv_file: 测试数据CSV文件路径
+        """
+        print(f"\n加载测试数据: {test_csv_file}")
+        self.test_data = pd.read_csv(test_csv_file)
+        self.test_data = self.test_data.dropna()
+        
+        print(f"测试样本数: {len(self.test_data)}")
+        print(f"列名: {self.test_data.columns.tolist()}")
+        
+        # 检查必要的列
+        if 'record' not in self.test_data.columns:
+            raise ValueError("数据文件中缺少 'record' 列")
+        if 'label' not in self.test_data.columns:
+            raise ValueError("数据文件中缺少 'label' 列")
+        
+        # 统计标签分布
+        label_counts = self.test_data['label'].value_counts().sort_index()
+        print(f"\n标签分布:")
+        for label, count in label_counts.items():
+            print(f"  类别 {label}: {count} ({count/len(self.test_data)*100:.1f}%)")
+    
+    def predict(self, batch_size=32):
+        """
+        对测试数据进行预测
+        
+        Args:
+            batch_size: 批次大小
+        """
+        print(f"\n开始预测 (batch_size={batch_size})...")
+        
+        self.predictions = []
+        self.labels = []
+        self.texts = []
+        
+        with torch.no_grad():
+            for i in tqdm(range(0, len(self.test_data), batch_size), desc="预测中"):
+                batch_texts = self.test_data['record'].iloc[i:i+batch_size].tolist()
+                batch_labels = self.test_data['label'].iloc[i:i+batch_size].tolist()
+                
+                # 分词
+                inputs = self.tokenizer(
+                    batch_texts,
+                    max_length=self.max_length,
+                    padding='max_length',
+                    truncation=True,
+                    return_tensors='pt'
+                )
+                
+                # 移动到设备
+                inputs = {k: v.to(self.device) for k, v in inputs.items()}
+                
+                # 预测
+                outputs = self.model(**inputs)
+                preds = torch.argmax(outputs.logits, dim=-1).cpu().numpy()
+                
+                self.predictions.extend(preds)
+                self.labels.extend(batch_labels)
+                self.texts.extend(batch_texts)
+        
+        self.predictions = np.array(self.predictions)
+        self.labels = np.array(self.labels)
+        
+        print(f"✓ 预测完成,共 {len(self.predictions)} 个样本")
+    
+    def evaluate(self):
+        """
+        评估模型性能并输出详细指标
+        """
+        print(f"\n{'='*60}")
+        print("评估结果")
+        print(f"{'='*60}")
+        
+        # 计算各项指标
+        accuracy = accuracy_score(self.labels, self.predictions)
+        precision_macro = precision_score(self.labels, self.predictions, average='macro')
+        recall_macro = recall_score(self.labels, self.predictions, average='macro')
+        f1_macro = f1_score(self.labels, self.predictions, average='macro')
+        
+        precision_weighted = precision_score(self.labels, self.predictions, average='weighted')
+        recall_weighted = recall_score(self.labels, self.predictions, average='weighted')
+        f1_weighted = f1_score(self.labels, self.predictions, average='weighted')
+        
+        # 输出总体指标
+        print("\n【总体指标】")
+        print(f"  准确率 (Accuracy): {accuracy:.4f} ({accuracy*100:.2f}%)")
+        print(f"\n  Macro平均:")
+        print(f"    精确率 (Precision): {precision_macro:.4f}")
+        print(f"    召回率 (Recall): {recall_macro:.4f}")
+        print(f"    F1分数: {f1_macro:.4f}")
+        print(f"\n  Weighted平均:")
+        print(f"    精确率 (Precision): {precision_weighted:.4f}")
+        print(f"    召回率 (Recall): {recall_weighted:.4f}")
+        print(f"    F1分数: {f1_weighted:.4f}")
+        
+        # 输出每个类别的详细指标
+        print("\n【各类别详细指标】")
+        print(classification_report(self.labels, self.predictions, digits=4))
+        
+        # 混淆矩阵
+        print("\n【混淆矩阵】")
+        cm = confusion_matrix(self.labels, self.predictions)
+        print(cm)
+        
+        # 绘制混淆矩阵热力图
+        self._plot_confusion_matrix(cm)
+        
+        return {
+            'accuracy': accuracy,
+            'precision_macro': precision_macro,
+            'recall_macro': recall_macro,
+            'f1_macro': f1_macro,
+            'precision_weighted': precision_weighted,
+            'recall_weighted': recall_weighted,
+            'f1_weighted': f1_weighted
+        }
+    
+    def _plot_confusion_matrix(self, cm):
+        """
+        绘制混淆矩阵热力图
+        """
+        plt.figure(figsize=(8, 6))
+        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True)
+        plt.title('Confusion Matrix')
+        plt.ylabel('True Label')
+        plt.xlabel('Predicted Label')
+        
+        # 保存图片
+        if self.is_english:
+            output_path = './confusion_matrix_en.png'
+        else:
+            output_path = './confusion_matrix_zh.png'
+        plt.savefig(output_path, dpi=300, bbox_inches='tight')
+        print(f"\n✓ 混淆矩阵已保存至: {output_path}")
+        plt.close()
+    
+    def save_errors(self, output_csv_file='errors.csv'):
+        """
+        保存预测错误的样本到CSV文件
+        
+        Args:
+            output_csv_file: 输出CSV文件路径
+        """
+        print(f"\n保存错误预测样本...")
+        
+        # 找出预测错误的样本
+        error_indices = np.where(self.predictions != self.labels)[0]
+        
+        if len(error_indices) == 0:
+            print("✓ 没有预测错误的样本")
+            return
+        
+        # 创建错误样本DataFrame
+        errors_df = pd.DataFrame({
+            'record': [self.texts[i] for i in error_indices],
+            'predict': self.predictions[error_indices],
+            'label': self.labels[error_indices]
+        })
+        
+        # 保存到CSV
+        errors_df.to_csv(output_csv_file, index=False, encoding='utf-8-sig')
+        
+        print(f"✓ 错误样本已保存至: {output_csv_file}")
+        print(f"  错误样本数: {len(errors_df)} / {len(self.labels)} ({len(errors_df)/len(self.labels)*100:.2f}%)")
+        
+        # 统计错误类型
+        print(f"\n【错误类型统计】")
+        error_types = errors_df.groupby(['label', 'predict']).size().reset_index(name='count')
+        for _, row in error_types.iterrows():
+            print(f"  真实标签 {row['label']} → 预测为 {row['predict']}: {row['count']} 个")
+    
+    def run_test(self, test_csv_file, output_csv_file='errors.csv', batch_size=32):
+        """
+        运行完整的测试流程
+        
+        Args:
+            test_csv_file: 测试数据文件路径
+            output_csv_file: 错误样本输出文件路径
+            batch_size: 批次大小
+        """
+        # 加载测试数据
+        self.load_test_data(test_csv_file)
+        
+        # 预测
+        self.predict(batch_size)
+        
+        # 评估
+        metrics = self.evaluate()
+        
+        # 保存错误样本
+        self.save_errors(output_csv_file)
+        
+        print(f"\n{'='*60}")
+        print("测试完成!")
+        print(f"{'='*60}")
+        
+        return metrics
+
+
+if __name__ == '__main__':
+    # 示例用法
+    tester = TextClassificationTester(
+        model_path='./checkpoints/bert-base-chinese_best_acc.pth',  # 模型路径
+        is_english=False,  # 中文模型
+        num_classes=2,     # 二分类
+        max_length=256     # 最大序列长度
+    )
+    
+    # 运行测试
+    metrics = tester.run_test(
+        test_csv_file='data_zh.csv',  # 测试数据文件
+        output_csv_file='errors_zh.csv',  # 错误样本输出文件
+        batch_size=32                  # 批次大小
+    )
+
+    del tester, metrics
+    gc.collect()  # 强制 Python 垃圾回收
+    torch.cuda.empty_cache()  # 清空 CUDA 缓存
+    time.sleep(2)  # 短暂等待即可,不需要太久
+
+    tester = TextClassificationTester(
+        model_path='./checkpoints/bert-base-uncased_best_acc.pth',  # 模型路径
+        is_english=True,  # 英文模型
+        num_classes=2,  # 二分类
+        max_length=256  # 最大序列长度
+    )
+
+    # 运行测试
+    metrics = tester.run_test(
+        test_csv_file='data_en.csv',  # 测试数据文件
+        output_csv_file='errors_en.csv',  # 错误样本输出文件
+        batch_size=32  # 批次大小
+    )

+ 320 - 0
train.py

@@ -0,0 +1,320 @@
+# 简化版本 - 使用标准的Hugging Face方法
+import torch
+import pandas as pd
+from torch.utils.data import random_split
+from torch.utils.data import Dataset
+from torch.utils.data import DataLoader
+from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+from tqdm import tqdm
+import os
+from torch.utils.data import WeightedRandomSampler
+import numpy as np
+from collections import Counter
+import time
+import gc
+
+class MyDataset(Dataset):
+    def __init__(self, file_path):
+        self.file_path = file_path
+        self.data = pd.read_csv(file_path)
+        self.data = self.data.dropna()
+        # 打印数据集基本信息
+        print(f"\n数据集信息:")
+        print(f"  总样本数: {len(self.data)}")
+        print(f"  列名: {self.data.columns.tolist()}")
+
+        # 统计文本长度分布
+        if 'record' in self.data.columns:
+            text_lengths = self.data['record'].astype(str).apply(len)
+            print(f"  文本长度统计:")
+            print(f"    最小: {text_lengths.min()}")
+            print(f"    最大: {text_lengths.max()}")
+            print(f"    平均: {text_lengths.mean():.1f}")
+            print(f"    中位数: {text_lengths.median():.1f}")
+            print(f"    95%分位数: {text_lengths.quantile(0.95):.1f}")
+
+        # 统计标签分布
+        if 'label' in self.data.columns:
+            label_counts = Counter(self.data['label'])
+            print(f"  标签分布:")
+            for label, count in sorted(label_counts.items()):
+                print(f"    类别 {label}: {count} ({count / len(self.data) * 100:.1f}%)")
+
+    def __getitem__(self, item):
+        return self.data.iloc[item]['record'], self.data.iloc[item]['label']
+
+    def __len__(self):
+        return len(self.data)
+
+
+class EarlyStopping:
+    """Early stopping机制"""
+
+    def __init__(self, patience:int=3, min_delta:float=0, mode:str='min'):
+        self.patience = patience
+        self.min_delta = min_delta
+        self.mode = mode
+        self.counter = 0
+        self.best_score = None
+        self.early_stop = False
+
+    def __call__(self, score):
+        if self.best_score is None:
+            self.best_score = score
+            return False
+
+        if self.mode == 'min':
+            if score < self.best_score - self.min_delta:
+                self.best_score = score
+                self.counter = 0
+            else:
+                self.counter += 1
+        else:  # mode == 'max'
+            if score > self.best_score + self.min_delta:
+                self.best_score = score
+                self.counter = 0
+            else:
+                self.counter += 1
+
+        if self.counter >= self.patience:
+            self.early_stop = True
+            return True
+        return False
+
+
+class SimpleTrainer:
+    def __init__(self, train_csv_file,
+                 batch_size:int=32,
+                 val_csv_file:str=None,
+                 is_english:bool=False,
+                 num_classes:int=2,
+                 max_length:int=256):
+        # 使用标准的BERT模型
+        if is_english:
+            self.model_name = "bert-base-uncased"
+        else:
+            self.model_name = "bert-base-chinese"
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        print(f"使用设备: {self.device}")
+        print(f"使用模型: {self.model_name}")
+
+        self.batch_size = batch_size
+        self.max_length = max_length
+        self.__global_step = 0
+        self.best_val_acc = 0.0
+        self.best_val_loss = float('inf')
+
+        # 创建checkpoint目录
+        self.checkpoint_root_path = './checkpoints'
+        if not os.path.exists(self.checkpoint_root_path):
+            os.makedirs(self.checkpoint_root_path)
+
+        self.best_acc_model_path = os.path.join(self.checkpoint_root_path , f'{self.model_name}_best_acc.pth')
+        self.best_loss_model_path = os.path.join(self.checkpoint_root_path , f'{self.model_name}_best_loss.pth')
+        # 断点续训
+        self.latest_checkpoint_path = os.path.join(self.checkpoint_root_path , f'{self.model_name}_latest_checkpoint.pth')
+
+        # 创建模型和分词器
+        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
+        self.model = AutoModelForSequenceClassification.from_pretrained(
+            self.model_name,
+            num_labels=num_classes
+        ).to(self.device)
+
+        # 打印模型参数量
+        total_params = sum(p.numel() for p in self.model.parameters())
+        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
+        print(f"模型总参数量: {total_params:,}")
+        print(f"可训练参数量: {trainable_params:,}")
+
+        # 创建数据集
+        print("\n加载数据集...")
+        self.train_dataset = MyDataset(train_csv_file)
+        # 计算样本权重(处理类别不平衡)
+        labels = self.train_dataset.data['label'].values
+        class_counts = np.bincount(labels)
+        sample_weights = 1.0 / class_counts[labels]  # 实际上给每个样本定义了采样的权重
+        sampler = WeightedRandomSampler(
+            weights=sample_weights,
+            num_samples=len(sample_weights),
+            replacement=True
+        )
+        if val_csv_file:
+            self.val_dataset = MyDataset(val_csv_file)
+        else:
+            # 如果没有独立验证集,需要分割
+            self.train_dataset, self.val_dataset = random_split(self.train_dataset, [0.8, 0.2])
+            # 从完整权重中提取训练集对应的权重
+            train_indices = self.train_dataset.indices
+            train_sample_weights = sample_weights[train_indices]
+            sampler = WeightedRandomSampler(
+                weights=train_sample_weights,
+                num_samples=len(train_sample_weights),
+                replacement=True
+            )
+        # 创建数据加载器
+        self.train_loader = DataLoader(self.train_dataset,
+                                       batch_size=self.batch_size,
+                                       sampler=sampler,
+                                       collate_fn=self.collate_func)
+        self.val_loader = DataLoader(self.val_dataset,
+                                     batch_size=self.batch_size,
+                                     shuffle=False,
+                                     collate_fn=self.collate_func)
+
+        # 创建优化器和调度器
+        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=3e-5, weight_decay=0.01)
+        self.scheduler = None
+        self.warmup_ratio = 0.1  # warmup步数占总步数的比例
+
+        # Early stopping
+        self.early_stopping = EarlyStopping(patience=5, min_delta=0.001, mode='min')
+        # 梯度裁剪阈值
+        self.max_grad_norm = 1.0
+
+    def collate_func(self, batch_data):
+        texts, labels = [], []
+        for item in batch_data:
+            texts.append(item[0])
+            labels.append(item[1])
+        
+        inputs = self.tokenizer(
+            texts, 
+            max_length=self.max_length,
+            padding='max_length', 
+            truncation=True, 
+            return_tensors='pt'
+        )
+        inputs['labels'] = torch.tensor(labels, dtype=torch.long)
+        return inputs
+
+    def train_step(self):
+        self.model.train()
+        total_loss = 0.0
+        correct_predictions = 0.0
+        total_samples = 0.0
+        
+        for batch_data in tqdm(self.train_loader, desc="Training"):
+            batch_data = {k: v.to(self.device) for k, v in batch_data.items()}
+            
+            self.optimizer.zero_grad()
+            outputs = self.model(**batch_data)
+            loss = outputs.loss
+            loss.backward()
+            # 添加梯度裁剪
+            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
+
+            self.optimizer.step()
+            self.scheduler.step()  # 更新学习率
+            
+            total_loss += loss.item()
+            predictions = torch.argmax(outputs.logits, dim=-1)
+            correct_predictions += (predictions == batch_data['labels']).float().sum().item()
+            total_samples += len(batch_data['labels'])
+            
+        avg_loss = total_loss / len(self.train_loader)
+        accuracy = correct_predictions / total_samples
+        return avg_loss, accuracy
+
+    def val_step(self):
+        self.model.eval()
+        total_loss = 0.0
+        correct_predictions = 0.0
+        total_samples = 0.0
+        
+        with torch.no_grad():
+            for batch_data in tqdm(self.val_loader, desc="Validating"):
+                batch_data = {k: v.to(self.device) for k, v in batch_data.items()}
+                
+                outputs = self.model(**batch_data)
+                total_loss += outputs.loss.item()
+                
+                predictions = torch.argmax(outputs.logits, dim=-1)
+                correct_predictions += (predictions == batch_data['labels']).float().sum().item()
+                total_samples += len(batch_data['labels'])
+                
+        avg_loss = total_loss / len(self.val_loader)
+        accuracy = correct_predictions / total_samples
+        return avg_loss, accuracy
+
+    def train_and_validate(self, num_epoch):
+        # 创建学习率调度器(带warmup的线性衰减)
+        total_steps = len(self.train_loader) * num_epoch
+        warmup_steps = int(total_steps * self.warmup_ratio)
+
+        self.scheduler = get_linear_schedule_with_warmup(
+            self.optimizer,
+            num_warmup_steps=warmup_steps,
+            num_training_steps=total_steps
+        )
+
+        print(f"\n训练配置:")
+        print(f"  总训练步数: {total_steps}")
+        print(f"  Warmup步数: {warmup_steps}")
+        print(f"  初始学习率: {self.optimizer.param_groups[0]['lr']}")
+        print(f"  Batch size: {self.batch_size}")
+        print(f"  Max length: {self.max_length}")
+        print(f"  梯度裁剪: {self.max_grad_norm}")
+
+        print("开始训练...")
+        for epoch in range(num_epoch):
+            print(f"\n{'='*60}")
+            print(f"Epoch {epoch+1}/{num_epoch}")
+            print(f"{'='*60}")
+
+            train_loss, train_acc = self.train_step()
+            val_loss, val_acc = self.val_step()
+
+            # 获取当前学习率
+            current_lr = self.optimizer.param_groups[0]['lr']
+
+            print(f'Epoch {epoch+1}/{num_epoch}:')
+            print(f'  Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}')
+            print(f'  Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}')
+            print(f'  Learning Rate: {current_lr:.2e}')
+
+            # 保存最佳模型 (基于验证准确率)
+            if val_acc > self.best_val_acc:
+                self.best_val_acc = val_acc
+                torch.save(self.model.state_dict(), self.best_acc_model_path)
+                print(f"  ✓ 保存了新的最佳准确率模型,验证准确率: {self.best_val_acc:.4f}")
+
+            # 保存最低验证损失模型
+            if val_loss < self.best_val_loss:
+                self.best_val_loss = val_loss
+                torch.save(self.model.state_dict(), self.best_loss_model_path)
+                print(f"  ✓ 保存了新的最低损失模型,验证损失: {self.best_val_loss:.4f}")
+
+            # Early stopping检查
+            if self.early_stopping(val_loss):
+                print(f"\n⚠ Early stopping触发! 在epoch {epoch+1}停止训练")
+                print(f"  验证损失已连续{self.early_stopping.patience}个epoch没有改善")
+                break
+
+        print(f"\n{'='*60}")
+        print("训练完成!")
+        print(f"{'='*60}")
+        print(f"最佳验证准确率: {self.best_val_acc:.4f}")
+        print(f"最低验证损失: {self.best_val_loss:.4f}")
+        print(f"模型保存路径:")
+        print(f"  最佳准确率模型: {self.best_acc_model_path}")
+        print(f"  最低损失模型: {self.best_loss_model_path}")
+
+if __name__ == '__main__':
+    trainer_en = SimpleTrainer(train_csv_file='data_en.csv',
+                            batch_size=32,
+                            is_english=True,
+                            max_length=256)  # 增加max_length
+    trainer_en.train_and_validate(num_epoch=20)
+
+    del trainer_en
+    gc.collect()  # 强制 Python 垃圾回收
+    torch.cuda.empty_cache()  # 清空 CUDA 缓存
+    time.sleep(2)  # 短暂等待即可,不需要太久
+
+    trainer = SimpleTrainer(train_csv_file='data_zh.csv',
+                            batch_size=32,
+                            is_english=False,
+                            max_length=256)  # 增加max_length
+    trainer.train_and_validate(num_epoch=20)