Daniel Gatis 2 years ago
parent
commit
2ef798d152

+ 4 - 0
rembg/sessions/base.py

@@ -46,6 +46,10 @@ class BaseSession:
     def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
         raise NotImplementedError
 
+    @classmethod
+    def checksum_disabled(cls, *args, **kwargs):
+        return os.getenv("MODEL_CHECKSUM_DISABLED", None) != None
+
     @classmethod
     def u2net_home(cls, *args, **kwargs):
         return os.path.expanduser(

+ 4 - 2
rembg/sessions/dis.py

@@ -34,9 +34,11 @@ class DisSession(BaseSession):
         fname = f"{cls.name()}.onnx"
         pooch.retrieve(
             "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
-            "md5:fc16ebd8b0c10d971d3513d564d01e29",
+            None
+            if cls.checksum_disabled(*args, **kwargs)
+            else "md5:fc16ebd8b0c10d971d3513d564d01e29",
             fname=fname,
-            path=cls.u2net_home(),
+            path=cls.u2net_home(*args, **kwargs),
             progressbar=True,
         )
 

+ 8 - 4
rembg/sessions/sam.py

@@ -141,17 +141,21 @@ class SamSession(BaseSession):
 
         pooch.retrieve(
             "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
-            "md5:13d97c5c79ab13ef86d67cbde5f1b250",
+            None
+            if cls.checksum_disabled(*args, **kwargs)
+            else "md5:13d97c5c79ab13ef86d67cbde5f1b250",
             fname=fname_encoder,
-            path=cls.u2net_home(),
+            path=cls.u2net_home(*args, **kwargs),
             progressbar=True,
         )
 
         pooch.retrieve(
             "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
-            "md5:fa3d1c36a3187d3de1c8deebf33dd127",
+            None
+            if cls.checksum_disabled(*args, **kwargs)
+            else "md5:fa3d1c36a3187d3de1c8deebf33dd127",
             fname=fname_decoder,
-            path=cls.u2net_home(),
+            path=cls.u2net_home(*args, **kwargs),
             progressbar=True,
         )
 

+ 4 - 2
rembg/sessions/silueta.py

@@ -36,9 +36,11 @@ class SiluetaSession(BaseSession):
         fname = f"{cls.name()}.onnx"
         pooch.retrieve(
             "https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
-            "md5:55e59e0d8062d2f5d013f4725ee84782",
+            None
+            if cls.checksum_disabled(*args, **kwargs)
+            else "md5:55e59e0d8062d2f5d013f4725ee84782",
             fname=fname,
-            path=cls.u2net_home(),
+            path=cls.u2net_home(*args, **kwargs),
             progressbar=True,
         )
 

+ 4 - 2
rembg/sessions/u2net.py

@@ -36,9 +36,11 @@ class U2netSession(BaseSession):
         fname = f"{cls.name()}.onnx"
         pooch.retrieve(
             "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
-            "md5:60024c5c889badc19c04ad937298a77b",
+            None
+            if cls.checksum_disabled(*args, **kwargs)
+            else "md5:60024c5c889badc19c04ad937298a77b",
             fname=fname,
-            path=cls.u2net_home(),
+            path=cls.u2net_home(*args, **kwargs),
             progressbar=True,
         )
 

+ 4 - 2
rembg/sessions/u2net_cloth_seg.py

@@ -97,9 +97,11 @@ class Unet2ClothSession(BaseSession):
         fname = f"{cls.name()}.onnx"
         pooch.retrieve(
             "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
-            "md5:2434d1f3cb744e0e49386c906e5a08bb",
+            None
+            if cls.checksum_disabled(*args, **kwargs)
+            else "md5:2434d1f3cb744e0e49386c906e5a08bb",
             fname=fname,
-            path=cls.u2net_home(),
+            path=cls.u2net_home(*args, **kwargs),
             progressbar=True,
         )
 

+ 4 - 2
rembg/sessions/u2net_human_seg.py

@@ -36,9 +36,11 @@ class U2netHumanSegSession(BaseSession):
         fname = f"{cls.name()}.onnx"
         pooch.retrieve(
             "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
-            "md5:c09ddc2e0104f800e3e1bb4652583d1f",
+            None
+            if cls.checksum_disabled(*args, **kwargs)
+            else "md5:c09ddc2e0104f800e3e1bb4652583d1f",
             fname=fname,
-            path=cls.u2net_home(),
+            path=cls.u2net_home(*args, **kwargs),
             progressbar=True,
         )
 

+ 4 - 2
rembg/sessions/u2netp.py

@@ -36,9 +36,11 @@ class U2netpSession(BaseSession):
         fname = f"{cls.name()}.onnx"
         pooch.retrieve(
             "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
-            "md5:8e83ca70e441ab06c318d82300c84806",
+            None
+            if cls.checksum_disabled(*args, **kwargs)
+            else "md5:8e83ca70e441ab06c318d82300c84806",
             fname=fname,
-            path=cls.u2net_home(),
+            path=cls.u2net_home(*args, **kwargs),
             progressbar=True,
         )