bmodel_application.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. import argparse
  2. import sophon.sail as sail
  3. import cv2
  4. import os
  5. import logging
  6. import json
  7. import numpy as np
  8. class Predictor:
  9. def __init__(self):
  10. # 加载推理引擎
  11. self.net = sail.Engine(args.bmodel, args.dev_id, sail.IOMode.SYSIO)
  12. self.graph_name = self.net.get_graph_names()[0]
  13. self.input_names = self.net.get_input_names(self.graph_name)
  14. self.input_shapes = [self.net.get_input_shape(self.graph_name, name) for name in self.input_names]
  15. self.output_names = self.net.get_output_names(self.graph_name)
  16. self.output_shapes = [self.net.get_output_shape(self.graph_name, name) for name in self.output_names] # [[1, 2]]
  17. self.input_name = self.input_names[0]
  18. self.input_shape = self.input_shapes[0] # [1, 3, 256, 256]
  19. self.batch_size = self.input_shape[0]
  20. self.net_h = self.input_shape[2] # 输入图像patch的高
  21. self.net_w = self.input_shape[3] # 输入图像patch的宽
  22. # 归一化参数,采用imagenet预训练参数
  23. self.mean = [0.485, 0.456, 0.406]
  24. self.std = [0.229, 0.224, 0.225]
  25. # 一张图像有多少行、列的patch
  26. self.current_patch_rows = 0
  27. self.current_patch_cols = 0
  28. # 输入的大图高度和宽度
  29. self.input_image_h = 0
  30. self.input_image_w = 0
  31. # 报警的置信度阈值
  32. self.confidence_threshold = 0.85
  33. # 存储连续的帧的检测结果
  34. self.continuous_detection_result = []
  35. # 连续浑浊的patch计数器
  36. self.continuous_counter_mat = np.array(0)
  37. # 判定有浑浊水体时的连续帧数量
  38. self.max_continuous_frames = 50
  39. # 比例判定的起始帧数
  40. self.start_frame_num = self.max_continuous_frames
  41. # 比例判定的阈值
  42. self.ratio_threshold = 0.85
  43. self.print_network_info()
  44. def __call__(self, img) -> bool:
  45. # 预处理,获取输入图像的patch序列和左上角坐标
  46. patches_index, patches = self.preprocess(img)
  47. # 推理
  48. confidence = []
  49. predicted_class = []
  50. for i in range(0, len(patches), self.batch_size):# 根据模型的输入batch size创建图像的tensor
  51. batch_patches = patches[i:i + self.batch_size]
  52. # 处理尾部, 当最后一批数据不足batch size时,使用最后一个patch填充
  53. if len(batch_patches) < self.batch_size:
  54. batch_patches += [batch_patches[-1]] * (self.batch_size - len(batch_patches))
  55. patches_tensor = np.stack(batch_patches)
  56. # print('推理中:', patches_tensor.shape)
  57. # 调用推理引擎
  58. batch_confi, batch_cls = self.predict(patches_tensor) # 返回值是二维数组,形状为[batch_size, cls]
  59. confidence += batch_confi
  60. predicted_class += batch_cls
  61. confidence = np.array(confidence[:len(patches)])
  62. predicted_class = np.array(predicted_class[:len(patches)])
  63. # 后处理, 报警逻辑
  64. # print('推理置信度:', confidence)
  65. # print('原始预测结果:', predicted_class)
  66. alarm = self.postprocess(confidence=confidence, predicted_class=predicted_class)
  67. return alarm
  68. def print_network_info(self):
  69. info = {
  70. 'Graph Name': self.graph_name,
  71. 'Input Name': self.input_name,
  72. 'Output Names': self.output_names,
  73. 'Output Shapes': self.output_shapes,
  74. 'Input Shape': self.input_shape,
  75. 'Batch Size': self.batch_size,
  76. 'Height': self.net_h,
  77. 'Width': self.net_w,
  78. 'Mean': self.mean,
  79. 'Std': self.std,
  80. 'Input Image Size': [self.input_image_h, self.input_image_w],
  81. 'Confidence Threshold': self.confidence_threshold,
  82. 'Max Continuous Frames': self.max_continuous_frames,
  83. 'Ratio Threshold': self.ratio_threshold,
  84. 'Start Frame Num': self.start_frame_num
  85. }
  86. print("=" * 50)
  87. print("Network Configuration Info")
  88. print("=" * 50)
  89. for key, value in info.items():
  90. print(f"{key:<18}: {value}")
  91. print("=" * 50)
  92. def predict(self, input_img):
  93. input_data = {self.input_name: input_img}
  94. outputs = self.net.process(self.graph_name, input_data)
  95. # print('predict fun:', outputs)
  96. outputs = list(outputs.values())[0]
  97. # print('predict fun return:', outputs)
  98. outputs_exp = np.exp(outputs)
  99. # print('exp res:', outputs_exp)
  100. outputs = outputs_exp / np.sum(outputs_exp, axis=1)[:, None]
  101. # print('softmax res:', outputs)
  102. confidence = np.max(outputs, axis=1)
  103. # print('最大概率:', confidence)
  104. predictions = np.argmax(outputs, axis=1) # 返回最大概率的类别
  105. # print('预测结果:', predictions)
  106. return confidence.tolist(), predictions.tolist()
  107. def preprocess(self, img: np.ndarray):
  108. """用于视频报警的预处理,将一张图像从左到右从上到下以此剪裁为patch序列
  109. 输入:完整图像
  110. 输出:
  111. """
  112. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  113. img = img.astype('float32')
  114. img = (img / 255. - self.mean) / self.std # 这一步是很有必要的, 因为编译过程并不会帮你做归一化,所以这里要自己做归一化,否则预测数值可能会非常不准确
  115. img_h, img_w, _ = img.shape
  116. img = np.transpose(img, (2, 0, 1)) # channel first
  117. # 自上而下,从左到右
  118. patches = []
  119. patch_index = []
  120. if img_h != self.input_image_h or img_w != self.input_image_w: # 只在输入图像尺寸改变时重新计算patch的行列数量
  121. self.input_image_h = img_h
  122. self.input_image_w = img_w
  123. self.current_patch_rows = self.input_image_h // self.net_h + 1
  124. self.current_patch_cols = self.input_image_w // self.net_w + 1
  125. # 此处初始化
  126. print('初始化连续计数器矩阵')
  127. self.continuous_counter_mat = np.zeros(self.current_patch_rows*self.current_patch_cols, dtype=np.int32)
  128. # 以此剪裁patch
  129. for i in range(self.current_patch_rows):
  130. for j in range(self.current_patch_cols):
  131. start_row = i * self.net_h
  132. start_col = j * self.net_w
  133. end_row = min(start_row + self.net_h, self.input_image_h)
  134. end_col = min(start_col + self.net_w, self.input_image_w)
  135. patch = img[::, start_row:end_row, start_col:end_col]
  136. _, patch_h, patch_w = patch.shape
  137. if patch_h != self.net_h or patch_w != self.net_w:
  138. patch = np.transpose(patch, (1, 2, 0))
  139. patch = cv2.resize(patch, (self.net_w, self.net_h))
  140. patch = np.transpose(patch, (2, 0, 1))
  141. patches.append(patch) # 图块
  142. patch_index.append((start_col, start_row)) # 图块的左上角坐标
  143. return patch_index, patches
  144. def postprocess(self, confidence, predicted_class):
  145. """根据预测结果判定是否报警"""
  146. # 第一层虚警抑制,置信度过滤,低于阈值将会被忽略
  147. for i in range(len(confidence)):
  148. if confidence[i] < self.confidence_threshold and predicted_class[i] == 1:
  149. confidence[i] = 1.0
  150. predicted_class[i] = 0
  151. # 第二层虚警抑制,空间滤波
  152. confidence, predicted_class = self.filter_prediction(predicted_class=predicted_class, confidence=confidence, filter_down_limit=3)
  153. # print('最终结果:', confidence)
  154. # print('后处理,空间滤波后结果:', predicted_class)
  155. # 第三层 时间滤波
  156. self.continuous_detection_result.append(predicted_class)
  157. flag = self.update_continuous_counter(predicted_class) # 连续帧滤波
  158. if len(self.continuous_detection_result) >= self.start_frame_num:
  159. flag = self.discriminate_ratio() and flag
  160. print(f'是否存在浑浊水体,综合判别结果:{flag}')
  161. self.continuous_detection_result.pop(0)
  162. return flag
  163. return False
  164. @staticmethod
  165. def get_33_patch(arr: np.ndarray, center_row: int, center_col: int):
  166. """以(center_row,center_col)为中心,从arr中取出来3*3区域的数据"""
  167. # 边界检查
  168. h, w = arr.shape
  169. safe_row_up_limit = max(0, center_row - 1)
  170. safe_row_bottom_limit = min(h, center_row + 2)
  171. safe_col_left_limit = max(0, center_col - 1)
  172. safe_col_right_limit = min(w, center_col + 2)
  173. return arr[safe_row_up_limit:safe_row_bottom_limit, safe_col_left_limit:safe_col_right_limit]
  174. def update_continuous_counter(self, pre_class_arr):
  175. """连续帧判别"""
  176. positive_index = np.array(pre_class_arr, dtype=np.int32) > 0
  177. negative_index = np.array(pre_class_arr, dtype=np.int32) == 0
  178. # 给负样本区域置零
  179. self.continuous_counter_mat[negative_index] -= 3
  180. # 给正样本区域加1
  181. self.continuous_counter_mat[positive_index] += 1
  182. # 保证不出现负数
  183. self.continuous_counter_mat[self.continuous_counter_mat<0] = 0
  184. # 判断浑浊
  185. bad_flag = bool(np.sum(self.continuous_counter_mat > self.max_continuous_frames) > 2) # 两个以上的patch满足条件
  186. # print('连续帧信息:', self.continuous_counter_mat)
  187. print(f'连续帧判别:该时间段是否存在浑浊水体:{bad_flag}')
  188. return bad_flag
  189. def discriminate_ratio(self):
  190. water_pre_list = self.continuous_detection_result.copy()
  191. # 方式一:60%以上的帧存在浑浊水体
  192. water_pre_arr = np.array(water_pre_list, dtype=np.float32)
  193. water_pre_arr_sum = np.sum(water_pre_arr, axis=0)
  194. bad_water = np.array(water_pre_arr_sum >= self.ratio_threshold * len(water_pre_list), dtype=np.int32)
  195. bad_flag = bool(np.sum(bad_water, dtype=np.int32) > 2) # 大于两个patch符合要求才可以
  196. # print('比例信息:',water_pre_arr_sum )
  197. print(f'浑浊比例判别:该时间段是否存在浑浊水体:{bad_flag}')
  198. return bad_flag
  199. def filter_prediction(self, predicted_class, confidence, filter_down_limit=3):
  200. """预测结果矩阵滤波,九宫格内部存在浑浊水体的数量需要大于filter_down_limit,"""
  201. predicted_class_mat = np.resize(predicted_class, (self.current_patch_rows, self.current_patch_cols))
  202. predicted_conf_mat = np.resize(confidence, (self.current_patch_rows, self.current_patch_cols))
  203. new_predicted_class_mat = predicted_class_mat.copy()
  204. new_predicted_conf_mat = predicted_conf_mat.copy()
  205. for i in range(self.current_patch_rows):
  206. for j in range(self.current_patch_cols):
  207. if (1. - predicted_class_mat[i, j]) > 0.1:
  208. continue # 跳过背景类
  209. core_region = self.get_33_patch(predicted_class_mat, i, j)
  210. if np.sum(core_region) < filter_down_limit:
  211. new_predicted_class_mat[i, j] = 0 # 重置为背景类
  212. new_predicted_conf_mat[i, j] = 1.0
  213. return new_predicted_conf_mat.flatten(), new_predicted_class_mat.flatten()
  214. def argsparser():
  215. parser = argparse.ArgumentParser(prog=__file__)
  216. parser.add_argument('--input','-i', type=str, default=r'./4_video_20251223163145', help='path of input, must be image directory')
  217. parser.add_argument('--bmodel','-b', type=str, default='./shufflenet_f32.bmodel', help='path of bmodel')
  218. parser.add_argument('--dev_id','-d', type=int, default=0, help='tpu id')
  219. args = parser.parse_args()
  220. return args
  221. def main(args):
  222. """函数的目的是为了实现一个能够报警的完整业务逻辑
  223. 输入:路径,包含了一个视频的帧序列,按照时间展开
  224. 输出:确认存在浑浊的水体,是表示报警,否表示无明显浑浊水体
  225. """
  226. # 加载推理引擎
  227. predictor = Predictor()
  228. # 获取图片
  229. all_imgs = [os.path.join(args.input, i) for i in sorted(os.listdir(args.input))]
  230. for img_path in all_imgs:
  231. img = cv2.imread(img_path, cv2.IMREAD_COLOR)
  232. # 跳过空图片
  233. if img is None:
  234. continue
  235. if img.size == 0:
  236. continue
  237. print("正在处理:", img_path)
  238. print('污水警报', predictor(img))
  239. if __name__ == '__main__':
  240. args = argsparser()
  241. main(args)