crop_patch.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # 根据标注文件,生成patch,每个类别放在一个文件夹下
  2. import os
  3. import sys
  4. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  5. import numpy as np
  6. import cv2
  7. from dotenv import load_dotenv
  8. load_dotenv()
  9. # 从图像中截取patch_w*patch_h大小的图像,并打上标签
  10. patch_w = int(os.getenv('PATCH_WIDTH', 256))
  11. patch_h = int(os.getenv('PATCH_HEIGHT', 256))
  12. def main():
  13. # TODO:需要修改为标注好的图片路径
  14. input_path = r'D:\code\water_turbidity_det\data\video4_20251129120320_20251129123514'
  15. # TODO: 需要修改为保存patch的根目录
  16. output_path_root = r'D:\code\water_turbidity_det\label_data\train'
  17. # 读取标注文件
  18. label_path = os.path.join(input_path, 'label.txt')
  19. if not os.path.exists(label_path): # 强制要求必须标注点什么
  20. raise FileNotFoundError(f"{label_path} 不存在")
  21. with open(label_path, 'r') as fr:
  22. lines = fr.readlines()
  23. lines = [line.strip() for line in lines]
  24. # 恢复标注的格网
  25. grids_info = [] # clicked_points
  26. for line in lines:
  27. point = line.split(',')
  28. grids_info.append([int(point[0]), int(point[1]), int(point[2]), int(point[3]), int(point[4])])
  29. # 我们先创建一些类别文件夹
  30. # 0类
  31. if not os.path.exists(os.path.join(output_path_root, str(0))):
  32. os.makedirs(os.path.join(output_path_root, str(0)))
  33. # 其余类
  34. for grid in grids_info:
  35. if grid[4] <= 0:
  36. continue
  37. if not os.path.exists(os.path.join(output_path_root, str(grid[4]))):
  38. os.makedirs(os.path.join(output_path_root, str(grid[4])))
  39. # 获取图像
  40. all_imgs = [os.path.join(input_path, i) for i in os.listdir(input_path) if i.split('.')[-1] == 'jpg' or i.split('.')[-1] == 'png']
  41. for img_path in all_imgs:
  42. img_base_name = os.path.basename(img_path).split('.')[0]
  43. img = cv2.imread(img_path)
  44. # 获取图像高宽
  45. img_h, img_w, _ = img.shape
  46. # 先将不参与训练的patch重置为0
  47. for g in grids_info:
  48. if g[4] < 0: # 标签小于零的不参与训练
  49. img[g[1]:min(g[1]+g[3], img_h), g[0]:min(g[0]+g[2], img_w), :] = 0
  50. # 再将大于0的patch保存到对应的类别文件夹下
  51. for g in grids_info:
  52. if g[4] > 0: # 标签大于零的放到相应的文件夹下
  53. patch_name = f'{img_base_name}_{g[0]}_{g[1]}_{g[4]}.jpg' # 图块保存名称:图片名_左上角x_左上角y_类别.jpg
  54. patch = img[g[1]:min(g[1]+g[3], img_h), g[0]:min(g[0]+g[2], img_w), :]
  55. # 保存图块
  56. cv2.imwrite(os.path.join(output_path_root,str(g[4]), patch_name), patch)
  57. # 置零已经保存的patch区域
  58. img[g[1]:min(g[1]+g[3], img_h), g[0]:min(g[0]+g[2], img_w), :] = 0
  59. # 最后将剩余的patch保存到0类文件夹下
  60. for i in range(img_h // patch_h + 1):
  61. for j in range(img_w // patch_w + 1):
  62. patch = img[i*patch_h:min(i*patch_h+patch_h, img_h), j*patch_w:min(j*patch_w+patch_w, img_w), :]
  63. patch_name = f'{img_base_name}_{j*patch_w}_{i*patch_h}_0.jpg'
  64. # 长宽比过滤
  65. if patch.shape[0] / patch.shape[1] > 1.314 or patch.shape[0] / patch.shape[1] < 0.75:
  66. print(f"长宽比过滤: {patch_name}")
  67. continue
  68. # 纯黑图像过滤
  69. if np.mean(patch) < 10.10:
  70. print(f"纯黑图像过滤: {patch_name}")
  71. continue
  72. cv2.imwrite(os.path.join(output_path_root, '0', patch_name), patch)
  73. print(f"保存图块: {patch_name}到{os.path.join(output_path_root, '0', patch_name)}")
  74. if __name__ == '__main__':
  75. main()