train.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. # 简化版本 - 使用标准的Hugging Face方法
  2. import torch
  3. import pandas as pd
  4. from torch.utils.data import random_split
  5. from torch.utils.data import Dataset
  6. from torch.utils.data import DataLoader
  7. from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
  8. from torch.optim.lr_scheduler import ReduceLROnPlateau
  9. from tqdm import tqdm
  10. import os
  11. from torch.utils.data import WeightedRandomSampler
  12. import numpy as np
  13. from collections import Counter
  14. import time
  15. import gc
  16. class MyDataset(Dataset):
  17. def __init__(self, file_path):
  18. self.file_path = file_path
  19. self.data = pd.read_csv(file_path)
  20. self.data = self.data.dropna()
  21. # 打印数据集基本信息
  22. print(f"\n数据集信息:")
  23. print(f" 总样本数: {len(self.data)}")
  24. print(f" 列名: {self.data.columns.tolist()}")
  25. # 统计文本长度分布
  26. if 'record' in self.data.columns:
  27. text_lengths = self.data['record'].astype(str).apply(len)
  28. print(f" 文本长度统计:")
  29. print(f" 最小: {text_lengths.min()}")
  30. print(f" 最大: {text_lengths.max()}")
  31. print(f" 平均: {text_lengths.mean():.1f}")
  32. print(f" 中位数: {text_lengths.median():.1f}")
  33. print(f" 95%分位数: {text_lengths.quantile(0.95):.1f}")
  34. # 统计标签分布
  35. if 'label' in self.data.columns:
  36. label_counts = Counter(self.data['label'])
  37. print(f" 标签分布:")
  38. for label, count in sorted(label_counts.items()):
  39. print(f" 类别 {label}: {count} ({count / len(self.data) * 100:.1f}%)")
  40. def __getitem__(self, item):
  41. return self.data.iloc[item]['record'], self.data.iloc[item]['label']
  42. def __len__(self):
  43. return len(self.data)
  44. class EarlyStopping:
  45. """Early stopping机制"""
  46. def __init__(self, patience:int=3, min_delta:float=0, mode:str='min'):
  47. self.patience = patience
  48. self.min_delta = min_delta
  49. self.mode = mode
  50. self.counter = 0
  51. self.best_score = None
  52. self.early_stop = False
  53. def __call__(self, score):
  54. if self.best_score is None:
  55. self.best_score = score
  56. return False
  57. if self.mode == 'min':
  58. if score < self.best_score - self.min_delta:
  59. self.best_score = score
  60. self.counter = 0
  61. else:
  62. self.counter += 1
  63. else: # mode == 'max'
  64. if score > self.best_score + self.min_delta:
  65. self.best_score = score
  66. self.counter = 0
  67. else:
  68. self.counter += 1
  69. if self.counter >= self.patience:
  70. self.early_stop = True
  71. return True
  72. return False
  73. class SimpleTrainer:
  74. def __init__(self, train_csv_file,
  75. batch_size:int=32,
  76. val_csv_file:str=None,
  77. is_english:bool=False,
  78. num_classes:int=2,
  79. max_length:int=256):
  80. # 使用标准的BERT模型
  81. if is_english:
  82. self.model_name = "bert-base-uncased"
  83. else:
  84. self.model_name = "bert-base-chinese"
  85. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  86. print(f"使用设备: {self.device}")
  87. print(f"使用模型: {self.model_name}")
  88. self.batch_size = batch_size
  89. self.max_length = max_length
  90. self.__global_step = 0
  91. self.best_val_acc = 0.0
  92. self.best_val_loss = float('inf')
  93. # 创建checkpoint目录
  94. self.checkpoint_root_path = './checkpoints'
  95. if not os.path.exists(self.checkpoint_root_path):
  96. os.makedirs(self.checkpoint_root_path)
  97. self.best_acc_model_path = os.path.join(self.checkpoint_root_path , f'{self.model_name}_best_acc.pth')
  98. self.best_loss_model_path = os.path.join(self.checkpoint_root_path , f'{self.model_name}_best_loss.pth')
  99. # 断点续训
  100. self.latest_checkpoint_path = os.path.join(self.checkpoint_root_path , f'{self.model_name}_latest_checkpoint.pth')
  101. # 创建模型和分词器
  102. self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
  103. self.model = AutoModelForSequenceClassification.from_pretrained(
  104. self.model_name,
  105. num_labels=num_classes
  106. ).to(self.device)
  107. # 打印模型参数量
  108. total_params = sum(p.numel() for p in self.model.parameters())
  109. trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
  110. print(f"模型总参数量: {total_params:,}")
  111. print(f"可训练参数量: {trainable_params:,}")
  112. # 创建数据集
  113. print("\n加载数据集...")
  114. self.train_dataset = MyDataset(train_csv_file)
  115. # 计算样本权重(处理类别不平衡)
  116. labels = self.train_dataset.data['label'].values
  117. class_counts = np.bincount(labels)
  118. sample_weights = 1.0 / class_counts[labels] # 实际上给每个样本定义了采样的权重
  119. sampler = WeightedRandomSampler(
  120. weights=sample_weights,
  121. num_samples=len(sample_weights),
  122. replacement=True
  123. )
  124. if val_csv_file:
  125. self.val_dataset = MyDataset(val_csv_file)
  126. else:
  127. # 如果没有独立验证集,需要分割
  128. self.train_dataset, self.val_dataset = random_split(self.train_dataset, [0.8, 0.2])
  129. # 从完整权重中提取训练集对应的权重
  130. train_indices = self.train_dataset.indices
  131. train_sample_weights = sample_weights[train_indices]
  132. sampler = WeightedRandomSampler(
  133. weights=train_sample_weights,
  134. num_samples=len(train_sample_weights),
  135. replacement=True
  136. )
  137. # 创建数据加载器
  138. self.train_loader = DataLoader(self.train_dataset,
  139. batch_size=self.batch_size,
  140. sampler=sampler,
  141. collate_fn=self.collate_func)
  142. self.val_loader = DataLoader(self.val_dataset,
  143. batch_size=self.batch_size,
  144. shuffle=False,
  145. collate_fn=self.collate_func)
  146. # 创建优化器和调度器
  147. self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=3e-5, weight_decay=0.01)
  148. self.scheduler = None
  149. self.warmup_ratio = 0.1 # warmup步数占总步数的比例
  150. # Early stopping
  151. self.early_stopping = EarlyStopping(patience=5, min_delta=0.001, mode='min')
  152. # 梯度裁剪阈值
  153. self.max_grad_norm = 1.0
  154. def collate_func(self, batch_data):
  155. texts, labels = [], []
  156. for item in batch_data:
  157. texts.append(item[0])
  158. labels.append(item[1])
  159. inputs = self.tokenizer(
  160. texts,
  161. max_length=self.max_length,
  162. padding='max_length',
  163. truncation=True,
  164. return_tensors='pt'
  165. )
  166. inputs['labels'] = torch.tensor(labels, dtype=torch.long)
  167. return inputs
  168. def train_step(self):
  169. self.model.train()
  170. total_loss = 0.0
  171. correct_predictions = 0.0
  172. total_samples = 0.0
  173. for batch_data in tqdm(self.train_loader, desc="Training"):
  174. batch_data = {k: v.to(self.device) for k, v in batch_data.items()}
  175. self.optimizer.zero_grad()
  176. outputs = self.model(**batch_data)
  177. loss = outputs.loss
  178. loss.backward()
  179. # 添加梯度裁剪
  180. torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
  181. self.optimizer.step()
  182. self.scheduler.step() # 更新学习率
  183. total_loss += loss.item()
  184. predictions = torch.argmax(outputs.logits, dim=-1)
  185. correct_predictions += (predictions == batch_data['labels']).float().sum().item()
  186. total_samples += len(batch_data['labels'])
  187. avg_loss = total_loss / len(self.train_loader)
  188. accuracy = correct_predictions / total_samples
  189. return avg_loss, accuracy
  190. def val_step(self):
  191. self.model.eval()
  192. total_loss = 0.0
  193. correct_predictions = 0.0
  194. total_samples = 0.0
  195. with torch.no_grad():
  196. for batch_data in tqdm(self.val_loader, desc="Validating"):
  197. batch_data = {k: v.to(self.device) for k, v in batch_data.items()}
  198. outputs = self.model(**batch_data)
  199. total_loss += outputs.loss.item()
  200. predictions = torch.argmax(outputs.logits, dim=-1)
  201. correct_predictions += (predictions == batch_data['labels']).float().sum().item()
  202. total_samples += len(batch_data['labels'])
  203. avg_loss = total_loss / len(self.val_loader)
  204. accuracy = correct_predictions / total_samples
  205. return avg_loss, accuracy
  206. def train_and_validate(self, num_epoch):
  207. # 创建学习率调度器(带warmup的线性衰减)
  208. total_steps = len(self.train_loader) * num_epoch
  209. warmup_steps = int(total_steps * self.warmup_ratio)
  210. self.scheduler = get_linear_schedule_with_warmup(
  211. self.optimizer,
  212. num_warmup_steps=warmup_steps,
  213. num_training_steps=total_steps
  214. )
  215. print(f"\n训练配置:")
  216. print(f" 总训练步数: {total_steps}")
  217. print(f" Warmup步数: {warmup_steps}")
  218. print(f" 初始学习率: {self.optimizer.param_groups[0]['lr']}")
  219. print(f" Batch size: {self.batch_size}")
  220. print(f" Max length: {self.max_length}")
  221. print(f" 梯度裁剪: {self.max_grad_norm}")
  222. print("开始训练...")
  223. for epoch in range(num_epoch):
  224. print(f"\n{'='*60}")
  225. print(f"Epoch {epoch+1}/{num_epoch}")
  226. print(f"{'='*60}")
  227. train_loss, train_acc = self.train_step()
  228. val_loss, val_acc = self.val_step()
  229. # 获取当前学习率
  230. current_lr = self.optimizer.param_groups[0]['lr']
  231. print(f'Epoch {epoch+1}/{num_epoch}:')
  232. print(f' Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}')
  233. print(f' Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}')
  234. print(f' Learning Rate: {current_lr:.2e}')
  235. # 保存最佳模型 (基于验证准确率)
  236. if val_acc > self.best_val_acc:
  237. self.best_val_acc = val_acc
  238. torch.save(self.model.state_dict(), self.best_acc_model_path)
  239. print(f" ✓ 保存了新的最佳准确率模型,验证准确率: {self.best_val_acc:.4f}")
  240. # 保存最低验证损失模型
  241. if val_loss < self.best_val_loss:
  242. self.best_val_loss = val_loss
  243. torch.save(self.model.state_dict(), self.best_loss_model_path)
  244. print(f" ✓ 保存了新的最低损失模型,验证损失: {self.best_val_loss:.4f}")
  245. # Early stopping检查
  246. if self.early_stopping(val_loss):
  247. print(f"\n⚠ Early stopping触发! 在epoch {epoch+1}停止训练")
  248. print(f" 验证损失已连续{self.early_stopping.patience}个epoch没有改善")
  249. break
  250. print(f"\n{'='*60}")
  251. print("训练完成!")
  252. print(f"{'='*60}")
  253. print(f"最佳验证准确率: {self.best_val_acc:.4f}")
  254. print(f"最低验证损失: {self.best_val_loss:.4f}")
  255. print(f"模型保存路径:")
  256. print(f" 最佳准确率模型: {self.best_acc_model_path}")
  257. print(f" 最低损失模型: {self.best_loss_model_path}")
  258. if __name__ == '__main__':
  259. trainer_en = SimpleTrainer(train_csv_file='data_en.csv',
  260. batch_size=32,
  261. is_english=True,
  262. max_length=256) # 增加max_length
  263. trainer_en.train_and_validate(num_epoch=20)
  264. del trainer_en
  265. gc.collect() # 强制 Python 垃圾回收
  266. torch.cuda.empty_cache() # 清空 CUDA 缓存
  267. time.sleep(2) # 短暂等待即可,不需要太久
  268. trainer = SimpleTrainer(train_csv_file='data_zh.csv',
  269. batch_size=32,
  270. is_english=False,
  271. max_length=256) # 增加max_length
  272. trainer.train_and_validate(num_epoch=20)