ソースを参照

throw on bogus model name

Previously if a bogus model name was provided,
it would silently ignore the error and load u2net.

It still default to u2net when no model name is given, but throw if a bogus model name is given.
divinity76 3 ヶ月 前
コミット
58e10239c5
1 ファイル変更10 行追加3 行削除
  1. 10 3
      rembg/session_factory.py

+ 10 - 3
rembg/session_factory.py

@@ -22,20 +22,27 @@ def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
         *args: Additional positional arguments.
         **kwargs: Additional keyword arguments.
 
+    Raises:
+        ValueError: If no session class with the given `model_name` is found.
+
     Returns:
         BaseSession: The created session object.
     """
-    session_class: Type[BaseSession] = U2netSession
+    session_class: Optional[Type[BaseSession]] = None
 
     for sc in sessions_class:
         if sc.name() == model_name:
             session_class = sc
             break
 
+    if session_class is None:
+        raise ValueError(f"No session class found for model '{model_name}'")
+
     sess_opts = ort.SessionOptions()
 
     if "OMP_NUM_THREADS" in os.environ:
-        sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
-        sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
+        threads = int(os.environ["OMP_NUM_THREADS"])
+        sess_opts.inter_op_num_threads = threads
+        sess_opts.intra_op_num_threads = threads
 
     return session_class(model_name, sess_opts, *args, **kwargs)