test.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. """
  2. Copyright (c) 2019-present NAVER Corp.
  3. MIT License
  4. """
  5. # -*- coding: utf-8 -*-
  6. import sys
  7. import os
  8. import time
  9. import argparse
  10. import torch
  11. import torch.nn as nn
  12. import torch.backends.cudnn as cudnn
  13. from torch.autograd import Variable
  14. from PIL import Image
  15. import cv2
  16. from skimage import io
  17. import numpy as np
  18. import craft_utils
  19. import imgproc
  20. import file_utils
  21. import json
  22. import zipfile
  23. from craft import CRAFT
  24. from collections import OrderedDict
  25. def copyStateDict(state_dict):
  26. if list(state_dict.keys())[0].startswith("module"):
  27. start_idx = 1
  28. else:
  29. start_idx = 0
  30. new_state_dict = OrderedDict()
  31. for k, v in state_dict.items():
  32. name = ".".join(k.split(".")[start_idx:])
  33. new_state_dict[name] = v
  34. return new_state_dict
  35. def str2bool(v):
  36. return v.lower() in ("yes", "y", "true", "t", "1")
  37. parser = argparse.ArgumentParser(description='CRAFT Text Detection')
  38. parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model')
  39. parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
  40. parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
  41. parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
  42. parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')
  43. parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
  44. parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
  45. parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
  46. parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
  47. parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images')
  48. parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner')
  49. parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')
  50. args = parser.parse_args()
  51. """ For test images in a folder """
  52. image_list, _, _ = file_utils.get_files(args.test_folder)
  53. result_folder = './result/'
  54. if not os.path.isdir(result_folder):
  55. os.mkdir(result_folder)
  56. def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None):
  57. t0 = time.time()
  58. # resize
  59. img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio)
  60. ratio_h = ratio_w = 1 / target_ratio
  61. # preprocessing
  62. x = imgproc.normalizeMeanVariance(img_resized)
  63. x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
  64. x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
  65. if cuda:
  66. x = x.cuda()
  67. # forward pass
  68. y, feature = net(x)
  69. # make score and link map
  70. score_text = y[0,:,:,0].cpu().data.numpy()
  71. score_link = y[0,:,:,1].cpu().data.numpy()
  72. # refine link
  73. if refine_net is not None:
  74. y_refiner = refine_net(y, feature)
  75. score_link = y_refiner[0,:,:,0].cpu().data.numpy()
  76. t0 = time.time() - t0
  77. t1 = time.time()
  78. # Post-processing
  79. boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)
  80. # coordinate adjustment
  81. boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
  82. polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
  83. for k in range(len(polys)):
  84. if polys[k] is None: polys[k] = boxes[k]
  85. t1 = time.time() - t1
  86. # render results (optional)
  87. render_img = score_text.copy()
  88. render_img = np.hstack((render_img, score_link))
  89. ret_score_text = imgproc.cvt2HeatmapImg(render_img)
  90. if args.show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))
  91. return boxes, polys, ret_score_text
  92. if __name__ == '__main__':
  93. # load net
  94. net = CRAFT() # initialize
  95. print('Loading weights from checkpoint (' + args.trained_model + ')')
  96. if args.cuda:
  97. net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
  98. else:
  99. net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))
  100. if args.cuda:
  101. net = net.cuda()
  102. net = torch.nn.DataParallel(net)
  103. cudnn.benchmark = False
  104. net.eval()
  105. # LinkRefiner
  106. refine_net = None
  107. if args.refine:
  108. from refinenet import RefineNet
  109. refine_net = RefineNet()
  110. print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
  111. if args.cuda:
  112. refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
  113. refine_net = refine_net.cuda()
  114. refine_net = torch.nn.DataParallel(refine_net)
  115. else:
  116. refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))
  117. refine_net.eval()
  118. args.poly = True
  119. t = time.time()
  120. # load data
  121. for k, image_path in enumerate(image_list):
  122. print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
  123. image = imgproc.loadImage(image_path)
  124. bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net)
  125. # save score text
  126. filename, file_ext = os.path.splitext(os.path.basename(image_path))
  127. mask_file = result_folder + "/res_" + filename + '_mask.jpg'
  128. cv2.imwrite(mask_file, score_text)
  129. file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)
  130. print("elapsed time : {}s".format(time.time() - t))