config_api.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. import logging
  2. import shutil
  3. from pathlib import Path
  4. from typing import Optional
  5. from fastapi import FastAPI, HTTPException, Query, UploadFile, File
  6. from pydantic import BaseModel, Field
  7. from typing import List, Dict, Any
  8. from config.config_manager import ConfigManager
  9. logger = logging.getLogger(__name__)
  10. # FastAPI 实例(由主程序挂载或独立运行)
  11. app = FastAPI(title="拾音器配置管理 API", version="1.0.0")
  12. # 全局 ConfigManager 实例(由 init_config_api 注入)
  13. _config_mgr: Optional[ConfigManager] = None
  14. # 全局 MultiModelPredictor 实例(由 init_config_api 注入,用于模型热加载 API)
  15. _multi_predictor = None
  16. def init_config_api(config_manager: ConfigManager, multi_predictor=None):
  17. # 在主程序启动时调用,注入 ConfigManager 和 MultiModelPredictor 实例
  18. global _config_mgr, _multi_predictor
  19. _config_mgr = config_manager
  20. _multi_predictor = multi_predictor
  21. logger.info("配置管理 API 已初始化")
  22. def get_mgr() -> ConfigManager:
  23. if _config_mgr is None:
  24. raise HTTPException(status_code=500, detail="ConfigManager 未初始化")
  25. return _config_mgr
  26. # ========================================
  27. # Pydantic 模型(请求/响应结构)
  28. # ========================================
  29. class PlantCreate(BaseModel):
  30. name: str = Field(..., description="水厂名称")
  31. project_id: int = Field(..., description="项目ID")
  32. push_url: str = Field('', description="推送URL")
  33. enabled: bool = Field(False, description="是否启用")
  34. class PlantUpdate(BaseModel):
  35. name: Optional[str] = None
  36. project_id: Optional[int] = None
  37. push_url: Optional[str] = None
  38. enabled: Optional[bool] = None
  39. class StreamCreate(BaseModel):
  40. plant_id: int = Field(..., description="所属水厂ID")
  41. name: str = Field(..., description="设备名称")
  42. url: str = Field(..., description="RTSP URL")
  43. channel: int = Field(..., description="通道号")
  44. device_code: str = Field(..., description="设备编码")
  45. pump_name: str = Field('', description="泵名称")
  46. model_subdir: str = Field('', description="模型子目录")
  47. enabled: bool = Field(True, description="是否启用")
  48. class StreamUpdate(BaseModel):
  49. name: Optional[str] = None
  50. url: Optional[str] = None
  51. channel: Optional[int] = None
  52. device_code: Optional[str] = None
  53. pump_name: Optional[str] = None
  54. model_subdir: Optional[str] = None
  55. enabled: Optional[bool] = None
  56. plant_id: Optional[int] = None
  57. class FlowPlcItem(BaseModel):
  58. pump_name: str = Field(..., description="泵名称")
  59. plc_address: str = Field(..., description="PLC地址")
  60. class PumpStatusPlcItem(BaseModel):
  61. pump_name: str = Field(..., description="泵名称")
  62. point: str = Field(..., description="PLC点位")
  63. point_name: str = Field('', description="点位名称")
  64. class ConfigUpdate(BaseModel):
  65. # 通用的配置更新模型:传入嵌套 dict 或扁平 KV
  66. config: Dict[str, Any] = Field(..., description="配置字典")
  67. class ApiResponse(BaseModel):
  68. code: int = 200
  69. msg: str = "success"
  70. data: Any = None
  71. # ========================================
  72. # 全量配置接口
  73. # ========================================
  74. @app.get("/api/config", response_model=ApiResponse, summary="获取全量配置")
  75. def get_full_config():
  76. # 返回与原 rtsp_config.yaml 结构一致的完整配置
  77. mgr = get_mgr()
  78. return ApiResponse(data=mgr.get_full_config())
  79. # ========================================
  80. # 水厂 CRUD
  81. # ========================================
  82. @app.get("/api/config/plants", response_model=ApiResponse, summary="获取水厂列表")
  83. def list_plants():
  84. mgr = get_mgr()
  85. return ApiResponse(data=mgr.get_plants())
  86. @app.get("/api/config/plants/{plant_id}", response_model=ApiResponse, summary="获取单个水厂")
  87. def get_plant(plant_id: int):
  88. mgr = get_mgr()
  89. plant = mgr.get_plant(plant_id)
  90. if not plant:
  91. raise HTTPException(status_code=404, detail=f"水厂不存在: id={plant_id}")
  92. return ApiResponse(data=plant)
  93. @app.post("/api/config/plants", response_model=ApiResponse, summary="创建水厂")
  94. def create_plant(body: PlantCreate):
  95. mgr = get_mgr()
  96. try:
  97. plant_id = mgr.create_plant(
  98. name=body.name,
  99. project_id=body.project_id,
  100. push_url=body.push_url,
  101. enabled=body.enabled
  102. )
  103. return ApiResponse(data={"id": plant_id})
  104. except Exception as e:
  105. raise HTTPException(status_code=400, detail=str(e))
  106. @app.put("/api/config/plants/{plant_id}", response_model=ApiResponse, summary="更新水厂")
  107. def update_plant(plant_id: int, body: PlantUpdate):
  108. mgr = get_mgr()
  109. # 只传递非 None 的字段
  110. updates = body.dict(exclude_none=True)
  111. if not updates:
  112. raise HTTPException(status_code=400, detail="无有效更新字段")
  113. mgr.update_plant(plant_id, **updates)
  114. return ApiResponse(msg="更新成功")
  115. @app.delete("/api/config/plants/{plant_id}", response_model=ApiResponse, summary="删除水厂")
  116. def delete_plant(plant_id: int):
  117. mgr = get_mgr()
  118. mgr.delete_plant(plant_id)
  119. return ApiResponse(msg="删除成功")
  120. # ========================================
  121. # RTSP 流 CRUD
  122. # ========================================
  123. @app.get("/api/config/streams", response_model=ApiResponse, summary="获取RTSP流列表")
  124. def list_streams(plant_id: Optional[int] = Query(None, description="按水厂ID过滤")):
  125. mgr = get_mgr()
  126. return ApiResponse(data=mgr.get_streams(plant_id))
  127. @app.post("/api/config/streams", response_model=ApiResponse, summary="创建RTSP流")
  128. def create_stream(body: StreamCreate):
  129. mgr = get_mgr()
  130. try:
  131. stream_id = mgr.create_stream(
  132. plant_id=body.plant_id,
  133. name=body.name,
  134. url=body.url,
  135. channel=body.channel,
  136. device_code=body.device_code,
  137. pump_name=body.pump_name,
  138. model_subdir=body.model_subdir,
  139. enabled=body.enabled
  140. )
  141. return ApiResponse(data={"id": stream_id})
  142. except Exception as e:
  143. raise HTTPException(status_code=400, detail=str(e))
  144. @app.put("/api/config/streams/{stream_id}", response_model=ApiResponse, summary="更新RTSP流")
  145. def update_stream(stream_id: int, body: StreamUpdate):
  146. mgr = get_mgr()
  147. updates = body.dict(exclude_none=True)
  148. if not updates:
  149. raise HTTPException(status_code=400, detail="无有效更新字段")
  150. mgr.update_stream(stream_id, **updates)
  151. return ApiResponse(msg="更新成功")
  152. @app.delete("/api/config/streams/{stream_id}", response_model=ApiResponse, summary="删除RTSP流")
  153. def delete_stream(stream_id: int):
  154. mgr = get_mgr()
  155. mgr.delete_stream(stream_id)
  156. return ApiResponse(msg="删除成功")
  157. # ========================================
  158. # 流量 PLC 配置
  159. # ========================================
  160. @app.post("/api/config/plants/{plant_id}/flow-plc", response_model=ApiResponse,
  161. summary="设置流量PLC映射")
  162. def set_flow_plc(plant_id: int, body: FlowPlcItem):
  163. mgr = get_mgr()
  164. mgr.set_flow_plc(plant_id, body.pump_name, body.plc_address)
  165. return ApiResponse(msg="设置成功")
  166. @app.delete("/api/config/plants/{plant_id}/flow-plc/{pump_name}", response_model=ApiResponse,
  167. summary="删除流量PLC映射")
  168. def delete_flow_plc(plant_id: int, pump_name: str):
  169. mgr = get_mgr()
  170. mgr.delete_flow_plc(plant_id, pump_name)
  171. return ApiResponse(msg="删除成功")
  172. # ========================================
  173. # 泵状态 PLC 点位
  174. # ========================================
  175. @app.post("/api/config/plants/{plant_id}/pump-status-plc", response_model=ApiResponse,
  176. summary="添加泵状态PLC点位")
  177. def add_pump_status_plc(plant_id: int, body: PumpStatusPlcItem):
  178. mgr = get_mgr()
  179. plc_id = mgr.add_pump_status_plc(plant_id, body.pump_name, body.point, body.point_name)
  180. return ApiResponse(data={"id": plc_id})
  181. @app.delete("/api/config/pump-status-plc/{plc_id}", response_model=ApiResponse,
  182. summary="删除泵状态PLC点位")
  183. def delete_pump_status_plc(plc_id: int):
  184. mgr = get_mgr()
  185. mgr.delete_pump_status_plc(plc_id)
  186. return ApiResponse(msg="删除成功")
  187. # ========================================
  188. # 系统级配置(audio, prediction, push_notification 等)
  189. # ========================================
  190. @app.get("/api/config/{section}", response_model=ApiResponse,
  191. summary="获取指定section的系统配置")
  192. def get_section_config(section: str):
  193. # 限制只允许合法的 section 名
  194. allowed_sections = {'audio', 'prediction', 'push_notification', 'scada_api', 'human_detection'}
  195. if section not in allowed_sections:
  196. raise HTTPException(status_code=400, detail=f"不支持的配置区域: {section}")
  197. mgr = get_mgr()
  198. return ApiResponse(data=mgr.get_system_config(section))
  199. @app.put("/api/config/{section}", response_model=ApiResponse,
  200. summary="更新指定section的系统配置")
  201. def update_section_config(section: str, body: ConfigUpdate):
  202. allowed_sections = {'audio', 'prediction', 'push_notification', 'scada_api', 'human_detection'}
  203. if section not in allowed_sections:
  204. raise HTTPException(status_code=400, detail=f"不支持的配置区域: {section}")
  205. mgr = get_mgr()
  206. mgr.update_section_config(section, body.config)
  207. return ApiResponse(msg="更新成功")
  208. # ========================================
  209. # 模型管理 API
  210. # ========================================
  211. @app.get("/api/model/status", response_model=ApiResponse, summary="获取模型加载状态")
  212. def get_model_status():
  213. # 返回各设备的模型加载状态
  214. if _multi_predictor is None:
  215. raise HTTPException(status_code=503, detail="模型预测器未初始化")
  216. return ApiResponse(data={
  217. "registered": _multi_predictor.registered_devices,
  218. "loaded": _multi_predictor.loaded_devices,
  219. "failed": list(_multi_predictor._failed_devices.keys())
  220. })
  221. @app.post("/api/model/reload/{device_code}", response_model=ApiResponse,
  222. summary="重载指定设备的模型")
  223. def reload_device_model(device_code: str):
  224. # 触发指定设备的模型热加载(替换 models/{device_code}/ 下的文件后调用此接口)
  225. if _multi_predictor is None:
  226. raise HTTPException(status_code=503, detail="模型预测器未初始化")
  227. success = _multi_predictor.reload_device(device_code)
  228. if success:
  229. return ApiResponse(msg=f"设备 {device_code} 模型重载成功")
  230. else:
  231. raise HTTPException(status_code=500, detail=f"设备 {device_code} 模型重载失败")
  232. @app.post("/api/model/reload-all", response_model=ApiResponse,
  233. summary="重载所有已注册设备的模型")
  234. def reload_all_models():
  235. # 触发所有已注册设备的模型重载
  236. if _multi_predictor is None:
  237. raise HTTPException(status_code=503, detail="模型预测器未初始化")
  238. results = {}
  239. for device_code in _multi_predictor.registered_devices:
  240. results[device_code] = _multi_predictor.reload_device(device_code)
  241. return ApiResponse(data=results)
  242. @app.post("/api/model/upload/{device_code}", response_model=ApiResponse,
  243. summary="上传模型文件并重载")
  244. async def upload_model(
  245. device_code: str,
  246. model_file: UploadFile = File(None, description="ae_model.pth 模型权重"),
  247. scale_file: UploadFile = File(None, description="global_scale.npy 标准化参数"),
  248. threshold_file: UploadFile = File(None, description="threshold.npy 阈值文件")
  249. ):
  250. """
  251. 上传模型文件到 models/{device_code}/ 目录
  252. 支持三种文件(可单独或组合上传):
  253. - model_file: ae_model.pth
  254. - scale_file: global_scale.npy
  255. - threshold_file: threshold_{device_code}.npy
  256. 上传完成后自动触发模型重载。
  257. """
  258. if _multi_predictor is None:
  259. raise HTTPException(status_code=503, detail="模型预测器未初始化")
  260. # 至少上传一个文件
  261. if not any([model_file, scale_file, threshold_file]):
  262. raise HTTPException(status_code=400, detail="至少需要上传一个文件")
  263. # 确定模型目录
  264. from predictor.config import CFG
  265. device_dir = CFG.MODEL_ROOT / device_code
  266. device_dir.mkdir(parents=True, exist_ok=True)
  267. saved_files = []
  268. try:
  269. # 保存模型权重
  270. if model_file and model_file.filename:
  271. dest = device_dir / "ae_model.pth"
  272. content = await model_file.read()
  273. dest.write_bytes(content)
  274. saved_files.append(str(dest))
  275. # 保存标准化参数
  276. if scale_file and scale_file.filename:
  277. dest = device_dir / "global_scale.npy"
  278. content = await scale_file.read()
  279. dest.write_bytes(content)
  280. saved_files.append(str(dest))
  281. # 保存阈值
  282. if threshold_file and threshold_file.filename:
  283. thr_dir = device_dir / "thresholds"
  284. thr_dir.mkdir(parents=True, exist_ok=True)
  285. dest = thr_dir / f"threshold_{device_code}.npy"
  286. content = await threshold_file.read()
  287. dest.write_bytes(content)
  288. saved_files.append(str(dest))
  289. except Exception as e:
  290. raise HTTPException(status_code=500, detail=f"文件保存失败: {e}")
  291. # 上传完成后自动重载
  292. reload_ok = _multi_predictor.reload_device(device_code)
  293. return ApiResponse(
  294. msg=f"模型上传成功,重载{'成功' if reload_ok else '失败'}",
  295. data={
  296. "device_code": device_code,
  297. "saved_files": saved_files,
  298. "reloaded": reload_ok
  299. }
  300. )