|
@@ -18,7 +18,9 @@ def new_session(model_name: str) -> BaseSession:
|
|
|
|
|
|
if model_name == "u2netp":
|
|
|
md5 = "8e83ca70e441ab06c318d82300c84806"
|
|
|
- url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx"
|
|
|
+ url = (
|
|
|
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx"
|
|
|
+ )
|
|
|
session_class = SimpleSession
|
|
|
elif model_name == "u2net":
|
|
|
md5 = "60024c5c889badc19c04ad937298a77b"
|
|
@@ -50,7 +52,7 @@ def new_session(model_name: str) -> BaseSession:
|
|
|
f"md5:{md5}",
|
|
|
fname=fname,
|
|
|
path=Path(u2net_home).expanduser(),
|
|
|
- progressbar=True
|
|
|
+ progressbar=True,
|
|
|
)
|
|
|
|
|
|
sess_opts = ort.SessionOptions()
|
|
@@ -61,6 +63,8 @@ def new_session(model_name: str) -> BaseSession:
|
|
|
return session_class(
|
|
|
model_name,
|
|
|
ort.InferenceSession(
|
|
|
- str(full_path), providers=ort.get_available_providers(), sess_options=sess_opts
|
|
|
+ str(full_path),
|
|
|
+ providers=ort.get_available_providers(),
|
|
|
+ sess_options=sess_opts,
|
|
|
),
|
|
|
)
|