| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- # -*- coding: utf-8 -*-
- """
- model_def.py - 卷积自编码器模型定义(部署版)
- =============================================
- 4层卷积自编码器,用于音频异常检测。
- 此文件与训练环境的model_def.py结构相同。
- 模型结构:
- 输入: [B, 1, 64, 504]
- 瓶颈: [B, 64, 4, 32] (压缩比4:1)
- 输出: [B, 1, 64, 504]
- """
- import torch
- import torch.nn as nn
- from .config import CFG
- from .utils import get_device
- class ConvAutoencoder(nn.Module):
- """
- 4层卷积自编码器
-
- 编码器: 4次stride=2下采样
- 解码器: 4次stride=2上采样
-
- 参数:
- in_ch: 输入通道数,默认1
- base_ch: 基础通道数,默认8
- """
-
- def __init__(self, in_ch=1, base_ch=8):
- """初始化模型"""
- super().__init__()
-
- # 编码器: 4层下采样
- self.encoder = nn.Sequential(
- # 第1层: 1→8通道
- nn.Conv2d(in_ch, base_ch, 3, stride=2, padding=1),
- nn.BatchNorm2d(base_ch),
- nn.ReLU(True),
-
- # 第2层: 8→16通道
- nn.Conv2d(base_ch, base_ch*2, 3, stride=2, padding=1),
- nn.BatchNorm2d(base_ch*2),
- nn.ReLU(True),
-
- # 第3层: 16→32通道
- nn.Conv2d(base_ch*2, base_ch*4, 3, stride=2, padding=1),
- nn.BatchNorm2d(base_ch*4),
- nn.ReLU(True),
-
- # 第4层: 32→64通道(瓶颈)
- nn.Conv2d(base_ch*4, base_ch*8, 3, stride=2, padding=1),
- nn.BatchNorm2d(base_ch*8),
- nn.ReLU(True),
- )
-
- # 解码器: 4层上采样
- self.decoder = nn.Sequential(
- # 第1层: 64→32通道
- nn.ConvTranspose2d(base_ch*8, base_ch*4, 3, stride=2, padding=1, output_padding=1),
- nn.BatchNorm2d(base_ch*4),
- nn.ReLU(True),
-
- # 第2层: 32→16通道
- nn.ConvTranspose2d(base_ch*4, base_ch*2, 3, stride=2, padding=1, output_padding=1),
- nn.BatchNorm2d(base_ch*2),
- nn.ReLU(True),
-
- # 第3层: 16→8通道
- nn.ConvTranspose2d(base_ch*2, base_ch, 3, stride=2, padding=1, output_padding=1),
- nn.BatchNorm2d(base_ch),
- nn.ReLU(True),
-
- # 第4层: 8→1通道
- nn.ConvTranspose2d(base_ch, in_ch, 3, stride=2, padding=1, output_padding=1),
- )
-
- def forward(self, x):
- """
- 前向传播
-
- 参数:
- x: 输入tensor [B, 1, H, W]
-
- 返回:
- 重构tensor [B, 1, H', W']
- """
- # 编码
- z = self.encoder(x)
- # 解码
- out = self.decoder(z)
- return out
- def load_trained_model():
- """
- 加载训练好的模型
-
- 返回:
- 元组 (model, device)
- """
- # 获取设备
- device = get_device()
-
- # 创建模型
- model = ConvAutoencoder().to(device)
-
- # 加载权重
- state = torch.load(CFG.AE_MODEL_PATH, map_location=device)
- model.load_state_dict(state)
-
- # 设置评估模式
- model.eval()
-
- return model, device
|