main.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. # -*- coding: utf-8 -*-
  2. """main.py: 主运行文件"""
  3. import argparse
  4. from config import config
  5. def main():
  6. parser = argparse.ArgumentParser(description="水厂诊断模型训练")
  7. parser.add_argument('-p', '--plant', type=str, required=True, help="水厂名称(对应文件夹名),例如: longting")
  8. args = parser.parse_args()
  9. print(f"[*] 正在初始化工作空间: {args.plant}")
  10. config.load(args.plant)
  11. # 在 config 初始化完成后,再导入后面的通用逻辑
  12. from data_processing import DataAnomalyProcessor
  13. from causal_structure import CausalStructureBuilder
  14. from rl_tracing import RLTrainer
  15. # 1. 数据层
  16. processor = DataAnomalyProcessor()
  17. train_scores, test_scores, threshold_df = processor.process()
  18. # 2. 因果层
  19. builder = CausalStructureBuilder(threshold_df)
  20. causal_graph = builder.build()
  21. # 3. 强化学习层
  22. trainer = RLTrainer(causal_graph, train_scores, threshold_df)
  23. trainer.pretrain_bc()
  24. trainer.train_ppo()
  25. trainer.save_model()
  26. # 4. 评估阶段
  27. trainer.evaluate(test_scores)
  28. print(f"\n[Success] {args.plant} 水厂训练与评估完毕!模型保存在: {config.MODEL_FILE_PATH}")
  29. if __name__ == "__main__":
  30. main()