瀏覽代碼

update default sam_prompt structure to include point data

Signed-off-by: ふぁ <[email protected]>
ふぁ 9 月之前
父節點
當前提交
9e6c46184d
共有 1 個文件被更改,包括 10 次插入1 次删除
  1. 10 1
      rembg/sessions/sam.py

+ 10 - 1
rembg/sessions/sam.py

@@ -143,7 +143,16 @@ class SamSession(BaseSession):
         Returns:
             List[PILImage]: A list of masks generated by the decoder.
         """
-        prompt = kwargs.get("sam_prompt", "{}")
+        prompt = kwargs.get(
+            "sam_prompt",
+            [
+                {
+                    "type": "point",
+                    "label": 1,
+                    "data": [int(img.width / 2), int(img.height / 2)],
+                }
+            ],
+        )
         schema = {
             "type": "array",
             "items": {