model_def.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # -*- coding: utf-8 -*-
  2. """
  3. model_def.py - 卷积自编码器模型定义(部署版)
  4. =============================================
  5. 4层卷积自编码器,用于音频异常检测。
  6. 此文件与训练环境的model_def.py结构相同。
  7. 模型结构:
  8. 输入: [B, 1, 64, 504]
  9. 瓶颈: [B, 64, 4, 32] (压缩比4:1)
  10. 输出: [B, 1, 64, 504]
  11. """
  12. import torch
  13. import torch.nn as nn
  14. from .config import CFG
  15. from .utils import get_device
  16. class ConvAutoencoder(nn.Module):
  17. """
  18. 4层卷积自编码器
  19. 编码器: 4次stride=2下采样
  20. 解码器: 4次stride=2上采样
  21. 参数:
  22. in_ch: 输入通道数,默认1
  23. base_ch: 基础通道数,默认8
  24. """
  25. def __init__(self, in_ch=1, base_ch=8):
  26. """初始化模型"""
  27. super().__init__()
  28. # 编码器: 4层下采样
  29. self.encoder = nn.Sequential(
  30. # 第1层: 1→8通道
  31. nn.Conv2d(in_ch, base_ch, 3, stride=2, padding=1),
  32. nn.BatchNorm2d(base_ch),
  33. nn.ReLU(True),
  34. # 第2层: 8→16通道
  35. nn.Conv2d(base_ch, base_ch*2, 3, stride=2, padding=1),
  36. nn.BatchNorm2d(base_ch*2),
  37. nn.ReLU(True),
  38. # 第3层: 16→32通道
  39. nn.Conv2d(base_ch*2, base_ch*4, 3, stride=2, padding=1),
  40. nn.BatchNorm2d(base_ch*4),
  41. nn.ReLU(True),
  42. # 第4层: 32→64通道(瓶颈)
  43. nn.Conv2d(base_ch*4, base_ch*8, 3, stride=2, padding=1),
  44. nn.BatchNorm2d(base_ch*8),
  45. nn.ReLU(True),
  46. )
  47. # 解码器: 4层上采样
  48. self.decoder = nn.Sequential(
  49. # 第1层: 64→32通道
  50. nn.ConvTranspose2d(base_ch*8, base_ch*4, 3, stride=2, padding=1, output_padding=1),
  51. nn.BatchNorm2d(base_ch*4),
  52. nn.ReLU(True),
  53. # 第2层: 32→16通道
  54. nn.ConvTranspose2d(base_ch*4, base_ch*2, 3, stride=2, padding=1, output_padding=1),
  55. nn.BatchNorm2d(base_ch*2),
  56. nn.ReLU(True),
  57. # 第3层: 16→8通道
  58. nn.ConvTranspose2d(base_ch*2, base_ch, 3, stride=2, padding=1, output_padding=1),
  59. nn.BatchNorm2d(base_ch),
  60. nn.ReLU(True),
  61. # 第4层: 8→1通道
  62. nn.ConvTranspose2d(base_ch, in_ch, 3, stride=2, padding=1, output_padding=1),
  63. )
  64. def forward(self, x):
  65. """
  66. 前向传播
  67. 参数:
  68. x: 输入tensor [B, 1, H, W]
  69. 返回:
  70. 重构tensor [B, 1, H', W']
  71. """
  72. # 编码
  73. z = self.encoder(x)
  74. # 解码
  75. out = self.decoder(z)
  76. return out
  77. def load_trained_model():
  78. """
  79. 加载训练好的模型
  80. 返回:
  81. 元组 (model, device)
  82. """
  83. # 获取设备
  84. device = get_device()
  85. # 创建模型
  86. model = ConvAutoencoder().to(device)
  87. # 加载权重
  88. state = torch.load(CFG.AE_MODEL_PATH, map_location=device)
  89. model.load_state_dict(state)
  90. # 设置评估模式
  91. model.eval()
  92. return model, device