浏览代码

Merge pull request #6 from mowshon/patch-1

Fix the TypeError: expected CUDA (got CPU) error
Daniel Gatis 5 年之前
父节点
当前提交
17336f660f
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 1 1
      src/rembg/u2net/detect.py

+ 1 - 1
src/rembg/u2net/detect.py

@@ -107,7 +107,7 @@ def predict(net, item):
     with torch.no_grad():
 
         if torch.cuda.is_available():
-            inputs_test = torch.cuda.FloatTensor(sample["image"].unsqueeze(0).float())
+            inputs_test = torch.cuda.FloatTensor(sample["image"].unsqueeze(0).cuda().float())
         else:
             inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())