浏览代码

fix project layout

Daniel Gatis 3 年之前
父节点
当前提交
0fd1236db4
共有 5 个文件被更改,包括 12 次插入12 次删除
  1. 5 5
      rembg/bg.py
  2. 0 0
      rembg/data_loader.py
  3. 7 7
      rembg/detect.py
  4. 0 0
      rembg/u2net.py
  5. 0 0
      rembg/u2net/__init__.py

+ 5 - 5
rembg/bg.py

@@ -8,7 +8,7 @@ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
 from pymatting.util.util import stack_images
 from pymatting.util.util import stack_images
 from scipy.ndimage.morphology import binary_erosion
 from scipy.ndimage.morphology import binary_erosion
 
 
-from .u2net import detect
+from .detect import load_model, predict
 
 
 
 
 def alpha_matting_cutout(
 def alpha_matting_cutout(
@@ -71,11 +71,11 @@ def naive_cutout(img, mask):
 @functools.lru_cache(maxsize=None)
 @functools.lru_cache(maxsize=None)
 def get_model(model_name):
 def get_model(model_name):
     if model_name == "u2netp":
     if model_name == "u2netp":
-        return detect.load_model(model_name="u2netp")
+        return load_model(model_name="u2netp")
     if model_name == "u2net_human_seg":
     if model_name == "u2net_human_seg":
-        return detect.load_model(model_name="u2net_human_seg")
+        return load_model(model_name="u2net_human_seg")
     else:
     else:
-        return detect.load_model(model_name="u2net")
+        return load_model(model_name="u2net")
 
 
 
 
 def resize_image(img, width, height):
 def resize_image(img, width, height):
@@ -105,7 +105,7 @@ def remove(
         img = resize_image(img, width, height)
         img = resize_image(img, width, height)
 
 
     model = get_model(model_name)
     model = get_model(model_name)
-    mask = detect.predict(model, np.array(img)).convert("L")
+    mask = predict(model, np.array(img)).convert("L")
 
 
     if alpha_matting:
     if alpha_matting:
         try:
         try:

+ 0 - 0
rembg/u2net/data_loader.py → rembg/data_loader.py


+ 7 - 7
rembg/u2net/detect.py → rembg/detect.py

@@ -15,8 +15,8 @@ from skimage import transform
 from torchvision import transforms
 from torchvision import transforms
 from tqdm import tqdm
 from tqdm import tqdm
 
 
-from . import data_loader, u2net
-
+from .data_loader import RescaleT, ToTensorLab
+from .u2net import U2NETP, U2NET
 
 
 def download_file_from_google_drive(id, fname, destination):
 def download_file_from_google_drive(id, fname, destination):
     head, tail = os.path.split(destination)
     head, tail = os.path.split(destination)
@@ -55,7 +55,7 @@ def load_model(model_name: str = "u2net"):
     hashfile = lambda f: md5(open(f, "rb").read()).hexdigest()
     hashfile = lambda f: md5(open(f, "rb").read()).hexdigest()
 
 
     if model_name == "u2netp":
     if model_name == "u2netp":
-        net = u2net.U2NETP(3, 1)
+        net = U2NETP(3, 1)
         path = os.environ.get(
         path = os.environ.get(
             "U2NETP_PATH",
             "U2NETP_PATH",
             os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
             os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
@@ -71,7 +71,7 @@ def load_model(model_name: str = "u2net"):
             )
             )
 
 
     elif model_name == "u2net":
     elif model_name == "u2net":
-        net = u2net.U2NET(3, 1)
+        net = U2NET(3, 1)
         path = os.environ.get(
         path = os.environ.get(
             "U2NET_PATH",
             "U2NET_PATH",
             os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
             os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
@@ -82,12 +82,12 @@ def load_model(model_name: str = "u2net"):
         ):
         ):
             download_file_from_google_drive(
             download_file_from_google_drive(
                 "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
                 "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
-                "u2net.pth",
+                "pth",
                 path,
                 path,
             )
             )
 
 
     elif model_name == "u2net_human_seg":
     elif model_name == "u2net_human_seg":
-        net = u2net.U2NET(3, 1)
+        net = U2NET(3, 1)
         path = os.environ.get(
         path = os.environ.get(
             "U2NET_PATH",
             "U2NET_PATH",
             os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
             os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
@@ -149,7 +149,7 @@ def preprocess(image):
         label = label[:, :, np.newaxis]
         label = label[:, :, np.newaxis]
 
 
     transform = transforms.Compose(
     transform = transforms.Compose(
-        [data_loader.RescaleT(320), data_loader.ToTensorLab(flag=0)]
+        [RescaleT(320), ToTensorLab(flag=0)]
     )
     )
     sample = transform({"imidx": np.array([0]), "image": image, "label": label})
     sample = transform({"imidx": np.array([0]), "image": image, "label": label})
 
 

+ 0 - 0
rembg/u2net/u2net.py → rembg/u2net.py


+ 0 - 0
rembg/u2net/__init__.py