api_main.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. """
  2. GAT-LSTM TMP预测模型 - FastAPI服务
  3. 版本:1.1.0
  4. 最后更新:2025-10-29
  5. 提供20分钟短期TMP预测的API服务
  6. """
  7. import os
  8. import sys
  9. import logging
  10. from logging.handlers import RotatingFileHandler
  11. import datetime
  12. import json
  13. import pandas as pd
  14. import uvicorn
  15. from fastapi import FastAPI, HTTPException
  16. from fastapi.middleware.cors import CORSMiddleware
  17. from pydantic import BaseModel
  18. from typing import List, Dict, Any
  19. # --- 日志配置 ---
  20. # 日志保存在logs目录下
  21. log_dir = os.path.join(os.path.dirname(__file__), 'logs')
  22. os.makedirs(log_dir, exist_ok=True)
  23. log_handler = RotatingFileHandler(
  24. os.path.join(log_dir, "api.log"),
  25. maxBytes=2 * 1024 * 1024, # 2 MB
  26. backupCount=5,
  27. encoding='utf-8'
  28. )
  29. # 支持环境变量控制日志详细程度
  30. LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper()
  31. DETAILED_LOGS = os.getenv('DETAILED_LOGS', 'false').lower() == 'true'
  32. # 避免重复配置日志处理器
  33. if not logging.getLogger().handlers:
  34. logging.basicConfig(
  35. level=getattr(logging, LOG_LEVEL),
  36. format='%(asctime)s - %(levelname)s - %(message)s',
  37. handlers=[
  38. log_handler,
  39. logging.StreamHandler()
  40. ]
  41. )
  42. logger = logging.getLogger(__name__)
  43. # --- 添加当前目录到Python路径 ---
  44. current_dir = os.path.dirname(os.path.abspath(__file__))
  45. if current_dir not in sys.path:
  46. sys.path.insert(0, current_dir)
  47. # --- 模型导入与模拟 ---
  48. # 优先尝试导入真实模型,如果失败则使用模拟类(Mock Class)替代
  49. try:
  50. # 使用importlib动态导入(因为模块名以数字开头)
  51. import importlib.util
  52. predict_module_path = os.path.join(current_dir, '20min', 'predict.py')
  53. spec = importlib.util.spec_from_file_location("predict_20min", predict_module_path)
  54. predict_module = importlib.util.module_from_spec(spec)
  55. spec.loader.exec_module(predict_module)
  56. Predictor = predict_module.Predictor
  57. logger.info("成功加载20分钟TMP预测模型模块。")
  58. except Exception as e:
  59. logger.warning(f"未能找到20分钟模型模块: {e},将使用模拟类进行替代。")
  60. logger.warning("请确保模型模块路径正确。")
  61. class Predictor:
  62. """模拟预测器"""
  63. def predict(self, df: pd.DataFrame):
  64. logger.info("正在使用模拟的 Predictor.predict 方法...")
  65. import numpy as np
  66. # 模拟返回5个时间步,16个特征的预测结果
  67. return np.random.rand(5, 16) * 100
  68. def save_predictions(self, res, start_date: str, output_path: str = None) -> pd.DataFrame:
  69. logger.info("正在使用模拟的 Predictor.save_predictions 方法...")
  70. start_dt = datetime.datetime.strptime(start_date, "%Y-%m-%d %H:%M:%S")
  71. time_index = [start_dt + datetime.timedelta(minutes=4 * i) for i in range(len(res))]
  72. # 创建包含16个预测列的DataFrame
  73. columns = [
  74. 'C.M.UF1_DB@press_PV_Predicted', 'C.M.UF2_DB@press_PV_Predicted',
  75. 'C.M.UF3_DB@press_PV_Predicted', 'C.M.UF4_DB@press_PV_Predicted',
  76. 'UF1Per_Predicted', 'UF2Per_Predicted', 'UF3Per_Predicted', 'UF4Per_Predicted',
  77. 'C.M.RO1_DB@DPT_1_Predicted', 'C.M.RO2_DB@DPT_1_Predicted',
  78. 'C.M.RO3_DB@DPT_1_Predicted', 'C.M.RO4_DB@DPT_1_Predicted',
  79. 'C.M.RO1_DB@DPT_2_Predicted', 'C.M.RO2_DB@DPT_2_Predicted',
  80. 'C.M.RO3_DB@DPT_2_Predicted', 'C.M.RO4_DB@DPT_2_Predicted'
  81. ]
  82. result_df = pd.DataFrame(res, columns=columns)
  83. result_df.insert(0, 'index', time_index)
  84. return result_df
  85. # --- FastAPI 应用初始化 ---
  86. app = FastAPI(
  87. title="智能决策与预测 API",
  88. description="一个集成了GAT-LSTM TMP预测模型的 FastAPI 服务。",
  89. version="1.1.0"
  90. )
  91. # 配置CORS中间件
  92. app.add_middleware(
  93. CORSMiddleware,
  94. allow_origins=["*"],
  95. allow_credentials=True,
  96. allow_methods=["*"],
  97. allow_headers=["*"],
  98. )
  99. # --- 全局模型实例 ---
  100. xishan_predict = Predictor()
  101. logger.info("预测模型实例初始化完成")
  102. # --- Pydantic 数据校验模型 ---
  103. class TimeSeriesDataPoint(BaseModel):
  104. """定义单个时间序列数据点的结构,允许包含除datetime外的其他任意字段"""
  105. datetime: str
  106. class Config:
  107. extra = "allow"
  108. class DoubleMembranceRequest(BaseModel):
  109. """双膜模型预测的请求体结构"""
  110. data: List[TimeSeriesDataPoint]
  111. class SuccessResponse(BaseModel):
  112. """定义标准的成功响应结构"""
  113. success: bool = True
  114. predict_result: List[Dict[str, Any]]
  115. # --- API 端点定义 ---
  116. @app.post(
  117. "/api/v1/process_model/double_membrance",
  118. response_model=SuccessResponse,
  119. summary="双膜环境体模型预测",
  120. tags=["模型处理"]
  121. )
  122. def get_double_membrance_model(request: DoubleMembranceRequest):
  123. """接收历史时序数据,使用GAT-LSTM模型进行未来趋势预测。"""
  124. try:
  125. # 精简的请求开始日志
  126. logger.info(f"开始双膜环境体模型预测 - 数据点: {len(request.data)}")
  127. if not request.data:
  128. raise HTTPException(status_code=400, detail="输入数据 'data' 不能为空")
  129. # 详细日志仅在调试模式下记录
  130. if DETAILED_LOGS:
  131. logger.info("输入数据结构分析:")
  132. logger.info(f" - 数据点数量: {len(request.data)}")
  133. logger.info(f" - 时间范围: {request.data[0].datetime} 到 {request.data[-1].datetime}")
  134. sample_data = request.data[0].dict()
  135. feature_count = len([k for k in sample_data.keys() if k != 'datetime'])
  136. logger.info(f" - 特征数量: {feature_count}")
  137. # 将输入数据转换为DataFrame并进行预处理
  138. df = pd.DataFrame([item.dict() for item in request.data])
  139. df["datetime"] = pd.to_datetime(df["datetime"])
  140. df = df.sort_values(by="datetime").rename(columns={"datetime": "index"})
  141. if DETAILED_LOGS:
  142. logger.info(f"数据预处理完成 - 形状: {df.shape}")
  143. # 调用模型进行预测
  144. logger.info("开始模型预测...")
  145. res = xishan_predict.predict(df)
  146. logger.info(f"模型预测完成 - 结果形状: {res.shape}")
  147. predict_start_time = (df['index'].max() + datetime.timedelta(minutes=4)).strftime("%Y-%m-%d %H:%M:%S")
  148. predict_result_df = xishan_predict.save_predictions(res, start_date=predict_start_time)
  149. # 格式化预测结果以符合API输出
  150. predict_result_df["index"] = predict_result_df["index"].apply(lambda x: x.strftime("%Y-%m-%d %H:%M:%S"))
  151. predict_result_df = predict_result_df.rename(columns={"index": "datetime"})
  152. predict_result = predict_result_df.to_dict(orient="records")
  153. # 精简的输出日志
  154. logger.info(
  155. f"预测完成 - 预测点: {len(predict_result)}, 时间范围: {predict_result[0]['datetime']} 到 {predict_result[-1]['datetime']}")
  156. return {"success": True, "predict_result": predict_result}
  157. except Exception as e:
  158. logger.error("处理 'double_membrance' 请求时发生错误:", exc_info=True)
  159. raise HTTPException(status_code=500, detail=str(e))
  160. @app.get(
  161. "/api/v1/process_model/test_double_membrance_from_file",
  162. response_model=SuccessResponse,
  163. summary="从本地文件测试双膜环境体模型预测",
  164. tags=["模型处理-测试"]
  165. )
  166. def test_double_membrance_from_file():
  167. """
  168. 从本地JSON文件加载模拟数据,用于测试环境体模型预测,无需调用接口传递数据。
  169. """
  170. try:
  171. base_dir = os.path.dirname(os.path.abspath(__file__))
  172. file_path = os.path.join(base_dir, "test_files", "pp.json")
  173. logger.info(f"开始本地文件测试 - 文件: {file_path}")
  174. with open(file_path, 'r', encoding='utf-8') as f:
  175. request_data = json.load(f)
  176. if "data" not in request_data:
  177. raise HTTPException(status_code=400, detail=f"JSON文件 {file_path} 中缺少 'data' 键")
  178. input_data = request_data["data"]
  179. if not input_data:
  180. raise HTTPException(status_code=400, detail="JSON文件中的 'data' 列表不能为空")
  181. # 详细日志仅在调试模式下记录
  182. if DETAILED_LOGS:
  183. logger.info("测试数据结构分析:")
  184. logger.info(f" - 数据点数量: {len(input_data)}")
  185. logger.info(f" - 时间范围: {input_data[0]['datetime']} 到 {input_data[-1]['datetime']}")
  186. sample_data = input_data[0]
  187. feature_count = len([k for k in sample_data.keys() if k != 'datetime'])
  188. logger.info(f" - 特征数量: {feature_count}")
  189. # 后续逻辑与 get_double_membrance_model 相同
  190. df = pd.DataFrame(input_data)
  191. df["datetime"] = pd.to_datetime(df["datetime"])
  192. df = df.sort_values(by="datetime").rename(columns={"datetime": "index"})
  193. if DETAILED_LOGS:
  194. logger.info(f"测试数据预处理完成 - 形状: {df.shape}")
  195. logger.info("开始测试模型预测...")
  196. res = xishan_predict.predict(df)
  197. logger.info(f"测试预测完成 - 结果形状: {res.shape}")
  198. predict_start_time = (df['index'].max() + datetime.timedelta(minutes=4)).strftime("%Y-%m-%d %H:%M:%S")
  199. predict_result_df = xishan_predict.save_predictions(res, start_date=predict_start_time)
  200. predict_result_df["index"] = predict_result_df["index"].apply(lambda x: x.strftime("%Y-%m-%d %H:%M:%S"))
  201. predict_result_df = predict_result_df.rename(columns={"index": "datetime"})
  202. predict_result = predict_result_df.to_dict(orient="records")
  203. logger.info(
  204. f"测试完成 - 预测点: {len(predict_result)}, 时间范围: {predict_result[0]['datetime']} 到 {predict_result[-1]['datetime']}")
  205. return {"success": True, "predict_result": predict_result}
  206. except FileNotFoundError:
  207. logger.error(f"测试文件未找到: {file_path}")
  208. raise HTTPException(status_code=404, detail=f"测试文件未找到: {file_path}")
  209. except json.JSONDecodeError:
  210. logger.error(f"无法解析JSON文件: {file_path}")
  211. raise HTTPException(status_code=400, detail=f"无法解析JSON文件,请检查格式: {file_path}")
  212. except Exception as e:
  213. logger.error("处理本地文件测试请求时发生未知错误:", exc_info=True)
  214. raise HTTPException(status_code=500, detail=str(e))
  215. @app.get("/", include_in_schema=False)
  216. def root():
  217. """根路径,提供API文档链接。"""
  218. return {"message": "欢迎使用GAT-LSTM TMP预测 API. 请访问 /docs 查看 API 文档."}
  219. # --- 服务启动入口 ---
  220. if __name__ == "__main__":
  221. uvicorn.run("api_main:app", host="0.0.0.0", port=7980, reload=False)