Browse Source

add gradio

Daniel Gatis 2 years ago
parent
commit
f6159b45e0
4 changed files with 56 additions and 7 deletions
  1. 3 3
      README.md
  2. 51 4
      rembg/commands/s_command.py
  3. 1 0
      requirements.txt
  4. 1 0
      setup.py

+ 3 - 3
README.md

@@ -182,18 +182,18 @@ rembg p -w path/to/input path/to/output
 
 
 Used to start http server.
 Used to start http server.
 
 
-To see the complete endpoints documentation, go to: `http://localhost:5000/docs`.
+To see the complete endpoints documentation, go to: `http://localhost:5000/api`.
 
 
 Remove the background from an image url
 Remove the background from an image url
 
 
 ```
 ```
-curl -s "http://localhost:5000/?url=http://input.png" -o output.png
+curl -s "http://localhost:5000/api/remove?url=http://input.png" -o output.png
 ```
 ```
 
 
 Remove the background from an uploaded image
 Remove the background from an uploaded image
 
 
 ```
 ```
-curl -s -F file=@/path/to/input.jpg "http://localhost:5000"  -o output.png
+curl -s -F file=@/path/to/input.jpg "http://localhost:5000/api/remove"  -o output.png
 ```
 ```
 
 
 ### rembg `b`
 ### rembg `b`

+ 51 - 4
rembg/commands/s_command.py

@@ -1,8 +1,11 @@
 import json
 import json
+import os
+import webbrowser
 from typing import Optional, Tuple, cast
 from typing import Optional, Tuple, cast
 
 
 import aiohttp
 import aiohttp
 import click
 import click
+import gradio as gr
 import uvicorn
 import uvicorn
 from asyncer import asyncify
 from asyncer import asyncify
 from fastapi import Depends, FastAPI, File, Form, Query
 from fastapi import Depends, FastAPI, File, Form, Query
@@ -70,6 +73,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
             "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
             "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
         },
         },
         openapi_tags=tags_metadata,
         openapi_tags=tags_metadata,
+        docs_url="/api",
     )
     )
 
 
     app.add_middleware(
     app.add_middleware(
@@ -190,13 +194,18 @@ def s_command(port: int, log_level: str, threads: int) -> None:
                 only_mask=commons.om,
                 only_mask=commons.om,
                 post_process_mask=commons.ppm,
                 post_process_mask=commons.ppm,
                 bgcolor=commons.bgc,
                 bgcolor=commons.bgc,
-                **kwargs
+                **kwargs,
             ),
             ),
             media_type="image/png",
             media_type="image/png",
         )
         )
 
 
     @app.on_event("startup")
     @app.on_event("startup")
     def startup():
     def startup():
+        try:
+            webbrowser.open(f"http://localhost:{port}")
+        except:
+            pass
+
         if threads is not None:
         if threads is not None:
             from anyio import CapacityLimiter
             from anyio import CapacityLimiter
             from anyio.lowlevel import RunVar
             from anyio.lowlevel import RunVar
@@ -204,7 +213,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
             RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
             RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
 
 
     @app.get(
     @app.get(
-        path="/",
+        path="/api/remove",
         tags=["Background Removal"],
         tags=["Background Removal"],
         summary="Remove from URL",
         summary="Remove from URL",
         description="Removes the background from an image obtained by retrieving an URL.",
         description="Removes the background from an image obtained by retrieving an URL.",
@@ -221,7 +230,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
                 return await asyncify(im_without_bg)(file, commons)
                 return await asyncify(im_without_bg)(file, commons)
 
 
     @app.post(
     @app.post(
-        path="/",
+        path="/api/remove",
         tags=["Background Removal"],
         tags=["Background Removal"],
         summary="Remove from Stream",
         summary="Remove from Stream",
         description="Removes the background from an image sent within the request itself.",
         description="Removes the background from an image sent within the request itself.",
@@ -235,4 +244,42 @@ def s_command(port: int, log_level: str, threads: int) -> None:
     ):
     ):
         return await asyncify(im_without_bg)(file, commons)  # type: ignore
         return await asyncify(im_without_bg)(file, commons)  # type: ignore
 
 
-    uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
+    def gr_app(app):
+        def inference(input_path, model):
+            output_path = "output.png"
+            with open(input_path, "rb") as i:
+                with open(output_path, "wb") as o:
+                    input = i.read()
+                    output = remove(input, session=new_session(model))
+                    o.write(output)
+            return os.path.join(output_path)
+
+        interface = gr.Interface(
+            inference,
+            [
+                gr.components.Image(type="filepath", label="Input"),
+                gr.components.Dropdown(
+                    [
+                        "u2net",
+                        "u2netp",
+                        "u2net_human_seg",
+                        "u2net_cloth_seg",
+                        "silueta",
+                        "isnet-general-use",
+                        "isnet-anime",
+                    ],
+                    value="u2net",
+                    label="Models",
+                ),
+            ],
+            gr.components.Image(type="filepath", label="Output"),
+        )
+
+        interface.queue(concurrency_count=3)
+        app = gr.mount_gradio_app(app, interface, path="/")
+        return app
+
+    print(f"To access the API documentation, go to http://localhost:{port}/api")
+    print(f"To access the UI, go to http://localhost:{port}")
+
+    uvicorn.run(gr_app(app), host="0.0.0.0", port=port, log_level=log_level)

+ 1 - 0
requirements.txt

@@ -3,6 +3,7 @@ asyncer==0.0.2
 click==8.1.3
 click==8.1.3
 fastapi==0.92.0
 fastapi==0.92.0
 filetype==1.2.0
 filetype==1.2.0
+gradio==3.32.0
 imagehash==4.3.1
 imagehash==4.3.1
 numpy==1.23.5
 numpy==1.23.5
 onnxruntime==1.14.1
 onnxruntime==1.14.1

+ 1 - 0
setup.py

@@ -42,6 +42,7 @@ setup(
         "click>=8.1.3",
         "click>=8.1.3",
         "fastapi>=0.92.0",
         "fastapi>=0.92.0",
         "filetype>=1.2.0",
         "filetype>=1.2.0",
+        "gradio>=3.32.0",
         "imagehash>=4.3.1",
         "imagehash>=4.3.1",
         "numpy>=1.23.5",
         "numpy>=1.23.5",
         "onnxruntime>=1.14.1",
         "onnxruntime>=1.14.1",