| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402 |
- import sqlite3
- import json
- import logging
- import threading
- from pathlib import Path
- from typing import Any, Dict, List, Optional
- from config.db_models import get_connection, get_db_path, init_db
- logger = logging.getLogger(__name__)
- class ConfigManager:
- # 统一的配置管理器,封装 SQLite 读写,对外提供与原 YAML dict 兼容的接口
- # 线程安全:每个线程使用独立的 SQLite 连接
- def __init__(self, db_path: Optional[str] = None):
- # db_path 为空时默认使用 config/pickup_config.db
- self._db_path = Path(db_path) if db_path else get_db_path()
- # 线程本地存储:每个线程持有独立连接,避免跨线程共享 sqlite3.Connection
- self._local = threading.local()
- # 确保表结构已创建
- init_db(self._db_path)
- logger.info(f"ConfigManager 初始化完成: {self._db_path}")
- @property
- def _conn(self):
- # 线程安全的连接获取:每个线程首次访问时创建独立连接
- if not hasattr(self._local, 'conn') or self._local.conn is None:
- self._local.conn = get_connection(self._db_path)
- return self._local.conn
- def close(self):
- # 关闭当前线程的连接
- if hasattr(self._local, 'conn') and self._local.conn:
- self._local.conn.close()
- self._local.conn = None
- # ========================================
- # 兼容层:返回与原 YAML dict 格式一致的配置(供现有代码无缝切换)
- # ========================================
- def get_full_config(self) -> dict:
- # 返回完整配置 dict,结构与原 rtsp_config.yaml 完全一致
- # 这是关键的兼容层方法:现有代码 self.config.get('plants', []) 等调用无需修改
- config = {}
- # 1. 组装 plants 列表
- config['plants'] = self._build_plants_list()
- # 2. 组装系统级配置(audio, prediction, push_notification, scada_api, human_detection)
- for section in ['audio', 'prediction', 'push_notification', 'scada_api', 'human_detection']:
- config[section] = self._get_section_config(section)
- return config
- def _build_plants_list(self) -> List[dict]:
- # 从 DB 组装 plants 列表,结构与 YAML 中的 plants 完全一致
- cursor = self._conn.execute(
- "SELECT id, name, enabled, project_id, push_url FROM plant ORDER BY id"
- )
- plants = []
- for row in cursor.fetchall():
- plant_id = row['id']
- plant = {
- 'name': row['name'],
- 'enabled': bool(row['enabled']),
- 'project_id': row['project_id'],
- 'push_url': row['push_url'],
- 'flow_plc': self._get_flow_plc(plant_id),
- 'pump_status_plc': self._get_pump_status_plc(plant_id),
- 'rtsp_streams': self._get_rtsp_streams(plant_id),
- }
- plants.append(plant)
- return plants
- def _get_flow_plc(self, plant_id: int) -> dict:
- # 获取流量 PLC 映射:{pump_name: plc_address}
- cursor = self._conn.execute(
- "SELECT pump_name, plc_address FROM flow_plc WHERE plant_id = ?",
- (plant_id,)
- )
- return {row['pump_name']: row['plc_address'] for row in cursor.fetchall()}
- def _get_pump_status_plc(self, plant_id: int) -> dict:
- # 获取泵状态 PLC 配置:{pump_name: [{point, name}, ...]}
- # 与 YAML 格式一致:同一 pump_name 下可能有多个点位
- cursor = self._conn.execute(
- "SELECT pump_name, point, point_name FROM pump_status_plc WHERE plant_id = ? ORDER BY id",
- (plant_id,)
- )
- result = {}
- for row in cursor.fetchall():
- pump_name = row['pump_name']
- if pump_name not in result:
- result[pump_name] = []
- result[pump_name].append({
- 'point': row['point'],
- 'name': row['point_name']
- })
- return result
- def _get_rtsp_streams(self, plant_id: int) -> List[dict]:
- # 获取 RTSP 流列表
- cursor = self._conn.execute(
- "SELECT name, url, channel, device_code, pump_name, model_subdir "
- "FROM rtsp_stream WHERE plant_id = ? AND enabled = 1 ORDER BY id",
- (plant_id,)
- )
- streams = []
- for row in cursor.fetchall():
- stream = {
- 'name': row['name'],
- 'url': row['url'],
- 'channel': row['channel'],
- 'device_code': row['device_code'],
- 'pump_name': row['pump_name'],
- }
- # model_subdir 仅在有值时添加(兼容旧配置中该字段可选的行为)
- if row['model_subdir']:
- stream['model_subdir'] = row['model_subdir']
- streams.append(stream)
- return streams
- def _get_section_config(self, section: str) -> dict:
- # 从 system_config 表读取指定 section 的所有 KV,还原为嵌套 dict
- # 例如 section=prediction, key=voting.window_size, value=5
- # 还原为 {'voting': {'window_size': 5}}
- cursor = self._conn.execute(
- "SELECT key, value, value_type FROM system_config WHERE section = ? ORDER BY key",
- (section,)
- )
- result = {}
- for row in cursor.fetchall():
- key_path = row['key']
- raw_value = row['value']
- value_type = row['value_type']
- # 类型转换
- typed_value = self._deserialize_value(raw_value, value_type)
- # 将点号分隔的 key 路径还原为嵌套 dict
- self._set_nested(result, key_path, typed_value)
- return result
- # ========================================
- # 系统级配置 CRUD
- # ========================================
- def get_system_config(self, section: str, key: str = None) -> Any:
- # 获取系统配置:指定 key 返回单值,不指定返回整个 section dict
- if key:
- cursor = self._conn.execute(
- "SELECT value, value_type FROM system_config WHERE section = ? AND key = ?",
- (section, key)
- )
- row = cursor.fetchone()
- if row:
- return self._deserialize_value(row['value'], row['value_type'])
- return None
- return self._get_section_config(section)
- def set_system_config(self, section: str, key: str, value: Any,
- value_type: str = None, description: str = ''):
- # 设置系统配置(upsert 语义:存在则更新,不存在则插入)
- if value_type is None:
- value_type = self._infer_type(value)
- serialized = self._serialize_value(value, value_type)
- self._conn.execute(
- "INSERT INTO system_config (section, key, value, value_type, description) "
- "VALUES (?, ?, ?, ?, ?) "
- "ON CONFLICT(section, key) DO UPDATE SET value=excluded.value, "
- "value_type=excluded.value_type, description=excluded.description",
- (section, key, serialized, value_type, description)
- )
- self._conn.commit()
- logger.debug(f"配置已更新: [{section}] {key} = {value}")
- def update_section_config(self, section: str, config_dict: dict):
- # 批量更新整个 section 的配置(将嵌套 dict 展平为 KV 对)
- flat_items = self._flatten_dict(config_dict)
- for key, value in flat_items.items():
- self.set_system_config(section, key, value)
- # ========================================
- # 水厂 CRUD
- # ========================================
- def get_plants(self) -> List[dict]:
- # 获取所有水厂(含关联数据)
- return self._build_plants_list()
- def get_plant(self, plant_id: int) -> Optional[dict]:
- cursor = self._conn.execute(
- "SELECT id, name, enabled, project_id, push_url FROM plant WHERE id = ?",
- (plant_id,)
- )
- row = cursor.fetchone()
- if not row:
- return None
- return {
- 'id': row['id'],
- 'name': row['name'],
- 'enabled': bool(row['enabled']),
- 'project_id': row['project_id'],
- 'push_url': row['push_url'],
- 'flow_plc': self._get_flow_plc(plant_id),
- 'pump_status_plc': self._get_pump_status_plc(plant_id),
- 'rtsp_streams': self._get_rtsp_streams(plant_id),
- }
- def create_plant(self, name: str, project_id: int, push_url: str = '',
- enabled: bool = False) -> int:
- cursor = self._conn.execute(
- "INSERT INTO plant (name, enabled, project_id, push_url) VALUES (?, ?, ?, ?)",
- (name, int(enabled), project_id, push_url)
- )
- self._conn.commit()
- plant_id = cursor.lastrowid
- logger.info(f"水厂已创建: id={plant_id}, name={name}")
- return plant_id
- def update_plant(self, plant_id: int, **kwargs):
- # 动态更新水厂字段
- allowed_fields = {'name', 'enabled', 'project_id', 'push_url'}
- updates = {k: v for k, v in kwargs.items() if k in allowed_fields}
- if not updates:
- return
- # enabled 字段需要转为 int(SQLite 无原生 bool)
- if 'enabled' in updates:
- updates['enabled'] = int(updates['enabled'])
- set_clause = ', '.join(f"{k} = ?" for k in updates)
- values = list(updates.values()) + [plant_id]
- self._conn.execute(
- f"UPDATE plant SET {set_clause} WHERE id = ?", values
- )
- self._conn.commit()
- logger.info(f"水厂已更新: id={plant_id}, fields={list(updates.keys())}")
- def delete_plant(self, plant_id: int):
- # 级联删除水厂及其关联数据(外键约束自动处理)
- self._conn.execute("DELETE FROM plant WHERE id = ?", (plant_id,))
- self._conn.commit()
- logger.info(f"水厂已删除: id={plant_id}")
- # ========================================
- # RTSP 流 CRUD
- # ========================================
- def get_streams(self, plant_id: int = None) -> List[dict]:
- if plant_id:
- cursor = self._conn.execute(
- "SELECT s.*, p.name as plant_name FROM rtsp_stream s "
- "JOIN plant p ON s.plant_id = p.id WHERE s.plant_id = ? ORDER BY s.id",
- (plant_id,)
- )
- else:
- cursor = self._conn.execute(
- "SELECT s.*, p.name as plant_name FROM rtsp_stream s "
- "JOIN plant p ON s.plant_id = p.id ORDER BY s.id"
- )
- return [dict(row) for row in cursor.fetchall()]
- def create_stream(self, plant_id: int, name: str, url: str, channel: int,
- device_code: str, pump_name: str = '', model_subdir: str = '',
- enabled: bool = True) -> int:
- cursor = self._conn.execute(
- "INSERT INTO rtsp_stream (plant_id, name, url, channel, device_code, "
- "pump_name, model_subdir, enabled) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
- (plant_id, name, url, channel, device_code, pump_name, model_subdir, int(enabled))
- )
- self._conn.commit()
- stream_id = cursor.lastrowid
- logger.info(f"RTSP流已创建: id={stream_id}, device_code={device_code}")
- return stream_id
- def update_stream(self, stream_id: int, **kwargs):
- allowed_fields = {'name', 'url', 'channel', 'device_code', 'pump_name',
- 'model_subdir', 'enabled', 'plant_id'}
- updates = {k: v for k, v in kwargs.items() if k in allowed_fields}
- if not updates:
- return
- if 'enabled' in updates:
- updates['enabled'] = int(updates['enabled'])
- set_clause = ', '.join(f"{k} = ?" for k in updates)
- values = list(updates.values()) + [stream_id]
- self._conn.execute(
- f"UPDATE rtsp_stream SET {set_clause} WHERE id = ?", values
- )
- self._conn.commit()
- logger.info(f"RTSP流已更新: id={stream_id}, fields={list(updates.keys())}")
- def delete_stream(self, stream_id: int):
- self._conn.execute("DELETE FROM rtsp_stream WHERE id = ?", (stream_id,))
- self._conn.commit()
- logger.info(f"RTSP流已删除: id={stream_id}")
- # ========================================
- # 流量 PLC CRUD
- # ========================================
- def set_flow_plc(self, plant_id: int, pump_name: str, plc_address: str):
- self._conn.execute(
- "INSERT INTO flow_plc (plant_id, pump_name, plc_address) VALUES (?, ?, ?) "
- "ON CONFLICT(plant_id, pump_name) DO UPDATE SET plc_address=excluded.plc_address",
- (plant_id, pump_name, plc_address)
- )
- self._conn.commit()
- def delete_flow_plc(self, plant_id: int, pump_name: str):
- self._conn.execute(
- "DELETE FROM flow_plc WHERE plant_id = ? AND pump_name = ?",
- (plant_id, pump_name)
- )
- self._conn.commit()
- # ========================================
- # 泵状态 PLC CRUD
- # ========================================
- def add_pump_status_plc(self, plant_id: int, pump_name: str,
- point: str, point_name: str = '') -> int:
- cursor = self._conn.execute(
- "INSERT INTO pump_status_plc (plant_id, pump_name, point, point_name) "
- "VALUES (?, ?, ?, ?)",
- (plant_id, pump_name, point, point_name)
- )
- self._conn.commit()
- return cursor.lastrowid
- def delete_pump_status_plc(self, plc_id: int):
- self._conn.execute("DELETE FROM pump_status_plc WHERE id = ?", (plc_id,))
- self._conn.commit()
- # ========================================
- # 工具方法:类型序列化/反序列化
- # ========================================
- @staticmethod
- def _serialize_value(value: Any, value_type: str) -> str:
- # 所有值统一序列化为字符串存储
- if value_type == 'json':
- return json.dumps(value, ensure_ascii=False)
- if value_type == 'bool':
- return '1' if value else '0'
- return str(value)
- @staticmethod
- def _deserialize_value(raw: str, value_type: str) -> Any:
- # 根据 value_type 将字符串还原为 Python 对象
- if value_type == 'int':
- return int(raw)
- elif value_type == 'float':
- return float(raw)
- elif value_type == 'bool':
- return raw in ('1', 'true', 'True')
- elif value_type == 'json':
- return json.loads(raw)
- return raw
- @staticmethod
- def _infer_type(value: Any) -> str:
- # 根据 Python 类型推断 value_type 标识
- if isinstance(value, bool):
- return 'bool'
- elif isinstance(value, int):
- return 'int'
- elif isinstance(value, float):
- return 'float'
- elif isinstance(value, (dict, list)):
- return 'json'
- return 'str'
- @staticmethod
- def _set_nested(d: dict, key_path: str, value: Any):
- # 将点号分隔的 key_path 设置到嵌套 dict 中
- # 例如 _set_nested({}, "voting.window_size", 5) => {"voting": {"window_size": 5}}
- keys = key_path.split('.')
- current = d
- for key in keys[:-1]:
- if key not in current:
- current[key] = {}
- current = current[key]
- current[keys[-1]] = value
- @staticmethod
- def _flatten_dict(d: dict, parent_key: str = '') -> dict:
- # 将嵌套 dict 展平为点号分隔的 KV 对
- # 例如 {"voting": {"window_size": 5}} => {"voting.window_size": 5}
- items = {}
- for k, v in d.items():
- new_key = f"{parent_key}.{k}" if parent_key else k
- if isinstance(v, dict):
- items.update(ConfigManager._flatten_dict(v, new_key))
- else:
- items[new_key] = v
- return items
|