test.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import time
  2. import torch
  3. import torch.nn as nn
  4. from torchvision import transforms
  5. from torchvision.models import resnet18, ResNet18_Weights,resnet50,ResNet50_Weights, squeezenet1_0, SqueezeNet1_0_Weights,\
  6. shufflenet_v2_x1_0, ShuffleNet_V2_X1_0_Weights, swin_v2_s, Swin_V2_S_Weights, swin_v2_b, Swin_V2_B_Weights
  7. import numpy as np
  8. from PIL import Image
  9. import os
  10. import argparse
  11. from labelme.utils import draw_grid, draw_predict_grid
  12. import cv2
  13. import matplotlib.pyplot as plt
  14. from dotenv import load_dotenv
  15. load_dotenv()
  16. # os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
  17. patch_w = int(os.getenv('PATCH_WIDTH', 256))
  18. patch_h = int(os.getenv('PATCH_HEIGHT', 256))
  19. confidence_threshold = float(os.getenv('CONFIDENCE_THRESHOLD', 0.80))
  20. scale = 2
  21. class Predictor:
  22. def __init__(self, model_name, weights_path, num_classes):
  23. self.model_name = model_name
  24. self.weights_path = weights_path
  25. self.num_classes = num_classes
  26. self.model = None
  27. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  28. print(f"当前设备: {self.device}")
  29. # 加载模型
  30. self.load_model()
  31. def load_model(self):
  32. if self.model is not None:
  33. return
  34. print(f"正在加载模型: {self.model_name}")
  35. name = self.model_name
  36. # 加载模型
  37. if name == 'resnet50':
  38. self.weights = ResNet50_Weights.IMAGENET1K_V2
  39. self.model = resnet50(weights=self.weights)
  40. elif name == 'squeezenet':
  41. self.weights = SqueezeNet1_0_Weights.IMAGENET1K_V1
  42. self.model = squeezenet1_0(weights=self.weights)
  43. elif name == 'shufflenet':
  44. self.weights = ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1
  45. self.model = shufflenet_v2_x1_0(weights=self.weights)
  46. elif name == 'swin_v2_s':
  47. self.weights = Swin_V2_S_Weights.IMAGENET1K_V1
  48. self.model = swin_v2_s(weights=self.weights)
  49. elif name == 'swin_v2_b':
  50. self.weights = Swin_V2_B_Weights.IMAGENET1K_V1
  51. self.model = swin_v2_b(weights=self.weights)
  52. else:
  53. raise ValueError(f"Invalid model name: {name}")
  54. # 替换最后的分类层以适应新的分类任务
  55. if hasattr(self.model, 'fc'):
  56. # ResNet系列模型
  57. self.model.fc = nn.Sequential(
  58. nn.Linear(int(self.model.fc.in_features), int(self.model.fc.in_features) // 2, bias=True),
  59. nn.ReLU(inplace=True),
  60. nn.Dropout(0.5),
  61. nn.Linear(int(self.model.fc.in_features) // 2, self.num_classes, bias=False)
  62. )
  63. elif hasattr(self.model, 'classifier'):
  64. # Swin Transformer等模型
  65. self.model.classifier = nn.Sequential(
  66. nn.Linear(int(self.model.classifier.in_features), int(self.model.classifier.in_features) // 2,
  67. bias=True),
  68. nn.ReLU(inplace=True),
  69. nn.Dropout(0.5),
  70. nn.Linear(int(self.model.classifier.in_features) // 2, self.num_classes, bias=False)
  71. )
  72. elif hasattr(self.model, 'head'):
  73. # Swin Transformer使用head层
  74. in_features = self.model.head.in_features
  75. self.model.head = nn.Sequential(
  76. nn.Linear(int(in_features), int(in_features) // 2, bias=True),
  77. nn.ReLU(inplace=True),
  78. nn.Dropout(0.5),
  79. nn.Linear(int(in_features) // 2, self.num_classes, bias=False)
  80. )
  81. else:
  82. raise ValueError(f"Model {name} does not have recognizable classifier layer")
  83. print(self.model)
  84. # 加载训练好的权重
  85. self.model.load_state_dict(torch.load(self.weights_path, map_location=torch.device('cpu')))
  86. print(f"成功加载模型参数: {self.weights_path}")
  87. # 将模型移动到GPU
  88. self.model.eval()
  89. self.model = self.model.to(self.device)
  90. print(f"成功加载模型: {self.model_name}")
  91. def predict(self, image_tensor):
  92. """
  93. 对单张图像进行预测
  94. Args:
  95. image_tensor: 预处理后的图像张量
  96. Returns:
  97. predicted_class: 预测的类别索引
  98. confidence: 预测置信度
  99. probabilities: 各类别的概率
  100. """
  101. image_tensor = image_tensor.to(self.device)
  102. with torch.no_grad():
  103. outputs = self.model(image_tensor)
  104. probabilities = torch.nn.functional.softmax(outputs, dim=1) # 沿行计算softmax
  105. confidence, predicted_class = torch.max(probabilities, 1)
  106. return confidence.cpu().numpy(), predicted_class.cpu().numpy()
  107. def preprocess_image(img):
  108. """
  109. 预处理图像以匹配训练时的预处理
  110. Args:
  111. img: PIL图像
  112. Returns:
  113. tensor: 预处理后的图像张量
  114. """
  115. # 定义与训练时相同的预处理步骤
  116. transform = transforms.Compose([
  117. transforms.Resize((224, 224)),
  118. transforms.ToTensor(),
  119. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  120. ])
  121. # 打开并转换图像
  122. img_w, img_h = img.size
  123. global patch_w, patch_h
  124. imgs_patch = []
  125. imgs_index = []
  126. # fig, axs = plt.subplots(img_h // patch_h + 1, img_w // patch_w + 1)
  127. for i in range(img_h // patch_h + 1):
  128. for j in range(img_w // patch_w + 1):
  129. left = j * patch_w # 裁剪区域左边框距离图像左边的像素值
  130. top = i * patch_h # 裁剪区域上边框距离图像上边的像素值
  131. right = min(j * patch_w + patch_w, img_w) # 裁剪区域右边框距离图像左边的像素值
  132. bottom = min(i * patch_h + patch_h, img_h) # 裁剪区域下边框距离图像上边的像素值
  133. # 检查区域是否有效
  134. if right > left and bottom > top:
  135. patch = img.crop((left, top, right, bottom))
  136. # 长宽比过滤
  137. # rate = patch.height / (patch.width + 1e-6)
  138. # if rate > 1.314 or rate < 0.75:
  139. # # print(f"长宽比过滤: {patch_name}")
  140. # continue
  141. imgs_patch.append(patch)
  142. imgs_index.append((left, top))
  143. # axs[i, j].imshow(patch)
  144. # axs[i, j].set_title(f'Image {i} {j}')
  145. # axs[i, j].axis('off')
  146. # plt.tight_layout()
  147. # plt.show()
  148. imgs_patch = torch.stack([transform(img) for img in imgs_patch])
  149. # 添加批次维度
  150. # image_tensor = image_tensor.unsqueeze(0)
  151. return imgs_index, imgs_patch
  152. def visualize_prediction(image_path, predicted_class, confidence, class_names):
  153. """
  154. 可视化预测结果
  155. Args:
  156. image_path: 图像路径
  157. predicted_class: 预测的类别索引
  158. confidence: 预测置信度
  159. class_names: 类别名称列表
  160. """
  161. image = Image.open(image_path).convert('RGB')
  162. plt.figure(figsize=(8, 6))
  163. plt.imshow(image)
  164. plt.axis('off')
  165. plt.title(f'Predicted: {class_names[predicted_class]}\n'
  166. f'Confidence: {confidence:.4f}', fontsize=14)
  167. plt.show()
  168. def get_33_patch(arr:np.ndarray, center_row:int, center_col:int):
  169. """以(center_row,center_col)为中心,从arr中取出来3*3区域的数据"""
  170. # 边界检查
  171. h,w = arr.shape
  172. safe_row_up_limit = max(0, center_row-1)
  173. safe_row_bottom_limit = min(h, center_row+2)
  174. safe_col_left_limit = max(0, center_col-1)
  175. safe_col_right_limit = min(w, center_col+2)
  176. return arr[safe_row_up_limit:safe_row_bottom_limit, safe_col_left_limit:safe_col_right_limit]
  177. def fileter_prediction(predicted_class, confidence, pre_rows, pre_cols, filter_down_limit=3):
  178. """预测结果矩阵滤波,九宫格内部存在浑浊水体的数量需要大于filter_down_limit,"""
  179. predicted_class_mat = np.resize(predicted_class, (pre_rows, pre_cols))
  180. predicted_conf_mat = np.resize(confidence, (pre_rows, pre_cols))
  181. new_predicted_class_mat = predicted_class_mat.copy()
  182. new_predicted_conf_mat = predicted_conf_mat.copy()
  183. for i in range(pre_rows):
  184. for j in range(pre_cols):
  185. if (1. - predicted_class_mat[i, j]) > 0.1:
  186. continue # 跳过背景类
  187. core_region = get_33_patch(predicted_class_mat, i, j)
  188. if np.sum(core_region) < filter_down_limit:
  189. new_predicted_class_mat[i, j] = 0 # 重置为背景类
  190. new_predicted_conf_mat[i, j] = 1.0
  191. return new_predicted_conf_mat.flatten(), new_predicted_class_mat.flatten()
  192. def discriminate_ratio(water_pre_list:list):
  193. # 方式一:60%以上的帧存在浑浊水体
  194. water_pre_arr = np.array(water_pre_list, dtype=np.float32)
  195. water_pre_arr_mean = np.mean(water_pre_arr, axis=0)
  196. bad_water = np.array(water_pre_arr_mean >= 0.6, dtype=np.int32)
  197. bad_flag = np.sum(bad_water, dtype=np.int32)
  198. print(f'浑浊比例方式:该时间段是否存在浑浊水体:{bool(bad_flag)}')
  199. return bad_flag
  200. def discriminate_cont(pre_class_arr, continuous_count_mat):
  201. """连续帧判别"""
  202. positive_index = np.array(pre_class_arr,dtype=np.int32) > 0
  203. negative_index = np.array(pre_class_arr,dtype=np.int32) == 0
  204. # 给负样本区域置零
  205. continuous_count_mat[negative_index] = 0
  206. # 给正样本区域加1
  207. continuous_count_mat[positive_index] += 1
  208. # 判断浑浊
  209. bad_flag = np.max(continuous_count_mat) > 30
  210. if bad_flag:
  211. print(f'连续帧方式:该时间段是否存在浑浊水体:{bool(bad_flag)}')
  212. return bad_flag
  213. def main():
  214. # 初始化模型实例
  215. # TODO:修改模型网络名称/模型权重路径/视频路径
  216. predictor = Predictor(model_name='shufflenet',
  217. weights_path=r'D:\code\water_turbidity_det\shufflenet_best_model_acc.pth',
  218. num_classes=2)
  219. input_path = r'D:\code\water_turbidity_det\data\4_video_202511211127'
  220. # 预处理图像
  221. all_imgs = os.listdir(input_path)
  222. all_imgs = [os.path.join(input_path, p) for p in all_imgs if p.split('.')[-1] in ['jpg', 'png']]
  223. image = Image.open(all_imgs[0]).convert('RGB')
  224. # 将预测结果reshape为矩阵时的行列数量
  225. pre_rows = image.height // patch_h + 1
  226. pre_cols = image.width // patch_w + 1
  227. # 图像显示时resize的尺寸
  228. resized_img_h = image.height // 2
  229. resized_img_w = image.width // 2
  230. # 预测每张图像
  231. water_pre_list = []
  232. continuous_count_mat = np.zeros(pre_rows*pre_cols, dtype=np.int32)
  233. flag = False
  234. for img_path in all_imgs:
  235. image = Image.open(img_path).convert('RGB')
  236. # 预处理
  237. patches_index, image_tensor = preprocess_image(image)
  238. # 推理
  239. confidence, predicted_class = predictor.predict(image_tensor)
  240. # 第一层虚警抑制,置信度过滤,低于阈值将会被忽略
  241. for i in range(len(confidence)):
  242. if confidence[i] < confidence_threshold:
  243. confidence[i] = 1.0
  244. predicted_class[i] = 0
  245. # 第二层虚警抑制,空间滤波
  246. # 在此处添加过滤逻辑
  247. new_confidence, new_predicted_class = fileter_prediction(predicted_class, confidence, pre_rows, pre_cols, filter_down_limit=3)
  248. # 可视化预测结果
  249. image = cv2.imread(img_path)
  250. image = draw_grid(image, patch_w, patch_h)
  251. image = draw_predict_grid(image, patches_index, predicted_class, confidence)
  252. new_image = cv2.imread(img_path)
  253. new_image = draw_grid(new_image, patch_w, patch_h)
  254. new_image = draw_predict_grid(new_image, patches_index, new_predicted_class, new_confidence)
  255. image = cv2.resize(image, (resized_img_w, resized_img_h))
  256. new_img = cv2.resize(new_image, (resized_img_w, resized_img_h))
  257. cv2.imshow('image', image)
  258. cv2.imshow('image_filter', new_img)
  259. cv2.waitKey(20)
  260. # 方式1判别
  261. if len(water_pre_list) > 100:
  262. flag = discriminate_ratio(water_pre_list) and flag
  263. water_pre_list = []
  264. print('综合判别结果:', flag)
  265. water_pre_list.append(new_predicted_class)
  266. # 方式2判别
  267. flag = discriminate_cont(new_predicted_class, continuous_count_mat)
  268. if __name__ == "__main__":
  269. main()