浏览代码

chore: Refactor remove function to handle unsupported input types

Daniel Gatis 1 年之前
父节点
当前提交
688b34841b
共有 1 个文件被更改,包括 6 次插入2 次删除
  1. 6 2
      rembg/bg.py

+ 6 - 2
rembg/bg.py

@@ -241,7 +241,7 @@ def remove(
     """
     """
     if isinstance(data, bytes) or force_return_bytes:
     if isinstance(data, bytes) or force_return_bytes:
         return_type = ReturnType.BYTES
         return_type = ReturnType.BYTES
-        img = Image.open(io.BytesIO(data))
+        img = Image.open(io.BytesIO(cast(bytes, data)))
     elif isinstance(data, PILImage):
     elif isinstance(data, PILImage):
         return_type = ReturnType.PILLOW
         return_type = ReturnType.PILLOW
         img = data
         img = data
@@ -249,7 +249,11 @@ def remove(
         return_type = ReturnType.NDARRAY
         return_type = ReturnType.NDARRAY
         img = Image.fromarray(data)
         img = Image.fromarray(data)
     else:
     else:
-        raise ValueError("Input type {} is not supported. Try using force_return_bytes=True to force python bytes output".format(type(data)))
+        raise ValueError(
+            "Input type {} is not supported. Try using force_return_bytes=True to force python bytes output".format(
+                type(data)
+            )
+        )
 
 
     putalpha = kwargs.pop("putalpha", False)
     putalpha = kwargs.pop("putalpha", False)