Răsfoiți Sursa

Merge pull request #6 from mowshon/patch-1

Fix the TypeError: expected CUDA (got CPU) error
Daniel Gatis 5 ani în urmă
părinte
comite
17336f660f
1 a modificat fișierele cu 1 adăugiri și 1 ștergeri
  1. 1 1
      src/rembg/u2net/detect.py

+ 1 - 1
src/rembg/u2net/detect.py

@@ -107,7 +107,7 @@ def predict(net, item):
     with torch.no_grad():
     with torch.no_grad():
 
 
         if torch.cuda.is_available():
         if torch.cuda.is_available():
-            inputs_test = torch.cuda.FloatTensor(sample["image"].unsqueeze(0).float())
+            inputs_test = torch.cuda.FloatTensor(sample["image"].unsqueeze(0).cuda().float())
         else:
         else:
             inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())
             inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())