Signed-off-by: ふぁ <[email protected]>
@@ -143,7 +143,16 @@ class SamSession(BaseSession):
Returns:
List[PILImage]: A list of masks generated by the decoder.
"""
- prompt = kwargs.get("sam_prompt", "{}")
+ prompt = kwargs.get(
+ "sam_prompt",
+ [
+ {
+ "type": "point",
+ "label": 1,
+ "data": [int(img.width / 2), int(img.height / 2)],
+ }
+ ],
+ )
schema = {
"type": "array",
"items": {