|
|
@@ -26,6 +26,7 @@ class Predictor:
|
|
|
self.weights_path = weights_path
|
|
|
self.num_classes = num_classes
|
|
|
self.model = None
|
|
|
+ self.use_bias = os.getenv('USE_BIAS', True)
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f"当前设备: {self.device}")
|
|
|
# 加载模型
|
|
|
@@ -51,13 +52,13 @@ class Predictor:
|
|
|
# 替换最后的分类层以适应新的分类任务
|
|
|
if hasattr(self.model, 'fc'):
|
|
|
# ResNet系列模型
|
|
|
- self.model.fc = nn.Linear(int(self.model.fc.in_features), self.num_classes, bias=False)
|
|
|
+ self.model.fc = nn.Linear(int(self.model.fc.in_features), self.num_classes, bias=self.use_bias)
|
|
|
elif hasattr(self.model, 'classifier'):
|
|
|
# Swin Transformer等模型
|
|
|
- self.model.classifier = nn.Linear(int(self.model.classifier.in_features), self.num_classes, bias=False)
|
|
|
+ self.model.classifier = nn.Linear(int(self.model.classifier.in_features), self.num_classes, bias=self.use_bias)
|
|
|
elif hasattr(self.model, 'head'):
|
|
|
# Swin Transformer使用head层
|
|
|
- self.model.head = nn.Linear(int(self.model.head.in_features), self.num_classes, bias=False)
|
|
|
+ self.model.head = nn.Linear(int(self.model.head.in_features), self.num_classes, bias=self.use_bias)
|
|
|
|
|
|
else:
|
|
|
raise ValueError(f"Model {name} does not have recognizable classifier layer")
|
|
|
@@ -220,9 +221,9 @@ def main():
|
|
|
# 初始化模型实例
|
|
|
# TODO:修改模型网络名称/模型权重路径/视频路径
|
|
|
predictor = Predictor(model_name='shufflenet',
|
|
|
- weights_path=r'/shufflenet.pth',
|
|
|
+ weights_path=r'./shufflenet.pth',
|
|
|
num_classes=2)
|
|
|
- input_path = r'frame_data/train/20251225/4_video_202511211127'
|
|
|
+ input_path = r'D:\code\water_turbidity_det\frame_data\test\20251225\video4_20251129120320_20251129123514'
|
|
|
# 预处理图像
|
|
|
all_imgs = os.listdir(input_path)
|
|
|
all_imgs = [os.path.join(input_path, p) for p in all_imgs if p.split('.')[-1] in ['jpg', 'png']]
|
|
|
@@ -268,7 +269,7 @@ def main():
|
|
|
|
|
|
cv2.waitKey(20)
|
|
|
# 方式1判别
|
|
|
- if len(water_pre_list) > 100:
|
|
|
+ if len(water_pre_list) > 20:
|
|
|
flag = discriminate_ratio(water_pre_list) and flag
|
|
|
water_pre_list = []
|
|
|
print('综合判别结果:', flag)
|