image.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. """
  2. 图像预处理工具
  3. """
  4. import cv2
  5. import numpy as np
  6. from PIL import Image, ImageEnhance, ImageFilter
  7. from typing import Tuple, Optional, Union
  8. import logging
  9. logger = logging.getLogger(__name__)
  10. class ImagePreprocessor:
  11. """图像预处理器"""
  12. def __init__(self):
  13. """初始化图像预处理器"""
  14. pass
  15. def resize(self, image: np.ndarray, size: Tuple[int, int],
  16. method: str = 'bilinear') -> np.ndarray:
  17. """调整图像大小"""
  18. if method == 'bilinear':
  19. return cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
  20. elif method == 'nearest':
  21. return cv2.resize(image, size, interpolation=cv2.INTER_NEAREST)
  22. elif method == 'cubic':
  23. return cv2.resize(image, size, interpolation=cv2.INTER_CUBIC)
  24. else:
  25. raise ValueError(f"不支持的插值方法: {method}")
  26. def crop(self, image: np.ndarray, x: int, y: int,
  27. width: int, height: int) -> np.ndarray:
  28. """裁剪图像"""
  29. return image[y:y+height, x:x+width]
  30. def center_crop(self, image: np.ndarray, size: Tuple[int, int]) -> np.ndarray:
  31. """中心裁剪"""
  32. h, w = image.shape[:2]
  33. crop_h, crop_w = size
  34. start_h = (h - crop_h) // 2
  35. start_w = (w - crop_w) // 2
  36. return image[start_h:start_h+crop_h, start_w:start_w+crop_w]
  37. def normalize(self, image: np.ndarray,
  38. mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
  39. std: Tuple[float, float, float] = (0.229, 0.224, 0.225)) -> np.ndarray:
  40. """标准化图像"""
  41. image = image.astype(np.float32) / 255.0
  42. if len(image.shape) == 3:
  43. for i in range(3):
  44. image[:, :, i] = (image[:, :, i] - mean[i]) / std[i]
  45. else:
  46. image = (image - mean[0]) / std[0]
  47. return image
  48. def to_tensor(self, image: np.ndarray) -> np.ndarray:
  49. """转换为张量格式 (H, W, C) -> (C, H, W)"""
  50. if len(image.shape) == 3:
  51. return np.transpose(image, (2, 0, 1))
  52. return image
  53. class ImageAugmenter:
  54. """图像增强器"""
  55. def __init__(self):
  56. """初始化图像增强器"""
  57. pass
  58. def random_horizontal_flip(self, image: np.ndarray, p: float = 0.5) -> np.ndarray:
  59. """随机水平翻转"""
  60. if np.random.random() < p:
  61. return cv2.flip(image, 1)
  62. return image
  63. def random_vertical_flip(self, image: np.ndarray, p: float = 0.5) -> np.ndarray:
  64. """随机垂直翻转"""
  65. if np.random.random() < p:
  66. return cv2.flip(image, 0)
  67. return image
  68. def random_rotation(self, image: np.ndarray, max_angle: float = 15) -> np.ndarray:
  69. """随机旋转"""
  70. angle = np.random.uniform(-max_angle, max_angle)
  71. h, w = image.shape[:2]
  72. center = (w // 2, h // 2)
  73. matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
  74. return cv2.warpAffine(image, matrix, (w, h))
  75. def random_brightness(self, image: np.ndarray, factor_range: Tuple[float, float] = (0.8, 1.2)) -> np.ndarray:
  76. """随机亮度调整"""
  77. factor = np.random.uniform(factor_range[0], factor_range[1])
  78. return np.clip(image * factor, 0, 255).astype(np.uint8)
  79. def random_contrast(self, image: np.ndarray, factor_range: Tuple[float, float] = (0.8, 1.2)) -> np.ndarray:
  80. """随机对比度调整"""
  81. factor = np.random.uniform(factor_range[0], factor_range[1])
  82. mean = np.mean(image)
  83. return np.clip((image - mean) * factor + mean, 0, 255).astype(np.uint8)
  84. def random_saturation(self, image: np.ndarray, factor_range: Tuple[float, float] = (0.8, 1.2)) -> np.ndarray:
  85. """随机饱和度调整"""
  86. if len(image.shape) == 3:
  87. hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
  88. factor = np.random.uniform(factor_range[0], factor_range[1])
  89. hsv[:, :, 1] = np.clip(hsv[:, :, 1] * factor, 0, 255)
  90. return cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
  91. return image
  92. def random_noise(self, image: np.ndarray, noise_factor: float = 0.1) -> np.ndarray:
  93. """随机噪声"""
  94. noise = np.random.normal(0, noise_factor * 255, image.shape)
  95. return np.clip(image + noise, 0, 255).astype(np.uint8)
  96. def random_crop(self, image: np.ndarray, crop_size: Tuple[int, int]) -> np.ndarray:
  97. """随机裁剪"""
  98. h, w = image.shape[:2]
  99. crop_h, crop_w = crop_size
  100. if h < crop_h or w < crop_w:
  101. return cv2.resize(image, crop_size)
  102. start_h = np.random.randint(0, h - crop_h + 1)
  103. start_w = np.random.randint(0, w - crop_w + 1)
  104. return image[start_h:start_h+crop_h, start_w:start_w+crop_w]
  105. class ImageTransforms:
  106. """图像变换组合"""
  107. def __init__(self, transforms: list):
  108. """初始化变换组合"""
  109. self.transforms = transforms
  110. def __call__(self, image: np.ndarray) -> np.ndarray:
  111. """应用所有变换"""
  112. for transform in self.transforms:
  113. image = transform(image)
  114. return image
  115. @staticmethod
  116. def get_train_transforms(image_size: Tuple[int, int] = (224, 224)):
  117. """获取训练时的变换"""
  118. return ImageTransforms([
  119. lambda img: ImageAugmenter().random_horizontal_flip(img, p=0.5),
  120. lambda img: ImageAugmenter().random_rotation(img, max_angle=15),
  121. lambda img: ImageAugmenter().random_brightness(img),
  122. lambda img: ImageAugmenter().random_contrast(img),
  123. lambda img: ImagePreprocessor().resize(img, image_size),
  124. lambda img: ImagePreprocessor().normalize(img)
  125. ])
  126. @staticmethod
  127. def get_test_transforms(image_size: Tuple[int, int] = (224, 224)):
  128. """获取测试时的变换"""
  129. return ImageTransforms([
  130. lambda img: ImagePreprocessor().resize(img, image_size),
  131. lambda img: ImagePreprocessor().normalize(img)
  132. ])