Youngmin Baek 6 лет назад
Родитель
Сommit
e4bf4b23f6
1 измененных файлов с 4 добавлено и 1 удалено
  1. 4 1
      test.py

+ 4 - 1
test.py

@@ -109,7 +109,10 @@ if __name__ == '__main__':
     net = CRAFT()     # initialize
 
     print('Loading weights from checkpoint (' + args.trained_model + ')')
-    net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
+    if args.cuda:
+        net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
+    else:
+        net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))
 
     if args.cuda:
         net = net.cuda()