Procházet zdrojové kódy

add torch.no_grad()

Youngmin Baek před 5 roky
rodič
revize
e332dd8b71
1 změnil soubory, kde provedl 4 přidání a 2 odebrání
  1. 4 2
      test.py

+ 4 - 2
test.py

@@ -81,7 +81,8 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, r
         x = x.cuda()
 
     # forward pass
-    y, feature = net(x)
+    with torch.no_grad():
+        y, feature = net(x)
 
     # make score and link map
     score_text = y[0,:,:,0].cpu().data.numpy()
@@ -89,7 +90,8 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, r
 
     # refine link
     if refine_net is not None:
-        y_refiner = refine_net(y, feature)
+        with torch.no_grad():
+            y_refiner = refine_net(y, feature)
         score_link = y_refiner[0,:,:,0].cpu().data.numpy()
 
     t0 = time.time() - t0