main.py 916 B

1234567891011121314151617181920212223242526272829303132
  1. # -*- coding: utf-8 -*-
  2. """main.py: 主运行文件"""
  3. from data_processing import DataAnomalyProcessor
  4. from causal_structure import CausalStructureBuilder
  5. from rl_tracing import RLTrainer
  6. def main():
  7. # 1. 数据层 (返回切分好的数据)
  8. processor = DataAnomalyProcessor()
  9. train_scores, test_scores, threshold_df = processor.process()
  10. # 2. 因果层
  11. builder = CausalStructureBuilder(threshold_df)
  12. causal_graph = builder.build()
  13. # 3. 强化学习层
  14. # 初始化传入训练集
  15. trainer = RLTrainer(causal_graph, train_scores, threshold_df)
  16. # 3.1 训练阶段
  17. trainer.pretrain_bc() # 学习已有的
  18. trainer.train_ppo() # 探索未知的
  19. trainer.save_model()
  20. # 3.2 评估阶段 (使用测试集)
  21. trainer.evaluate(test_scores)
  22. print("\n[Success] 所有任务执行完毕!")
  23. if __name__ == "__main__":
  24. main()