jiyuhang 3 hónapja
szülő
commit
eb66da1583

+ 9 - 1
.env

@@ -1,3 +1,11 @@
+# 图块高度
 PATCH_WIDTH=256
+# 图块宽度
 PATCH_HEIGHT=256
-CONFIDENCE_THRESHOLD=0.90
+CONFIDENCE_THRESHOLD=0.
+# 图片输入大小
+IMG_INPUT_SIZE=256
+# 工作线程数
+WORKERS=0
+# CUDA设备
+CUDA_VISIBLE_DEVICES=0

+ 4 - 1
.gitignore

@@ -15,7 +15,7 @@
 *.mkv
 *.flv
 *.wmv
-
+*.dat
 # Python 缓存文件
 __pycache__/
 *.pyc
@@ -40,3 +40,6 @@ __pycache__/
 
 # 模型文件
 *.pth
+*.txt
+# 日志文件
+runs/*

+ 0 - 7
data/1_video_202511211128/label.txt

@@ -1,7 +0,0 @@
-0,1280,256,256,-1
-256,1280,256,256,-1
-2304,0,256,256,-1
-2048,0,256,256,-1
-1792,0,256,256,-1
-2304,256,256,256,-1
-2560,256,256,256,-1

+ 0 - 19
data/2_video_202511211128/label.txt

@@ -1,19 +0,0 @@
-1280,1280,256,256,1
-1024,1280,256,256,1
-1024,1024,256,256,1
-1280,1024,256,256,1
-1536,1280,256,256,1
-1536,768,256,256,1
-1280,768,256,256,1
-1536,1024,256,256,1
-1792,256,256,256,-1
-1536,256,256,256,-1
-0,0,256,256,-1
-256,0,256,256,-1
-1024,512,256,256,-1
-768,512,256,256,-1
-512,512,256,256,-1
-256,512,256,256,-1
-256,768,256,256,-1
-512,768,256,256,-1
-0,768,256,256,-1

+ 0 - 8
data/3_video_202511211127/label.txt

@@ -1,8 +0,0 @@
-768,512,256,256,1
-512,512,256,256,1
-256,256,256,256,1
-512,256,256,256,1
-768,0,256,256,-1
-1024,0,256,256,-1
-0,0,256,256,-1
-256,0,256,256,-1

+ 0 - 19
data/4_video_202511211127/label.txt

@@ -1,19 +0,0 @@
-256,1280,256,256,-1
-0,1280,256,256,-1
-768,1280,256,256,1
-768,1024,256,256,1
-768,768,256,256,1
-768,512,256,256,1
-768,256,256,256,1
-768,0,256,256,1
-512,0,256,256,1
-512,256,256,256,1
-512,512,256,256,1
-512,768,256,256,1
-512,1024,256,256,1
-512,1280,256,256,1
-256,512,256,256,1
-256,768,256,256,1
-256,1024,256,256,1
-256,256,256,256,1
-256,0,256,256,-1

+ 0 - 6
data/video1_20251129120104_20251129123102/label.txt

@@ -1,6 +0,0 @@
-0,1280,256,256,-1
-256,1280,256,256,-1
-2304,0,256,256,-1
-1792,0,256,256,-1
-2048,0,256,256,-1
-2304,256,256,256,-1

+ 0 - 23
data/video4_20251129120320_20251129123514/label.txt

@@ -1,23 +0,0 @@
-0,1280,256,256,-1
-256,1280,256,256,-1
-256,0,256,256,-1
-0,768,256,256,-1
-0,1024,256,256,-1
-512,0,256,256,1
-768,0,256,256,1
-768,256,256,256,1
-512,256,256,256,1
-256,256,256,256,1
-256,512,256,256,1
-256,768,256,256,1
-512,1280,256,256,1
-512,1024,256,256,1
-256,1024,256,256,1
-512,768,256,256,1
-768,768,256,256,1
-768,512,256,256,1
-512,512,256,256,1
-768,1280,256,256,1
-768,1024,256,256,1
-1024,256,256,256,1
-1024,512,256,256,1

+ 13 - 0
labelme/utils.py

@@ -10,4 +10,17 @@ def draw_grid(img: np.ndarray, grid_w: int, grid_h:int):
     # 绘制纵向网格线
     for i in range((img_w // grid_w)+1):
         cv2.line(img, (i*grid_w, 0), (i*grid_w, img_h), (0, 255, 0), 2)
+    return img
+
+def draw_predict_grid(img, patches_index, predicted_class, confidence):
+    """在img上绘制预测结果"""
+    for i, (idx_w, idx_h) in enumerate(patches_index):
+        cv2.circle(img, (idx_w, idx_h), 10, (0, 255, 0), -1)
+        text1 = f'cls:{predicted_class[i]}'
+        text2 = f'prob:{confidence[i] * 100:.1f}%'
+        color = (0, 0, 255) if predicted_class[i] else (255, 0, 0)
+        cv2.putText(img, text1, (idx_w, idx_h + 128),
+                    cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
+        cv2.putText(img, text2, (idx_w, idx_h + 128 + 25),
+                    cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
     return img

+ 0 - 0
runs/turbidity_classification/events.out.tfevents.1766071975.240.2967095.0 → runs_dec19/runs/turbidity_classification/events.out.tfevents.1766071975.240.2967095.0


+ 0 - 0
runs/turbidity_classification/events.out.tfevents.1766081826.240.3294804.0 → runs_dec19/runs/turbidity_classification/events.out.tfevents.1766081826.240.3294804.0


+ 140 - 30
test.py

@@ -1,12 +1,15 @@
+import time
+
 import torch
 import torch.nn as nn
 from torchvision import transforms
-from torchvision.models import resnet18, resnet50
+from torchvision.models import resnet18, ResNet18_Weights,resnet50,ResNet50_Weights, squeezenet1_0, SqueezeNet1_0_Weights,\
+    shufflenet_v2_x1_0, ShuffleNet_V2_X1_0_Weights, swin_v2_s, Swin_V2_S_Weights, swin_v2_b, Swin_V2_B_Weights
 import numpy as np
 from PIL import Image
 import os
 import argparse
-from labelme.utils import draw_grid
+from labelme.utils import draw_grid, draw_predict_grid
 import cv2
 import matplotlib.pyplot as plt
 from dotenv import load_dotenv
@@ -29,22 +32,60 @@ class Predictor:
         # 加载模型
         self.load_model()
 
-        # 检查模型结构
-        print(self.model)
-
 
     def load_model(self):
         if self.model is not None:
             return
         print(f"正在加载模型: {self.model_name}")
-        if self.model_name == 'resnet18':
-            self.model = resnet18(weights=None)
-        elif self.model_name == 'resnet50':
-            self.model = resnet50(weights=None)
+        name = self.model_name
+        # 加载模型
+        if name == 'resnet50':
+            self.weights = ResNet50_Weights.IMAGENET1K_V2
+            self.model = resnet50(weights=self.weights)
+        elif name == 'squeezenet':
+            self.weights = SqueezeNet1_0_Weights.IMAGENET1K_V1
+            self.model = squeezenet1_0(weights=self.weights)
+        elif name == 'shufflenet':
+            self.weights = ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1
+            self.model = shufflenet_v2_x1_0(weights=self.weights)
+        elif name == 'swin_v2_s':
+            self.weights = Swin_V2_S_Weights.IMAGENET1K_V1
+            self.model = swin_v2_s(weights=self.weights)
+        elif name == 'swin_v2_b':
+            self.weights = Swin_V2_B_Weights.IMAGENET1K_V1
+            self.model = swin_v2_b(weights=self.weights)
         else:
-            raise ValueError(f"不支持的模型类型: {self.model_name}")
-        # 修改最后的全连接层
-        self.model.fc = nn.Linear(self.model.fc.in_features, self.num_classes)
+            raise ValueError(f"Invalid model name: {name}")
+        # 替换最后的分类层以适应新的分类任务
+        if hasattr(self.model, 'fc'):
+            # ResNet系列模型
+            self.model.fc = nn.Sequential(
+                nn.Linear(int(self.model.fc.in_features), int(self.model.fc.in_features) // 2, bias=True),
+                nn.ReLU(inplace=True),
+                nn.Dropout(0.5),
+                nn.Linear(int(self.model.fc.in_features) // 2, self.num_classes, bias=False)
+            )
+        elif hasattr(self.model, 'classifier'):
+            # Swin Transformer等模型
+            self.model.classifier = nn.Sequential(
+                nn.Linear(int(self.model.classifier.in_features), int(self.model.classifier.in_features) // 2,
+                          bias=True),
+                nn.ReLU(inplace=True),
+                nn.Dropout(0.5),
+                nn.Linear(int(self.model.classifier.in_features) // 2, self.num_classes, bias=False)
+            )
+        elif hasattr(self.model, 'head'):
+            # Swin Transformer使用head层
+            in_features = self.model.head.in_features
+            self.model.head = nn.Sequential(
+                nn.Linear(int(in_features), int(in_features) // 2, bias=True),
+                nn.ReLU(inplace=True),
+                nn.Dropout(0.5),
+                nn.Linear(int(in_features) // 2, self.num_classes, bias=False)
+            )
+        else:
+            raise ValueError(f"Model {name} does not have recognizable classifier layer")
+        print(self.model)
         # 加载训练好的权重
         self.model.load_state_dict(torch.load(self.weights_path, map_location=torch.device('cpu')))
         print(f"成功加载模型参数: {self.weights_path}")
@@ -148,20 +189,85 @@ def visualize_prediction(image_path, predicted_class, confidence, class_names):
               f'Confidence: {confidence:.4f}', fontsize=14)
     plt.show()
 
+def get_33_patch(arr:np.ndarray, center_row:int, center_col:int):
+    """以(center_row,center_col)为中心,从arr中取出来3*3区域的数据"""
+    # 边界检查
+    h,w = arr.shape
+    safe_row_up_limit = max(0, center_row-1)
+    safe_row_bottom_limit = min(h, center_row+2)
+    safe_col_left_limit = max(0, center_col-1)
+    safe_col_right_limit = min(w, center_col+2)
+    return arr[safe_row_up_limit:safe_row_bottom_limit, safe_col_left_limit:safe_col_right_limit]
+
+
+def fileter_prediction(predicted_class, confidence, pre_rows, pre_cols, filter_down_limit=3):
+    """预测结果矩阵滤波,九宫格内部存在浑浊水体的数量需要大于filter_down_limit,"""
+    predicted_class_mat = np.resize(predicted_class, (pre_rows, pre_cols))
+    predicted_conf_mat = np.resize(confidence, (pre_rows, pre_cols))
+    new_predicted_class_mat = predicted_class_mat.copy()
+    new_predicted_conf_mat = predicted_conf_mat.copy()
+    for i in range(pre_rows):
+        for j in range(pre_cols):
+            if (1. - predicted_class_mat[i, j]) > 0.1:
+                continue  # 跳过背景类
+            core_region = get_33_patch(predicted_class_mat, i, j)
+            if np.sum(core_region) < filter_down_limit:
+                new_predicted_class_mat[i, j] = 0  #  重置为背景类
+                new_predicted_conf_mat[i, j] = 1.0
+    return new_predicted_conf_mat.flatten(), new_predicted_class_mat.flatten()
+
+def discriminate_ratio(water_pre_list:list):
+    # 方式一:60%以上的帧存在浑浊水体
+    water_pre_arr = np.array(water_pre_list, dtype=np.float32)
+    water_pre_arr_mean = np.mean(water_pre_arr, axis=0)
+    bad_water = np.array(water_pre_arr_mean >= 0.6, dtype=np.int32)
+    bad_flag = np.sum(bad_water, dtype=np.int32)
+    print(f'浑浊比例方式:该时间段是否存在浑浊水体:{bool(bad_flag)}')
+    return bad_flag
+
+
+def discriminate_cont(pre_class_arr, continuous_count_mat):
+    """连续帧判别"""
+    positive_index = np.array(pre_class_arr,dtype=np.int32) > 0
+    negative_index = np.array(pre_class_arr,dtype=np.int32) == 0
+    # 给负样本区域置零
+    continuous_count_mat[negative_index] = 0
+    # 给正样本区域加1
+    continuous_count_mat[positive_index] += 1
+    # 判断浑浊
+    bad_flag = np.max(continuous_count_mat) > 30
+    if bad_flag:
+        print(f'连续帧方式:该时间段是否存在浑浊水体:{bool(bad_flag)}')
+    return bad_flag
 
 def main():
 
     # 初始化模型实例
-    predictor = Predictor(model_name='resnet50',
-                          weights_path=r'D:\code\water_turbidity_det\resnet50_best_model_acc.pth',
+    # TODO:修改模型网络名称/模型权重路径/视频路径
+    predictor = Predictor(model_name='shufflenet',
+                          weights_path=r'D:\code\water_turbidity_det\shufflenet_best_model_acc.pth',
                           num_classes=2)
-    input_path = r'D:\code\water_turbidity_det\data\video1_20251129120104_20251129123102'
+    input_path = r'D:\code\water_turbidity_det\data\4_video_202511211127'
     # 预处理图像
     all_imgs = os.listdir(input_path)
     all_imgs = [os.path.join(input_path, p) for p in all_imgs if p.split('.')[-1] in ['jpg', 'png']]
+    image = Image.open(all_imgs[0]).convert('RGB')
+    # 将预测结果reshape为矩阵时的行列数量
+    pre_rows = image.height // patch_h + 1
+    pre_cols = image.width // patch_w + 1
+    # 图像显示时resize的尺寸
+    resized_img_h = image.height // 2
+    resized_img_w = image.width // 2
+    # 预测每张图像
+
+    water_pre_list = []
+    continuous_count_mat = np.zeros(pre_rows*pre_cols, dtype=np.int32)
+    flag = False
     for img_path in all_imgs:
         image = Image.open(img_path).convert('RGB')
+        # 预处理
         patches_index, image_tensor = preprocess_image(image)
+        # 推理
         confidence, predicted_class  = predictor.predict(image_tensor)
         # 第一层虚警抑制,置信度过滤,低于阈值将会被忽略
         for i in range(len(confidence)):
@@ -169,27 +275,31 @@ def main():
                 confidence[i] = 1.0
                 predicted_class[i] = 0
         # 第二层虚警抑制,空间滤波
-
-        predicted_class_mat = np.resize(predicted_class, (image.height//patch_h+1, image.width//patch_w+1))
+        # 在此处添加过滤逻辑
+        new_confidence, new_predicted_class = fileter_prediction(predicted_class, confidence, pre_rows, pre_cols, filter_down_limit=3)
         # 可视化预测结果
         image = cv2.imread(img_path)
         image = draw_grid(image, patch_w, patch_h)
-        dw = patch_w // 2
-        dh = patch_h // 2
-        resized_img_h = image.shape[0] // 2
-        resized_img_w = image.shape[1] // 2
-        for i, (idx_w, idx_h) in enumerate(patches_index):
-            cv2.circle(image, (idx_w, idx_h), 10, (0, 255, 0), -1)
-            text1 = f'cls:{predicted_class[i]}'
-            text2 = f'prob:{confidence[i]*100:.1f}%'
-            color = (0, 0, 255) if predicted_class[i] else (255, 0, 0)
-            cv2.putText(image, text1, (idx_w, idx_h + dh),
-                        cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
-            cv2.putText(image, text2, (idx_w, idx_h + dh +25),
-                        cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
+        image = draw_predict_grid(image, patches_index, predicted_class, confidence)
+
+        new_image = cv2.imread(img_path)
+        new_image = draw_grid(new_image, patch_w, patch_h)
+        new_image = draw_predict_grid(new_image, patches_index, new_predicted_class, new_confidence)
         image = cv2.resize(image, (resized_img_w, resized_img_h))
+        new_img = cv2.resize(new_image, (resized_img_w, resized_img_h))
+
         cv2.imshow('image', image)
+        cv2.imshow('image_filter', new_img)
+
         cv2.waitKey(20)
+        # 方式1判别
+        if len(water_pre_list) > 100:
+            flag = discriminate_ratio(water_pre_list) and flag
+            water_pre_list = []
+            print('综合判别结果:', flag)
+        water_pre_list.append(new_predicted_class)
+        # 方式2判别
+        flag = discriminate_cont(new_predicted_class, continuous_count_mat)
 
 if __name__ == "__main__":
     main()

+ 118 - 110
train.py

@@ -1,9 +1,9 @@
 # 微调pytorch的预训练模型,在自己的数据上训练,完成分类任务。
 import time
-
 import torch
 import torch.nn as nn
 import torch.optim as optim
+from fontTools.misc.timeTools import epoch_diff
 from torch.utils.data import DataLoader
 import torchvision.transforms as transforms
 from torchvision.datasets import ImageFolder
@@ -14,31 +14,40 @@ import numpy as np
 from torch.utils.tensorboard import SummaryWriter  # 添加 TensorBoard 支持
 from datetime import datetime
 import os
-os.environ['CUDA_VISIBLE_DEVICES'] = '1'
+from dotenv import load_dotenv
+os.environ['CUDA_VISIBLE_DEVICES'] = os.getenv('CUDA_VISIBLE_DEVICES', '0')
 
 class Trainer:
     def __init__(self, batch_size, train_dir, val_dir, name, checkpoint):
-        # 初始化 TensorBoard writer
-        self.name = name
+        # 定义一些参数
+        self.name = name  # 采用的模型名称
+        self.img_size = int(os.getenv('IMG_INPUT_SIZE', 224))  # 输入图片尺寸
+        self.batch_size = batch_size  # 批次大小
+        self.cls_map = {"0": "non-muddy", "1":"muddy"}  # 类别名称映射词典
+        self.imagenet = True  # 是否使用ImageNet预训练权重
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 训练设备
         self.checkpoint = checkpoint
-        # 获取当前时间戳
-        timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
-        log_dir = f'runs/turbidity_{self.name}_{timestamp}'
-        self.writer = SummaryWriter(log_dir)
-
-        # 定义数据增强和处理
+        self.__global_step = 0
+        self.workers = int(os.getenv('WORKERS', 0))
+        # 创建日志目录
+        timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")  # 获取当前时间戳
+        log_dir = f'runs/turbidity_{self.name}_{timestamp}'# 按照模型的名称和时间戳创建日志目录
+        self.writer = SummaryWriter(log_dir)  # 创建 TensorBoard writer
+
+        # 定义数据增强和预处理层
         self.train_transforms = transforms.Compose([
-            transforms.Resize((256, 256)),  # 调整图像大小为256x256 (ResNet输入尺寸)
-            transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转,增加数据多样性
+            transforms.Resize((self.img_size, self.img_size)),  # 调整图像大小为256x256 (ResNet输入尺寸)
+            transforms.RandomHorizontalFlip(p=0.3),  # 随机水平翻转,增加数据多样性
+            transforms.RandomVerticalFlip(p=0.3), # 随机垂直翻转,增加数据多样性
+            transforms.RandomGrayscale(p=0.25), # 随机灰度化,增加数据多样性
             transforms.RandomRotation(10),  # 随机旋转±10度
             transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0, hue=0),  # 颜色抖动
             transforms.ToTensor(),  # 转换为tensor并归一化到[0,1]
             transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化
         ])
-
         # 测试集基础变换
         self.val_transforms = transforms.Compose([
-            transforms.Resize((256, 256)),
+            transforms.Resize((self.img_size, self.img_size)),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
         ])
@@ -46,85 +55,93 @@ class Trainer:
         # 创建数据集对象
         self.train_dataset = ImageFolder(root=train_dir, transform=self.train_transforms)
         self.val_dataset = ImageFolder(root=val_dir, transform=self.val_transforms)
-
         # 创建数据加载器 (Windows环境下设置num_workers=0避免多进程问题)
-        self.batch_size = batch_size
-        self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=10)
-        self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, num_workers=10)
+        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers)  # 可迭代对象,返回(inputs_tensor, label)
+        self.val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.workers)
         # 获取类别数量
         self.num_classes = len(self.train_dataset.classes)
-        print(f"发现 {self.num_classes} 个类别: {self.train_dataset.classes}")
+        print(f"自动发现 {self.num_classes} 个类别:")
+        # 打印类别和名称
+        for cls in self.train_dataset.classes:
+            print(f"{cls}: {self.cls_map.get(cls,'None')}")
+
+        # 创建模型
+        self.model = None
+        self.model = self.__load_model()
+        # 定义损失函数
+        self.loss = nn.CrossEntropyLoss()  # 多分类常用的交叉熵损失
+
+        # 定义优化器
+        # 只更新requires_grad=True的参数
+        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
 
+        # 基于验证损失动态调整,更智能
+        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
+            self.optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-7,cooldown=2
+        )
+    def __load_model(self):
+        """加载模型结构"""
         # 加载模型
-        if name == 'resnet50':
-            self.weights = ResNet50_Weights.IMAGENET1K_V2
+        if self.name == 'resnet50':
+            self.weights = ResNet50_Weights.IMAGENET1K_V2 if self.imagenet else None
             self.model = resnet50(weights=self.weights)
-        elif name == 'squeezenet':
-            self.weights = SqueezeNet1_0_Weights.IMAGENET1K_V1
+        elif self.name == 'squeezenet':
+            self.weights = SqueezeNet1_0_Weights.IMAGENET1K_V1 if self.imagenet else None
             self.model = squeezenet1_0(weights=self.weights)
-        elif name == 'shufflenet':
-            self.weights = ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1
+        elif self.name == 'shufflenet':
+            self.weights = ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1 if self.imagenet else None
             self.model = shufflenet_v2_x1_0(weights=self.weights)
-        elif name == 'swin_v2_s':
-            self.weights = Swin_V2_S_Weights.IMAGENET1K_V1
+        elif self.name == 'swin_v2_s':
+            self.weights = Swin_V2_S_Weights.IMAGENET1K_V1 if self.imagenet else None
             self.model = swin_v2_s(weights=self.weights)
-        elif name == 'swin_v2_b':
-            self.weights = Swin_V2_B_Weights.IMAGENET1K_V1
+        elif self.name == 'swin_v2_b':
+            self.weights = Swin_V2_B_Weights.IMAGENET1K_V1 if self.imagenet else None
             self.model = swin_v2_b(weights=self.weights)
         else:
-            raise ValueError(f"Invalid model name: {name}")
-        print(self.model)
-        # 冻结特征提取层,只训练最后几层,
-        for param in self.model.parameters():
-            param.requires_grad = False
-
+            raise ValueError(f"Invalid model name: {self.name}")
+        # 如果采用预训练的神经网络,就需要冻结特征提取层,只训练最后几层
+        if self.imagenet:
+            for param in self.model.parameters():
+                param.requires_grad = False
         # 替换最后的分类层以适应新的分类任务
         if hasattr(self.model, 'fc'):
             # ResNet系列模型
-            self.model.fc = nn.Sequential(
-                nn.Linear(int(self.model.fc.in_features), int(self.model.fc.in_features) // 2, bias=True),
-                nn.ReLU(inplace=True),
-                nn.Dropout(0.5),
-                nn.Linear(int(self.model.fc.in_features) // 2, self.num_classes, bias=False)
-            )
+            self.model.fc = nn.Linear(int(self.model.fc.in_features), self.num_classes, bias=False)
+            # self.model.fc = nn.Sequential(
+            #     nn.Linear(int(self.model.fc.in_features), int(self.model.fc.in_features) // 2, bias=False),
+            #     nn.ReLU(inplace=True),
+            #     nn.Dropout(0.5),
+            #     nn.Linear(int(self.model.fc.in_features) // 2, self.num_classes, bias=False)
+            # )
         elif hasattr(self.model, 'classifier'):
             # Swin Transformer等模型
-            self.model.classifier = nn.Sequential(
-                nn.Linear(int(self.model.classifier.in_features), int(self.model.classifier.in_features) // 2,
-                          bias=True),
-                nn.ReLU(inplace=True),
-                nn.Dropout(0.5),
-                nn.Linear(int(self.model.classifier.in_features) // 2, self.num_classes, bias=False)
-            )
+            self.model.classifier = nn.Linear(int(self.model.classifier.in_features), self.num_classes, bias=False)
+            # self.model.classifier = nn.Sequential(
+            #     nn.Linear(int(self.model.classifier.in_features), int(self.model.classifier.in_features) // 2,
+            #               bias=True),
+            #     nn.ReLU(inplace=True),
+            #     nn.Dropout(0.5),
+            #     nn.Linear(int(self.model.classifier.in_features) // 2, self.num_classes, bias=False)
+            # )
         elif hasattr(self.model, 'head'):
             # Swin Transformer使用head层
-            in_features = self.model.head.in_features
-            self.model.head = nn.Sequential(
-                nn.Linear(int(in_features), int(in_features) // 2, bias=True),
-                nn.ReLU(inplace=True),
-                nn.Dropout(0.5),
-                nn.Linear(int(in_features) // 2, self.num_classes, bias=False)
-            )
+            self.model.head = nn.Linear(int(self.model.head.in_features), self.num_classes, bias=False)
+            # in_features = self.model.head.in_features
+            # self.model.head = nn.Sequential(
+            #     nn.Linear(int(in_features), int(in_features) // 2, bias=True),
+            #     nn.ReLU(inplace=True),
+            #     nn.Dropout(0.5),
+            #     nn.Linear(int(in_features) // 2, self.num_classes, bias=False)
+            # )
         else:
-            raise ValueError(f"Model {name} does not have recognizable classifier layer")
-
-        # 将模型移动到GPU
-        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+            raise ValueError(f"Model {self.name} does not have recognizable classifier layer")
+        print(self.model)
+        print(f'模型{self.name}结构已经加载,移动到设备{self.device}')
+        # 将模型移动到GPU/cpu
         self.model = self.model.to(self.device)
+        return self.model
 
-        # 定义损失函数
-        self.loss = nn.CrossEntropyLoss()  # 多分类常用的交叉熵损失
-
-        # 定义优化器
-        # 只更新requires_grad=True的参数
-        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
-
-        # 基于验证损失动态调整,更智能
-        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
-            self.optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-7
-        )
-
-    def train_model(self):
+    def train_step(self):
         """
         单轮训练函数
 
@@ -135,22 +152,22 @@ class Trainer:
             accuracy: 准确率
         """
         self.model.train()  # 设置模型为训练模式(启用dropout/batchnorm等)
-        running_loss = 0.0
-        correct_predictions = 0
-        total_samples = 0
+        epoch_loss = 0.0  # epoch损失
+        correct_predictions = 0.0  # 预测正确的样本数
+        total_samples = 0.0  # 已经训练的总样本
 
         # 遍历训练数据
         for inputs, labels in self.train_loader:
             # 将数据移到指定设备上
-            inputs = inputs.to(self.device)
-            labels = labels.to(self.device)
+            inputs = inputs.to(self.device)  # b c h w
+            labels = labels.to(self.device)  # b,
 
             # 清零梯度缓存
             self.optimizer.zero_grad()
 
             # 前向传播
-            outputs = self.model(inputs)
-            loss = self.loss(outputs, labels)
+            outputs = self.model(inputs)  # b, 2
+            loss = self.loss(outputs, labels) # 标量
 
             # 反向传播
             loss.backward()
@@ -159,17 +176,20 @@ class Trainer:
             self.optimizer.step()
 
             # 统计信息
-            running_loss += loss.item() * inputs.size(0)
-            _, predicted = torch.max(outputs.data, 1)
-            total_samples += labels.size(0)
-            correct_predictions += (predicted == labels).sum().item()
-
-        epoch_loss = running_loss / len(self.train_loader.dataset)
+            batch_loss = loss.item() * inputs.size(0)  # 批损失
+            self.writer.add_scalar('Batch_Loss/Train', batch_loss, self.__global_step)
+            print(f'Training | Batch Loss: {batch_loss:.4f}\t', end='\r',flush=True)
+            epoch_loss += batch_loss  # epoch损失
+            _, predicted = torch.max(outputs.data, 1) # predicted b,存储的是预测的类别,最大值的位置
+            total_samples += labels.size(0)  # 统计总样本数量
+            correct_predictions += (predicted == labels).sum().item()  # 统计预测正确的样本数
+
+        epoch_loss = epoch_loss / len(self.train_loader.dataset)
         epoch_acc = correct_predictions / total_samples
         return epoch_loss, epoch_acc
 
 
-    def validate_model(self):
+    def val_step(self):
         """
         验证模型性能
 
@@ -180,9 +200,9 @@ class Trainer:
         """
         self.model.eval()  # 设置模型为评估模式(关闭dropout/batchnorm等)
 
-        running_loss = 0.0
-        correct_predictions = 0
-        total_samples = 0
+        epoch_loss = 0.0
+        correct_predictions = 0.
+        total_samples = 0.
 
         # 不计算梯度,提高推理速度
         with torch.no_grad():
@@ -193,12 +213,12 @@ class Trainer:
                 outputs = self.model(inputs)
                 loss = self.loss(outputs, labels)
 
-                running_loss += loss.item() * inputs.size(0)
+                epoch_loss += loss.item() * inputs.size(0)
                 _, predicted = torch.max(outputs.data, 1)
                 total_samples += labels.size(0)
                 correct_predictions += (predicted == labels).sum().item()
 
-        epoch_loss = running_loss / len(self.val_loader.dataset)
+        epoch_loss = epoch_loss / len(self.val_loader.dataset)  # 平均损失
         epoch_acc = correct_predictions / total_samples
 
         return epoch_loss, epoch_acc
@@ -217,11 +237,6 @@ class Trainer:
             val_losses: 每轮验证损失
             val_accuracies: 每轮验证准确率
         """
-        # 存储训练过程中的指标
-        train_losses = []
-        train_accuracies = []
-        val_losses = []
-        val_accuracies = []
 
         best_val_acc = 0.0
         best_val_loss = float('inf')
@@ -231,12 +246,12 @@ class Trainer:
             print(f'Epoch {epoch + 1}/{num_epochs}')
             print('-' * 20)
 
-            # 训练阶段
-            train_loss, train_acc = self.train_model()
+            # 单步训练
+            train_loss, train_acc = self.train_step()
             print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')
 
             # 验证阶段
-            val_loss, val_acc = self.validate_model()
+            val_loss, val_acc = self.val_step()
             print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')
 
             # 学习率调度
@@ -249,12 +264,6 @@ class Trainer:
             self.writer.add_scalar('Accuracy/Validation', val_acc, epoch)
             self.writer.add_scalar('Learning Rate', self.scheduler.get_last_lr()[0], epoch)
 
-            # 记录指标
-            train_losses.append(train_loss)
-            train_accuracies.append(train_acc)
-            val_losses.append(val_loss)
-            val_accuracies.append(val_acc)
-
             # 保存最佳模型 (基于验证准确率)
             if val_acc > best_val_acc:
                 best_val_acc = val_acc
@@ -262,7 +271,6 @@ class Trainer:
                 print(f"保存了新的最佳准确率模型,验证准确率: {best_val_acc:.4f}")
             
             # 保存最低验证损失模型
-            
             if val_loss < best_val_loss:
                 best_val_loss = val_loss
                 torch.save(self.model.state_dict(), f'{self.name}_best_model_loss.pth')
@@ -273,7 +281,7 @@ class Trainer:
         self.writer.close()
         
         print(f"训练完成! 最佳验证准确率: {best_val_acc:.4f}, 最低验证损失: {best_val_loss:.4f}")
-        return train_losses, train_accuracies, val_losses, val_accuracies
+        return 1
 
 if __name__ == '__main__':
     # 开始训练
@@ -281,12 +289,12 @@ if __name__ == '__main__':
     parser = argparse.ArgumentParser('预训练模型调参')
     parser.add_argument('--train_dir',default='./label_data/train',help='help')
     parser.add_argument('--val_dir', default='./label_data/test',help='help')
-    parser.add_argument('--model', default='resnet18',help='help')
+    parser.add_argument('--model', default='shufflenet',help='help')
     args = parser.parse_args()
     num_epochs = 100
-    trainer = Trainer(batch_size=64,
+    trainer = Trainer(batch_size=128,
                       train_dir=args.train_dir,
                       val_dir=args.val_dir,
                       name=args.model,
                       checkpoint=False)
-    train_losses, train_accuracies, val_losses, val_accuracies = trainer.train_and_validate(num_epochs)
+    trainer.train_and_validate(num_epochs)