test.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import torch
  2. import torch.nn as nn
  3. from torchvision import transforms
  4. from torchvision.models import resnet18, resnet50
  5. import numpy as np
  6. from PIL import Image
  7. import os
  8. import argparse
  9. from labelme.utils import draw_grid
  10. import cv2
  11. import matplotlib.pyplot as plt
  12. from dotenv import load_dotenv
  13. load_dotenv()
  14. # os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
  15. patch_w = int(os.getenv('PATCH_WIDTH', 256))
  16. patch_h = int(os.getenv('PATCH_HEIGHT', 256))
  17. confidence_threshold = float(os.getenv('CONFIDENCE_THRESHOLD', 0.80))
  18. scale = 2
  19. class Predictor:
  20. def __init__(self, model_name, weights_path, num_classes):
  21. self.model_name = model_name
  22. self.weights_path = weights_path
  23. self.num_classes = num_classes
  24. self.model = None
  25. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  26. print(f"当前设备: {self.device}")
  27. # 加载模型
  28. self.load_model()
  29. # 检查模型结构
  30. print(self.model)
  31. def load_model(self):
  32. if self.model is not None:
  33. return
  34. print(f"正在加载模型: {self.model_name}")
  35. if self.model_name == 'resnet18':
  36. self.model = resnet18(weights=None)
  37. elif self.model_name == 'resnet50':
  38. self.model = resnet50(weights=None)
  39. else:
  40. raise ValueError(f"不支持的模型类型: {self.model_name}")
  41. # 修改最后的全连接层
  42. self.model.fc = nn.Linear(self.model.fc.in_features, self.num_classes)
  43. # 加载训练好的权重
  44. self.model.load_state_dict(torch.load(self.weights_path, map_location=torch.device('cpu')))
  45. print(f"成功加载模型参数: {self.weights_path}")
  46. # 将模型移动到GPU
  47. self.model.eval()
  48. self.model = self.model.to(self.device)
  49. print(f"成功加载模型: {self.model_name}")
  50. def predict(self, image_tensor):
  51. """
  52. 对单张图像进行预测
  53. Args:
  54. image_tensor: 预处理后的图像张量
  55. Returns:
  56. predicted_class: 预测的类别索引
  57. confidence: 预测置信度
  58. probabilities: 各类别的概率
  59. """
  60. image_tensor = image_tensor.to(self.device)
  61. with torch.no_grad():
  62. outputs = self.model(image_tensor)
  63. probabilities = torch.nn.functional.softmax(outputs, dim=1) # 沿行计算softmax
  64. confidence, predicted_class = torch.max(probabilities, 1)
  65. return confidence.cpu().numpy(), predicted_class.cpu().numpy()
  66. def preprocess_image(img):
  67. """
  68. 预处理图像以匹配训练时的预处理
  69. Args:
  70. img: PIL图像
  71. Returns:
  72. tensor: 预处理后的图像张量
  73. """
  74. # 定义与训练时相同的预处理步骤
  75. transform = transforms.Compose([
  76. transforms.Resize((224, 224)),
  77. transforms.ToTensor(),
  78. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  79. ])
  80. # 打开并转换图像
  81. img_w, img_h = img.size
  82. global patch_w, patch_h
  83. imgs_patch = []
  84. imgs_index = []
  85. # fig, axs = plt.subplots(img_h // patch_h + 1, img_w // patch_w + 1)
  86. for i in range(img_h // patch_h + 1):
  87. for j in range(img_w // patch_w + 1):
  88. left = j * patch_w # 裁剪区域左边框距离图像左边的像素值
  89. top = i * patch_h # 裁剪区域上边框距离图像上边的像素值
  90. right = min(j * patch_w + patch_w, img_w) # 裁剪区域右边框距离图像左边的像素值
  91. bottom = min(i * patch_h + patch_h, img_h) # 裁剪区域下边框距离图像上边的像素值
  92. # 检查区域是否有效
  93. if right > left and bottom > top:
  94. patch = img.crop((left, top, right, bottom))
  95. # 长宽比过滤
  96. # rate = patch.height / (patch.width + 1e-6)
  97. # if rate > 1.314 or rate < 0.75:
  98. # # print(f"长宽比过滤: {patch_name}")
  99. # continue
  100. imgs_patch.append(patch)
  101. imgs_index.append((left, top))
  102. # axs[i, j].imshow(patch)
  103. # axs[i, j].set_title(f'Image {i} {j}')
  104. # axs[i, j].axis('off')
  105. # plt.tight_layout()
  106. # plt.show()
  107. imgs_patch = torch.stack([transform(img) for img in imgs_patch])
  108. # 添加批次维度
  109. # image_tensor = image_tensor.unsqueeze(0)
  110. return imgs_index, imgs_patch
  111. def visualize_prediction(image_path, predicted_class, confidence, class_names):
  112. """
  113. 可视化预测结果
  114. Args:
  115. image_path: 图像路径
  116. predicted_class: 预测的类别索引
  117. confidence: 预测置信度
  118. class_names: 类别名称列表
  119. """
  120. image = Image.open(image_path).convert('RGB')
  121. plt.figure(figsize=(8, 6))
  122. plt.imshow(image)
  123. plt.axis('off')
  124. plt.title(f'Predicted: {class_names[predicted_class]}\n'
  125. f'Confidence: {confidence:.4f}', fontsize=14)
  126. plt.show()
  127. def main():
  128. # 初始化模型实例
  129. predictor = Predictor(model_name='resnet50',
  130. weights_path=r'D:\code\water_turbidity_det\resnet50_best_model_acc.pth',
  131. num_classes=2)
  132. input_path = r'D:\code\water_turbidity_det\data\video1_20251129120104_20251129123102'
  133. # 预处理图像
  134. all_imgs = os.listdir(input_path)
  135. all_imgs = [os.path.join(input_path, p) for p in all_imgs if p.split('.')[-1] in ['jpg', 'png']]
  136. for img_path in all_imgs:
  137. image = Image.open(img_path).convert('RGB')
  138. patches_index, image_tensor = preprocess_image(image)
  139. confidence, predicted_class = predictor.predict(image_tensor)
  140. # 第一层虚警抑制,置信度过滤,低于阈值将会被忽略
  141. for i in range(len(confidence)):
  142. if confidence[i] < confidence_threshold:
  143. confidence[i] = 1.0
  144. predicted_class[i] = 0
  145. # 第二层虚警抑制,空间滤波
  146. predicted_class_mat = np.resize(predicted_class, (image.height//patch_h+1, image.width//patch_w+1))
  147. # 可视化预测结果
  148. image = cv2.imread(img_path)
  149. image = draw_grid(image, patch_w, patch_h)
  150. dw = patch_w // 2
  151. dh = patch_h // 2
  152. resized_img_h = image.shape[0] // 2
  153. resized_img_w = image.shape[1] // 2
  154. for i, (idx_w, idx_h) in enumerate(patches_index):
  155. cv2.circle(image, (idx_w, idx_h), 10, (0, 255, 0), -1)
  156. text1 = f'cls:{predicted_class[i]}'
  157. text2 = f'prob:{confidence[i]*100:.1f}%'
  158. color = (0, 0, 255) if predicted_class[i] else (255, 0, 0)
  159. cv2.putText(image, text1, (idx_w, idx_h + dh),
  160. cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
  161. cv2.putText(image, text2, (idx_w, idx_h + dh +25),
  162. cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
  163. image = cv2.resize(image, (resized_img_w, resized_img_h))
  164. cv2.imshow('image', image)
  165. cv2.waitKey(20)
  166. if __name__ == "__main__":
  167. main()