فهرست منبع

add torch.no_grad()

Youngmin Baek 5 سال پیش
والد
کامیت
e332dd8b71
1فایلهای تغییر یافته به همراه4 افزوده شده و 2 حذف شده
  1. 4 2
      test.py

+ 4 - 2
test.py

@@ -81,7 +81,8 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, r
         x = x.cuda()
         x = x.cuda()
 
 
     # forward pass
     # forward pass
-    y, feature = net(x)
+    with torch.no_grad():
+        y, feature = net(x)
 
 
     # make score and link map
     # make score and link map
     score_text = y[0,:,:,0].cpu().data.numpy()
     score_text = y[0,:,:,0].cpu().data.numpy()
@@ -89,7 +90,8 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, r
 
 
     # refine link
     # refine link
     if refine_net is not None:
     if refine_net is not None:
-        y_refiner = refine_net(y, feature)
+        with torch.no_grad():
+            y_refiner = refine_net(y, feature)
         score_link = y_refiner[0,:,:,0].cpu().data.numpy()
         score_link = y_refiner[0,:,:,0].cpu().data.numpy()
 
 
     t0 = time.time() - t0
     t0 = time.time() - t0