|
@@ -13,7 +13,10 @@ from .session_cloth import ClothSession
|
|
from .session_simple import SimpleSession
|
|
from .session_simple import SimpleSession
|
|
|
|
|
|
|
|
|
|
-def new_session(model_name: str = "u2net") -> BaseSession:
|
|
|
|
|
|
+def new_session(model_name: str = "u2net", output_size=None) -> BaseSession:
|
|
|
|
+ # Set output size if not set ( because isnet hat a different size )
|
|
|
|
+ output_size = output_size or (320, 320)
|
|
|
|
+
|
|
session_class: Type[BaseSession]
|
|
session_class: Type[BaseSession]
|
|
md5 = "60024c5c889badc19c04ad937298a77b"
|
|
md5 = "60024c5c889badc19c04ad937298a77b"
|
|
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
|
|
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
|
|
@@ -39,6 +42,10 @@ def new_session(model_name: str = "u2net") -> BaseSession:
|
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx"
|
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx"
|
|
)
|
|
)
|
|
session_class = SimpleSession
|
|
session_class = SimpleSession
|
|
|
|
+ elif model_name == "isnet-general-use":
|
|
|
|
+ md5 = "fc16ebd8b0c10d971d3513d564d01e29"
|
|
|
|
+ url = "https://github.com/Flippchen/rembg/releases/download/test/isnet-general-use.onnx"
|
|
|
|
+ session_class = SimpleSession
|
|
|
|
|
|
u2net_home = os.getenv(
|
|
u2net_home = os.getenv(
|
|
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
|
|
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
|
|
@@ -68,4 +75,5 @@ def new_session(model_name: str = "u2net") -> BaseSession:
|
|
providers=ort.get_available_providers(),
|
|
providers=ort.get_available_providers(),
|
|
sess_options=sess_opts,
|
|
sess_options=sess_opts,
|
|
),
|
|
),
|
|
|
|
+ output_size=output_size
|
|
)
|
|
)
|