#!/usr/bin/env python3 import torch # Build a simple nn model class SimpleModel(torch.nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.m1 = torch.nn.Conv2d(3, 8, 3, 1, 0) self.m2 = torch.nn.Conv2d(8, 8, 3, 1, 1) def forward(self, x): y0 = self.m1(x) y1 = self.m2(y0) y2 = y0 + y1 return y2 # Create a SimpleModel and save its weight in the current directory model = SimpleModel() torch.save(model.state_dict(), "simple.pth")