# -*- coding: utf-8 -*- """ model_def.py - 卷积自编码器模型定义(部署版) ============================================= 4层卷积自编码器,用于音频异常检测。 此文件与训练环境的model_def.py结构相同。 模型结构: 输入: [B, 1, 64, 504] 瓶颈: [B, 64, 4, 32] (压缩比4:1) 输出: [B, 1, 64, 504] """ import torch import torch.nn as nn from .config import CFG from .utils import get_device class ConvAutoencoder(nn.Module): """ 4层卷积自编码器 编码器: 4次stride=2下采样 解码器: 4次stride=2上采样 参数: in_ch: 输入通道数,默认1 base_ch: 基础通道数,默认8 """ def __init__(self, in_ch=1, base_ch=8): """初始化模型""" super().__init__() # 编码器: 4层下采样 self.encoder = nn.Sequential( # 第1层: 1→8通道 nn.Conv2d(in_ch, base_ch, 3, stride=2, padding=1), nn.BatchNorm2d(base_ch), nn.ReLU(True), # 第2层: 8→16通道 nn.Conv2d(base_ch, base_ch*2, 3, stride=2, padding=1), nn.BatchNorm2d(base_ch*2), nn.ReLU(True), # 第3层: 16→32通道 nn.Conv2d(base_ch*2, base_ch*4, 3, stride=2, padding=1), nn.BatchNorm2d(base_ch*4), nn.ReLU(True), # 第4层: 32→64通道(瓶颈) nn.Conv2d(base_ch*4, base_ch*8, 3, stride=2, padding=1), nn.BatchNorm2d(base_ch*8), nn.ReLU(True), ) # 解码器: 4层上采样 self.decoder = nn.Sequential( # 第1层: 64→32通道 nn.ConvTranspose2d(base_ch*8, base_ch*4, 3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(base_ch*4), nn.ReLU(True), # 第2层: 32→16通道 nn.ConvTranspose2d(base_ch*4, base_ch*2, 3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(base_ch*2), nn.ReLU(True), # 第3层: 16→8通道 nn.ConvTranspose2d(base_ch*2, base_ch, 3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(base_ch), nn.ReLU(True), # 第4层: 8→1通道 nn.ConvTranspose2d(base_ch, in_ch, 3, stride=2, padding=1, output_padding=1), ) def forward(self, x): """ 前向传播 参数: x: 输入tensor [B, 1, H, W] 返回: 重构tensor [B, 1, H', W'] """ # 编码 z = self.encoder(x) # 解码 out = self.decoder(z) return out def load_trained_model(): """ 加载训练好的模型 返回: 元组 (model, device) """ # 获取设备 device = get_device() # 创建模型 model = ConvAutoencoder().to(device) # 加载权重 state = torch.load(CFG.AE_MODEL_PATH, map_location=device) model.load_state_dict(state) # 设置评估模式 model.eval() return model, device