Просмотр исходного кода

fix: fix model urls and polygons

boomb0om 1 год назад
Родитель
Сommit
20f0c94d6b
2 измененных файлов с 12 добавлено и 4 удалено
  1. 8 1
      CRAFT/basenet/vgg16_bn.py
  2. 4 3
      CRAFT/model.py

+ 8 - 1
CRAFT/basenet/vgg16_bn.py

@@ -4,7 +4,13 @@ import torch
 import torch.nn as nn
 import torch.nn.init as init
 from torchvision import models
-from torchvision.models.vgg import model_urls
+#from torchvision.models.vgg import model_urls
+
+
+model_urls = {
+    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
+}
+
 
 def init_weights(modules):
     for m in modules:
@@ -19,6 +25,7 @@ def init_weights(modules):
             m.weight.data.normal_(0, 0.01)
             m.bias.data.zero_()
 
+
 class vgg16_bn(torch.nn.Module):
     def __init__(self, pretrained=True, freeze=True):
         super(vgg16_bn, self).__init__()

+ 4 - 3
CRAFT/model.py

@@ -111,11 +111,12 @@ class CRAFTModel:
             self.text_threshold, self.link_threshold, 
             self.low_text, True
         )
-        
         boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
-        polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
         for k in range(len(polys)):
-            if polys[k] is None: polys[k] = boxes[k]
+            if polys[k] is None: 
+                polys[k] = boxes[k]
+            else:
+                polys[k] = adjustResultCoordinates(polys[k], ratio_w, ratio_h)
 
         res = []
         for poly in polys: