simple_model.py 519 B

1234567891011121314151617181920
  1. #!/usr/bin/env python3
  2. import torch
  3. # Build a simple nn model
  4. class SimpleModel(torch.nn.Module):
  5. def __init__(self):
  6. super(SimpleModel, self).__init__()
  7. self.m1 = torch.nn.Conv2d(3, 8, 3, 1, 0)
  8. self.m2 = torch.nn.Conv2d(8, 8, 3, 1, 1)
  9. def forward(self, x):
  10. y0 = self.m1(x)
  11. y1 = self.m2(y0)
  12. y2 = y0 + y1
  13. return y2
  14. # Create a SimpleModel and save its weight in the current directory
  15. model = SimpleModel()
  16. torch.save(model.state_dict(), "simple.pth")