|
@@ -10,6 +10,13 @@ from ..bg import remove
|
|
|
|
|
|
|
|
|
def main():
|
|
|
+ model_path = os.environ.get(
|
|
|
+ "U2NETP_PATH",
|
|
|
+ os.path.expanduser(os.path.join("~", ".u2net")),
|
|
|
+ )
|
|
|
+ model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))]
|
|
|
+
|
|
|
+
|
|
|
ap = argparse.ArgumentParser()
|
|
|
|
|
|
ap.add_argument(
|
|
@@ -17,7 +24,7 @@ def main():
|
|
|
"--model",
|
|
|
default="u2net",
|
|
|
type=str,
|
|
|
- choices=("u2net", "u2net_human_seg", "u2netp"),
|
|
|
+ choices=model_choices,
|
|
|
help="The model name.",
|
|
|
)
|
|
|
|
|
@@ -66,8 +73,8 @@ def main():
|
|
|
ap.add_argument(
|
|
|
"-p",
|
|
|
"--path",
|
|
|
- nargs="+",
|
|
|
- help="Path of a file or a folder of files.",
|
|
|
+ nargs=2,
|
|
|
+ help="An input folder and an output folder.",
|
|
|
)
|
|
|
|
|
|
ap.add_argument(
|
|
@@ -94,15 +101,20 @@ def main():
|
|
|
|
|
|
if args.path:
|
|
|
full_paths = [os.path.abspath(path) for path in args.path]
|
|
|
+
|
|
|
+ input_paths = [full_paths[0]]
|
|
|
+ output_path = full_paths[1]
|
|
|
+
|
|
|
+ if not os.path.exists(output_path):
|
|
|
+ os.makedirs(output_path)
|
|
|
+
|
|
|
files = set()
|
|
|
|
|
|
- for path in full_paths:
|
|
|
+ for path in input_paths:
|
|
|
if os.path.isfile(path):
|
|
|
files.add(path)
|
|
|
else:
|
|
|
- full_paths += set(glob.glob(path + "/*")) - set(
|
|
|
- glob.glob(path + "/*.out.png")
|
|
|
- )
|
|
|
+ input_paths += set(glob.glob(path + "/*"))
|
|
|
|
|
|
for fi in tqdm(files):
|
|
|
fi_type = filetype.guess(fi)
|
|
@@ -113,7 +125,7 @@ def main():
|
|
|
continue
|
|
|
|
|
|
with open(fi, "rb") as input:
|
|
|
- with open(os.path.splitext(fi)[0] + ".out.png", "wb") as output:
|
|
|
+ with open(os.path.join(output_path, os.path.splitext(os.path.basename(fi))[0] + ".png"), "wb") as output:
|
|
|
w(
|
|
|
output,
|
|
|
remove(
|