|
|
@@ -3,19 +3,29 @@ import time
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
-from fontTools.misc.timeTools import epoch_diff
|
|
|
from torch.utils.data import DataLoader
|
|
|
import torchvision.transforms as transforms
|
|
|
from torchvision.datasets import ImageFolder
|
|
|
-from torchvision.models import resnet18, ResNet18_Weights,resnet50,ResNet50_Weights, squeezenet1_0, SqueezeNet1_0_Weights,\
|
|
|
- shufflenet_v2_x1_0, ShuffleNet_V2_X1_0_Weights, swin_v2_s, Swin_V2_S_Weights, swin_v2_b, Swin_V2_B_Weights
|
|
|
-import matplotlib.pyplot as plt
|
|
|
-import numpy as np
|
|
|
+from torchvision.models import resnet18,resnet50, squeezenet1_0,shufflenet_v2_x1_0,shufflenet_v2_x2_0
|
|
|
from torch.utils.tensorboard import SummaryWriter # 添加 TensorBoard 支持
|
|
|
from datetime import datetime
|
|
|
import os
|
|
|
from dotenv import load_dotenv
|
|
|
+load_dotenv()
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = os.getenv('CUDA_VISIBLE_DEVICES', '0')
|
|
|
+# 读取并打印.env文件中的变量
|
|
|
+def print_env_variables():
|
|
|
+ print("从.env文件加载的变量:")
|
|
|
+ env_vars = {
|
|
|
+ 'PATCH_WIDTH': os.getenv('PATCH_WIDTH'),
|
|
|
+ 'PATCH_HEIGHT': os.getenv('PATCH_HEIGHT'),
|
|
|
+ 'CONFIDENCE_THRESHOLD': os.getenv('CONFIDENCE_THRESHOLD'),
|
|
|
+ 'IMG_INPUT_SIZE': os.getenv('IMG_INPUT_SIZE'),
|
|
|
+ 'WORKERS': os.getenv('WORKERS'),
|
|
|
+ 'CUDA_VISIBLE_DEVICES': os.getenv('CUDA_VISIBLE_DEVICES')
|
|
|
+ }
|
|
|
+ for var, value in env_vars.items():
|
|
|
+ print(f"{var}: {value}")
|
|
|
|
|
|
class Trainer:
|
|
|
def __init__(self, batch_size, train_dir, val_dir, name, checkpoint):
|
|
|
@@ -25,7 +35,19 @@ class Trainer:
|
|
|
self.batch_size = batch_size # 批次大小
|
|
|
self.cls_map = {"0": "non-muddy", "1":"muddy"} # 类别名称映射词典
|
|
|
self.imagenet = True # 是否使用ImageNet预训练权重
|
|
|
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 训练设备
|
|
|
+ # 训练设备 - 优先使用GPU,如果不可用则使用CPU
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ try:
|
|
|
+ # 尝试进行简单的CUDA操作以确认CUDA功能正常
|
|
|
+ _ = torch.zeros(1).cuda()
|
|
|
+ self.device = torch.device("cuda")
|
|
|
+ print("成功检测到CUDA设备,使用GPU进行训练")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"CUDA设备存在问题: {e},回退到CPU")
|
|
|
+ self.device = torch.device("cpu")
|
|
|
+ else:
|
|
|
+ self.device = torch.device("cpu")
|
|
|
+ print("CUDA不可用,使用CPU进行训练")
|
|
|
self.checkpoint = checkpoint
|
|
|
self.__global_step = 0
|
|
|
self.workers = int(os.getenv('WORKERS', 0))
|
|
|
@@ -82,21 +104,17 @@ class Trainer:
|
|
|
def __load_model(self):
|
|
|
"""加载模型结构"""
|
|
|
# 加载模型
|
|
|
+ pretrained = True if self.imagenet else False
|
|
|
if self.name == 'resnet50':
|
|
|
- self.weights = ResNet50_Weights.IMAGENET1K_V2 if self.imagenet else None
|
|
|
- self.model = resnet50(weights=self.weights)
|
|
|
+ self.model = resnet50(pretrained=pretrained)
|
|
|
elif self.name == 'squeezenet':
|
|
|
- self.weights = SqueezeNet1_0_Weights.IMAGENET1K_V1 if self.imagenet else None
|
|
|
- self.model = squeezenet1_0(weights=self.weights)
|
|
|
- elif self.name == 'shufflenet':
|
|
|
- self.weights = ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1 if self.imagenet else None
|
|
|
- self.model = shufflenet_v2_x1_0(weights=self.weights)
|
|
|
- elif self.name == 'swin_v2_s':
|
|
|
- self.weights = Swin_V2_S_Weights.IMAGENET1K_V1 if self.imagenet else None
|
|
|
- self.model = swin_v2_s(weights=self.weights)
|
|
|
- elif self.name == 'swin_v2_b':
|
|
|
- self.weights = Swin_V2_B_Weights.IMAGENET1K_V1 if self.imagenet else None
|
|
|
- self.model = swin_v2_b(weights=self.weights)
|
|
|
+ self.model = squeezenet1_0(pretrained=pretrained)
|
|
|
+ elif self.name == 'shufflenet' or self.name == 'shufflenet-x1':
|
|
|
+ self.model = shufflenet_v2_x1_0(pretrained=pretrained)
|
|
|
+ elif self.name == 'shufflenet-x2':
|
|
|
+ self.model = shufflenet_v2_x2_0(pretrained=False)
|
|
|
+ self.imagenet = False
|
|
|
+ print('shufflenet-x2无预训练权重,重新训练所有权重')
|
|
|
else:
|
|
|
raise ValueError(f"Invalid model name: {self.name}")
|
|
|
# 如果采用预训练的神经网络,就需要冻结特征提取层,只训练最后几层
|
|
|
@@ -240,7 +258,8 @@ class Trainer:
|
|
|
|
|
|
best_val_acc = 0.0
|
|
|
best_val_loss = float('inf')
|
|
|
-
|
|
|
+ # 在你的代码中调用
|
|
|
+ print_env_variables()
|
|
|
print("开始训练...")
|
|
|
for epoch in range(num_epochs):
|
|
|
print(f'Epoch {epoch + 1}/{num_epochs}')
|