Browse Source

update default sam_prompt structure to include point data

Signed-off-by: ふぁ <[email protected]>
ふぁ 9 months ago
parent
commit
9e6c46184d
1 changed files with 10 additions and 1 deletions
  1. 10 1
      rembg/sessions/sam.py

+ 10 - 1
rembg/sessions/sam.py

@@ -143,7 +143,16 @@ class SamSession(BaseSession):
         Returns:
             List[PILImage]: A list of masks generated by the decoder.
         """
-        prompt = kwargs.get("sam_prompt", "{}")
+        prompt = kwargs.get(
+            "sam_prompt",
+            [
+                {
+                    "type": "point",
+                    "label": 1,
+                    "data": [int(img.width / 2), int(img.height / 2)],
+                }
+            ],
+        )
         schema = {
             "type": "array",
             "items": {