boomb0om 2 anni fa
parent
commit
bcac40da7a
4 ha cambiato i file con 38 aggiunte e 38 eliminazioni
  1. 3 3
      CRAFT/craft.py
  2. 11 11
      CRAFT/model.py
  3. 2 2
      CRAFT/refinenet.py
  4. 22 22
      example.ipynb

+ 3 - 3
CRAFT/craft.py

@@ -8,9 +8,9 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from CRAFT.basenet.vgg16_bn import vgg16_bn, init_weights
-from CRAFT.utils import copyStateDict
-from CRAFT.fp16 import FP16Module
+from .basenet.vgg16_bn import vgg16_bn, init_weights
+from .utils import copyStateDict
+from .fp16 import FP16Module
 
 
 class double_conv(nn.Module):

+ 11 - 11
CRAFT/model.py

@@ -8,10 +8,10 @@ import numpy as np
 import cv2
 from huggingface_hub import hf_hub_url, cached_download
 
-from CRAFT.craft import CRAFT, init_CRAFT_model
-from CRAFT.refinenet import RefineNet, init_refiner_model
-import CRAFT.craft_utils as craft_utils
-import CRAFT.imgproc as imgproc
+from .craft import CRAFT, init_CRAFT_model
+from .refinenet import RefineNet, init_refiner_model
+from .craft_utils import adjustResultCoordinates, getDetBoxes
+from .imgproc import resize_aspect_ratio, normalizeMeanVariance
 
 
 HF_MODELS = {
@@ -28,13 +28,13 @@ HF_MODELS = {
     
 def preprocess_image(image: np.ndarray, canvas_size: int, mag_ratio: bool):
     # resize
-    img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
+    img_resized, target_ratio, size_heatmap = resize_aspect_ratio(
         image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio
     )
     ratio_h = ratio_w = 1 / target_ratio
 
     # preprocessing
-    x = imgproc.normalizeMeanVariance(img_resized)
+    x = normalizeMeanVariance(img_resized)
     x = torch.from_numpy(x).permute(2, 0, 1)    # [h, w, c] to [c, h, w]
     x = Variable(x.unsqueeze(0))                # [c, h, w] to [b, c, h, w]
     return x, ratio_w, ratio_h
@@ -106,14 +106,14 @@ class CRAFTModel:
         score_text, score_link = self.get_text_map(x, ratio_w, ratio_h)
         
         # Post-processing
-        boxes, polys = craft_utils.getDetBoxes(
+        boxes, polys = getDetBoxes(
             score_text, score_link, 
             self.text_threshold, self.link_threshold, 
             self.low_text, True
         )
         
-        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
-        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
+        boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
+        polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
         for k in range(len(polys)):
             if polys[k] is None: polys[k] = boxes[k]
 
@@ -126,13 +126,13 @@ class CRAFTModel:
         score_text, score_link = self.get_text_map(x, ratio_w, ratio_h)
         
         # Post-processing
-        boxes, polys = craft_utils.getDetBoxes(
+        boxes, polys = getDetBoxes(
             score_text, score_link, 
             self.text_threshold, self.link_threshold, 
             self.low_text, False
         )
         
-        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
+        boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
         boxes_final = []
         if len(boxes)>0:
             boxes = boxes.astype(np.int32).tolist()

+ 2 - 2
CRAFT/refinenet.py

@@ -8,8 +8,8 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from torch.autograd import Variable
-from CRAFT.basenet.vgg16_bn import init_weights
-from CRAFT.utils import copyStateDict
+from .basenet.vgg16_bn import init_weights
+from .utils import copyStateDict
 
 
 class RefineNet(nn.Module):

+ 22 - 22
example.ipynb

@@ -3,7 +3,7 @@
   {
    "cell_type": "code",
    "execution_count": 1,
-   "id": "bba67631",
+   "id": "69761026",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -16,7 +16,7 @@
   {
    "cell_type": "code",
    "execution_count": 2,
-   "id": "6045288a",
+   "id": "f32678d6",
    "metadata": {},
    "outputs": [
     {
@@ -35,7 +35,7 @@
   {
    "cell_type": "code",
    "execution_count": 3,
-   "id": "000251de",
+   "id": "dc3a7fc5",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -44,7 +44,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "1f6755fb",
+   "id": "b47dd639",
    "metadata": {},
    "source": [
     "## Predict boxes"
@@ -52,16 +52,16 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 14,
-   "id": "a00bd58a",
+   "execution_count": 13,
+   "id": "472a1a05",
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "CPU times: user 77.6 ms, sys: 20.9 ms, total: 98.5 ms\n",
-      "Wall time: 67.8 ms\n"
+      "CPU times: user 88 ms, sys: 8.28 ms, total: 96.3 ms\n",
+      "Wall time: 63.9 ms\n"
      ]
     }
    ],
@@ -74,7 +74,7 @@
   {
    "cell_type": "code",
    "execution_count": 5,
-   "id": "10c9ffcf",
+   "id": "af29bd5f",
    "metadata": {},
    "outputs": [
     {
@@ -96,7 +96,7 @@
   {
    "cell_type": "code",
    "execution_count": 6,
-   "id": "b69f82fd",
+   "id": "29cf4256",
    "metadata": {},
    "outputs": [
     {
@@ -118,7 +118,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "1b9e564c",
+   "id": "2ff48e55",
    "metadata": {},
    "source": [
     "## Predict polygons"
@@ -126,16 +126,16 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
-   "id": "65751da8",
+   "execution_count": 7,
+   "id": "94bff170",
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "CPU times: user 77 ms, sys: 57.7 ms, total: 135 ms\n",
-      "Wall time: 80.3 ms\n"
+      "CPU times: user 89.1 ms, sys: 40.9 ms, total: 130 ms\n",
+      "Wall time: 77.3 ms\n"
      ]
     },
     {
@@ -155,8 +155,8 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
-   "id": "f18ababd",
+   "execution_count": 8,
+   "id": "dea33597",
    "metadata": {},
    "outputs": [
     {
@@ -166,7 +166,7 @@
        "<PIL.Image.Image image mode=RGB size=500x603>"
       ]
      },
-     "execution_count": 9,
+     "execution_count": 8,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -177,8 +177,8 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
-   "id": "183a32f0",
+   "execution_count": 9,
+   "id": "23c82ca1",
    "metadata": {},
    "outputs": [
     {
@@ -187,7 +187,7 @@
        "0.23886567164179104"
       ]
      },
-     "execution_count": 10,
+     "execution_count": 9,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -201,7 +201,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "17190321",
+   "id": "575f383e",
    "metadata": {},
    "outputs": [],
    "source": []