rl_optimizer.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import torch
  2. import numpy as np
  3. import gymnasium as gym
  4. from gymnasium import spaces
  5. from stable_baselines3 import PPO
  6. from stable_baselines3.common.callbacks import BaseCallback
  7. import torch.optim as optim
  8. from gat import GAT
  9. from data_trainer import DataTrainer
  10. class GATEnv(gym.Env):
  11. metadata = {'render_modes': ['human'], 'render_fps': 4}
  12. def __init__(self, preprocessor, train_loader, val_loader, adj, args, logger):
  13. super().__init__()
  14. self.preprocessor = preprocessor
  15. self.train_loader = train_loader
  16. self.val_loader = val_loader
  17. # 使用指定设备(支持GPU)
  18. self.eval_device = torch.device(args.device)
  19. self.adj = adj.to(self.eval_device)
  20. self.args = args
  21. self.logger = logger
  22. self.action_space = spaces.Box(
  23. low=np.array([1e-5, 32, 2, 0.1], dtype=np.float32),
  24. high=np.array([1e-2, 128, 8, 0.5], dtype=np.float32),
  25. shape=(4,),
  26. dtype=np.float32
  27. )
  28. self.observation_space = spaces.Box(
  29. low=np.array([1e-5, 32, 2, 0.1, 0], dtype=np.float32),
  30. high=np.array([1e-2, 128, 8, 0.5, 100], dtype=np.float32),
  31. shape=(5,),
  32. dtype=np.float32
  33. )
  34. self.best_val_loss = float('inf')
  35. self.current_step = 0
  36. self.max_steps = args.rl_max_steps
  37. self.render_mode = None
  38. def reset(self, seed=None, options=None):
  39. super().reset(seed=seed)
  40. self.current_step = 0
  41. self.best_val_loss = float('inf')
  42. self.current_state = np.array([
  43. float(self.args.lr),
  44. float(self.args.hidden_dim),
  45. float(self.args.num_heads),
  46. float(self.args.dropout),
  47. 10.0
  48. ], dtype=np.float32)
  49. return self.current_state, {}
  50. def step(self, action):
  51. self.current_step += 1
  52. lr = float(action[0])
  53. hidden_dim = int(round(float(action[1])))
  54. num_heads = int(round(float(action[2])))
  55. dropout = float(action[3])
  56. hidden_dim = max(32, min(128, hidden_dim))
  57. num_heads = max(2, min(8, num_heads))
  58. dropout = max(0.1, min(0.5, dropout))
  59. # 在指定设备上构建与评估(支持GPU)
  60. model = GAT(
  61. nfeat=1,
  62. nhid=hidden_dim,
  63. noutput=self.args.num_targets,
  64. dropout=dropout,
  65. nheads=num_heads,
  66. alpha=0.2
  67. ).to(self.eval_device)
  68. optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=self.args.weight_decay)
  69. # 使用指定设备
  70. rl_args = self.args
  71. trainer = DataTrainer(model, rl_args, self.preprocessor, optimizer, logger=self.logger)
  72. val_loss = self._short_evaluate(trainer)
  73. reward = 1.0 / (1.0 + val_loss)
  74. if val_loss < self.best_val_loss:
  75. reward += 0.5
  76. self.best_val_loss = val_loss
  77. self.current_state = np.array([lr, hidden_dim, num_heads, dropout, val_loss], dtype=np.float32)
  78. terminated = self.current_step >= self.max_steps
  79. truncated = False
  80. return self.current_state, float(reward), terminated, truncated, {}
  81. def _short_evaluate(self, trainer):
  82. """
  83. 关键加速:只用极少 batch 做快速近似,保证一个 env.step() 在毫秒到秒级完成。
  84. """
  85. # 训练 1 个 batch、重复 2 次以产生可用梯度信号
  86. for _ in range(2):
  87. trainer.train_epoch(self.train_loader, self.adj, max_batches=1)
  88. # 验证用 2 个 batch,降低方差
  89. val_loss, _ = trainer.validate(self.val_loader, self.adj, max_batches=2)
  90. return float(val_loss)
  91. def render(self):
  92. if self.render_mode == 'human':
  93. print(f"[RL] Step: {self.current_step}, Best Val Loss: {self.best_val_loss:.6f}")
  94. def close(self):
  95. pass
  96. class TrainingCallback(BaseCallback):
  97. def __init__(self, verbose=0, print_every=100):
  98. super().__init__(verbose)
  99. self.print_every = print_every
  100. def _on_step(self) -> bool:
  101. # BaseCallback.logger 不是 logging.Logger;用 print 或 record。
  102. if self.n_calls % self.print_every == 0:
  103. # 某些版本下 self.locals 里没有 'rewards' 键,做个健壮保护
  104. rew = None
  105. try:
  106. r = self.locals.get('rewards', None)
  107. if r is not None:
  108. rew = float(r[0])
  109. except Exception:
  110. pass
  111. print(f"[RL] timesteps={self.num_timesteps} calls={self.n_calls} reward={rew}")
  112. return True
  113. class RLOptimizer:
  114. def __init__(self, args, preprocessor, train_loader, val_loader, adj, logger):
  115. self.args = args
  116. self.preprocessor = preprocessor
  117. self.train_loader = train_loader
  118. self.val_loader = val_loader
  119. self.adj = adj
  120. self.logger = logger
  121. def optimize(self):
  122. env = GATEnv(
  123. self.preprocessor, self.train_loader, self.val_loader,
  124. self.adj, self.args, self.logger
  125. )
  126. # 关键:将 PPO rollout 和训练配置调小,避免一次 rollout 等太久
  127. model = PPO(
  128. "MlpPolicy",
  129. env,
  130. verbose=1,
  131. learning_rate=3e-4,
  132. n_steps=32, # 原来 2048 -> 32
  133. batch_size=32, # 原来 64 -> 32
  134. n_epochs=1, # 原来 10 -> 1
  135. gamma=0.99,
  136. gae_lambda=0.95,
  137. clip_range=0.2,
  138. ent_coef=0.01,
  139. device=self.args.device # 使用指定设备(支持GPU)
  140. )
  141. self.logger.info("开始训练强化学习智能体...")
  142. callback = TrainingCallback(verbose=1, print_every=100)
  143. model.learn(total_timesteps=self.args.rl_timesteps, callback=callback)
  144. model.save("gat_ppo_agent")
  145. # 评估并选最优动作
  146. self.logger.info("寻找最优超参数组合...")
  147. best_reward = -1.0
  148. best_action = None
  149. eval_env = GATEnv(
  150. self.preprocessor, self.train_loader, self.val_loader,
  151. self.adj, self.args, self.logger
  152. )
  153. for _ in range(self.args.rl_eval_episodes):
  154. obs, _ = eval_env.reset()
  155. action, _ = model.predict(obs, deterministic=True)
  156. _, reward, _, _, _ = eval_env.step(action)
  157. if reward > best_reward:
  158. best_reward = reward
  159. best_action = action
  160. best_hparams = {
  161. 'lr': float(best_action[0]),
  162. 'hidden_dim': int(round(float(best_action[1]))),
  163. 'num_heads': int(round(float(best_action[2]))),
  164. 'dropout': float(best_action[3])
  165. }
  166. self.logger.info(f"\n最优超参数: {best_hparams}")
  167. return best_hparams