draw.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import cv2
  2. import numpy as np
  3. import os
  4. class DrawRectangle:
  5. def __init__(self, div_scale):
  6. self.mask_save_dir = './mask'
  7. self.scale = div_scale
  8. self.current_roi_points = []
  9. self.rois = []
  10. self.window_title = "Image - Select ROI"
  11. self.draw_complete = False
  12. pass
  13. def callback(self, event, x, y, flags, param):
  14. drawing_image = param['Image']
  15. # 左键添加感兴趣点
  16. if event == cv2.EVENT_LBUTTONDOWN:
  17. # 添加感兴趣点
  18. self.current_roi_points.append((x, y))
  19. # 绘制标记点
  20. cv2.circle(drawing_image, (x, y), 4, (0, 255, 0), -1)
  21. # 绘制多边形
  22. if len(self.current_roi_points) > 1:
  23. cv2.line(drawing_image, self.current_roi_points[-2], self.current_roi_points[-1], (0, 255, 0), 2)
  24. # 显示图像
  25. cv2.imshow(self.window_title, drawing_image)
  26. print(f'添加感兴趣点:{y}行, {x}列')
  27. # 右键闭合感兴趣区域
  28. if event == cv2.EVENT_RBUTTONDOWN:
  29. if len(self.current_roi_points) < 3:
  30. print("[提示] ROI 至少需要 3 个点构成多边形!")
  31. return
  32. cv2.line(drawing_image, self.current_roi_points[-1], self.current_roi_points[0], (0, 255, 0), 2)
  33. cv2.imshow(self.window_title, drawing_image)
  34. # 清理
  35. self.rois.append(self.current_roi_points)
  36. print(f'添加感兴趣区,包含点数:{len(self.current_roi_points)}个')
  37. self.current_roi_points = []
  38. def draw(self, img_path: str):
  39. """在输入图像中绘制多边形区域,然后生成相应的mask图片"""
  40. # 读取图像
  41. ori_img = cv2.imread(img_path)
  42. mask_base_name = os.path.splitext(os.path.basename(img_path))[0] + '.png'
  43. img = cv2.resize(ori_img, (ori_img.shape[1] // self.scale, ori_img.shape[0] // self.scale))
  44. if img is None:
  45. raise RuntimeError('Cannot read the image!')
  46. param = {'Image': img}
  47. cv2.namedWindow(self.window_title)
  48. cv2.setMouseCallback(self.window_title, self.callback, param=param)
  49. # 显示图像并等待退出
  50. while True:
  51. cv2.imshow(self.window_title, img)
  52. key = cv2.waitKey(1) & 0xFF
  53. if key == ord('q') or key == 27: # 按'q'或ESC键退出
  54. break
  55. # 为原图生成掩膜
  56. mask = np.zeros((ori_img.shape[0], ori_img.shape[1]),dtype=np.uint8) # shape等于原始输入图像
  57. for roi in self.rois:
  58. roi_points = np.array(roi, np.int32).reshape((-1, 1, 2)) * self.scale # 兴趣点的缩放处理
  59. cv2.fillPoly(mask, [roi_points], 255)
  60. # 保存掩膜图像
  61. if not os.path.exists(self.mask_save_dir):
  62. os.makedirs(self.mask_save_dir)
  63. cv2.imwrite(os.path.join(self.mask_save_dir, mask_base_name), mask)
  64. # cv2.imshow("mask", mask)
  65. # cv2.waitKey(0)
  66. cv2.destroyAllWindows()
  67. if __name__ == '__main__':
  68. drawer = DrawRectangle(2)
  69. drawer.draw(r"D:\code\water_turbidity_det\draw_mask\mask\4_device_capture.jpg")