jiyuhang 2 mesiacov pred
rodič
commit
360d563aff

+ 246 - 0
bmodel_application.py

@@ -0,0 +1,246 @@
+import argparse
+import sophon.sail as sail
+import cv2
+import os
+import logging
+import json
+import numpy as np
+
+class Predictor:
+
+    def __init__(self):
+        # 加载推理引擎
+        self.net = sail.Engine(args.bmodel, args.dev_id, sail.IOMode.SYSIO)
+        self.graph_name = self.net.get_graph_names()[0]
+        self.input_names = self.net.get_input_names(self.graph_name)
+        self.input_shapes = [self.net.get_input_shape(self.graph_name, name) for name in self.input_names]
+        self.output_names = self.net.get_output_names(self.graph_name)
+        self.output_shapes = [self.net.get_output_shape(self.graph_name, name) for name in self.output_names]  # [[1, 2]]
+        self.input_name = self.input_names[0]
+        self.input_shape = self.input_shapes[0]  # [1, 3, 256, 256]
+        self.batch_size = self.input_shape[0]
+        self.net_h = self.input_shape[2]  # 输入图像patch的高
+        self.net_w = self.input_shape[3]  # 输入图像patch的宽
+        # 归一化参数,采用imagenet预训练参数
+        self.mean = [0.485, 0.456, 0.406]
+        self.std = [0.229, 0.224, 0.225]
+        # 一张图像有多少行、列的patch
+        self.current_patch_rows = 0
+        self.current_patch_cols = 0
+        # 输入的大图高度和宽度
+        self.input_image_h = 0
+        self.input_image_w = 0
+        self.print_network_info()
+        # 报警的置信度阈值
+        self.confidence_threshold = 0.6
+        # 存储连续的帧的检测结果
+        self.continuous_detection_result = []
+        # 连续浑浊的patch计数器
+        self.continuous_counter_mat = np.array(0)
+        # 判定有浑浊水体时的连续帧数量
+        self.max_continuous_frames = 15
+
+    def __call__(self, img) -> bool:
+        # 预处理,获取输入图像的patch序列和左上角坐标
+        patches_index, patches = self.preprocess(img)
+        # 推理
+        confidence = []
+        predicted_class = []
+        for i in range(0, len(patches), self.batch_size):# 根据模型的输入batch size创建图像的tensor
+            batch_patches = patches[i:i + self.batch_size]
+            # 处理尾部, 当最后一批数据不足batch size时,使用最后一个patch填充
+            if len(batch_patches) < self.batch_size:
+                batch_patches += [batch_patches[-1]] * (self.batch_size - len(batch_patches))
+            patches_tensor = np.stack(batch_patches)
+            # print('推理中:', patches_tensor.shape)
+            #  调用推理引擎
+            batch_confi, batch_cls = self.predict(patches_tensor)  # 返回值是二维数组,形状为[batch_size, cls]
+            confidence += batch_confi
+            predicted_class += batch_cls
+        confidence = np.array(confidence[:len(patches)])
+        predicted_class = np.array(predicted_class[:len(patches)])
+        # 后处理, 报警逻辑
+        # print('推理置信度:', confidence)
+        # print('原始预测结果:', predicted_class)
+        alarm = self.postprocess(confidence=confidence, predicted_class=predicted_class)
+        return alarm
+
+    def print_network_info(self):
+        info = {
+            'Graph Name': self.graph_name,
+            'Input Name': self.input_name,
+            'Output Names': self.output_names,
+            'Output Shapes': self.output_shapes,
+            'Input Shape': self.input_shape,
+            'Batch Size': self.batch_size,
+            'Height': self.net_h,
+            'Width': self.net_w,
+            'Mean': self.mean,
+            'Std': self.std
+        }
+
+        print("=" * 50)
+        print("Network Configuration Info")
+        print("=" * 50)
+        for key, value in info.items():
+            print(f"{key:<18}: {value}")
+        print("=" * 50)
+
+    def predict(self, input_img):
+        input_data = {self.input_name: input_img}
+        outputs = self.net.process(self.graph_name, input_data)
+        # print('predict fun:', outputs)
+        outputs = list(outputs.values())[0]
+        # print('predict fun return:', outputs)
+        outputs_exp = np.exp(outputs)
+        # print('exp res:', outputs_exp)
+        outputs = outputs_exp / np.sum(outputs_exp, axis=1)[:, None]
+        # print('softmax res:', outputs)
+        confidence = np.max(outputs, axis=1)
+        # print('最大概率:', confidence)
+        predictions = np.argmax(outputs, axis=1)  # 返回最大概率的类别
+        # print('预测结果:', predictions)
+        return confidence.tolist(), predictions.tolist()
+
+    def preprocess(self, img: np.ndarray):
+        """用于视频报警的预处理,将一张图像从左到右从上到下以此剪裁为patch序列
+        输入:完整图像
+        输出:
+        """
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        img = img.astype('float32')
+        img = (img / 255. - self.mean) / self.std  # 这一步是很有必要的, 因为编译过程并不会帮你做归一化,所以这里要自己做归一化,否则预测数值可能会非常不准确
+        img_h, img_w, _ = img.shape
+        img = np.transpose(img, (2, 0, 1))  # channel first
+        # 自上而下,从左到右
+        patches = []
+        patch_index = []
+        if img_h != self.input_image_h or img_w != self.input_image_w:  # 只在输入图像尺寸改变时重新计算patch的行列数量
+            self.input_image_h = img_h
+            self.input_image_w = img_w
+            self.current_patch_rows = self.input_image_h // self.net_h + 1
+            self.current_patch_cols = self.input_image_w // self.net_w + 1
+            # 此处初始化
+            print('初始化连续计数器矩阵')
+            self.continuous_counter_mat = np.zeros(self.current_patch_rows*self.current_patch_cols, dtype=np.int32)
+        # 以此剪裁patch
+        for i in range(self.current_patch_rows):
+            for j in range(self.current_patch_cols):
+                start_row = i * self.net_h
+                start_col = j * self.net_w
+                end_row = min(start_row + self.net_h, self.input_image_h)
+                end_col = min(start_col + self.net_w, self.input_image_w)
+                patch = img[::, start_row:end_row, start_col:end_col]
+                _, patch_h, patch_w = patch.shape
+                if patch_h != self.net_h or patch_w != self.net_w:
+                    patch = np.transpose(patch, (1, 2, 0))
+                    patch = cv2.resize(patch, (self.net_w, self.net_h))
+                    patch = np.transpose(patch, (2, 0, 1))
+                patches.append(patch)  # 图块
+                patch_index.append((start_col, start_row))  # 图块的左上角坐标
+        return patch_index, patches
+
+    def postprocess(self, confidence, predicted_class):
+        """根据预测结果判定是否报警"""
+        # 第一层虚警抑制,置信度过滤,低于阈值将会被忽略
+        for i in range(len(confidence)):
+            if confidence[i] < self.confidence_threshold  and predicted_class[i] == 1:
+                confidence[i] = 1.0
+                predicted_class[i] = 0
+        # 第二层虚警抑制,空间滤波
+        confidence, predicted_class = self.filter_prediction(predicted_class=predicted_class, confidence=confidence, filter_down_limit=3)
+        # print('最终结果:', confidence)
+        # print('后处理,空间滤波后结果:', predicted_class)
+        # 第三层 时间滤波
+        self.continuous_detection_result.append(predicted_class)
+        flag = self.update_continuous_counter(predicted_class) # 连续帧滤波
+        if len(self.continuous_detection_result) >= 20:
+            flag = self.discriminate_ratio() and flag
+            print(f'是否存在浑浊水体,综合判别结果:{flag}')
+            self.continuous_detection_result.pop(0)
+            return  flag
+        return  False
+    @staticmethod
+    def get_33_patch(arr: np.ndarray, center_row: int, center_col: int):
+        """以(center_row,center_col)为中心,从arr中取出来3*3区域的数据"""
+        # 边界检查
+        h, w = arr.shape
+        safe_row_up_limit = max(0, center_row - 1)
+        safe_row_bottom_limit = min(h, center_row + 2)
+        safe_col_left_limit = max(0, center_col - 1)
+        safe_col_right_limit = min(w, center_col + 2)
+        return arr[safe_row_up_limit:safe_row_bottom_limit, safe_col_left_limit:safe_col_right_limit]
+
+    def update_continuous_counter(self, pre_class_arr):
+        """连续帧判别"""
+        positive_index = np.array(pre_class_arr, dtype=np.int32) > 0
+        negative_index = np.array(pre_class_arr, dtype=np.int32) == 0
+        # 给负样本区域置零
+        self.continuous_counter_mat[negative_index] -= 1
+        # 给正样本区域加1
+        self.continuous_counter_mat[positive_index] += 1
+        # 保证不出现负数
+        self.continuous_counter_mat[self.continuous_counter_mat<0] = 0
+        # 判断浑浊
+        bad_flag = bool(np.sum(self.continuous_counter_mat > self.max_continuous_frames) > 2)  # 两个以上的patch满足条件
+        print('连续帧信息:', self.continuous_counter_mat)
+        print(f'连续帧判别:该时间段是否存在浑浊水体:{bad_flag}')
+        return bad_flag
+
+    def discriminate_ratio(self):
+        water_pre_list = self.continuous_detection_result.copy()
+        # 方式一:60%以上的帧存在浑浊水体
+        water_pre_arr = np.array(water_pre_list, dtype=np.float32)
+        water_pre_arr_sum = np.sum(water_pre_arr, axis=0)
+        bad_water = np.array(water_pre_arr_sum >= 0.6 * len(water_pre_list), dtype=np.int32)
+        bad_flag = bool(np.sum(bad_water, dtype=np.int32) > 2)  # 大于两个patch符合要求才可以
+        # print('比例信息:',water_pre_arr_sum )
+        print(f'浑浊比例判别:该时间段是否存在浑浊水体:{bad_flag}')
+        return bad_flag
+
+    def filter_prediction(self, predicted_class, confidence, filter_down_limit=3):
+        """预测结果矩阵滤波,九宫格内部存在浑浊水体的数量需要大于filter_down_limit,"""
+        predicted_class_mat = np.resize(predicted_class, (self.current_patch_rows, self.current_patch_cols))
+        predicted_conf_mat = np.resize(confidence, (self.current_patch_rows, self.current_patch_cols))
+        new_predicted_class_mat = predicted_class_mat.copy()
+        new_predicted_conf_mat = predicted_conf_mat.copy()
+        for i in range(self.current_patch_rows):
+            for j in range(self.current_patch_cols):
+                if (1. - predicted_class_mat[i, j]) > 0.1:
+                    continue  # 跳过背景类
+                core_region = self.get_33_patch(predicted_class_mat, i, j)
+                if np.sum(core_region) < filter_down_limit:
+                    new_predicted_class_mat[i, j] = 0  # 重置为背景类
+                    new_predicted_conf_mat[i, j] = 1.0
+        return new_predicted_conf_mat.flatten(), new_predicted_class_mat.flatten()
+def argsparser():
+    parser = argparse.ArgumentParser(prog=__file__)
+    parser.add_argument('--input','-i', type=str, default=r'4_video_20251223163145', help='path of input, must be image directory')
+    parser.add_argument('--bmodel','-b', type=str, default='./shufflenet_f32.bmodel', help='path of bmodel')
+    parser.add_argument('--dev_id','-d', type=int, default=0, help='tpu id')
+    args = parser.parse_args()
+
+    return args
+
+def main(args):
+    """函数的目的是为了实现一个能够报警的完整业务逻辑
+    输入:路径,包含了一个视频的帧序列,按照时间展开
+    输出:确认存在浑浊的水体,是表示报警,否表示无明显浑浊水体
+    """
+    # 加载推理引擎
+    predictor = Predictor()
+    # 获取图片
+    all_imgs = [os.path.join(args.input, i) for i in sorted(os.listdir(args.input))]
+    for img_path in all_imgs:
+        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
+        # 跳过空图片
+        if img is None:
+            continue
+        if img.size == 0:
+            continue
+        print("正在处理:", img_path)
+        print('污水警报', predictor(img))
+
+if __name__ == '__main__':
+    args = argsparser()
+    main(args)

+ 4 - 4
labelme/crop_patch.py

@@ -15,9 +15,9 @@ patch_h = int(os.getenv('PATCH_HEIGHT', 256))
 
 def main():
     # TODO:需要修改为标注好的图片路径
-    input_path = r'/frame_data/test/20251225/video4_20251129120320_20251129123514'
+    input_path = r'D:\code\water_turbidity_det\frame_data\test\20251230\4video_20251229133103'
     # TODO: 需要修改为保存patch的根目录
-    output_path_root = r'D:\code\water_turbidity_det\label_data\test'
+    output_path_root = r'D:\code\water_turbidity_det\label_data_tem\test'
 
     # 读取标注文件
     label_path = os.path.join(input_path, 'label.txt')
@@ -55,7 +55,7 @@ def main():
         # 再将大于0的patch保存到对应的类别文件夹下
         for g in grids_info:
             if g[4] > 0:  # 标签大于零的放到相应的文件夹下
-                patch_name = f'{img_base_name}_{g[0]}_{g[1]}_{g[4]}.jpg'  # 图块保存名称:图片名_左上角x_左上角y_类别.jpg
+                patch_name = f'{os.path.basename(input_path)}_{img_base_name}_{g[0]}_{g[1]}_{g[4]}.jpg'  # 图块保存名称:video名_图片名_左上角x_左上角y_类别.jpg
                 patch = img[g[1]:min(g[1]+g[3], img_h), g[0]:min(g[0]+g[2], img_w), :]
                 # 保存图块
                 cv2.imwrite(os.path.join(output_path_root,str(g[4]), patch_name), patch)
@@ -65,7 +65,7 @@ def main():
         for i in range(img_h // patch_h + 1):
             for j in range(img_w // patch_w + 1):
                 patch = img[i*patch_h:min(i*patch_h+patch_h, img_h), j*patch_w:min(j*patch_w+patch_w, img_w), :]
-                patch_name = f'{img_base_name}_{j*patch_w}_{i*patch_h}_0.jpg'
+                patch_name = f'{os.path.basename(input_path)}_{img_base_name}_{j*patch_w}_{i*patch_h}_0.jpg'
                 # 长宽比过滤
                 if patch.shape[0] / (patch.shape[1]+1e-6) > 1.314 or patch.shape[0] / (patch.shape[1]+1e-6) < 0.75:
                     #print(f"长宽比过滤: {patch_name}")

+ 1 - 1
labelme/fixed_label.py

@@ -164,7 +164,7 @@ def main():
     global patch_h
     global scale
     # TODO: 需要更改为准备标注的图像路径,使用当前目录下的000000.jpg,结果保存在当前目录下label.txt
-    img_path = r'/frame_data/test/video4_20251129120320_20251129123514\000000.jpg'
+    img_path = r'D:\code\water_turbidity_det\frame_data\train\20251230\4video_20251229160514\000040.jpg'
     play_video(img_path)
     img = cv2.imread(img_path)
     draw_tool.set_path(img_path)

+ 1 - 1
labelme/random_del.py

@@ -4,7 +4,7 @@ import random
 
 def main():
     # TODO:需要修改图像路径
-    path = r'D:\code\water_turbidity_det\label_data\test\0'
+    path = r'D:\code\water_turbidity_det\label_data_tem\train\0'
     del_rate = 0.3
     img_path = [i for i in os.listdir(path) if i.split('.')[-1] in ['jpg', 'png'] ]
     random.shuffle(img_path)

+ 97 - 14
labelme/statistic.py

@@ -1,28 +1,111 @@
 # 统计标注好的数据,同时给出统计结果保存为txt
 import os
-def count_imgs(path:str, tag:str)->str:
+
+def count_imgs(path: str, tag: str) -> dict:
     target_path = os.path.join(path, tag)
     # 获取类别子目录
     sta_res = {}
+    total_count = 0
+    
     for c in os.listdir(target_path):
         cls_path = os.path.join(target_path, c)
-        # 获取图片
-        imgs = os.listdir(cls_path)
-        sta_res[c] = len(imgs)
-    return f'{tag} data statistics: {sta_res}'
+        if os.path.isdir(cls_path):  # 确保是目录
+            # 获取图片
+            imgs = [f for f in os.listdir(cls_path) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.gif'))]
+            count = len(imgs)
+            sta_res[c] = count
+            total_count += count
+    
+    return {
+        'dataset_type': tag,
+        'class_counts': sta_res,
+        'total_count': total_count,
+        'num_classes': len(sta_res)
+    }
+
+def format_statistics(stats: dict) -> str:
+    """格式化统计数据为易读的字符串"""
+    dataset_type = stats['dataset_type']
+    class_counts = stats['class_counts']
+    total_count = stats['total_count']
+    num_classes = stats['num_classes']
+    
+    formatted = []
+    formatted.append(f"{'='*50}")
+    formatted.append(f"{dataset_type.upper()} 数据集统计")
+    formatted.append(f"{'='*50}")
+    formatted.append(f"总图片数量: {total_count}")
+    formatted.append(f"类别数量: {num_classes}")
+    formatted.append("-" * 30)
+    formatted.append("各类别分布:")
+    
+    # 按类别名称排序
+    for class_name in sorted(class_counts.keys()):
+        count = class_counts[class_name]
+        percentage = (count / total_count * 100) if total_count > 0 else 0
+        formatted.append(f"  {class_name:<20}: {count:>6} 张 ({percentage:>5.1f}%)")
+    
+    formatted.append(f"{'='*50}")
+    return "\n".join(formatted)
+
 def main():
-    train_data_path = r'D:\code\water_turbidity_det\label_data'
+    # TODO:修改数据集路径
+    train_data_path = r'D:\code\water_turbidity_det\label_data_tem'
     dirs = os.listdir(train_data_path)
-    info = []
+    
+    # 检查数据集目录是否存在
+    dataset_types = []
     if 'train' in dirs:
-        info.append(count_imgs(train_data_path, 'train'))
+        dataset_types.append('train')
     if 'test' in dirs:
-        info.append(count_imgs(train_data_path, 'test'))
+        dataset_types.append('test')
     if 'val' in dirs:
-        info.append(count_imgs(train_data_path, 'val'))
-    with open(os.path.join(train_data_path, 'statistic.txt'), 'w') as fw:
-        for i in info:
-            fw.write(i)
-            fw.write('\n')
+        dataset_types.append('val')
+    
+    if not dataset_types:
+        print(f"在 {train_data_path} 中未找到 train, test, 或 val 目录")
+        return
+    
+    all_stats = []
+    summary_stats = {
+        'total_overall': 0,
+        'datasets': {}
+    }
+    
+    for dataset_type in dataset_types:
+        stats = count_imgs(train_data_path, dataset_type)
+        all_stats.append(stats)
+        summary_stats['datasets'][dataset_type] = stats
+        summary_stats['total_overall'] += stats['total_count']
+    
+    # 写入详细的统计信息到文件
+    output_path = os.path.join(train_data_path, 'statistic.txt')
+    with open(output_path, 'w', encoding='utf-8') as fw:
+        fw.write("数据集详细统计报告\n")
+        fw.write("="*60 + "\n\n")
+        
+        for stats in all_stats:
+            fw.write(format_statistics(stats))
+            fw.write("\n\n")
+        
+        # 添加汇总统计
+        fw.write("汇总统计\n")
+        fw.write("="*30 + "\n")
+        fw.write(f"总体图片数量: {summary_stats['total_overall']}\n")
+        fw.write(f"数据集类型: {', '.join(summary_stats['datasets'].keys())}\n")
+        
+        if len(summary_stats['datasets']) > 1:
+            fw.write("\n数据集分布:\n")
+            for dataset_type, stats in summary_stats['datasets'].items():
+                percentage = (stats['total_count'] / summary_stats['total_overall'] * 100) if summary_stats['total_overall'] > 0 else 0
+                fw.write(f"  {dataset_type:<10}: {stats['total_count']:>6} 张 ({percentage:>5.1f}%)\n")
+    
+    # 同时在控制台输出简要统计
+    print("数据集统计完成!详细信息已保存到:", output_path)
+    print("\n简要统计:")
+    for stats in all_stats:
+        print(f"{stats['dataset_type']}集: {stats['total_count']}张图片, {stats['num_classes']}个类别")
+    print(f"总计: {summary_stats['total_overall']}张图片")
+
 if __name__ == '__main__':
     main()

+ 2 - 2
labelme/video_depart.py

@@ -5,10 +5,10 @@ import shutil
 def main():
     # 视频路径
     # TODO: 修改视频路径为自己的视频路径,每次指定一个视频
-    path = r'D:\code\water_turbidity_det\video\video20251225\4_video_20251223163145.mp4'
+    path = r'D:\code\water_turbidity_det\video\20251230day\4video_20251229160514.mp4'
     output_rootpath = r'D:\code\water_turbidity_det\frame_data'  # 输出路径的根目录
     # 抽帧间隔
-    interval = 15
+    interval = 20
     # 我们将图像输出到根目录下的子目录中,子目录和视频名称相同
 
     img_base = os.path.basename(path).split('.')[0]

+ 18 - 5
pth2onnx.py

@@ -3,17 +3,30 @@ import onnx
 import torch.onnx
 import onnxruntime as ort
 import numpy as np
-from torchvision.models import resnet50, shufflenet_v2_x1_0, shufflenet_v2_x2_0
+from torchvision.models import resnet50, shufflenet_v2_x1_0, shufflenet_v2_x2_0, squeezenet1_0
 from torch import nn
 # from simple_model import  SimpleModel
 if __name__ == '__main__':
 
     # 载入模型框架
     # model = SimpleModel()
-    model = resnet50(pretrained=False)
-    # model = shufflenet_v2_x1_0()
-    model_name = "resnet50"
-    model.fc = nn.Linear(int(model.fc.in_features), 2, bias=True)
+    # model = resnet50(pretrained=False)
+    model = shufflenet_v2_x1_0()
+    # model = shufflenet_v2_x2_0()
+    # model = squeezenet1_0()
+    model_name = "shufflenet"
+    if model_name == "squeezenet":
+        # 获取SqueezeNet的最后一个卷积层的输入通道数
+        final_conv_in_channels = model.classifier[1].in_channels
+        # 替换classifier为新的Sequential,将输出改为2类
+        model.classifier = nn.Sequential(
+            nn.Dropout(p=0.5),
+            nn.Conv2d(final_conv_in_channels, 2, kernel_size=(1, 1)),
+            nn.ReLU(inplace=True),
+            nn.AdaptiveAvgPool2d((1, 1))
+        )
+    if model_name == "shufflenet":
+        model.fc = nn.Linear(int(model.fc.in_features), 2, bias=True)
     model.load_state_dict(torch.load(rf'./{model_name}.pth')) # xxx.pth表示.pth文件, 这一步载入模型权重
     print("加载模型成功")
     model.eval() # 设置模型为推理模式

BIN
runs/turbidity_squeezenet_20251226-105741/events.out.tfevents.1766717861.O5XVSKDYW6B57G7.32516.0


BIN
runs/turbidity_squeezenet_20251226-105857/events.out.tfevents.1766717937.O5XVSKDYW6B57G7.17992.0


BIN
runs/turbidity_squeezenet_20251226-110326/events.out.tfevents.1766718206.O5XVSKDYW6B57G7.28116.0


+ 17 - 4
train.py

@@ -124,18 +124,31 @@ class Trainer:
             for param in self.model.parameters():
                 param.requires_grad = False
         # 替换最后的分类层以适应新的分类任务
+        print(self.model)
+        print(f"正在将模型{self.name}的分类层替换为{self.num_classes}个类别")
         if hasattr(self.model, 'fc'):
             # ResNet系列模型
             self.model.fc = nn.Linear(int(self.model.fc.in_features), self.num_classes, bias=True)
         elif hasattr(self.model, 'classifier'):
-            # Swin Transformer等模型
-            self.model.classifier = nn.Linear(int(self.model.classifier.in_features), self.num_classes, bias=True)
+            # SqueezeNet、ShuffleNet系列模型
+            if self.name == 'squeezenet':
+                # 获取SqueezeNet的最后一个卷积层的输入通道数
+                final_conv_in_channels = self.model.classifier[1].in_channels
+                # 替换classifier为新的Sequential,将输出改为2类
+                self.model.classifier = nn.Sequential(
+                    nn.Dropout(p=0.5),
+                    nn.Conv2d(final_conv_in_channels, self.num_classes, kernel_size=(1, 1)),
+                    nn.ReLU(inplace=True),
+                    nn.AdaptiveAvgPool2d((1, 1))
+                )
+            else:
+                # Swin Transformer等模型
+                self.model.classifier = nn.Linear(int(self.model.classifier.in_features), self.num_classes, bias=True)
         elif hasattr(self.model, 'head'):
             # Swin Transformer使用head层
             self.model.head = nn.Linear(int(self.model.head.in_features), self.num_classes, bias=True)
         else:
             raise ValueError(f"Model {self.name} does not have recognizable classifier layer")
-        print(self.model)
         print(f'模型{self.name}结构已经加载,移动到设备{self.device}')
         # 将模型移动到GPU/cpu
         self.model = self.model.to(self.device)
@@ -291,7 +304,7 @@ if __name__ == '__main__':
     parser = argparse.ArgumentParser('预训练模型调参')
     parser.add_argument('--train_dir',default='./label_data/train',help='help')
     parser.add_argument('--val_dir', default='./label_data/test',help='help')
-    parser.add_argument('--model', default='shufflenet',help='help')
+    parser.add_argument('--model', default='squeezenet',help='help')
     args = parser.parse_args()
     num_epochs = 100
     trainer = Trainer(batch_size=int(os.getenv('BATCH_SIZE', 32)),

+ 31 - 19
video_test.py

@@ -37,31 +37,41 @@ class Predictor:
         if self.model is not None:
             return
         print(f"正在加载模型: {self.model_name}")
-        name = self.model_name
         # 加载模型
-        if name == 'resnet50':
-
+        if self.model_name== 'resnet50':
             self.model = resnet50()
-        elif name == 'squeezenet':
+        elif self.model_name == 'squeezenet':
 
             self.model = squeezenet1_0()
-        elif name == 'shufflenet':
+        elif self.model_name == 'shufflenet':
             self.model = shufflenet_v2_x1_0()
         else:
-            raise ValueError(f"Invalid model name: {name}")
+            raise ValueError(f"Invalid model name: {self.model_name}")
         # 替换最后的分类层以适应新的分类任务
         if hasattr(self.model, 'fc'):
             # ResNet系列模型
             self.model.fc = nn.Linear(int(self.model.fc.in_features), self.num_classes, bias=self.use_bias)
         elif hasattr(self.model, 'classifier'):
-            # Swin Transformer等模型
-            self.model.classifier = nn.Linear(int(self.model.classifier.in_features), self.num_classes, bias=self.use_bias)
+            # SqueezeNet、ShuffleNet系列模型
+            if self.model_name == 'squeezenet':
+                # 获取SqueezeNet的最后一个卷积层的输入通道数
+                final_conv_in_channels = self.model.classifier[1].in_channels
+                # 替换classifier为新的Sequential,将输出改为2类
+                self.model.classifier = nn.Sequential(
+                    nn.Dropout(p=0.5),
+                    nn.Conv2d(final_conv_in_channels, self.num_classes, kernel_size=(1, 1)),
+                    nn.ReLU(inplace=True),
+                    nn.AdaptiveAvgPool2d((1, 1))
+                )
+            else:
+                # Swin Transformer等模型
+                self.model.classifier = nn.Linear(int(self.model.classifier.in_features), self.num_classes, bias=True)
         elif hasattr(self.model, 'head'):
             # Swin Transformer使用head层
             self.model.head = nn.Linear(int(self.model.head.in_features), self.num_classes, bias=self.use_bias)
 
         else:
-            raise ValueError(f"Model {name} does not have recognizable classifier layer")
+            raise ValueError(f"Model {self.model_name} does not have recognizable classifier layer")
         print(self.model)
         # 加载训练好的权重
         self.model.load_state_dict(torch.load(self.weights_path, map_location=torch.device('cpu')))
@@ -195,10 +205,10 @@ def fileter_prediction(predicted_class, confidence, pre_rows, pre_cols, filter_d
 def discriminate_ratio(water_pre_list:list):
     # 方式一:60%以上的帧存在浑浊水体
     water_pre_arr = np.array(water_pre_list, dtype=np.float32)
-    water_pre_arr_mean = np.mean(water_pre_arr, axis=0)
-    bad_water = np.array(water_pre_arr_mean >= 0.6, dtype=np.int32)
-    bad_flag = np.sum(bad_water, dtype=np.int32)
-    print(f'浑浊比例方式:该时间段是否存在浑浊水体:{bool(bad_flag)}')
+    water_pre_arr_sum = np.sum(water_pre_arr, axis=0)
+    bad_water = np.array(water_pre_arr_sum >= 0.6*len(water_pre_list), dtype=np.int32)
+    bad_flag = bool(np.sum(bad_water, dtype=np.int32) > 2)  # 大于两个patch符合要求才可以
+    print(f'浑浊比例判别:该时间段是否存在浑浊水体:{bad_flag}')
     return bad_flag
 
 
@@ -223,7 +233,7 @@ def main():
     predictor = Predictor(model_name='shufflenet',
                           weights_path=r'./shufflenet.pth',
                           num_classes=2)
-    input_path = r'D:\code\water_turbidity_det\frame_data\test\20251225\video4_20251129120320_20251129123514'
+    input_path = r'D:\code\water_turbidity_det\frame_data\train\20251230\4video_20251229160514'
     # 预处理图像
     all_imgs = os.listdir(input_path)
     all_imgs = [os.path.join(input_path, p) for p in all_imgs if p.split('.')[-1] in ['jpg', 'png']]
@@ -242,17 +252,19 @@ def main():
     for img_path in all_imgs:
         image = Image.open(img_path).convert('RGB')
         # 预处理
-        patches_index, image_tensor = preprocess_image(image)
+        patches_index, image_tensor = preprocess_image(image) # patches_index:list[tuple, ...]
         # 推理
-        confidence, predicted_class  = predictor.predict(image_tensor)
+        confidence, predicted_class  = predictor.predict(image_tensor)  # confidence: np.ndarray, shape=(x,), predicted_class: np.ndarray, shape=(x,), raw_outputs: np.ndarray, shape=(x,)
         # 第一层虚警抑制,置信度过滤,低于阈值将会被忽略
         for i in range(len(confidence)):
-            if confidence[i] < confidence_threshold:
+            if confidence[i] < confidence_threshold and predicted_class[i] == 1:
                 confidence[i] = 1.0
                 predicted_class[i] = 0
         # 第二层虚警抑制,空间滤波
         # 在此处添加过滤逻辑
+        # print('原始预测结果:', predicted_class)
         new_confidence, new_predicted_class = fileter_prediction(predicted_class, confidence, pre_rows, pre_cols, filter_down_limit=3)
+        # print('过滤后预测结果:', new_predicted_class)
         # 可视化预测结果
         image = cv2.imread(img_path)
         image = draw_grid(image, patch_w, patch_h)
@@ -267,9 +279,9 @@ def main():
         cv2.imshow('image', image)
         cv2.imshow('image_filter', new_img)
 
-        cv2.waitKey(20)
+        cv2.waitKey(25)
         # 方式1判别
-        if len(water_pre_list) > 20:
+        if len(water_pre_list) > 25:
             flag = discriminate_ratio(water_pre_list) and flag
             water_pre_list = []
             print('综合判别结果:', flag)