|
|
@@ -0,0 +1,246 @@
|
|
|
+import argparse
|
|
|
+import sophon.sail as sail
|
|
|
+import cv2
|
|
|
+import os
|
|
|
+import logging
|
|
|
+import json
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+class Predictor:
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ # 加载推理引擎
|
|
|
+ self.net = sail.Engine(args.bmodel, args.dev_id, sail.IOMode.SYSIO)
|
|
|
+ self.graph_name = self.net.get_graph_names()[0]
|
|
|
+ self.input_names = self.net.get_input_names(self.graph_name)
|
|
|
+ self.input_shapes = [self.net.get_input_shape(self.graph_name, name) for name in self.input_names]
|
|
|
+ self.output_names = self.net.get_output_names(self.graph_name)
|
|
|
+ self.output_shapes = [self.net.get_output_shape(self.graph_name, name) for name in self.output_names] # [[1, 2]]
|
|
|
+ self.input_name = self.input_names[0]
|
|
|
+ self.input_shape = self.input_shapes[0] # [1, 3, 256, 256]
|
|
|
+ self.batch_size = self.input_shape[0]
|
|
|
+ self.net_h = self.input_shape[2] # 输入图像patch的高
|
|
|
+ self.net_w = self.input_shape[3] # 输入图像patch的宽
|
|
|
+ # 归一化参数,采用imagenet预训练参数
|
|
|
+ self.mean = [0.485, 0.456, 0.406]
|
|
|
+ self.std = [0.229, 0.224, 0.225]
|
|
|
+ # 一张图像有多少行、列的patch
|
|
|
+ self.current_patch_rows = 0
|
|
|
+ self.current_patch_cols = 0
|
|
|
+ # 输入的大图高度和宽度
|
|
|
+ self.input_image_h = 0
|
|
|
+ self.input_image_w = 0
|
|
|
+ self.print_network_info()
|
|
|
+ # 报警的置信度阈值
|
|
|
+ self.confidence_threshold = 0.6
|
|
|
+ # 存储连续的帧的检测结果
|
|
|
+ self.continuous_detection_result = []
|
|
|
+ # 连续浑浊的patch计数器
|
|
|
+ self.continuous_counter_mat = np.array(0)
|
|
|
+ # 判定有浑浊水体时的连续帧数量
|
|
|
+ self.max_continuous_frames = 15
|
|
|
+
|
|
|
+ def __call__(self, img) -> bool:
|
|
|
+ # 预处理,获取输入图像的patch序列和左上角坐标
|
|
|
+ patches_index, patches = self.preprocess(img)
|
|
|
+ # 推理
|
|
|
+ confidence = []
|
|
|
+ predicted_class = []
|
|
|
+ for i in range(0, len(patches), self.batch_size):# 根据模型的输入batch size创建图像的tensor
|
|
|
+ batch_patches = patches[i:i + self.batch_size]
|
|
|
+ # 处理尾部, 当最后一批数据不足batch size时,使用最后一个patch填充
|
|
|
+ if len(batch_patches) < self.batch_size:
|
|
|
+ batch_patches += [batch_patches[-1]] * (self.batch_size - len(batch_patches))
|
|
|
+ patches_tensor = np.stack(batch_patches)
|
|
|
+ # print('推理中:', patches_tensor.shape)
|
|
|
+ # 调用推理引擎
|
|
|
+ batch_confi, batch_cls = self.predict(patches_tensor) # 返回值是二维数组,形状为[batch_size, cls]
|
|
|
+ confidence += batch_confi
|
|
|
+ predicted_class += batch_cls
|
|
|
+ confidence = np.array(confidence[:len(patches)])
|
|
|
+ predicted_class = np.array(predicted_class[:len(patches)])
|
|
|
+ # 后处理, 报警逻辑
|
|
|
+ # print('推理置信度:', confidence)
|
|
|
+ # print('原始预测结果:', predicted_class)
|
|
|
+ alarm = self.postprocess(confidence=confidence, predicted_class=predicted_class)
|
|
|
+ return alarm
|
|
|
+
|
|
|
+ def print_network_info(self):
|
|
|
+ info = {
|
|
|
+ 'Graph Name': self.graph_name,
|
|
|
+ 'Input Name': self.input_name,
|
|
|
+ 'Output Names': self.output_names,
|
|
|
+ 'Output Shapes': self.output_shapes,
|
|
|
+ 'Input Shape': self.input_shape,
|
|
|
+ 'Batch Size': self.batch_size,
|
|
|
+ 'Height': self.net_h,
|
|
|
+ 'Width': self.net_w,
|
|
|
+ 'Mean': self.mean,
|
|
|
+ 'Std': self.std
|
|
|
+ }
|
|
|
+
|
|
|
+ print("=" * 50)
|
|
|
+ print("Network Configuration Info")
|
|
|
+ print("=" * 50)
|
|
|
+ for key, value in info.items():
|
|
|
+ print(f"{key:<18}: {value}")
|
|
|
+ print("=" * 50)
|
|
|
+
|
|
|
+ def predict(self, input_img):
|
|
|
+ input_data = {self.input_name: input_img}
|
|
|
+ outputs = self.net.process(self.graph_name, input_data)
|
|
|
+ # print('predict fun:', outputs)
|
|
|
+ outputs = list(outputs.values())[0]
|
|
|
+ # print('predict fun return:', outputs)
|
|
|
+ outputs_exp = np.exp(outputs)
|
|
|
+ # print('exp res:', outputs_exp)
|
|
|
+ outputs = outputs_exp / np.sum(outputs_exp, axis=1)[:, None]
|
|
|
+ # print('softmax res:', outputs)
|
|
|
+ confidence = np.max(outputs, axis=1)
|
|
|
+ # print('最大概率:', confidence)
|
|
|
+ predictions = np.argmax(outputs, axis=1) # 返回最大概率的类别
|
|
|
+ # print('预测结果:', predictions)
|
|
|
+ return confidence.tolist(), predictions.tolist()
|
|
|
+
|
|
|
+ def preprocess(self, img: np.ndarray):
|
|
|
+ """用于视频报警的预处理,将一张图像从左到右从上到下以此剪裁为patch序列
|
|
|
+ 输入:完整图像
|
|
|
+ 输出:
|
|
|
+ """
|
|
|
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
+ img = img.astype('float32')
|
|
|
+ img = (img / 255. - self.mean) / self.std # 这一步是很有必要的, 因为编译过程并不会帮你做归一化,所以这里要自己做归一化,否则预测数值可能会非常不准确
|
|
|
+ img_h, img_w, _ = img.shape
|
|
|
+ img = np.transpose(img, (2, 0, 1)) # channel first
|
|
|
+ # 自上而下,从左到右
|
|
|
+ patches = []
|
|
|
+ patch_index = []
|
|
|
+ if img_h != self.input_image_h or img_w != self.input_image_w: # 只在输入图像尺寸改变时重新计算patch的行列数量
|
|
|
+ self.input_image_h = img_h
|
|
|
+ self.input_image_w = img_w
|
|
|
+ self.current_patch_rows = self.input_image_h // self.net_h + 1
|
|
|
+ self.current_patch_cols = self.input_image_w // self.net_w + 1
|
|
|
+ # 此处初始化
|
|
|
+ print('初始化连续计数器矩阵')
|
|
|
+ self.continuous_counter_mat = np.zeros(self.current_patch_rows*self.current_patch_cols, dtype=np.int32)
|
|
|
+ # 以此剪裁patch
|
|
|
+ for i in range(self.current_patch_rows):
|
|
|
+ for j in range(self.current_patch_cols):
|
|
|
+ start_row = i * self.net_h
|
|
|
+ start_col = j * self.net_w
|
|
|
+ end_row = min(start_row + self.net_h, self.input_image_h)
|
|
|
+ end_col = min(start_col + self.net_w, self.input_image_w)
|
|
|
+ patch = img[::, start_row:end_row, start_col:end_col]
|
|
|
+ _, patch_h, patch_w = patch.shape
|
|
|
+ if patch_h != self.net_h or patch_w != self.net_w:
|
|
|
+ patch = np.transpose(patch, (1, 2, 0))
|
|
|
+ patch = cv2.resize(patch, (self.net_w, self.net_h))
|
|
|
+ patch = np.transpose(patch, (2, 0, 1))
|
|
|
+ patches.append(patch) # 图块
|
|
|
+ patch_index.append((start_col, start_row)) # 图块的左上角坐标
|
|
|
+ return patch_index, patches
|
|
|
+
|
|
|
+ def postprocess(self, confidence, predicted_class):
|
|
|
+ """根据预测结果判定是否报警"""
|
|
|
+ # 第一层虚警抑制,置信度过滤,低于阈值将会被忽略
|
|
|
+ for i in range(len(confidence)):
|
|
|
+ if confidence[i] < self.confidence_threshold and predicted_class[i] == 1:
|
|
|
+ confidence[i] = 1.0
|
|
|
+ predicted_class[i] = 0
|
|
|
+ # 第二层虚警抑制,空间滤波
|
|
|
+ confidence, predicted_class = self.filter_prediction(predicted_class=predicted_class, confidence=confidence, filter_down_limit=3)
|
|
|
+ # print('最终结果:', confidence)
|
|
|
+ # print('后处理,空间滤波后结果:', predicted_class)
|
|
|
+ # 第三层 时间滤波
|
|
|
+ self.continuous_detection_result.append(predicted_class)
|
|
|
+ flag = self.update_continuous_counter(predicted_class) # 连续帧滤波
|
|
|
+ if len(self.continuous_detection_result) >= 20:
|
|
|
+ flag = self.discriminate_ratio() and flag
|
|
|
+ print(f'是否存在浑浊水体,综合判别结果:{flag}')
|
|
|
+ self.continuous_detection_result.pop(0)
|
|
|
+ return flag
|
|
|
+ return False
|
|
|
+ @staticmethod
|
|
|
+ def get_33_patch(arr: np.ndarray, center_row: int, center_col: int):
|
|
|
+ """以(center_row,center_col)为中心,从arr中取出来3*3区域的数据"""
|
|
|
+ # 边界检查
|
|
|
+ h, w = arr.shape
|
|
|
+ safe_row_up_limit = max(0, center_row - 1)
|
|
|
+ safe_row_bottom_limit = min(h, center_row + 2)
|
|
|
+ safe_col_left_limit = max(0, center_col - 1)
|
|
|
+ safe_col_right_limit = min(w, center_col + 2)
|
|
|
+ return arr[safe_row_up_limit:safe_row_bottom_limit, safe_col_left_limit:safe_col_right_limit]
|
|
|
+
|
|
|
+ def update_continuous_counter(self, pre_class_arr):
|
|
|
+ """连续帧判别"""
|
|
|
+ positive_index = np.array(pre_class_arr, dtype=np.int32) > 0
|
|
|
+ negative_index = np.array(pre_class_arr, dtype=np.int32) == 0
|
|
|
+ # 给负样本区域置零
|
|
|
+ self.continuous_counter_mat[negative_index] -= 1
|
|
|
+ # 给正样本区域加1
|
|
|
+ self.continuous_counter_mat[positive_index] += 1
|
|
|
+ # 保证不出现负数
|
|
|
+ self.continuous_counter_mat[self.continuous_counter_mat<0] = 0
|
|
|
+ # 判断浑浊
|
|
|
+ bad_flag = bool(np.sum(self.continuous_counter_mat > self.max_continuous_frames) > 2) # 两个以上的patch满足条件
|
|
|
+ print('连续帧信息:', self.continuous_counter_mat)
|
|
|
+ print(f'连续帧判别:该时间段是否存在浑浊水体:{bad_flag}')
|
|
|
+ return bad_flag
|
|
|
+
|
|
|
+ def discriminate_ratio(self):
|
|
|
+ water_pre_list = self.continuous_detection_result.copy()
|
|
|
+ # 方式一:60%以上的帧存在浑浊水体
|
|
|
+ water_pre_arr = np.array(water_pre_list, dtype=np.float32)
|
|
|
+ water_pre_arr_sum = np.sum(water_pre_arr, axis=0)
|
|
|
+ bad_water = np.array(water_pre_arr_sum >= 0.6 * len(water_pre_list), dtype=np.int32)
|
|
|
+ bad_flag = bool(np.sum(bad_water, dtype=np.int32) > 2) # 大于两个patch符合要求才可以
|
|
|
+ # print('比例信息:',water_pre_arr_sum )
|
|
|
+ print(f'浑浊比例判别:该时间段是否存在浑浊水体:{bad_flag}')
|
|
|
+ return bad_flag
|
|
|
+
|
|
|
+ def filter_prediction(self, predicted_class, confidence, filter_down_limit=3):
|
|
|
+ """预测结果矩阵滤波,九宫格内部存在浑浊水体的数量需要大于filter_down_limit,"""
|
|
|
+ predicted_class_mat = np.resize(predicted_class, (self.current_patch_rows, self.current_patch_cols))
|
|
|
+ predicted_conf_mat = np.resize(confidence, (self.current_patch_rows, self.current_patch_cols))
|
|
|
+ new_predicted_class_mat = predicted_class_mat.copy()
|
|
|
+ new_predicted_conf_mat = predicted_conf_mat.copy()
|
|
|
+ for i in range(self.current_patch_rows):
|
|
|
+ for j in range(self.current_patch_cols):
|
|
|
+ if (1. - predicted_class_mat[i, j]) > 0.1:
|
|
|
+ continue # 跳过背景类
|
|
|
+ core_region = self.get_33_patch(predicted_class_mat, i, j)
|
|
|
+ if np.sum(core_region) < filter_down_limit:
|
|
|
+ new_predicted_class_mat[i, j] = 0 # 重置为背景类
|
|
|
+ new_predicted_conf_mat[i, j] = 1.0
|
|
|
+ return new_predicted_conf_mat.flatten(), new_predicted_class_mat.flatten()
|
|
|
+def argsparser():
|
|
|
+ parser = argparse.ArgumentParser(prog=__file__)
|
|
|
+ parser.add_argument('--input','-i', type=str, default=r'4_video_20251223163145', help='path of input, must be image directory')
|
|
|
+ parser.add_argument('--bmodel','-b', type=str, default='./shufflenet_f32.bmodel', help='path of bmodel')
|
|
|
+ parser.add_argument('--dev_id','-d', type=int, default=0, help='tpu id')
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ return args
|
|
|
+
|
|
|
+def main(args):
|
|
|
+ """函数的目的是为了实现一个能够报警的完整业务逻辑
|
|
|
+ 输入:路径,包含了一个视频的帧序列,按照时间展开
|
|
|
+ 输出:确认存在浑浊的水体,是表示报警,否表示无明显浑浊水体
|
|
|
+ """
|
|
|
+ # 加载推理引擎
|
|
|
+ predictor = Predictor()
|
|
|
+ # 获取图片
|
|
|
+ all_imgs = [os.path.join(args.input, i) for i in sorted(os.listdir(args.input))]
|
|
|
+ for img_path in all_imgs:
|
|
|
+ img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
|
|
+ # 跳过空图片
|
|
|
+ if img is None:
|
|
|
+ continue
|
|
|
+ if img.size == 0:
|
|
|
+ continue
|
|
|
+ print("正在处理:", img_path)
|
|
|
+ print('污水警报', predictor(img))
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ args = argsparser()
|
|
|
+ main(args)
|