| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- # main.py
- import os
- import torch
- import numpy as np
- import random
- from gat_lstm import GAT_LSTM
- from data_trainer import Trainer
- from args import lstm_args_parser
- from torch.nn import MSELoss
- from data_preprocessor import DataPreprocessor
- def set_seed(seed):
- random.seed(seed)
- os.environ['PYTHONHASHSEED'] = str(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- def main():
- args = lstm_args_parser()
- set_seed(args.random_seed)
-
- device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
- args.device = device
- print(f"当前配置: 序列长度={args.seq_len}, 特征数={args.feature_num}, 目标数={args.labels_num}")
- # 数据预处理
- data = DataPreprocessor.read_and_combine_csv_files(args)
- train_loader, val_loader, test_loader, _ = DataPreprocessor.load_and_process_data(args, data)
-
- # 初始化模型
- model = GAT_LSTM(args).to(device)
- # 训练器
- trainer = Trainer(model, args, data)
- criterion = MSELoss()
- optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step_size, gamma=args.scheduler_gamma)
- print("=== 开始训练 ===")
- trainer.train_full_model(train_loader, val_loader, optimizer, criterion, scheduler)
- trainer.save_model()
-
- print("=== 开始评估 ===")
- trainer.evaluate_model(test_loader, MSELoss())
- if __name__ == "__main__":
- main()
|