jiyuhang пре 3 месеци
родитељ
комит
2642348869

+ 16 - 0
.env

@@ -0,0 +1,16 @@
+# 图块高度
+PATCH_WIDTH=256
+# 图块宽度
+PATCH_HEIGHT=256
+# 测试精度时置信度阈值
+CONFIDENCE_THRESHOLD=0.6
+# 图片输入大小
+IMG_INPUT_SIZE=256
+# 工作线程数
+WORKERS=0
+# CUDA设备
+CUDA_VISIBLE_DEVICES=0
+# batch size
+BATCH_SIZE=128
+# pretrained
+PRETRAINED=True

+ 2 - 1
.gitignore

@@ -47,6 +47,7 @@ __pycache__/
 *.npz
 *.prototxt
 *.bmodel
-*.josn
+*.json
+*.profile
 # 日志文件
 runs/*

+ 106 - 0
bmodel_test.py

@@ -0,0 +1,106 @@
+import argparse
+import sophon.sail as sail
+import cv2
+import logging
+import numpy as np
+
+class Predictor:
+
+    def __init__(self):
+        # 加载推理引擎
+        self.net = sail.Engine(args.bmodel, args.dev_id, sail.IOMode.SYSIO)
+        self.graph_name = self.net.get_graph_names()[0]
+        self.input_names = self.net.get_input_names(self.graph_name)
+        self.input_shapes = [self.net.get_input_shape(self.graph_name, name) for name in self.input_names]
+        self.output_names = self.net.get_output_names(self.graph_name)
+        self.output_shapes = [self.net.get_output_shape(self.graph_name, name) for name in self.output_names]  # [[1, 2]]
+        self.input_name = self.input_names[0]
+        self.input_shape = self.input_shapes[0]  # [1, 3, 256, 256]
+        self.batch_size = self.input_shape[0]
+        self.net_h = self.input_shape[2]  # 输入图像的高
+        self.net_w = self.input_shape[3]  # 输入图像的宽
+        # 归一化参数,采用imagenet预训练参数
+        self.mean = [0.485, 0.456, 0.406]
+        self.std = [0.229, 0.224, 0.225]
+        self.print_network_info()
+
+
+    def __call__(self, img):
+        return self.predict(img)
+
+    def print_network_info(self):
+        info = {
+            'Graph Name': self.graph_name,
+            'Input Name': self.input_name,
+            'Output Names': self.output_names,
+            'Output Shapes': self.output_shapes,
+            'Input Shape': self.input_shape,
+            'Batch Size': self.batch_size,
+            'Height': self.net_h,
+            'Width': self.net_w,
+            'Mean': self.mean,
+            'Std': self.std
+        }
+
+        print("=" * 50)
+        print("Network Configuration Info")
+        print("=" * 50)
+        for key, value in info.items():
+            print(f"{key:<18}: {value}")
+        print("=" * 50)
+
+    def predict(self, input_img):
+        input_data = {self.input_name: input_img}
+        outputs = self.net.process(self.graph_name, input_data)
+        print('predict fun:', outputs)
+        print('predict fun return:', list(outputs.values())[0])
+        return list(outputs.values())[0]
+
+    def preprocess(self, img):
+        h, w, _ = img.shape
+        if h != 256 or w != 256:
+            img = cv2.resize(img, (self.net_w, self.net_h))
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        img = img.astype('float32')
+        img = (img / 255 - self.mean) / self.std  # 这一步是很有必要的
+        # img = img / 255. # 编译过程并不会帮你做归一化,所以这里要自己做归一化,否则预测数值可能会非常不准确
+        img = np.transpose(img, (2, 0, 1))
+        return img
+
+    def postprocess(self, outputs):
+
+        outputs_exp = np.exp(outputs)
+        print('exp res:', outputs_exp)
+        outputs = outputs_exp / np.sum(outputs_exp, axis=1)[:, None]
+        print('softmax res:', outputs)
+        predictions = np.argmax(outputs, axis=1)
+        print('预测结果:', predictions)
+        return outputs
+
+def main(args):
+    predictor = Predictor()
+    filename = args.input
+    src_img = cv2.imread(filename, cv2.IMREAD_COLOR)
+    src_img = predictor.preprocess(src_img)
+    src_img = np.stack([src_img])
+    print('图像输入shape:',src_img.shape)
+    if src_img is None:
+        logging.error("{} imread is None.".format(filename))
+        return
+    res = predictor(src_img)
+    print('预测结果res:', res)
+    predictor.postprocess(res)
+
+
+def argsparser():
+    parser = argparse.ArgumentParser(prog=__file__)
+    parser.add_argument('--input', type=str, default='./000000_256_512_1.jpg', help='path of input, must be image directory')
+    parser.add_argument('--bmodel', type=str, default='./shufflenet_f32.bmodel', help='path of bmodel')
+    parser.add_argument('--dev_id', type=int, default=0, help='tpu id')
+    args = parser.parse_args()
+
+    return args
+
+if __name__ == '__main__':
+    args = argsparser()
+    main(args)

+ 408 - 0
calculator.py

@@ -0,0 +1,408 @@
+import cv2
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+from matplotlib import pyplot as plt
+from  matplotlib import rcParams
+import seaborn as sns
+import threading
+
+
+def set_chinese_font():
+    rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'Microsoft YaHei', 'SimSun']
+    rcParams['axes.unicode_minus'] = False
+
+def region_ndti(img, mask, is_night):
+    """对于可见光图像计算ndti,对于红外图像直接取值"""
+    if len(img.shape) != 3:
+        raise RuntimeError('请输入三通道彩色图像')
+
+    if len(mask.shape) != 2:
+        raise RuntimeError('请输入单通道掩膜图像')
+
+    img = img.astype(np.float32)
+    mask = mask.astype(np.float32)
+    # 取通道
+    g = img[:,:,1]
+    r = img[:,:,2]
+    # 判断是否为红外
+    if is_night: # 红外直接取值
+        ndti = r
+    else:  # 可见光计算ndti
+        ndti = (r - g) / (r + g + 1e-6)
+    # 掩膜归一化
+    if np.max(mask) > 2:
+        mask /= 255.
+    # 仅保留掩膜区域的ndti
+    ndti *= mask
+    return ndti
+
+def plot_histogram_opencv(matrix, start, end, step, title):
+    """使用OpenCV绘制直方图,避免GUI冲突"""
+    # 计算直方图
+    num_bins = int((end - start) / step)
+    hist, bins = np.histogram(matrix.ravel(), bins=num_bins, range=(start, end))
+
+    # 创建直方图图像(增加高度以便显示标签)
+    hist_img = np.zeros((450, 600, 3), dtype=np.uint8)
+    hist_img.fill(255)  # 白色背景
+
+    # 归一化直方图:确保高度在合理范围内
+    if len(hist) > 0 and np.max(hist) > 0:
+        # 保留10像素的边距
+        hist_normalized = cv2.normalize(hist, None, 0, 350, cv2.NORM_MINMAX)
+    else:
+        hist_normalized = np.zeros_like(hist)
+
+    # 计算合理的矩形宽度(确保至少1像素宽)
+    if len(hist) > 0:
+        bin_width = max(1, 600 // len(hist))  # 确保最小宽度为1
+    else:
+        bin_width = 1
+
+    # 绘制直方图矩形
+    for i in range(len(hist_normalized)):
+        if i >= 600:  # 防止索引越界
+            break
+
+        x1 = i * bin_width
+        x2 = min(x1 + bin_width, 599)  # 确保不超出图像边界
+        y1 = 400 - int(hist_normalized[i])  # 从底部开始计算高度
+        y2 = 400
+
+        # 只绘制有高度的矩形
+        if y1 < y2:
+            cv2.rectangle(hist_img, (x1, y1), (x2, y2), (0, 0, 255), -1)
+
+    # 添加坐标轴
+    cv2.line(hist_img, (50, 400), (550, 400), (0, 0, 0), 2)  # x轴
+    cv2.line(hist_img, (50, 400), (50, 50), (0, 0, 0), 2)  # y轴
+
+    # 添加标题和标签(调整位置避免被裁剪)
+    cv2.putText(hist_img, title, (10, 30),
+                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
+    cv2.putText(hist_img, 'gray', (280, 430),
+                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
+    cv2.putText(hist_img, 'fre', (5, 200),
+                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
+
+    return hist_img
+
+def plot_histogram_seaborn_simple(matrix, start, end, step, title):
+    """
+    使用Seaborn的histplot函数直接绘制(最简单的方法)
+    """
+    # 解决中文显示问题
+    plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
+    plt.rcParams['axes.unicode_minus'] = False
+
+    plt.figure(figsize=(12, 6))
+
+    # 直接使用Seaborn的histplot,设置bins参数[3](@ref)
+    num_bins = int((end - start) / step)
+    sns.histplot(matrix.ravel(), bins=num_bins, kde=True,
+                 color='skyblue', alpha=0.7,
+                 edgecolor='white', linewidth=0.5)
+
+    plt.title(title+ '_histplot')
+    plt.xlabel('灰度值')
+    plt.ylabel('像素频率')
+    plt.tight_layout()
+    # plt.show()
+    plt.savefig('temp_' + title +'.png')  # 保存到文件而不是显示
+    plt.close()  # 关闭图形,释放资源
+
+    # 如果需要显示,可以用cv2读取并显示
+    hist_img = cv2.imread('temp_' + title +'.png')
+    cv2.imshow('Histogram '+title, hist_img)
+
+def plot_histogram_seaborn(matrix, hist, bin_edges, start, end, step):
+    """
+    使用Seaborn绘制更美观的直方图
+    """
+    # 设置Seaborn样式
+    sns.set_style("whitegrid")
+    plt.figure(figsize=(12, 6))
+
+    # 将数据转换为适合Seaborn的格式
+    flattened_data = matrix.ravel()
+
+    # 使用Seaborn的histplot(会自动计算直方图,但我们用自定义的)
+    # 这里我们手动绘制以确保使用我们的自定义bins
+    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
+
+    # 创建条形图
+    bars = plt.bar(bin_centers, hist, width=step * 0.8,
+                   alpha=0.7, color=sns.color_palette("viridis", 1)[0],
+                   edgecolor='white', linewidth=0.5)
+
+    # 添加KDE曲线(核密度估计)
+    sns.kdeplot(flattened_data, color='red', linewidth=2, label='密度曲线')
+
+    plt.title(f'灰度分布直方图', fontsize=16)
+    plt.xlabel('灰度值', fontsize=12)
+    plt.ylabel('像素频率', fontsize=12)
+    plt.legend()
+    plt.tight_layout()
+    plt.show()
+
+
+
+def colorful_ndti_matrix(ndti_matrix):
+    """给NDTI矩阵着色"""
+    # 存放着色后的NDTI矩阵
+    color_ndti_matrix = np.zeros((ndti_matrix.shape[0], ndti_matrix.shape[1], 3), dtype=np.uint8)
+    negative_mask = ndti_matrix < 0
+    positive_mask = ndti_matrix > 0
+    # 处理负值
+    if np.any(negative_mask):
+        blue_intensity = ndti_matrix.copy()
+        blue_intensity[positive_mask] = 0
+        blue_intensity[negative_mask] *= -255.
+        blue_intensity = np.clip(blue_intensity, 0, 255).astype(np.uint8)
+        color_ndti_matrix[:, :,0] = blue_intensity
+    # 处理正值区域(红色渐变)
+    if np.any(positive_mask):
+        # 将0到+1映射到0-255的红色强度
+        red_intensity = ndti_matrix.copy()
+        red_intensity[negative_mask] = 0
+        red_intensity[positive_mask] *= 255
+        red_intensity = np.clip(red_intensity, 0, 255).astype(np.uint8)
+        color_ndti_matrix[:, :, 2] = red_intensity
+
+    # 返回着色后的ndti矩阵
+    return color_ndti_matrix
+
+def img_add(img1, img2, img1_w=1, img2_w=0, gamma=0):
+
+    if len(img1.shape) != len(img2.shape):
+        raise ValueError('img1 and img2 must have the same shape')
+
+    # 设置权重参数(透明度)
+    alpha = img1_w  # 第一张图像的权重
+    beta = img2_w # 第二张图像的权重
+    gamma =gamma  # 亮度调节参数
+
+    # 执行加权叠加
+    result = cv2.addWeighted(img1, alpha, img2, beta, gamma)
+    return result
+def callback(event, x, y, flags, param):
+    if event == cv2.EVENT_LBUTTONDOWN:
+        is_night_ = param['is_night']
+        frame = param['Image']
+        w_name = param['window_name']
+        b = frame[:,:,0].astype(np.float32)
+        g = frame[:,:,1].astype(np.float32)
+        r = frame[:,:,2].astype(np.float32)
+        if is_night_:
+            ndti_value = r[y, x]
+        else:
+            ndti_value = (r[y,x] - g[y,x]) / (r[y,x] + g[y,x])
+        cv2.putText(frame, f'{ndti_value:.2f}', (x, y), cv2.FONT_HERSHEY_DUPLEX, 0.6, (0, 255, 0), 2, cv2.LINE_AA)
+        cv2.putText(frame, f'paused', (10, 45), cv2.FONT_HERSHEY_DUPLEX, 0.6, (0, 255, 0), 2, cv2.LINE_AA)
+        cv2.imshow(w_name, frame)
+def single_img(img_path,mask_path, id=0):
+
+    frame = cv2.imread(img_path)
+    if frame is None:
+        raise RuntimeError('img open failed')
+    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
+    if mask is None:
+        raise RuntimeError('mask open failed')
+    scale = 4 if id != 3 else 2
+    mask = cv2.resize(mask, (mask.shape[1] // scale, mask.shape[0] // scale))
+    frame = cv2.resize(frame, (frame.shape[1] // scale, frame.shape[0] // scale))
+    roi_ndti = region_ndti(frame, mask)
+    # 给ndti着色
+    color_ndti = colorful_ndti_matrix(roi_ndti)
+
+    # 绘制直方图
+    roi_index = mask > 0
+    roi_ndti_fla = roi_ndti[roi_index]  # 仅保留感兴趣区的计算结果
+    hist_img = plot_histogram_opencv(roi_ndti[roi_index], -1, 1, 0.01, f'NDTI Histogram {id}')
+
+    # 打印统计信息
+    # 3σ原则
+    mean_value = np.mean(roi_ndti_fla)
+    std_value = np.std(roi_ndti_fla)
+    min_value = np.min(roi_ndti_fla)
+    max_value = np.max(roi_ndti_fla)
+
+    up_limit = mean_value + 3 * std_value
+    down_limit = mean_value - 3 * std_value
+    # 调整
+    roi_ndti_fla = roi_ndti_fla[(up_limit >= roi_ndti_fla) & (roi_ndti_fla >= down_limit)]
+    text = f'pixs:{int(np.sum(roi_ndti_fla))} adj_mean:{roi_ndti_fla.mean():.3f} adj_std:{roi_ndti_fla.std():.3f}'
+    cv2.putText(frame, text, (10, 25), cv2.FONT_HERSHEY_DUPLEX, 0.6, (0, 255, 0), 2, cv2.LINE_AA)
+    print(f"""统计信息:
+    总像素数: {np.sum(roi_ndti_fla):,}
+    灰度范围: {min_value:.1f} - {max_value:.1f}
+    平均值: {mean_value:.2f}
+    标准差: {std_value:.2f}
+    调整平均值:{roi_ndti_fla.mean():.2f}
+    调整标准差:{roi_ndti_fla.std():.2f}
+    """
+          )
+    # 显示当前帧处理结果
+    cv2.imshow('original', frame)  # 原图
+    cv2.imshow('mask'+ str(id), mask)   # 掩膜
+    roi_ndti = np.abs(roi_ndti*255.).astype(np.uint8)
+    cv2.imshow('ndti'+ str(id), roi_ndti)  # ndti黑白强度
+    cv2.imshow('color_ndti'+ str(id), color_ndti)  # # ndti彩色强度
+    # 图像叠加
+    add_img = img_add(frame, color_ndti)
+    cv2.imshow('add_ori_ndti' + str(id), add_img)
+    param = {'Image': frame}
+    cv2.setMouseCallback('original', callback, param=param)
+    cv2.waitKey(0)
+    cv2.destroyAllWindows()
+
+def main(video_dir, mask_dir, id):
+    # 视频分析浊度
+    video_path = video_dir
+    # 加载掩膜
+    mask_path = mask_dir
+    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
+    if mask is None:
+        raise RuntimeError('mask open failed')
+    # 除数倍率
+    scale = 4 if id != 3 else 2
+
+    mask = cv2.resize(mask, (mask.shape[1] // scale, mask.shape[0] // scale))
+    # 视频读取
+    cap = cv2.VideoCapture(video_path)
+    pause = False
+    # 检查视频是否成功打开
+    if not cap.isOpened():
+        print("错误:无法打开视频文件。")
+        print("可能的原因:文件路径错误,或系统缺少必要的解码器。")
+        exit()
+    else:
+        ret, frame = cap.read()
+
+    # 成功打开后,逐帧读取和显示视频
+    global_img = [None]
+    is_night = False
+    while True:
+        if not pause:
+            ret, frame = cap.read()  # ret指示是否成功读取帧,frame是图像数据
+
+            # 如果读取帧失败(可能是文件损坏或已到结尾),退出循环
+            if not ret:
+                print("视频播放完毕。")
+                break
+            # 判断是否为夜间模式
+            if np.mean(np.abs(frame[:, :, 0] - frame[:, :, 1])) < 0.1:
+                is_night = True
+            # 处理当前帧
+            # 缩放
+            frame = cv2.resize(frame, (frame.shape[1] // scale, frame.shape[0] // scale))
+            # cv2.imshow('original', frame)
+            # 滤波
+            # frame = cv2.GaussianBlur(frame, (5, 5), 1.5)
+            # 计算掩膜区域的NDTI值
+            roi_ndti = region_ndti(frame, mask, is_night)
+            # 给ndti着色
+            color_ndti = colorful_ndti_matrix(roi_ndti)
+
+            # 绘制直方图
+            roi_index = mask > 0
+            roi_ndti_fla = roi_ndti[roi_index]  # 仅保留感兴趣区的计算结果
+            # plot_histogram_seaborn_simple(roi_ndti_fla,-1,1,0.01, str(id))
+
+            # 打印统计信息
+            # 3σ原则
+            mean_value  = np.mean(roi_ndti_fla)
+            print('调整前平均值:', mean_value)
+            std_value = np.std(roi_ndti_fla)
+            print('调整前平均值:', std_value)
+            min_value = np.min(roi_ndti_fla)
+            max_value = np.max(roi_ndti_fla)
+
+            up_limit = mean_value + 3 * std_value
+            down_limit = mean_value - 3 * std_value
+            # 调整
+            roi_ndti_fla = roi_ndti_fla[(up_limit >= roi_ndti_fla) & (roi_ndti_fla >= down_limit)]
+            text = f'pixs:{int(np.sum(roi_ndti_fla))} adj_mean:{roi_ndti_fla.mean():.3f} adj_std:{roi_ndti_fla.std():.3f}'
+            cv2.putText(frame, text, (10, 25), cv2.FONT_HERSHEY_DUPLEX, 0.6, (0, 255, 0), 2, cv2.LINE_AA)
+            # print(f"""统计信息:
+            # 总像素数: {np.sum(roi_ndti_fla):,}
+            # 灰度范围: {min_value:.1f} - {max_value:.1f}
+            # 平均值: {mean_value:.2f}
+            # 标准差: {std_value:.2f}
+            # 调整平均值:{roi_ndti_fla.mean():.2f}
+            # 调整标准差:{roi_ndti_fla.std():.2f}
+            # """
+            #       )
+            # 显示当前帧处理结果
+            #cv2.imshow('original', frame)  # 原图
+            #cv2.imshow('mask'+ str(id), mask)   # 掩膜
+            #roi_ndti = np.abs(roi_ndti*255.).astype(np.uint8)
+            #cv2.imshow('ndti'+ str(id), roi_ndti)  # ndti黑白强度
+            #cv2.imshow('color_ndti'+ str(id), color_ndti)  # # ndti彩色强度
+            # 图像叠加
+            add_img = img_add(frame, color_ndti)
+            global_img[0] = add_img
+            cv2.imshow('add_ori_ndti'+ str(id), add_img)
+            #cv2.imshow('Histogram', hist_img)
+
+        # 播放帧率为25FPS,如果期间按下'q'键则退出循环
+        key = cv2.waitKey(500) & 0xFF
+        if key == ord(' '):
+            pause = not pause
+            status = "已暂停" if pause else "播放中"
+            print(f"{id} 状态: {status}")
+
+        if key  == ord('q'):
+            break
+        if pause:
+
+            if global_img[0] is not None:
+                param = {'Image': global_img[0],'window_name': 'add_ori_ndti' + str(id), 'is_night': is_night}
+                cv2.setMouseCallback('add_ori_ndti' + str(id), callback, param=param)
+
+    # 释放资源并关闭所有窗口
+    cap.release()
+
+
+if __name__ == '__main__':
+    set_chinese_font()
+
+    # single_img(r'D:\code\water_turbidity_det\frame_data\day_202511211129\1_device_capture.jpg',
+    #             r'D:\code\water_turbidity_det\draw_mask\mask\1_main_20251119102036_@1000000.png')
+    # pass
+    # path1 = r'D:\code\water_turbidity_det\frame_data\day_202511211129\1_video_202511211128.dav'
+    # path2 = r'D:\code\water_turbidity_det\frame_data\day_202511211129\2_video_202511211128.dav'
+    # path3 = r'D:\code\water_turbidity_det\frame_data\day_202511211129\3_video_202511211128.dav'
+    # path4 = r'D:\code\water_turbidity_det\frame_data\day_202511211129\4_video_202511211128.dav'
+    #
+    # path1 = r'D:\code\water_turbidity_det\frame_data\night\1_video_20251120.dav'
+    # path2 = r'D:\code\water_turbidity_det\frame_data\night\2_video_20251120_1801.dav'
+    # path3 = r'D:\code\water_turbidity_det\frame_data\night\3_video_20251120_1759.dav'
+    # path4 = r'D:\code\water_turbidity_det\frame_data\night\4_video_20251120_1800.dav'
+    #
+    # t1 = threading.Thread(target=main, kwargs={'video_dir': path1,
+    #                                            'mask_dir': r'D:\code\water_turbidity_det\draw_mask\mask\1_main_20251119102036_@1000000.png',
+    #                                            'id':1})
+    # t2 = threading.Thread(target=main, kwargs={'video_dir': path2,
+    #                                            'mask_dir': r'D:\code\water_turbidity_det\draw_mask\mask\2_main_20251119102038_@1000000.png',
+    #                                            'id':2})
+    # t3 = threading.Thread(target=main, kwargs={'video_dir': path3,
+    #                                            'mask_dir': r'D:\code\water_turbidity_det\draw_mask\mask\3_main_20251119102042_@1000000.png',
+    #                                            'id':3})
+    # t4 = threading.Thread(target=main, kwargs={'video_dir': path4,
+    #                                            'mask_dir': r'D:\code\water_turbidity_det\draw_mask\mask\4_main_20251119102044_@1000000.png',
+    #                                            'id':4})
+    # # threads = [t1, t2, t3, t4]
+    # threads = [t4]
+    # for t in threads:
+    #     t.start()
+    # for t in threads:
+    #     if t.is_alive():
+    #         t.join()
+    main(video_dir=r'/video/records_202512031553\video4_20251129053218_20251129060528.mp4',
+         mask_dir=r'D:\code\water_turbidity_det\draw_mask\mask\4_device_capture.png',
+         id=4)
+    cv2.destroyAllWindows()
+
+

+ 0 - 0
check/检查标注结果是否正确.md


+ 80 - 0
draw_mask/draw.py

@@ -0,0 +1,80 @@
+import cv2
+import numpy as np
+import os
+
+class DrawRectangle:
+
+    def __init__(self, div_scale):
+        self.mask_save_dir = './mask'
+        self.scale = div_scale
+        self.current_roi_points = []
+        self.rois = []
+        self.window_title = "Image - Select ROI"
+        self.draw_complete = False
+        pass
+
+    def callback(self, event, x, y, flags, param):
+
+        drawing_image = param['Image']
+        # 左键添加感兴趣点
+        if event == cv2.EVENT_LBUTTONDOWN:
+            # 添加感兴趣点
+            self.current_roi_points.append((x, y))
+            # 绘制标记点
+            cv2.circle(drawing_image, (x, y), 4, (0, 255, 0), -1)
+            # 绘制多边形
+            if len(self.current_roi_points) > 1:
+                cv2.line(drawing_image, self.current_roi_points[-2], self.current_roi_points[-1], (0, 255, 0), 2)
+            # 显示图像
+            cv2.imshow(self.window_title, drawing_image)
+            print(f'添加感兴趣点:{y}行, {x}列')
+        # 右键闭合感兴趣区域
+        if event == cv2.EVENT_RBUTTONDOWN:
+            if len(self.current_roi_points) < 3:
+                print("[提示] ROI 至少需要 3 个点构成多边形!")
+                return
+            cv2.line(drawing_image, self.current_roi_points[-1], self.current_roi_points[0], (0, 255, 0), 2)
+            cv2.imshow(self.window_title, drawing_image)
+            # 清理
+            self.rois.append(self.current_roi_points)
+            print(f'添加感兴趣区,包含点数:{len(self.current_roi_points)}个')
+            self.current_roi_points = []
+
+
+    def draw(self, img_path: str):
+        """在输入图像中绘制多边形区域,然后生成相应的mask图片"""
+        # 读取图像
+        ori_img = cv2.imread(img_path)
+        mask_base_name = os.path.splitext(os.path.basename(img_path))[0] + '.png'
+        img = cv2.resize(ori_img, (ori_img.shape[1] // self.scale, ori_img.shape[0] // self.scale))
+        if img is None:
+            raise RuntimeError('Cannot read the image!')
+        param = {'Image': img}
+        cv2.namedWindow(self.window_title)
+        cv2.setMouseCallback(self.window_title, self.callback, param=param)
+
+
+        # 显示图像并等待退出
+        while True:
+            cv2.imshow(self.window_title, img)
+            key = cv2.waitKey(1) & 0xFF
+            if key == ord('q') or key == 27:  # 按'q'或ESC键退出
+                break
+
+        # 为原图生成掩膜
+        mask = np.zeros((ori_img.shape[0], ori_img.shape[1]),dtype=np.uint8)  # shape等于原始输入图像
+        for roi in self.rois:
+            roi_points = np.array(roi, np.int32).reshape((-1, 1, 2)) * self.scale  # 兴趣点的缩放处理
+            cv2.fillPoly(mask, [roi_points], 255)
+
+        # 保存掩膜图像
+        if not os.path.exists(self.mask_save_dir):
+            os.makedirs(self.mask_save_dir)
+        cv2.imwrite(os.path.join(self.mask_save_dir, mask_base_name), mask)
+        # cv2.imshow("mask", mask)
+        # cv2.waitKey(0)
+        cv2.destroyAllWindows()
+
+if __name__ == '__main__':
+    drawer = DrawRectangle(2)
+    drawer.draw(r"D:\code\water_turbidity_det\draw_mask\mask\4_device_capture.jpg")

BIN
draw_mask/mask/1_device_capture.png


BIN
draw_mask/mask/2_device_capture.png


BIN
draw_mask/mask/3_device_capture.png


BIN
draw_mask/mask/4_device_capture.png


BIN
label_data.tar.gz


+ 62 - 0
labelme/check_label.py

@@ -0,0 +1,62 @@
+"""检查标注是否正确,读取标签,信息"""
+import cv2
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+from utils import draw_grid
+from dotenv import load_dotenv
+load_dotenv() # 加载环境变量
+
+patch_w = int(os.getenv('PATCH_WIDTH', 256))
+patch_h = int(os.getenv('PATCH_HEIGHT', 256))
+scale = 2
+clicked_points = []  # u v w h cls
+
+def main():
+    """检查标注结果是否正确"""
+    # 标注文件路径
+    global patch_w
+    global patch_h
+    global scale
+    global clicked_points
+    # TODO:修改为要检查的图片路径
+    imgs_path = r'/frame_data/test/3_video_202511211127'
+    output_root = r'D:\code\water_turbidity_det\check'
+    label_path = os.path.join(imgs_path, 'label.txt')
+    if os.path.exists(label_path):
+        with open(label_path, 'r') as fr:
+            lines = fr.readlines()
+        lines = [line.strip() for line in lines]
+        for line in lines:
+            point = line.split(',')
+            clicked_points.append([int(point[0])//scale, int(point[1])//scale, int(point[2])//scale, int(point[3])//scale, int(point[4])])
+        del lines
+    # 检查结果输出路径
+    output_path = os.path.join(output_root, os.path.basename(imgs_path)+'_check')
+    if not os.path.exists(output_path):
+        os.makedirs(output_path)
+    # 获取所有照片
+    all_imgs = os.listdir(imgs_path)
+    all_imgs = [img for img in all_imgs if img.split('.')[-1] == 'jpg' or img.split('.')[-1] == 'png']
+
+    for img in all_imgs:
+        img_path = os.path.join(imgs_path, img)
+        img = cv2.imread(img_path)
+        img = cv2.resize(img, (img.shape[1] // scale, img.shape[0] // scale))
+        # 绘制网格线
+        img = draw_grid(img, patch_w // scale, patch_h // scale)
+        # 绘制标记点
+        for point in clicked_points:
+            # 计算中心点
+            center_x = point[0]+point[2]//scale
+            center_y = point[1]+point[3]//scale
+            cv2.circle(img, (center_x, center_y), 5, (0, 0, 255), -1)
+            # 显示标签文本
+            cv2.putText(img, str(point[4]), (center_x + 10, center_y + 10),
+                        cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
+        cv2.imwrite(os.path.join(output_path, os.path.basename(img_path)),
+                    img)
+        print(f"检查结果保存在: {os.path.join(output_path, os.path.basename(img_path))}")
+
+if __name__ == '__main__':
+    main()

+ 84 - 0
labelme/crop_patch.py

@@ -0,0 +1,84 @@
+# 根据标注文件,生成patch,每个类别放在一个文件夹下
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+
+import numpy as np
+import cv2
+import gc
+from dotenv import load_dotenv
+load_dotenv()
+
+# 从图像中截取patch_w*patch_h大小的图像,并打上标签
+patch_w = int(os.getenv('PATCH_WIDTH', 256))
+patch_h = int(os.getenv('PATCH_HEIGHT', 256))
+
+def main():
+    # TODO:需要修改为标注好的图片路径
+    input_path = r'/frame_data/test/20251225/video4_20251129120320_20251129123514'
+    # TODO: 需要修改为保存patch的根目录
+    output_path_root = r'D:\code\water_turbidity_det\label_data\test'
+
+    # 读取标注文件
+    label_path = os.path.join(input_path, 'label.txt')
+    if not os.path.exists(label_path):  # 强制要求必须标注点什么
+        raise FileNotFoundError(f"{label_path} 不存在")
+    with open(label_path, 'r') as fr:
+        lines = fr.readlines()
+        lines = [line.strip() for line in lines]
+    # 恢复标注的格网
+    grids_info = [] # clicked_points
+    for line in lines:
+        point = line.split(',')
+        grids_info.append([int(point[0]), int(point[1]), int(point[2]), int(point[3]), int(point[4])])
+    # 我们先创建一些类别文件夹
+    # 0类
+    if not os.path.exists(os.path.join(output_path_root, str(0))):
+        os.makedirs(os.path.join(output_path_root, str(0)))
+    # 其余类
+    for grid in grids_info:
+        if grid[4] <= 0:
+            continue
+        if not os.path.exists(os.path.join(output_path_root, str(grid[4]))):
+            os.makedirs(os.path.join(output_path_root, str(grid[4])))
+    # 获取图像
+    all_imgs = [os.path.join(input_path, i) for i in os.listdir(input_path) if i.split('.')[-1] == 'jpg' or i.split('.')[-1] == 'png']
+    for img_path in all_imgs:
+        img_base_name = os.path.basename(img_path).split('.')[0]
+        img = cv2.imread(img_path)
+        # 获取图像高宽
+        img_h, img_w, _ = img.shape
+        # 先将不参与训练的patch重置为0
+        for g in grids_info:
+            if g[4] < 0:  # 标签小于零的不参与训练
+                img[g[1]:min(g[1]+g[3], img_h), g[0]:min(g[0]+g[2], img_w), :] = 0
+        # 再将大于0的patch保存到对应的类别文件夹下
+        for g in grids_info:
+            if g[4] > 0:  # 标签大于零的放到相应的文件夹下
+                patch_name = f'{img_base_name}_{g[0]}_{g[1]}_{g[4]}.jpg'  # 图块保存名称:图片名_左上角x_左上角y_类别.jpg
+                patch = img[g[1]:min(g[1]+g[3], img_h), g[0]:min(g[0]+g[2], img_w), :]
+                # 保存图块
+                cv2.imwrite(os.path.join(output_path_root,str(g[4]), patch_name), patch)
+                # 置零已经保存的patch区域
+                img[g[1]:min(g[1]+g[3], img_h), g[0]:min(g[0]+g[2], img_w), :] = 0
+        # 最后将剩余的patch保存到0类文件夹下
+        for i in range(img_h // patch_h + 1):
+            for j in range(img_w // patch_w + 1):
+                patch = img[i*patch_h:min(i*patch_h+patch_h, img_h), j*patch_w:min(j*patch_w+patch_w, img_w), :]
+                patch_name = f'{img_base_name}_{j*patch_w}_{i*patch_h}_0.jpg'
+                # 长宽比过滤
+                if patch.shape[0] / (patch.shape[1]+1e-6) > 1.314 or patch.shape[0] / (patch.shape[1]+1e-6) < 0.75:
+                    #print(f"长宽比过滤: {patch_name}")
+                    continue
+                # 纯黑图像过滤
+                if np.mean(patch) < 10.10:
+                    #print(f"纯黑图像过滤: {patch_name}")
+                    continue
+                cv2.imwrite(os.path.join(output_path_root, '0', patch_name), patch)
+                #print(f"保存图块: {patch_name}到{os.path.join(output_path_root, '0', patch_name)}")
+        print(f"处理图片: {img_path}完成")
+        # del patch, img
+        # gc.collect()
+
+if __name__ == '__main__':
+    main()

+ 226 - 0
labelme/fixed_label.py

@@ -0,0 +1,226 @@
+# 要求保证视频不能移动,且全过程没有任何遮挡物
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+from utils import draw_grid
+
+import cv2
+import numpy as np
+import tkinter as tk
+from tkinter import simpledialog
+from dotenv import load_dotenv
+load_dotenv()  # 加载环境变量
+
+class DrawTool:
+    """重新绘制图像的工具"""
+    def __init__(self, patch_w, patch_h, scale):
+        self.img_path = None
+
+        self.patch_w = patch_w
+        self.patch_h = patch_h
+        self.scale = scale
+    def draw_new_img(self, points):
+        img = cv2.imread(self.img_path)
+        img = cv2.resize(img, (img.shape[1] // self.scale, img.shape[0] // self.scale))
+        draw_grid(img, self.patch_w // self.scale, self.patch_h // self.scale)
+        for p in points:
+            # 绘制角点
+            cv2.circle(img, (p[0], p[1]), 5, (255, 0, 0), -1)
+            # 绘制中心点
+            circle_x_center = p[0] + self.patch_w//(2*self.scale)
+            circle_y_center = p[1] + self.patch_h//(2*self.scale)
+            cv2.circle(img, (circle_x_center, circle_y_center), 5, (0, 0, 255), -1)
+            # 标注类别
+            cv2.putText(img, str(p[4]), (circle_x_center + 10, circle_y_center + 10),
+                        cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
+        return img
+    def set_path(self, img_path):
+        """设置保存路径"""
+        self.img_path = img_path
+# 所要切分的图块宽高
+patch_w = int(os.getenv('PATCH_WIDTH', 256))
+patch_h = int(os.getenv('PATCH_HEIGHT', 256))
+scale = 2
+draw_tool = DrawTool(patch_w=patch_w, patch_h=patch_h, scale=scale)
+# 存储标记点
+clicked_points = []
+
+def get_text_input(prompt):
+    """创建弹窗获取文本输入"""
+    root = tk.Tk()
+    root.withdraw()  # 隐藏主窗口
+    root.attributes('-topmost', True)  # 确保弹窗置顶
+    result = simpledialog.askstring(' ', prompt)
+    root.destroy()
+    if result is None:
+        result = ""
+    if not result.strip().isdigit():
+        result = -1
+    return int(result)
+def det_same_pos_point(x_cor, y_cor):
+    """判断是否重复点击"""
+    global clicked_points
+    for idx, point in enumerate(clicked_points):
+        if point[0] == x_cor and point[1] == y_cor:
+            return idx, True
+    return  -1, False
+def mouse_callback(event, x, y, flags, param):
+    """
+    鼠标回调函数
+    """
+    global clicked_points
+    global patch_w
+    global patch_h
+    global scale
+    global draw_tool
+
+    if event == cv2.EVENT_LBUTTONDOWN:  # 左键点击
+        # 在点击位置绘制红色圆点
+        scale_patch_w = patch_w // scale
+        scale_patch_h = patch_h // scale
+        # 格子角点
+        circle_x_corner = (x // scale_patch_w)*scale_patch_w
+        circle_y_corner = (y // scale_patch_h)*scale_patch_h
+        # 格子中心点
+        circle_x_center = circle_x_corner + scale_patch_w//2
+        circle_y_center = circle_y_corner + scale_patch_h//2
+        cv2.circle(param, (circle_x_center, circle_y_center), 5, (0, 0, 255), -1)
+        cv2.circle(param, (circle_x_corner, circle_y_corner), 5, (255, 0, 0), -1)
+
+        # 更新显示
+        cv2.imshow('img', param)
+
+        cls = get_text_input('请输入类别:0.背景 1.浑浊 -1.不参与')
+        # 显示标签文本
+        cv2.putText(param, str(cls), (circle_x_center + 10, circle_y_center + 10),
+                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
+        # 更新显示
+        cv2.imshow('img', param)
+        valid_cls = [0, 1, -1]
+        if cls in valid_cls:
+            print(f"点击网格角点: ({circle_x_corner}, {circle_y_corner}) 中心点: ({circle_x_center}, {circle_y_center}) 类别:{cls}")
+            # 记录标注数据,角点 u v w h cls
+            pos, is_exist = det_same_pos_point(circle_x_corner, circle_y_corner)  # 判断是否重复点击
+            if is_exist:
+                print(f"已存在该点: ({clicked_points[pos][0]}, {clicked_points[pos][1]}) 类别: {clicked_points[pos][4]}")
+                clicked_points[pos][4] = cls
+                print(f'重新标注该点: ({clicked_points[pos][0]}, {clicked_points[pos][1]}) 类别: {clicked_points[pos][4]}')
+            else:
+                print(f"添加点: ({circle_x_corner}, {circle_y_corner}) 类别: {cls}")
+            clicked_points.append([circle_x_corner, circle_y_corner, scale_patch_w, scale_patch_h, cls])
+        else:
+            print("请输入正确的类别!")
+    elif event == cv2.EVENT_RBUTTONDOWN:  # 右键点击
+        removed_point = clicked_points.pop()
+        print(f"撤销标注点: ({removed_point[0]}, {removed_point[1]}) 类别: {removed_point[4]}")
+        # 将撤销点标记为黑色
+        x = removed_point[0]
+        y = removed_point[1]
+        # 在点击位置绘制黑色圆点
+        scale_patch_w = patch_w // scale
+        scale_patch_h = patch_h // scale
+        # 格子角点
+        circle_x_corner = (x // scale_patch_w)*scale_patch_w
+        circle_y_corner = (y // scale_patch_h)*scale_patch_h
+        # 格子中心点
+        circle_x_center = circle_x_corner + scale_patch_w//2
+        circle_y_center = circle_y_corner + scale_patch_h//2
+        cv2.circle(param, (circle_x_center, circle_y_center), 5, (128, 128, 128), -1)
+        cv2.circle(param, (circle_x_corner, circle_y_corner), 5, (128, 128, 128), -1)
+        # 更新显示
+        cv2.imshow('img', param)
+
+
+def remove_duplicates(arr:list):
+    """列表去重"""
+    unique_list = []
+    [unique_list.append(item) for item in arr if item not in unique_list]
+    return unique_list
+
+def play_video(video_path):
+    global scale
+    dir_name = os.path.dirname(video_path)
+    for i in os.listdir(dir_name):
+        frame = cv2.imread(os.path.join(dir_name, i))
+        if frame is None:
+            continue
+        # resize
+        frame = cv2.resize(frame, (frame.shape[1] // scale, frame.shape[0] // scale))
+        cv2.imshow('Video', frame)
+        # 按esc退出
+        if cv2.waitKey(20) == 27:
+            cv2.destroyAllWindows()
+            break
+
+def main():
+    """
+    固定摄像头标注,只需要标注一张图像,后续图像保持一致
+    1.标注过程,先将图像划分为图框,用cv2划线工具在图像上划网格线
+    2.用鼠标进行交互,点击图块输入标签,按下空格键完成交互过程,保存标签
+    3.标签格式:u,v,w,h,label u,v为块左上角坐标,w,h为块的宽和高,label为块的标签
+    """
+    global clicked_points
+    global patch_w
+    global patch_h
+    global scale
+    # TODO: 需要更改为准备标注的图像路径,使用当前目录下的000000.jpg,结果保存在当前目录下label.txt
+    img_path = r'/frame_data/test/video4_20251129120320_20251129123514\000000.jpg'
+    play_video(img_path)
+    img = cv2.imread(img_path)
+    draw_tool.set_path(img_path)
+    # resize 图像太大了显示不全
+    img = cv2.resize(img, (img.shape[1] // scale, img.shape[0] // scale))
+    # 绘制网格线
+    draw_grid(img, patch_w // scale, patch_h // scale)
+    # 交互标注
+    print("操作说明:")
+    print("- 点击鼠标左键在图像上添加红色标记点: 0.其他 1.浑浊 -1.忽略,不参与训练和测试")
+    print("- 点击鼠标右键撤回上一个红色标记点")
+    print("- 按 'c' 键清除所有标记点")
+    print("- 按 ESC 键退出程序")
+    cv2.namedWindow('img')
+    cv2.setMouseCallback('img', mouse_callback, img)
+    # 交互标注
+    while True:
+        # 更新显示
+        cv2.imshow('img', draw_tool.draw_new_img(clicked_points))
+        key = cv2.waitKey(1) & 0xFF
+
+        # 按 'c' 键清除所有标记点
+        if key == ord('c'):
+            img = cv2.imread(img_path)
+            img = cv2.resize(img, (img.shape[1] // scale, img.shape[0] // scale))
+            draw_grid(img, patch_w // scale, patch_h // scale)
+            clicked_points.clear()
+            cv2.setMouseCallback('img', mouse_callback, img)
+            print("已清除所有标记点")
+
+        # 按 ESC 键退出
+        elif key == 27:  # ESC键
+            break
+    cv2.destroyAllWindows()
+    # 输出所有点击位置
+    # 列表去重
+    clicked_points = remove_duplicates(clicked_points)
+
+    print(f"总共标记了 {len(clicked_points)} 个点:")
+    for i, point in enumerate(clicked_points):
+        print(f"  点 {i + 1}: ({point[0]}, {point[1]}, {point[2]}, {point[3]}, {point[4]})")
+    # 恢复尺寸
+    clicked_points = [[p[0]*scale, p[1]*scale, p[2]*scale, p[3]*scale, p[4]] for p in clicked_points]
+    # 写入txt
+    if clicked_points:
+        with open(os.path.join(os.path.dirname(img_path), 'label.txt'), 'w') as fw:
+            for point in clicked_points:
+                fw.write(f"{point[0]},{point[1]},{point[2]},{point[3]},{point[4]}\n")
+        # 保存点
+        print(f"保存标记点 {len(clicked_points)} 个:")
+        for i, point in enumerate(clicked_points):
+            print(f"  点 {i + 1}: ({point[0]}, {point[1]}, {point[2]}, {point[3]}, {point[4]})")
+    else :
+        print("没有标记点!不保存任何文件!")
+
+
+
+if __name__ == '__main__':
+    main()

+ 25 - 0
labelme/random_del.py

@@ -0,0 +1,25 @@
+# 按照比例随机删除某个路径下的图像
+import os
+import random
+
+def main():
+    # TODO:需要修改图像路径
+    path = r'D:\code\water_turbidity_det\label_data\test\0'
+    del_rate = 0.3
+    img_path = [i for i in os.listdir(path) if i.split('.')[-1] in ['jpg', 'png'] ]
+    random.shuffle(img_path)
+
+    del_list = img_path[:int(len(img_path)*del_rate)]
+
+    for i in del_list:
+        target_path = os.path.join(path, i)
+        if os.path.isfile(target_path):  # 或者使用 os.path.exists(file_path)
+            os.remove(target_path)
+            print("文件删除成功。",target_path)
+
+    print(f"文件数量: {len(img_path)}")
+    print(f"删除比例: {del_rate}")
+    print(f"删除数量: {len(del_list)}")
+    print(f'剩余数量: {len(img_path)-len(del_list)}')
+if __name__ == '__main__':
+    main()

+ 28 - 0
labelme/statistic.py

@@ -0,0 +1,28 @@
+# 统计标注好的数据,同时给出统计结果保存为txt
+import os
+def count_imgs(path:str, tag:str)->str:
+    target_path = os.path.join(path, tag)
+    # 获取类别子目录
+    sta_res = {}
+    for c in os.listdir(target_path):
+        cls_path = os.path.join(target_path, c)
+        # 获取图片
+        imgs = os.listdir(cls_path)
+        sta_res[c] = len(imgs)
+    return f'{tag} data statistics: {sta_res}'
+def main():
+    train_data_path = r'D:\code\water_turbidity_det\label_data'
+    dirs = os.listdir(train_data_path)
+    info = []
+    if 'train' in dirs:
+        info.append(count_imgs(train_data_path, 'train'))
+    if 'test' in dirs:
+        info.append(count_imgs(train_data_path, 'test'))
+    if 'val' in dirs:
+        info.append(count_imgs(train_data_path, 'val'))
+    with open(os.path.join(train_data_path, 'statistic.txt'), 'w') as fw:
+        for i in info:
+            fw.write(i)
+            fw.write('\n')
+if __name__ == '__main__':
+    main()

+ 26 - 0
labelme/utils.py

@@ -0,0 +1,26 @@
+import numpy as np
+import cv2
+
+def draw_grid(img: np.ndarray, grid_w: int, grid_h:int):
+    """划格网"""
+    img_h, img_w, _ = img.shape
+    # 绘制横向网格线
+    for i in range((img_h // grid_h)+1):
+        cv2.line(img, (0, i*grid_h), (img_w, i*grid_h), (0, 255, 0), 2)
+    # 绘制纵向网格线
+    for i in range((img_w // grid_w)+1):
+        cv2.line(img, (i*grid_w, 0), (i*grid_w, img_h), (0, 255, 0), 2)
+    return img
+
+def draw_predict_grid(img, patches_index, predicted_class, confidence):
+    """在img上绘制预测结果"""
+    for i, (idx_w, idx_h) in enumerate(patches_index):
+        cv2.circle(img, (idx_w, idx_h), 10, (0, 255, 0), -1)
+        text1 = f'cls:{predicted_class[i]}'
+        text2 = f'prob:{confidence[i] * 100:.1f}%'
+        color = (0, 0, 255) if predicted_class[i] else (255, 0, 0)
+        cv2.putText(img, text1, (idx_w, idx_h + 128),
+                    cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
+        cv2.putText(img, text2, (idx_w, idx_h + 128 + 25),
+                    cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
+    return img

+ 44 - 0
labelme/video_depart.py

@@ -0,0 +1,44 @@
+# 将视频图像分离为单帧序列
+import cv2
+import os
+import shutil
+def main():
+    # 视频路径
+    # TODO: 修改视频路径为自己的视频路径,每次指定一个视频
+    path = r'D:\code\water_turbidity_det\video\video20251225\4_video_20251223163145.mp4'
+    output_rootpath = r'D:\code\water_turbidity_det\frame_data'  # 输出路径的根目录
+    # 抽帧间隔
+    interval = 15
+    # 我们将图像输出到根目录下的子目录中,子目录和视频名称相同
+
+    img_base = os.path.basename(path).split('.')[0]
+    imgs_output_path = os.path.join(output_rootpath, img_base)  # 视频名称不要有.符号混淆后缀名
+    # 清理上次抽帧结果
+    if os.path.exists(imgs_output_path):
+        shutil.rmtree(imgs_output_path)
+    # 创建路径
+    if not os.path.exists(imgs_output_path):
+        os.makedirs(imgs_output_path)
+
+    # 打开视频文件
+    cap = cv2.VideoCapture(path)
+    # 检查视频是否成功打开
+    if not cap.isOpened():
+        raise Exception('错误:无法打开视频文件。')
+    frame_count = 0
+    save_count = 0
+    while True:
+        ret, frame = cap.read()
+        if not ret:
+            break
+        img_save_path = os.path.join(imgs_output_path,f"{frame_count:06d}.jpg")
+        # 按照抽帧间隔保存图片
+        if frame_count % interval == 0:
+            cv2.imwrite(img_save_path, frame)
+            save_count+=1
+        frame_count += 1
+        print(f"已处理{frame_count}帧, 保存至{img_save_path}")
+    print(f"处理完成,共处理{frame_count}帧, 保存{save_count}帧")
+if __name__ == '__main__':
+    main()
+

+ 4 - 0
main.py

@@ -0,0 +1,4 @@
+import torch
+from torchvision.models import resnet50, ResNet50_Weights
+
+

+ 95 - 0
play.py

@@ -0,0 +1,95 @@
+import cv2
+import numpy as np
+from calculator import region_ndti, colorful_ndti_matrix
+import threading
+
+
+def main(video_path, video_mask_path, video_id):
+
+    records = []
+
+    cap = cv2.VideoCapture(video_path)
+    if not cap.isOpened():
+        print("错误:无法打开视频文件。")
+        print("可能的原因:文件路径错误,或系统缺少必要的解码器。")
+        exit()
+    scale = 2
+    mask = cv2.imread(video_mask_path, cv2.IMREAD_GRAYSCALE)
+    if mask is None:
+        raise RuntimeError('mask open failed')
+    mask = cv2.resize(mask, (mask.shape[1] // scale, mask.shape[0] // scale))
+
+    fps = cap.get(cv2.CAP_PROP_FPS)
+    print("fps:", fps)
+    while True:
+
+        ret, frame = cap.read()  # ret指示是否成功读取帧,frame是图像数据
+        if not ret:
+            print("视频播放完毕。")
+            break
+        # 判断是否为夜间模式
+        if np.mean(np.abs(frame[:, :, 0] - frame[:, :, 1])) < 0.1:
+            is_night = True
+        else:
+            is_night = False
+        frame = cv2.resize(frame, (frame.shape[1] // scale, frame.shape[0] // scale))
+
+        # 计算掩膜区域的NDTI值
+        roi_ndti = region_ndti(frame, mask, is_night)
+
+
+        # 给ndti着色
+        color_ndti = colorful_ndti_matrix(roi_ndti)
+
+
+        # 仅保留感兴趣区的计算结果
+        roi_index = mask > 0
+        roi_ndti_fla = roi_ndti[roi_index]
+        # 3σ原则
+        # mean_value = np.mean(roi_ndti_fla)
+        # std_value = np.std(roi_ndti_fla)
+        # up_limit = mean_value + 3 * std_value
+        # down_limit = mean_value - 3 * std_value
+        # roi_ndti_fla = roi_ndti_fla[(up_limit >= roi_ndti_fla) & (roi_ndti_fla >= down_limit)]
+
+        text = f'pixs:{int(np.sum(roi_ndti_fla))} mean:{roi_ndti_fla.mean():.3f} std:{roi_ndti_fla.std():.3f} {'night' if is_night else 'day'}'
+        cv2.putText(frame, text, (10, 25), cv2.FONT_HERSHEY_DUPLEX, 0.6, (0, 255, 0), 2, cv2.LINE_AA)
+        cv2.imshow(f'original_{video_id}', frame)
+        cv2.imshow(f'ndti_color_{video_id}', color_ndti)
+        cv2.waitKey(30)
+
+        # 退出逻辑
+        key = cv2.waitKey(10) & 0xFF
+        if key  == ord('q'):
+            break
+
+if __name__ == '__main__':
+    # main(video_path=r'D:\code\water_turbidity_det\frame_data\night\4_video_20251120_1800.dav',
+    #      video_mask_path=r'D:\code\water_turbidity_det\draw_mask\mask\4_device_capture.png',
+    #      video_id=1)
+
+    t1 = threading.Thread(target=main,
+                     args=(r'D:\code\water_turbidity_det\data\records_202512031553\video4_20251129120320_20251129123514.mp4',
+                           r'D:\code\water_turbidity_det\draw_mask\mask\4_device_capture.png',
+                           4))
+
+    # t2 = threading.Thread(target=main,
+    #                  args=(r'D:\code\water_turbidity_det\frame_data\day_202511211129\4_video_202511211127.dav',
+    #                        r'D:\code\water_turbidity_det\draw_mask\mask\4_device_capture.png',
+    #                        4))
+
+    # t1 = threading.Thread(target=main,
+    #                  args=(r'D:\code\water_turbidity_det\frame_data\records_202512031553\video1_20251129055729_20251129063102.mp4',
+    #                        r'D:\code\water_turbidity_det\draw_mask\mask\1_device_capture.png',
+    #                        1))
+    #
+    # t2 = threading.Thread(target=main,
+    #                  args=(r'D:\code\water_turbidity_det\frame_data\records_202512031553\video4_20251129053218_20251129060528.mp4',
+    #                        r'D:\code\water_turbidity_det\draw_mask\mask\4_device_capture.png',
+    #                        4))
+    t1.start()
+    # t2.start()
+
+    t1.join()
+    # t2.join()
+    cv2.destroyAllWindows()

+ 47 - 0
pth2onnx.py

@@ -0,0 +1,47 @@
+import torch
+import onnx
+import torch.onnx
+import onnxruntime as ort
+import numpy as np
+from torchvision.models import resnet50, shufflenet_v2_x1_0, shufflenet_v2_x2_0
+from torch import nn
+from simple_model import  SimpleModel
+if __name__ == '__main__':
+
+    # 载入模型框架
+    # model = SimpleModel()
+    # model = resnet50(pretrained=True)
+
+    model = shufflenet_v2_x1_0()
+    model.fc = nn.Linear(int(model.fc.in_features), 2, bias=False)
+    model.load_state_dict(torch.load(r'./shufflenet.pth')) # xxx.pth表示.pth文件, 这一步载入模型权重
+    print("加载模型成功")
+    model.eval() # 设置模型为推理模式
+
+    example_input = torch.randn(1, 3, 256, 256)          # [1,3,224,224]分别对应[B,C,H,W]
+   # print(model)
+    torch.onnx.export(model,
+                      example_input,
+                      "shufflenet.onnx",
+                      opset_version=13,
+                      export_params=True,
+                      do_constant_folding=True,
+                      )  # xxx.onnx表示.onnx文件, 这一步导出为onnx模型, 并不做任何算子融合操作。
+
+    # 验证模型
+
+    onnx_model = onnx.load("shufflenet.onnx")  # 使用不同变量名
+    onnx.checker.check_model(onnx_model)  # 验证模型完整性
+    # 使用ONNX Runtime进行推理
+    ort_session = ort.InferenceSession("shufflenet.onnx")
+    ort_inputs = {ort_session.get_inputs()[0].name: example_input.detach().numpy()}
+    ort_outs = ort_session.run(None, ort_inputs)
+
+    # 与PyTorch原始输出对比
+    with torch.no_grad():
+        torch_out = model(example_input)
+
+    # 检查最大误差
+    print("输出差异最大为:", np.max(np.abs(torch_out.numpy() - ort_outs[0])))
+    #mean_mlir = [0.485×255, 0.456×255, 0.406×255] = [123.675, 116.28, 103.53]
+    #scale_mlir = [0.229*255, 0.224*255, 0.225*255] = [58.395, 57.12, 57.375]

+ 14 - 0
readme.md

@@ -0,0 +1,14 @@
+初始化
+
+
+浊度检测深度学习方案:
+一、样本标注
+我们首先将一张图片划分为多个小块,每个小块都对应一个标签。标签类别如下:
+0:表示背景;1:表示清澈水体;2:表示浑浊水体
+
+1.工具开发
+添加图片标注工具,对原始图像进行标注,生成对应的标签文件,存储格式为:
+文件名称:与图片名称一致,后缀为txt;
+文件内容:对应图片的标签信息,每行表示一个块;
+文件格式:u,v,w,h,label u,v为块左上角坐标,w,h为块的宽和高,label为块的标签
+添加标注信息展示工具,将图片和标签信息进行展示,方便用户查看标注信息

+ 265 - 0
roi_creator_mp.py

@@ -0,0 +1,265 @@
+# -*- coding: utf-8 -*-
+"""
+代码一:定义数据并保存(支持多个ROI)
+
+功能:
+1. 加载图一。
+2. 允许用户交互式地绘制多个多边形 ROI,并为每个ROI命名。
+3. 允许用户交互式地绘制多个矩形屏蔽区。
+4. 将所有ROI坐标、名称、屏蔽区坐标和相关参数打包保存到一个 JSON 文件中。
+"""
+import cv2
+import numpy as np
+import json
+import os
+import base64
+
+# ========== 全局变量 ==========
+# --- ROI 绘制相关 ---
+current_roi_points = []
+all_rois = []  # 存储所有ROI及其名称
+drawing_complete = False
+drawing_image = None
+original_image = None
+# --- 屏蔽区绘制相关 ---
+ignore_polygons = []
+drawing_ignore = None
+drawing_rect = False
+rect_start_point = None
+
+# ========== 可配置参数 ==========
+CONFIG_PARAMS = {
+    "MARGIN_RATIO": 0.25,
+    "MIN_MATCH_COUNT": 10,
+    "USE_CLAHE": True
+}
+
+
+# -------------------------------------------------
+# 1. 回调函数:选择 ROI
+# -------------------------------------------------
+def select_roi(event, x, y, flags, param):
+    """用于绘制 ROI 多边形的回调函数"""
+    global current_roi_points, drawing_complete, drawing_image
+
+    if drawing_complete:
+        return
+
+    if event == cv2.EVENT_LBUTTONDOWN:
+        current_roi_points.append((x, y))
+        cv2.circle(drawing_image, (x, y), 4, (0, 255, 0), -1)
+        if len(current_roi_points) > 1:
+            cv2.line(drawing_image, current_roi_points[-2], current_roi_points[-1], (0, 255, 0), 2)
+        cv2.imshow("Image 1 - Select ROI", drawing_image)
+
+    elif event == cv2.EVENT_RBUTTONDOWN:
+        if len(current_roi_points) < 3:
+            print("[提示] ROI 至少需要 3 个点构成多边形!")
+            return
+        drawing_complete = True
+        cv2.line(drawing_image, current_roi_points[-1], current_roi_points[0], (0, 255, 0), 2)
+        cv2.imshow("Image 1 - Select ROI", drawing_image)
+        print("[提示] 当前ROI选择完毕,按任意键继续...")
+
+
+# -------------------------------------------------
+# 2. 回调函数:选择屏蔽区
+# -------------------------------------------------
+def select_ignore_rect(event, x, y, flags, param):
+    """用于绘制屏蔽矩形的回调函数"""
+    global rect_start_point, drawing_rect, ignore_polygons, drawing_ignore
+
+    if event == cv2.EVENT_LBUTTONDOWN:
+        rect_start_point = (x, y)
+        drawing_rect = True
+
+    elif event == cv2.EVENT_MOUSEMOVE:
+        if drawing_rect:
+            preview_img = drawing_ignore.copy()
+            cv2.rectangle(preview_img, rect_start_point, (x, y), (0, 0, 255), 2)
+            cv2.imshow("Image 1 - Ignore Area", preview_img)
+
+    elif event == cv2.EVENT_LBUTTONUP:
+        if not drawing_rect:
+            return
+        drawing_rect = False
+        x_start, y_start = rect_start_point
+        x1, y1 = min(x_start, x), min(y_start, y)
+        x2, y2 = max(x_start, x), max(y_start, y)
+
+        if x2 - x1 > 5 and y2 - y1 > 5:
+            cv2.rectangle(drawing_ignore, (x1, y1), (x2, y2), (0, 0, 255), 2)
+            rect_as_poly = [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
+            ignore_polygons.append(rect_as_poly)
+            print(f"[提示] 已新增屏蔽区 #{len(ignore_polygons)},可继续拉框,完成后按 Enter 键。")
+        cv2.imshow("Image 1 - Ignore Area", drawing_ignore)
+
+
+# -------------------------------------------------
+# 3. 工具函数
+# -------------------------------------------------
+def read_image_chinese(path):
+    """读取可能包含中文路径的图片"""
+    try:
+        with open(path, 'rb') as f:
+            data = np.frombuffer(f.read(), np.uint8)
+        return cv2.imdecode(data, cv2.IMREAD_COLOR)
+    except Exception as e:
+        print(f"读取图片失败: {e}")
+        return None
+
+
+def draw_all_rois(image, rois_list):
+    """在图像上绘制所有已保存的ROI"""
+    result = image.copy()
+    colors = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0),
+              (255, 0, 255), (0, 255, 255), (128, 255, 0), (255, 128, 0)]
+
+    for idx, roi_data in enumerate(rois_list):
+        color = colors[idx % len(colors)]
+        points = roi_data["points"]
+        name = roi_data["name"]
+
+        # 绘制多边形
+        cv2.polylines(result, [np.array(points, np.int32)], True, color, 2)
+
+        # 计算中心点并显示名称
+        center = np.mean(points, axis=0).astype(int)
+        cv2.putText(result, name, tuple(center), cv2.FONT_HERSHEY_SIMPLEX,
+                    0.7, color, 2, cv2.LINE_AA)
+
+    return result
+
+
+# -------------------------------------------------
+# 4. 主函数
+# -------------------------------------------------
+def define_and_save_data(image1_path, output_json_path):
+    """
+    主流程函数:加载图像,引导用户选择多个ROI并命名,最终保存数据。
+    """
+    global current_roi_points, drawing_complete, drawing_image, all_rois
+    global drawing_ignore, ignore_polygons, original_image
+
+    # --- 步骤 1: 加载图像 ---
+    img1 = read_image_chinese(image1_path)
+    if img1 is None:
+        print(f"错误:无法加载图像 '{image1_path}',请检查路径!")
+        return
+
+    original_image = img1.copy()
+
+    # --- 步骤 2: 循环选择多个ROI ---
+    print("\n========== 开始选择ROI ==========")
+    roi_count = 0
+
+    while True:
+        roi_count += 1
+        current_roi_points = []
+        drawing_complete = False
+
+        # 显示已有的ROI
+        if all_rois:
+            drawing_image = draw_all_rois(original_image, all_rois)
+        else:
+            drawing_image = original_image.copy()
+
+        cv2.namedWindow("Image 1 - Select ROI")
+        cv2.setMouseCallback("Image 1 - Select ROI", select_roi)
+
+        print(f"\n【ROI #{roi_count}】")
+        print("用鼠标左键多点圈出ROI,右键闭合完成当前ROI。")
+        print("按 ESC 键结束所有ROI选择,按其他键继续添加下一个ROI。")
+        cv2.imshow("Image 1 - Select ROI", drawing_image)
+
+        key = cv2.waitKey(0)
+
+        if key == 27:  # ESC键,结束ROI选择
+            cv2.destroyWindow("Image 1 - Select ROI")
+            if len(current_roi_points) >= 3:
+                # 保存最后一个ROI
+                roi_name = input(f"请输入ROI #{roi_count} 的名称: ").strip()
+                if not roi_name:
+                    roi_name = f"ROI_{roi_count}"
+                all_rois.append({
+                    "name": roi_name,
+                    "points": current_roi_points
+                })
+                print(f"已保存ROI: {roi_name}")
+            break
+        else:
+            if len(current_roi_points) >= 3:
+                # 为当前ROI输入名称
+                roi_name = input(f"请输入ROI #{roi_count} 的名称: ").strip()
+                if not roi_name:
+                    roi_name = f"ROI_{roi_count}"
+
+                all_rois.append({
+                    "name": roi_name,
+                    "points": current_roi_points
+                })
+                print(f"已保存ROI: {roi_name}")
+            else:
+                print("当前ROI点数不足,跳过保存。")
+                roi_count -= 1
+
+    if not all_rois:
+        print("错误:未选择任何ROI,程序退出。")
+        return
+
+    print(f"\n总共选择了 {len(all_rois)} 个ROI")
+
+    # --- 步骤 3: 交互式选择屏蔽区 ---
+    drawing_ignore = draw_all_rois(original_image, all_rois)
+    cv2.namedWindow("Image 1 - Ignore Area")
+    cv2.setMouseCallback("Image 1 - Ignore Area", select_ignore_rect)
+    print("\n【步骤 2】在弹窗中,用鼠标左键拖拽拉框来选择屏蔽区,可画多个。")
+    print("         所有屏蔽区画完后,按 Enter 键保存并结束,或按 Esc 键退出。")
+    cv2.imshow("Image 1 - Ignore Area", drawing_ignore)
+
+    while True:
+        key = cv2.waitKey(0)
+        if key in [13, 10]:  # Enter键
+            break
+        elif key == 27:  # Esc
+            cv2.destroyAllWindows()
+            print("用户取消操作。")
+            return
+    cv2.destroyAllWindows()
+
+    # --- 步骤 4: 封装数据并保存到 JSON ---
+    try:
+        with open(image1_path, 'rb') as f:
+            image_bytes = f.read()
+        image_base64 = base64.b64encode(image_bytes).decode('utf-8')
+        image_filename = os.path.basename(image1_path)
+    except Exception as e:
+        print(f"\n[失败] 读取并编码图像文件时出错: {e}")
+        return
+
+    # 封装数据
+    data_to_save = {
+        "image1_filename": image_filename,
+        "image1_data_base64": image_base64,
+        "rois": all_rois,  # 保存所有ROI及其名称
+        "ignore_polygons": ignore_polygons,
+        "parameters": CONFIG_PARAMS
+    }
+
+    try:
+        with open(output_json_path, 'w', encoding='utf-8') as f:
+            json.dump(data_to_save, f, indent=4, ensure_ascii=False)
+        print(f"\n[成功] {len(all_rois)} 个ROI、屏蔽区和图像数据已成功保存到: {output_json_path}")
+    except Exception as e:
+        print(f"\n[失败] 保存 JSON 文件时出错: {e}")
+
+
+# -------------------------------------------------
+if __name__ == "__main__":
+    IMG1_PATH = r"20250615000002.jpg"
+    OUTPUT_JSON = "roi_json_dir/matching_data_multi_roi.json"
+
+    # 创建输出目录
+    os.makedirs(os.path.dirname(OUTPUT_JSON), exist_ok=True)
+
+    define_and_save_data(IMG1_PATH, OUTPUT_JSON)

+ 12 - 0
run.bash

@@ -0,0 +1,12 @@
+#python train.py --model swin_v2_b
+#sleep 60
+#python train.py --model swin_v2_s
+#sleep 60
+python train.py --model squeezenet
+sleep 60
+python train.py --model squeezenet-x2
+sleep 60
+python train.py --model shufflenet
+sleep 60
+python train.py --model resnet50
+

+ 20 - 0
simple_model.py

@@ -0,0 +1,20 @@
+#!/usr/bin/env python3
+import torch
+
+# Build a simple nn model
+class SimpleModel(torch.nn.Module):
+
+   def __init__(self):
+      super(SimpleModel, self).__init__()
+      self.m1 = torch.nn.Conv2d(3, 8, 3, 1, 0)
+      self.m2 = torch.nn.Conv2d(8, 8, 3, 1, 1)
+
+   def forward(self, x):
+      y0 = self.m1(x)
+      y1 = self.m2(y0)
+      y2 = y0 + y1
+      return y2
+
+# Create a SimpleModel and save its weight in the current directory
+model = SimpleModel()
+torch.save(model.state_dict(), "simple.pth")

+ 280 - 0
test.py

@@ -0,0 +1,280 @@
+import time
+
+import torch
+import torch.nn as nn
+from torchvision import transforms
+from torchvision.models import resnet18,resnet50, squeezenet1_0, shufflenet_v2_x1_0
+import numpy as np
+from PIL import Image
+import os
+import argparse
+from labelme.utils import draw_grid, draw_predict_grid
+import cv2
+import matplotlib.pyplot as plt
+from dotenv import load_dotenv
+load_dotenv()
+# os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
+patch_w = int(os.getenv('PATCH_WIDTH', 256))
+patch_h = int(os.getenv('PATCH_HEIGHT', 256))
+confidence_threshold = float(os.getenv('CONFIDENCE_THRESHOLD', 0.80))
+scale = 2
+
+
+class Predictor:
+    def __init__(self, model_name, weights_path, num_classes):
+        self.model_name = model_name
+        self.weights_path = weights_path
+        self.num_classes = num_classes
+        self.model = None
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        print(f"当前设备: {self.device}")
+        # 加载模型
+        self.load_model()
+
+
+    def load_model(self):
+        if self.model is not None:
+            return
+        print(f"正在加载模型: {self.model_name}")
+        name = self.model_name
+        # 加载模型
+        if name == 'resnet50':
+
+            self.model = resnet50()
+        elif name == 'squeezenet':
+
+            self.model = squeezenet1_0()
+        elif name == 'shufflenet':
+            self.model = shufflenet_v2_x1_0()
+        else:
+            raise ValueError(f"Invalid model name: {name}")
+        # 替换最后的分类层以适应新的分类任务
+        if hasattr(self.model, 'fc'):
+            # ResNet系列模型
+            self.model.fc = nn.Linear(int(self.model.fc.in_features), self.num_classes, bias=False)
+        elif hasattr(self.model, 'classifier'):
+            # Swin Transformer等模型
+            self.model.classifier = nn.Linear(int(self.model.classifier.in_features), self.num_classes, bias=False)
+        elif hasattr(self.model, 'head'):
+            # Swin Transformer使用head层
+            self.model.head = nn.Linear(int(self.model.head.in_features), self.num_classes, bias=False)
+
+        else:
+            raise ValueError(f"Model {name} does not have recognizable classifier layer")
+        print(self.model)
+        # 加载训练好的权重
+        self.model.load_state_dict(torch.load(self.weights_path, map_location=torch.device('cpu')))
+        print(f"成功加载模型参数: {self.weights_path}")
+        # 将模型移动到GPU
+        self.model.eval()
+        self.model = self.model.to(self.device)
+        print(f"成功加载模型: {self.model_name}")
+
+    def predict(self, image_tensor):
+        """
+        对单张图像进行预测
+
+        Args:
+            image_tensor: 预处理后的图像张量
+
+        Returns:
+            predicted_class: 预测的类别索引
+            confidence: 预测置信度
+            probabilities: 各类别的概率
+        """
+
+        image_tensor = image_tensor.to(self.device)
+
+        with torch.no_grad():
+            outputs = self.model(image_tensor)
+            probabilities = torch.softmax(outputs, dim=1)  # 沿行计算softmax
+            confidence, predicted_class = torch.max(probabilities, 1)
+
+        return confidence.cpu().numpy(), predicted_class.cpu().numpy()
+
+
+def preprocess_image(img):
+    """
+    预处理图像以匹配训练时的预处理
+    
+    Args:
+        img: PIL图像
+        
+    Returns:
+        tensor: 预处理后的图像张量
+    """
+    # 定义与训练时相同的预处理步骤
+    transform = transforms.Compose([
+        transforms.Resize((224, 224)),
+        transforms.ToTensor(),
+        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+    ])
+
+    # 打开并转换图像
+
+    img_w, img_h = img.size
+    global patch_w, patch_h
+    imgs_patch = []
+    imgs_index = []
+    # fig, axs = plt.subplots(img_h // patch_h + 1, img_w // patch_w + 1)
+    for i in range(img_h // patch_h + 1):
+        for j in range(img_w // patch_w + 1):
+            left = j * patch_w  # 裁剪区域左边框距离图像左边的像素值
+            top = i * patch_h  # 裁剪区域上边框距离图像上边的像素值
+            right = min(j * patch_w + patch_w, img_w)  # 裁剪区域右边框距离图像左边的像素值
+            bottom = min(i * patch_h + patch_h, img_h)  # 裁剪区域下边框距离图像上边的像素值
+            # 检查区域是否有效
+            if right > left and bottom > top:
+                patch = img.crop((left, top, right, bottom))
+                # 长宽比过滤
+                # rate = patch.height / (patch.width + 1e-6)
+                # if rate > 1.314 or rate < 0.75:
+                #     # print(f"长宽比过滤: {patch_name}")
+                #     continue
+                imgs_patch.append(patch)
+                imgs_index.append((left, top))
+                # axs[i, j].imshow(patch)
+                # axs[i, j].set_title(f'Image {i} {j}')
+                # axs[i, j].axis('off')
+
+    # plt.tight_layout()
+    # plt.show()
+    imgs_patch = torch.stack([transform(img) for img in imgs_patch])
+    # 添加批次维度
+    # image_tensor = image_tensor.unsqueeze(0)
+    return imgs_index, imgs_patch
+
+
+def visualize_prediction(image_path, predicted_class, confidence, class_names):
+    """
+    可视化预测结果
+    
+    Args:
+        image_path: 图像路径
+        predicted_class: 预测的类别索引
+        confidence: 预测置信度
+        class_names: 类别名称列表
+    """
+    image = Image.open(image_path).convert('RGB')
+    
+    plt.figure(figsize=(8, 6))
+    plt.imshow(image)
+    plt.axis('off')
+    plt.title(f'Predicted: {class_names[predicted_class]}\n'
+              f'Confidence: {confidence:.4f}', fontsize=14)
+    plt.show()
+
+def get_33_patch(arr:np.ndarray, center_row:int, center_col:int):
+    """以(center_row,center_col)为中心,从arr中取出来3*3区域的数据"""
+    # 边界检查
+    h,w = arr.shape
+    safe_row_up_limit = max(0, center_row-1)
+    safe_row_bottom_limit = min(h, center_row+2)
+    safe_col_left_limit = max(0, center_col-1)
+    safe_col_right_limit = min(w, center_col+2)
+    return arr[safe_row_up_limit:safe_row_bottom_limit, safe_col_left_limit:safe_col_right_limit]
+
+
+def fileter_prediction(predicted_class, confidence, pre_rows, pre_cols, filter_down_limit=3):
+    """预测结果矩阵滤波,九宫格内部存在浑浊水体的数量需要大于filter_down_limit,"""
+    predicted_class_mat = np.resize(predicted_class, (pre_rows, pre_cols))
+    predicted_conf_mat = np.resize(confidence, (pre_rows, pre_cols))
+    new_predicted_class_mat = predicted_class_mat.copy()
+    new_predicted_conf_mat = predicted_conf_mat.copy()
+    for i in range(pre_rows):
+        for j in range(pre_cols):
+            if (1. - predicted_class_mat[i, j]) > 0.1:
+                continue  # 跳过背景类
+            core_region = get_33_patch(predicted_class_mat, i, j)
+            if np.sum(core_region) < filter_down_limit:
+                new_predicted_class_mat[i, j] = 0  #  重置为背景类
+                new_predicted_conf_mat[i, j] = 1.0
+    return new_predicted_conf_mat.flatten(), new_predicted_class_mat.flatten()
+
+def discriminate_ratio(water_pre_list:list):
+    # 方式一:60%以上的帧存在浑浊水体
+    water_pre_arr = np.array(water_pre_list, dtype=np.float32)
+    water_pre_arr_mean = np.mean(water_pre_arr, axis=0)
+    bad_water = np.array(water_pre_arr_mean >= 0.6, dtype=np.int32)
+    bad_flag = np.sum(bad_water, dtype=np.int32)
+    print(f'浑浊比例方式:该时间段是否存在浑浊水体:{bool(bad_flag)}')
+    return bad_flag
+
+
+def discriminate_cont(pre_class_arr, continuous_count_mat):
+    """连续帧判别"""
+    positive_index = np.array(pre_class_arr,dtype=np.int32) > 0
+    negative_index = np.array(pre_class_arr,dtype=np.int32) == 0
+    # 给负样本区域置零
+    continuous_count_mat[negative_index] = 0
+    # 给正样本区域加1
+    continuous_count_mat[positive_index] += 1
+    # 判断浑浊
+    bad_flag = np.max(continuous_count_mat) > 30
+    if bad_flag:
+        print(f'连续帧方式:该时间段是否存在浑浊水体:{bool(bad_flag)}')
+    return bad_flag
+
+def main():
+
+    # 初始化模型实例
+    # TODO:修改模型网络名称/模型权重路径/视频路径
+    predictor = Predictor(model_name='shufflenet',
+                          weights_path=r'/shufflenet.pth',
+                          num_classes=2)
+    input_path = r'frame_data/train/20251225/4_video_202511211127'
+    # 预处理图像
+    all_imgs = os.listdir(input_path)
+    all_imgs = [os.path.join(input_path, p) for p in all_imgs if p.split('.')[-1] in ['jpg', 'png']]
+    image = Image.open(all_imgs[0]).convert('RGB')
+    # 将预测结果reshape为矩阵时的行列数量
+    pre_rows = image.height // patch_h + 1
+    pre_cols = image.width // patch_w + 1
+    # 图像显示时resize的尺寸
+    resized_img_h = image.height // 2
+    resized_img_w = image.width // 2
+    # 预测每张图像
+
+    water_pre_list = []
+    continuous_count_mat = np.zeros(pre_rows*pre_cols, dtype=np.int32)
+    flag = False
+    for img_path in all_imgs:
+        image = Image.open(img_path).convert('RGB')
+        # 预处理
+        patches_index, image_tensor = preprocess_image(image)
+        # 推理
+        confidence, predicted_class  = predictor.predict(image_tensor)
+        # 第一层虚警抑制,置信度过滤,低于阈值将会被忽略
+        for i in range(len(confidence)):
+            if confidence[i] < confidence_threshold:
+                confidence[i] = 1.0
+                predicted_class[i] = 0
+        # 第二层虚警抑制,空间滤波
+        # 在此处添加过滤逻辑
+        new_confidence, new_predicted_class = fileter_prediction(predicted_class, confidence, pre_rows, pre_cols, filter_down_limit=3)
+        # 可视化预测结果
+        image = cv2.imread(img_path)
+        image = draw_grid(image, patch_w, patch_h)
+        image = draw_predict_grid(image, patches_index, predicted_class, confidence)
+
+        new_image = cv2.imread(img_path)
+        new_image = draw_grid(new_image, patch_w, patch_h)
+        new_image = draw_predict_grid(new_image, patches_index, new_predicted_class, new_confidence)
+        image = cv2.resize(image, (resized_img_w, resized_img_h))
+        new_img = cv2.resize(new_image, (resized_img_w, resized_img_h))
+
+        cv2.imshow('image', image)
+        cv2.imshow('image_filter', new_img)
+
+        cv2.waitKey(20)
+        # 方式1判别
+        if len(water_pre_list) > 100:
+            flag = discriminate_ratio(water_pre_list) and flag
+            water_pre_list = []
+            print('综合判别结果:', flag)
+        water_pre_list.append(new_predicted_class)
+        # 方式2判别
+        flag = discriminate_cont(new_predicted_class, continuous_count_mat)
+
+if __name__ == "__main__":
+    main()

+ 302 - 0
train.py

@@ -0,0 +1,302 @@
+# 微调pytorch的预训练模型,在自己的数据上训练,完成分类任务。
+import time
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision.transforms as transforms
+from torchvision.datasets import ImageFolder
+from torchvision.models import resnet18,resnet50, squeezenet1_0,shufflenet_v2_x1_0,shufflenet_v2_x2_0
+from torch.utils.tensorboard import SummaryWriter  # 添加 TensorBoard 支持
+from datetime import datetime
+import os
+from dotenv import load_dotenv
+load_dotenv()
+os.environ['CUDA_VISIBLE_DEVICES'] = os.getenv('CUDA_VISIBLE_DEVICES', '0')
+# 读取并打印.env文件中的变量
+def print_env_variables():
+    print("从.env文件加载的变量:")
+    env_vars = {
+        'PATCH_WIDTH': os.getenv('PATCH_WIDTH'),
+        'PATCH_HEIGHT': os.getenv('PATCH_HEIGHT'),
+        'CONFIDENCE_THRESHOLD': os.getenv('CONFIDENCE_THRESHOLD'),
+        'IMG_INPUT_SIZE': os.getenv('IMG_INPUT_SIZE'),
+        'WORKERS': os.getenv('WORKERS'),
+        'CUDA_VISIBLE_DEVICES': os.getenv('CUDA_VISIBLE_DEVICES')
+    }
+    for var, value in env_vars.items():
+        print(f"{var}: {value}")
+
+class Trainer:
+    def __init__(self, batch_size, train_dir, val_dir, name, checkpoint):
+        # 定义一些参数
+        self.name = name  # 采用的模型名称
+        self.img_size = int(os.getenv('IMG_INPUT_SIZE', 224))  # 输入图片尺寸
+        self.batch_size = batch_size  # 批次大小
+        self.cls_map = {"0": "non-muddy", "1":"muddy"}  # 类别名称映射词典
+        self.imagenet = os.getenv('PRETRAINED', True)  # 是否使用ImageNet预训练权重
+        # 训练设备 - 优先使用GPU,如果不可用则使用CPU
+        if torch.cuda.is_available():
+            try:
+                # 尝试进行简单的CUDA操作以确认CUDA功能正常
+                _ = torch.zeros(1).cuda()
+                self.device = torch.device("cuda")
+                print("成功检测到CUDA设备,使用GPU进行训练")
+            except Exception as e:
+                print(f"CUDA设备存在问题: {e},回退到CPU")
+                self.device = torch.device("cpu")
+        else:
+            self.device = torch.device("cpu")
+            print("CUDA不可用,使用CPU进行训练")
+        self.checkpoint = checkpoint
+        self.__global_step = 0
+        self.workers = int(os.getenv('WORKERS', 0))
+        # 创建日志目录
+        timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")  # 获取当前时间戳
+        log_dir = f'runs/turbidity_{self.name}_{timestamp}'# 按照模型的名称和时间戳创建日志目录
+        self.writer = SummaryWriter(log_dir)  # 创建 TensorBoard writer
+
+        # 定义数据增强和预处理层
+        self.train_transforms = transforms.Compose([
+            transforms.Resize((self.img_size, self.img_size)),  # 调整图像大小为256x256 (ResNet输入尺寸)
+            transforms.RandomHorizontalFlip(p=0.3),  # 随机水平翻转,增加数据多样性
+            transforms.RandomVerticalFlip(p=0.3), # 随机垂直翻转,增加数据多样性
+            transforms.RandomGrayscale(p=0.25), # 随机灰度化,增加数据多样性
+            transforms.RandomRotation(10),  # 随机旋转±10度
+            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0, hue=0),  # 颜色抖动
+            transforms.ToTensor(),  # 转换为tensor并归一化到[0,1]
+            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化
+        ])
+        # 测试集基础变换
+        self.val_transforms = transforms.Compose([
+            transforms.Resize((self.img_size, self.img_size)),
+            transforms.ToTensor(),
+            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        ])
+
+        # 创建数据集对象
+        self.train_dataset = ImageFolder(root=train_dir, transform=self.train_transforms)
+        self.val_dataset = ImageFolder(root=val_dir, transform=self.val_transforms)
+        # 创建数据加载器 (Windows环境下设置num_workers=0避免多进程问题)
+        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers)  # 可迭代对象,返回(inputs_tensor, label)
+        self.val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.workers)
+        # 获取类别数量
+        self.num_classes = len(self.train_dataset.classes)
+        print(f"自动发现 {self.num_classes} 个类别:")
+        # 打印类别和名称
+        for cls in self.train_dataset.classes:
+            print(f"{cls}: {self.cls_map.get(cls,'None')}")
+        # 打印训练集和测试集图像数量
+        print(f"训练集图像数量: {len(self.train_dataset)}")
+        print(f"验证集图像数量: {len(self.val_dataset)}")
+        # 创建模型
+        self.model = None
+        self.model = self.__load_model()
+        # 定义损失函数
+        self.loss = nn.CrossEntropyLoss()  # 多分类常用的交叉熵损失
+
+        # 定义优化器
+        # 只更新requires_grad=True的参数
+        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
+
+        # 基于验证损失动态调整,更智能
+        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
+            self.optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-7,cooldown=2
+        )
+    def __load_model(self):
+        """加载模型结构"""
+        # 加载模型
+        pretrained = True if self.imagenet else False
+        if self.name == 'resnet50':
+            self.model = resnet50(pretrained=pretrained)
+        elif self.name == 'squeezenet':
+            self.model = squeezenet1_0(pretrained=pretrained)
+        elif self.name == 'shufflenet' or self.name == 'shufflenet-x1':
+            self.model = shufflenet_v2_x1_0(pretrained=pretrained)
+        elif self.name == 'shufflenet-x2':
+            self.model = shufflenet_v2_x2_0(pretrained=False)
+            self.imagenet = False
+            print('shufflenet-x2无预训练权重,重新训练所有权重')
+        else:
+            raise ValueError(f"Invalid model name: {self.name}")
+        # 如果采用预训练的神经网络,就需要冻结特征提取层,只训练最后几层
+        if self.imagenet:
+            for param in self.model.parameters():
+                param.requires_grad = False
+        # 替换最后的分类层以适应新的分类任务
+        if hasattr(self.model, 'fc'):
+            # ResNet系列模型
+            self.model.fc = nn.Linear(int(self.model.fc.in_features), self.num_classes, bias=True)
+        elif hasattr(self.model, 'classifier'):
+            # Swin Transformer等模型
+            self.model.classifier = nn.Linear(int(self.model.classifier.in_features), self.num_classes, bias=True)
+        elif hasattr(self.model, 'head'):
+            # Swin Transformer使用head层
+            self.model.head = nn.Linear(int(self.model.head.in_features), self.num_classes, bias=True)
+        else:
+            raise ValueError(f"Model {self.name} does not have recognizable classifier layer")
+        print(self.model)
+        print(f'模型{self.name}结构已经加载,移动到设备{self.device}')
+        # 将模型移动到GPU/cpu
+        self.model = self.model.to(self.device)
+        return self.model
+
+    def train_step(self):
+        """
+        单轮训练函数
+
+        Args:
+
+        Returns:
+            average_loss: 平均损失
+            accuracy: 准确率
+        """
+        self.model.train()  # 设置模型为训练模式(启用dropout/batchnorm等)
+        epoch_loss = 0.0  # epoch损失
+        correct_predictions = 0.0  # 预测正确的样本数
+        total_samples = 0.0  # 已经训练的总样本
+
+        # 遍历训练数据
+        for inputs, labels in self.train_loader:
+            # 将数据移到指定设备上
+            inputs = inputs.to(self.device)  # b c h w
+            labels = labels.to(self.device)  # b,
+
+            # 清零梯度缓存
+            self.optimizer.zero_grad()
+
+            # 前向传播
+            outputs = self.model(inputs)  # b, 2
+            loss = self.loss(outputs, labels) # 标量
+
+            # 反向传播
+            loss.backward()
+
+            # 更新参数
+            self.optimizer.step()
+
+            # 统计信息
+            batch_loss = loss.item() * inputs.size(0)  # 批损失
+            self.writer.add_scalar('Batch_Loss/Train', batch_loss, self.__global_step)
+            print(f'Training | Batch Loss: {batch_loss:.4f}\t', end='\r',flush=True)
+            epoch_loss += batch_loss  # epoch损失
+            _, predicted = torch.max(outputs.data, 1) # predicted b,存储的是预测的类别,最大值的位置
+            total_samples += labels.size(0)  # 统计总样本数量
+            correct_predictions += (predicted == labels).sum().item()  # 统计预测正确的样本数
+
+        epoch_loss = epoch_loss / len(self.train_loader.dataset)
+        epoch_acc = correct_predictions / total_samples
+        return epoch_loss, epoch_acc
+
+
+    def val_step(self):
+        """
+        验证模型性能
+
+        Args:
+        Returns:
+            average_loss: 平均损失
+            accuracy: 准确率
+        """
+        self.model.eval()  # 设置模型为评估模式(关闭dropout/batchnorm等)
+
+        epoch_loss = 0.0
+        correct_predictions = 0.
+        total_samples = 0.
+
+        # 不计算梯度,提高推理速度
+        with torch.no_grad():
+            for inputs, labels in self.val_loader:
+                inputs = inputs.to(self.device)
+                labels = labels.to(self.device)
+
+                outputs = self.model(inputs)
+                loss = self.loss(outputs, labels)
+
+                epoch_loss += loss.item() * inputs.size(0)
+                _, predicted = torch.max(outputs.data, 1)
+                total_samples += labels.size(0)
+                correct_predictions += (predicted == labels).sum().item()
+
+        epoch_loss = epoch_loss / len(self.val_loader.dataset)  # 平均损失
+        epoch_acc = correct_predictions / total_samples
+
+        return epoch_loss, epoch_acc
+
+
+    def train_and_validate(self, num_epochs=25):
+        """
+        训练和验证
+
+        Args:
+            num_epochs: 训练轮数
+
+        Returns:
+            train_losses: 每轮训练损失
+            train_accuracies: 每轮训练准确率
+            val_losses: 每轮验证损失
+            val_accuracies: 每轮验证准确率
+        """
+
+        best_val_acc = 0.0
+        best_val_loss = float('inf')
+        # 在你的代码中调用
+        print_env_variables()
+        print("开始训练...")
+        for epoch in range(num_epochs):
+            print(f'Epoch {epoch + 1}/{num_epochs}')
+            print('-' * 20)
+
+            # 单步训练
+            train_loss, train_acc = self.train_step()
+            print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')
+
+            # 验证阶段
+            val_loss, val_acc = self.val_step()
+            print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')
+
+            # 学习率调度
+            self.scheduler.step(val_loss)
+
+            # 记录指标到 TensorBoard
+            self.writer.add_scalar('Loss/Train', train_loss, epoch)
+            self.writer.add_scalar('Loss/Validation', val_loss, epoch)
+            self.writer.add_scalar('Accuracy/Train', train_acc, epoch)
+            self.writer.add_scalar('Accuracy/Validation', val_acc, epoch)
+            self.writer.add_scalar('Learning Rate', self.optimizer.param_groups[0]['lr'], epoch)
+
+
+            # 保存最佳模型 (基于验证准确率)
+            if val_acc > best_val_acc:
+                best_val_acc = val_acc
+                torch.save(self.model.state_dict(), f'{self.name}_best_model_acc.pth')
+                print(f"保存了新的最佳准确率模型,验证准确率: {best_val_acc:.4f}")
+            
+            # 保存最低验证损失模型
+            if val_loss < best_val_loss:
+                best_val_loss = val_loss
+                torch.save(self.model.state_dict(), f'{self.name}_best_model_loss.pth')
+                print(f"保存了新的最低损失模型,验证损失: {best_val_loss:.4f}")
+
+
+        # 关闭 TensorBoard writer
+        self.writer.close()
+        
+        print(f"训练完成! 最佳验证准确率: {best_val_acc:.4f}, 最低验证损失: {best_val_loss:.4f}")
+        return 1
+
+if __name__ == '__main__':
+    # 开始训练
+    import argparse
+    parser = argparse.ArgumentParser('预训练模型调参')
+    parser.add_argument('--train_dir',default='./label_data/train',help='help')
+    parser.add_argument('--val_dir', default='./label_data/test',help='help')
+    parser.add_argument('--model', default='shufflenet',help='help')
+    args = parser.parse_args()
+    num_epochs = 100
+    trainer = Trainer(batch_size=int(os.getenv('BATCH_SIZE', 32)),
+                      train_dir=args.train_dir,
+                      val_dir=args.val_dir,
+                      name=args.model,
+                      checkpoint=False)
+    trainer.train_and_validate(num_epochs)

+ 0 - 0
video/原始视频文件.md