jiyuhang 2 месяцев назад
Родитель
Сommit
8ae25d3fc1

+ 1 - 1
.env

@@ -11,7 +11,7 @@ WORKERS=0
 # CUDA设备
 # CUDA设备
 CUDA_VISIBLE_DEVICES=0
 CUDA_VISIBLE_DEVICES=0
 # batch size
 # batch size
-BATCH_SIZE=128
+BATCH_SIZE=32
 # pretrained
 # pretrained
 PRETRAINED=True
 PRETRAINED=True
 # 是否使用bias
 # 是否使用bias

+ 2 - 2
labelme/crop_patch.py

@@ -15,9 +15,9 @@ patch_h = int(os.getenv('PATCH_HEIGHT', 256))
 
 
 def main():
 def main():
     # TODO:需要修改为标注好的图片路径
     # TODO:需要修改为标注好的图片路径
-    input_path = r'D:\code\water_turbidity_det\frame_data\test\20251230\4video_20251229133103'
+    input_path = r'D:\code\water_turbidity_det\frame_data\train\20260107\4_ch26_20260105141754'
     # TODO: 需要修改为保存patch的根目录
     # TODO: 需要修改为保存patch的根目录
-    output_path_root = r'D:\code\water_turbidity_det\label_data_tem\test'
+    output_path_root = r'D:\code\water_turbidity_det\label_data_tem\train'
 
 
     # 读取标注文件
     # 读取标注文件
     label_path = os.path.join(input_path, 'label.txt')
     label_path = os.path.join(input_path, 'label.txt')

+ 1 - 1
labelme/fixed_label.py

@@ -164,7 +164,7 @@ def main():
     global patch_h
     global patch_h
     global scale
     global scale
     # TODO: 需要更改为准备标注的图像路径,使用当前目录下的000000.jpg,结果保存在当前目录下label.txt
     # TODO: 需要更改为准备标注的图像路径,使用当前目录下的000000.jpg,结果保存在当前目录下label.txt
-    img_path = r'D:\code\water_turbidity_det\frame_data\train\20251230\4video_20251229160514\000040.jpg'
+    img_path = r'D:\code\water_turbidity_det\frame_data\4_ch26_20260105141754\000000.jpg'
     play_video(img_path)
     play_video(img_path)
     img = cv2.imread(img_path)
     img = cv2.imread(img_path)
     draw_tool.set_path(img_path)
     draw_tool.set_path(img_path)

+ 19 - 15
labelme/random_del.py

@@ -4,22 +4,26 @@ import random
 
 
 def main():
 def main():
     # TODO:需要修改图像路径
     # TODO:需要修改图像路径
-    path = r'D:\code\water_turbidity_det\label_data_tem\train\0'
+    path = r'D:\code\water_turbidity_det\frame_data\1_ch25_20260105084338'
     del_rate = 0.3
     del_rate = 0.3
     img_path = [i for i in os.listdir(path) if i.split('.')[-1] in ['jpg', 'png'] ]
     img_path = [i for i in os.listdir(path) if i.split('.')[-1] in ['jpg', 'png'] ]
-    random.shuffle(img_path)
-
-    del_list = img_path[:int(len(img_path)*del_rate)]
-
-    for i in del_list:
-        target_path = os.path.join(path, i)
-        if os.path.isfile(target_path):  # 或者使用 os.path.exists(file_path)
-            os.remove(target_path)
-            print("文件删除成功。",target_path)
-
-    print(f"文件数量: {len(img_path)}")
-    print(f"删除比例: {del_rate}")
-    print(f"删除数量: {len(del_list)}")
-    print(f'剩余数量: {len(img_path)-len(del_list)}')
+    is_del = input(f'{path}路径内共有{len(img_path)}张图片,是否删除{del_rate}比例的图片?(y/n): ')
+    is_change_rate = input('是否修改删除比例?(y/n): ')
+    if is_change_rate.lower() == 'y':
+        del_rate = float(input('请输入删除比例:'))
+    if is_del.lower() == 'y':
+        random.shuffle(img_path)
+        del_list = img_path[:int(len(img_path)*del_rate)]
+        for i in del_list:
+            target_path = os.path.join(path, i)
+            if os.path.isfile(target_path):  # 或者使用 os.path.exists(file_path)
+                os.remove(target_path)
+                print("文件删除成功。",target_path)
+        print(f"文件数量: {len(img_path)}")
+        print(f"删除比例: {del_rate}")
+        print(f"删除数量: {len(del_list)}")
+        print(f'剩余数量: {len(img_path)-len(del_list)}')
+    else:
+        exit()
 if __name__ == '__main__':
 if __name__ == '__main__':
     main()
     main()

+ 1 - 1
labelme/statistic.py

@@ -50,7 +50,7 @@ def format_statistics(stats: dict) -> str:
 
 
 def main():
 def main():
     # TODO:修改数据集路径
     # TODO:修改数据集路径
-    train_data_path = r'D:\code\water_turbidity_det\label_data_tem'
+    train_data_path = r'D:\code\water_turbidity_det\label_data'
     dirs = os.listdir(train_data_path)
     dirs = os.listdir(train_data_path)
     
     
     # 检查数据集目录是否存在
     # 检查数据集目录是否存在

+ 1 - 1
labelme/video_depart.py

@@ -5,7 +5,7 @@ import shutil
 def main():
 def main():
     # 视频路径
     # 视频路径
     # TODO: 修改视频路径为自己的视频路径,每次指定一个视频
     # TODO: 修改视频路径为自己的视频路径,每次指定一个视频
-    path = r'D:\code\water_turbidity_det\video\20251230day\4video_20251229160514.mp4'
+    path = r'D:\code\water_turbidity_det\video\20260107\4_ch26_20260105141754.mp4'
     output_rootpath = r'D:\code\water_turbidity_det\frame_data'  # 输出路径的根目录
     output_rootpath = r'D:\code\water_turbidity_det\frame_data'  # 输出路径的根目录
     # 抽帧间隔
     # 抽帧间隔
     interval = 20
     interval = 20

+ 55 - 4
model/model_zoon.py

@@ -1,8 +1,59 @@
+import torch
+from torchvision.models import resnet18,resnet50, squeezenet1_0,shufflenet_v2_x1_0,shufflenet_v2_x2_0
+import torch.nn as nn
 
 
 
 
 
 
+def load_model(name:str, num_classes:int, device:torch.device,imagenet:bool=None, weights_path:str=None):
+    """加载模型结构"""
 
 
-
-def load_model(name:str):
-
-    pass
+    # 加载模型
+    pretrained = True if imagenet else False
+    if name == 'resnet50':
+        model = resnet50(pretrained=pretrained)
+    elif name == 'squeezenet':
+        model = squeezenet1_0(pretrained=pretrained)
+    elif name == 'shufflenet' or name == 'shufflenet-x1':
+        model = shufflenet_v2_x1_0(pretrained=pretrained)
+    elif name == 'shufflenet-x2':
+        model = shufflenet_v2_x2_0(pretrained=False)
+        imagenet = False
+        print('shufflenet-x2无预训练权重,重新训练所有权重')
+    else:
+        raise ValueError(f"Invalid model name: {name}")
+    # 如果采用预训练的神经网络,就需要冻结特征提取层,只训练最后几层
+    if imagenet:
+        for param in model.parameters():
+            param.requires_grad = False
+    # 替换最后的分类层以适应新的分类任务
+    print(model)
+    print(f"正在将模型{name}的分类层替换为{num_classes}个类别")
+    if hasattr(model, 'fc'):
+        # ResNet系列模型
+        model.fc = nn.Linear(int(model.fc.in_features), num_classes, bias=True)
+    elif hasattr(model, 'classifier'):
+        # SqueezeNet、ShuffleNet系列模型
+        if name == 'squeezenet':
+            # 获取SqueezeNet的最后一个卷积层的输入通道数
+            final_conv_in_channels = model.classifier[1].in_channels
+            # 替换classifier为新的Sequential,将输出改为2类
+            model.classifier = nn.Sequential(
+                nn.Dropout(p=0.5),
+                nn.Conv2d(final_conv_in_channels, num_classes, kernel_size=(1, 1)),
+                nn.ReLU(inplace=True),
+                nn.AdaptiveAvgPool2d((1, 1))
+            )
+        else:
+            # Swin Transformer等模型
+            model.classifier = nn.Linear(int(model.classifier.in_features), num_classes, bias=True)
+    elif hasattr(model, 'head'):
+        # Swin Transformer使用head层
+        model.head = nn.Linear(int(model.head.in_features), num_classes, bias=True)
+    else:
+        raise ValueError(f"Model {name} does not have recognizable classifier layer")
+    print(f'模型{name}结构已经加载,移动到设备{device}')
+    if weights_path:
+        model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
+    # 将模型移动到GPU/cpu
+    model = model.to(device)
+    return model

+ 519 - 0
rtsp_video_extractor.py

@@ -0,0 +1,519 @@
+# -*- coding: utf-8 -*-
+"""
+video_frame_extractor.py
+Python 3.8+
+
+修复 NPU(BM) 抽帧“花屏/块状错位/灰块”的常见原因:pipe 实际输出像素格式不是 bgr24
+或存在 stride/对齐问题。
+
+改动要点(NPU 路径):
+1) ffmpeg 输出改为 rawvideo + yuv420p(I420),避免 bgr24 在 BM 链路中不稳定
+2) Python 按 yuv420p 每帧读取 w*h*3/2 字节,再用 OpenCV 转成 BGR ndarray
+3) CPU/CUDA 仍然直接输出 bgr24(保持原行为)
+
+同时保留:
+- 自动重连
+- 分辨率/codec 探测(按你给的特殊 ffprobe 解析)
+- NPU/CUDA 检测与 prefer 策略
+"""
+
+import subprocess as sp
+import threading
+import queue
+import time
+import os
+import signal
+import logging
+from logging.handlers import RotatingFileHandler
+from typing import Optional, Tuple
+
+import numpy as np
+import cv2
+
+
+class RTSPFrameExtractor:
+    """
+    高并发、自动重连、可选 CUDA / NPU(BM) 加速的 RTSP → Numpy 帧提取器
+
+    prefer:
+      - "npu": 优先 NPU,其次 CUDA,否则 CPU
+      - "cuda": 优先 CUDA,其次 NPU,否则 CPU
+      - "cpu": 强制 CPU
+    """
+
+    def __init__(self,
+                 rtsp_url: str,
+                 fps: int = 1,
+                 width: Optional[int] = 1280,
+                 height: Optional[int] = 720,
+                 queue_size: int = 200,
+                 no_frame_timeout: int = 300,
+                 log_path: str = "./logs/frame_extractor.log",
+                 use_cuda: bool = True,
+                 use_npu: bool = True,
+                 prefer: str = "npu",
+                 ):
+        # ----------------- 基本参数 -----------------
+        self.rtsp_url = rtsp_url
+        self.fps = int(fps)
+        self.width = width
+        self.height = height
+
+        self.codec_name: str = ""  # h264/hevc/...
+        self.output_pix_fmt: str = "bgr24"  # NPU 会切到 yuv420p
+        self.frame_size: int = 0  # bytes per frame (pipe)
+
+        self.queue = queue.Queue(maxsize=queue_size)
+        self.no_frame_timeout = int(no_frame_timeout)
+
+        self.proc: Optional[sp.Popen] = None
+        self.read_thread: Optional[threading.Thread] = None
+        self.running = threading.Event()
+        self.last_frame_ts = 0.0
+        self.last_probe_ts = 0.0
+
+        self._restart_lock = threading.Lock()
+
+        # ----------------- 加速配置 -----------------
+        self.use_cuda_cfg = bool(use_cuda)
+        self.use_npu_cfg = bool(use_npu)
+        self.prefer = (prefer or "npu").lower().strip()
+
+        self.cuda_enabled = False
+        self.npu_enabled = False
+
+        # ----------------- 日志 -----------------
+        log_dir = os.path.dirname(log_path)
+        if log_dir and (not os.path.exists(log_dir)):
+            os.makedirs(log_dir, exist_ok=True)
+        if not os.path.exists(log_path):
+            open(log_path, 'w').close()
+        self.logger = self._init_logger(log_path)
+
+        # ----------------- 检测加速能力 -----------------
+        npu_ok = self._npu_available() if self.use_npu_cfg else False
+        cuda_ok = self._cuda_available() if self.use_cuda_cfg else False
+
+        # 按 prefer 选择
+        self._select_accel(npu_ok=npu_ok, cuda_ok=cuda_ok)
+
+        self.logger.info(
+            "Accel decision: prefer=%s, npu_ok=%s, cuda_ok=%s -> NPU=%s CUDA=%s",
+            self.prefer, npu_ok, cuda_ok, self.npu_enabled, self.cuda_enabled
+        )
+
+        # ----------------- 启动 -----------------
+        self._bootstrap()
+
+    # --------------------------------------------------------------------- #
+    #                              PUBLIC API                               #
+    # --------------------------------------------------------------------- #
+    def get_frame(self, timeout: float = 1.0) -> Optional[Tuple[np.ndarray, float]]:
+        """
+        - 正常返回: (frame: np.ndarray[BGR], timestamp: float)
+        - 超时 / 无帧: None
+        """
+        try:
+            return self.queue.get(timeout=timeout)
+        except queue.Empty:
+            if time.time() - self.last_frame_ts > self.no_frame_timeout:
+                self.logger.warning("No frame for %.1f sec, restarting...", self.no_frame_timeout)
+                self._restart()
+            return None
+
+    def stop(self):
+        self.running.clear()
+        if self.read_thread and self.read_thread.is_alive():
+            self.read_thread.join(timeout=2)
+
+        if self.proc:
+            self._kill_proc(self.proc)
+        self.proc = None
+        self.logger.info("RTSPFrameExtractor stopped.")
+
+    def close(self):
+        self.stop()
+
+    # --------------------------------------------------------------------- #
+    #                              INTERNAL                                 #
+    # --------------------------------------------------------------------- #
+    def _select_accel(self, npu_ok: bool, cuda_ok: bool):
+        if self.prefer == "npu":
+            self.npu_enabled = bool(npu_ok)
+            self.cuda_enabled = bool((not self.npu_enabled) and cuda_ok)
+        elif self.prefer == "cuda":
+            self.cuda_enabled = bool(cuda_ok)
+            self.npu_enabled = bool((not self.cuda_enabled) and npu_ok)
+        elif self.prefer == "cpu":
+            self.cuda_enabled = False
+            self.npu_enabled = False
+        else:
+            # 兜底:npu -> cuda -> cpu
+            self.npu_enabled = bool(npu_ok)
+            self.cuda_enabled = bool((not self.npu_enabled) and cuda_ok)
+
+    def _bootstrap(self):
+        if self.width is None or self.height is None or not self.codec_name:
+            self.logger.info("Probing RTSP resolution/codec...")
+            w, h, codec = self._probe_resolution_loop()
+            if self.width is None:
+                self.width = w
+            if self.height is None:
+                self.height = h
+            self.codec_name = codec or "h264"
+            self.logger.info("Got stream info: %dx%d codec=%s", self.width, self.height, self.codec_name)
+
+        self._recompute_output_format_and_size()
+        self._start_ffmpeg()
+        self._start_reader()
+
+    def _recompute_output_format_and_size(self):
+        """
+        关键:NPU 走 yuv420p 输出,CPU/CUDA 走 bgr24 输出
+        """
+        if not self.width or not self.height:
+            raise ValueError("width/height not set")
+
+        if self.npu_enabled:
+            # 为了稳定,NPU 输出 yuv420p (I420): size = w*h*3/2
+            self.output_pix_fmt = "yuv420p"
+            self.frame_size = int(self.width) * int(self.height) * 3 // 2
+        else:
+            self.output_pix_fmt = "bgr24"
+            self.frame_size = int(self.width) * int(self.height) * 3
+
+    # ----------------------------- probing -------------------------------- #
+    def _probe_resolution_loop(self) -> Tuple[int, int, str]:
+        while True:
+            w, h, c = self._probe_once()
+            if w and h:
+                return w, h, (c or "h264")
+            self.logger.warning("ffprobe failed, retry in 2s...")
+            time.sleep(2)
+
+    def _probe_once(self) -> Tuple[int, int, str]:
+        """
+        你给的特殊 ffprobe 解析方式
+        """
+        cmd = [
+            "ffprobe", "-v", "error", "-rtsp_transport", "tcp",
+            "-select_streams", "v:0",
+            "-show_entries", "stream=width,height,codec_name",
+            "-of", "csv=p=0",
+            self.rtsp_url
+        ]
+        try:
+            process = sp.run(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, timeout=10)
+            output = process.stdout.decode('utf-8', errors='ignore')
+
+            for line in output.splitlines():
+                line = line.strip()
+                if not line:
+                    continue
+                ignore_keywords = ['interleave', 'TCP based', 'BMvid', 'libbmvideo', 'firmware', 'VERSION']
+                if any(k in line for k in ignore_keywords):
+                    continue
+
+                parts = [p.strip() for p in line.split(',') if p.strip()]
+                if len(parts) < 2:
+                    continue
+
+                nums = []
+                codec_candidate = ""
+                for p_str in parts:
+                    if p_str.isdigit():
+                        nums.append(int(p_str))
+                    elif len(p_str) < 20:
+                        codec_candidate = p_str.lower()
+
+                if len(nums) >= 2:
+                    codec_candidate = codec_candidate or "h264"
+                    if codec_candidate in ("h265", "hevc"):
+                        codec_candidate = "hevc"
+                    elif codec_candidate in ("h264", "avc1", "avc"):
+                        codec_candidate = "h264"
+                    return nums[0], nums[1], codec_candidate
+
+        except Exception as e:
+            self.logger.error(f"ffprobe unexpected error: {e}")
+        return 0, 0, ""
+
+    # --------------------------- ffmpeg 管理 ------------------------------ #
+    def _ffmpeg_cmd(self):
+        """
+        NPU(BM):
+          -c:v h264_bm/hevc_bm
+          -vf "scale_bm=...:format=yuv420p,fps=..."
+          输出:-f rawvideo -pix_fmt yuv420p -
+        CPU/CUDA:
+          输出:-f rawvideo -pix_fmt bgr24 -
+        """
+        cmd = ['ffmpeg', '-loglevel', 'error', '-rtsp_transport', 'tcp']
+
+        if self.npu_enabled:
+            codec = (self.codec_name or "h264").lower()
+            dec = "hevc_bm" if codec in ("hevc", "h265") else "h264_bm"
+
+            vf_parts = []
+            # 强制经过 scale_bm 并输出 yuv420p,减少 stride/格式不确定性
+            vf_parts.append(f"scale_bm=w={int(self.width)}:h={int(self.height)}:format=yuv420p")
+            vf_parts.append(f"fps={int(self.fps)}")
+            vf = ",".join(vf_parts)
+
+            cmd += ['-c:v', dec, '-i', self.rtsp_url, '-vf', vf, '-an',
+                    '-f', 'rawvideo', '-pix_fmt', 'yuv420p', '-']
+            return cmd
+
+        # CUDA(仅硬解,输出仍由 ffmpeg 转成 bgr24)
+        if self.cuda_enabled:
+            cmd += ['-hwaccel', 'cuda']
+
+        cmd += ['-i', self.rtsp_url,
+                '-vf', f'fps={int(self.fps)}',
+                '-an',
+                '-f', 'rawvideo',
+                '-pix_fmt', 'bgr24',
+                '-']
+        return cmd
+
+    def _start_ffmpeg(self):
+        if self.proc:
+            self._kill_proc(self.proc)
+
+        self.logger.info("Starting ffmpeg... (NPU=%s CUDA=%s codec=%s out_pix=%s)",
+                         self.npu_enabled, self.cuda_enabled, self.codec_name, self.output_pix_fmt)
+        self.logger.debug("ffmpeg cmd → %s", " ".join(self._ffmpeg_cmd()))
+
+        kwargs = {}
+        if os.name == 'posix':
+            kwargs['preexec_fn'] = os.setsid
+        elif os.name == 'nt':
+            kwargs['creationflags'] = sp.CREATE_NEW_PROCESS_GROUP
+
+        self.proc = sp.Popen(
+            self._ffmpeg_cmd(),
+            stdout=sp.PIPE,
+            stderr=sp.PIPE,
+            bufsize=self.frame_size * 10,
+            **kwargs
+        )
+        self.last_frame_ts = time.time()
+        self.last_probe_ts = time.time()
+
+    # -------------------------- Reader Thread ----------------------------- #
+    def _start_reader(self):
+        self.running.set()
+        self.read_thread = threading.Thread(target=self._reader_loop, daemon=True)
+        self.read_thread.start()
+
+    def _reader_loop(self):
+        self.logger.info("Reader thread started.")
+        w = int(self.width)
+        h = int(self.height)
+
+        while self.running.is_set():
+            try:
+                if not self.proc or not self.proc.stdout:
+                    time.sleep(0.1)
+                    continue
+
+                raw = self.proc.stdout.read(self.frame_size)
+                ts = time.time()
+            except Exception as e:
+                self.logger.error("Read error: %s", e)
+                raw, ts = b'', time.time()
+
+            if len(raw) != self.frame_size:
+                self.logger.warning("Incomplete frame (%d/%d bytes).", len(raw), self.frame_size)
+
+                # 进程退出则重启
+                if self.proc and (self.proc.poll() is not None):
+                    self.logger.warning("ffmpeg exited with code=%s, restarting...", self.proc.returncode)
+                    self._restart()
+                    return
+
+                time.sleep(0.05)
+                continue
+
+            # --------- 按输出像素格式解包 ---------
+            if self.output_pix_fmt == "yuv420p":
+                # I420: shape = (h*3/2, w)
+                yuv = np.frombuffer(raw, np.uint8).reshape((h * 3 // 2, w))
+                frame = cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR_I420)
+            else:
+                frame = np.frombuffer(raw, np.uint8).reshape((h, w, 3)).copy()
+
+            # 入队
+            try:
+                self.queue.put_nowait((frame, ts))
+            except queue.Full:
+                _ = self.queue.get_nowait()
+                self.queue.put_nowait((frame, ts))
+
+            self.last_frame_ts = time.time()
+
+            # 每小时 probe 一次(分辨率/codec 变化就重启)
+            if time.time() - self.last_probe_ts >= 3600:
+                self.last_probe_ts = time.time()
+                nw, nh, ncodec = self._probe_once()
+                if nw and nh:
+                    changed = (nw != self.width or nh != self.height or (ncodec and ncodec != self.codec_name))
+                    if changed:
+                        self.logger.warning("Stream info changed %dx%d/%s → %dx%d/%s, restarting...",
+                                            self.width, self.height, self.codec_name, nw, nh, ncodec)
+                        self.width, self.height = nw, nh
+                        self.codec_name = ncodec or self.codec_name
+                        self._recompute_output_format_and_size()
+                        self._restart()
+                        return
+
+        self.logger.info("Reader thread exit.")
+
+    # ----------------------------- Restart -------------------------------- #
+    def _restart(self):
+        with self._restart_lock:
+            self.running.clear()
+            if self.read_thread and self.read_thread.is_alive():
+                self.read_thread.join(timeout=2)
+
+            if self.proc:
+                self._kill_proc(self.proc)
+            self.proc = None
+
+            self.logger.info("Restarting, probing resolution/codec...")
+            w, h, codec = self._probe_resolution_loop()
+            self.width, self.height, self.codec_name = w, h, (codec or self.codec_name or "h264")
+
+            # 重启时重新检测设备能力
+            npu_ok = self._npu_available() if self.use_npu_cfg else False
+            cuda_ok = self._cuda_available() if self.use_cuda_cfg else False
+            self._select_accel(npu_ok=npu_ok, cuda_ok=cuda_ok)
+
+            self._recompute_output_format_and_size()
+
+            self.logger.info("New stream info: %dx%d codec=%s | Accel: NPU=%s CUDA=%s | out_pix=%s",
+                             self.width, self.height, self.codec_name, self.npu_enabled, self.cuda_enabled, self.output_pix_fmt)
+
+            self._start_ffmpeg()
+            self._start_reader()
+
+    # ----------------------------- Utils ---------------------------------- #
+    @staticmethod
+    def _kill_proc(proc: sp.Popen):
+        if proc and proc.poll() is None:
+            try:
+                if os.name == 'posix':
+                    os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
+                elif os.name == 'nt':
+                    proc.send_signal(signal.CTRL_BREAK_EVENT)
+            except Exception:
+                try:
+                    proc.terminate()
+                except Exception:
+                    pass
+            try:
+                proc.wait(timeout=3)
+            except sp.TimeoutExpired:
+                try:
+                    proc.kill()
+                except Exception:
+                    pass
+
+    def _cuda_available(self) -> bool:
+        try:
+            out = sp.check_output(
+                ['ffmpeg', '-hide_banner', '-hwaccels'],
+                stderr=sp.STDOUT, timeout=3
+            ).decode(errors='ignore').lower()
+            if 'cuda' not in out:
+                return False
+        except Exception:
+            return False
+
+        try:
+            sp.check_output(['nvidia-smi', '-L'], stderr=sp.STDOUT, timeout=3)
+        except Exception:
+            return False
+
+        return True
+
+    def _npu_available(self) -> bool:
+        """
+        检测 BM NPU 能力:
+        - codecs 里存在 h264_bm/hevc_bm
+        - filters 里存在 scale_bm
+        """
+        try:
+            codecs = sp.check_output(
+                ['ffmpeg', '-hide_banner', '-codecs'],
+                stderr=sp.STDOUT, timeout=3
+            ).decode(errors='ignore').lower()
+            if ('h264_bm' not in codecs) and ('hevc_bm' not in codecs):
+                return False
+        except Exception:
+            return False
+
+        try:
+            filters = sp.check_output(
+                ['ffmpeg', '-hide_banner', '-filters'],
+                stderr=sp.STDOUT, timeout=3
+            ).decode(errors='ignore').lower()
+            if 'scale_bm' not in filters:
+                return False
+        except Exception:
+            return False
+
+        return True
+
+    @staticmethod
+    def _init_logger(log_path: str):
+        logger = logging.getLogger("FrameExtractor")
+        if logger.handlers:
+            return logger
+        logger.setLevel(logging.INFO)
+
+        handler = RotatingFileHandler(
+            log_path, maxBytes=10 * 1024 * 1024,
+            backupCount=5, encoding='utf-8'
+        )
+        fmt = logging.Formatter(
+            fmt="%(asctime)s %(levelname)s: %(message)s",
+            datefmt="%Y-%m-%d %H:%M:%S"
+        )
+        handler.setFormatter(fmt)
+        logger.addHandler(handler)
+
+        console = logging.StreamHandler()
+        console.setFormatter(fmt)
+        logger.addHandler(console)
+        return logger
+
+
+# ----------------------------- Demo ---------------------------------- #
+if __name__ == "__main__":
+    RTSP = "rtsp://rtsp:newwater123@222.130.26.194:59371/streaming/channels/401"
+
+    extractor = RTSPFrameExtractor(
+        rtsp_url=RTSP,
+        fps=1,
+        use_npu=True,
+        use_cuda=True,
+        prefer="npu",
+        width=1920,
+        height=1080
+    )
+
+    try:
+        while True:
+            item = extractor.get_frame(timeout=2)
+            if item is None:
+                continue
+            frame, ts = item
+            cv2.putText(frame, f"Time: {ts:.3f}", (10, 30),
+                        cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
+            cv2.imshow("RTSP Stream", frame)
+            if cv2.waitKey(1) & 0xFF == ord('q'):
+                break
+    finally:
+        extractor.close()
+        cv2.destroyAllWindows()

BIN
runs/turbidity_resnet50_20260113-174312/events.out.tfevents.1768297392.240.724077.0


BIN
runs/turbidity_shufflenet-x2_20260113-161246/events.out.tfevents.1768291966.240.328038.0


BIN
runs/turbidity_shufflenet_20260113-154754/events.out.tfevents.1768290474.240.186034.0


BIN
runs/turbidity_squeezenet_20251226-105741/events.out.tfevents.1766717861.O5XVSKDYW6B57G7.32516.0


BIN
runs/turbidity_squeezenet_20251226-105857/events.out.tfevents.1766717937.O5XVSKDYW6B57G7.17992.0


BIN
runs/turbidity_squeezenet_20251226-110326/events.out.tfevents.1766718206.O5XVSKDYW6B57G7.28116.0


BIN
runs/turbidity_squeezenet_20260113-170816/events.out.tfevents.1768295296.240.558540.0



+ 49 - 69
train.py

@@ -6,7 +6,7 @@ import torch.optim as optim
 from torch.utils.data import DataLoader
 from torch.utils.data import DataLoader
 import torchvision.transforms as transforms
 import torchvision.transforms as transforms
 from torchvision.datasets import ImageFolder
 from torchvision.datasets import ImageFolder
-from torchvision.models import resnet18,resnet50, squeezenet1_0,shufflenet_v2_x1_0,shufflenet_v2_x2_0
+from model.model_zoon import load_model
 from torch.utils.tensorboard import SummaryWriter  # 添加 TensorBoard 支持
 from torch.utils.tensorboard import SummaryWriter  # 添加 TensorBoard 支持
 from datetime import datetime
 from datetime import datetime
 import os
 import os
@@ -28,7 +28,7 @@ def print_env_variables():
         print(f"{var}: {value}")
         print(f"{var}: {value}")
 
 
 class Trainer:
 class Trainer:
-    def __init__(self, batch_size, train_dir, val_dir, name, checkpoint):
+    def __init__(self, batch_size, train_dir, val_dir, name, checkpoint:bool=False):
         # 定义一些参数
         # 定义一些参数
         self.name = name  # 采用的模型名称
         self.name = name  # 采用的模型名称
         self.img_size = int(os.getenv('IMG_INPUT_SIZE', 224))  # 输入图片尺寸
         self.img_size = int(os.getenv('IMG_INPUT_SIZE', 224))  # 输入图片尺寸
@@ -48,8 +48,15 @@ class Trainer:
         else:
         else:
             self.device = torch.device("cpu")
             self.device = torch.device("cpu")
             print("CUDA不可用,使用CPU进行训练")
             print("CUDA不可用,使用CPU进行训练")
-        self.checkpoint = checkpoint
         self.__global_step = 0
         self.__global_step = 0
+        self.best_val_acc = 0.0
+        self.epoch = 0
+        self.best_val_loss = float('inf')
+        self.checkpoint_root_path = './checkpoints'
+        self.best_acc_model_path = os.path.join(self.checkpoint_root_path , f'{self.name}_best_model_acc.pth')
+        self.best_loss_model_path = os.path.join(self.checkpoint_root_path , f'{self.name}_best_model_loss.pth')
+        self.latest_checkpoint_path = os.path.join(self.checkpoint_root_path , f'{self.name}_latest_checkpoint.pth')
+
         self.workers = int(os.getenv('WORKERS', 0))
         self.workers = int(os.getenv('WORKERS', 0))
         # 创建日志目录
         # 创建日志目录
         timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")  # 获取当前时间戳
         timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")  # 获取当前时间戳
@@ -90,8 +97,8 @@ class Trainer:
         print(f"训练集图像数量: {len(self.train_dataset)}")
         print(f"训练集图像数量: {len(self.train_dataset)}")
         print(f"验证集图像数量: {len(self.val_dataset)}")
         print(f"验证集图像数量: {len(self.val_dataset)}")
         # 创建模型
         # 创建模型
-        self.model = None
-        self.model = self.__load_model()
+        self.model = load_model(name=self.name, imagenet=self.imagenet, num_classes=self.num_classes, device=self.device)
+
         # 定义损失函数
         # 定义损失函数
         self.loss = nn.CrossEntropyLoss()  # 多分类常用的交叉熵损失
         self.loss = nn.CrossEntropyLoss()  # 多分类常用的交叉熵损失
 
 
@@ -103,61 +110,38 @@ class Trainer:
         self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
         self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
             self.optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-7,cooldown=2
             self.optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-7,cooldown=2
         )
         )
-    def __load_model(self):
-        """加载模型结构"""
-        # 加载模型
-        pretrained = True if self.imagenet else False
-        if self.name == 'resnet50':
-            self.model = resnet50(pretrained=pretrained)
-        elif self.name == 'squeezenet':
-            self.model = squeezenet1_0(pretrained=pretrained)
-        elif self.name == 'shufflenet' or self.name == 'shufflenet-x1':
-            self.model = shufflenet_v2_x1_0(pretrained=pretrained)
-        elif self.name == 'shufflenet-x2':
-            self.model = shufflenet_v2_x2_0(pretrained=False)
-            self.imagenet = False
-            print('shufflenet-x2无预训练权重,重新训练所有权重')
-        else:
-            raise ValueError(f"Invalid model name: {self.name}")
-        # 如果采用预训练的神经网络,就需要冻结特征提取层,只训练最后几层
-        if self.imagenet:
-            for param in self.model.parameters():
-                param.requires_grad = False
-        # 替换最后的分类层以适应新的分类任务
-        print(self.model)
-        print(f"正在将模型{self.name}的分类层替换为{self.num_classes}个类别")
-        if hasattr(self.model, 'fc'):
-            # ResNet系列模型
-            self.model.fc = nn.Linear(int(self.model.fc.in_features), self.num_classes, bias=True)
-        elif hasattr(self.model, 'classifier'):
-            # SqueezeNet、ShuffleNet系列模型
-            if self.name == 'squeezenet':
-                # 获取SqueezeNet的最后一个卷积层的输入通道数
-                final_conv_in_channels = self.model.classifier[1].in_channels
-                # 替换classifier为新的Sequential,将输出改为2类
-                self.model.classifier = nn.Sequential(
-                    nn.Dropout(p=0.5),
-                    nn.Conv2d(final_conv_in_channels, self.num_classes, kernel_size=(1, 1)),
-                    nn.ReLU(inplace=True),
-                    nn.AdaptiveAvgPool2d((1, 1))
-                )
-            else:
-                # Swin Transformer等模型
-                self.model.classifier = nn.Linear(int(self.model.classifier.in_features), self.num_classes, bias=True)
-        elif hasattr(self.model, 'head'):
-            # Swin Transformer使用head层
-            self.model.head = nn.Linear(int(self.model.head.in_features), self.num_classes, bias=True)
-        else:
-            raise ValueError(f"Model {self.name} does not have recognizable classifier layer")
-        print(f'模型{self.name}结构已经加载,移动到设备{self.device}')
-        # 将模型移动到GPU/cpu
-        self.model = self.model.to(self.device)
-        return self.model
+        # 加载检查点
+        if checkpoint and os.path.exists(self.latest_checkpoint_path):
+            self.load_checkpoint()
+
+    def save_checkpoint(self):
+        if not os.path.exists(self.checkpoint_root_path ):
+            os.makedirs(self.checkpoint_root_path )
+        checkpoint = {
+            'epoch': self.epoch,
+            'model_state_dict': self.model.state_dict(),
+            'optimizer_state_dict': self.optimizer.state_dict(),
+            'scheduler_state_dict': self.scheduler.state_dict(),
+            'best_val_acc': self.best_val_acc,
+            'best_val_loss': self.best_val_loss
+        }
+        torch.save(checkpoint, self.latest_checkpoint_path)
+        print(f"已保存检查点到 {self.latest_checkpoint_path}")
+
+    def load_checkpoint(self, from_where='latest'):
+        checkpoint = torch.load(self.latest_checkpoint_path)
+        self.model.load_state_dict(checkpoint['model_state_dict'])
+        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
+        self.best_val_acc = checkpoint['best_val_acc']
+        self.best_val_loss = checkpoint['best_val_loss']
+        self.epoch = checkpoint['epoch'] + 1
+        print(f"从 {self.latest_checkpoint_path} 加载检查点")
+
 
 
     def train_step(self):
     def train_step(self):
         """
         """
         单轮训练函数
         单轮训练函数
-
         Args:
         Args:
 
 
         Returns:
         Returns:
@@ -174,20 +158,15 @@ class Trainer:
             # 将数据移到指定设备上
             # 将数据移到指定设备上
             inputs = inputs.to(self.device)  # b c h w
             inputs = inputs.to(self.device)  # b c h w
             labels = labels.to(self.device)  # b,
             labels = labels.to(self.device)  # b,
-
             # 清零梯度缓存
             # 清零梯度缓存
             self.optimizer.zero_grad()
             self.optimizer.zero_grad()
-
             # 前向传播
             # 前向传播
             outputs = self.model(inputs)  # b, 2
             outputs = self.model(inputs)  # b, 2
             loss = self.loss(outputs, labels) # 标量
             loss = self.loss(outputs, labels) # 标量
-
             # 反向传播
             # 反向传播
             loss.backward()
             loss.backward()
-
             # 更新参数
             # 更新参数
             self.optimizer.step()
             self.optimizer.step()
-
             # 统计信息
             # 统计信息
             batch_loss = loss.item() * inputs.size(0)  # 批损失
             batch_loss = loss.item() * inputs.size(0)  # 批损失
             self.writer.add_scalar('Batch_Loss/Train', batch_loss, self.__global_step)
             self.writer.add_scalar('Batch_Loss/Train', batch_loss, self.__global_step)
@@ -250,13 +229,10 @@ class Trainer:
             val_losses: 每轮验证损失
             val_losses: 每轮验证损失
             val_accuracies: 每轮验证准确率
             val_accuracies: 每轮验证准确率
         """
         """
-
-        best_val_acc = 0.0
-        best_val_loss = float('inf')
         # 在你的代码中调用
         # 在你的代码中调用
         print_env_variables()
         print_env_variables()
         print("开始训练...")
         print("开始训练...")
-        for epoch in range(num_epochs):
+        for epoch in range(self.epoch, num_epochs):
             print(f'Epoch {epoch + 1}/{num_epochs}')
             print(f'Epoch {epoch + 1}/{num_epochs}')
             print('-' * 20)
             print('-' * 20)
 
 
@@ -280,22 +256,25 @@ class Trainer:
 
 
 
 
             # 保存最佳模型 (基于验证准确率)
             # 保存最佳模型 (基于验证准确率)
-            if val_acc > best_val_acc:
+            if val_acc > self.best_val_acc:
                 best_val_acc = val_acc
                 best_val_acc = val_acc
                 torch.save(self.model.state_dict(), f'{self.name}_best_model_acc.pth')
                 torch.save(self.model.state_dict(), f'{self.name}_best_model_acc.pth')
                 print(f"保存了新的最佳准确率模型,验证准确率: {best_val_acc:.4f}")
                 print(f"保存了新的最佳准确率模型,验证准确率: {best_val_acc:.4f}")
             
             
             # 保存最低验证损失模型
             # 保存最低验证损失模型
-            if val_loss < best_val_loss:
+            if val_loss < self.best_val_loss:
                 best_val_loss = val_loss
                 best_val_loss = val_loss
                 torch.save(self.model.state_dict(), f'{self.name}_best_model_loss.pth')
                 torch.save(self.model.state_dict(), f'{self.name}_best_model_loss.pth')
                 print(f"保存了新的最低损失模型,验证损失: {best_val_loss:.4f}")
                 print(f"保存了新的最低损失模型,验证损失: {best_val_loss:.4f}")
 
 
+            self.save_checkpoint()
+            self.epoch += 1
+
 
 
         # 关闭 TensorBoard writer
         # 关闭 TensorBoard writer
         self.writer.close()
         self.writer.close()
         
         
-        print(f"训练完成! 最佳验证准确率: {best_val_acc:.4f}, 最低验证损失: {best_val_loss:.4f}")
+        print(f"训练完成! 最佳验证准确率: {self.best_val_acc:.4f}, 最低验证损失: {self.best_val_loss:.4f}")
         return 1
         return 1
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
@@ -305,11 +284,12 @@ if __name__ == '__main__':
     parser.add_argument('--train_dir',default='./label_data/train',help='help')
     parser.add_argument('--train_dir',default='./label_data/train',help='help')
     parser.add_argument('--val_dir', default='./label_data/test',help='help')
     parser.add_argument('--val_dir', default='./label_data/test',help='help')
     parser.add_argument('--model', default='squeezenet',help='help')
     parser.add_argument('--model', default='squeezenet',help='help')
+    parser.add_argument('--resume', action='store_true',help='是否恢复继续训练')
     args = parser.parse_args()
     args = parser.parse_args()
     num_epochs = 100
     num_epochs = 100
     trainer = Trainer(batch_size=int(os.getenv('BATCH_SIZE', 32)),
     trainer = Trainer(batch_size=int(os.getenv('BATCH_SIZE', 32)),
                       train_dir=args.train_dir,
                       train_dir=args.train_dir,
                       val_dir=args.val_dir,
                       val_dir=args.val_dir,
                       name=args.model,
                       name=args.model,
-                      checkpoint=False)
+                      checkpoint=args.resume)
     trainer.train_and_validate(num_epochs)
     trainer.train_and_validate(num_epochs)

+ 4 - 52
video_test.py

@@ -3,7 +3,7 @@ import time
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from torchvision import transforms
 from torchvision import transforms
-from torchvision.models import resnet18,resnet50, squeezenet1_0, shufflenet_v2_x1_0
+from model.model_zoon import load_model
 import numpy as np
 import numpy as np
 from PIL import Image
 from PIL import Image
 import os
 import os
@@ -25,61 +25,13 @@ class Predictor:
         self.model_name = model_name
         self.model_name = model_name
         self.weights_path = weights_path
         self.weights_path = weights_path
         self.num_classes = num_classes
         self.num_classes = num_classes
-        self.model = None
-        self.use_bias = os.getenv('USE_BIAS', True)
+        # self.use_bias = os.getenv('USE_BIAS', True)
         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         print(f"当前设备: {self.device}")
         print(f"当前设备: {self.device}")
-        # 加载模型
-        self.load_model()
-
+        self.model = self.load_model()
 
 
     def load_model(self):
     def load_model(self):
-        if self.model is not None:
-            return
-        print(f"正在加载模型: {self.model_name}")
-        # 加载模型
-        if self.model_name== 'resnet50':
-            self.model = resnet50()
-        elif self.model_name == 'squeezenet':
-
-            self.model = squeezenet1_0()
-        elif self.model_name == 'shufflenet':
-            self.model = shufflenet_v2_x1_0()
-        else:
-            raise ValueError(f"Invalid model name: {self.model_name}")
-        # 替换最后的分类层以适应新的分类任务
-        if hasattr(self.model, 'fc'):
-            # ResNet系列模型
-            self.model.fc = nn.Linear(int(self.model.fc.in_features), self.num_classes, bias=self.use_bias)
-        elif hasattr(self.model, 'classifier'):
-            # SqueezeNet、ShuffleNet系列模型
-            if self.model_name == 'squeezenet':
-                # 获取SqueezeNet的最后一个卷积层的输入通道数
-                final_conv_in_channels = self.model.classifier[1].in_channels
-                # 替换classifier为新的Sequential,将输出改为2类
-                self.model.classifier = nn.Sequential(
-                    nn.Dropout(p=0.5),
-                    nn.Conv2d(final_conv_in_channels, self.num_classes, kernel_size=(1, 1)),
-                    nn.ReLU(inplace=True),
-                    nn.AdaptiveAvgPool2d((1, 1))
-                )
-            else:
-                # Swin Transformer等模型
-                self.model.classifier = nn.Linear(int(self.model.classifier.in_features), self.num_classes, bias=True)
-        elif hasattr(self.model, 'head'):
-            # Swin Transformer使用head层
-            self.model.head = nn.Linear(int(self.model.head.in_features), self.num_classes, bias=self.use_bias)
-
-        else:
-            raise ValueError(f"Model {self.model_name} does not have recognizable classifier layer")
-        print(self.model)
-        # 加载训练好的权重
-        self.model.load_state_dict(torch.load(self.weights_path, map_location=torch.device('cpu')))
-        print(f"成功加载模型参数: {self.weights_path}")
-        # 将模型移动到GPU
-        self.model.eval()
-        self.model = self.model.to(self.device)
-        print(f"成功加载模型: {self.model_name}")
+        return load_model(name=self.model_name, num_classes=self.num_classes, weights_path=self.weights_path, device=self.device)
 
 
     def predict(self, image_tensor):
     def predict(self, image_tensor):
         """
         """

+ 295 - 0
video_test_rtsp.py

@@ -0,0 +1,295 @@
+import time
+
+import torch
+import torch.nn as nn
+from torchvision import transforms
+from torchvision.models import resnet18,resnet50, squeezenet1_0, shufflenet_v2_x1_0
+import numpy as np
+from PIL import Image
+import os
+import argparse
+from labelme.utils import draw_grid, draw_predict_grid
+import cv2
+import matplotlib.pyplot as plt
+from dotenv import load_dotenv
+load_dotenv()
+# os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
+patch_w = int(os.getenv('PATCH_WIDTH', 256))
+patch_h = int(os.getenv('PATCH_HEIGHT', 256))
+confidence_threshold = float(os.getenv('CONFIDENCE_THRESHOLD', 0.80))
+scale = 2
+
+
+class Predictor:
+    def __init__(self, model_name, weights_path, num_classes):
+        self.model_name = model_name
+        self.weights_path = weights_path
+        self.num_classes = num_classes
+        self.model = None
+        self.use_bias = os.getenv('USE_BIAS', True)
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        print(f"当前设备: {self.device}")
+        # 加载模型
+        self.load_model()
+
+
+    def load_model(self):
+        if self.model is not None:
+            return
+        print(f"正在加载模型: {self.model_name}")
+        # 加载模型
+        if self.model_name== 'resnet50':
+            self.model = resnet50()
+        elif self.model_name == 'squeezenet':
+
+            self.model = squeezenet1_0()
+        elif self.model_name == 'shufflenet':
+            self.model = shufflenet_v2_x1_0()
+        else:
+            raise ValueError(f"Invalid model name: {self.model_name}")
+        # 替换最后的分类层以适应新的分类任务
+        if hasattr(self.model, 'fc'):
+            # ResNet系列模型
+            self.model.fc = nn.Linear(int(self.model.fc.in_features), self.num_classes, bias=self.use_bias)
+        elif hasattr(self.model, 'classifier'):
+            # SqueezeNet、ShuffleNet系列模型
+            if self.model_name == 'squeezenet':
+                # 获取SqueezeNet的最后一个卷积层的输入通道数
+                final_conv_in_channels = self.model.classifier[1].in_channels
+                # 替换classifier为新的Sequential,将输出改为2类
+                self.model.classifier = nn.Sequential(
+                    nn.Dropout(p=0.5),
+                    nn.Conv2d(final_conv_in_channels, self.num_classes, kernel_size=(1, 1)),
+                    nn.ReLU(inplace=True),
+                    nn.AdaptiveAvgPool2d((1, 1))
+                )
+            else:
+                # Swin Transformer等模型
+                self.model.classifier = nn.Linear(int(self.model.classifier.in_features), self.num_classes, bias=True)
+        elif hasattr(self.model, 'head'):
+            # Swin Transformer使用head层
+            self.model.head = nn.Linear(int(self.model.head.in_features), self.num_classes, bias=self.use_bias)
+
+        else:
+            raise ValueError(f"Model {self.model_name} does not have recognizable classifier layer")
+        print(self.model)
+        # 加载训练好的权重
+        self.model.load_state_dict(torch.load(self.weights_path, map_location=torch.device('cpu')))
+        print(f"成功加载模型参数: {self.weights_path}")
+        # 将模型移动到GPU
+        self.model.eval()
+        self.model = self.model.to(self.device)
+        print(f"成功加载模型: {self.model_name}")
+
+    def predict(self, image_tensor):
+        """
+        对单张图像进行预测
+
+        Args:
+            image_tensor: 预处理后的图像张量
+
+        Returns:
+            predicted_class: 预测的类别索引
+            confidence: 预测置信度
+            probabilities: 各类别的概率
+        """
+
+        image_tensor = image_tensor.to(self.device)
+
+        with torch.no_grad():
+            outputs = self.model(image_tensor)
+            probabilities = torch.softmax(outputs, dim=1)  # 沿行计算softmax
+            confidence, predicted_class = torch.max(probabilities, 1)
+
+        return confidence.cpu().numpy(), predicted_class.cpu().numpy()
+
+
+def preprocess_image(img):
+    """
+    预处理图像以匹配训练时的预处理
+    
+    Args:
+        img: PIL图像
+        
+    Returns:
+        tensor: 预处理后的图像张量
+    """
+    # 定义与训练时相同的预处理步骤
+    transform = transforms.Compose([
+        transforms.Resize((224, 224)),
+        transforms.ToTensor(),
+        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+    ])
+
+    # 打开并转换图像
+
+    img_w, img_h = img.size
+    global patch_w, patch_h
+    imgs_patch = []
+    imgs_index = []
+    # fig, axs = plt.subplots(img_h // patch_h + 1, img_w // patch_w + 1)
+    for i in range(img_h // patch_h + 1):
+        for j in range(img_w // patch_w + 1):
+            left = j * patch_w  # 裁剪区域左边框距离图像左边的像素值
+            top = i * patch_h  # 裁剪区域上边框距离图像上边的像素值
+            right = min(j * patch_w + patch_w, img_w)  # 裁剪区域右边框距离图像左边的像素值
+            bottom = min(i * patch_h + patch_h, img_h)  # 裁剪区域下边框距离图像上边的像素值
+            # 检查区域是否有效
+            if right > left and bottom > top:
+                patch = img.crop((left, top, right, bottom))
+                # 长宽比过滤
+                # rate = patch.height / (patch.width + 1e-6)
+                # if rate > 1.314 or rate < 0.75:
+                #     # print(f"长宽比过滤: {patch_name}")
+                #     continue
+                imgs_patch.append(patch)
+                imgs_index.append((left, top))
+                # axs[i, j].imshow(patch)
+                # axs[i, j].set_title(f'Image {i} {j}')
+                # axs[i, j].axis('off')
+
+    # plt.tight_layout()
+    # plt.show()
+    imgs_patch = torch.stack([transform(img) for img in imgs_patch])
+    # 添加批次维度
+    # image_tensor = image_tensor.unsqueeze(0)
+    return imgs_index, imgs_patch
+
+
+def visualize_prediction(image_path, predicted_class, confidence, class_names):
+    """
+    可视化预测结果
+    
+    Args:
+        image_path: 图像路径
+        predicted_class: 预测的类别索引
+        confidence: 预测置信度
+        class_names: 类别名称列表
+    """
+    image = Image.open(image_path).convert('RGB')
+    
+    plt.figure(figsize=(8, 6))
+    plt.imshow(image)
+    plt.axis('off')
+    plt.title(f'Predicted: {class_names[predicted_class]}\n'
+              f'Confidence: {confidence:.4f}', fontsize=14)
+    plt.show()
+
+def get_33_patch(arr:np.ndarray, center_row:int, center_col:int):
+    """以(center_row,center_col)为中心,从arr中取出来3*3区域的数据"""
+    # 边界检查
+    h,w = arr.shape
+    safe_row_up_limit = max(0, center_row-1)
+    safe_row_bottom_limit = min(h, center_row+2)
+    safe_col_left_limit = max(0, center_col-1)
+    safe_col_right_limit = min(w, center_col+2)
+    return arr[safe_row_up_limit:safe_row_bottom_limit, safe_col_left_limit:safe_col_right_limit]
+
+
+def fileter_prediction(predicted_class, confidence, pre_rows, pre_cols, filter_down_limit=3):
+    """预测结果矩阵滤波,九宫格内部存在浑浊水体的数量需要大于filter_down_limit,"""
+    predicted_class_mat = np.resize(predicted_class, (pre_rows, pre_cols))
+    predicted_conf_mat = np.resize(confidence, (pre_rows, pre_cols))
+    new_predicted_class_mat = predicted_class_mat.copy()
+    new_predicted_conf_mat = predicted_conf_mat.copy()
+    for i in range(pre_rows):
+        for j in range(pre_cols):
+            if (1. - predicted_class_mat[i, j]) > 0.1:
+                continue  # 跳过背景类
+            core_region = get_33_patch(predicted_class_mat, i, j)
+            if np.sum(core_region) < filter_down_limit:
+                new_predicted_class_mat[i, j] = 0  #  重置为背景类
+                new_predicted_conf_mat[i, j] = 1.0
+    return new_predicted_conf_mat.flatten(), new_predicted_class_mat.flatten()
+
+def discriminate_ratio(water_pre_list:list):
+    # 方式一:60%以上的帧存在浑浊水体
+    water_pre_arr = np.array(water_pre_list, dtype=np.float32)
+    water_pre_arr_sum = np.sum(water_pre_arr, axis=0)
+    bad_water = np.array(water_pre_arr_sum >= 0.6 * len(water_pre_list), dtype=np.int32)
+    bad_flag = bool(np.sum(bad_water, dtype=np.int32) > 2)  # 大于两个patch符合要求才可以
+    print(f'浑浊比例判别:该时间段是否存在浑浊水体:{bad_flag}')
+    return bad_flag
+
+
+def discriminate_count(pre_class_arr, continuous_count_mat):
+    """连续帧判别"""
+    positive_index = np.array(pre_class_arr,dtype=np.int32) > 0
+    negative_index = np.array(pre_class_arr,dtype=np.int32) == 0
+    # 给负样本区域置零
+    continuous_count_mat[negative_index] -= 1
+    # 给正样本区域加1
+    continuous_count_mat[positive_index] += 1
+    # 保证不出现负数
+    continuous_count_mat[continuous_count_mat<0] = 0
+    # 判断浑浊
+    bad_flag = bool(np.sum(continuous_count_mat > 15) > 2)
+    print(f'连续帧方式:该时间段是否存在浑浊水体:{bad_flag}')
+    return bad_flag
+
+def main():
+
+    # 初始化模型实例
+    # TODO:修改模型网络名称/模型权重路径/视频路径
+    predictor = Predictor(model_name='shufflenet',
+                          weights_path=r'./shufflenet.pth',
+                          num_classes=2)
+    input_path = r'D:\code\water_turbidity_det\frame_data\1video_20251229124533_hunzhuo'
+    # 预处理图像
+    all_imgs = os.listdir(input_path)
+    all_imgs = [os.path.join(input_path, p) for p in all_imgs if p.split('.')[-1] in ['jpg', 'png']]
+    image = Image.open(all_imgs[0]).convert('RGB')
+    # 将预测结果reshape为矩阵时的行列数量
+    pre_rows = image.height // patch_h + 1
+    pre_cols = image.width // patch_w + 1
+    # 图像显示时resize的尺寸
+    resized_img_h = image.height // 2
+    resized_img_w = image.width // 2
+    # 预测每张图像
+
+    water_pre_list = []
+    continuous_count_mat = np.zeros(pre_rows*pre_cols, dtype=np.int32)
+    flag = False
+    for img_path in all_imgs:
+        image = Image.open(img_path).convert('RGB')
+        # 预处理
+        patches_index, image_tensor = preprocess_image(image) # patches_index:list[tuple, ...]
+        # 推理
+        confidence, predicted_class  = predictor.predict(image_tensor)  # confidence: np.ndarray, shape=(x,), predicted_class: np.ndarray, shape=(x,), raw_outputs: np.ndarray, shape=(x,)
+        # 第一层虚警抑制,置信度过滤,低于阈值将会被忽略
+        for i in range(len(confidence)):
+            if confidence[i] < confidence_threshold and predicted_class[i] == 1:
+                confidence[i] = 1.0
+                predicted_class[i] = 0
+        # 第二层虚警抑制,空间滤波
+        # 在此处添加过滤逻辑
+        # print('原始预测结果:', predicted_class)
+        new_confidence, new_predicted_class = fileter_prediction(predicted_class, confidence, pre_rows, pre_cols, filter_down_limit=3)
+        # print('过滤后预测结果:', new_predicted_class)
+        # 可视化预测结果
+        image = cv2.imread(img_path)
+        image = draw_grid(image, patch_w, patch_h)
+        image = draw_predict_grid(image, patches_index, predicted_class, confidence)
+
+        new_image = cv2.imread(img_path)
+        new_image = draw_grid(new_image, patch_w, patch_h)
+        new_image = draw_predict_grid(new_image, patches_index, new_predicted_class, new_confidence)
+        image = cv2.resize(image, (resized_img_w, resized_img_h))
+        new_img = cv2.resize(new_image, (resized_img_w, resized_img_h))
+
+        cv2.imshow('image', image)
+        cv2.imshow('image_filter', new_img)
+
+        cv2.waitKey(25)
+        water_pre_list.append(new_predicted_class)
+        # 方式2判别
+        flag = discriminate_count(new_predicted_class, continuous_count_mat)
+        # 方式1判别
+        if len(water_pre_list) > 25:
+            flag = discriminate_ratio(water_pre_list) and flag
+        print('综合判别结果:', flag)
+
+
+
+if __name__ == "__main__":
+    main()