|
@@ -27,6 +27,18 @@ import zipfile
|
|
|
|
|
|
from craft import CRAFT
|
|
|
|
|
|
+from collections import OrderedDict
|
|
|
+def copyStateDict(state_dict):
|
|
|
+ if list(state_dict.keys())[0].startswith("module"):
|
|
|
+ start_idx = 1
|
|
|
+ else:
|
|
|
+ start_idx = 0
|
|
|
+ new_state_dict = OrderedDict()
|
|
|
+ for k, v in state_dict.items():
|
|
|
+ name = ".".join(k.split(".")[start_idx:])
|
|
|
+ new_state_dict[name] = v
|
|
|
+ return new_state_dict
|
|
|
+
|
|
|
def str2bool(v):
|
|
|
return v.lower() in ("yes", "y", "true", "t", "1")
|
|
|
|
|
@@ -96,13 +108,14 @@ if __name__ == '__main__':
|
|
|
# load net
|
|
|
net = CRAFT() # initialize
|
|
|
|
|
|
+ print('Loading weights from checkpoint (' + args.trained_model + ')')
|
|
|
+ net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
|
|
|
+
|
|
|
if args.cuda:
|
|
|
net = net.cuda()
|
|
|
net = torch.nn.DataParallel(net)
|
|
|
cudnn.benchmark = False
|
|
|
|
|
|
- print('Loading weights from checkpoint (' + args.trained_model + ')')
|
|
|
- net.load_state_dict(torch.load(args.trained_model))
|
|
|
net.eval()
|
|
|
|
|
|
t = time.time()
|