Daniel Gatis 4 سال پیش
والد
کامیت
ca4be25f90
4فایلهای تغییر یافته به همراه40 افزوده شده و 12 حذف شده
  1. 9 1
      README.md
  2. 1 1
      setup.py
  3. 20 8
      src/rembg/cmd/cli.py
  4. 10 2
      src/rembg/cmd/server.py

+ 9 - 1
README.md

@@ -80,7 +80,15 @@ rembg -o path/to/output.png path/to/input.png
 
 Remove the background from all images in a folder
 ```bash
-rembg -p path/to/inputs
+rembg -p path/to/inputs path/to/output
+```
+
+### Add a custom model
+
+Copy the `custom-model.pth` file to `~/.u2net` and run:
+
+```bash
+curl -s http://input.png | rembg -m custom-model > output.png
 ```
 
 ### Usage as a server

+ 1 - 1
setup.py

@@ -11,7 +11,7 @@ with open("requirements.txt") as f:
 
 setup(
     name="rembg",
-    version="1.0.24",
+    version="1.0.25",
     description="Remove image background",
     long_description=long_description,
     long_description_content_type="text/markdown",

+ 20 - 8
src/rembg/cmd/cli.py

@@ -10,6 +10,13 @@ from ..bg import remove
 
 
 def main():
+    model_path = os.environ.get(
+        "U2NETP_PATH",
+        os.path.expanduser(os.path.join("~", ".u2net")),
+    )
+    model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))]
+
+
     ap = argparse.ArgumentParser()
 
     ap.add_argument(
@@ -17,7 +24,7 @@ def main():
         "--model",
         default="u2net",
         type=str,
-        choices=("u2net", "u2net_human_seg", "u2netp"),
+        choices=model_choices,
         help="The model name.",
     )
 
@@ -66,8 +73,8 @@ def main():
     ap.add_argument(
         "-p",
         "--path",
-        nargs="+",
-        help="Path of a file or a folder of files.",
+        nargs=2,
+        help="An input folder and an output folder.",
     )
 
     ap.add_argument(
@@ -94,15 +101,20 @@ def main():
 
     if args.path:
         full_paths = [os.path.abspath(path) for path in args.path]
+        
+        input_paths = [full_paths[0]] 
+        output_path = full_paths[1]
+
+        if not os.path.exists(output_path):
+            os.makedirs(output_path)
+
         files = set()
 
-        for path in full_paths:
+        for path in input_paths:
             if os.path.isfile(path):
                 files.add(path)
             else:
-                full_paths += set(glob.glob(path + "/*")) - set(
-                    glob.glob(path + "/*.out.png")
-                )
+                input_paths += set(glob.glob(path + "/*"))
 
         for fi in tqdm(files):
             fi_type = filetype.guess(fi)
@@ -113,7 +125,7 @@ def main():
                 continue
 
             with open(fi, "rb") as input:
-                with open(os.path.splitext(fi)[0] + ".out.png", "wb") as output:
+                with open(os.path.join(output_path, os.path.splitext(os.path.basename(fi))[0] + ".png"), "wb") as output:
                     w(
                         output,
                         remove(

+ 10 - 2
src/rembg/cmd/server.py

@@ -1,3 +1,5 @@
+import os
+import glob
 import argparse
 from io import BytesIO
 from urllib.parse import unquote_plus
@@ -38,8 +40,14 @@ def index():
     az = request.values.get("az", type=int, default=1000)
 
     model = request.args.get("model", type=str, default="u2net")
-    if model not in ("u2net", "u2net_human_seg", "u2netp"):
-        return {"error": "invalid query param 'model'"}, 400
+    model_path = os.environ.get(
+        "U2NETP_PATH",
+        os.path.expanduser(os.path.join("~", ".u2net")),
+    )
+    model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))]
+
+    if model not in model_choices:
+        return {"error": f"invalid query param 'model'. Available options are {model_choices}"}, 400
 
     try:
         return send_file(