Procházet zdrojové kódy

add trained model on IC15 and link refiner

root před 6 roky
rodič
revize
3cd65f5a7b
4 změnil soubory, kde provedl 106 přidání a 6 odebrání
  1. 13 1
      README.md
  2. 1 1
      craft_utils.py
  3. 65 0
      refinenet.py
  4. 27 4
      test.py

+ 13 - 1
README.md

@@ -15,6 +15,7 @@ PyTorch implementation for CRAFT text detector that effectively detect text area
 ## Updates
 **13 Jun, 2019**: Initial update
 **20 Jul, 2019**: Added post-processing for polygon result
+**28 Sep, 2019**: Added the trained model on IC15 and the link refiner
 
 
 ## Getting started
@@ -33,7 +34,14 @@ The code for training is not included in this repository, and we cannot release
 
 
 ### Test instruction using pretrained model
-- Download [Trained Model on IC13,IC17](https://drive.google.com/open?id=1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ)
+- Download the trained models
+ 
+ *Model name* | *Used datasets* | *Languages* | *Purpose* | *Model Link* |
+ | :--- | :--- | :--- | :--- | :--- |
+General | SynthText, IC13, IC17 | Eng + MLT | For general purpose | [Click](https://drive.google.com/open?id=1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ)
+IC15 | SynthText, IC15 | Eng | For IC15 only | [Click](https://drive.google.com/open?id=1i2R7UIUqmkUtF0jv_3MXTqmQ_9wuAnLf)
+LinkRefiner | CTW1500 | - | Used with the General Model | [Click](https://drive.google.com/open?id=1XSaFwBkOaFOdtk4Ane3DFyJGPRw6v5bO)
+
 * Run with pretrained model
 ``` (with python 3.7)
 python test.py --trained_model=[weightfile] --test_folder=[folder path to test images]
@@ -46,11 +54,15 @@ The result image and socre maps will be saved to `./result` by default.
 * `--text_threshold`: text confidence threshold
 * `--low_text`: text low-bound score
 * `--link_threshold`: link confidence threshold
+* `--cuda`: use cuda for inference (default:True)
 * `--canvas_size`: max image size for inference
 * `--mag_ratio`: image magnification ratio
 * `--poly`: enable polygon type result
 * `--show_time`: show processing time
 * `--test_folder`: folder path to input images
+* `--refine`: use link refiner for sentense-level dataset
+* `--refiner_model`: pretrained refiner model
+
 
 ## Links
 - WebDemo : https://demo.ocr.clova.ai/

+ 1 - 1
craft_utils.py

@@ -90,7 +90,7 @@ def getPoly_core(boxes, labels, mapper, linkmap):
     for k, box in enumerate(boxes):
         # size filter for small instance
         w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1)
-        if w < 30 or h < 30:
+        if w < 10 or h < 10:
             polys.append(None); continue
 
         # warp image

+ 65 - 0
refinenet.py

@@ -0,0 +1,65 @@
+"""  
+Copyright (c) 2019-present NAVER Corp.
+MIT License
+"""
+
+# -*- coding: utf-8 -*-
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+from basenet.vgg16_bn import init_weights
+
+
+class RefineNet(nn.Module):
+    def __init__(self):
+        super(RefineNet, self).__init__()
+
+        self.last_conv = nn.Sequential(
+            nn.Conv2d(34, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
+            nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
+            nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)
+        )
+
+        self.aspp1 = nn.Sequential(
+            nn.Conv2d(64, 128, kernel_size=3, dilation=6, padding=6), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 1, kernel_size=1)
+        )
+
+        self.aspp2 = nn.Sequential(
+            nn.Conv2d(64, 128, kernel_size=3, dilation=12, padding=12), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 1, kernel_size=1)
+        )
+
+        self.aspp3 = nn.Sequential(
+            nn.Conv2d(64, 128, kernel_size=3, dilation=18, padding=18), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 1, kernel_size=1)
+        )
+
+        self.aspp4 = nn.Sequential(
+            nn.Conv2d(64, 128, kernel_size=3, dilation=24, padding=24), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
+            nn.Conv2d(128, 1, kernel_size=1)
+        )
+
+        init_weights(self.last_conv.modules())
+        init_weights(self.aspp1.modules())
+        init_weights(self.aspp2.modules())
+        init_weights(self.aspp3.modules())
+        init_weights(self.aspp4.modules())
+
+    def forward(self, y, upconv4):
+        refine = torch.cat([y.permute(0,3,1,2), upconv4], dim=1)
+        refine = self.last_conv(refine)
+
+        aspp1 = self.aspp1(refine)
+        aspp2 = self.aspp2(refine)
+        aspp3 = self.aspp3(refine)
+        aspp4 = self.aspp4(refine)
+
+        #out = torch.add([aspp1, aspp2, aspp3, aspp4], dim=1)
+        out = aspp1 + aspp2 + aspp3 + aspp4
+        return out.permute(0, 2, 3, 1)  # , refine.permute(0,2,3,1)

+ 27 - 4
test.py

@@ -47,12 +47,14 @@ parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type
 parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
 parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
 parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
-parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda to train model')
+parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')
 parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
 parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
 parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
 parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
 parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images')
+parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner')
+parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')
 
 args = parser.parse_args()
 
@@ -64,7 +66,7 @@ result_folder = './result/'
 if not os.path.isdir(result_folder):
     os.mkdir(result_folder)
 
-def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly):
+def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None):
     t0 = time.time()
 
     # resize
@@ -79,12 +81,17 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly):
         x = x.cuda()
 
     # forward pass
-    y, _ = net(x)
+    y, feature = net(x)
 
     # make score and link map
     score_text = y[0,:,:,0].cpu().data.numpy()
     score_link = y[0,:,:,1].cpu().data.numpy()
 
+    # refine link
+    if refine_net is not None:
+        y_refiner = refine_net(y, feature)
+        score_link = y_refiner[0,:,:,0].cpu().data.numpy()
+
     t0 = time.time() - t0
     t1 = time.time()
 
@@ -127,6 +134,22 @@ if __name__ == '__main__':
 
     net.eval()
 
+    # LinkRefiner
+    refine_net = None
+    if args.refine:
+        from refinenet import RefineNet
+        refine_net = RefineNet()
+        print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
+        if args.cuda:
+            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
+            refine_net = refine_net.cuda()
+            refine_net = torch.nn.DataParallel(refine_net)
+        else:
+            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))
+
+        refine_net.eval()
+        args.poly = True
+
     t = time.time()
 
     # load data
@@ -134,7 +157,7 @@ if __name__ == '__main__':
         print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
         image = imgproc.loadImage(image_path)
 
-        bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly)
+        bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net)
 
         # save score text
         filename, file_ext = os.path.splitext(os.path.basename(image_path))