|
@@ -8,10 +8,10 @@ import numpy as np
|
|
|
import cv2
|
|
|
from huggingface_hub import hf_hub_url, cached_download
|
|
|
|
|
|
-from CRAFT.craft import CRAFT, init_CRAFT_model
|
|
|
-from CRAFT.refinenet import RefineNet, init_refiner_model
|
|
|
-import CRAFT.craft_utils as craft_utils
|
|
|
-import CRAFT.imgproc as imgproc
|
|
|
+from .craft import CRAFT, init_CRAFT_model
|
|
|
+from .refinenet import RefineNet, init_refiner_model
|
|
|
+from .craft_utils import adjustResultCoordinates, getDetBoxes
|
|
|
+from .imgproc import resize_aspect_ratio, normalizeMeanVariance
|
|
|
|
|
|
|
|
|
HF_MODELS = {
|
|
@@ -28,13 +28,13 @@ HF_MODELS = {
|
|
|
|
|
|
def preprocess_image(image: np.ndarray, canvas_size: int, mag_ratio: bool):
|
|
|
# resize
|
|
|
- img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
|
|
|
+ img_resized, target_ratio, size_heatmap = resize_aspect_ratio(
|
|
|
image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio
|
|
|
)
|
|
|
ratio_h = ratio_w = 1 / target_ratio
|
|
|
|
|
|
# preprocessing
|
|
|
- x = imgproc.normalizeMeanVariance(img_resized)
|
|
|
+ x = normalizeMeanVariance(img_resized)
|
|
|
x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
|
|
|
x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
|
|
|
return x, ratio_w, ratio_h
|
|
@@ -106,14 +106,14 @@ class CRAFTModel:
|
|
|
score_text, score_link = self.get_text_map(x, ratio_w, ratio_h)
|
|
|
|
|
|
# Post-processing
|
|
|
- boxes, polys = craft_utils.getDetBoxes(
|
|
|
+ boxes, polys = getDetBoxes(
|
|
|
score_text, score_link,
|
|
|
self.text_threshold, self.link_threshold,
|
|
|
self.low_text, True
|
|
|
)
|
|
|
|
|
|
- boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
|
|
|
- polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
|
|
|
+ boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
|
|
|
+ polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
|
|
|
for k in range(len(polys)):
|
|
|
if polys[k] is None: polys[k] = boxes[k]
|
|
|
|
|
@@ -126,13 +126,13 @@ class CRAFTModel:
|
|
|
score_text, score_link = self.get_text_map(x, ratio_w, ratio_h)
|
|
|
|
|
|
# Post-processing
|
|
|
- boxes, polys = craft_utils.getDetBoxes(
|
|
|
+ boxes, polys = getDetBoxes(
|
|
|
score_text, score_link,
|
|
|
self.text_threshold, self.link_threshold,
|
|
|
self.low_text, False
|
|
|
)
|
|
|
|
|
|
- boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
|
|
|
+ boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
|
|
|
boxes_final = []
|
|
|
if len(boxes)>0:
|
|
|
boxes = boxes.astype(np.int32).tolist()
|