123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- from typing import List, Tuple, Optional
- import os
- import time
- import torch
- from torch.autograd import Variable
- from PIL import Image
- import numpy as np
- import cv2
- from huggingface_hub import hf_hub_url, cached_download
- from CRAFT.craft import CRAFT, init_CRAFT_model
- from CRAFT.refinenet import RefineNet, init_refiner_model
- from CRAFT.craft_utils import adjustResultCoordinates, getDetBoxes
- from CRAFT.imgproc import resize_aspect_ratio, normalizeMeanVariance
- HF_MODELS = {
- 'craft': dict(
- repo_id='boomb0om/CRAFT-text-detector',
- filename='craft_mlt_25k.pth',
- ),
- 'refiner': dict(
- repo_id='boomb0om/CRAFT-text-detector',
- filename='craft_refiner_CTW1500.pth',
- )
- }
-
- def preprocess_image(image: np.ndarray, canvas_size: int, mag_ratio: bool):
- # resize
- img_resized, target_ratio, size_heatmap = resize_aspect_ratio(
- image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio
- )
- ratio_h = ratio_w = 1 / target_ratio
- # preprocessing
- x = normalizeMeanVariance(img_resized)
- x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
- x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
- return x, ratio_w, ratio_h
- class CRAFTModel:
-
- def __init__(
- self,
- cache_dir: str,
- device: torch.device,
- local_files_only: bool = False,
- use_refiner: bool = True,
- fp16: bool = True,
- canvas_size: int = 1280,
- mag_ratio: float = 1.5,
- text_threshold: float = 0.7,
- link_threshold: float = 0.4,
- low_text: float = 0.4
- ):
- self.cache_dir = cache_dir
- self.use_refiner = use_refiner
- self.device = device
- self.fp16 = fp16
-
- self.canvas_size = canvas_size
- self.mag_ratio = mag_ratio
- self.text_threshold = text_threshold
- self.link_threshold = link_threshold
- self.low_text = low_text
-
- # loading models
- paths = {}
- for model_name in ['craft', 'refiner']:
- config = HF_MODELS[model_name]
- paths[model_name] = os.path.join(cache_dir, config['filename'])
- if not local_files_only:
- config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
- cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename'])
-
- self.net = init_CRAFT_model(paths['craft'], device, fp16=fp16)
- if self.use_refiner:
- self.refiner = init_refiner_model(paths['refiner'], device)
- else:
- self.refiner = None
-
- def get_text_map(self, x: torch.Tensor, ratio_w: int, ratio_h: int) -> Tuple[np.ndarray, np.ndarray]:
- x = x.to(self.device)
- # forward pass
- with torch.no_grad():
- y, feature = self.net(x)
- # make score and link map
- score_text = y[0,:,:,0].cpu().data.numpy()
- score_link = y[0,:,:,1].cpu().data.numpy()
- # refine link
- if self.refiner:
- with torch.no_grad():
- y_refiner = self.refiner(y, feature)
- score_link = y_refiner[0,:,:,0].cpu().data.numpy()
-
- return score_text, score_link
- def get_polygons(self, image: Image.Image) -> List[List[List[int]]]:
- x, ratio_w, ratio_h = preprocess_image(np.array(image), self.canvas_size, self.mag_ratio)
-
- score_text, score_link = self.get_text_map(x, ratio_w, ratio_h)
-
- # Post-processing
- boxes, polys = getDetBoxes(
- score_text, score_link,
- self.text_threshold, self.link_threshold,
- self.low_text, True
- )
- boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
- for k in range(len(polys)):
- if polys[k] is None:
- polys[k] = boxes[k]
- else:
- polys[k] = adjustResultCoordinates(polys[k], ratio_w, ratio_h)
- res = []
- for poly in polys:
- res.append(poly.astype(np.int32).tolist())
- return res
-
- def _get_boxes_preproc(self, x, ratio_w, ratio_h) -> List[List[List[int]]]:
- score_text, score_link = self.get_text_map(x, ratio_w, ratio_h)
-
- # Post-processing
- boxes, polys = getDetBoxes(
- score_text, score_link,
- self.text_threshold, self.link_threshold,
- self.low_text, False
- )
-
- boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
- boxes_final = []
- if len(boxes)>0:
- boxes = boxes.astype(np.int32).tolist()
- for box in boxes:
- boxes_final.append([box[0], box[2]])
- return boxes_final
-
- def get_boxes(self, image: Image.Image) -> List[List[List[int]]]:
- x, ratio_w, ratio_h = preprocess_image(np.array(image), self.canvas_size, self.mag_ratio)
-
- boxes_final = self._get_boxes_preproc(x, ratio_w, ratio_h)
- return boxes_final
|