| 1234567891011121314151617181920 |
- #!/usr/bin/env python3
- import torch
- # Build a simple nn model
- class SimpleModel(torch.nn.Module):
- def __init__(self):
- super(SimpleModel, self).__init__()
- self.m1 = torch.nn.Conv2d(3, 8, 3, 1, 0)
- self.m2 = torch.nn.Conv2d(8, 8, 3, 1, 1)
- def forward(self, x):
- y0 = self.m1(x)
- y1 = self.m2(y0)
- y2 = y0 + y1
- return y2
- # Create a SimpleModel and save its weight in the current directory
- model = SimpleModel()
- torch.save(model.state_dict(), "simple.pth")
|