浏览代码

edited session factory to support models with more than one model

Flippchen 2 年之前
父节点
当前提交
106254c42d
共有 1 个文件被更改,包括 44 次插入13 次删除
  1. 44 13
      rembg/session_factory.py

+ 44 - 13
rembg/session_factory.py

@@ -12,9 +12,29 @@ from .session_base import BaseSession
 from .session_cloth import ClothSession
 from .session_dis import DisSession
 from .session_simple import SimpleSession
+from .session_sam import SamSession
+
+
+def download_model(url: str, md5: str, fname: str, path: Path):
+    pooch.retrieve(
+        url,
+        f"md5:{md5}",
+        fname=fname,
+        path=path,
+        progressbar=True,
+    )
 
 
 def new_session(model_name: str = "u2net") -> BaseSession:
+    # Define the model path
+    u2net_home = os.getenv(
+        "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
+    )
+
+    fname = f"{model_name}.onnx"
+    path = Path(u2net_home).expanduser()
+    full_path = Path(u2net_home).expanduser() / fname
+
     session_class: Type[BaseSession]
     md5 = "60024c5c889badc19c04ad937298a77b"
     url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
@@ -44,22 +64,33 @@ def new_session(model_name: str = "u2net") -> BaseSession:
         md5 = "fc16ebd8b0c10d971d3513d564d01e29"
         url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx"
         session_class = DisSession
+    elif model_name == "SAM":
+        path = Path(u2net_home).expanduser()
 
-    u2net_home = os.getenv(
-        "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
-    )
+        fname_encoder = f"{model_name}_encoder.onnx"
+        encoder_md5 = "13d97c5c79ab13ef86d67cbde5f1b250"
+        encoder_url = "https://github.com/Flippchen/rembg/releases/download/test/vit_b-encoder-quant.onnx"
 
-    fname = f"{model_name}.onnx"
-    path = Path(u2net_home).expanduser()
-    full_path = Path(u2net_home).expanduser() / fname
+        fname_decoder = f"{model_name}_decoder.onnx"
+        decoder_md5 = "fa3d1c36a3187d3de1c8deebf33dd127"
+        decoder_url = "https://github.com/Flippchen/rembg/releases/download/test/vit_b-decoder-quant.onnx"
 
-    pooch.retrieve(
-        url,
-        f"md5:{md5}",
-        fname=fname,
-        path=Path(u2net_home).expanduser(),
-        progressbar=True,
-    )
+
+        download_model(encoder_url, encoder_md5, fname_encoder, path)
+        download_model(decoder_url, decoder_md5, fname_decoder, path)
+
+        sess_opts = ort.SessionOptions()
+
+        if "OMP_NUM_THREADS" in os.environ:
+            sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
+
+        return SamSession(
+            model_name,
+            ort.InferenceSession(str(path / fname_encoder), providers=ort.get_available_providers(), sess_options=sess_opts),
+            ort.InferenceSession(str(path / fname_decoder), providers=ort.get_available_providers(), sess_options=sess_opts)
+        )
+
+    download_model(url, md5, fname, path)
 
     sess_opts = ort.SessionOptions()