|
@@ -15,7 +15,8 @@ def main():
|
|
|
os.path.expanduser(os.path.join("~", ".u2net")),
|
|
|
)
|
|
|
model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))]
|
|
|
-
|
|
|
+ if len(model_choices) == 0:
|
|
|
+ model_choices = ["u2net", "u2netp", "u2net_human_seg"]
|
|
|
|
|
|
ap = argparse.ArgumentParser()
|
|
|
|
|
@@ -101,8 +102,8 @@ def main():
|
|
|
|
|
|
if args.path:
|
|
|
full_paths = [os.path.abspath(path) for path in args.path]
|
|
|
-
|
|
|
- input_paths = [full_paths[0]]
|
|
|
+
|
|
|
+ input_paths = [full_paths[0]]
|
|
|
output_path = full_paths[1]
|
|
|
|
|
|
if not os.path.exists(output_path):
|