vgg16_bn.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from collections import namedtuple
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.init as init
  5. from torchvision import models
  6. #from torchvision.models.vgg import model_urls
  7. model_urls = {
  8. 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
  9. }
  10. def init_weights(modules):
  11. for m in modules:
  12. if isinstance(m, nn.Conv2d):
  13. init.xavier_uniform_(m.weight.data)
  14. if m.bias is not None:
  15. m.bias.data.zero_()
  16. elif isinstance(m, nn.BatchNorm2d):
  17. m.weight.data.fill_(1)
  18. m.bias.data.zero_()
  19. elif isinstance(m, nn.Linear):
  20. m.weight.data.normal_(0, 0.01)
  21. m.bias.data.zero_()
  22. class vgg16_bn(torch.nn.Module):
  23. def __init__(self, pretrained=True, freeze=True):
  24. super(vgg16_bn, self).__init__()
  25. model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://')
  26. vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
  27. self.slice1 = torch.nn.Sequential()
  28. self.slice2 = torch.nn.Sequential()
  29. self.slice3 = torch.nn.Sequential()
  30. self.slice4 = torch.nn.Sequential()
  31. self.slice5 = torch.nn.Sequential()
  32. for x in range(12): # conv2_2
  33. self.slice1.add_module(str(x), vgg_pretrained_features[x])
  34. for x in range(12, 19): # conv3_3
  35. self.slice2.add_module(str(x), vgg_pretrained_features[x])
  36. for x in range(19, 29): # conv4_3
  37. self.slice3.add_module(str(x), vgg_pretrained_features[x])
  38. for x in range(29, 39): # conv5_3
  39. self.slice4.add_module(str(x), vgg_pretrained_features[x])
  40. # fc6, fc7 without atrous conv
  41. self.slice5 = torch.nn.Sequential(
  42. nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
  43. nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
  44. nn.Conv2d(1024, 1024, kernel_size=1)
  45. )
  46. if not pretrained:
  47. init_weights(self.slice1.modules())
  48. init_weights(self.slice2.modules())
  49. init_weights(self.slice3.modules())
  50. init_weights(self.slice4.modules())
  51. init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
  52. if freeze:
  53. for param in self.slice1.parameters(): # only first conv
  54. param.requires_grad= False
  55. def forward(self, X):
  56. h = self.slice1(X)
  57. h_relu2_2 = h
  58. h = self.slice2(h)
  59. h_relu3_2 = h
  60. h = self.slice3(h)
  61. h_relu4_3 = h
  62. h = self.slice4(h)
  63. h_relu5_3 = h
  64. h = self.slice5(h)
  65. h_fc7 = h
  66. vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
  67. out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
  68. return out