| 123456789101112131415161718192021222324252627282930313233343536373839 |
- # -*- coding: utf-8 -*-
- """main.py: 主运行文件"""
- import argparse
- from config import config
- def main():
- parser = argparse.ArgumentParser(description="水厂诊断模型训练")
- parser.add_argument('-p', '--plant', type=str, required=True, help="水厂名称(对应文件夹名),例如: longting")
- args = parser.parse_args()
-
- print(f"[*] 正在初始化工作空间: {args.plant}")
- config.load(args.plant)
- # 在 config 初始化完成后,再导入后面的通用逻辑
- from data_processing import DataAnomalyProcessor
- from causal_structure import CausalStructureBuilder
- from rl_tracing import RLTrainer
-
- # 1. 数据层
- processor = DataAnomalyProcessor()
- train_scores, test_scores, threshold_df = processor.process()
-
- # 2. 因果层
- builder = CausalStructureBuilder(threshold_df)
- causal_graph = builder.build()
-
- # 3. 强化学习层
- trainer = RLTrainer(causal_graph, train_scores, threshold_df)
- trainer.pretrain_bc()
- trainer.train_ppo()
- trainer.save_model()
-
- # 4. 评估阶段
- trainer.evaluate(test_scores)
-
- print(f"\n[Success] {args.plant} 水厂训练与评估完毕!模型保存在: {config.MODEL_FILE_PATH}")
- if __name__ == "__main__":
- main()
|