|
@@ -1,6 +1,6 @@
|
|
|
import os
|
|
|
from copy import deepcopy
|
|
|
-from typing import Dict, List, Tuple
|
|
|
+from typing import List
|
|
|
|
|
|
import cv2
|
|
|
import numpy as np
|
|
@@ -105,9 +105,10 @@ class SamSession(BaseSession):
|
|
|
valid_providers = []
|
|
|
available_providers = ort.get_available_providers()
|
|
|
|
|
|
- for provider in providers or []:
|
|
|
- if provider in available_providers:
|
|
|
- valid_providers.append(provider)
|
|
|
+ if providers:
|
|
|
+ for provider in providers or []:
|
|
|
+ if provider in available_providers:
|
|
|
+ valid_providers.append(provider)
|
|
|
else:
|
|
|
valid_providers.extend(available_providers)
|
|
|
|
|
@@ -142,7 +143,16 @@ class SamSession(BaseSession):
|
|
|
Returns:
|
|
|
List[PILImage]: A list of masks generated by the decoder.
|
|
|
"""
|
|
|
- prompt = kwargs.get("sam_prompt", "{}")
|
|
|
+ prompt = kwargs.get(
|
|
|
+ "sam_prompt",
|
|
|
+ [
|
|
|
+ {
|
|
|
+ "type": "point",
|
|
|
+ "label": 1,
|
|
|
+ "data": [int(img.width / 2), int(img.height / 2)],
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ )
|
|
|
schema = {
|
|
|
"type": "array",
|
|
|
"items": {
|