|
@@ -90,7 +90,10 @@ def remove(
|
|
|
if session is None:
|
|
|
session = ort_session("u2net")
|
|
|
|
|
|
- mask = predict(session, np.array(img.convert("RGB"))).convert("L")
|
|
|
+ img = img.convert("RGB")
|
|
|
+
|
|
|
+ mask = predict(session, np.array(img))
|
|
|
+ mask = mask.convert("L")
|
|
|
mask = mask.resize(img.size, Image.LANCZOS)
|
|
|
|
|
|
if only_mask:
|