datasets.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # -*- coding: utf-8 -*-
  2. """
  3. datasets.py - PyTorch Dataset封装
  4. =================================
  5. 提供训练和推理所需的数据集类
  6. """
  7. from pathlib import Path
  8. from typing import List, Union
  9. import numpy as np
  10. import torch
  11. from torch.utils.data import Dataset
  12. from .config import CFG
  13. class MelNPYDataset(Dataset):
  14. """
  15. Mel频谱数据集
  16. 从npy文件加载mel频谱数据。
  17. 支持两种初始化方式:
  18. 1. 传入目录路径(自动扫描)
  19. 2. 传入文件列表
  20. """
  21. def __init__(self, mel_files_or_dir: Union[Path, List[Path]] = None,
  22. target_frames: int = None):
  23. """
  24. 初始化
  25. 参数:
  26. mel_files_or_dir: npy文件列表或包含npy文件的目录
  27. target_frames: 目标帧数,用于对齐(默认使用CFG.TARGET_FRAMES)
  28. """
  29. # 处理文件列表或目录
  30. if mel_files_or_dir is None:
  31. mel_dir = Path(CFG.AUDIO_DIR)
  32. self.files = sorted(mel_dir.glob("**/*.npy"))
  33. elif isinstance(mel_files_or_dir, (list, tuple)):
  34. self.files = [Path(f) for f in mel_files_or_dir]
  35. else:
  36. mel_dir = Path(mel_files_or_dir)
  37. self.files = sorted(mel_dir.glob("**/*.npy"))
  38. # 目标帧数
  39. self.target_frames = target_frames or CFG.TARGET_FRAMES
  40. def __len__(self):
  41. # 返回数据集大小
  42. return len(self.files)
  43. def __getitem__(self, idx):
  44. # 加载npy文件,形状为 [n_mels, frames]
  45. arr = np.load(self.files[idx])
  46. # 对齐帧数
  47. if arr.shape[1] > self.target_frames:
  48. # 截断
  49. arr = arr[:, :self.target_frames]
  50. elif arr.shape[1] < self.target_frames:
  51. # 填充
  52. pad_width = self.target_frames - arr.shape[1]
  53. arr = np.pad(arr, ((0, 0), (0, pad_width)), mode='constant')
  54. # 增加通道维度,变为 [1, n_mels, frames]
  55. arr = np.expand_dims(arr, 0)
  56. # 转换为PyTorch tensor
  57. return torch.from_numpy(arr).float()