model.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from typing import List, Tuple, Optional
  2. import os
  3. import time
  4. import torch
  5. from torch.autograd import Variable
  6. from PIL import Image
  7. import numpy as np
  8. import cv2
  9. from huggingface_hub import hf_hub_url, cached_download
  10. from CRAFT.craft import CRAFT, init_CRAFT_model
  11. from CRAFT.refinenet import RefineNet, init_refiner_model
  12. from CRAFT.craft_utils import adjustResultCoordinates, getDetBoxes
  13. from CRAFT.imgproc import resize_aspect_ratio, normalizeMeanVariance
  14. HF_MODELS = {
  15. 'craft': dict(
  16. repo_id='boomb0om/CRAFT-text-detector',
  17. filename='craft_mlt_25k.pth',
  18. ),
  19. 'refiner': dict(
  20. repo_id='boomb0om/CRAFT-text-detector',
  21. filename='craft_refiner_CTW1500.pth',
  22. )
  23. }
  24. def preprocess_image(image: np.ndarray, canvas_size: int, mag_ratio: bool):
  25. # resize
  26. img_resized, target_ratio, size_heatmap = resize_aspect_ratio(
  27. image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio
  28. )
  29. ratio_h = ratio_w = 1 / target_ratio
  30. # preprocessing
  31. x = normalizeMeanVariance(img_resized)
  32. x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
  33. x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
  34. return x, ratio_w, ratio_h
  35. class CRAFTModel:
  36. def __init__(
  37. self,
  38. cache_dir: str,
  39. device: torch.device,
  40. local_files_only: bool = False,
  41. use_refiner: bool = True,
  42. fp16: bool = True,
  43. canvas_size: int = 1280,
  44. mag_ratio: float = 1.5,
  45. text_threshold: float = 0.7,
  46. link_threshold: float = 0.4,
  47. low_text: float = 0.4
  48. ):
  49. self.cache_dir = cache_dir
  50. self.use_refiner = use_refiner
  51. self.device = device
  52. self.fp16 = fp16
  53. self.canvas_size = canvas_size
  54. self.mag_ratio = mag_ratio
  55. self.text_threshold = text_threshold
  56. self.link_threshold = link_threshold
  57. self.low_text = low_text
  58. # loading models
  59. paths = {}
  60. for model_name in ['craft', 'refiner']:
  61. config = HF_MODELS[model_name]
  62. paths[model_name] = os.path.join(cache_dir, config['filename'])
  63. if not local_files_only:
  64. config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
  65. cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename'])
  66. self.net = init_CRAFT_model(paths['craft'], device, fp16=fp16)
  67. if self.use_refiner:
  68. self.refiner = init_refiner_model(paths['refiner'], device)
  69. else:
  70. self.refiner = None
  71. def get_text_map(self, x: torch.Tensor, ratio_w: int, ratio_h: int) -> Tuple[np.ndarray, np.ndarray]:
  72. x = x.to(self.device)
  73. # forward pass
  74. with torch.no_grad():
  75. y, feature = self.net(x)
  76. # make score and link map
  77. score_text = y[0,:,:,0].cpu().data.numpy()
  78. score_link = y[0,:,:,1].cpu().data.numpy()
  79. # refine link
  80. if self.refiner:
  81. with torch.no_grad():
  82. y_refiner = self.refiner(y, feature)
  83. score_link = y_refiner[0,:,:,0].cpu().data.numpy()
  84. return score_text, score_link
  85. def get_polygons(self, image: Image.Image) -> List[List[List[int]]]:
  86. x, ratio_w, ratio_h = preprocess_image(np.array(image), self.canvas_size, self.mag_ratio)
  87. score_text, score_link = self.get_text_map(x, ratio_w, ratio_h)
  88. # Post-processing
  89. boxes, polys = getDetBoxes(
  90. score_text, score_link,
  91. self.text_threshold, self.link_threshold,
  92. self.low_text, True
  93. )
  94. boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
  95. for k in range(len(polys)):
  96. if polys[k] is None:
  97. polys[k] = boxes[k]
  98. else:
  99. polys[k] = adjustResultCoordinates(polys[k], ratio_w, ratio_h)
  100. res = []
  101. for poly in polys:
  102. res.append(poly.astype(np.int32).tolist())
  103. return res
  104. def _get_boxes_preproc(self, x, ratio_w, ratio_h) -> List[List[List[int]]]:
  105. score_text, score_link = self.get_text_map(x, ratio_w, ratio_h)
  106. # Post-processing
  107. boxes, polys = getDetBoxes(
  108. score_text, score_link,
  109. self.text_threshold, self.link_threshold,
  110. self.low_text, False
  111. )
  112. boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
  113. boxes_final = []
  114. if len(boxes)>0:
  115. boxes = boxes.astype(np.int32).tolist()
  116. for box in boxes:
  117. boxes_final.append([box[0], box[2]])
  118. return boxes_final
  119. def get_boxes(self, image: Image.Image) -> List[List[List[int]]]:
  120. x, ratio_w, ratio_h = preprocess_image(np.array(image), self.canvas_size, self.mag_ratio)
  121. boxes_final = self._get_boxes_preproc(x, ratio_w, ratio_h)
  122. return boxes_final