# -*- coding: utf-8 -*- """ datasets.py - PyTorch Dataset封装 ================================= 提供训练和推理所需的数据集类 """ from pathlib import Path from typing import List, Union import numpy as np import torch from torch.utils.data import Dataset from .config import CFG class MelNPYDataset(Dataset): """ Mel频谱数据集 从npy文件加载mel频谱数据。 支持两种初始化方式: 1. 传入目录路径(自动扫描) 2. 传入文件列表 """ def __init__(self, mel_files_or_dir: Union[Path, List[Path]] = None, target_frames: int = None): """ 初始化 参数: mel_files_or_dir: npy文件列表或包含npy文件的目录 target_frames: 目标帧数,用于对齐(默认使用CFG.TARGET_FRAMES) """ # 处理文件列表或目录 if mel_files_or_dir is None: mel_dir = Path(CFG.AUDIO_DIR) self.files = sorted(mel_dir.glob("**/*.npy")) elif isinstance(mel_files_or_dir, (list, tuple)): self.files = [Path(f) for f in mel_files_or_dir] else: mel_dir = Path(mel_files_or_dir) self.files = sorted(mel_dir.glob("**/*.npy")) # 目标帧数 self.target_frames = target_frames or CFG.TARGET_FRAMES def __len__(self): # 返回数据集大小 return len(self.files) def __getitem__(self, idx): # 加载npy文件,形状为 [n_mels, frames] arr = np.load(self.files[idx]) # 对齐帧数 if arr.shape[1] > self.target_frames: # 截断 arr = arr[:, :self.target_frames] elif arr.shape[1] < self.target_frames: # 填充 pad_width = self.target_frames - arr.shape[1] arr = np.pad(arr, ((0, 0), (0, pad_width)), mode='constant') # 增加通道维度,变为 [1, n_mels, frames] arr = np.expand_dims(arr, 0) # 转换为PyTorch tensor return torch.from_numpy(arr).float()