瀏覽代碼

fix docker

Daniel Gatis 2 年之前
父節點
當前提交
ccaa9005af
共有 2 個文件被更改,包括 6 次插入2 次删除
  1. 5 0
      rembg/sessions/base.py
  2. 1 2
      rembg/sessions/u2net_custom.py

+ 5 - 0
rembg/sessions/base.py

@@ -28,6 +28,11 @@ class BaseSession:
         else:
             self.providers.extend(_providers)
 
+        model_path = kwargs.get("model_path")
+
+        if model_path is None:
+            raise ValueError("model_path is required")
+
         self.inner_session = ort.InferenceSession(
             str(self.__class__.download_models(*args, **kwargs)),
             providers=self.providers,

+ 1 - 2
rembg/sessions/u2net_custom.py

@@ -34,9 +34,8 @@ class U2netCustomSession(BaseSession):
     @classmethod
     def download_models(cls, *args, **kwargs):
         model_path = kwargs.get("model_path")
-
         if model_path is None:
-            raise ValueError("model_path is required")
+            return
 
         return os.path.abspath(os.path.expanduser(model_path))