# -*- coding: utf-8 -*- """main.py: 主运行文件""" import argparse from config import config def main(): parser = argparse.ArgumentParser(description="水厂诊断模型训练") parser.add_argument('-p', '--plant', type=str, required=True, help="水厂名称(对应文件夹名),例如: longting") args = parser.parse_args() print(f"[*] 正在初始化工作空间: {args.plant}") config.load(args.plant) # 在 config 初始化完成后,再导入后面的通用逻辑 from data_processing import DataAnomalyProcessor from causal_structure import CausalStructureBuilder from rl_tracing import RLTrainer # 1. 数据层 processor = DataAnomalyProcessor() train_scores, test_scores, threshold_df = processor.process() # 2. 因果层 builder = CausalStructureBuilder(threshold_df) causal_graph = builder.build() # 3. 强化学习层 trainer = RLTrainer(causal_graph, train_scores, threshold_df) trainer.pretrain_bc() trainer.train_ppo() trainer.save_model() # 4. 评估阶段 trainer.evaluate(test_scores) print(f"\n[Success] {args.plant} 水厂训练与评估完毕!模型保存在: {config.MODEL_FILE_PATH}") if __name__ == "__main__": main()