ソースを参照

add torch.no_grad()

Youngmin Baek 5 年 前
コミット
e332dd8b71
1 ファイル変更4 行追加2 行削除
  1. 4 2
      test.py

+ 4 - 2
test.py

@@ -81,7 +81,8 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, r
         x = x.cuda()
 
     # forward pass
-    y, feature = net(x)
+    with torch.no_grad():
+        y, feature = net(x)
 
     # make score and link map
     score_text = y[0,:,:,0].cpu().data.numpy()
@@ -89,7 +90,8 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, r
 
     # refine link
     if refine_net is not None:
-        y_refiner = refine_net(y, feature)
+        with torch.no_grad():
+            y_refiner = refine_net(y, feature)
         score_link = y_refiner[0,:,:,0].cpu().data.numpy()
 
     t0 = time.time() - t0