loader.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. """
  2. 通用数据加载器
  3. """
  4. import pandas as pd
  5. import numpy as np
  6. from pathlib import Path
  7. from typing import Union, Dict, Any, Optional
  8. import logging
  9. logger = logging.getLogger(__name__)
  10. class DataLoader:
  11. """通用数据加载器"""
  12. def __init__(self, data_path: Union[str, Path]):
  13. """初始化数据加载器"""
  14. self.data_path = Path(data_path)
  15. def load_csv(self, **kwargs) -> pd.DataFrame:
  16. """加载CSV文件"""
  17. try:
  18. df = pd.read_csv(self.data_path, **kwargs)
  19. logger.info(f"成功加载CSV文件: {self.data_path}")
  20. return df
  21. except Exception as e:
  22. logger.error(f"加载CSV文件失败: {e}")
  23. raise
  24. def load_json(self, **kwargs) -> pd.DataFrame:
  25. """加载JSON文件"""
  26. try:
  27. df = pd.read_json(self.data_path, **kwargs)
  28. logger.info(f"成功加载JSON文件: {self.data_path}")
  29. return df
  30. except Exception as e:
  31. logger.error(f"加载JSON文件失败: {e}")
  32. raise
  33. def load_parquet(self, **kwargs) -> pd.DataFrame:
  34. """加载Parquet文件"""
  35. try:
  36. df = pd.read_parquet(self.data_path, **kwargs)
  37. logger.info(f"成功加载Parquet文件: {self.data_path}")
  38. return df
  39. except Exception as e:
  40. logger.error(f"加载Parquet文件失败: {e}")
  41. raise
  42. def save_csv(self, df: pd.DataFrame, **kwargs) -> None:
  43. """保存为CSV文件"""
  44. try:
  45. df.to_csv(self.data_path, **kwargs)
  46. logger.info(f"成功保存CSV文件: {self.data_path}")
  47. except Exception as e:
  48. logger.error(f"保存CSV文件失败: {e}")
  49. raise
  50. def save_json(self, df: pd.DataFrame, **kwargs) -> None:
  51. """保存为JSON文件"""
  52. try:
  53. df.to_json(self.data_path, **kwargs)
  54. logger.info(f"成功保存JSON文件: {self.data_path}")
  55. except Exception as e:
  56. logger.error(f"保存JSON文件失败: {e}")
  57. raise
  58. def save_parquet(self, df: pd.DataFrame, **kwargs) -> None:
  59. """保存为Parquet文件"""
  60. try:
  61. df.to_parquet(self.data_path, **kwargs)
  62. logger.info(f"成功保存Parquet文件: {self.data_path}")
  63. except Exception as e:
  64. logger.error(f"保存Parquet文件失败: {e}")
  65. raise
  66. class ImageDataLoader:
  67. """图像数据加载器"""
  68. def __init__(self, data_dir: Union[str, Path]):
  69. """初始化图像数据加载器"""
  70. self.data_dir = Path(data_dir)
  71. def load_images(self, extensions: list = ['.jpg', '.jpeg', '.png', '.bmp']) -> list:
  72. """加载图像文件路径"""
  73. image_paths = []
  74. for ext in extensions:
  75. image_paths.extend(self.data_dir.glob(f'**/*{ext}'))
  76. image_paths.extend(self.data_dir.glob(f'**/*{ext.upper()}'))
  77. logger.info(f"找到 {len(image_paths)} 个图像文件")
  78. return sorted(image_paths)
  79. def create_dataset_info(self, image_paths: list, label_func: callable = None) -> pd.DataFrame:
  80. """创建数据集信息DataFrame"""
  81. data = []
  82. for img_path in image_paths:
  83. if label_func:
  84. label = label_func(img_path)
  85. else:
  86. # 默认从文件夹名获取标签
  87. label = img_path.parent.name
  88. data.append({
  89. 'image_path': str(img_path),
  90. 'label': label,
  91. 'filename': img_path.name
  92. })
  93. return pd.DataFrame(data)
  94. class TextDataLoader:
  95. """文本数据加载器"""
  96. def __init__(self, data_path: Union[str, Path]):
  97. """初始化文本数据加载器"""
  98. self.data_path = Path(data_path)
  99. def load_text(self, encoding: str = 'utf-8') -> str:
  100. """加载文本文件"""
  101. try:
  102. with open(self.data_path, 'r', encoding=encoding) as f:
  103. text = f.read()
  104. logger.info(f"成功加载文本文件: {self.data_path}")
  105. return text
  106. except Exception as e:
  107. logger.error(f"加载文本文件失败: {e}")
  108. raise
  109. def load_lines(self, encoding: str = 'utf-8') -> list:
  110. """按行加载文本文件"""
  111. try:
  112. with open(self.data_path, 'r', encoding=encoding) as f:
  113. lines = f.readlines()
  114. logger.info(f"成功加载文本文件,共 {len(lines)} 行: {self.data_path}")
  115. return lines
  116. except Exception as e:
  117. logger.error(f"加载文本文件失败: {e}")
  118. raise