Selaa lähdekoodia

update .gitignore

jiyuhang 3 kuukautta sitten
vanhempi
commit
28fb013700

+ 0 - 11
.env

@@ -1,11 +0,0 @@
-# 图块高度
-PATCH_WIDTH=256
-# 图块宽度
-PATCH_HEIGHT=256
-CONFIDENCE_THRESHOLD=0.6
-# 图片输入大小
-IMG_INPUT_SIZE=256
-# 工作线程数
-WORKERS=0
-# CUDA设备
-CUDA_VISIBLE_DEVICES=0

+ 5 - 0
.gitignore

@@ -43,5 +43,10 @@ __pycache__/
 *.txt
 *.onnx
 *.onnx.data
+*.mlir
+*.npz
+*.prototxt
+*.bmodel
+*.josn
 # 日志文件
 runs/*

+ 0 - 408
calculator.py

@@ -1,408 +0,0 @@
-import cv2
-import numpy as np
-import matplotlib
-matplotlib.use('Agg')
-from matplotlib import pyplot as plt
-from  matplotlib import rcParams
-import seaborn as sns
-import threading
-
-
-def set_chinese_font():
-    rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'Microsoft YaHei', 'SimSun']
-    rcParams['axes.unicode_minus'] = False
-
-def region_ndti(img, mask, is_night):
-    """对于可见光图像计算ndti,对于红外图像直接取值"""
-    if len(img.shape) != 3:
-        raise RuntimeError('请输入三通道彩色图像')
-
-    if len(mask.shape) != 2:
-        raise RuntimeError('请输入单通道掩膜图像')
-
-    img = img.astype(np.float32)
-    mask = mask.astype(np.float32)
-    # 取通道
-    g = img[:,:,1]
-    r = img[:,:,2]
-    # 判断是否为红外
-    if is_night: # 红外直接取值
-        ndti = r
-    else:  # 可见光计算ndti
-        ndti = (r - g) / (r + g + 1e-6)
-    # 掩膜归一化
-    if np.max(mask) > 2:
-        mask /= 255.
-    # 仅保留掩膜区域的ndti
-    ndti *= mask
-    return ndti
-
-def plot_histogram_opencv(matrix, start, end, step, title):
-    """使用OpenCV绘制直方图,避免GUI冲突"""
-    # 计算直方图
-    num_bins = int((end - start) / step)
-    hist, bins = np.histogram(matrix.ravel(), bins=num_bins, range=(start, end))
-
-    # 创建直方图图像(增加高度以便显示标签)
-    hist_img = np.zeros((450, 600, 3), dtype=np.uint8)
-    hist_img.fill(255)  # 白色背景
-
-    # 归一化直方图:确保高度在合理范围内
-    if len(hist) > 0 and np.max(hist) > 0:
-        # 保留10像素的边距
-        hist_normalized = cv2.normalize(hist, None, 0, 350, cv2.NORM_MINMAX)
-    else:
-        hist_normalized = np.zeros_like(hist)
-
-    # 计算合理的矩形宽度(确保至少1像素宽)
-    if len(hist) > 0:
-        bin_width = max(1, 600 // len(hist))  # 确保最小宽度为1
-    else:
-        bin_width = 1
-
-    # 绘制直方图矩形
-    for i in range(len(hist_normalized)):
-        if i >= 600:  # 防止索引越界
-            break
-
-        x1 = i * bin_width
-        x2 = min(x1 + bin_width, 599)  # 确保不超出图像边界
-        y1 = 400 - int(hist_normalized[i])  # 从底部开始计算高度
-        y2 = 400
-
-        # 只绘制有高度的矩形
-        if y1 < y2:
-            cv2.rectangle(hist_img, (x1, y1), (x2, y2), (0, 0, 255), -1)
-
-    # 添加坐标轴
-    cv2.line(hist_img, (50, 400), (550, 400), (0, 0, 0), 2)  # x轴
-    cv2.line(hist_img, (50, 400), (50, 50), (0, 0, 0), 2)  # y轴
-
-    # 添加标题和标签(调整位置避免被裁剪)
-    cv2.putText(hist_img, title, (10, 30),
-                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
-    cv2.putText(hist_img, 'gray', (280, 430),
-                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
-    cv2.putText(hist_img, 'fre', (5, 200),
-                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
-
-    return hist_img
-
-def plot_histogram_seaborn_simple(matrix, start, end, step, title):
-    """
-    使用Seaborn的histplot函数直接绘制(最简单的方法)
-    """
-    # 解决中文显示问题
-    plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
-    plt.rcParams['axes.unicode_minus'] = False
-
-    plt.figure(figsize=(12, 6))
-
-    # 直接使用Seaborn的histplot,设置bins参数[3](@ref)
-    num_bins = int((end - start) / step)
-    sns.histplot(matrix.ravel(), bins=num_bins, kde=True,
-                 color='skyblue', alpha=0.7,
-                 edgecolor='white', linewidth=0.5)
-
-    plt.title(title+ '_histplot')
-    plt.xlabel('灰度值')
-    plt.ylabel('像素频率')
-    plt.tight_layout()
-    # plt.show()
-    plt.savefig('temp_' + title +'.png')  # 保存到文件而不是显示
-    plt.close()  # 关闭图形,释放资源
-
-    # 如果需要显示,可以用cv2读取并显示
-    hist_img = cv2.imread('temp_' + title +'.png')
-    cv2.imshow('Histogram '+title, hist_img)
-
-def plot_histogram_seaborn(matrix, hist, bin_edges, start, end, step):
-    """
-    使用Seaborn绘制更美观的直方图
-    """
-    # 设置Seaborn样式
-    sns.set_style("whitegrid")
-    plt.figure(figsize=(12, 6))
-
-    # 将数据转换为适合Seaborn的格式
-    flattened_data = matrix.ravel()
-
-    # 使用Seaborn的histplot(会自动计算直方图,但我们用自定义的)
-    # 这里我们手动绘制以确保使用我们的自定义bins
-    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
-
-    # 创建条形图
-    bars = plt.bar(bin_centers, hist, width=step * 0.8,
-                   alpha=0.7, color=sns.color_palette("viridis", 1)[0],
-                   edgecolor='white', linewidth=0.5)
-
-    # 添加KDE曲线(核密度估计)
-    sns.kdeplot(flattened_data, color='red', linewidth=2, label='密度曲线')
-
-    plt.title(f'灰度分布直方图', fontsize=16)
-    plt.xlabel('灰度值', fontsize=12)
-    plt.ylabel('像素频率', fontsize=12)
-    plt.legend()
-    plt.tight_layout()
-    plt.show()
-
-
-
-def colorful_ndti_matrix(ndti_matrix):
-    """给NDTI矩阵着色"""
-    # 存放着色后的NDTI矩阵
-    color_ndti_matrix = np.zeros((ndti_matrix.shape[0], ndti_matrix.shape[1], 3), dtype=np.uint8)
-    negative_mask = ndti_matrix < 0
-    positive_mask = ndti_matrix > 0
-    # 处理负值
-    if np.any(negative_mask):
-        blue_intensity = ndti_matrix.copy()
-        blue_intensity[positive_mask] = 0
-        blue_intensity[negative_mask] *= -255.
-        blue_intensity = np.clip(blue_intensity, 0, 255).astype(np.uint8)
-        color_ndti_matrix[:, :,0] = blue_intensity
-    # 处理正值区域(红色渐变)
-    if np.any(positive_mask):
-        # 将0到+1映射到0-255的红色强度
-        red_intensity = ndti_matrix.copy()
-        red_intensity[negative_mask] = 0
-        red_intensity[positive_mask] *= 255
-        red_intensity = np.clip(red_intensity, 0, 255).astype(np.uint8)
-        color_ndti_matrix[:, :, 2] = red_intensity
-
-    # 返回着色后的ndti矩阵
-    return color_ndti_matrix
-
-def img_add(img1, img2, img1_w=1, img2_w=0, gamma=0):
-
-    if len(img1.shape) != len(img2.shape):
-        raise ValueError('img1 and img2 must have the same shape')
-
-    # 设置权重参数(透明度)
-    alpha = img1_w  # 第一张图像的权重
-    beta = img2_w # 第二张图像的权重
-    gamma =gamma  # 亮度调节参数
-
-    # 执行加权叠加
-    result = cv2.addWeighted(img1, alpha, img2, beta, gamma)
-    return result
-def callback(event, x, y, flags, param):
-    if event == cv2.EVENT_LBUTTONDOWN:
-        is_night_ = param['is_night']
-        frame = param['Image']
-        w_name = param['window_name']
-        b = frame[:,:,0].astype(np.float32)
-        g = frame[:,:,1].astype(np.float32)
-        r = frame[:,:,2].astype(np.float32)
-        if is_night_:
-            ndti_value = r[y, x]
-        else:
-            ndti_value = (r[y,x] - g[y,x]) / (r[y,x] + g[y,x])
-        cv2.putText(frame, f'{ndti_value:.2f}', (x, y), cv2.FONT_HERSHEY_DUPLEX, 0.6, (0, 255, 0), 2, cv2.LINE_AA)
-        cv2.putText(frame, f'paused', (10, 45), cv2.FONT_HERSHEY_DUPLEX, 0.6, (0, 255, 0), 2, cv2.LINE_AA)
-        cv2.imshow(w_name, frame)
-def single_img(img_path,mask_path, id=0):
-
-    frame = cv2.imread(img_path)
-    if frame is None:
-        raise RuntimeError('img open failed')
-    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
-    if mask is None:
-        raise RuntimeError('mask open failed')
-    scale = 4 if id != 3 else 2
-    mask = cv2.resize(mask, (mask.shape[1] // scale, mask.shape[0] // scale))
-    frame = cv2.resize(frame, (frame.shape[1] // scale, frame.shape[0] // scale))
-    roi_ndti = region_ndti(frame, mask)
-    # 给ndti着色
-    color_ndti = colorful_ndti_matrix(roi_ndti)
-
-    # 绘制直方图
-    roi_index = mask > 0
-    roi_ndti_fla = roi_ndti[roi_index]  # 仅保留感兴趣区的计算结果
-    hist_img = plot_histogram_opencv(roi_ndti[roi_index], -1, 1, 0.01, f'NDTI Histogram {id}')
-
-    # 打印统计信息
-    # 3σ原则
-    mean_value = np.mean(roi_ndti_fla)
-    std_value = np.std(roi_ndti_fla)
-    min_value = np.min(roi_ndti_fla)
-    max_value = np.max(roi_ndti_fla)
-
-    up_limit = mean_value + 3 * std_value
-    down_limit = mean_value - 3 * std_value
-    # 调整
-    roi_ndti_fla = roi_ndti_fla[(up_limit >= roi_ndti_fla) & (roi_ndti_fla >= down_limit)]
-    text = f'pixs:{int(np.sum(roi_ndti_fla))} adj_mean:{roi_ndti_fla.mean():.3f} adj_std:{roi_ndti_fla.std():.3f}'
-    cv2.putText(frame, text, (10, 25), cv2.FONT_HERSHEY_DUPLEX, 0.6, (0, 255, 0), 2, cv2.LINE_AA)
-    print(f"""统计信息:
-    总像素数: {np.sum(roi_ndti_fla):,}
-    灰度范围: {min_value:.1f} - {max_value:.1f}
-    平均值: {mean_value:.2f}
-    标准差: {std_value:.2f}
-    调整平均值:{roi_ndti_fla.mean():.2f}
-    调整标准差:{roi_ndti_fla.std():.2f}
-    """
-          )
-    # 显示当前帧处理结果
-    cv2.imshow('original', frame)  # 原图
-    cv2.imshow('mask'+ str(id), mask)   # 掩膜
-    roi_ndti = np.abs(roi_ndti*255.).astype(np.uint8)
-    cv2.imshow('ndti'+ str(id), roi_ndti)  # ndti黑白强度
-    cv2.imshow('color_ndti'+ str(id), color_ndti)  # # ndti彩色强度
-    # 图像叠加
-    add_img = img_add(frame, color_ndti)
-    cv2.imshow('add_ori_ndti' + str(id), add_img)
-    param = {'Image': frame}
-    cv2.setMouseCallback('original', callback, param=param)
-    cv2.waitKey(0)
-    cv2.destroyAllWindows()
-
-def main(video_dir, mask_dir, id):
-    # 视频分析浊度
-    video_path = video_dir
-    # 加载掩膜
-    mask_path = mask_dir
-    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
-    if mask is None:
-        raise RuntimeError('mask open failed')
-    # 除数倍率
-    scale = 4 if id != 3 else 2
-
-    mask = cv2.resize(mask, (mask.shape[1] // scale, mask.shape[0] // scale))
-    # 视频读取
-    cap = cv2.VideoCapture(video_path)
-    pause = False
-    # 检查视频是否成功打开
-    if not cap.isOpened():
-        print("错误:无法打开视频文件。")
-        print("可能的原因:文件路径错误,或系统缺少必要的解码器。")
-        exit()
-    else:
-        ret, frame = cap.read()
-
-    # 成功打开后,逐帧读取和显示视频
-    global_img = [None]
-    is_night = False
-    while True:
-        if not pause:
-            ret, frame = cap.read()  # ret指示是否成功读取帧,frame是图像数据
-
-            # 如果读取帧失败(可能是文件损坏或已到结尾),退出循环
-            if not ret:
-                print("视频播放完毕。")
-                break
-            # 判断是否为夜间模式
-            if np.mean(np.abs(frame[:, :, 0] - frame[:, :, 1])) < 0.1:
-                is_night = True
-            # 处理当前帧
-            # 缩放
-            frame = cv2.resize(frame, (frame.shape[1] // scale, frame.shape[0] // scale))
-            # cv2.imshow('original', frame)
-            # 滤波
-            # frame = cv2.GaussianBlur(frame, (5, 5), 1.5)
-            # 计算掩膜区域的NDTI值
-            roi_ndti = region_ndti(frame, mask, is_night)
-            # 给ndti着色
-            color_ndti = colorful_ndti_matrix(roi_ndti)
-
-            # 绘制直方图
-            roi_index = mask > 0
-            roi_ndti_fla = roi_ndti[roi_index]  # 仅保留感兴趣区的计算结果
-            # plot_histogram_seaborn_simple(roi_ndti_fla,-1,1,0.01, str(id))
-
-            # 打印统计信息
-            # 3σ原则
-            mean_value  = np.mean(roi_ndti_fla)
-            print('调整前平均值:', mean_value)
-            std_value = np.std(roi_ndti_fla)
-            print('调整前平均值:', std_value)
-            min_value = np.min(roi_ndti_fla)
-            max_value = np.max(roi_ndti_fla)
-
-            up_limit = mean_value + 3 * std_value
-            down_limit = mean_value - 3 * std_value
-            # 调整
-            roi_ndti_fla = roi_ndti_fla[(up_limit >= roi_ndti_fla) & (roi_ndti_fla >= down_limit)]
-            text = f'pixs:{int(np.sum(roi_ndti_fla))} adj_mean:{roi_ndti_fla.mean():.3f} adj_std:{roi_ndti_fla.std():.3f}'
-            cv2.putText(frame, text, (10, 25), cv2.FONT_HERSHEY_DUPLEX, 0.6, (0, 255, 0), 2, cv2.LINE_AA)
-            # print(f"""统计信息:
-            # 总像素数: {np.sum(roi_ndti_fla):,}
-            # 灰度范围: {min_value:.1f} - {max_value:.1f}
-            # 平均值: {mean_value:.2f}
-            # 标准差: {std_value:.2f}
-            # 调整平均值:{roi_ndti_fla.mean():.2f}
-            # 调整标准差:{roi_ndti_fla.std():.2f}
-            # """
-            #       )
-            # 显示当前帧处理结果
-            #cv2.imshow('original', frame)  # 原图
-            #cv2.imshow('mask'+ str(id), mask)   # 掩膜
-            #roi_ndti = np.abs(roi_ndti*255.).astype(np.uint8)
-            #cv2.imshow('ndti'+ str(id), roi_ndti)  # ndti黑白强度
-            #cv2.imshow('color_ndti'+ str(id), color_ndti)  # # ndti彩色强度
-            # 图像叠加
-            add_img = img_add(frame, color_ndti)
-            global_img[0] = add_img
-            cv2.imshow('add_ori_ndti'+ str(id), add_img)
-            #cv2.imshow('Histogram', hist_img)
-
-        # 播放帧率为25FPS,如果期间按下'q'键则退出循环
-        key = cv2.waitKey(500) & 0xFF
-        if key == ord(' '):
-            pause = not pause
-            status = "已暂停" if pause else "播放中"
-            print(f"{id} 状态: {status}")
-
-        if key  == ord('q'):
-            break
-        if pause:
-
-            if global_img[0] is not None:
-                param = {'Image': global_img[0],'window_name': 'add_ori_ndti' + str(id), 'is_night': is_night}
-                cv2.setMouseCallback('add_ori_ndti' + str(id), callback, param=param)
-
-    # 释放资源并关闭所有窗口
-    cap.release()
-
-
-if __name__ == '__main__':
-    set_chinese_font()
-
-    # single_img(r'D:\code\water_turbidity_det\data\day_202511211129\1_device_capture.jpg',
-    #             r'D:\code\water_turbidity_det\draw_mask\mask\1_main_20251119102036_@1000000.png')
-    # pass
-    # path1 = r'D:\code\water_turbidity_det\data\day_202511211129\1_video_202511211128.dav'
-    # path2 = r'D:\code\water_turbidity_det\data\day_202511211129\2_video_202511211128.dav'
-    # path3 = r'D:\code\water_turbidity_det\data\day_202511211129\3_video_202511211128.dav'
-    # path4 = r'D:\code\water_turbidity_det\data\day_202511211129\4_video_202511211128.dav'
-    #
-    # path1 = r'D:\code\water_turbidity_det\data\night\1_video_20251120.dav'
-    # path2 = r'D:\code\water_turbidity_det\data\night\2_video_20251120_1801.dav'
-    # path3 = r'D:\code\water_turbidity_det\data\night\3_video_20251120_1759.dav'
-    # path4 = r'D:\code\water_turbidity_det\data\night\4_video_20251120_1800.dav'
-    #
-    # t1 = threading.Thread(target=main, kwargs={'video_dir': path1,
-    #                                            'mask_dir': r'D:\code\water_turbidity_det\draw_mask\mask\1_main_20251119102036_@1000000.png',
-    #                                            'id':1})
-    # t2 = threading.Thread(target=main, kwargs={'video_dir': path2,
-    #                                            'mask_dir': r'D:\code\water_turbidity_det\draw_mask\mask\2_main_20251119102038_@1000000.png',
-    #                                            'id':2})
-    # t3 = threading.Thread(target=main, kwargs={'video_dir': path3,
-    #                                            'mask_dir': r'D:\code\water_turbidity_det\draw_mask\mask\3_main_20251119102042_@1000000.png',
-    #                                            'id':3})
-    # t4 = threading.Thread(target=main, kwargs={'video_dir': path4,
-    #                                            'mask_dir': r'D:\code\water_turbidity_det\draw_mask\mask\4_main_20251119102044_@1000000.png',
-    #                                            'id':4})
-    # # threads = [t1, t2, t3, t4]
-    # threads = [t4]
-    # for t in threads:
-    #     t.start()
-    # for t in threads:
-    #     if t.is_alive():
-    #         t.join()
-    main(video_dir=r'/video/records_202512031553\video4_20251129053218_20251129060528.mp4',
-         mask_dir=r'D:\code\water_turbidity_det\draw_mask\mask\4_device_capture.png',
-         id=4)
-    cv2.destroyAllWindows()
-
-

+ 0 - 0
check/检查标注结果是否正确.md


+ 0 - 0
data/标注好的视频帧.md


+ 0 - 80
draw_mask/draw.py

@@ -1,80 +0,0 @@
-import cv2
-import numpy as np
-import os
-
-class DrawRectangle:
-
-    def __init__(self, div_scale):
-        self.mask_save_dir = './mask'
-        self.scale = div_scale
-        self.current_roi_points = []
-        self.rois = []
-        self.window_title = "Image - Select ROI"
-        self.draw_complete = False
-        pass
-
-    def callback(self, event, x, y, flags, param):
-
-        drawing_image = param['Image']
-        # 左键添加感兴趣点
-        if event == cv2.EVENT_LBUTTONDOWN:
-            # 添加感兴趣点
-            self.current_roi_points.append((x, y))
-            # 绘制标记点
-            cv2.circle(drawing_image, (x, y), 4, (0, 255, 0), -1)
-            # 绘制多边形
-            if len(self.current_roi_points) > 1:
-                cv2.line(drawing_image, self.current_roi_points[-2], self.current_roi_points[-1], (0, 255, 0), 2)
-            # 显示图像
-            cv2.imshow(self.window_title, drawing_image)
-            print(f'添加感兴趣点:{y}行, {x}列')
-        # 右键闭合感兴趣区域
-        if event == cv2.EVENT_RBUTTONDOWN:
-            if len(self.current_roi_points) < 3:
-                print("[提示] ROI 至少需要 3 个点构成多边形!")
-                return
-            cv2.line(drawing_image, self.current_roi_points[-1], self.current_roi_points[0], (0, 255, 0), 2)
-            cv2.imshow(self.window_title, drawing_image)
-            # 清理
-            self.rois.append(self.current_roi_points)
-            print(f'添加感兴趣区,包含点数:{len(self.current_roi_points)}个')
-            self.current_roi_points = []
-
-
-    def draw(self, img_path: str):
-        """在输入图像中绘制多边形区域,然后生成相应的mask图片"""
-        # 读取图像
-        ori_img = cv2.imread(img_path)
-        mask_base_name = os.path.splitext(os.path.basename(img_path))[0] + '.png'
-        img = cv2.resize(ori_img, (ori_img.shape[1] // self.scale, ori_img.shape[0] // self.scale))
-        if img is None:
-            raise RuntimeError('Cannot read the image!')
-        param = {'Image': img}
-        cv2.namedWindow(self.window_title)
-        cv2.setMouseCallback(self.window_title, self.callback, param=param)
-
-
-        # 显示图像并等待退出
-        while True:
-            cv2.imshow(self.window_title, img)
-            key = cv2.waitKey(1) & 0xFF
-            if key == ord('q') or key == 27:  # 按'q'或ESC键退出
-                break
-
-        # 为原图生成掩膜
-        mask = np.zeros((ori_img.shape[0], ori_img.shape[1]),dtype=np.uint8)  # shape等于原始输入图像
-        for roi in self.rois:
-            roi_points = np.array(roi, np.int32).reshape((-1, 1, 2)) * self.scale  # 兴趣点的缩放处理
-            cv2.fillPoly(mask, [roi_points], 255)
-
-        # 保存掩膜图像
-        if not os.path.exists(self.mask_save_dir):
-            os.makedirs(self.mask_save_dir)
-        cv2.imwrite(os.path.join(self.mask_save_dir, mask_base_name), mask)
-        # cv2.imshow("mask", mask)
-        # cv2.waitKey(0)
-        cv2.destroyAllWindows()
-
-if __name__ == '__main__':
-    drawer = DrawRectangle(2)
-    drawer.draw(r"D:\code\water_turbidity_det\draw_mask\mask\4_device_capture.jpg")

BIN
draw_mask/mask/1_device_capture.png


BIN
draw_mask/mask/2_device_capture.png


BIN
draw_mask/mask/3_device_capture.png


BIN
draw_mask/mask/4_device_capture.png


+ 0 - 62
labelme/check_label.py

@@ -1,62 +0,0 @@
-"""检查标注是否正确,读取标签,信息"""
-import cv2
-import os
-import sys
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-from utils import draw_grid
-from dotenv import load_dotenv
-load_dotenv() # 加载环境变量
-
-patch_w = int(os.getenv('PATCH_WIDTH', 256))
-patch_h = int(os.getenv('PATCH_HEIGHT', 256))
-scale = 2
-clicked_points = []  # u v w h cls
-
-def main():
-    """检查标注结果是否正确"""
-    # 标注文件路径
-    global patch_w
-    global patch_h
-    global scale
-    global clicked_points
-    # TODO:修改为要检查的图片路径
-    imgs_path = r'D:\code\water_turbidity_det\data\video4_20251129120320_20251129123514'
-    output_root = r'D:\code\water_turbidity_det\check'
-    label_path = os.path.join(imgs_path, 'label.txt')
-    if os.path.exists(label_path):
-        with open(label_path, 'r') as fr:
-            lines = fr.readlines()
-        lines = [line.strip() for line in lines]
-        for line in lines:
-            point = line.split(',')
-            clicked_points.append([int(point[0])//scale, int(point[1])//scale, int(point[2])//scale, int(point[3])//scale, int(point[4])])
-        del lines
-    # 检查结果输出路径
-    output_path = os.path.join(output_root, os.path.basename(imgs_path)+'_check')
-    if not os.path.exists(output_path):
-        os.makedirs(output_path)
-    # 获取所有照片
-    all_imgs = os.listdir(imgs_path)
-    all_imgs = [img for img in all_imgs if img.split('.')[-1] == 'jpg' or img.split('.')[-1] == 'png']
-
-    for img in all_imgs:
-        img_path = os.path.join(imgs_path, img)
-        img = cv2.imread(img_path)
-        img = cv2.resize(img, (img.shape[1] // scale, img.shape[0] // scale))
-        # 绘制网格线
-        img = draw_grid(img, patch_w // scale, patch_h // scale)
-        # 绘制标记点
-        for point in clicked_points:
-            # 计算中心点
-            center_x = point[0]+point[2]//scale
-            center_y = point[1]+point[3]//scale
-            cv2.circle(img, (center_x, center_y), 5, (0, 0, 255), -1)
-            # 显示标签文本
-            cv2.putText(img, str(point[4]), (center_x + 10, center_y + 10),
-                        cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
-        cv2.imwrite(os.path.join(output_path, os.path.basename(img_path)),
-                    img)
-        print(f"检查结果保存在: {os.path.join(output_path, os.path.basename(img_path))}")
-
-if __name__ == '__main__':
-    main()

+ 0 - 84
labelme/crop_patch.py

@@ -1,84 +0,0 @@
-# 根据标注文件,生成patch,每个类别放在一个文件夹下
-import os
-import sys
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-import numpy as np
-import cv2
-import gc
-from dotenv import load_dotenv
-load_dotenv()
-
-# 从图像中截取patch_w*patch_h大小的图像,并打上标签
-patch_w = int(os.getenv('PATCH_WIDTH', 256))
-patch_h = int(os.getenv('PATCH_HEIGHT', 256))
-
-def main():
-    # TODO:需要修改为标注好的图片路径
-    input_path = r'D:\code\water_turbidity_det\data\4_video_202511211127'
-    # TODO: 需要修改为保存patch的根目录
-    output_path_root = r'D:\code\water_turbidity_det\label_data\test'
-
-    # 读取标注文件
-    label_path = os.path.join(input_path, 'label.txt')
-    if not os.path.exists(label_path):  # 强制要求必须标注点什么
-        raise FileNotFoundError(f"{label_path} 不存在")
-    with open(label_path, 'r') as fr:
-        lines = fr.readlines()
-        lines = [line.strip() for line in lines]
-    # 恢复标注的格网
-    grids_info = [] # clicked_points
-    for line in lines:
-        point = line.split(',')
-        grids_info.append([int(point[0]), int(point[1]), int(point[2]), int(point[3]), int(point[4])])
-    # 我们先创建一些类别文件夹
-    # 0类
-    if not os.path.exists(os.path.join(output_path_root, str(0))):
-        os.makedirs(os.path.join(output_path_root, str(0)))
-    # 其余类
-    for grid in grids_info:
-        if grid[4] <= 0:
-            continue
-        if not os.path.exists(os.path.join(output_path_root, str(grid[4]))):
-            os.makedirs(os.path.join(output_path_root, str(grid[4])))
-    # 获取图像
-    all_imgs = [os.path.join(input_path, i) for i in os.listdir(input_path) if i.split('.')[-1] == 'jpg' or i.split('.')[-1] == 'png']
-    for img_path in all_imgs:
-        img_base_name = os.path.basename(img_path).split('.')[0]
-        img = cv2.imread(img_path)
-        # 获取图像高宽
-        img_h, img_w, _ = img.shape
-        # 先将不参与训练的patch重置为0
-        for g in grids_info:
-            if g[4] < 0:  # 标签小于零的不参与训练
-                img[g[1]:min(g[1]+g[3], img_h), g[0]:min(g[0]+g[2], img_w), :] = 0
-        # 再将大于0的patch保存到对应的类别文件夹下
-        for g in grids_info:
-            if g[4] > 0:  # 标签大于零的放到相应的文件夹下
-                patch_name = f'{img_base_name}_{g[0]}_{g[1]}_{g[4]}.jpg'  # 图块保存名称:图片名_左上角x_左上角y_类别.jpg
-                patch = img[g[1]:min(g[1]+g[3], img_h), g[0]:min(g[0]+g[2], img_w), :]
-                # 保存图块
-                cv2.imwrite(os.path.join(output_path_root,str(g[4]), patch_name), patch)
-                # 置零已经保存的patch区域
-                img[g[1]:min(g[1]+g[3], img_h), g[0]:min(g[0]+g[2], img_w), :] = 0
-        # 最后将剩余的patch保存到0类文件夹下
-        for i in range(img_h // patch_h + 1):
-            for j in range(img_w // patch_w + 1):
-                patch = img[i*patch_h:min(i*patch_h+patch_h, img_h), j*patch_w:min(j*patch_w+patch_w, img_w), :]
-                patch_name = f'{img_base_name}_{j*patch_w}_{i*patch_h}_0.jpg'
-                # 长宽比过滤
-                if patch.shape[0] / (patch.shape[1]+1e-6) > 1.314 or patch.shape[0] / (patch.shape[1]+1e-6) < 0.75:
-                    #print(f"长宽比过滤: {patch_name}")
-                    continue
-                # 纯黑图像过滤
-                if np.mean(patch) < 10.10:
-                    #print(f"纯黑图像过滤: {patch_name}")
-                    continue
-                cv2.imwrite(os.path.join(output_path_root, '0', patch_name), patch)
-                #print(f"保存图块: {patch_name}到{os.path.join(output_path_root, '0', patch_name)}")
-        print(f"处理图片: {img_path}完成")
-        # del patch, img
-        # gc.collect()
-
-if __name__ == '__main__':
-    main()

+ 0 - 165
labelme/fixed_label.py

@@ -1,165 +0,0 @@
-# 要求保证视频不能移动,且全过程没有任何遮挡物
-import os
-import sys
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-from utils import draw_grid
-
-import cv2
-import numpy as np
-import tkinter as tk
-from tkinter import simpledialog
-from dotenv import load_dotenv
-load_dotenv()  # 加载环境变量
-
-# 所要切分的图块宽高
-patch_w = int(os.getenv('PATCH_WIDTH', 256))
-patch_h = int(os.getenv('PATCH_HEIGHT', 256))
-scale = 2
-# 存储标记点
-clicked_points = []
-
-def get_text_input(prompt):
-    """创建弹窗获取文本输入"""
-    root = tk.Tk()
-    root.withdraw()  # 隐藏主窗口
-    root.attributes('-topmost', True)  # 确保弹窗置顶
-    result = simpledialog.askstring(' ', prompt)
-    root.destroy()
-    if result is None:
-        result = ""
-    if not result.strip().isdigit():
-        result = -1
-    return int(result)
-
-def mouse_callback(event, x, y, flags, param):
-    """
-    鼠标回调函数
-    """
-    global clicked_points
-    global patch_w
-    global patch_h
-    global scale
-
-    if event == cv2.EVENT_LBUTTONDOWN:  # 左键点击
-        # 在点击位置绘制红色圆点
-        scale_patch_w = patch_w // scale
-        scale_patch_h = patch_h // scale
-        # 格子角点
-        circle_x_corner = (x // scale_patch_w)*scale_patch_w
-        circle_y_corner = (y // scale_patch_h)*scale_patch_h
-        # 格子中心点
-        circle_x_center = circle_x_corner + scale_patch_w//2
-        circle_y_center = circle_y_corner + scale_patch_h//2
-        cv2.circle(param, (circle_x_center, circle_y_center), 5, (0, 0, 255), -1)
-        cv2.circle(param, (circle_x_corner, circle_y_corner), 5, (255, 0, 0), -1)
-
-        # 更新显示
-        cv2.imshow('img', param)
-
-        cls = get_text_input('请输入类别:0.背景 1.浑浊 -1.不参与')
-        # 显示标签文本
-        cv2.putText(param, str(cls), (circle_x_center + 10, circle_y_center + 10),
-                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
-        # 更新显示
-        cv2.imshow('img', param)
-        valid_cls = [0, 1, -1]
-        if cls in valid_cls:
-            print(f"点击网格角点: ({circle_x_corner}, {circle_y_corner}) 中心点: ({circle_x_center}, {circle_y_center}) 类别:{cls}")
-            # 记录标注数据,角点 u v w h cls
-            clicked_points.append([circle_x_corner, circle_y_corner, scale_patch_w, scale_patch_h, cls])
-        else:
-            print("请输入正确的类别!")
-    elif event == cv2.EVENT_RBUTTONDOWN:  # 右键点击
-        removed_point = clicked_points.pop()
-        print(f"撤销标注点: ({removed_point[0]}, {removed_point[1]}) 类别: {removed_point[4]}")
-        # 将撤销点标记为黑色
-        x = removed_point[0]
-        y = removed_point[1]
-        # 在点击位置绘制黑色圆点
-        scale_patch_w = patch_w // scale
-        scale_patch_h = patch_h // scale
-        # 格子角点
-        circle_x_corner = (x // scale_patch_w)*scale_patch_w
-        circle_y_corner = (y // scale_patch_h)*scale_patch_h
-        # 格子中心点
-        circle_x_center = circle_x_corner + scale_patch_w//2
-        circle_y_center = circle_y_corner + scale_patch_h//2
-        cv2.circle(param, (circle_x_center, circle_y_center), 5, (128, 128, 128), -1)
-        cv2.circle(param, (circle_x_corner, circle_y_corner), 5, (128, 128, 128), -1)
-        # 更新显示
-        cv2.imshow('img', param)
-
-
-def remove_duplicates(arr:list):
-    """列表去重"""
-    unique_list = []
-    [unique_list.append(item) for item in arr if item not in unique_list]
-    return unique_list
-
-def main():
-    """
-    固定摄像头标注,只需要标注一张图像,后续图像保持一致
-    1.标注过程,先将图像划分为图框,用cv2划线工具在图像上划网格线
-    2.用鼠标进行交互,点击图块输入标签,按下空格键完成交互过程,保存标签
-    3.标签格式:u,v,w,h,label u,v为块左上角坐标,w,h为块的宽和高,label为块的标签
-    """
-    global clicked_points
-    global patch_w
-    global patch_h
-    global scale
-    # TODO: 需要更改为准备标注的图像路径,使用当前目录下的000000.jpg,结果保存在当前目录下label.txt
-    img_path = r'D:\code\water_turbidity_det\data\video1_20251129120104_20251129123102\000000.jpg'
-    img = cv2.imread(img_path)
-    # resize 图像太大了显示不全
-    img = cv2.resize(img, (img.shape[1] // scale, img.shape[0] // scale))
-    # 绘制网格线
-    draw_grid(img, patch_w // scale, patch_h // scale)
-    # 交互标注
-    print("操作说明:")
-    print("- 点击鼠标左键在图像上添加红色标记点: 0.其他 1.浑浊 -1.忽略,不参与训练和测试")
-    print("- 按 'c' 键清除所有标记点")
-    print("- 按 ESC 键退出程序")
-    cv2.namedWindow('img')
-    cv2.setMouseCallback('img', mouse_callback, img)
-    # 交互标注
-    while True:
-        cv2.imshow('img', img)
-        key = cv2.waitKey(1) & 0xFF
-
-        # 按 'c' 键清除所有标记点
-        if key == ord('c'):
-            img = cv2.imread(img_path)
-            img = cv2.resize(img, (img.shape[1] // scale, img.shape[0] // scale))
-            draw_grid(img, patch_w // scale, patch_h // scale)
-            clicked_points.clear()
-            cv2.setMouseCallback('img', mouse_callback, img)
-
-        # 按 ESC 键退出
-        elif key == 27:  # ESC键
-            break
-    cv2.destroyAllWindows()
-    # 输出所有点击位置
-    # 列表去重
-    clicked_points = remove_duplicates(clicked_points)
-
-    print(f"总共标记了 {len(clicked_points)} 个点:")
-    for i, point in enumerate(clicked_points):
-        print(f"  点 {i + 1}: ({point[0]}, {point[1]}, {point[2]}, {point[3]}, {point[4]})")
-    # 恢复尺寸
-    clicked_points = [[p[0]*scale, p[1]*scale, p[2]*scale, p[3]*scale, p[4]] for p in clicked_points]
-    # 写入txt
-    if clicked_points:
-        with open(os.path.join(os.path.dirname(img_path), 'label.txt'), 'w') as fw:
-            for point in clicked_points:
-                fw.write(f"{point[0]},{point[1]},{point[2]},{point[3]},{point[4]}\n")
-        # 保存点
-        print(f"保存标记点 {len(clicked_points)} 个:")
-        for i, point in enumerate(clicked_points):
-            print(f"  点 {i + 1}: ({point[0]}, {point[1]}, {point[2]}, {point[3]}, {point[4]})")
-    else :
-        print("没有标记点!不保存任何文件!")
-
-
-
-if __name__ == '__main__':
-    main()

+ 0 - 26
labelme/utils.py

@@ -1,26 +0,0 @@
-import numpy as np
-import cv2
-
-def draw_grid(img: np.ndarray, grid_w: int, grid_h:int):
-    """划格网"""
-    img_h, img_w, _ = img.shape
-    # 绘制横向网格线
-    for i in range((img_h // grid_h)+1):
-        cv2.line(img, (0, i*grid_h), (img_w, i*grid_h), (0, 255, 0), 2)
-    # 绘制纵向网格线
-    for i in range((img_w // grid_w)+1):
-        cv2.line(img, (i*grid_w, 0), (i*grid_w, img_h), (0, 255, 0), 2)
-    return img
-
-def draw_predict_grid(img, patches_index, predicted_class, confidence):
-    """在img上绘制预测结果"""
-    for i, (idx_w, idx_h) in enumerate(patches_index):
-        cv2.circle(img, (idx_w, idx_h), 10, (0, 255, 0), -1)
-        text1 = f'cls:{predicted_class[i]}'
-        text2 = f'prob:{confidence[i] * 100:.1f}%'
-        color = (0, 0, 255) if predicted_class[i] else (255, 0, 0)
-        cv2.putText(img, text1, (idx_w, idx_h + 128),
-                    cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
-        cv2.putText(img, text2, (idx_w, idx_h + 128 + 25),
-                    cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
-    return img

+ 0 - 34
labelme/video_depart.py

@@ -1,34 +0,0 @@
-# 将视频图像分离为单帧序列
-import cv2
-import os
-
-def main():
-    # 视频路径
-    # TODO: 修改视频路径为自己的视频路径
-    path = r'/video/day_202511211129\4_video_202511211127.dav'
-    output_rootpath = r'D:\code\water_turbidity_det\data'  # 输出路径的根目录
-
-    # 我们将图像输出到根目录下的子目录中,子目录和视频名称相同
-    # 创建路径
-    img_base = os.path.basename(path).split('.')[0]
-    imgs_output_path = os.path.join(output_rootpath, img_base)  # 视频名称不要有.符号混淆后缀名
-    if not os.path.exists(imgs_output_path):
-        os.mkdir(imgs_output_path)
-
-    # 打开视频文件
-    cap = cv2.VideoCapture(path)
-    # 检查视频是否成功打开
-    if not cap.isOpened():
-        raise Exception('错误:无法打开视频文件。')
-    frame_count = 0
-    while True:
-        ret, frame = cap.read()
-        if not ret:
-            break
-        img_save_path = os.path.join(imgs_output_path,f"{frame_count:06d}.jpg")
-        cv2.imwrite(img_save_path, frame)
-        frame_count += 1
-        print(f"已处理{frame_count}帧, 保存至{img_save_path}")
-if __name__ == '__main__':
-    main()
-

+ 0 - 4
main.py

@@ -1,4 +0,0 @@
-import torch
-from torchvision.models import resnet50, ResNet50_Weights
-
-

+ 0 - 95
play.py

@@ -1,95 +0,0 @@
-import cv2
-import numpy as np
-from calculator import region_ndti, colorful_ndti_matrix
-import threading
-
-
-def main(video_path, video_mask_path, video_id):
-
-    records = []
-
-    cap = cv2.VideoCapture(video_path)
-    if not cap.isOpened():
-        print("错误:无法打开视频文件。")
-        print("可能的原因:文件路径错误,或系统缺少必要的解码器。")
-        exit()
-    scale = 2
-    mask = cv2.imread(video_mask_path, cv2.IMREAD_GRAYSCALE)
-    if mask is None:
-        raise RuntimeError('mask open failed')
-    mask = cv2.resize(mask, (mask.shape[1] // scale, mask.shape[0] // scale))
-
-    fps = cap.get(cv2.CAP_PROP_FPS)
-    print("fps:", fps)
-    while True:
-
-        ret, frame = cap.read()  # ret指示是否成功读取帧,frame是图像数据
-        if not ret:
-            print("视频播放完毕。")
-            break
-        # 判断是否为夜间模式
-        if np.mean(np.abs(frame[:, :, 0] - frame[:, :, 1])) < 0.1:
-            is_night = True
-        else:
-            is_night = False
-        frame = cv2.resize(frame, (frame.shape[1] // scale, frame.shape[0] // scale))
-
-        # 计算掩膜区域的NDTI值
-        roi_ndti = region_ndti(frame, mask, is_night)
-
-
-        # 给ndti着色
-        color_ndti = colorful_ndti_matrix(roi_ndti)
-
-
-        # 仅保留感兴趣区的计算结果
-        roi_index = mask > 0
-        roi_ndti_fla = roi_ndti[roi_index]
-        # 3σ原则
-        # mean_value = np.mean(roi_ndti_fla)
-        # std_value = np.std(roi_ndti_fla)
-        # up_limit = mean_value + 3 * std_value
-        # down_limit = mean_value - 3 * std_value
-        # roi_ndti_fla = roi_ndti_fla[(up_limit >= roi_ndti_fla) & (roi_ndti_fla >= down_limit)]
-
-        text = f'pixs:{int(np.sum(roi_ndti_fla))} mean:{roi_ndti_fla.mean():.3f} std:{roi_ndti_fla.std():.3f} {'night' if is_night else 'day'}'
-        cv2.putText(frame, text, (10, 25), cv2.FONT_HERSHEY_DUPLEX, 0.6, (0, 255, 0), 2, cv2.LINE_AA)
-        cv2.imshow(f'original_{video_id}', frame)
-        cv2.imshow(f'ndti_color_{video_id}', color_ndti)
-        cv2.waitKey(30)
-
-        # 退出逻辑
-        key = cv2.waitKey(10) & 0xFF
-        if key  == ord('q'):
-            break
-
-if __name__ == '__main__':
-    # main(video_path=r'D:\code\water_turbidity_det\data\night\4_video_20251120_1800.dav',
-    #      video_mask_path=r'D:\code\water_turbidity_det\draw_mask\mask\4_device_capture.png',
-    #      video_id=1)
-
-    t1 = threading.Thread(target=main,
-                     args=(r'D:\code\water_turbidity_det\data\records_202512031553\video4_20251129120320_20251129123514.mp4',
-                           r'D:\code\water_turbidity_det\draw_mask\mask\4_device_capture.png',
-                           4))
-
-    # t2 = threading.Thread(target=main,
-    #                  args=(r'D:\code\water_turbidity_det\data\day_202511211129\4_video_202511211127.dav',
-    #                        r'D:\code\water_turbidity_det\draw_mask\mask\4_device_capture.png',
-    #                        4))
-
-    # t1 = threading.Thread(target=main,
-    #                  args=(r'D:\code\water_turbidity_det\data\records_202512031553\video1_20251129055729_20251129063102.mp4',
-    #                        r'D:\code\water_turbidity_det\draw_mask\mask\1_device_capture.png',
-    #                        1))
-    #
-    # t2 = threading.Thread(target=main,
-    #                  args=(r'D:\code\water_turbidity_det\data\records_202512031553\video4_20251129053218_20251129060528.mp4',
-    #                        r'D:\code\water_turbidity_det\draw_mask\mask\4_device_capture.png',
-    #                        4))
-    t1.start()
-    # t2.start()
-
-    t1.join()
-    # t2.join()
-    cv2.destroyAllWindows()

+ 0 - 39
pth2onnx.py

@@ -1,39 +0,0 @@
-import torch
-import torch.onnx
-from torchvision.models import resnet50, ResNet50_Weights
-from torch import nn
-
-if __name__ == '__main__':
-    input = torch.randn(1, 3, 256, 256)          # [1,3,224,224]分别对应[B,C,H,W]
-    # 载入模型框架
-    model = resnet50()
-    # model.fc = nn.Sequential(
-    #     nn.Linear(int(model.fc.in_features), int(model.fc.in_features) // 2, bias=True),
-    #     nn.ReLU(inplace=True),
-    #     nn.Dropout(0.5),
-    #     nn.Linear(int(model.fc.in_features) // 2, 2, bias=False)
-    # )
-
-    # model.load_state_dict(torch.load("resnet50_best_model_acc.pth")) # xxx.pth表示.pth文件, 这一步载入模型权重
-    model.load_state_dict(torch.load(r'D:\code\water_turbidity_det\resnet50-11ad3fa6.pth')) # xxx.pth表示.pth文件, 这一步载入模型权重
-    model.eval()                                 # 设置模型为推理模式
-   # print(model)
-   # model = torch.jit.script(model)  # 先转换为TorchScript
-    torch.onnx.export(model,
-                      input,
-                      "resnet50_best_model_acc.onnx",
-                      training=torch.onnx.TrainingMode.EVAL,
-                      opset_version=18,
-                      export_params=True,
-                      do_constant_folding=True,
-                      input_names=['input'],
-                      output_names=['output']
-                      )  # xxx.onnx表示.onnx文件, 这一步导出为onnx模型, 并不做任何算子融合操作。
-
-    # 验证模型
-    import onnx
-    model = onnx.load("resnet50_best_model_acc.onnx")
-    onnx.checker.check_model(model)  # 验证模型完整性
-
-    #mean_mlir = [0.485×255, 0.456×255, 0.406×255] = [123.675, 116.28, 103.53]
-    #scale_mlir = [0.229*255, 0.224*255, 0.225*255] = [58.395, 57.12, 57.375]

+ 0 - 14
readme.md

@@ -1,14 +0,0 @@
-初始化
-
-
-浊度检测深度学习方案:
-一、样本标注
-我们首先将一张图片划分为多个小块,每个小块都对应一个标签。标签类别如下:
-0:表示背景;1:表示清澈水体;2:表示浑浊水体
-
-1.工具开发
-添加图片标注工具,对原始图像进行标注,生成对应的标签文件,存储格式为:
-文件名称:与图片名称一致,后缀为txt;
-文件内容:对应图片的标签信息,每行表示一个块;
-文件格式:u,v,w,h,label u,v为块左上角坐标,w,h为块的宽和高,label为块的标签
-添加标注信息展示工具,将图片和标签信息进行展示,方便用户查看标注信息

+ 0 - 265
roi_creator_mp.py

@@ -1,265 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-代码一:定义数据并保存(支持多个ROI)
-
-功能:
-1. 加载图一。
-2. 允许用户交互式地绘制多个多边形 ROI,并为每个ROI命名。
-3. 允许用户交互式地绘制多个矩形屏蔽区。
-4. 将所有ROI坐标、名称、屏蔽区坐标和相关参数打包保存到一个 JSON 文件中。
-"""
-import cv2
-import numpy as np
-import json
-import os
-import base64
-
-# ========== 全局变量 ==========
-# --- ROI 绘制相关 ---
-current_roi_points = []
-all_rois = []  # 存储所有ROI及其名称
-drawing_complete = False
-drawing_image = None
-original_image = None
-# --- 屏蔽区绘制相关 ---
-ignore_polygons = []
-drawing_ignore = None
-drawing_rect = False
-rect_start_point = None
-
-# ========== 可配置参数 ==========
-CONFIG_PARAMS = {
-    "MARGIN_RATIO": 0.25,
-    "MIN_MATCH_COUNT": 10,
-    "USE_CLAHE": True
-}
-
-
-# -------------------------------------------------
-# 1. 回调函数:选择 ROI
-# -------------------------------------------------
-def select_roi(event, x, y, flags, param):
-    """用于绘制 ROI 多边形的回调函数"""
-    global current_roi_points, drawing_complete, drawing_image
-
-    if drawing_complete:
-        return
-
-    if event == cv2.EVENT_LBUTTONDOWN:
-        current_roi_points.append((x, y))
-        cv2.circle(drawing_image, (x, y), 4, (0, 255, 0), -1)
-        if len(current_roi_points) > 1:
-            cv2.line(drawing_image, current_roi_points[-2], current_roi_points[-1], (0, 255, 0), 2)
-        cv2.imshow("Image 1 - Select ROI", drawing_image)
-
-    elif event == cv2.EVENT_RBUTTONDOWN:
-        if len(current_roi_points) < 3:
-            print("[提示] ROI 至少需要 3 个点构成多边形!")
-            return
-        drawing_complete = True
-        cv2.line(drawing_image, current_roi_points[-1], current_roi_points[0], (0, 255, 0), 2)
-        cv2.imshow("Image 1 - Select ROI", drawing_image)
-        print("[提示] 当前ROI选择完毕,按任意键继续...")
-
-
-# -------------------------------------------------
-# 2. 回调函数:选择屏蔽区
-# -------------------------------------------------
-def select_ignore_rect(event, x, y, flags, param):
-    """用于绘制屏蔽矩形的回调函数"""
-    global rect_start_point, drawing_rect, ignore_polygons, drawing_ignore
-
-    if event == cv2.EVENT_LBUTTONDOWN:
-        rect_start_point = (x, y)
-        drawing_rect = True
-
-    elif event == cv2.EVENT_MOUSEMOVE:
-        if drawing_rect:
-            preview_img = drawing_ignore.copy()
-            cv2.rectangle(preview_img, rect_start_point, (x, y), (0, 0, 255), 2)
-            cv2.imshow("Image 1 - Ignore Area", preview_img)
-
-    elif event == cv2.EVENT_LBUTTONUP:
-        if not drawing_rect:
-            return
-        drawing_rect = False
-        x_start, y_start = rect_start_point
-        x1, y1 = min(x_start, x), min(y_start, y)
-        x2, y2 = max(x_start, x), max(y_start, y)
-
-        if x2 - x1 > 5 and y2 - y1 > 5:
-            cv2.rectangle(drawing_ignore, (x1, y1), (x2, y2), (0, 0, 255), 2)
-            rect_as_poly = [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
-            ignore_polygons.append(rect_as_poly)
-            print(f"[提示] 已新增屏蔽区 #{len(ignore_polygons)},可继续拉框,完成后按 Enter 键。")
-        cv2.imshow("Image 1 - Ignore Area", drawing_ignore)
-
-
-# -------------------------------------------------
-# 3. 工具函数
-# -------------------------------------------------
-def read_image_chinese(path):
-    """读取可能包含中文路径的图片"""
-    try:
-        with open(path, 'rb') as f:
-            data = np.frombuffer(f.read(), np.uint8)
-        return cv2.imdecode(data, cv2.IMREAD_COLOR)
-    except Exception as e:
-        print(f"读取图片失败: {e}")
-        return None
-
-
-def draw_all_rois(image, rois_list):
-    """在图像上绘制所有已保存的ROI"""
-    result = image.copy()
-    colors = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0),
-              (255, 0, 255), (0, 255, 255), (128, 255, 0), (255, 128, 0)]
-
-    for idx, roi_data in enumerate(rois_list):
-        color = colors[idx % len(colors)]
-        points = roi_data["points"]
-        name = roi_data["name"]
-
-        # 绘制多边形
-        cv2.polylines(result, [np.array(points, np.int32)], True, color, 2)
-
-        # 计算中心点并显示名称
-        center = np.mean(points, axis=0).astype(int)
-        cv2.putText(result, name, tuple(center), cv2.FONT_HERSHEY_SIMPLEX,
-                    0.7, color, 2, cv2.LINE_AA)
-
-    return result
-
-
-# -------------------------------------------------
-# 4. 主函数
-# -------------------------------------------------
-def define_and_save_data(image1_path, output_json_path):
-    """
-    主流程函数:加载图像,引导用户选择多个ROI并命名,最终保存数据。
-    """
-    global current_roi_points, drawing_complete, drawing_image, all_rois
-    global drawing_ignore, ignore_polygons, original_image
-
-    # --- 步骤 1: 加载图像 ---
-    img1 = read_image_chinese(image1_path)
-    if img1 is None:
-        print(f"错误:无法加载图像 '{image1_path}',请检查路径!")
-        return
-
-    original_image = img1.copy()
-
-    # --- 步骤 2: 循环选择多个ROI ---
-    print("\n========== 开始选择ROI ==========")
-    roi_count = 0
-
-    while True:
-        roi_count += 1
-        current_roi_points = []
-        drawing_complete = False
-
-        # 显示已有的ROI
-        if all_rois:
-            drawing_image = draw_all_rois(original_image, all_rois)
-        else:
-            drawing_image = original_image.copy()
-
-        cv2.namedWindow("Image 1 - Select ROI")
-        cv2.setMouseCallback("Image 1 - Select ROI", select_roi)
-
-        print(f"\n【ROI #{roi_count}】")
-        print("用鼠标左键多点圈出ROI,右键闭合完成当前ROI。")
-        print("按 ESC 键结束所有ROI选择,按其他键继续添加下一个ROI。")
-        cv2.imshow("Image 1 - Select ROI", drawing_image)
-
-        key = cv2.waitKey(0)
-
-        if key == 27:  # ESC键,结束ROI选择
-            cv2.destroyWindow("Image 1 - Select ROI")
-            if len(current_roi_points) >= 3:
-                # 保存最后一个ROI
-                roi_name = input(f"请输入ROI #{roi_count} 的名称: ").strip()
-                if not roi_name:
-                    roi_name = f"ROI_{roi_count}"
-                all_rois.append({
-                    "name": roi_name,
-                    "points": current_roi_points
-                })
-                print(f"已保存ROI: {roi_name}")
-            break
-        else:
-            if len(current_roi_points) >= 3:
-                # 为当前ROI输入名称
-                roi_name = input(f"请输入ROI #{roi_count} 的名称: ").strip()
-                if not roi_name:
-                    roi_name = f"ROI_{roi_count}"
-
-                all_rois.append({
-                    "name": roi_name,
-                    "points": current_roi_points
-                })
-                print(f"已保存ROI: {roi_name}")
-            else:
-                print("当前ROI点数不足,跳过保存。")
-                roi_count -= 1
-
-    if not all_rois:
-        print("错误:未选择任何ROI,程序退出。")
-        return
-
-    print(f"\n总共选择了 {len(all_rois)} 个ROI")
-
-    # --- 步骤 3: 交互式选择屏蔽区 ---
-    drawing_ignore = draw_all_rois(original_image, all_rois)
-    cv2.namedWindow("Image 1 - Ignore Area")
-    cv2.setMouseCallback("Image 1 - Ignore Area", select_ignore_rect)
-    print("\n【步骤 2】在弹窗中,用鼠标左键拖拽拉框来选择屏蔽区,可画多个。")
-    print("         所有屏蔽区画完后,按 Enter 键保存并结束,或按 Esc 键退出。")
-    cv2.imshow("Image 1 - Ignore Area", drawing_ignore)
-
-    while True:
-        key = cv2.waitKey(0)
-        if key in [13, 10]:  # Enter键
-            break
-        elif key == 27:  # Esc
-            cv2.destroyAllWindows()
-            print("用户取消操作。")
-            return
-    cv2.destroyAllWindows()
-
-    # --- 步骤 4: 封装数据并保存到 JSON ---
-    try:
-        with open(image1_path, 'rb') as f:
-            image_bytes = f.read()
-        image_base64 = base64.b64encode(image_bytes).decode('utf-8')
-        image_filename = os.path.basename(image1_path)
-    except Exception as e:
-        print(f"\n[失败] 读取并编码图像文件时出错: {e}")
-        return
-
-    # 封装数据
-    data_to_save = {
-        "image1_filename": image_filename,
-        "image1_data_base64": image_base64,
-        "rois": all_rois,  # 保存所有ROI及其名称
-        "ignore_polygons": ignore_polygons,
-        "parameters": CONFIG_PARAMS
-    }
-
-    try:
-        with open(output_json_path, 'w', encoding='utf-8') as f:
-            json.dump(data_to_save, f, indent=4, ensure_ascii=False)
-        print(f"\n[成功] {len(all_rois)} 个ROI、屏蔽区和图像数据已成功保存到: {output_json_path}")
-    except Exception as e:
-        print(f"\n[失败] 保存 JSON 文件时出错: {e}")
-
-
-# -------------------------------------------------
-if __name__ == "__main__":
-    IMG1_PATH = r"20250615000002.jpg"
-    OUTPUT_JSON = "roi_json_dir/matching_data_multi_roi.json"
-
-    # 创建输出目录
-    os.makedirs(os.path.dirname(OUTPUT_JSON), exist_ok=True)
-
-    define_and_save_data(IMG1_PATH, OUTPUT_JSON)

+ 0 - 12
run.bash

@@ -1,12 +0,0 @@
-#python train.py --model swin_v2_b
-#sleep 60
-#python train.py --model swin_v2_s
-#sleep 60
-python train.py --model squeezenet
-sleep 60
-python train.py --model squeezenet-x2
-sleep 60
-python train.py --model shufflenet
-sleep 60
-python train.py --model resnet50
-

BIN
runs_dec19/runs/turbidity_classification/events.out.tfevents.1766071975.240.2967095.0


BIN
runs_dec19/runs/turbidity_classification/events.out.tfevents.1766081826.240.3294804.0


+ 0 - 304
test.py

@@ -1,304 +0,0 @@
-import time
-
-import torch
-import torch.nn as nn
-from torchvision import transforms
-from torchvision.models import resnet18, ResNet18_Weights,resnet50,ResNet50_Weights, squeezenet1_0, SqueezeNet1_0_Weights,\
-    shufflenet_v2_x1_0, ShuffleNet_V2_X1_0_Weights, swin_v2_s, Swin_V2_S_Weights, swin_v2_b, Swin_V2_B_Weights
-import numpy as np
-from PIL import Image
-import os
-import argparse
-from labelme.utils import draw_grid, draw_predict_grid
-import cv2
-import matplotlib.pyplot as plt
-from dotenv import load_dotenv
-load_dotenv()
-# os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
-patch_w = int(os.getenv('PATCH_WIDTH', 256))
-patch_h = int(os.getenv('PATCH_HEIGHT', 256))
-confidence_threshold = float(os.getenv('CONFIDENCE_THRESHOLD', 0.80))
-scale = 2
-
-
-class Predictor:
-    def __init__(self, model_name, weights_path, num_classes):
-        self.model_name = model_name
-        self.weights_path = weights_path
-        self.num_classes = num_classes
-        self.model = None
-        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-        print(f"当前设备: {self.device}")
-        # 加载模型
-        self.load_model()
-
-
-    def load_model(self):
-        if self.model is not None:
-            return
-        print(f"正在加载模型: {self.model_name}")
-        name = self.model_name
-        # 加载模型
-        if name == 'resnet50':
-            self.weights = ResNet50_Weights.IMAGENET1K_V2
-            self.model = resnet50(weights=self.weights)
-        elif name == 'squeezenet':
-            self.weights = SqueezeNet1_0_Weights.IMAGENET1K_V1
-            self.model = squeezenet1_0(weights=self.weights)
-        elif name == 'shufflenet':
-            self.weights = ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1
-            self.model = shufflenet_v2_x1_0(weights=self.weights)
-        elif name == 'swin_v2_s':
-            self.weights = Swin_V2_S_Weights.IMAGENET1K_V1
-            self.model = swin_v2_s(weights=self.weights)
-        elif name == 'swin_v2_b':
-            self.weights = Swin_V2_B_Weights.IMAGENET1K_V1
-            self.model = swin_v2_b(weights=self.weights)
-        else:
-            raise ValueError(f"Invalid model name: {name}")
-        # 替换最后的分类层以适应新的分类任务
-        if hasattr(self.model, 'fc'):
-            # ResNet系列模型
-            self.model.fc = nn.Sequential(
-                nn.Linear(int(self.model.fc.in_features), int(self.model.fc.in_features) // 2, bias=True),
-                nn.ReLU(inplace=True),
-                nn.Dropout(0.5),
-                nn.Linear(int(self.model.fc.in_features) // 2, self.num_classes, bias=False)
-            )
-        elif hasattr(self.model, 'classifier'):
-            # Swin Transformer等模型
-            self.model.classifier = nn.Sequential(
-                nn.Linear(int(self.model.classifier.in_features), int(self.model.classifier.in_features) // 2,
-                          bias=True),
-                nn.ReLU(inplace=True),
-                nn.Dropout(0.5),
-                nn.Linear(int(self.model.classifier.in_features) // 2, self.num_classes, bias=False)
-            )
-        elif hasattr(self.model, 'head'):
-            # Swin Transformer使用head层
-            in_features = self.model.head.in_features
-            self.model.head = nn.Sequential(
-                nn.Linear(int(in_features), int(in_features) // 2, bias=True),
-                nn.ReLU(inplace=True),
-                nn.Dropout(0.5),
-                nn.Linear(int(in_features) // 2, self.num_classes, bias=False)
-            )
-        else:
-            raise ValueError(f"Model {name} does not have recognizable classifier layer")
-        print(self.model)
-        # 加载训练好的权重
-        self.model.load_state_dict(torch.load(self.weights_path, map_location=torch.device('cpu')))
-        print(f"成功加载模型参数: {self.weights_path}")
-        # 将模型移动到GPU
-        self.model.eval()
-        self.model = self.model.to(self.device)
-        print(f"成功加载模型: {self.model_name}")
-
-    def predict(self, image_tensor):
-        """
-        对单张图像进行预测
-
-        Args:
-            image_tensor: 预处理后的图像张量
-
-        Returns:
-            predicted_class: 预测的类别索引
-            confidence: 预测置信度
-            probabilities: 各类别的概率
-        """
-
-        image_tensor = image_tensor.to(self.device)
-
-        with torch.no_grad():
-            outputs = self.model(image_tensor)
-            probabilities = torch.nn.functional.softmax(outputs, dim=1)  # 沿行计算softmax
-            confidence, predicted_class = torch.max(probabilities, 1)
-
-        return confidence.cpu().numpy(), predicted_class.cpu().numpy()
-
-
-def preprocess_image(img):
-    """
-    预处理图像以匹配训练时的预处理
-    
-    Args:
-        img: PIL图像
-        
-    Returns:
-        tensor: 预处理后的图像张量
-    """
-    # 定义与训练时相同的预处理步骤
-    transform = transforms.Compose([
-        transforms.Resize((224, 224)),
-        transforms.ToTensor(),
-        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
-    ])
-
-    # 打开并转换图像
-
-    img_w, img_h = img.size
-    global patch_w, patch_h
-    imgs_patch = []
-    imgs_index = []
-    # fig, axs = plt.subplots(img_h // patch_h + 1, img_w // patch_w + 1)
-    for i in range(img_h // patch_h + 1):
-        for j in range(img_w // patch_w + 1):
-            left = j * patch_w  # 裁剪区域左边框距离图像左边的像素值
-            top = i * patch_h  # 裁剪区域上边框距离图像上边的像素值
-            right = min(j * patch_w + patch_w, img_w)  # 裁剪区域右边框距离图像左边的像素值
-            bottom = min(i * patch_h + patch_h, img_h)  # 裁剪区域下边框距离图像上边的像素值
-            # 检查区域是否有效
-            if right > left and bottom > top:
-                patch = img.crop((left, top, right, bottom))
-                # 长宽比过滤
-                # rate = patch.height / (patch.width + 1e-6)
-                # if rate > 1.314 or rate < 0.75:
-                #     # print(f"长宽比过滤: {patch_name}")
-                #     continue
-                imgs_patch.append(patch)
-                imgs_index.append((left, top))
-                # axs[i, j].imshow(patch)
-                # axs[i, j].set_title(f'Image {i} {j}')
-                # axs[i, j].axis('off')
-
-    # plt.tight_layout()
-    # plt.show()
-    imgs_patch = torch.stack([transform(img) for img in imgs_patch])
-    # 添加批次维度
-    # image_tensor = image_tensor.unsqueeze(0)
-    return imgs_index, imgs_patch
-
-
-def visualize_prediction(image_path, predicted_class, confidence, class_names):
-    """
-    可视化预测结果
-    
-    Args:
-        image_path: 图像路径
-        predicted_class: 预测的类别索引
-        confidence: 预测置信度
-        class_names: 类别名称列表
-    """
-    image = Image.open(image_path).convert('RGB')
-    
-    plt.figure(figsize=(8, 6))
-    plt.imshow(image)
-    plt.axis('off')
-    plt.title(f'Predicted: {class_names[predicted_class]}\n'
-              f'Confidence: {confidence:.4f}', fontsize=14)
-    plt.show()
-
-def get_33_patch(arr:np.ndarray, center_row:int, center_col:int):
-    """以(center_row,center_col)为中心,从arr中取出来3*3区域的数据"""
-    # 边界检查
-    h,w = arr.shape
-    safe_row_up_limit = max(0, center_row-1)
-    safe_row_bottom_limit = min(h, center_row+2)
-    safe_col_left_limit = max(0, center_col-1)
-    safe_col_right_limit = min(w, center_col+2)
-    return arr[safe_row_up_limit:safe_row_bottom_limit, safe_col_left_limit:safe_col_right_limit]
-
-
-def fileter_prediction(predicted_class, confidence, pre_rows, pre_cols, filter_down_limit=3):
-    """预测结果矩阵滤波,九宫格内部存在浑浊水体的数量需要大于filter_down_limit,"""
-    predicted_class_mat = np.resize(predicted_class, (pre_rows, pre_cols))
-    predicted_conf_mat = np.resize(confidence, (pre_rows, pre_cols))
-    new_predicted_class_mat = predicted_class_mat.copy()
-    new_predicted_conf_mat = predicted_conf_mat.copy()
-    for i in range(pre_rows):
-        for j in range(pre_cols):
-            if (1. - predicted_class_mat[i, j]) > 0.1:
-                continue  # 跳过背景类
-            core_region = get_33_patch(predicted_class_mat, i, j)
-            if np.sum(core_region) < filter_down_limit:
-                new_predicted_class_mat[i, j] = 0  #  重置为背景类
-                new_predicted_conf_mat[i, j] = 1.0
-    return new_predicted_conf_mat.flatten(), new_predicted_class_mat.flatten()
-
-def discriminate_ratio(water_pre_list:list):
-    # 方式一:60%以上的帧存在浑浊水体
-    water_pre_arr = np.array(water_pre_list, dtype=np.float32)
-    water_pre_arr_mean = np.mean(water_pre_arr, axis=0)
-    bad_water = np.array(water_pre_arr_mean >= 0.6, dtype=np.int32)
-    bad_flag = np.sum(bad_water, dtype=np.int32)
-    print(f'浑浊比例方式:该时间段是否存在浑浊水体:{bool(bad_flag)}')
-    return bad_flag
-
-
-def discriminate_cont(pre_class_arr, continuous_count_mat):
-    """连续帧判别"""
-    positive_index = np.array(pre_class_arr,dtype=np.int32) > 0
-    negative_index = np.array(pre_class_arr,dtype=np.int32) == 0
-    # 给负样本区域置零
-    continuous_count_mat[negative_index] = 0
-    # 给正样本区域加1
-    continuous_count_mat[positive_index] += 1
-    # 判断浑浊
-    bad_flag = np.max(continuous_count_mat) > 30
-    if bad_flag:
-        print(f'连续帧方式:该时间段是否存在浑浊水体:{bool(bad_flag)}')
-    return bad_flag
-
-def main():
-
-    # 初始化模型实例
-    # TODO:修改模型网络名称/模型权重路径/视频路径
-    predictor = Predictor(model_name='shufflenet',
-                          weights_path=r'D:\code\water_turbidity_det\shufflenet_best_model_acc.pth',
-                          num_classes=2)
-    input_path = r'D:\code\water_turbidity_det\data\4_video_202511211127'
-    # 预处理图像
-    all_imgs = os.listdir(input_path)
-    all_imgs = [os.path.join(input_path, p) for p in all_imgs if p.split('.')[-1] in ['jpg', 'png']]
-    image = Image.open(all_imgs[0]).convert('RGB')
-    # 将预测结果reshape为矩阵时的行列数量
-    pre_rows = image.height // patch_h + 1
-    pre_cols = image.width // patch_w + 1
-    # 图像显示时resize的尺寸
-    resized_img_h = image.height // 2
-    resized_img_w = image.width // 2
-    # 预测每张图像
-
-    water_pre_list = []
-    continuous_count_mat = np.zeros(pre_rows*pre_cols, dtype=np.int32)
-    flag = False
-    for img_path in all_imgs:
-        image = Image.open(img_path).convert('RGB')
-        # 预处理
-        patches_index, image_tensor = preprocess_image(image)
-        # 推理
-        confidence, predicted_class  = predictor.predict(image_tensor)
-        # 第一层虚警抑制,置信度过滤,低于阈值将会被忽略
-        for i in range(len(confidence)):
-            if confidence[i] < confidence_threshold:
-                confidence[i] = 1.0
-                predicted_class[i] = 0
-        # 第二层虚警抑制,空间滤波
-        # 在此处添加过滤逻辑
-        new_confidence, new_predicted_class = fileter_prediction(predicted_class, confidence, pre_rows, pre_cols, filter_down_limit=3)
-        # 可视化预测结果
-        image = cv2.imread(img_path)
-        image = draw_grid(image, patch_w, patch_h)
-        image = draw_predict_grid(image, patches_index, predicted_class, confidence)
-
-        new_image = cv2.imread(img_path)
-        new_image = draw_grid(new_image, patch_w, patch_h)
-        new_image = draw_predict_grid(new_image, patches_index, new_predicted_class, new_confidence)
-        image = cv2.resize(image, (resized_img_w, resized_img_h))
-        new_img = cv2.resize(new_image, (resized_img_w, resized_img_h))
-
-        cv2.imshow('image', image)
-        cv2.imshow('image_filter', new_img)
-
-        cv2.waitKey(20)
-        # 方式1判别
-        if len(water_pre_list) > 100:
-            flag = discriminate_ratio(water_pre_list) and flag
-            water_pre_list = []
-            print('综合判别结果:', flag)
-        water_pre_list.append(new_predicted_class)
-        # 方式2判别
-        flag = discriminate_cont(new_predicted_class, continuous_count_mat)
-
-if __name__ == "__main__":
-    main()

+ 0 - 319
train.py

@@ -1,319 +0,0 @@
-# 微调pytorch的预训练模型,在自己的数据上训练,完成分类任务。
-import time
-import torch
-import torch.nn as nn
-import torch.optim as optim
-from torch.utils.data import DataLoader
-import torchvision.transforms as transforms
-from torchvision.datasets import ImageFolder
-from torchvision.models import resnet18,resnet50, squeezenet1_0,shufflenet_v2_x1_0,shufflenet_v2_x2_0
-from torch.utils.tensorboard import SummaryWriter  # 添加 TensorBoard 支持
-from datetime import datetime
-import os
-from dotenv import load_dotenv
-load_dotenv()
-os.environ['CUDA_VISIBLE_DEVICES'] = os.getenv('CUDA_VISIBLE_DEVICES', '0')
-# 读取并打印.env文件中的变量
-def print_env_variables():
-    print("从.env文件加载的变量:")
-    env_vars = {
-        'PATCH_WIDTH': os.getenv('PATCH_WIDTH'),
-        'PATCH_HEIGHT': os.getenv('PATCH_HEIGHT'),
-        'CONFIDENCE_THRESHOLD': os.getenv('CONFIDENCE_THRESHOLD'),
-        'IMG_INPUT_SIZE': os.getenv('IMG_INPUT_SIZE'),
-        'WORKERS': os.getenv('WORKERS'),
-        'CUDA_VISIBLE_DEVICES': os.getenv('CUDA_VISIBLE_DEVICES')
-    }
-    for var, value in env_vars.items():
-        print(f"{var}: {value}")
-
-class Trainer:
-    def __init__(self, batch_size, train_dir, val_dir, name, checkpoint):
-        # 定义一些参数
-        self.name = name  # 采用的模型名称
-        self.img_size = int(os.getenv('IMG_INPUT_SIZE', 224))  # 输入图片尺寸
-        self.batch_size = batch_size  # 批次大小
-        self.cls_map = {"0": "non-muddy", "1":"muddy"}  # 类别名称映射词典
-        self.imagenet = True  # 是否使用ImageNet预训练权重
-        # 训练设备 - 优先使用GPU,如果不可用则使用CPU
-        if torch.cuda.is_available():
-            try:
-                # 尝试进行简单的CUDA操作以确认CUDA功能正常
-                _ = torch.zeros(1).cuda()
-                self.device = torch.device("cuda")
-                print("成功检测到CUDA设备,使用GPU进行训练")
-            except Exception as e:
-                print(f"CUDA设备存在问题: {e},回退到CPU")
-                self.device = torch.device("cpu")
-        else:
-            self.device = torch.device("cpu")
-            print("CUDA不可用,使用CPU进行训练")
-        self.checkpoint = checkpoint
-        self.__global_step = 0
-        self.workers = int(os.getenv('WORKERS', 0))
-        # 创建日志目录
-        timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")  # 获取当前时间戳
-        log_dir = f'runs/turbidity_{self.name}_{timestamp}'# 按照模型的名称和时间戳创建日志目录
-        self.writer = SummaryWriter(log_dir)  # 创建 TensorBoard writer
-
-        # 定义数据增强和预处理层
-        self.train_transforms = transforms.Compose([
-            transforms.Resize((self.img_size, self.img_size)),  # 调整图像大小为256x256 (ResNet输入尺寸)
-            transforms.RandomHorizontalFlip(p=0.3),  # 随机水平翻转,增加数据多样性
-            transforms.RandomVerticalFlip(p=0.3), # 随机垂直翻转,增加数据多样性
-            transforms.RandomGrayscale(p=0.25), # 随机灰度化,增加数据多样性
-            transforms.RandomRotation(10),  # 随机旋转±10度
-            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0, hue=0),  # 颜色抖动
-            transforms.ToTensor(),  # 转换为tensor并归一化到[0,1]
-            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化
-        ])
-        # 测试集基础变换
-        self.val_transforms = transforms.Compose([
-            transforms.Resize((self.img_size, self.img_size)),
-            transforms.ToTensor(),
-            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
-        ])
-
-        # 创建数据集对象
-        self.train_dataset = ImageFolder(root=train_dir, transform=self.train_transforms)
-        self.val_dataset = ImageFolder(root=val_dir, transform=self.val_transforms)
-        # 创建数据加载器 (Windows环境下设置num_workers=0避免多进程问题)
-        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers)  # 可迭代对象,返回(inputs_tensor, label)
-        self.val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.workers)
-        # 获取类别数量
-        self.num_classes = len(self.train_dataset.classes)
-        print(f"自动发现 {self.num_classes} 个类别:")
-        # 打印类别和名称
-        for cls in self.train_dataset.classes:
-            print(f"{cls}: {self.cls_map.get(cls,'None')}")
-
-        # 创建模型
-        self.model = None
-        self.model = self.__load_model()
-        # 定义损失函数
-        self.loss = nn.CrossEntropyLoss()  # 多分类常用的交叉熵损失
-
-        # 定义优化器
-        # 只更新requires_grad=True的参数
-        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
-
-        # 基于验证损失动态调整,更智能
-        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
-            self.optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-7,cooldown=2
-        )
-    def __load_model(self):
-        """加载模型结构"""
-        # 加载模型
-        pretrained = True if self.imagenet else False
-        if self.name == 'resnet50':
-            self.model = resnet50(pretrained=pretrained)
-        elif self.name == 'squeezenet':
-            self.model = squeezenet1_0(pretrained=pretrained)
-        elif self.name == 'shufflenet' or self.name == 'shufflenet-x1':
-            self.model = shufflenet_v2_x1_0(pretrained=pretrained)
-        elif self.name == 'shufflenet-x2':
-            self.model = shufflenet_v2_x2_0(pretrained=False)
-            self.imagenet = False
-            print('shufflenet-x2无预训练权重,重新训练所有权重')
-        else:
-            raise ValueError(f"Invalid model name: {self.name}")
-        # 如果采用预训练的神经网络,就需要冻结特征提取层,只训练最后几层
-        if self.imagenet:
-            for param in self.model.parameters():
-                param.requires_grad = False
-        # 替换最后的分类层以适应新的分类任务
-        if hasattr(self.model, 'fc'):
-            # ResNet系列模型
-            self.model.fc = nn.Linear(int(self.model.fc.in_features), self.num_classes, bias=False)
-            # self.model.fc = nn.Sequential(
-            #     nn.Linear(int(self.model.fc.in_features), int(self.model.fc.in_features) // 2, bias=False),
-            #     nn.ReLU(inplace=True),
-            #     nn.Dropout(0.5),
-            #     nn.Linear(int(self.model.fc.in_features) // 2, self.num_classes, bias=False)
-            # )
-        elif hasattr(self.model, 'classifier'):
-            # Swin Transformer等模型
-            self.model.classifier = nn.Linear(int(self.model.classifier.in_features), self.num_classes, bias=False)
-            # self.model.classifier = nn.Sequential(
-            #     nn.Linear(int(self.model.classifier.in_features), int(self.model.classifier.in_features) // 2,
-            #               bias=True),
-            #     nn.ReLU(inplace=True),
-            #     nn.Dropout(0.5),
-            #     nn.Linear(int(self.model.classifier.in_features) // 2, self.num_classes, bias=False)
-            # )
-        elif hasattr(self.model, 'head'):
-            # Swin Transformer使用head层
-            self.model.head = nn.Linear(int(self.model.head.in_features), self.num_classes, bias=False)
-            # in_features = self.model.head.in_features
-            # self.model.head = nn.Sequential(
-            #     nn.Linear(int(in_features), int(in_features) // 2, bias=True),
-            #     nn.ReLU(inplace=True),
-            #     nn.Dropout(0.5),
-            #     nn.Linear(int(in_features) // 2, self.num_classes, bias=False)
-            # )
-        else:
-            raise ValueError(f"Model {self.name} does not have recognizable classifier layer")
-        print(self.model)
-        print(f'模型{self.name}结构已经加载,移动到设备{self.device}')
-        # 将模型移动到GPU/cpu
-        self.model = self.model.to(self.device)
-        return self.model
-
-    def train_step(self):
-        """
-        单轮训练函数
-
-        Args:
-
-        Returns:
-            average_loss: 平均损失
-            accuracy: 准确率
-        """
-        self.model.train()  # 设置模型为训练模式(启用dropout/batchnorm等)
-        epoch_loss = 0.0  # epoch损失
-        correct_predictions = 0.0  # 预测正确的样本数
-        total_samples = 0.0  # 已经训练的总样本
-
-        # 遍历训练数据
-        for inputs, labels in self.train_loader:
-            # 将数据移到指定设备上
-            inputs = inputs.to(self.device)  # b c h w
-            labels = labels.to(self.device)  # b,
-
-            # 清零梯度缓存
-            self.optimizer.zero_grad()
-
-            # 前向传播
-            outputs = self.model(inputs)  # b, 2
-            loss = self.loss(outputs, labels) # 标量
-
-            # 反向传播
-            loss.backward()
-
-            # 更新参数
-            self.optimizer.step()
-
-            # 统计信息
-            batch_loss = loss.item() * inputs.size(0)  # 批损失
-            self.writer.add_scalar('Batch_Loss/Train', batch_loss, self.__global_step)
-            print(f'Training | Batch Loss: {batch_loss:.4f}\t', end='\r',flush=True)
-            epoch_loss += batch_loss  # epoch损失
-            _, predicted = torch.max(outputs.data, 1) # predicted b,存储的是预测的类别,最大值的位置
-            total_samples += labels.size(0)  # 统计总样本数量
-            correct_predictions += (predicted == labels).sum().item()  # 统计预测正确的样本数
-
-        epoch_loss = epoch_loss / len(self.train_loader.dataset)
-        epoch_acc = correct_predictions / total_samples
-        return epoch_loss, epoch_acc
-
-
-    def val_step(self):
-        """
-        验证模型性能
-
-        Args:
-        Returns:
-            average_loss: 平均损失
-            accuracy: 准确率
-        """
-        self.model.eval()  # 设置模型为评估模式(关闭dropout/batchnorm等)
-
-        epoch_loss = 0.0
-        correct_predictions = 0.
-        total_samples = 0.
-
-        # 不计算梯度,提高推理速度
-        with torch.no_grad():
-            for inputs, labels in self.val_loader:
-                inputs = inputs.to(self.device)
-                labels = labels.to(self.device)
-
-                outputs = self.model(inputs)
-                loss = self.loss(outputs, labels)
-
-                epoch_loss += loss.item() * inputs.size(0)
-                _, predicted = torch.max(outputs.data, 1)
-                total_samples += labels.size(0)
-                correct_predictions += (predicted == labels).sum().item()
-
-        epoch_loss = epoch_loss / len(self.val_loader.dataset)  # 平均损失
-        epoch_acc = correct_predictions / total_samples
-
-        return epoch_loss, epoch_acc
-
-
-    def train_and_validate(self, num_epochs=25):
-        """
-        训练和验证
-
-        Args:
-            num_epochs: 训练轮数
-
-        Returns:
-            train_losses: 每轮训练损失
-            train_accuracies: 每轮训练准确率
-            val_losses: 每轮验证损失
-            val_accuracies: 每轮验证准确率
-        """
-
-        best_val_acc = 0.0
-        best_val_loss = float('inf')
-        # 在你的代码中调用
-        print_env_variables()
-        print("开始训练...")
-        for epoch in range(num_epochs):
-            print(f'Epoch {epoch + 1}/{num_epochs}')
-            print('-' * 20)
-
-            # 单步训练
-            train_loss, train_acc = self.train_step()
-            print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')
-
-            # 验证阶段
-            val_loss, val_acc = self.val_step()
-            print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')
-
-            # 学习率调度
-            self.scheduler.step(val_loss)
-
-            # 记录指标到 TensorBoard
-            self.writer.add_scalar('Loss/Train', train_loss, epoch)
-            self.writer.add_scalar('Loss/Validation', val_loss, epoch)
-            self.writer.add_scalar('Accuracy/Train', train_acc, epoch)
-            self.writer.add_scalar('Accuracy/Validation', val_acc, epoch)
-            self.writer.add_scalar('Learning Rate', self.scheduler.get_last_lr()[0], epoch)
-
-            # 保存最佳模型 (基于验证准确率)
-            if val_acc > best_val_acc:
-                best_val_acc = val_acc
-                torch.save(self.model.state_dict(), f'{self.name}_best_model_acc.pth')
-                print(f"保存了新的最佳准确率模型,验证准确率: {best_val_acc:.4f}")
-            
-            # 保存最低验证损失模型
-            if val_loss < best_val_loss:
-                best_val_loss = val_loss
-                torch.save(self.model.state_dict(), f'{self.name}_best_model_loss.pth')
-                print(f"保存了新的最低损失模型,验证损失: {best_val_loss:.4f}")
-
-
-        # 关闭 TensorBoard writer
-        self.writer.close()
-        
-        print(f"训练完成! 最佳验证准确率: {best_val_acc:.4f}, 最低验证损失: {best_val_loss:.4f}")
-        return 1
-
-if __name__ == '__main__':
-    # 开始训练
-    import argparse
-    parser = argparse.ArgumentParser('预训练模型调参')
-    parser.add_argument('--train_dir',default='./label_data/train',help='help')
-    parser.add_argument('--val_dir', default='./label_data/test',help='help')
-    parser.add_argument('--model', default='shufflenet',help='help')
-    args = parser.parse_args()
-    num_epochs = 100
-    trainer = Trainer(batch_size=128,
-                      train_dir=args.train_dir,
-                      val_dir=args.val_dir,
-                      name=args.model,
-                      checkpoint=False)
-    trainer.train_and_validate(num_epochs)

+ 0 - 0
video/原始视频文件.md