ソースを参照

Initial commit

boomb0om 2 年 前
コミット
ebe3b1f3ee
17 ファイル変更967 行追加1 行削除
  1. 2 0
      .gitignore
  2. 2 0
      CRAFT/__init__.py
  3. 0 0
      CRAFT/basenet/__init__.py
  4. 73 0
      CRAFT/basenet/vgg16_bn.py
  5. 95 0
      CRAFT/craft.py
  6. 243 0
      CRAFT/craft_utils.py
  7. 64 0
      CRAFT/fp16.py
  8. 70 0
      CRAFT/imgproc.py
  9. 144 0
      CRAFT/model.py
  10. 75 0
      CRAFT/refinenet.py
  11. 68 0
      CRAFT/utils.py
  12. 30 1
      README.md
  13. 81 0
      example.ipynb
  14. BIN
      images/cafe_sign.jpg
  15. BIN
      images/result.jpg
  16. 5 0
      requirements.txt
  17. 15 0
      test.py

+ 2 - 0
.gitignore

@@ -1,3 +1,5 @@
+weights/
+
 # Byte-compiled / optimized / DLL files
 __pycache__/
 *.py[cod]

+ 2 - 0
CRAFT/__init__.py

@@ -0,0 +1,2 @@
+from .model import CRAFTModel
+from .utils import draw_boxes, draw_polygons, boxes_area, polygons_area

+ 0 - 0
CRAFT/basenet/__init__.py


+ 73 - 0
CRAFT/basenet/vgg16_bn.py

@@ -0,0 +1,73 @@
+from collections import namedtuple
+
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+from torchvision import models
+from torchvision.models.vgg import model_urls
+
+def init_weights(modules):
+    for m in modules:
+        if isinstance(m, nn.Conv2d):
+            init.xavier_uniform_(m.weight.data)
+            if m.bias is not None:
+                m.bias.data.zero_()
+        elif isinstance(m, nn.BatchNorm2d):
+            m.weight.data.fill_(1)
+            m.bias.data.zero_()
+        elif isinstance(m, nn.Linear):
+            m.weight.data.normal_(0, 0.01)
+            m.bias.data.zero_()
+
+class vgg16_bn(torch.nn.Module):
+    def __init__(self, pretrained=True, freeze=True):
+        super(vgg16_bn, self).__init__()
+        model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://')
+        vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
+        self.slice1 = torch.nn.Sequential()
+        self.slice2 = torch.nn.Sequential()
+        self.slice3 = torch.nn.Sequential()
+        self.slice4 = torch.nn.Sequential()
+        self.slice5 = torch.nn.Sequential()
+        for x in range(12):         # conv2_2
+            self.slice1.add_module(str(x), vgg_pretrained_features[x])
+        for x in range(12, 19):         # conv3_3
+            self.slice2.add_module(str(x), vgg_pretrained_features[x])
+        for x in range(19, 29):         # conv4_3
+            self.slice3.add_module(str(x), vgg_pretrained_features[x])
+        for x in range(29, 39):         # conv5_3
+            self.slice4.add_module(str(x), vgg_pretrained_features[x])
+
+        # fc6, fc7 without atrous conv
+        self.slice5 = torch.nn.Sequential(
+                nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
+                nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
+                nn.Conv2d(1024, 1024, kernel_size=1)
+        )
+
+        if not pretrained:
+            init_weights(self.slice1.modules())
+            init_weights(self.slice2.modules())
+            init_weights(self.slice3.modules())
+            init_weights(self.slice4.modules())
+
+        init_weights(self.slice5.modules())        # no pretrained model for fc6 and fc7
+
+        if freeze:
+            for param in self.slice1.parameters():      # only first conv
+                param.requires_grad= False
+
+    def forward(self, X):
+        h = self.slice1(X)
+        h_relu2_2 = h
+        h = self.slice2(h)
+        h_relu3_2 = h
+        h = self.slice3(h)
+        h_relu4_3 = h
+        h = self.slice4(h)
+        h_relu5_3 = h
+        h = self.slice5(h)
+        h_fc7 = h
+        vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
+        out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
+        return out

+ 95 - 0
CRAFT/craft.py

@@ -0,0 +1,95 @@
+"""  
+Copyright (c) 2019-present NAVER Corp.
+MIT License
+"""
+
+# -*- coding: utf-8 -*-
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from CRAFT.basenet.vgg16_bn import vgg16_bn, init_weights
+from CRAFT.utils import copyStateDict
+from CRAFT.fp16 import FP16Module
+
+
+class double_conv(nn.Module):
+    
+    def __init__(self, in_ch, mid_ch, out_ch):
+        super(double_conv, self).__init__()
+        self.conv = nn.Sequential(
+            nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
+            nn.BatchNorm2d(mid_ch),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
+            nn.BatchNorm2d(out_ch),
+            nn.ReLU(inplace=True)
+        )
+
+    def forward(self, x):
+        x = self.conv(x)
+        return x
+
+
+class CRAFT(nn.Module):
+    
+    def __init__(self, pretrained=False, freeze=False):
+        super(CRAFT, self).__init__()
+
+        """ Base network """
+        self.basenet = vgg16_bn(pretrained, freeze)
+
+        """ U network """
+        self.upconv1 = double_conv(1024, 512, 256)
+        self.upconv2 = double_conv(512, 256, 128)
+        self.upconv3 = double_conv(256, 128, 64)
+        self.upconv4 = double_conv(128, 64, 32)
+
+        num_class = 2
+        self.conv_cls = nn.Sequential(
+            nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
+            nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
+            nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
+            nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
+            nn.Conv2d(16, num_class, kernel_size=1),
+        )
+
+        init_weights(self.upconv1.modules())
+        init_weights(self.upconv2.modules())
+        init_weights(self.upconv3.modules())
+        init_weights(self.upconv4.modules())
+        init_weights(self.conv_cls.modules())
+        
+    def forward(self, x):
+        """ Base network """
+        sources = self.basenet(x)
+
+        """ U network """
+        y = torch.cat([sources[0], sources[1]], dim=1)
+        y = self.upconv1(y)
+
+        y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
+        y = torch.cat([y, sources[2]], dim=1)
+        y = self.upconv2(y)
+
+        y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
+        y = torch.cat([y, sources[3]], dim=1)
+        y = self.upconv3(y)
+
+        y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
+        y = torch.cat([y, sources[4]], dim=1)
+        feature = self.upconv4(y)
+
+        y = self.conv_cls(feature)
+
+        return y.permute(0,2,3,1), feature
+
+
+def init_CRAFT_model(chekpoint_path: str, device: str, fp16: bool = True) -> CRAFT:
+    net = CRAFT()
+    net.load_state_dict(copyStateDict(torch.load(chekpoint_path, map_location=torch.device('cpu'))))
+    if fp16:
+        net = FP16Module(net)
+    net = net.to(device)
+    net.eval()
+    return net

+ 243 - 0
CRAFT/craft_utils.py

@@ -0,0 +1,243 @@
+"""  
+Copyright (c) 2019-present NAVER Corp.
+MIT License
+"""
+
+# -*- coding: utf-8 -*-
+import numpy as np
+import cv2
+import math
+
+""" auxilary functions """
+# unwarp corodinates
+def warpCoord(Minv, pt):
+    out = np.matmul(Minv, (pt[0], pt[1], 1))
+    return np.array([out[0]/out[2], out[1]/out[2]])
+""" end of auxilary functions """
+
+
+def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
+    # prepare data
+    linkmap = linkmap.copy()
+    textmap = textmap.copy()
+    img_h, img_w = textmap.shape
+
+    """ labeling method """
+    ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
+    ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
+
+    text_score_comb = np.clip(text_score + link_score, 0, 1)
+    nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4)
+
+    det = []
+    mapper = []
+    for k in range(1,nLabels):
+        # size filtering
+        size = stats[k, cv2.CC_STAT_AREA]
+        if size < 10: continue
+
+        # thresholding
+        if np.max(textmap[labels==k]) < text_threshold: continue
+
+        # make segmentation map
+        segmap = np.zeros(textmap.shape, dtype=np.uint8)
+        segmap[labels==k] = 255
+        segmap[np.logical_and(link_score==1, text_score==0)] = 0   # remove link area
+        x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
+        w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
+        niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
+        sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
+        # boundary check
+        if sx < 0 : sx = 0
+        if sy < 0 : sy = 0
+        if ex >= img_w: ex = img_w
+        if ey >= img_h: ey = img_h
+        kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))
+        segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)
+
+        # make box
+        np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)
+        rectangle = cv2.minAreaRect(np_contours)
+        box = cv2.boxPoints(rectangle)
+
+        # align diamond-shape
+        w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
+        box_ratio = max(w, h) / (min(w, h) + 1e-5)
+        if abs(1 - box_ratio) <= 0.1:
+            l, r = min(np_contours[:,0]), max(np_contours[:,0])
+            t, b = min(np_contours[:,1]), max(np_contours[:,1])
+            box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
+
+        # make clock-wise order
+        startidx = box.sum(axis=1).argmin()
+        box = np.roll(box, 4-startidx, 0)
+        box = np.array(box)
+
+        det.append(box)
+        mapper.append(k)
+
+    return det, labels, mapper
+
+def getPoly_core(boxes, labels, mapper, linkmap):
+    # configs
+    num_cp = 5
+    max_len_ratio = 0.7
+    expand_ratio = 1.45
+    max_r = 2.0
+    step_r = 0.2
+
+    polys = []  
+    for k, box in enumerate(boxes):
+        # size filter for small instance
+        w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1)
+        if w < 10 or h < 10:
+            polys.append(None); continue
+
+        # warp image
+        tar = np.float32([[0,0],[w,0],[w,h],[0,h]])
+        M = cv2.getPerspectiveTransform(box, tar)
+        word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST)
+        try:
+            Minv = np.linalg.inv(M)
+        except:
+            polys.append(None); continue
+
+        # binarization for selected label
+        cur_label = mapper[k]
+        word_label[word_label != cur_label] = 0
+        word_label[word_label > 0] = 1
+
+        """ Polygon generation """
+        # find top/bottom contours
+        cp = []
+        max_len = -1
+        for i in range(w):
+            region = np.where(word_label[:,i] != 0)[0]
+            if len(region) < 2 : continue
+            cp.append((i, region[0], region[-1]))
+            length = region[-1] - region[0] + 1
+            if length > max_len: max_len = length
+
+        # pass if max_len is similar to h
+        if h * max_len_ratio < max_len:
+            polys.append(None); continue
+
+        # get pivot points with fixed length
+        tot_seg = num_cp * 2 + 1
+        seg_w = w / tot_seg     # segment width
+        pp = [None] * num_cp    # init pivot points
+        cp_section = [[0, 0]] * tot_seg
+        seg_height = [0] * num_cp
+        seg_num = 0
+        num_sec = 0
+        prev_h = -1
+        for i in range(0,len(cp)):
+            (x, sy, ey) = cp[i]
+            if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
+                # average previous segment
+                if num_sec == 0: break
+                cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec]
+                num_sec = 0
+
+                # reset variables
+                seg_num += 1
+                prev_h = -1
+
+            # accumulate center points
+            cy = (sy + ey) * 0.5
+            cur_h = ey - sy + 1
+            cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy]
+            num_sec += 1
+
+            if seg_num % 2 == 0: continue # No polygon area
+
+            if prev_h < cur_h:
+                pp[int((seg_num - 1)/2)] = (x, cy)
+                seg_height[int((seg_num - 1)/2)] = cur_h
+                prev_h = cur_h
+
+        # processing last segment
+        if num_sec != 0:
+            cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]
+
+        # pass if num of pivots is not sufficient or segment widh is smaller than character height 
+        if None in pp or seg_w < np.max(seg_height) * 0.25:
+            polys.append(None); continue
+
+        # calc median maximum of pivot points
+        half_char_h = np.median(seg_height) * expand_ratio / 2
+
+        # calc gradiant and apply to make horizontal pivots
+        new_pp = []
+        for i, (x, cy) in enumerate(pp):
+            dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
+            dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
+            if dx == 0:     # gradient if zero
+                new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
+                continue
+            rad = - math.atan2(dy, dx)
+            c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
+            new_pp.append([x - s, cy - c, x + s, cy + c])
+
+        # get edge points to cover character heatmaps
+        isSppFound, isEppFound = False, False
+        grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])
+        grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])
+        for r in np.arange(0.5, max_r, step_r):
+            dx = 2 * half_char_h * r
+            if not isSppFound:
+                line_img = np.zeros(word_label.shape, dtype=np.uint8)
+                dy = grad_s * dx
+                p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
+                cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
+                if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
+                    spp = p
+                    isSppFound = True
+            if not isEppFound:
+                line_img = np.zeros(word_label.shape, dtype=np.uint8)
+                dy = grad_e * dx
+                p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
+                cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
+                if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
+                    epp = p
+                    isEppFound = True
+            if isSppFound and isEppFound:
+                break
+
+        # pass if boundary of polygon is not found
+        if not (isSppFound and isEppFound):
+            polys.append(None); continue
+
+        # make final polygon
+        poly = []
+        poly.append(warpCoord(Minv, (spp[0], spp[1])))
+        for p in new_pp:
+            poly.append(warpCoord(Minv, (p[0], p[1])))
+        poly.append(warpCoord(Minv, (epp[0], epp[1])))
+        poly.append(warpCoord(Minv, (epp[2], epp[3])))
+        for p in reversed(new_pp):
+            poly.append(warpCoord(Minv, (p[2], p[3])))
+        poly.append(warpCoord(Minv, (spp[2], spp[3])))
+
+        # add to final result
+        polys.append(np.array(poly))
+
+    return polys
+
+def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):
+    boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)
+
+    if poly:
+        polys = getPoly_core(boxes, labels, mapper, linkmap)
+    else:
+        polys = [None] * len(boxes)
+
+    return boxes, polys
+
+def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2):
+    if len(polys) > 0:
+        polys = np.array(polys)
+        for k in range(len(polys)):
+            if polys[k] is not None:
+                polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)
+    return polys

+ 64 - 0
CRAFT/fp16.py

@@ -0,0 +1,64 @@
+# -*- coding: utf-8 -*-
+import torch
+from torch import nn
+from torch.autograd import Variable
+from torch.nn.parameter import Parameter
+
+FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
+HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
+
+
+def conversion_helper(val, conversion):
+    """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
+    if not isinstance(val, (tuple, list)):
+        return conversion(val)
+    rtn = [conversion_helper(v, conversion) for v in val]
+    if isinstance(val, tuple):
+        rtn = tuple(rtn)
+    return rtn
+
+
+def fp32_to_fp16(val):
+    """Convert fp32 `val` to fp16"""
+    def half_conversion(val):
+        val_typecheck = val
+        if isinstance(val_typecheck, (Parameter, Variable)):
+            val_typecheck = val.data
+        if isinstance(val_typecheck, FLOAT_TYPES):
+            val = val.half()
+        return val
+    return conversion_helper(val, half_conversion)
+
+
+def fp16_to_fp32(val):
+    """Convert fp16 `val` to fp32"""
+    def float_conversion(val):
+        val_typecheck = val
+        if isinstance(val_typecheck, (Parameter, Variable)):
+            val_typecheck = val.data
+        if isinstance(val_typecheck, HALF_TYPES):
+            val = val.float()
+        return val
+    return conversion_helper(val, float_conversion)
+
+
+class FP16Module(nn.Module):
+    def __init__(self, module):
+        super(FP16Module, self).__init__()
+        self.add_module('module', module.half())
+
+    def forward(self, *inputs, **kwargs):
+        return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
+
+    def state_dict(self, destination=None, prefix='', keep_vars=False):
+        return self.module.state_dict(destination, prefix, keep_vars)
+
+    def load_state_dict(self, state_dict, strict=True):
+        self.module.load_state_dict(state_dict, strict=strict)
+
+    def get_param(self, item):
+        return self.module.get_param(item)
+
+    def to(self, device, *args, **kwargs):
+        self.module.to(device)
+        return super().to(device, *args, **kwargs)

+ 70 - 0
CRAFT/imgproc.py

@@ -0,0 +1,70 @@
+"""  
+Copyright (c) 2019-present NAVER Corp.
+MIT License
+"""
+
+# -*- coding: utf-8 -*-
+import numpy as np
+from skimage import io
+import cv2
+
+def loadImage(img_file):
+    img = io.imread(img_file)           # RGB order
+    if img.shape[0] == 2: img = img[0]
+    if len(img.shape) == 2 : img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+    if img.shape[2] == 4:   img = img[:,:,:3]
+    img = np.array(img)
+
+    return img
+
+def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):
+    # should be RGB order
+    img = in_img.copy().astype(np.float32)
+
+    img -= np.array([mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32)
+    img /= np.array([variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], dtype=np.float32)
+    return img
+
+def denormalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):
+    # should be RGB order
+    img = in_img.copy()
+    img *= variance
+    img += mean
+    img *= 255.0
+    img = np.clip(img, 0, 255).astype(np.uint8)
+    return img
+
+def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1):
+    height, width, channel = img.shape
+
+    # magnify image size
+    target_size = mag_ratio * max(height, width)
+
+    # set original image size
+    if target_size > square_size:
+        target_size = square_size
+    
+    ratio = target_size / max(height, width)    
+
+    target_h, target_w = int(height * ratio), int(width * ratio)
+    proc = cv2.resize(img, (target_w, target_h), interpolation = interpolation)
+
+
+    # make canvas and paste image
+    target_h32, target_w32 = target_h, target_w
+    if target_h % 32 != 0:
+        target_h32 = target_h + (32 - target_h % 32)
+    if target_w % 32 != 0:
+        target_w32 = target_w + (32 - target_w % 32)
+    resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32)
+    resized[0:target_h, 0:target_w, :] = proc
+    target_h, target_w = target_h32, target_w32
+
+    size_heatmap = (int(target_w/2), int(target_h/2))
+
+    return resized, ratio, size_heatmap
+
+def cvt2HeatmapImg(img):
+    img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
+    img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
+    return img

+ 144 - 0
CRAFT/model.py

@@ -0,0 +1,144 @@
+from typing import List, Tuple, Optional
+import os
+import time
+import torch
+from torch.autograd import Variable
+from PIL import Image
+import numpy as np
+import cv2
+from huggingface_hub import hf_hub_url, cached_download
+
+from CRAFT.craft import CRAFT, init_CRAFT_model
+from CRAFT.refinenet import RefineNet, init_refiner_model
+import CRAFT.craft_utils as craft_utils
+import CRAFT.imgproc as imgproc
+
+
+HF_MODELS = {
+    'craft': dict(
+        repo_id='boomb0om/CRAFT-text-detector',
+        filename='craft_mlt_25k.pth',
+    ),
+    'refiner': dict(
+        repo_id='boomb0om/CRAFT-text-detector',
+        filename='craft_refiner_CTW1500.pth',
+    )
+}
+
+    
+def preprocess_image(image: np.ndarray, canvas_size: int, mag_ratio: bool):
+    # resize
+    img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
+        image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio
+    )
+    ratio_h = ratio_w = 1 / target_ratio
+
+    # preprocessing
+    x = imgproc.normalizeMeanVariance(img_resized)
+    x = torch.from_numpy(x).permute(2, 0, 1)    # [h, w, c] to [c, h, w]
+    x = Variable(x.unsqueeze(0))                # [c, h, w] to [b, c, h, w]
+    return x, ratio_w, ratio_h
+
+
+class CRAFTModel:
+    
+    def __init__(
+        self, 
+        cache_dir: str,
+        device: torch.device,
+        local_files_only: bool = False,
+        use_refiner: bool = True,
+        fp16: bool = True,
+        canvas_size: int = 1280,
+        mag_ratio: float = 1.5,
+        text_threshold: float = 0.7,
+        link_threshold: float = 0.4,
+        low_text: float = 0.4
+    ):
+        self.cache_dir = cache_dir
+        self.use_refiner = use_refiner
+        self.device = device
+        self.fp16 = fp16
+        
+        self.canvas_size = canvas_size
+        self.mag_ratio = mag_ratio
+        self.text_threshold = text_threshold
+        self.link_threshold = link_threshold
+        self.low_text = low_text
+        
+        # loading models
+        paths = {}
+        for model_name in ['craft', 'refiner']:
+            config = HF_MODELS[model_name]
+            paths[model_name] = os.path.join(cache_dir, config['filename'])
+            if not local_files_only:
+                config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
+                cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename'])
+            
+        self.net = init_CRAFT_model(paths['craft'], device, fp16=fp16)
+        if self.use_refiner:
+            self.refiner = init_refiner_model(paths['refiner'], device)
+        else:
+            self.refiner = None
+        
+    def get_text_map(self, x: torch.Tensor, ratio_w: int, ratio_h: int) -> Tuple[np.ndarray, np.ndarray]:
+        x = x.to(self.device)
+
+        # forward pass
+        with torch.no_grad():
+            y, feature = self.net(x)
+
+        # make score and link map
+        score_text = y[0,:,:,0].cpu().data.numpy()
+        score_link = y[0,:,:,1].cpu().data.numpy()
+
+        # refine link
+        if self.refiner:
+            with torch.no_grad():
+                y_refiner = self.refiner(y, feature)
+            score_link = y_refiner[0,:,:,0].cpu().data.numpy()
+            
+        return score_text, score_link
+
+    def get_polygons(self, image: Image.Image) -> List[List[List[int]]]:
+        x, ratio_w, ratio_h = preprocess_image(np.array(image), self.canvas_size, self.mag_ratio)
+        
+        score_text, score_link = self.get_text_map(x, ratio_w, ratio_h)
+        
+        # Post-processing
+        boxes, polys = craft_utils.getDetBoxes(
+            score_text, score_link, 
+            self.text_threshold, self.link_threshold, 
+            self.low_text, True
+        )
+        
+        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
+        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
+        for k in range(len(polys)):
+            if polys[k] is None: polys[k] = boxes[k]
+
+        res = []
+        for poly in polys:
+            res.append(poly.astype(np.int32).tolist())
+        return res
+    
+    def get_boxes(self, image: Image.Image) -> List[List[List[int]]]:
+        x, ratio_w, ratio_h = preprocess_image(np.array(image), self.canvas_size, self.mag_ratio)
+        
+        score_text, score_link = self.get_text_map(x, ratio_w, ratio_h)
+        
+        # Post-processing
+        boxes, polys = craft_utils.getDetBoxes(
+            score_text, score_link, 
+            self.text_threshold, self.link_threshold, 
+            self.low_text, False
+        )
+        
+        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
+        boxes_final = []
+        if len(boxes)>0:
+            boxes = boxes.astype(np.int32).tolist()
+            for box in boxes:
+                boxes_final.append([box[0], box[2]])
+
+        return boxes_final

+ 75 - 0
CRAFT/refinenet.py

@@ -0,0 +1,75 @@
+"""  
+Copyright (c) 2019-present NAVER Corp.
+MIT License
+"""
+
+# -*- coding: utf-8 -*-
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+from CRAFT.basenet.vgg16_bn import init_weights
+from CRAFT.utils import copyStateDict
+
+
+class RefineNet(nn.Module):
+    
+    def __init__(self):
+        super(RefineNet, self).__init__()
+
+        self.last_conv = nn.Sequential(
+            nn.Conv2d(34, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
+            nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
+            nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)
+        )
+
+        self.aspp1 = nn.Sequential(
+            nn.Conv2d(64, 128, kernel_size=3, dilation=6, padding=6), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 1, kernel_size=1)
+        )
+
+        self.aspp2 = nn.Sequential(
+            nn.Conv2d(64, 128, kernel_size=3, dilation=12, padding=12), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 1, kernel_size=1)
+        )
+
+        self.aspp3 = nn.Sequential(
+            nn.Conv2d(64, 128, kernel_size=3, dilation=18, padding=18), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 1, kernel_size=1)
+        )
+
+        self.aspp4 = nn.Sequential(
+            nn.Conv2d(64, 128, kernel_size=3, dilation=24, padding=24), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 1, kernel_size=1)
+        )
+
+        init_weights(self.last_conv.modules())
+        init_weights(self.aspp1.modules())
+        init_weights(self.aspp2.modules())
+        init_weights(self.aspp3.modules())
+        init_weights(self.aspp4.modules())
+
+    def forward(self, y, upconv4):
+        refine = torch.cat([y.permute(0,3,1,2), upconv4], dim=1)
+        refine = self.last_conv(refine)
+
+        aspp1 = self.aspp1(refine)
+        aspp2 = self.aspp2(refine)
+        aspp3 = self.aspp3(refine)
+        aspp4 = self.aspp4(refine)
+
+        #out = torch.add([aspp1, aspp2, aspp3, aspp4], dim=1)
+        out = aspp1 + aspp2 + aspp3 + aspp4
+        return out.permute(0, 2, 3, 1)  # , refine.permute(0,2,3,1)
+    
+    
+def init_refiner_model(chekpoint_path: str, device: torch.device) -> RefineNet:
+    refine_net = RefineNet()
+    refine_net.load_state_dict(copyStateDict(torch.load(chekpoint_path, map_location=torch.device('cpu'))))
+    refine_net = refine_net.to(device)
+    refine_net.eval()
+    return refine_net

+ 68 - 0
CRAFT/utils.py

@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+from typing import List
+import os
+from PIL import Image
+import numpy as np
+import cv2
+from shapely.geometry import Polygon
+from collections import OrderedDict
+
+
+def str2bool(v: str) -> bool:
+    return v.lower() in ("yes", "y", "true", "t", "1")
+
+
+def copyStateDict(state_dict):
+    if list(state_dict.keys())[0].startswith("module"):
+        start_idx = 1
+    else:
+        start_idx = 0
+    new_state_dict = OrderedDict()
+    for k, v in state_dict.items():
+        name = ".".join(k.split(".")[start_idx:])
+        new_state_dict[name] = v
+    return new_state_dict
+
+
+def draw_boxes(image: Image.Image, boxes: List[List[List[int]]], line_thickness: int = 2) -> Image.Image:
+    img = np.array(image)
+    for i, box in enumerate(boxes):
+        poly_ = np.array(box_to_poly(box)).astype(np.int32).reshape((-1))
+        poly_ = poly_.reshape(-1, 2)
+        cv2.polylines(img, [poly_.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=line_thickness)
+        ptColor = (0, 255, 255)
+    return Image.fromarray(img)
+
+
+def draw_polygons(image: Image.Image, polygons: List[List[List[int]]], line_thickness: int = 2) -> Image.Image:
+    img = np.array(image)
+    for i, poly in enumerate(polygons):
+        poly_ = np.array(poly).astype(np.int32).reshape((-1))
+        poly_ = poly_.reshape(-1, 2)
+        cv2.polylines(img, [poly_.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=line_thickness)
+        ptColor = (0, 255, 255)
+    return Image.fromarray(img)
+
+
+def box_to_poly(box: List[List[int]]) -> List[List[int]]:
+    return [box[0], [box[0][0], box[1][1]], box[1], [box[1][0], box[0][1]]]
+
+
+def boxes_area(img: Image.Image, bboxes: List[List[List[int]]]) -> int:
+    img_s = img.size[0]*img.size[1]
+    total_S = 0
+    for box in bboxes:
+        pgon = Polygon(box_to_poly(box)) 
+        S = pgon.area
+        total_S+=S
+    return total_S/img_s
+
+
+def polygons_area(img: Image.Image, polygons: List[List[List[int]]]) -> int:
+    img_s = img.size[0]*img.size[1]
+    total_S = 0
+    for poly in polygons:
+        pgon = Polygon(poly) 
+        S = pgon.area
+        total_S+=S
+    return total_S/img_s

+ 30 - 1
README.md

@@ -1,2 +1,31 @@
 # CRAFT-text-detection
-An unofficial PyTorch implementation of CRAFT text detector with better and more user-friendly interface
+
+An unofficial PyTorch implementation of CRAFT text detector with better interface and fp16 support
+
+> This is not an official implementation. I partially use code from the [original repository](https://github.com/clovaai/CRAFT-pytorch)
+
+Main features of this implementation:
+- User-friendly interface 
+- Easier to integrate this model in your project
+- fp16 inference support
+- Automatically downloading weights from [huggingface](https://huggingface.co/boomb0om/CRAFT-text-detector/tree/main)
+
+## Installation
+
+```bash
+git clone https://github.com/boomb0om/CRAFT-text-detection
+cd CRAFT-text-detection/
+pip install -r requirements.txt
+```
+
+To test the model you can run `test.py` file.
+
+## Examples
+
+You can find usage examples in [example.ipynb](example.ipynb)
+
+![](images/cafe_sign.jpg)
+
+Detected polygons:
+
+![](images/result.jpg)

ファイルの差分が大きいため隠しています
+ 81 - 0
example.ipynb


BIN
images/cafe_sign.jpg


BIN
images/result.jpg


+ 5 - 0
requirements.txt

@@ -0,0 +1,5 @@
+torch
+numpy
+opencv-python
+shapely
+huggingface_hub

+ 15 - 0
test.py

@@ -0,0 +1,15 @@
+import torch
+from PIL import Image
+import numpy as np
+from CRAFT import CRAFTModel, draw_polygons
+
+
+if __name__ == "__main__":
+    model = CRAFTModel('weights/', 'cuda', use_refiner=True, fp16=True)
+    
+    img = Image.open('images/cafe_sign.jpg')
+    polygons = model.get_polygons(img)
+    
+    result = draw_polygons(img, polygons)
+    result.save('images/result.jpg')
+    print(f'Result saved to: images/result.jpg')

この差分においてかなりの量のファイルが変更されているため、一部のファイルを表示していません