""" 图像预处理工具 """ import cv2 import numpy as np from PIL import Image, ImageEnhance, ImageFilter from typing import Tuple, Optional, Union import logging logger = logging.getLogger(__name__) class ImagePreprocessor: """图像预处理器""" def __init__(self): """初始化图像预处理器""" pass def resize(self, image: np.ndarray, size: Tuple[int, int], method: str = 'bilinear') -> np.ndarray: """调整图像大小""" if method == 'bilinear': return cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) elif method == 'nearest': return cv2.resize(image, size, interpolation=cv2.INTER_NEAREST) elif method == 'cubic': return cv2.resize(image, size, interpolation=cv2.INTER_CUBIC) else: raise ValueError(f"不支持的插值方法: {method}") def crop(self, image: np.ndarray, x: int, y: int, width: int, height: int) -> np.ndarray: """裁剪图像""" return image[y:y+height, x:x+width] def center_crop(self, image: np.ndarray, size: Tuple[int, int]) -> np.ndarray: """中心裁剪""" h, w = image.shape[:2] crop_h, crop_w = size start_h = (h - crop_h) // 2 start_w = (w - crop_w) // 2 return image[start_h:start_h+crop_h, start_w:start_w+crop_w] def normalize(self, image: np.ndarray, mean: Tuple[float, float, float] = (0.485, 0.456, 0.406), std: Tuple[float, float, float] = (0.229, 0.224, 0.225)) -> np.ndarray: """标准化图像""" image = image.astype(np.float32) / 255.0 if len(image.shape) == 3: for i in range(3): image[:, :, i] = (image[:, :, i] - mean[i]) / std[i] else: image = (image - mean[0]) / std[0] return image def to_tensor(self, image: np.ndarray) -> np.ndarray: """转换为张量格式 (H, W, C) -> (C, H, W)""" if len(image.shape) == 3: return np.transpose(image, (2, 0, 1)) return image class ImageAugmenter: """图像增强器""" def __init__(self): """初始化图像增强器""" pass def random_horizontal_flip(self, image: np.ndarray, p: float = 0.5) -> np.ndarray: """随机水平翻转""" if np.random.random() < p: return cv2.flip(image, 1) return image def random_vertical_flip(self, image: np.ndarray, p: float = 0.5) -> np.ndarray: """随机垂直翻转""" if np.random.random() < p: return cv2.flip(image, 0) return image def random_rotation(self, image: np.ndarray, max_angle: float = 15) -> np.ndarray: """随机旋转""" angle = np.random.uniform(-max_angle, max_angle) h, w = image.shape[:2] center = (w // 2, h // 2) matrix = cv2.getRotationMatrix2D(center, angle, 1.0) return cv2.warpAffine(image, matrix, (w, h)) def random_brightness(self, image: np.ndarray, factor_range: Tuple[float, float] = (0.8, 1.2)) -> np.ndarray: """随机亮度调整""" factor = np.random.uniform(factor_range[0], factor_range[1]) return np.clip(image * factor, 0, 255).astype(np.uint8) def random_contrast(self, image: np.ndarray, factor_range: Tuple[float, float] = (0.8, 1.2)) -> np.ndarray: """随机对比度调整""" factor = np.random.uniform(factor_range[0], factor_range[1]) mean = np.mean(image) return np.clip((image - mean) * factor + mean, 0, 255).astype(np.uint8) def random_saturation(self, image: np.ndarray, factor_range: Tuple[float, float] = (0.8, 1.2)) -> np.ndarray: """随机饱和度调整""" if len(image.shape) == 3: hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) factor = np.random.uniform(factor_range[0], factor_range[1]) hsv[:, :, 1] = np.clip(hsv[:, :, 1] * factor, 0, 255) return cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) return image def random_noise(self, image: np.ndarray, noise_factor: float = 0.1) -> np.ndarray: """随机噪声""" noise = np.random.normal(0, noise_factor * 255, image.shape) return np.clip(image + noise, 0, 255).astype(np.uint8) def random_crop(self, image: np.ndarray, crop_size: Tuple[int, int]) -> np.ndarray: """随机裁剪""" h, w = image.shape[:2] crop_h, crop_w = crop_size if h < crop_h or w < crop_w: return cv2.resize(image, crop_size) start_h = np.random.randint(0, h - crop_h + 1) start_w = np.random.randint(0, w - crop_w + 1) return image[start_h:start_h+crop_h, start_w:start_w+crop_w] class ImageTransforms: """图像变换组合""" def __init__(self, transforms: list): """初始化变换组合""" self.transforms = transforms def __call__(self, image: np.ndarray) -> np.ndarray: """应用所有变换""" for transform in self.transforms: image = transform(image) return image @staticmethod def get_train_transforms(image_size: Tuple[int, int] = (224, 224)): """获取训练时的变换""" return ImageTransforms([ lambda img: ImageAugmenter().random_horizontal_flip(img, p=0.5), lambda img: ImageAugmenter().random_rotation(img, max_angle=15), lambda img: ImageAugmenter().random_brightness(img), lambda img: ImageAugmenter().random_contrast(img), lambda img: ImagePreprocessor().resize(img, image_size), lambda img: ImagePreprocessor().normalize(img) ]) @staticmethod def get_test_transforms(image_size: Tuple[int, int] = (224, 224)): """获取测试时的变换""" return ImageTransforms([ lambda img: ImagePreprocessor().resize(img, image_size), lambda img: ImagePreprocessor().normalize(img) ])