| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- # main.py
- import os
- import torch
- import numpy as np
- import random
- import argparse
- from torch.nn import MSELoss
- from config import config
- def set_seed(seed):
- random.seed(seed)
- os.environ['PYTHONHASHSEED'] = str(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- def main():
- parser = argparse.ArgumentParser(description="水厂预测模型训练")
- parser.add_argument('-p', '--plant', type=str, required=True, help="水厂名称,例如: lankao")
- args = parser.parse_args()
-
- # 加载对应水厂的配置
- config.load(args.plant)
-
- # 延迟导入,确保 config 已加载
- from gat_lstm import GAT_LSTM
- from data_trainer import Trainer
- from data_preprocessor import DataPreprocessor
- set_seed(config.RANDOM_SEED)
- device = torch.device(f"cuda:{config.DEVICE_ID}" if torch.cuda.is_available() else "cpu")
- print(f"[*] 工作空间: {args.plant} | 序列={config.SEQ_LEN}, 特征={config.FEATURE_NUM}, 目标={config.LABELS_NUM}")
- data = DataPreprocessor.read_and_combine_csv_files()
- train_loader, val_loader, test_loader, _ = DataPreprocessor.load_and_process_data(data)
-
- model = GAT_LSTM().to(device)
- if os.path.exists(config.EDGE_INDEX_PATH):
- model.set_edge_index(torch.load(config.EDGE_INDEX_PATH, map_location=device, weights_only=True))
- print("已加载 edge_index.pt")
- trainer = Trainer(model, data)
- optimizer = torch.optim.Adam(model.parameters(), lr=config.LR)
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.SCHEDULER_STEP_SIZE, gamma=config.SCHEDULER_GAMMA)
- print("=== 开始训练 ===")
- trainer.train_full_model(train_loader, val_loader, optimizer, MSELoss(), scheduler)
- trainer.save_model()
-
- print("=== 开始评估 ===")
- trainer.evaluate_model(test_loader)
- if __name__ == "__main__":
- main()
|