Explorar o código

Merge pull request #693 from fa0311/main

Fixed bugs related to sum
Daniel Gatis hai 9 meses
pai
achega
5435d2ffee
Modificáronse 1 ficheiros con 15 adicións e 5 borrados
  1. 15 5
      rembg/sessions/sam.py

+ 15 - 5
rembg/sessions/sam.py

@@ -1,6 +1,6 @@
 import os
 from copy import deepcopy
-from typing import Dict, List, Tuple
+from typing import List
 
 import cv2
 import numpy as np
@@ -105,9 +105,10 @@ class SamSession(BaseSession):
         valid_providers = []
         available_providers = ort.get_available_providers()
 
-        for provider in providers or []:
-            if provider in available_providers:
-                valid_providers.append(provider)
+        if providers:
+            for provider in providers or []:
+                if provider in available_providers:
+                    valid_providers.append(provider)
         else:
             valid_providers.extend(available_providers)
 
@@ -142,7 +143,16 @@ class SamSession(BaseSession):
         Returns:
             List[PILImage]: A list of masks generated by the decoder.
         """
-        prompt = kwargs.get("sam_prompt", "{}")
+        prompt = kwargs.get(
+            "sam_prompt",
+            [
+                {
+                    "type": "point",
+                    "label": 1,
+                    "data": [int(img.width / 2), int(img.height / 2)],
+                }
+            ],
+        )
         schema = {
             "type": "array",
             "items": {