Parcourir la source

fix model choices

Daniel Gatis il y a 3 ans
Parent
commit
9405b96d80
4 fichiers modifiés avec 7 ajouts et 7 suppressions
  1. 2 2
      requirements.txt
  2. 1 1
      setup.py
  3. 2 2
      src/rembg/cmd/cli.py
  4. 2 2
      src/rembg/cmd/server.py

+ 2 - 2
requirements.txt

@@ -2,8 +2,8 @@ flask>=1.1.2
 numpy>=1.19.5
 pillow>=8.0.1
 scikit-image>=0.17.2
-torch>=1.7.0
-torchvision>=0.8.1
+torch>=1.9.1
+torchvision>=0.10.1
 waitress>=1.4.4
 tqdm>=4.51.0
 requests>=2.24.0

+ 1 - 1
setup.py

@@ -11,7 +11,7 @@ with open("requirements.txt") as f:
 
 setup(
     name="rembg",
-    version="1.0.27",
+    version="1.0.28",
     description="Remove image background",
     long_description=long_description,
     long_description_content_type="text/markdown",

+ 2 - 2
src/rembg/cmd/cli.py

@@ -18,8 +18,8 @@ def main():
         os.path.splitext(os.path.basename(x))[0]
         for x in set(glob.glob(model_path + "/*"))
     ]
-    if len(model_choices) == 0:
-        model_choices = ["u2net", "u2netp", "u2net_human_seg"]
+
+    model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"]))
 
     ap = argparse.ArgumentParser()
 

+ 2 - 2
src/rembg/cmd/server.py

@@ -52,8 +52,8 @@ def index():
         os.path.splitext(os.path.basename(x))[0]
         for x in set(glob.glob(model_path + "/*"))
     ]
-    if len(model_choices) == 0:
-        model_choices = ["u2net", "u2netp", "u2net_human_seg"]
+
+    model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"]))
 
     if model not in model_choices:
         return {