| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- # -*- coding: utf-8 -*-
- """
- datasets.py - PyTorch Dataset封装
- =================================
- 提供训练和推理所需的数据集类
- """
- from pathlib import Path
- from typing import List, Union
- import numpy as np
- import torch
- from torch.utils.data import Dataset
- from .config import CFG
- class MelNPYDataset(Dataset):
- """
- Mel频谱数据集
-
- 从npy文件加载mel频谱数据。
- 支持两种初始化方式:
- 1. 传入目录路径(自动扫描)
- 2. 传入文件列表
- """
-
- def __init__(self, mel_files_or_dir: Union[Path, List[Path]] = None,
- target_frames: int = None):
- """
- 初始化
-
- 参数:
- mel_files_or_dir: npy文件列表或包含npy文件的目录
- target_frames: 目标帧数,用于对齐(默认使用CFG.TARGET_FRAMES)
- """
- # 处理文件列表或目录
- if mel_files_or_dir is None:
- mel_dir = Path(CFG.AUDIO_DIR)
- self.files = sorted(mel_dir.glob("**/*.npy"))
- elif isinstance(mel_files_or_dir, (list, tuple)):
- self.files = [Path(f) for f in mel_files_or_dir]
- else:
- mel_dir = Path(mel_files_or_dir)
- self.files = sorted(mel_dir.glob("**/*.npy"))
-
- # 目标帧数
- self.target_frames = target_frames or CFG.TARGET_FRAMES
-
- def __len__(self):
- # 返回数据集大小
- return len(self.files)
-
- def __getitem__(self, idx):
- # 加载npy文件,形状为 [n_mels, frames]
- arr = np.load(self.files[idx])
-
- # 对齐帧数
- if arr.shape[1] > self.target_frames:
- # 截断
- arr = arr[:, :self.target_frames]
- elif arr.shape[1] < self.target_frames:
- # 填充
- pad_width = self.target_frames - arr.shape[1]
- arr = np.pad(arr, ((0, 0), (0, pad_width)), mode='constant')
-
- # 增加通道维度,变为 [1, n_mels, frames]
- arr = np.expand_dims(arr, 0)
- # 转换为PyTorch tensor
- return torch.from_numpy(arr).float()
|