|
@@ -4,7 +4,13 @@ import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.init as init
|
|
|
from torchvision import models
|
|
|
-from torchvision.models.vgg import model_urls
|
|
|
+#from torchvision.models.vgg import model_urls
|
|
|
+
|
|
|
+
|
|
|
+model_urls = {
|
|
|
+ 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
|
|
|
+}
|
|
|
+
|
|
|
|
|
|
def init_weights(modules):
|
|
|
for m in modules:
|
|
@@ -19,6 +25,7 @@ def init_weights(modules):
|
|
|
m.weight.data.normal_(0, 0.01)
|
|
|
m.bias.data.zero_()
|
|
|
|
|
|
+
|
|
|
class vgg16_bn(torch.nn.Module):
|
|
|
def __init__(self, pretrained=True, freeze=True):
|
|
|
super(vgg16_bn, self).__init__()
|