config_manager.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. import sqlite3
  2. import json
  3. import logging
  4. import threading
  5. from pathlib import Path
  6. from typing import Any, Dict, List, Optional
  7. from config.db_models import get_connection, get_db_path, init_db
  8. logger = logging.getLogger(__name__)
  9. class ConfigManager:
  10. # 统一的配置管理器,封装 SQLite 读写,对外提供与原 YAML dict 兼容的接口
  11. # 线程安全:每个线程使用独立的 SQLite 连接
  12. def __init__(self, db_path: Optional[str] = None):
  13. # db_path 为空时默认使用 config/pickup_config.db
  14. self._db_path = Path(db_path) if db_path else get_db_path()
  15. # 线程本地存储:每个线程持有独立连接,避免跨线程共享 sqlite3.Connection
  16. self._local = threading.local()
  17. # 确保表结构已创建
  18. init_db(self._db_path)
  19. logger.info(f"ConfigManager 初始化完成: {self._db_path}")
  20. @property
  21. def _conn(self):
  22. # 线程安全的连接获取:每个线程首次访问时创建独立连接
  23. if not hasattr(self._local, 'conn') or self._local.conn is None:
  24. self._local.conn = get_connection(self._db_path)
  25. return self._local.conn
  26. def close(self):
  27. # 关闭当前线程的连接
  28. if hasattr(self._local, 'conn') and self._local.conn:
  29. self._local.conn.close()
  30. self._local.conn = None
  31. # ========================================
  32. # 兼容层:返回与原 YAML dict 格式一致的配置(供现有代码无缝切换)
  33. # ========================================
  34. def get_full_config(self) -> dict:
  35. # 返回完整配置 dict,结构与原 rtsp_config.yaml 完全一致
  36. # 这是关键的兼容层方法:现有代码 self.config.get('plants', []) 等调用无需修改
  37. config = {}
  38. # 1. 组装 plants 列表
  39. config['plants'] = self._build_plants_list()
  40. # 2. 组装系统级配置(audio, prediction, push_notification, scada_api, human_detection)
  41. for section in ['audio', 'prediction', 'push_notification', 'scada_api', 'human_detection']:
  42. config[section] = self._get_section_config(section)
  43. return config
  44. def _build_plants_list(self) -> List[dict]:
  45. # 从 DB 组装 plants 列表,结构与 YAML 中的 plants 完全一致
  46. cursor = self._conn.execute(
  47. "SELECT id, name, enabled, project_id, push_url FROM plant ORDER BY id"
  48. )
  49. plants = []
  50. for row in cursor.fetchall():
  51. plant_id = row['id']
  52. plant = {
  53. 'name': row['name'],
  54. 'enabled': bool(row['enabled']),
  55. 'project_id': row['project_id'],
  56. 'push_url': row['push_url'],
  57. 'flow_plc': self._get_flow_plc(plant_id),
  58. 'pump_status_plc': self._get_pump_status_plc(plant_id),
  59. 'rtsp_streams': self._get_rtsp_streams(plant_id),
  60. }
  61. plants.append(plant)
  62. return plants
  63. def _get_flow_plc(self, plant_id: int) -> dict:
  64. # 获取流量 PLC 映射:{pump_name: plc_address}
  65. cursor = self._conn.execute(
  66. "SELECT pump_name, plc_address FROM flow_plc WHERE plant_id = ?",
  67. (plant_id,)
  68. )
  69. return {row['pump_name']: row['plc_address'] for row in cursor.fetchall()}
  70. def _get_pump_status_plc(self, plant_id: int) -> dict:
  71. # 获取泵状态 PLC 配置:{pump_name: [{point, name}, ...]}
  72. # 与 YAML 格式一致:同一 pump_name 下可能有多个点位
  73. cursor = self._conn.execute(
  74. "SELECT pump_name, point, point_name FROM pump_status_plc WHERE plant_id = ? ORDER BY id",
  75. (plant_id,)
  76. )
  77. result = {}
  78. for row in cursor.fetchall():
  79. pump_name = row['pump_name']
  80. if pump_name not in result:
  81. result[pump_name] = []
  82. result[pump_name].append({
  83. 'point': row['point'],
  84. 'name': row['point_name']
  85. })
  86. return result
  87. def _get_rtsp_streams(self, plant_id: int) -> List[dict]:
  88. # 获取 RTSP 流列表
  89. cursor = self._conn.execute(
  90. "SELECT name, url, channel, device_code, pump_name, model_subdir "
  91. "FROM rtsp_stream WHERE plant_id = ? AND enabled = 1 ORDER BY id",
  92. (plant_id,)
  93. )
  94. streams = []
  95. for row in cursor.fetchall():
  96. stream = {
  97. 'name': row['name'],
  98. 'url': row['url'],
  99. 'channel': row['channel'],
  100. 'device_code': row['device_code'],
  101. 'pump_name': row['pump_name'],
  102. }
  103. # model_subdir 仅在有值时添加(兼容旧配置中该字段可选的行为)
  104. if row['model_subdir']:
  105. stream['model_subdir'] = row['model_subdir']
  106. streams.append(stream)
  107. return streams
  108. def _get_section_config(self, section: str) -> dict:
  109. # 从 system_config 表读取指定 section 的所有 KV,还原为嵌套 dict
  110. # 例如 section=prediction, key=voting.window_size, value=5
  111. # 还原为 {'voting': {'window_size': 5}}
  112. cursor = self._conn.execute(
  113. "SELECT key, value, value_type FROM system_config WHERE section = ? ORDER BY key",
  114. (section,)
  115. )
  116. result = {}
  117. for row in cursor.fetchall():
  118. key_path = row['key']
  119. raw_value = row['value']
  120. value_type = row['value_type']
  121. # 类型转换
  122. typed_value = self._deserialize_value(raw_value, value_type)
  123. # 将点号分隔的 key 路径还原为嵌套 dict
  124. self._set_nested(result, key_path, typed_value)
  125. return result
  126. # ========================================
  127. # 系统级配置 CRUD
  128. # ========================================
  129. def get_system_config(self, section: str, key: str = None) -> Any:
  130. # 获取系统配置:指定 key 返回单值,不指定返回整个 section dict
  131. if key:
  132. cursor = self._conn.execute(
  133. "SELECT value, value_type FROM system_config WHERE section = ? AND key = ?",
  134. (section, key)
  135. )
  136. row = cursor.fetchone()
  137. if row:
  138. return self._deserialize_value(row['value'], row['value_type'])
  139. return None
  140. return self._get_section_config(section)
  141. def set_system_config(self, section: str, key: str, value: Any,
  142. value_type: str = None, description: str = ''):
  143. # 设置系统配置(upsert 语义:存在则更新,不存在则插入)
  144. if value_type is None:
  145. value_type = self._infer_type(value)
  146. serialized = self._serialize_value(value, value_type)
  147. self._conn.execute(
  148. "INSERT INTO system_config (section, key, value, value_type, description) "
  149. "VALUES (?, ?, ?, ?, ?) "
  150. "ON CONFLICT(section, key) DO UPDATE SET value=excluded.value, "
  151. "value_type=excluded.value_type, description=excluded.description",
  152. (section, key, serialized, value_type, description)
  153. )
  154. self._conn.commit()
  155. logger.debug(f"配置已更新: [{section}] {key} = {value}")
  156. def update_section_config(self, section: str, config_dict: dict):
  157. # 批量更新整个 section 的配置(将嵌套 dict 展平为 KV 对)
  158. flat_items = self._flatten_dict(config_dict)
  159. for key, value in flat_items.items():
  160. self.set_system_config(section, key, value)
  161. # ========================================
  162. # 水厂 CRUD
  163. # ========================================
  164. def get_plants(self) -> List[dict]:
  165. # 获取所有水厂(含关联数据)
  166. return self._build_plants_list()
  167. def get_plant(self, plant_id: int) -> Optional[dict]:
  168. cursor = self._conn.execute(
  169. "SELECT id, name, enabled, project_id, push_url FROM plant WHERE id = ?",
  170. (plant_id,)
  171. )
  172. row = cursor.fetchone()
  173. if not row:
  174. return None
  175. return {
  176. 'id': row['id'],
  177. 'name': row['name'],
  178. 'enabled': bool(row['enabled']),
  179. 'project_id': row['project_id'],
  180. 'push_url': row['push_url'],
  181. 'flow_plc': self._get_flow_plc(plant_id),
  182. 'pump_status_plc': self._get_pump_status_plc(plant_id),
  183. 'rtsp_streams': self._get_rtsp_streams(plant_id),
  184. }
  185. def create_plant(self, name: str, project_id: int, push_url: str = '',
  186. enabled: bool = False) -> int:
  187. cursor = self._conn.execute(
  188. "INSERT INTO plant (name, enabled, project_id, push_url) VALUES (?, ?, ?, ?)",
  189. (name, int(enabled), project_id, push_url)
  190. )
  191. self._conn.commit()
  192. plant_id = cursor.lastrowid
  193. logger.info(f"水厂已创建: id={plant_id}, name={name}")
  194. return plant_id
  195. def update_plant(self, plant_id: int, **kwargs):
  196. # 动态更新水厂字段
  197. allowed_fields = {'name', 'enabled', 'project_id', 'push_url'}
  198. updates = {k: v for k, v in kwargs.items() if k in allowed_fields}
  199. if not updates:
  200. return
  201. # enabled 字段需要转为 int(SQLite 无原生 bool)
  202. if 'enabled' in updates:
  203. updates['enabled'] = int(updates['enabled'])
  204. set_clause = ', '.join(f"{k} = ?" for k in updates)
  205. values = list(updates.values()) + [plant_id]
  206. self._conn.execute(
  207. f"UPDATE plant SET {set_clause} WHERE id = ?", values
  208. )
  209. self._conn.commit()
  210. logger.info(f"水厂已更新: id={plant_id}, fields={list(updates.keys())}")
  211. def delete_plant(self, plant_id: int):
  212. # 级联删除水厂及其关联数据(外键约束自动处理)
  213. self._conn.execute("DELETE FROM plant WHERE id = ?", (plant_id,))
  214. self._conn.commit()
  215. logger.info(f"水厂已删除: id={plant_id}")
  216. # ========================================
  217. # RTSP 流 CRUD
  218. # ========================================
  219. def get_streams(self, plant_id: int = None) -> List[dict]:
  220. if plant_id:
  221. cursor = self._conn.execute(
  222. "SELECT s.*, p.name as plant_name FROM rtsp_stream s "
  223. "JOIN plant p ON s.plant_id = p.id WHERE s.plant_id = ? ORDER BY s.id",
  224. (plant_id,)
  225. )
  226. else:
  227. cursor = self._conn.execute(
  228. "SELECT s.*, p.name as plant_name FROM rtsp_stream s "
  229. "JOIN plant p ON s.plant_id = p.id ORDER BY s.id"
  230. )
  231. return [dict(row) for row in cursor.fetchall()]
  232. def create_stream(self, plant_id: int, name: str, url: str, channel: int,
  233. device_code: str, pump_name: str = '', model_subdir: str = '',
  234. enabled: bool = True) -> int:
  235. cursor = self._conn.execute(
  236. "INSERT INTO rtsp_stream (plant_id, name, url, channel, device_code, "
  237. "pump_name, model_subdir, enabled) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
  238. (plant_id, name, url, channel, device_code, pump_name, model_subdir, int(enabled))
  239. )
  240. self._conn.commit()
  241. stream_id = cursor.lastrowid
  242. logger.info(f"RTSP流已创建: id={stream_id}, device_code={device_code}")
  243. return stream_id
  244. def update_stream(self, stream_id: int, **kwargs):
  245. allowed_fields = {'name', 'url', 'channel', 'device_code', 'pump_name',
  246. 'model_subdir', 'enabled', 'plant_id'}
  247. updates = {k: v for k, v in kwargs.items() if k in allowed_fields}
  248. if not updates:
  249. return
  250. if 'enabled' in updates:
  251. updates['enabled'] = int(updates['enabled'])
  252. set_clause = ', '.join(f"{k} = ?" for k in updates)
  253. values = list(updates.values()) + [stream_id]
  254. self._conn.execute(
  255. f"UPDATE rtsp_stream SET {set_clause} WHERE id = ?", values
  256. )
  257. self._conn.commit()
  258. logger.info(f"RTSP流已更新: id={stream_id}, fields={list(updates.keys())}")
  259. def delete_stream(self, stream_id: int):
  260. self._conn.execute("DELETE FROM rtsp_stream WHERE id = ?", (stream_id,))
  261. self._conn.commit()
  262. logger.info(f"RTSP流已删除: id={stream_id}")
  263. # ========================================
  264. # 流量 PLC CRUD
  265. # ========================================
  266. def set_flow_plc(self, plant_id: int, pump_name: str, plc_address: str):
  267. self._conn.execute(
  268. "INSERT INTO flow_plc (plant_id, pump_name, plc_address) VALUES (?, ?, ?) "
  269. "ON CONFLICT(plant_id, pump_name) DO UPDATE SET plc_address=excluded.plc_address",
  270. (plant_id, pump_name, plc_address)
  271. )
  272. self._conn.commit()
  273. def delete_flow_plc(self, plant_id: int, pump_name: str):
  274. self._conn.execute(
  275. "DELETE FROM flow_plc WHERE plant_id = ? AND pump_name = ?",
  276. (plant_id, pump_name)
  277. )
  278. self._conn.commit()
  279. # ========================================
  280. # 泵状态 PLC CRUD
  281. # ========================================
  282. def add_pump_status_plc(self, plant_id: int, pump_name: str,
  283. point: str, point_name: str = '') -> int:
  284. cursor = self._conn.execute(
  285. "INSERT INTO pump_status_plc (plant_id, pump_name, point, point_name) "
  286. "VALUES (?, ?, ?, ?)",
  287. (plant_id, pump_name, point, point_name)
  288. )
  289. self._conn.commit()
  290. return cursor.lastrowid
  291. def delete_pump_status_plc(self, plc_id: int):
  292. self._conn.execute("DELETE FROM pump_status_plc WHERE id = ?", (plc_id,))
  293. self._conn.commit()
  294. # ========================================
  295. # 工具方法:类型序列化/反序列化
  296. # ========================================
  297. @staticmethod
  298. def _serialize_value(value: Any, value_type: str) -> str:
  299. # 所有值统一序列化为字符串存储
  300. if value_type == 'json':
  301. return json.dumps(value, ensure_ascii=False)
  302. if value_type == 'bool':
  303. return '1' if value else '0'
  304. return str(value)
  305. @staticmethod
  306. def _deserialize_value(raw: str, value_type: str) -> Any:
  307. # 根据 value_type 将字符串还原为 Python 对象
  308. if value_type == 'int':
  309. return int(raw)
  310. elif value_type == 'float':
  311. return float(raw)
  312. elif value_type == 'bool':
  313. return raw in ('1', 'true', 'True')
  314. elif value_type == 'json':
  315. return json.loads(raw)
  316. return raw
  317. @staticmethod
  318. def _infer_type(value: Any) -> str:
  319. # 根据 Python 类型推断 value_type 标识
  320. if isinstance(value, bool):
  321. return 'bool'
  322. elif isinstance(value, int):
  323. return 'int'
  324. elif isinstance(value, float):
  325. return 'float'
  326. elif isinstance(value, (dict, list)):
  327. return 'json'
  328. return 'str'
  329. @staticmethod
  330. def _set_nested(d: dict, key_path: str, value: Any):
  331. # 将点号分隔的 key_path 设置到嵌套 dict 中
  332. # 例如 _set_nested({}, "voting.window_size", 5) => {"voting": {"window_size": 5}}
  333. keys = key_path.split('.')
  334. current = d
  335. for key in keys[:-1]:
  336. if key not in current:
  337. current[key] = {}
  338. current = current[key]
  339. current[keys[-1]] = value
  340. @staticmethod
  341. def _flatten_dict(d: dict, parent_key: str = '') -> dict:
  342. # 将嵌套 dict 展平为点号分隔的 KV 对
  343. # 例如 {"voting": {"window_size": 5}} => {"voting.window_size": 5}
  344. items = {}
  345. for k, v in d.items():
  346. new_key = f"{parent_key}.{k}" if parent_key else k
  347. if isinstance(v, dict):
  348. items.update(ConfigManager._flatten_dict(v, new_key))
  349. else:
  350. items[new_key] = v
  351. return items