| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387 |
- import logging
- import shutil
- from pathlib import Path
- from typing import Optional
- from fastapi import FastAPI, HTTPException, Query, UploadFile, File
- from pydantic import BaseModel, Field
- from typing import List, Dict, Any
- from config.config_manager import ConfigManager
- logger = logging.getLogger(__name__)
- # FastAPI 实例(由主程序挂载或独立运行)
- app = FastAPI(title="拾音器配置管理 API", version="1.0.0")
- # 全局 ConfigManager 实例(由 init_config_api 注入)
- _config_mgr: Optional[ConfigManager] = None
- # 全局 MultiModelPredictor 实例(由 init_config_api 注入,用于模型热加载 API)
- _multi_predictor = None
- def init_config_api(config_manager: ConfigManager, multi_predictor=None):
- # 在主程序启动时调用,注入 ConfigManager 和 MultiModelPredictor 实例
- global _config_mgr, _multi_predictor
- _config_mgr = config_manager
- _multi_predictor = multi_predictor
- logger.info("配置管理 API 已初始化")
- def get_mgr() -> ConfigManager:
- if _config_mgr is None:
- raise HTTPException(status_code=500, detail="ConfigManager 未初始化")
- return _config_mgr
- # ========================================
- # Pydantic 模型(请求/响应结构)
- # ========================================
- class PlantCreate(BaseModel):
- name: str = Field(..., description="水厂名称")
- project_id: int = Field(..., description="项目ID")
- push_url: str = Field('', description="推送URL")
- enabled: bool = Field(False, description="是否启用")
- class PlantUpdate(BaseModel):
- name: Optional[str] = None
- project_id: Optional[int] = None
- push_url: Optional[str] = None
- enabled: Optional[bool] = None
- class StreamCreate(BaseModel):
- plant_id: int = Field(..., description="所属水厂ID")
- name: str = Field(..., description="设备名称")
- url: str = Field(..., description="RTSP URL")
- channel: int = Field(..., description="通道号")
- device_code: str = Field(..., description="设备编码")
- pump_name: str = Field('', description="泵名称")
- model_subdir: str = Field('', description="模型子目录")
- enabled: bool = Field(True, description="是否启用")
- class StreamUpdate(BaseModel):
- name: Optional[str] = None
- url: Optional[str] = None
- channel: Optional[int] = None
- device_code: Optional[str] = None
- pump_name: Optional[str] = None
- model_subdir: Optional[str] = None
- enabled: Optional[bool] = None
- plant_id: Optional[int] = None
- class FlowPlcItem(BaseModel):
- pump_name: str = Field(..., description="泵名称")
- plc_address: str = Field(..., description="PLC地址")
- class PumpStatusPlcItem(BaseModel):
- pump_name: str = Field(..., description="泵名称")
- point: str = Field(..., description="PLC点位")
- point_name: str = Field('', description="点位名称")
- class ConfigUpdate(BaseModel):
- # 通用的配置更新模型:传入嵌套 dict 或扁平 KV
- config: Dict[str, Any] = Field(..., description="配置字典")
- class ApiResponse(BaseModel):
- code: int = 200
- msg: str = "success"
- data: Any = None
- # ========================================
- # 全量配置接口
- # ========================================
- @app.get("/api/config", response_model=ApiResponse, summary="获取全量配置")
- def get_full_config():
- # 返回与原 rtsp_config.yaml 结构一致的完整配置
- mgr = get_mgr()
- return ApiResponse(data=mgr.get_full_config())
- # ========================================
- # 水厂 CRUD
- # ========================================
- @app.get("/api/config/plants", response_model=ApiResponse, summary="获取水厂列表")
- def list_plants():
- mgr = get_mgr()
- return ApiResponse(data=mgr.get_plants())
- @app.get("/api/config/plants/{plant_id}", response_model=ApiResponse, summary="获取单个水厂")
- def get_plant(plant_id: int):
- mgr = get_mgr()
- plant = mgr.get_plant(plant_id)
- if not plant:
- raise HTTPException(status_code=404, detail=f"水厂不存在: id={plant_id}")
- return ApiResponse(data=plant)
- @app.post("/api/config/plants", response_model=ApiResponse, summary="创建水厂")
- def create_plant(body: PlantCreate):
- mgr = get_mgr()
- try:
- plant_id = mgr.create_plant(
- name=body.name,
- project_id=body.project_id,
- push_url=body.push_url,
- enabled=body.enabled
- )
- return ApiResponse(data={"id": plant_id})
- except Exception as e:
- raise HTTPException(status_code=400, detail=str(e))
- @app.put("/api/config/plants/{plant_id}", response_model=ApiResponse, summary="更新水厂")
- def update_plant(plant_id: int, body: PlantUpdate):
- mgr = get_mgr()
- # 只传递非 None 的字段
- updates = body.dict(exclude_none=True)
- if not updates:
- raise HTTPException(status_code=400, detail="无有效更新字段")
- mgr.update_plant(plant_id, **updates)
- return ApiResponse(msg="更新成功")
- @app.delete("/api/config/plants/{plant_id}", response_model=ApiResponse, summary="删除水厂")
- def delete_plant(plant_id: int):
- mgr = get_mgr()
- mgr.delete_plant(plant_id)
- return ApiResponse(msg="删除成功")
- # ========================================
- # RTSP 流 CRUD
- # ========================================
- @app.get("/api/config/streams", response_model=ApiResponse, summary="获取RTSP流列表")
- def list_streams(plant_id: Optional[int] = Query(None, description="按水厂ID过滤")):
- mgr = get_mgr()
- return ApiResponse(data=mgr.get_streams(plant_id))
- @app.post("/api/config/streams", response_model=ApiResponse, summary="创建RTSP流")
- def create_stream(body: StreamCreate):
- mgr = get_mgr()
- try:
- stream_id = mgr.create_stream(
- plant_id=body.plant_id,
- name=body.name,
- url=body.url,
- channel=body.channel,
- device_code=body.device_code,
- pump_name=body.pump_name,
- model_subdir=body.model_subdir,
- enabled=body.enabled
- )
- return ApiResponse(data={"id": stream_id})
- except Exception as e:
- raise HTTPException(status_code=400, detail=str(e))
- @app.put("/api/config/streams/{stream_id}", response_model=ApiResponse, summary="更新RTSP流")
- def update_stream(stream_id: int, body: StreamUpdate):
- mgr = get_mgr()
- updates = body.dict(exclude_none=True)
- if not updates:
- raise HTTPException(status_code=400, detail="无有效更新字段")
- mgr.update_stream(stream_id, **updates)
- return ApiResponse(msg="更新成功")
- @app.delete("/api/config/streams/{stream_id}", response_model=ApiResponse, summary="删除RTSP流")
- def delete_stream(stream_id: int):
- mgr = get_mgr()
- mgr.delete_stream(stream_id)
- return ApiResponse(msg="删除成功")
- # ========================================
- # 流量 PLC 配置
- # ========================================
- @app.post("/api/config/plants/{plant_id}/flow-plc", response_model=ApiResponse,
- summary="设置流量PLC映射")
- def set_flow_plc(plant_id: int, body: FlowPlcItem):
- mgr = get_mgr()
- mgr.set_flow_plc(plant_id, body.pump_name, body.plc_address)
- return ApiResponse(msg="设置成功")
- @app.delete("/api/config/plants/{plant_id}/flow-plc/{pump_name}", response_model=ApiResponse,
- summary="删除流量PLC映射")
- def delete_flow_plc(plant_id: int, pump_name: str):
- mgr = get_mgr()
- mgr.delete_flow_plc(plant_id, pump_name)
- return ApiResponse(msg="删除成功")
- # ========================================
- # 泵状态 PLC 点位
- # ========================================
- @app.post("/api/config/plants/{plant_id}/pump-status-plc", response_model=ApiResponse,
- summary="添加泵状态PLC点位")
- def add_pump_status_plc(plant_id: int, body: PumpStatusPlcItem):
- mgr = get_mgr()
- plc_id = mgr.add_pump_status_plc(plant_id, body.pump_name, body.point, body.point_name)
- return ApiResponse(data={"id": plc_id})
- @app.delete("/api/config/pump-status-plc/{plc_id}", response_model=ApiResponse,
- summary="删除泵状态PLC点位")
- def delete_pump_status_plc(plc_id: int):
- mgr = get_mgr()
- mgr.delete_pump_status_plc(plc_id)
- return ApiResponse(msg="删除成功")
- # ========================================
- # 系统级配置(audio, prediction, push_notification 等)
- # ========================================
- @app.get("/api/config/{section}", response_model=ApiResponse,
- summary="获取指定section的系统配置")
- def get_section_config(section: str):
- # 限制只允许合法的 section 名
- allowed_sections = {'audio', 'prediction', 'push_notification', 'scada_api', 'human_detection'}
- if section not in allowed_sections:
- raise HTTPException(status_code=400, detail=f"不支持的配置区域: {section}")
- mgr = get_mgr()
- return ApiResponse(data=mgr.get_system_config(section))
- @app.put("/api/config/{section}", response_model=ApiResponse,
- summary="更新指定section的系统配置")
- def update_section_config(section: str, body: ConfigUpdate):
- allowed_sections = {'audio', 'prediction', 'push_notification', 'scada_api', 'human_detection'}
- if section not in allowed_sections:
- raise HTTPException(status_code=400, detail=f"不支持的配置区域: {section}")
- mgr = get_mgr()
- mgr.update_section_config(section, body.config)
- return ApiResponse(msg="更新成功")
- # ========================================
- # 模型管理 API
- # ========================================
- @app.get("/api/model/status", response_model=ApiResponse, summary="获取模型加载状态")
- def get_model_status():
- # 返回各设备的模型加载状态
- if _multi_predictor is None:
- raise HTTPException(status_code=503, detail="模型预测器未初始化")
- return ApiResponse(data={
- "registered": _multi_predictor.registered_devices,
- "loaded": _multi_predictor.loaded_devices,
- "failed": list(_multi_predictor._failed_devices.keys())
- })
- @app.post("/api/model/reload/{device_code}", response_model=ApiResponse,
- summary="重载指定设备的模型")
- def reload_device_model(device_code: str):
- # 触发指定设备的模型热加载(替换 models/{device_code}/ 下的文件后调用此接口)
- if _multi_predictor is None:
- raise HTTPException(status_code=503, detail="模型预测器未初始化")
- success = _multi_predictor.reload_device(device_code)
- if success:
- return ApiResponse(msg=f"设备 {device_code} 模型重载成功")
- else:
- raise HTTPException(status_code=500, detail=f"设备 {device_code} 模型重载失败")
- @app.post("/api/model/reload-all", response_model=ApiResponse,
- summary="重载所有已注册设备的模型")
- def reload_all_models():
- # 触发所有已注册设备的模型重载
- if _multi_predictor is None:
- raise HTTPException(status_code=503, detail="模型预测器未初始化")
- results = {}
- for device_code in _multi_predictor.registered_devices:
- results[device_code] = _multi_predictor.reload_device(device_code)
- return ApiResponse(data=results)
- @app.post("/api/model/upload/{device_code}", response_model=ApiResponse,
- summary="上传模型文件并重载")
- async def upload_model(
- device_code: str,
- model_file: UploadFile = File(None, description="ae_model.pth 模型权重"),
- scale_file: UploadFile = File(None, description="global_scale.npy 标准化参数"),
- threshold_file: UploadFile = File(None, description="threshold.npy 阈值文件")
- ):
- """
- 上传模型文件到 models/{device_code}/ 目录
- 支持三种文件(可单独或组合上传):
- - model_file: ae_model.pth
- - scale_file: global_scale.npy
- - threshold_file: threshold_{device_code}.npy
- 上传完成后自动触发模型重载。
- """
- if _multi_predictor is None:
- raise HTTPException(status_code=503, detail="模型预测器未初始化")
- # 至少上传一个文件
- if not any([model_file, scale_file, threshold_file]):
- raise HTTPException(status_code=400, detail="至少需要上传一个文件")
- # 确定模型目录
- from predictor.config import CFG
- device_dir = CFG.MODEL_ROOT / device_code
- device_dir.mkdir(parents=True, exist_ok=True)
- saved_files = []
- try:
- # 保存模型权重
- if model_file and model_file.filename:
- dest = device_dir / "ae_model.pth"
- content = await model_file.read()
- dest.write_bytes(content)
- saved_files.append(str(dest))
- # 保存标准化参数
- if scale_file and scale_file.filename:
- dest = device_dir / "global_scale.npy"
- content = await scale_file.read()
- dest.write_bytes(content)
- saved_files.append(str(dest))
- # 保存阈值
- if threshold_file and threshold_file.filename:
- thr_dir = device_dir / "thresholds"
- thr_dir.mkdir(parents=True, exist_ok=True)
- dest = thr_dir / f"threshold_{device_code}.npy"
- content = await threshold_file.read()
- dest.write_bytes(content)
- saved_files.append(str(dest))
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"文件保存失败: {e}")
- # 上传完成后自动重载
- reload_ok = _multi_predictor.reload_device(device_code)
- return ApiResponse(
- msg=f"模型上传成功,重载{'成功' if reload_ok else '失败'}",
- data={
- "device_code": device_code,
- "saved_files": saved_files,
- "reloaded": reload_ok
- }
- )
|