fixed_label.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # 要求保证视频不能移动,且全过程没有任何遮挡物
  2. import os
  3. import sys
  4. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  5. from utils import draw_grid
  6. import cv2
  7. import numpy as np
  8. import tkinter as tk
  9. from tkinter import simpledialog
  10. from dotenv import load_dotenv
  11. load_dotenv() # 加载环境变量
  12. class DrawTool:
  13. """重新绘制图像的工具"""
  14. def __init__(self, patch_w, patch_h, scale):
  15. self.img_path = None
  16. self.patch_w = patch_w
  17. self.patch_h = patch_h
  18. self.scale = scale
  19. def draw_new_img(self, points):
  20. img = cv2.imread(self.img_path)
  21. img = cv2.resize(img, (img.shape[1] // self.scale, img.shape[0] // self.scale))
  22. draw_grid(img, self.patch_w // self.scale, self.patch_h // self.scale)
  23. for p in points:
  24. # 绘制角点
  25. cv2.circle(img, (p[0], p[1]), 5, (255, 0, 0), -1)
  26. # 绘制中心点
  27. circle_x_center = p[0] + self.patch_w//(2*self.scale)
  28. circle_y_center = p[1] + self.patch_h//(2*self.scale)
  29. cv2.circle(img, (circle_x_center, circle_y_center), 5, (0, 0, 255), -1)
  30. # 标注类别
  31. cv2.putText(img, str(p[4]), (circle_x_center + 10, circle_y_center + 10),
  32. cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
  33. return img
  34. def set_path(self, img_path):
  35. """设置保存路径"""
  36. self.img_path = img_path
  37. # 所要切分的图块宽高
  38. patch_w = int(os.getenv('PATCH_WIDTH', 256))
  39. patch_h = int(os.getenv('PATCH_HEIGHT', 256))
  40. scale = 2
  41. draw_tool = DrawTool(patch_w=patch_w, patch_h=patch_h, scale=scale)
  42. # 存储标记点
  43. clicked_points = []
  44. def get_text_input(prompt):
  45. """创建弹窗获取文本输入"""
  46. root = tk.Tk()
  47. root.withdraw() # 隐藏主窗口
  48. root.attributes('-topmost', True) # 确保弹窗置顶
  49. result = simpledialog.askstring(' ', prompt)
  50. root.destroy()
  51. if result is None:
  52. result = ""
  53. if not result.strip().isdigit():
  54. result = -1
  55. return int(result)
  56. def det_same_pos_point(x_cor, y_cor):
  57. """判断是否重复点击"""
  58. global clicked_points
  59. for idx, point in enumerate(clicked_points):
  60. if point[0] == x_cor and point[1] == y_cor:
  61. return idx, True
  62. return -1, False
  63. def mouse_callback(event, x, y, flags, param):
  64. """
  65. 鼠标回调函数
  66. """
  67. global clicked_points
  68. global patch_w
  69. global patch_h
  70. global scale
  71. global draw_tool
  72. if event == cv2.EVENT_LBUTTONDOWN: # 左键点击
  73. # 在点击位置绘制红色圆点
  74. scale_patch_w = patch_w // scale
  75. scale_patch_h = patch_h // scale
  76. # 格子角点
  77. circle_x_corner = (x // scale_patch_w)*scale_patch_w
  78. circle_y_corner = (y // scale_patch_h)*scale_patch_h
  79. # 格子中心点
  80. circle_x_center = circle_x_corner + scale_patch_w//2
  81. circle_y_center = circle_y_corner + scale_patch_h//2
  82. cv2.circle(param, (circle_x_center, circle_y_center), 5, (0, 0, 255), -1)
  83. cv2.circle(param, (circle_x_corner, circle_y_corner), 5, (255, 0, 0), -1)
  84. # 更新显示
  85. cv2.imshow('img', param)
  86. cls = get_text_input('请输入类别:0.背景 1.浑浊 -1.不参与')
  87. # 显示标签文本
  88. cv2.putText(param, str(cls), (circle_x_center + 10, circle_y_center + 10),
  89. cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
  90. # 更新显示
  91. cv2.imshow('img', param)
  92. valid_cls = [0, 1, -1]
  93. if cls in valid_cls:
  94. print(f"点击网格角点: ({circle_x_corner}, {circle_y_corner}) 中心点: ({circle_x_center}, {circle_y_center}) 类别:{cls}")
  95. # 记录标注数据,角点 u v w h cls
  96. pos, is_exist = det_same_pos_point(circle_x_corner, circle_y_corner) # 判断是否重复点击
  97. if is_exist:
  98. print(f"已存在该点: ({clicked_points[pos][0]}, {clicked_points[pos][1]}) 类别: {clicked_points[pos][4]}")
  99. clicked_points[pos][4] = cls
  100. print(f'重新标注该点: ({clicked_points[pos][0]}, {clicked_points[pos][1]}) 类别: {clicked_points[pos][4]}')
  101. else:
  102. print(f"添加点: ({circle_x_corner}, {circle_y_corner}) 类别: {cls}")
  103. clicked_points.append([circle_x_corner, circle_y_corner, scale_patch_w, scale_patch_h, cls])
  104. else:
  105. print("请输入正确的类别!")
  106. elif event == cv2.EVENT_RBUTTONDOWN: # 右键点击
  107. removed_point = clicked_points.pop()
  108. print(f"撤销标注点: ({removed_point[0]}, {removed_point[1]}) 类别: {removed_point[4]}")
  109. # 将撤销点标记为黑色
  110. x = removed_point[0]
  111. y = removed_point[1]
  112. # 在点击位置绘制黑色圆点
  113. scale_patch_w = patch_w // scale
  114. scale_patch_h = patch_h // scale
  115. # 格子角点
  116. circle_x_corner = (x // scale_patch_w)*scale_patch_w
  117. circle_y_corner = (y // scale_patch_h)*scale_patch_h
  118. # 格子中心点
  119. circle_x_center = circle_x_corner + scale_patch_w//2
  120. circle_y_center = circle_y_corner + scale_patch_h//2
  121. cv2.circle(param, (circle_x_center, circle_y_center), 5, (128, 128, 128), -1)
  122. cv2.circle(param, (circle_x_corner, circle_y_corner), 5, (128, 128, 128), -1)
  123. # 更新显示
  124. cv2.imshow('img', param)
  125. def remove_duplicates(arr:list):
  126. """列表去重"""
  127. unique_list = []
  128. [unique_list.append(item) for item in arr if item not in unique_list]
  129. return unique_list
  130. def play_video(video_path):
  131. global scale
  132. dir_name = os.path.dirname(video_path)
  133. for i in os.listdir(dir_name):
  134. frame = cv2.imread(os.path.join(dir_name, i))
  135. if frame is None:
  136. continue
  137. # resize
  138. frame = cv2.resize(frame, (frame.shape[1] // scale, frame.shape[0] // scale))
  139. cv2.imshow('Video', frame)
  140. # 按esc退出
  141. if cv2.waitKey(20) == 27:
  142. cv2.destroyAllWindows()
  143. break
  144. def main():
  145. """
  146. 固定摄像头标注,只需要标注一张图像,后续图像保持一致
  147. 1.标注过程,先将图像划分为图框,用cv2划线工具在图像上划网格线
  148. 2.用鼠标进行交互,点击图块输入标签,按下空格键完成交互过程,保存标签
  149. 3.标签格式:u,v,w,h,label u,v为块左上角坐标,w,h为块的宽和高,label为块的标签
  150. """
  151. global clicked_points
  152. global patch_w
  153. global patch_h
  154. global scale
  155. # TODO: 需要更改为准备标注的图像路径,使用当前目录下的000000.jpg,结果保存在当前目录下label.txt
  156. img_path = r'/frame_data/test/video4_20251129120320_20251129123514\000000.jpg'
  157. play_video(img_path)
  158. img = cv2.imread(img_path)
  159. draw_tool.set_path(img_path)
  160. # resize 图像太大了显示不全
  161. img = cv2.resize(img, (img.shape[1] // scale, img.shape[0] // scale))
  162. # 绘制网格线
  163. draw_grid(img, patch_w // scale, patch_h // scale)
  164. # 交互标注
  165. print("操作说明:")
  166. print("- 点击鼠标左键在图像上添加红色标记点: 0.其他 1.浑浊 -1.忽略,不参与训练和测试")
  167. print("- 点击鼠标右键撤回上一个红色标记点")
  168. print("- 按 'c' 键清除所有标记点")
  169. print("- 按 ESC 键退出程序")
  170. cv2.namedWindow('img')
  171. cv2.setMouseCallback('img', mouse_callback, img)
  172. # 交互标注
  173. while True:
  174. # 更新显示
  175. cv2.imshow('img', draw_tool.draw_new_img(clicked_points))
  176. key = cv2.waitKey(1) & 0xFF
  177. # 按 'c' 键清除所有标记点
  178. if key == ord('c'):
  179. img = cv2.imread(img_path)
  180. img = cv2.resize(img, (img.shape[1] // scale, img.shape[0] // scale))
  181. draw_grid(img, patch_w // scale, patch_h // scale)
  182. clicked_points.clear()
  183. cv2.setMouseCallback('img', mouse_callback, img)
  184. print("已清除所有标记点")
  185. # 按 ESC 键退出
  186. elif key == 27: # ESC键
  187. break
  188. cv2.destroyAllWindows()
  189. # 输出所有点击位置
  190. # 列表去重
  191. clicked_points = remove_duplicates(clicked_points)
  192. print(f"总共标记了 {len(clicked_points)} 个点:")
  193. for i, point in enumerate(clicked_points):
  194. print(f" 点 {i + 1}: ({point[0]}, {point[1]}, {point[2]}, {point[3]}, {point[4]})")
  195. # 恢复尺寸
  196. clicked_points = [[p[0]*scale, p[1]*scale, p[2]*scale, p[3]*scale, p[4]] for p in clicked_points]
  197. # 写入txt
  198. if clicked_points:
  199. with open(os.path.join(os.path.dirname(img_path), 'label.txt'), 'w') as fw:
  200. for point in clicked_points:
  201. fw.write(f"{point[0]},{point[1]},{point[2]},{point[3]},{point[4]}\n")
  202. # 保存点
  203. print(f"保存标记点 {len(clicked_points)} 个:")
  204. for i, point in enumerate(clicked_points):
  205. print(f" 点 {i + 1}: ({point[0]}, {point[1]}, {point[2]}, {point[3]}, {point[4]})")
  206. else :
  207. print("没有标记点!不保存任何文件!")
  208. if __name__ == '__main__':
  209. main()