| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- """
- GAT-LSTM TMP预测模型 - FastAPI服务
- 版本:2.0.0
- 最后更新:2025-10-30
- 20分钟短期TMP预测的API服务
- """
- import os
- import sys
- import logging
- from logging.handlers import RotatingFileHandler
- import datetime
- import pandas as pd
- import uvicorn
- from fastapi import FastAPI, HTTPException
- from fastapi.middleware.cors import CORSMiddleware
- from pydantic import BaseModel
- from typing import List, Dict, Any
- # --- 日志和数据目录配置 ---
- base_dir = os.path.dirname(__file__)
- log_dir = os.path.join(base_dir, 'logs')
- os.makedirs(log_dir, exist_ok=True)
- # 数据保存目录
- data_save_dir = os.path.join(base_dir, 'received_data')
- os.makedirs(data_save_dir, exist_ok=True)
- if not logging.getLogger().handlers: # 只在无handler时配置
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(levelname)s - %(message)s',
- handlers=[
- RotatingFileHandler(
- os.path.join(log_dir, "api.log"),
- maxBytes=2 * 1024 * 1024,
- backupCount=5,
- encoding='utf-8'
- ),
- logging.StreamHandler()
- ]
- )
- logger = logging.getLogger(__name__)
- # --- 添加当前目录到Python路径 ---
- current_dir = os.path.dirname(os.path.abspath(__file__))
- if current_dir not in sys.path:
- sys.path.insert(0, current_dir)
- # --- 导入20分钟预测模型 ---
- import importlib.util
- predict_module_path = os.path.join(current_dir, '20min', 'predict.py')
- spec = importlib.util.spec_from_file_location("predict_20min", predict_module_path)
- predict_module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(predict_module)
- Predictor = predict_module.Predictor
- logger.info("成功加载20分钟TMP预测模型")
- # --- FastAPI 应用初始化 ---
- app = FastAPI(
- title="GAT-LSTM TMP预测 API",
- description="20分钟短期TMP预测服务",
- version="2.0.0"
- )
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # --- 全局模型实例 ---
- predictor = Predictor()
- logger.info("预测器初始化完成")
- # --- 数据模型 ---
- class TimeSeriesDataPoint(BaseModel):
- datetime: str
- class Config:
- extra = "allow"
- class PredictRequest(BaseModel):
- data: List[TimeSeriesDataPoint]
- class PredictResponse(BaseModel):
- success: bool = True
- predict_result: List[Dict[str, Any]]
- # --- API 端点 ---
- @app.post(
- "/api/v1/process_model/double_membrance",
- response_model=PredictResponse,
- summary="20分钟TMP预测",
- tags=["预测"]
- )
- def get_double_membrance_model(request: PredictRequest):
- """接收历史时序数据,返回未来20分钟TMP预测结果"""
- try:
- logger.info(f"收到预测请求 - 数据点数: {len(request.data)}")
-
- if not request.data:
- raise HTTPException(status_code=400, detail="输入数据不能为空")
-
- # 转换为DataFrame
- df = pd.DataFrame([item.dict() for item in request.data])
- logger.info(f"数据转换完成 - shape: {df.shape}, 列数: {len(df.columns)}")
-
- # 保存接收到的数据
- try:
- import json
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
- save_filename = f"request_{timestamp}.json"
- save_path = os.path.join(data_save_dir, save_filename)
-
- with open(save_path, 'w', encoding='utf-8') as f:
- json.dump({
- "timestamp": datetime.datetime.now().isoformat(),
- "data_shape": {"rows": len(df), "columns": len(df.columns)},
- "data": [item.dict() for item in request.data]
- }, f, ensure_ascii=False, indent=2)
- logger.info(f"请求数据已保存: {save_filename}")
- except Exception as e:
- logger.warning(f"保存请求数据失败: {e}")
-
- df["datetime"] = pd.to_datetime(df["datetime"])
- df = df.sort_values(by="datetime").rename(columns={"datetime": "index"})
- logger.info(f"数据预处理完成 - 时间范围: {df['index'].min()} 到 {df['index'].max()}")
-
- # 执行预测(predict内部会计算predict_start_time)
- predictions = predictor.predict(df)
-
- # 保存预测结果(使用predict内部计算的test_start_date)
- result_df = predictor.save_predictions(predictions)
-
- # 格式化输出
- result_df["index"] = result_df["index"].apply(lambda x: x.strftime("%Y-%m-%d %H:%M:%S"))
- result_df = result_df.rename(columns={"index": "datetime"})
- predict_result = result_df.to_dict(orient="records")
-
- logger.info(f"预测完成 - 预测点数: {len(predict_result)}, 时间范围: {predict_result[0]['datetime']} 到 {predict_result[-1]['datetime']}")
-
- return {"success": True, "predict_result": predict_result}
-
- except Exception as e:
- logger.error(f"预测失败: {str(e)}", exc_info=True)
- raise HTTPException(status_code=500, detail=str(e))
- @app.get(
- "/api/v1/process_model/test_double_membrance_from_file",
- response_model=PredictResponse,
- summary="测试预测(从本地文件)",
- tags=["测试"]
- )
- def test_double_membrance_from_file():
- """从本地JSON文件加载测试数据进行预测"""
- try:
- import json
- file_path = os.path.join(base_dir, "test_files", "pp.json")
- logger.info(f"开始测试 - 文件: {file_path}")
-
- with open(file_path, 'r', encoding='utf-8') as f:
- request_data = json.load(f)
-
- if "data" not in request_data or not request_data["data"]:
- raise HTTPException(status_code=400, detail="测试文件数据无效")
-
- # 转换为DataFrame
- df = pd.DataFrame(request_data["data"])
- logger.info(f"测试数据加载完成 - shape: {df.shape}")
-
- df["datetime"] = pd.to_datetime(df["datetime"])
- df = df.sort_values(by="datetime").rename(columns={"datetime": "index"})
- logger.info(f"测试数据预处理完成 - 时间范围: {df['index'].min()} 到 {df['index'].max()}")
-
- # 执行预测
- predictions = predictor.predict(df)
- result_df = predictor.save_predictions(predictions)
-
- # 格式化输出
- result_df["index"] = result_df["index"].apply(lambda x: x.strftime("%Y-%m-%d %H:%M:%S"))
- result_df = result_df.rename(columns={"index": "datetime"})
- predict_result = result_df.to_dict(orient="records")
-
- logger.info(f"测试完成 - 预测点数: {len(predict_result)}")
- return {"success": True, "predict_result": predict_result}
-
- except FileNotFoundError:
- raise HTTPException(status_code=404, detail=f"测试文件未找到: {file_path}")
- except Exception as e:
- logger.error(f"测试失败: {str(e)}", exc_info=True)
- raise HTTPException(status_code=500, detail=str(e))
- @app.get("/", include_in_schema=False)
- def root():
- return {"message": "GAT-LSTM TMP预测服务运行中,访问 /docs 查看API文档"}
- if __name__ == "__main__":
- uvicorn.run("api_main:app", host="0.0.0.0", port=7980, reload=False)
|