| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273 |
- """
- GAT-LSTM TMP预测模型 - FastAPI服务
- 版本:1.1.0
- 最后更新:2025-10-29
- 提供20分钟短期TMP预测的API服务
- """
- import os
- import sys
- import logging
- from logging.handlers import RotatingFileHandler
- import datetime
- import json
- import pandas as pd
- import uvicorn
- from fastapi import FastAPI, HTTPException
- from fastapi.middleware.cors import CORSMiddleware
- from pydantic import BaseModel
- from typing import List, Dict, Any
- # --- 日志配置 ---
- # 日志保存在logs目录下
- log_dir = os.path.join(os.path.dirname(__file__), 'logs')
- os.makedirs(log_dir, exist_ok=True)
- log_handler = RotatingFileHandler(
- os.path.join(log_dir, "api.log"),
- maxBytes=2 * 1024 * 1024, # 2 MB
- backupCount=5,
- encoding='utf-8'
- )
- # 支持环境变量控制日志详细程度
- LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper()
- DETAILED_LOGS = os.getenv('DETAILED_LOGS', 'false').lower() == 'true'
- # 避免重复配置日志处理器
- if not logging.getLogger().handlers:
- logging.basicConfig(
- level=getattr(logging, LOG_LEVEL),
- format='%(asctime)s - %(levelname)s - %(message)s',
- handlers=[
- log_handler,
- logging.StreamHandler()
- ]
- )
- logger = logging.getLogger(__name__)
- # --- 添加当前目录到Python路径 ---
- current_dir = os.path.dirname(os.path.abspath(__file__))
- if current_dir not in sys.path:
- sys.path.insert(0, current_dir)
- # --- 模型导入与模拟 ---
- # 优先尝试导入真实模型,如果失败则使用模拟类(Mock Class)替代
- try:
- # 使用importlib动态导入(因为模块名以数字开头)
- import importlib.util
- predict_module_path = os.path.join(current_dir, '20min', 'predict.py')
- spec = importlib.util.spec_from_file_location("predict_20min", predict_module_path)
- predict_module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(predict_module)
- Predictor = predict_module.Predictor
- logger.info("成功加载20分钟TMP预测模型模块。")
- except Exception as e:
- logger.warning(f"未能找到20分钟模型模块: {e},将使用模拟类进行替代。")
- logger.warning("请确保模型模块路径正确。")
-
- class Predictor:
- """模拟预测器"""
- def predict(self, df: pd.DataFrame):
- logger.info("正在使用模拟的 Predictor.predict 方法...")
- import numpy as np
- # 模拟返回5个时间步,16个特征的预测结果
- return np.random.rand(5, 16) * 100
-
- def save_predictions(self, res, start_date: str, output_path: str = None) -> pd.DataFrame:
- logger.info("正在使用模拟的 Predictor.save_predictions 方法...")
- start_dt = datetime.datetime.strptime(start_date, "%Y-%m-%d %H:%M:%S")
- time_index = [start_dt + datetime.timedelta(minutes=4 * i) for i in range(len(res))]
- # 创建包含16个预测列的DataFrame
- columns = [
- 'C.M.UF1_DB@press_PV_Predicted', 'C.M.UF2_DB@press_PV_Predicted',
- 'C.M.UF3_DB@press_PV_Predicted', 'C.M.UF4_DB@press_PV_Predicted',
- 'UF1Per_Predicted', 'UF2Per_Predicted', 'UF3Per_Predicted', 'UF4Per_Predicted',
- 'C.M.RO1_DB@DPT_1_Predicted', 'C.M.RO2_DB@DPT_1_Predicted',
- 'C.M.RO3_DB@DPT_1_Predicted', 'C.M.RO4_DB@DPT_1_Predicted',
- 'C.M.RO1_DB@DPT_2_Predicted', 'C.M.RO2_DB@DPT_2_Predicted',
- 'C.M.RO3_DB@DPT_2_Predicted', 'C.M.RO4_DB@DPT_2_Predicted'
- ]
- result_df = pd.DataFrame(res, columns=columns)
- result_df.insert(0, 'index', time_index)
- return result_df
- # --- FastAPI 应用初始化 ---
- app = FastAPI(
- title="智能决策与预测 API",
- description="一个集成了GAT-LSTM TMP预测模型的 FastAPI 服务。",
- version="1.1.0"
- )
- # 配置CORS中间件
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # --- 全局模型实例 ---
- xishan_predict = Predictor()
- logger.info("预测模型实例初始化完成")
- # --- Pydantic 数据校验模型 ---
- class TimeSeriesDataPoint(BaseModel):
- """定义单个时间序列数据点的结构,允许包含除datetime外的其他任意字段"""
- datetime: str
-
- class Config:
- extra = "allow"
- class DoubleMembranceRequest(BaseModel):
- """双膜模型预测的请求体结构"""
- data: List[TimeSeriesDataPoint]
- class SuccessResponse(BaseModel):
- """定义标准的成功响应结构"""
- success: bool = True
- predict_result: List[Dict[str, Any]]
- # --- API 端点定义 ---
- @app.post(
- "/api/v1/process_model/double_membrance",
- response_model=SuccessResponse,
- summary="双膜环境体模型预测",
- tags=["模型处理"]
- )
- def get_double_membrance_model(request: DoubleMembranceRequest):
- """接收历史时序数据,使用GAT-LSTM模型进行未来趋势预测。"""
- try:
- # 精简的请求开始日志
- logger.info(f"开始双膜环境体模型预测 - 数据点: {len(request.data)}")
-
- if not request.data:
- raise HTTPException(status_code=400, detail="输入数据 'data' 不能为空")
-
- # 详细日志仅在调试模式下记录
- if DETAILED_LOGS:
- logger.info("输入数据结构分析:")
- logger.info(f" - 数据点数量: {len(request.data)}")
- logger.info(f" - 时间范围: {request.data[0].datetime} 到 {request.data[-1].datetime}")
-
- sample_data = request.data[0].dict()
- feature_count = len([k for k in sample_data.keys() if k != 'datetime'])
- logger.info(f" - 特征数量: {feature_count}")
-
- # 将输入数据转换为DataFrame并进行预处理
- df = pd.DataFrame([item.dict() for item in request.data])
- df["datetime"] = pd.to_datetime(df["datetime"])
- df = df.sort_values(by="datetime").rename(columns={"datetime": "index"})
-
- if DETAILED_LOGS:
- logger.info(f"数据预处理完成 - 形状: {df.shape}")
-
- # 调用模型进行预测
- logger.info("开始模型预测...")
- res = xishan_predict.predict(df)
- logger.info(f"模型预测完成 - 结果形状: {res.shape}")
-
- predict_start_time = (df['index'].max() + datetime.timedelta(minutes=4)).strftime("%Y-%m-%d %H:%M:%S")
- predict_result_df = xishan_predict.save_predictions(res, start_date=predict_start_time)
-
- # 格式化预测结果以符合API输出
- predict_result_df["index"] = predict_result_df["index"].apply(lambda x: x.strftime("%Y-%m-%d %H:%M:%S"))
- predict_result_df = predict_result_df.rename(columns={"index": "datetime"})
- predict_result = predict_result_df.to_dict(orient="records")
-
- # 精简的输出日志
- logger.info(
- f"预测完成 - 预测点: {len(predict_result)}, 时间范围: {predict_result[0]['datetime']} 到 {predict_result[-1]['datetime']}")
-
- return {"success": True, "predict_result": predict_result}
- except Exception as e:
- logger.error("处理 'double_membrance' 请求时发生错误:", exc_info=True)
- raise HTTPException(status_code=500, detail=str(e))
- @app.get(
- "/api/v1/process_model/test_double_membrance_from_file",
- response_model=SuccessResponse,
- summary="从本地文件测试双膜环境体模型预测",
- tags=["模型处理-测试"]
- )
- def test_double_membrance_from_file():
- """
- 从本地JSON文件加载模拟数据,用于测试环境体模型预测,无需调用接口传递数据。
- """
- try:
- base_dir = os.path.dirname(os.path.abspath(__file__))
- file_path = os.path.join(base_dir, "test_files", "pp.json")
-
- logger.info(f"开始本地文件测试 - 文件: {file_path}")
-
- with open(file_path, 'r', encoding='utf-8') as f:
- request_data = json.load(f)
-
- if "data" not in request_data:
- raise HTTPException(status_code=400, detail=f"JSON文件 {file_path} 中缺少 'data' 键")
-
- input_data = request_data["data"]
- if not input_data:
- raise HTTPException(status_code=400, detail="JSON文件中的 'data' 列表不能为空")
-
- # 详细日志仅在调试模式下记录
- if DETAILED_LOGS:
- logger.info("测试数据结构分析:")
- logger.info(f" - 数据点数量: {len(input_data)}")
- logger.info(f" - 时间范围: {input_data[0]['datetime']} 到 {input_data[-1]['datetime']}")
-
- sample_data = input_data[0]
- feature_count = len([k for k in sample_data.keys() if k != 'datetime'])
- logger.info(f" - 特征数量: {feature_count}")
-
- # 后续逻辑与 get_double_membrance_model 相同
- df = pd.DataFrame(input_data)
- df["datetime"] = pd.to_datetime(df["datetime"])
- df = df.sort_values(by="datetime").rename(columns={"datetime": "index"})
-
- if DETAILED_LOGS:
- logger.info(f"测试数据预处理完成 - 形状: {df.shape}")
-
- logger.info("开始测试模型预测...")
- res = xishan_predict.predict(df)
- logger.info(f"测试预测完成 - 结果形状: {res.shape}")
-
- predict_start_time = (df['index'].max() + datetime.timedelta(minutes=4)).strftime("%Y-%m-%d %H:%M:%S")
- predict_result_df = xishan_predict.save_predictions(res, start_date=predict_start_time)
-
- predict_result_df["index"] = predict_result_df["index"].apply(lambda x: x.strftime("%Y-%m-%d %H:%M:%S"))
- predict_result_df = predict_result_df.rename(columns={"index": "datetime"})
- predict_result = predict_result_df.to_dict(orient="records")
-
- logger.info(
- f"测试完成 - 预测点: {len(predict_result)}, 时间范围: {predict_result[0]['datetime']} 到 {predict_result[-1]['datetime']}")
- return {"success": True, "predict_result": predict_result}
-
- except FileNotFoundError:
- logger.error(f"测试文件未找到: {file_path}")
- raise HTTPException(status_code=404, detail=f"测试文件未找到: {file_path}")
- except json.JSONDecodeError:
- logger.error(f"无法解析JSON文件: {file_path}")
- raise HTTPException(status_code=400, detail=f"无法解析JSON文件,请检查格式: {file_path}")
- except Exception as e:
- logger.error("处理本地文件测试请求时发生未知错误:", exc_info=True)
- raise HTTPException(status_code=500, detail=str(e))
- @app.get("/", include_in_schema=False)
- def root():
- """根路径,提供API文档链接。"""
- return {"message": "欢迎使用GAT-LSTM TMP预测 API. 请访问 /docs 查看 API 文档."}
- # --- 服务启动入口 ---
- if __name__ == "__main__":
- uvicorn.run("api_main:app", host="0.0.0.0", port=7980, reload=False)
|