|
@@ -109,7 +109,10 @@ if __name__ == '__main__':
|
|
net = CRAFT() # initialize
|
|
net = CRAFT() # initialize
|
|
|
|
|
|
print('Loading weights from checkpoint (' + args.trained_model + ')')
|
|
print('Loading weights from checkpoint (' + args.trained_model + ')')
|
|
- net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
|
|
|
|
|
|
+ if args.cuda:
|
|
|
|
+ net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
|
|
|
|
+ else:
|
|
|
|
+ net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))
|
|
|
|
|
|
if args.cuda:
|
|
if args.cuda:
|
|
net = net.cuda()
|
|
net = net.cuda()
|