Przeglądaj źródła

Merge pull request #6 from mowshon/patch-1

Fix the TypeError: expected CUDA (got CPU) error
Daniel Gatis 5 lat temu
rodzic
commit
17336f660f
1 zmienionych plików z 1 dodań i 1 usunięć
  1. 1 1
      src/rembg/u2net/detect.py

+ 1 - 1
src/rembg/u2net/detect.py

@@ -107,7 +107,7 @@ def predict(net, item):
     with torch.no_grad():
 
         if torch.cuda.is_available():
-            inputs_test = torch.cuda.FloatTensor(sample["image"].unsqueeze(0).float())
+            inputs_test = torch.cuda.FloatTensor(sample["image"].unsqueeze(0).cuda().float())
         else:
             inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())