# main.py import os import torch import numpy as np import random import argparse from torch.nn import MSELoss from config import config def set_seed(seed): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def main(): parser = argparse.ArgumentParser(description="水厂预测模型训练") parser.add_argument('-p', '--plant', type=str, required=True, help="水厂名称,例如: lankao") args = parser.parse_args() # 加载对应水厂的配置 config.load(args.plant) # 延迟导入,确保 config 已加载 from gat_lstm import GAT_LSTM from data_trainer import Trainer from data_preprocessor import DataPreprocessor set_seed(config.RANDOM_SEED) device = torch.device(f"cuda:{config.DEVICE_ID}" if torch.cuda.is_available() else "cpu") print(f"[*] 工作空间: {args.plant} | 序列={config.SEQ_LEN}, 特征={config.FEATURE_NUM}, 目标={config.LABELS_NUM}") data = DataPreprocessor.read_and_combine_csv_files() train_loader, val_loader, test_loader, _ = DataPreprocessor.load_and_process_data(data) model = GAT_LSTM().to(device) if os.path.exists(config.EDGE_INDEX_PATH): model.set_edge_index(torch.load(config.EDGE_INDEX_PATH, map_location=device, weights_only=True)) print("已加载 edge_index.pt") trainer = Trainer(model, data) optimizer = torch.optim.Adam(model.parameters(), lr=config.LR) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.SCHEDULER_STEP_SIZE, gamma=config.SCHEDULER_GAMMA) print("=== 开始训练 ===") trainer.train_full_model(train_loader, val_loader, optimizer, MSELoss(), scheduler) trainer.save_model() print("=== 开始评估 ===") trainer.evaluate_model(test_loader) if __name__ == "__main__": main()