|
@@ -47,12 +47,14 @@ parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type
|
|
parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
|
|
parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
|
|
parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
|
|
parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
|
|
parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
|
|
parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
|
|
-parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda to train model')
|
|
|
|
|
|
+parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')
|
|
parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
|
|
parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
|
|
parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
|
|
parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
|
|
parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
|
|
parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
|
|
parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
|
|
parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
|
|
parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images')
|
|
parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images')
|
|
|
|
+parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner')
|
|
|
|
+parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')
|
|
|
|
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
|
|
|
|
@@ -64,7 +66,7 @@ result_folder = './result/'
|
|
if not os.path.isdir(result_folder):
|
|
if not os.path.isdir(result_folder):
|
|
os.mkdir(result_folder)
|
|
os.mkdir(result_folder)
|
|
|
|
|
|
-def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly):
|
|
|
|
|
|
+def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None):
|
|
t0 = time.time()
|
|
t0 = time.time()
|
|
|
|
|
|
# resize
|
|
# resize
|
|
@@ -79,12 +81,17 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly):
|
|
x = x.cuda()
|
|
x = x.cuda()
|
|
|
|
|
|
# forward pass
|
|
# forward pass
|
|
- y, _ = net(x)
|
|
|
|
|
|
+ y, feature = net(x)
|
|
|
|
|
|
# make score and link map
|
|
# make score and link map
|
|
score_text = y[0,:,:,0].cpu().data.numpy()
|
|
score_text = y[0,:,:,0].cpu().data.numpy()
|
|
score_link = y[0,:,:,1].cpu().data.numpy()
|
|
score_link = y[0,:,:,1].cpu().data.numpy()
|
|
|
|
|
|
|
|
+ # refine link
|
|
|
|
+ if refine_net is not None:
|
|
|
|
+ y_refiner = refine_net(y, feature)
|
|
|
|
+ score_link = y_refiner[0,:,:,0].cpu().data.numpy()
|
|
|
|
+
|
|
t0 = time.time() - t0
|
|
t0 = time.time() - t0
|
|
t1 = time.time()
|
|
t1 = time.time()
|
|
|
|
|
|
@@ -127,6 +134,22 @@ if __name__ == '__main__':
|
|
|
|
|
|
net.eval()
|
|
net.eval()
|
|
|
|
|
|
|
|
+ # LinkRefiner
|
|
|
|
+ refine_net = None
|
|
|
|
+ if args.refine:
|
|
|
|
+ from refinenet import RefineNet
|
|
|
|
+ refine_net = RefineNet()
|
|
|
|
+ print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
|
|
|
|
+ if args.cuda:
|
|
|
|
+ refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
|
|
|
|
+ refine_net = refine_net.cuda()
|
|
|
|
+ refine_net = torch.nn.DataParallel(refine_net)
|
|
|
|
+ else:
|
|
|
|
+ refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))
|
|
|
|
+
|
|
|
|
+ refine_net.eval()
|
|
|
|
+ args.poly = True
|
|
|
|
+
|
|
t = time.time()
|
|
t = time.time()
|
|
|
|
|
|
# load data
|
|
# load data
|
|
@@ -134,7 +157,7 @@ if __name__ == '__main__':
|
|
print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
|
|
print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
|
|
image = imgproc.loadImage(image_path)
|
|
image = imgproc.loadImage(image_path)
|
|
|
|
|
|
- bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly)
|
|
|
|
|
|
+ bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net)
|
|
|
|
|
|
# save score text
|
|
# save score text
|
|
filename, file_ext = os.path.splitext(os.path.basename(image_path))
|
|
filename, file_ext = os.path.splitext(os.path.basename(image_path))
|