Browse Source

add u2net_human_seg model

Daniel Gatis 4 years ago
parent
commit
fd3b5a5e8f
4 changed files with 21 additions and 3 deletions
  1. 2 0
      src/rembg/bg.py
  2. 1 1
      src/rembg/cmd/cli.py
  3. 1 1
      src/rembg/cmd/server.py
  4. 17 1
      src/rembg/u2net/detect.py

+ 2 - 0
src/rembg/bg.py

@@ -72,6 +72,8 @@ def naive_cutout(img, mask):
 def get_model(model_name):
     if model_name == "u2netp":
         return detect.load_model(model_name="u2netp")
+    if model_name == "u2net_human_seg":
+        return detect.load_model(model_name="u2net_human_seg")
     else:
         return detect.load_model(model_name="u2net")
 

+ 1 - 1
src/rembg/cmd/cli.py

@@ -17,7 +17,7 @@ def main():
         "--model",
         default="u2net",
         type=str,
-        choices=("u2net", "u2netp"),
+        choices=("u2net", "u2net_human_seg", "u2netp"),
         help="The model name.",
     )
 

+ 1 - 1
src/rembg/cmd/server.py

@@ -38,7 +38,7 @@ def index():
     az = request.values.get("az", type=int, default=1000)
 
     model = request.args.get("model", type=str, default="u2net")
-    if model not in ("u2net", "u2netp"):
+    if model not in ("u2net", "u2net_human_seg", "u2netp"):
         return {"error": "invalid query param 'model'"}, 400
 
     try:

+ 17 - 1
src/rembg/u2net/detect.py

@@ -85,8 +85,24 @@ def load_model(model_name: str = "u2net"):
                 "u2net.pth",
                 path,
             )
+
+    elif model_name == "u2net_human_seg":
+        net = u2net.U2NET(3, 1)
+        path = os.environ.get(
+            "U2NET_PATH",
+            os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
+        )
+        if (
+            not os.path.exists(path)
+            or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
+        ):
+            download_file_from_google_drive(
+                "1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
+                "u2net_human_seg.pth",
+                path,
+            )
     else:
-        print("Choose between u2net or u2netp", file=sys.stderr)
+        print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
 
     try:
         if torch.cuda.is_available():