|
@@ -14,7 +14,7 @@ class DQNTrainer:
|
|
|
- 在测试集环境上评估策略
|
|
- 在测试集环境上评估策略
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
- def __init__(self, env, params, callback=None,PROJECT_ROOT=None):
|
|
|
|
|
|
|
+ def __init__(self, env, params, callback=None,PROJECT_ROOT=None,DIR_NAME=None):
|
|
|
"""
|
|
"""
|
|
|
初始化训练器
|
|
初始化训练器
|
|
|
|
|
|
|
@@ -27,9 +27,11 @@ class DQNTrainer:
|
|
|
self.params = params
|
|
self.params = params
|
|
|
self.callback = callback
|
|
self.callback = callback
|
|
|
self.PROJECT_ROOT = PROJECT_ROOT
|
|
self.PROJECT_ROOT = PROJECT_ROOT
|
|
|
|
|
+ self.dir_name = DIR_NAME
|
|
|
self.log_dir = self._create_log_dir() # 创建TensorBoard日志目录
|
|
self.log_dir = self._create_log_dir() # 创建TensorBoard日志目录
|
|
|
self.model = self._create_model() # 创建DQN模型
|
|
self.model = self._create_model() # 创建DQN模型
|
|
|
|
|
|
|
|
|
|
+
|
|
|
# ------------------- 私有方法 -------------------
|
|
# ------------------- 私有方法 -------------------
|
|
|
def _create_log_dir(self):
|
|
def _create_log_dir(self):
|
|
|
"""
|
|
"""
|
|
@@ -54,7 +56,7 @@ class DQNTrainer:
|
|
|
|
|
|
|
|
# 4️⃣ 固定日志存放位置:PROJECT_ROOT/model_result/uf_dqn_tensorboard
|
|
# 4️⃣ 固定日志存放位置:PROJECT_ROOT/model_result/uf_dqn_tensorboard
|
|
|
# 假设在 run_dqn_train.py 中定义了 PROJECT_ROOT = "models/uf-rl"
|
|
# 假设在 run_dqn_train.py 中定义了 PROJECT_ROOT = "models/uf-rl"
|
|
|
- base_dir = os.path.join(self.PROJECT_ROOT, "model_result", "uf_dqn_tensorboard","anzhen48h")
|
|
|
|
|
|
|
+ base_dir = os.path.join(self.PROJECT_ROOT, "model_result", "uf_dqn_tensorboard",self.dir_name)
|
|
|
os.makedirs(base_dir, exist_ok=True)
|
|
os.makedirs(base_dir, exist_ok=True)
|
|
|
|
|
|
|
|
# 5️⃣ 完整日志目录路径
|
|
# 5️⃣ 完整日志目录路径
|