check_label.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. """检查标注是否正确,读取标签,信息"""
  2. import cv2
  3. import os
  4. import sys
  5. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  6. from utils import draw_grid
  7. from dotenv import load_dotenv
  8. load_dotenv() # 加载环境变量
  9. patch_w = int(os.getenv('PATCH_WIDTH', 256))
  10. patch_h = int(os.getenv('PATCH_HEIGHT', 256))
  11. scale = 2
  12. clicked_points = [] # u v w h cls
  13. def main():
  14. """检查标注结果是否正确"""
  15. # 标注文件路径
  16. global patch_w
  17. global patch_h
  18. global scale
  19. global clicked_points
  20. # TODO:修改为要检查的图片路径
  21. imgs_path = r'/frame_data/test/3_video_202511211127'
  22. output_root = r'D:\code\water_turbidity_det\check'
  23. label_path = os.path.join(imgs_path, 'label.txt')
  24. if os.path.exists(label_path):
  25. with open(label_path, 'r') as fr:
  26. lines = fr.readlines()
  27. lines = [line.strip() for line in lines]
  28. for line in lines:
  29. point = line.split(',')
  30. clicked_points.append([int(point[0])//scale, int(point[1])//scale, int(point[2])//scale, int(point[3])//scale, int(point[4])])
  31. del lines
  32. # 检查结果输出路径
  33. output_path = os.path.join(output_root, os.path.basename(imgs_path)+'_check')
  34. if not os.path.exists(output_path):
  35. os.makedirs(output_path)
  36. # 获取所有照片
  37. all_imgs = os.listdir(imgs_path)
  38. all_imgs = [img for img in all_imgs if img.split('.')[-1] == 'jpg' or img.split('.')[-1] == 'png']
  39. for img in all_imgs:
  40. img_path = os.path.join(imgs_path, img)
  41. img = cv2.imread(img_path)
  42. img = cv2.resize(img, (img.shape[1] // scale, img.shape[0] // scale))
  43. # 绘制网格线
  44. img = draw_grid(img, patch_w // scale, patch_h // scale)
  45. # 绘制标记点
  46. for point in clicked_points:
  47. # 计算中心点
  48. center_x = point[0]+point[2]//scale
  49. center_y = point[1]+point[3]//scale
  50. cv2.circle(img, (center_x, center_y), 5, (0, 0, 255), -1)
  51. # 显示标签文本
  52. cv2.putText(img, str(point[4]), (center_x + 10, center_y + 10),
  53. cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
  54. cv2.imwrite(os.path.join(output_path, os.path.basename(img_path)),
  55. img)
  56. print(f"检查结果保存在: {os.path.join(output_path, os.path.basename(img_path))}")
  57. if __name__ == '__main__':
  58. main()