video_test.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. import time
  2. import torch
  3. import torch.nn as nn
  4. from torchvision import transforms
  5. from model.model_zoon import load_model
  6. import numpy as np
  7. from PIL import Image
  8. import os
  9. import argparse
  10. from labelme.utils import draw_grid, draw_predict_grid
  11. import cv2
  12. import matplotlib.pyplot as plt
  13. from dotenv import load_dotenv
  14. load_dotenv()
  15. # os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
  16. patch_w = int(os.getenv('PATCH_WIDTH', 256))
  17. patch_h = int(os.getenv('PATCH_HEIGHT', 256))
  18. confidence_threshold = float(os.getenv('CONFIDENCE_THRESHOLD', 0.80))
  19. scale = 2
  20. class Predictor:
  21. def __init__(self, model_name, weights_path, num_classes):
  22. self.model_name = model_name
  23. self.weights_path = weights_path
  24. self.num_classes = num_classes
  25. # self.use_bias = os.getenv('USE_BIAS', True)
  26. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  27. print(f"当前设备: {self.device}")
  28. self.model = self.load_model()
  29. # 判定有浑浊水体时的连续帧数量
  30. self.max_continuous_frames = 100
  31. # 比例判定的起始帧数
  32. self.start_frame_num = self.max_continuous_frames
  33. # 比例判定的阈值
  34. self.ratio_threshold = 0.90
  35. # 报警的置信度阈值
  36. self.confidence_threshold = 0.90
  37. def load_model(self):
  38. return load_model(name=self.model_name, num_classes=self.num_classes, weights_path=self.weights_path, device=self.device)
  39. def predict(self, image_tensor):
  40. """
  41. 对单张图像进行预测
  42. Args:
  43. image_tensor: 预处理后的图像张量
  44. Returns:
  45. predicted_class: 预测的类别索引
  46. confidence: 预测置信度
  47. probabilities: 各类别的概率
  48. """
  49. image_tensor = image_tensor.to(self.device)
  50. with torch.no_grad():
  51. outputs = self.model(image_tensor)
  52. probabilities = torch.softmax(outputs, dim=1) # 沿行计算softmax
  53. confidence, predicted_class = torch.max(probabilities, 1)
  54. return confidence.cpu().numpy(), predicted_class.cpu().numpy()
  55. def preprocess_image(img):
  56. """
  57. 预处理图像以匹配训练时的预处理
  58. Args:
  59. img: PIL图像
  60. Returns:
  61. tensor: 预处理后的图像张量
  62. """
  63. # 定义与训练时相同的预处理步骤
  64. transform = transforms.Compose([
  65. transforms.Resize((224, 224)),
  66. transforms.ToTensor(),
  67. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  68. ])
  69. # 打开并转换图像
  70. img_w, img_h = img.size
  71. global patch_w, patch_h
  72. imgs_patch = []
  73. imgs_index = []
  74. # fig, axs = plt.subplots(img_h // patch_h + 1, img_w // patch_w + 1)
  75. for i in range(img_h // patch_h + 1):
  76. for j in range(img_w // patch_w + 1):
  77. left = j * patch_w # 裁剪区域左边框距离图像左边的像素值
  78. top = i * patch_h # 裁剪区域上边框距离图像上边的像素值
  79. right = min(j * patch_w + patch_w, img_w) # 裁剪区域右边框距离图像左边的像素值
  80. bottom = min(i * patch_h + patch_h, img_h) # 裁剪区域下边框距离图像上边的像素值
  81. # 检查区域是否有效
  82. if right > left and bottom > top:
  83. patch = img.crop((left, top, right, bottom))
  84. # 长宽比过滤
  85. # rate = patch.height / (patch.width + 1e-6)
  86. # if rate > 1.314 or rate < 0.75:
  87. # # print(f"长宽比过滤: {patch_name}")
  88. # continue
  89. imgs_patch.append(patch)
  90. imgs_index.append((left, top))
  91. # axs[i, j].imshow(patch)
  92. # axs[i, j].set_title(f'Image {i} {j}')
  93. # axs[i, j].axis('off')
  94. # plt.tight_layout()
  95. # plt.show()
  96. imgs_patch = torch.stack([transform(img) for img in imgs_patch])
  97. # 添加批次维度
  98. # image_tensor = image_tensor.unsqueeze(0)
  99. return imgs_index, imgs_patch
  100. def visualize_prediction(image_path, predicted_class, confidence, class_names):
  101. """
  102. 可视化预测结果
  103. Args:
  104. image_path: 图像路径
  105. predicted_class: 预测的类别索引
  106. confidence: 预测置信度
  107. class_names: 类别名称列表
  108. """
  109. image = Image.open(image_path).convert('RGB')
  110. plt.figure(figsize=(8, 6))
  111. plt.imshow(image)
  112. plt.axis('off')
  113. plt.title(f'Predicted: {class_names[predicted_class]}\n'
  114. f'Confidence: {confidence:.4f}', fontsize=14)
  115. plt.show()
  116. def get_33_patch(arr:np.ndarray, center_row:int, center_col:int):
  117. """以(center_row,center_col)为中心,从arr中取出来3*3区域的数据"""
  118. # 边界检查
  119. h,w = arr.shape
  120. safe_row_up_limit = max(0, center_row-1)
  121. safe_row_bottom_limit = min(h, center_row+2)
  122. safe_col_left_limit = max(0, center_col-1)
  123. safe_col_right_limit = min(w, center_col+2)
  124. return arr[safe_row_up_limit:safe_row_bottom_limit, safe_col_left_limit:safe_col_right_limit]
  125. def fileter_prediction(predicted_class, confidence, pre_rows, pre_cols, filter_down_limit=3):
  126. """预测结果矩阵滤波,九宫格内部存在浑浊水体的数量需要大于filter_down_limit,"""
  127. predicted_class_mat = np.resize(predicted_class, (pre_rows, pre_cols))
  128. predicted_conf_mat = np.resize(confidence, (pre_rows, pre_cols))
  129. new_predicted_class_mat = predicted_class_mat.copy()
  130. new_predicted_conf_mat = predicted_conf_mat.copy()
  131. for i in range(pre_rows):
  132. for j in range(pre_cols):
  133. if (1. - predicted_class_mat[i, j]) > 0.1:
  134. continue # 跳过背景类
  135. core_region = get_33_patch(predicted_class_mat, i, j)
  136. if np.sum(core_region) < filter_down_limit:
  137. new_predicted_class_mat[i, j] = 0 # 重置为背景类
  138. new_predicted_conf_mat[i, j] = 1.0
  139. return new_predicted_conf_mat.flatten(), new_predicted_class_mat.flatten()
  140. def discriminate_ratio(water_pre_list:list, right_ratio:float):
  141. # 方式一:60%以上的帧存在浑浊水体
  142. water_pre_arr = np.array(water_pre_list, dtype=np.float32)
  143. water_pre_arr_sum = np.sum(water_pre_arr, axis=0)
  144. bad_water = np.array(water_pre_arr_sum >= right_ratio * len(water_pre_list), dtype=np.int32)
  145. bad_flag = bool(np.sum(bad_water, dtype=np.int32) > 2) # 大于两个patch符合要求才可以
  146. print(f'浑浊比例判别:该时间段是否存在浑浊水体:{bad_flag}')
  147. return bad_flag
  148. def discriminate_count(pre_class_arr, continuous_count_mat,max_continuous_frames):
  149. """连续帧判别"""
  150. positive_index = np.array(pre_class_arr,dtype=np.int32) > 0
  151. negative_index = np.array(pre_class_arr,dtype=np.int32) == 0
  152. # 给负样本区域置零
  153. continuous_count_mat[negative_index] -= 3
  154. # 给正样本区域加1
  155. continuous_count_mat[positive_index] += 1
  156. # 保证不出现负数
  157. continuous_count_mat[continuous_count_mat<0] = 0
  158. # 判断浑浊
  159. bad_flag = bool(np.sum(continuous_count_mat > max_continuous_frames) > 2)
  160. print(f'连续帧方式:该时间段是否存在浑浊水体:{bad_flag}')
  161. return bad_flag
  162. def main():
  163. # 初始化模型实例
  164. # TODO:修改模型网络名称/模型权重路径/视频路径
  165. predictor = Predictor(model_name='shufflenet-x2',
  166. weights_path=r'./shufflenet-x2.pth',
  167. num_classes=2)
  168. input_path = r'D:\code\water_turbidity_det\tem_test\2_ch52_20260113011503_0'
  169. # 预处理图像
  170. all_imgs = os.listdir(input_path)
  171. all_imgs = [os.path.join(input_path, p) for p in all_imgs if p.split('.')[-1] in ['jpg', 'png']]
  172. image = Image.open(all_imgs[0]).convert('RGB')
  173. # 将预测结果reshape为矩阵时的行列数量
  174. pre_rows = image.height // patch_h + 1
  175. pre_cols = image.width // patch_w + 1
  176. # 图像显示时resize的尺寸
  177. resized_img_h = image.height // 2
  178. resized_img_w = image.width // 2
  179. # 预测每张图像
  180. water_pre_list = []
  181. continuous_count_mat = np.zeros(pre_rows*pre_cols, dtype=np.int32)
  182. flag = False
  183. for img_path in all_imgs:
  184. image = Image.open(img_path).convert('RGB')
  185. # 预处理
  186. patches_index, image_tensor = preprocess_image(image) # patches_index:list[tuple, ...]
  187. # 推理
  188. confidence, predicted_class = predictor.predict(image_tensor) # confidence: np.ndarray, shape=(x,), predicted_class: np.ndarray, shape=(x,), raw_outputs: np.ndarray, shape=(x,)
  189. # 第一层虚警抑制,置信度过滤,低于阈值将会被忽略
  190. for i in range(len(confidence)):
  191. if confidence[i] < confidence_threshold and predicted_class[i] == 1:
  192. confidence[i] = 1.0
  193. predicted_class[i] = 0
  194. # 第二层虚警抑制,空间滤波
  195. # 在此处添加过滤逻辑
  196. # print('原始预测结果:', predicted_class)
  197. new_confidence, new_predicted_class = fileter_prediction(predicted_class, confidence, pre_rows, pre_cols, filter_down_limit=3)
  198. # print('过滤后预测结果:', new_predicted_class)
  199. # 可视化预测结果
  200. image = cv2.imread(img_path)
  201. image = draw_grid(image, patch_w, patch_h)
  202. image = draw_predict_grid(image, patches_index, predicted_class, confidence)
  203. new_image = cv2.imread(img_path)
  204. new_image = draw_grid(new_image, patch_w, patch_h)
  205. new_image = draw_predict_grid(new_image, patches_index, new_predicted_class, new_confidence)
  206. image = cv2.resize(image, (resized_img_w, resized_img_h))
  207. new_img = cv2.resize(new_image, (resized_img_w, resized_img_h))
  208. cv2.imshow('image', image)
  209. cv2.imshow('image_filter', new_img)
  210. cv2.waitKey(25)
  211. water_pre_list.append(new_predicted_class)
  212. # 方式2判别
  213. flag = discriminate_count(new_predicted_class, continuous_count_mat, predictor.max_continuous_frames)
  214. # 方式1判别
  215. if len(water_pre_list) > predictor.start_frame_num:
  216. flag = discriminate_ratio(water_pre_list, predictor.ratio_threshold) and flag
  217. print('综合判别结果:', flag)
  218. if __name__ == "__main__":
  219. main()