Daniel Gatis преди 3 години
родител
ревизия
f264c105d4
променени са 3 файла, в които са добавени 41 реда и са изтрити 32 реда
  1. 26 27
      README.md
  2. 2 0
      rembg/__init__.py
  3. 13 5
      rembg/bg.py

+ 26 - 27
README.md

@@ -87,45 +87,44 @@ Also you can send the file as a FormData (multipart/form-data):
 
 ### Usage as a library
 
-#### Example 1: Read from stdin and write to stdout
-
-In `app.py`
+Input and output as bytes
 ```python
-import sys
-from rembg.bg import remove
+from rembg import remove
 
-sys.stdout.buffer.write(remove(sys.stdin.buffer.read()))
-```
+input_path = 'input.png'
+output_path = 'output.png'
 
-Then run
+with open(input_path, 'rb') as i:
+    with open(output_path, 'wb') as o:
+        input = i.read()
+        output = remove(input)
+        o.write(output)
 ```
-cat input.png | python app.py > out.png
-```
-
-#### Example 2: Using PIL
 
-In `app.py`
+Input and output as a PIL image
 ```python
-from rembg.bg import remove
-import numpy as np
-import io
+from rembg import remove
 from PIL import Image
 
 input_path = 'input.png'
-output_path = 'out.png'
-
-# Uncomment the following line if working with trucated image formats (ex. JPEG / JPG)
-# ImageFile.LOAD_TRUNCATED_IMAGES = True
+output_path = 'output.png'
 
-f = np.fromfile(input_path)
-result = remove(f)
-img = Image.open(io.BytesIO(result)).convert("RGBA")
-img.save(output_path)
+input = Image.open(input_path)
+output = remove(input)
+output.save(output_path)
 ```
 
-Then run
-```
-python app.py
+Input and output as a numpy array
+```python
+from rembg import remove
+import cv2
+
+input_path = 'input.png'
+output_path = 'output.png'
+
+input = cv2.imread(input_path)
+output = remove(input)
+cv2.imwrite(output_path, output)
 ```
 
 ### Usage as a docker

+ 2 - 0
rembg/__init__.py

@@ -1,3 +1,5 @@
 from . import _version
 
 __version__ = _version.get_versions()["version"]
+
+from .bg import remove

+ 13 - 5
rembg/bg.py

@@ -17,6 +17,7 @@ from .detect import ort_session, predict
 class ReturnType(Enum):
     BYTES = 0
     PILLOW = 1
+    NDARRAY = 2
 
 
 def alpha_matting_cutout(
@@ -65,27 +66,31 @@ def naive_cutout(img: Image, mask: Image) -> Image:
 
 
 def remove(
-    data: Union[bytes, PILImage],
+    data: Union[bytes, PILImage, np.ndarray],
     alpha_matting: bool = False,
     alpha_matting_foreground_threshold: int = 240,
     alpha_matting_background_threshold: int = 10,
     alpha_matting_erode_size: int = 10,
     session: Optional[ort.InferenceSession] = None,
     only_mask: bool = False,
-) -> Union[bytes, PILImage]:
+) -> Union[bytes, PILImage, np.ndarray]:
+
     if isinstance(data, PILImage):
         return_type = ReturnType.PILLOW
-        img = data.convert("RGB")
+        img = data
     elif isinstance(data, bytes):
         return_type = ReturnType.BYTES
-        img = Image.open(io.BytesIO(data)).convert("RGB")
+        img = Image.open(io.BytesIO(data))
+    elif isinstance(data, np.ndarray):
+        return_type = ReturnType.NDARRAY
+        img = Image.fromarray(data)
     else:
         raise ValueError("Input type {} is not supported.".format(type(data)))
 
     if session is None:
         session = ort_session("u2net")
 
-    mask = predict(session, np.array(img)).convert("L")
+    mask = predict(session, np.array(img.convert("RGB"))).convert("L")
     mask = mask.resize(img.size, Image.LANCZOS)
 
     if only_mask:
@@ -105,6 +110,9 @@ def remove(
     if ReturnType.PILLOW == return_type:
         return cutout
 
+    if ReturnType.NDARRAY == return_type:
+        return np.asarray(cutout)
+
     bio = io.BytesIO()
     cutout.save(bio, "PNG")
     bio.seek(0)