Browse Source

add gradio

Daniel Gatis 2 years ago
parent
commit
f6159b45e0
4 changed files with 56 additions and 7 deletions
  1. 3 3
      README.md
  2. 51 4
      rembg/commands/s_command.py
  3. 1 0
      requirements.txt
  4. 1 0
      setup.py

+ 3 - 3
README.md

@@ -182,18 +182,18 @@ rembg p -w path/to/input path/to/output
 
 Used to start http server.
 
-To see the complete endpoints documentation, go to: `http://localhost:5000/docs`.
+To see the complete endpoints documentation, go to: `http://localhost:5000/api`.
 
 Remove the background from an image url
 
 ```
-curl -s "http://localhost:5000/?url=http://input.png" -o output.png
+curl -s "http://localhost:5000/api/remove?url=http://input.png" -o output.png
 ```
 
 Remove the background from an uploaded image
 
 ```
-curl -s -F file=@/path/to/input.jpg "http://localhost:5000"  -o output.png
+curl -s -F file=@/path/to/input.jpg "http://localhost:5000/api/remove"  -o output.png
 ```
 
 ### rembg `b`

+ 51 - 4
rembg/commands/s_command.py

@@ -1,8 +1,11 @@
 import json
+import os
+import webbrowser
 from typing import Optional, Tuple, cast
 
 import aiohttp
 import click
+import gradio as gr
 import uvicorn
 from asyncer import asyncify
 from fastapi import Depends, FastAPI, File, Form, Query
@@ -70,6 +73,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
             "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
         },
         openapi_tags=tags_metadata,
+        docs_url="/api",
     )
 
     app.add_middleware(
@@ -190,13 +194,18 @@ def s_command(port: int, log_level: str, threads: int) -> None:
                 only_mask=commons.om,
                 post_process_mask=commons.ppm,
                 bgcolor=commons.bgc,
-                **kwargs
+                **kwargs,
             ),
             media_type="image/png",
         )
 
     @app.on_event("startup")
     def startup():
+        try:
+            webbrowser.open(f"http://localhost:{port}")
+        except:
+            pass
+
         if threads is not None:
             from anyio import CapacityLimiter
             from anyio.lowlevel import RunVar
@@ -204,7 +213,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
             RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
 
     @app.get(
-        path="/",
+        path="/api/remove",
         tags=["Background Removal"],
         summary="Remove from URL",
         description="Removes the background from an image obtained by retrieving an URL.",
@@ -221,7 +230,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
                 return await asyncify(im_without_bg)(file, commons)
 
     @app.post(
-        path="/",
+        path="/api/remove",
         tags=["Background Removal"],
         summary="Remove from Stream",
         description="Removes the background from an image sent within the request itself.",
@@ -235,4 +244,42 @@ def s_command(port: int, log_level: str, threads: int) -> None:
     ):
         return await asyncify(im_without_bg)(file, commons)  # type: ignore
 
-    uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
+    def gr_app(app):
+        def inference(input_path, model):
+            output_path = "output.png"
+            with open(input_path, "rb") as i:
+                with open(output_path, "wb") as o:
+                    input = i.read()
+                    output = remove(input, session=new_session(model))
+                    o.write(output)
+            return os.path.join(output_path)
+
+        interface = gr.Interface(
+            inference,
+            [
+                gr.components.Image(type="filepath", label="Input"),
+                gr.components.Dropdown(
+                    [
+                        "u2net",
+                        "u2netp",
+                        "u2net_human_seg",
+                        "u2net_cloth_seg",
+                        "silueta",
+                        "isnet-general-use",
+                        "isnet-anime",
+                    ],
+                    value="u2net",
+                    label="Models",
+                ),
+            ],
+            gr.components.Image(type="filepath", label="Output"),
+        )
+
+        interface.queue(concurrency_count=3)
+        app = gr.mount_gradio_app(app, interface, path="/")
+        return app
+
+    print(f"To access the API documentation, go to http://localhost:{port}/api")
+    print(f"To access the UI, go to http://localhost:{port}")
+
+    uvicorn.run(gr_app(app), host="0.0.0.0", port=port, log_level=log_level)

+ 1 - 0
requirements.txt

@@ -3,6 +3,7 @@ asyncer==0.0.2
 click==8.1.3
 fastapi==0.92.0
 filetype==1.2.0
+gradio==3.32.0
 imagehash==4.3.1
 numpy==1.23.5
 onnxruntime==1.14.1

+ 1 - 0
setup.py

@@ -42,6 +42,7 @@ setup(
         "click>=8.1.3",
         "fastapi>=0.92.0",
         "filetype>=1.2.0",
+        "gradio>=3.32.0",
         "imagehash>=4.3.1",
         "numpy>=1.23.5",
         "onnxruntime>=1.14.1",