Browse Source

added support for isnet model

Flippchen 2 years ago
parent
commit
3c20805f85
3 changed files with 17 additions and 6 deletions
  1. 3 3
      rembg/session_base.py
  2. 5 1
      rembg/session_factory.py
  3. 9 2
      rembg/session_simple.py

+ 3 - 3
rembg/session_base.py

@@ -7,18 +7,18 @@ from PIL.Image import Image as PILImage
 
 
 class BaseSession:
-    def __init__(self, model_name: str, inner_session: ort.InferenceSession):
+    def __init__(self, model_name: str, inner_session: ort.InferenceSession, output_size: Tuple[int, int] = (320, 320)):
         self.model_name = model_name
         self.inner_session = inner_session
+        self.output_size = output_size
 
     def normalize(
         self,
         img: PILImage,
         mean: Tuple[float, float, float],
         std: Tuple[float, float, float],
-        size: Tuple[int, int],
     ) -> Dict[str, np.ndarray]:
-        im = img.convert("RGB").resize(size, Image.LANCZOS)
+        im = img.convert("RGB").resize(self.output_size, Image.LANCZOS)
 
         im_ary = np.array(im)
         im_ary = im_ary / np.max(im_ary)

+ 5 - 1
rembg/session_factory.py

@@ -13,7 +13,10 @@ from .session_cloth import ClothSession
 from .session_simple import SimpleSession
 
 
-def new_session(model_name: str = "u2net") -> BaseSession:
+def new_session(model_name: str = "u2net", output_size=None) -> BaseSession:
+    # Set output size if not set ( because isnet hat a different size )
+    output_size = output_size or (320, 320)
+
     session_class: Type[BaseSession]
     md5 = "60024c5c889badc19c04ad937298a77b"
     url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
@@ -68,4 +71,5 @@ def new_session(model_name: str = "u2net") -> BaseSession:
             providers=ort.get_available_providers(),
             sess_options=sess_opts,
         ),
+        output_size=output_size
     )

+ 9 - 2
rembg/session_simple.py

@@ -1,4 +1,4 @@
-from typing import List
+from typing import List, Tuple
 
 import numpy as np
 from PIL import Image
@@ -9,10 +9,17 @@ from .session_base import BaseSession
 
 class SimpleSession(BaseSession):
     def predict(self, img: PILImage) -> List[PILImage]:
+        if self.model_name == "isnet-general-use":
+            mean = (0.5, 0.5, 0.5)
+            std = (1., 1., 1.)
+        else:
+            mean = (0.485, 0.456, 0.406)
+            std = (0.229, 0.224, 0.225)
+
         ort_outs = self.inner_session.run(
             None,
             self.normalize(
-                img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
+                img, mean, std
             ),
         )