import argparse import sophon.sail as sail import cv2 import os import logging import json import numpy as np class Predictor: def __init__(self): # 加载推理引擎 self.net = sail.Engine(args.bmodel, args.dev_id, sail.IOMode.SYSIO) self.graph_name = self.net.get_graph_names()[0] self.input_names = self.net.get_input_names(self.graph_name) self.input_shapes = [self.net.get_input_shape(self.graph_name, name) for name in self.input_names] self.output_names = self.net.get_output_names(self.graph_name) self.output_shapes = [self.net.get_output_shape(self.graph_name, name) for name in self.output_names] # [[1, 2]] self.input_name = self.input_names[0] self.input_shape = self.input_shapes[0] # [1, 3, 256, 256] self.batch_size = self.input_shape[0] self.net_h = self.input_shape[2] # 输入图像patch的高 self.net_w = self.input_shape[3] # 输入图像patch的宽 # 归一化参数,采用imagenet预训练参数 self.mean = [0.485, 0.456, 0.406] self.std = [0.229, 0.224, 0.225] # 一张图像有多少行、列的patch self.current_patch_rows = 0 self.current_patch_cols = 0 # 输入的大图高度和宽度 self.input_image_h = 0 self.input_image_w = 0 self.print_network_info() # 报警的置信度阈值 self.confidence_threshold = 0.6 # 存储连续的帧的检测结果 self.continuous_detection_result = [] # 连续浑浊的patch计数器 self.continuous_counter_mat = np.array(0) # 判定有浑浊水体时的连续帧数量 self.max_continuous_frames = 15 def __call__(self, img) -> bool: # 预处理,获取输入图像的patch序列和左上角坐标 patches_index, patches = self.preprocess(img) # 推理 confidence = [] predicted_class = [] for i in range(0, len(patches), self.batch_size):# 根据模型的输入batch size创建图像的tensor batch_patches = patches[i:i + self.batch_size] # 处理尾部, 当最后一批数据不足batch size时,使用最后一个patch填充 if len(batch_patches) < self.batch_size: batch_patches += [batch_patches[-1]] * (self.batch_size - len(batch_patches)) patches_tensor = np.stack(batch_patches) # print('推理中:', patches_tensor.shape) # 调用推理引擎 batch_confi, batch_cls = self.predict(patches_tensor) # 返回值是二维数组,形状为[batch_size, cls] confidence += batch_confi predicted_class += batch_cls confidence = np.array(confidence[:len(patches)]) predicted_class = np.array(predicted_class[:len(patches)]) # 后处理, 报警逻辑 # print('推理置信度:', confidence) # print('原始预测结果:', predicted_class) alarm = self.postprocess(confidence=confidence, predicted_class=predicted_class) return alarm def print_network_info(self): info = { 'Graph Name': self.graph_name, 'Input Name': self.input_name, 'Output Names': self.output_names, 'Output Shapes': self.output_shapes, 'Input Shape': self.input_shape, 'Batch Size': self.batch_size, 'Height': self.net_h, 'Width': self.net_w, 'Mean': self.mean, 'Std': self.std } print("=" * 50) print("Network Configuration Info") print("=" * 50) for key, value in info.items(): print(f"{key:<18}: {value}") print("=" * 50) def predict(self, input_img): input_data = {self.input_name: input_img} outputs = self.net.process(self.graph_name, input_data) # print('predict fun:', outputs) outputs = list(outputs.values())[0] # print('predict fun return:', outputs) outputs_exp = np.exp(outputs) # print('exp res:', outputs_exp) outputs = outputs_exp / np.sum(outputs_exp, axis=1)[:, None] # print('softmax res:', outputs) confidence = np.max(outputs, axis=1) # print('最大概率:', confidence) predictions = np.argmax(outputs, axis=1) # 返回最大概率的类别 # print('预测结果:', predictions) return confidence.tolist(), predictions.tolist() def preprocess(self, img: np.ndarray): """用于视频报警的预处理,将一张图像从左到右从上到下以此剪裁为patch序列 输入:完整图像 输出: """ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img.astype('float32') img = (img / 255. - self.mean) / self.std # 这一步是很有必要的, 因为编译过程并不会帮你做归一化,所以这里要自己做归一化,否则预测数值可能会非常不准确 img_h, img_w, _ = img.shape img = np.transpose(img, (2, 0, 1)) # channel first # 自上而下,从左到右 patches = [] patch_index = [] if img_h != self.input_image_h or img_w != self.input_image_w: # 只在输入图像尺寸改变时重新计算patch的行列数量 self.input_image_h = img_h self.input_image_w = img_w self.current_patch_rows = self.input_image_h // self.net_h + 1 self.current_patch_cols = self.input_image_w // self.net_w + 1 # 此处初始化 print('初始化连续计数器矩阵') self.continuous_counter_mat = np.zeros(self.current_patch_rows*self.current_patch_cols, dtype=np.int32) # 以此剪裁patch for i in range(self.current_patch_rows): for j in range(self.current_patch_cols): start_row = i * self.net_h start_col = j * self.net_w end_row = min(start_row + self.net_h, self.input_image_h) end_col = min(start_col + self.net_w, self.input_image_w) patch = img[::, start_row:end_row, start_col:end_col] _, patch_h, patch_w = patch.shape if patch_h != self.net_h or patch_w != self.net_w: patch = np.transpose(patch, (1, 2, 0)) patch = cv2.resize(patch, (self.net_w, self.net_h)) patch = np.transpose(patch, (2, 0, 1)) patches.append(patch) # 图块 patch_index.append((start_col, start_row)) # 图块的左上角坐标 return patch_index, patches def postprocess(self, confidence, predicted_class): """根据预测结果判定是否报警""" # 第一层虚警抑制,置信度过滤,低于阈值将会被忽略 for i in range(len(confidence)): if confidence[i] < self.confidence_threshold and predicted_class[i] == 1: confidence[i] = 1.0 predicted_class[i] = 0 # 第二层虚警抑制,空间滤波 confidence, predicted_class = self.filter_prediction(predicted_class=predicted_class, confidence=confidence, filter_down_limit=3) # print('最终结果:', confidence) # print('后处理,空间滤波后结果:', predicted_class) # 第三层 时间滤波 self.continuous_detection_result.append(predicted_class) flag = self.update_continuous_counter(predicted_class) # 连续帧滤波 if len(self.continuous_detection_result) >= 20: flag = self.discriminate_ratio() and flag print(f'是否存在浑浊水体,综合判别结果:{flag}') self.continuous_detection_result.pop(0) return flag return False @staticmethod def get_33_patch(arr: np.ndarray, center_row: int, center_col: int): """以(center_row,center_col)为中心,从arr中取出来3*3区域的数据""" # 边界检查 h, w = arr.shape safe_row_up_limit = max(0, center_row - 1) safe_row_bottom_limit = min(h, center_row + 2) safe_col_left_limit = max(0, center_col - 1) safe_col_right_limit = min(w, center_col + 2) return arr[safe_row_up_limit:safe_row_bottom_limit, safe_col_left_limit:safe_col_right_limit] def update_continuous_counter(self, pre_class_arr): """连续帧判别""" positive_index = np.array(pre_class_arr, dtype=np.int32) > 0 negative_index = np.array(pre_class_arr, dtype=np.int32) == 0 # 给负样本区域置零 self.continuous_counter_mat[negative_index] -= 1 # 给正样本区域加1 self.continuous_counter_mat[positive_index] += 1 # 保证不出现负数 self.continuous_counter_mat[self.continuous_counter_mat<0] = 0 # 判断浑浊 bad_flag = bool(np.sum(self.continuous_counter_mat > self.max_continuous_frames) > 2) # 两个以上的patch满足条件 # print('连续帧信息:', self.continuous_counter_mat) print(f'连续帧判别:该时间段是否存在浑浊水体:{bad_flag}') return bad_flag def discriminate_ratio(self): water_pre_list = self.continuous_detection_result.copy() # 方式一:60%以上的帧存在浑浊水体 water_pre_arr = np.array(water_pre_list, dtype=np.float32) water_pre_arr_sum = np.sum(water_pre_arr, axis=0) bad_water = np.array(water_pre_arr_sum >= 0.6 * len(water_pre_list), dtype=np.int32) bad_flag = bool(np.sum(bad_water, dtype=np.int32) > 2) # 大于两个patch符合要求才可以 # print('比例信息:',water_pre_arr_sum ) print(f'浑浊比例判别:该时间段是否存在浑浊水体:{bad_flag}') return bad_flag def filter_prediction(self, predicted_class, confidence, filter_down_limit=3): """预测结果矩阵滤波,九宫格内部存在浑浊水体的数量需要大于filter_down_limit,""" predicted_class_mat = np.resize(predicted_class, (self.current_patch_rows, self.current_patch_cols)) predicted_conf_mat = np.resize(confidence, (self.current_patch_rows, self.current_patch_cols)) new_predicted_class_mat = predicted_class_mat.copy() new_predicted_conf_mat = predicted_conf_mat.copy() for i in range(self.current_patch_rows): for j in range(self.current_patch_cols): if (1. - predicted_class_mat[i, j]) > 0.1: continue # 跳过背景类 core_region = self.get_33_patch(predicted_class_mat, i, j) if np.sum(core_region) < filter_down_limit: new_predicted_class_mat[i, j] = 0 # 重置为背景类 new_predicted_conf_mat[i, j] = 1.0 return new_predicted_conf_mat.flatten(), new_predicted_class_mat.flatten() def argsparser(): parser = argparse.ArgumentParser(prog=__file__) parser.add_argument('--input','-i', type=str, default=r'4_video_20251223163145', help='path of input, must be image directory') parser.add_argument('--bmodel','-b', type=str, default='./shufflenet_f32.bmodel', help='path of bmodel') parser.add_argument('--dev_id','-d', type=int, default=0, help='tpu id') args = parser.parse_args() return args def main(args): """函数的目的是为了实现一个能够报警的完整业务逻辑 输入:路径,包含了一个视频的帧序列,按照时间展开 输出:确认存在浑浊的水体,是表示报警,否表示无明显浑浊水体 """ # 加载推理引擎 predictor = Predictor() # 获取图片 all_imgs = [os.path.join(args.input, i) for i in sorted(os.listdir(args.input))] for img_path in all_imgs: img = cv2.imread(img_path, cv2.IMREAD_COLOR) # 跳过空图片 if img is None: continue if img.size == 0: continue print("正在处理:", img_path) print('污水警报', predictor(img)) if __name__ == '__main__': args = argsparser() main(args)