| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- """
- 图像预处理工具
- """
- import cv2
- import numpy as np
- from PIL import Image, ImageEnhance, ImageFilter
- from typing import Tuple, Optional, Union
- import logging
- logger = logging.getLogger(__name__)
- class ImagePreprocessor:
- """图像预处理器"""
-
- def __init__(self):
- """初始化图像预处理器"""
- pass
-
- def resize(self, image: np.ndarray, size: Tuple[int, int],
- method: str = 'bilinear') -> np.ndarray:
- """调整图像大小"""
- if method == 'bilinear':
- return cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
- elif method == 'nearest':
- return cv2.resize(image, size, interpolation=cv2.INTER_NEAREST)
- elif method == 'cubic':
- return cv2.resize(image, size, interpolation=cv2.INTER_CUBIC)
- else:
- raise ValueError(f"不支持的插值方法: {method}")
-
- def crop(self, image: np.ndarray, x: int, y: int,
- width: int, height: int) -> np.ndarray:
- """裁剪图像"""
- return image[y:y+height, x:x+width]
-
- def center_crop(self, image: np.ndarray, size: Tuple[int, int]) -> np.ndarray:
- """中心裁剪"""
- h, w = image.shape[:2]
- crop_h, crop_w = size
-
- start_h = (h - crop_h) // 2
- start_w = (w - crop_w) // 2
-
- return image[start_h:start_h+crop_h, start_w:start_w+crop_w]
-
- def normalize(self, image: np.ndarray,
- mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
- std: Tuple[float, float, float] = (0.229, 0.224, 0.225)) -> np.ndarray:
- """标准化图像"""
- image = image.astype(np.float32) / 255.0
-
- if len(image.shape) == 3:
- for i in range(3):
- image[:, :, i] = (image[:, :, i] - mean[i]) / std[i]
- else:
- image = (image - mean[0]) / std[0]
-
- return image
-
- def to_tensor(self, image: np.ndarray) -> np.ndarray:
- """转换为张量格式 (H, W, C) -> (C, H, W)"""
- if len(image.shape) == 3:
- return np.transpose(image, (2, 0, 1))
- return image
- class ImageAugmenter:
- """图像增强器"""
-
- def __init__(self):
- """初始化图像增强器"""
- pass
-
- def random_horizontal_flip(self, image: np.ndarray, p: float = 0.5) -> np.ndarray:
- """随机水平翻转"""
- if np.random.random() < p:
- return cv2.flip(image, 1)
- return image
-
- def random_vertical_flip(self, image: np.ndarray, p: float = 0.5) -> np.ndarray:
- """随机垂直翻转"""
- if np.random.random() < p:
- return cv2.flip(image, 0)
- return image
-
- def random_rotation(self, image: np.ndarray, max_angle: float = 15) -> np.ndarray:
- """随机旋转"""
- angle = np.random.uniform(-max_angle, max_angle)
- h, w = image.shape[:2]
- center = (w // 2, h // 2)
- matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
- return cv2.warpAffine(image, matrix, (w, h))
-
- def random_brightness(self, image: np.ndarray, factor_range: Tuple[float, float] = (0.8, 1.2)) -> np.ndarray:
- """随机亮度调整"""
- factor = np.random.uniform(factor_range[0], factor_range[1])
- return np.clip(image * factor, 0, 255).astype(np.uint8)
-
- def random_contrast(self, image: np.ndarray, factor_range: Tuple[float, float] = (0.8, 1.2)) -> np.ndarray:
- """随机对比度调整"""
- factor = np.random.uniform(factor_range[0], factor_range[1])
- mean = np.mean(image)
- return np.clip((image - mean) * factor + mean, 0, 255).astype(np.uint8)
-
- def random_saturation(self, image: np.ndarray, factor_range: Tuple[float, float] = (0.8, 1.2)) -> np.ndarray:
- """随机饱和度调整"""
- if len(image.shape) == 3:
- hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
- factor = np.random.uniform(factor_range[0], factor_range[1])
- hsv[:, :, 1] = np.clip(hsv[:, :, 1] * factor, 0, 255)
- return cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
- return image
-
- def random_noise(self, image: np.ndarray, noise_factor: float = 0.1) -> np.ndarray:
- """随机噪声"""
- noise = np.random.normal(0, noise_factor * 255, image.shape)
- return np.clip(image + noise, 0, 255).astype(np.uint8)
-
- def random_crop(self, image: np.ndarray, crop_size: Tuple[int, int]) -> np.ndarray:
- """随机裁剪"""
- h, w = image.shape[:2]
- crop_h, crop_w = crop_size
-
- if h < crop_h or w < crop_w:
- return cv2.resize(image, crop_size)
-
- start_h = np.random.randint(0, h - crop_h + 1)
- start_w = np.random.randint(0, w - crop_w + 1)
-
- return image[start_h:start_h+crop_h, start_w:start_w+crop_w]
- class ImageTransforms:
- """图像变换组合"""
-
- def __init__(self, transforms: list):
- """初始化变换组合"""
- self.transforms = transforms
-
- def __call__(self, image: np.ndarray) -> np.ndarray:
- """应用所有变换"""
- for transform in self.transforms:
- image = transform(image)
- return image
-
- @staticmethod
- def get_train_transforms(image_size: Tuple[int, int] = (224, 224)):
- """获取训练时的变换"""
- return ImageTransforms([
- lambda img: ImageAugmenter().random_horizontal_flip(img, p=0.5),
- lambda img: ImageAugmenter().random_rotation(img, max_angle=15),
- lambda img: ImageAugmenter().random_brightness(img),
- lambda img: ImageAugmenter().random_contrast(img),
- lambda img: ImagePreprocessor().resize(img, image_size),
- lambda img: ImagePreprocessor().normalize(img)
- ])
-
- @staticmethod
- def get_test_transforms(image_size: Tuple[int, int] = (224, 224)):
- """获取测试时的变换"""
- return ImageTransforms([
- lambda img: ImagePreprocessor().resize(img, image_size),
- lambda img: ImagePreprocessor().normalize(img)
- ])
|