| 1234567891011121314151617181920212223242526272829303132 |
- # -*- coding: utf-8 -*-
- """main.py: 主运行文件"""
- from data_processing import DataAnomalyProcessor
- from causal_structure import CausalStructureBuilder
- from rl_tracing import RLTrainer
- def main():
-
- # 1. 数据层 (返回切分好的数据)
- processor = DataAnomalyProcessor()
- train_scores, test_scores, threshold_df = processor.process()
-
- # 2. 因果层
- builder = CausalStructureBuilder(threshold_df)
- causal_graph = builder.build()
-
- # 3. 强化学习层
- # 初始化传入训练集
- trainer = RLTrainer(causal_graph, train_scores, threshold_df)
-
- # 3.1 训练阶段
- trainer.pretrain_bc() # 学习已有的
- trainer.train_ppo() # 探索未知的
- trainer.save_model()
-
- # 3.2 评估阶段 (使用测试集)
- trainer.evaluate(test_scores)
-
- print("\n[Success] 所有任务执行完毕!")
- if __name__ == "__main__":
- main()
|