Prechádzať zdrojové kódy

fix error in loading the model in cpu mode

Youngmin Baek 6 rokov pred
rodič
commit
ce07620c12
1 zmenil súbory, kde vykonal 15 pridanie a 2 odobranie
  1. 15 2
      test.py

+ 15 - 2
test.py

@@ -27,6 +27,18 @@ import zipfile
 
 from craft import CRAFT
 
+from collections import OrderedDict
+def copyStateDict(state_dict):
+    if list(state_dict.keys())[0].startswith("module"):
+        start_idx = 1
+    else:
+        start_idx = 0
+    new_state_dict = OrderedDict()
+    for k, v in state_dict.items():
+        name = ".".join(k.split(".")[start_idx:])
+        new_state_dict[name] = v
+    return new_state_dict
+
 def str2bool(v):
     return v.lower() in ("yes", "y", "true", "t", "1")
 
@@ -96,13 +108,14 @@ if __name__ == '__main__':
     # load net
     net = CRAFT()     # initialize
 
+    print('Loading weights from checkpoint (' + args.trained_model + ')')
+    net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
+
     if args.cuda:
         net = net.cuda()
         net = torch.nn.DataParallel(net)
         cudnn.benchmark = False
 
-    print('Loading weights from checkpoint (' + args.trained_model + ')')
-    net.load_state_dict(torch.load(args.trained_model))
     net.eval()
 
     t = time.time()