Browse Source

add a cache for models

Daniel Gatis 4 years ago
parent
commit
8d5e4aba5d
3 changed files with 36 additions and 23 deletions
  1. 1 1
      setup.py
  2. 10 8
      src/rembg/bg.py
  3. 25 14
      src/rembg/u2net/detect.py

+ 1 - 1
setup.py

@@ -11,7 +11,7 @@ with open("requirements.txt") as f:
 
 setup(
     name="rembg",
-    version="1.0.13",
+    version="1.0.14",
     description="Remove image background",
     long_description=long_description,
     long_description_content_type="text/markdown",

+ 10 - 8
src/rembg/bg.py

@@ -1,5 +1,6 @@
 import io
 
+import functools
 import numpy as np
 from PIL import Image
 from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
@@ -9,9 +10,6 @@ from scipy.ndimage.morphology import binary_erosion
 
 from .u2net import detect
 
-model_u2net = detect.load_model(model_name="u2net")
-model_u2netp = detect.load_model(model_name="u2netp")
-
 
 def alpha_matting_cutout(
     img, mask, foreground_threshold, background_threshold, erode_structure_size,
@@ -66,6 +64,14 @@ def naive_cutout(img, mask):
     return cutout
 
 
[email protected]_cache
+def get_model(model_name):
+    if model_name == "u2netp":
+        return detect.load_model(model_name="u2netp")
+    else:
+        return detect.load_model(model_name="u2net")
+
+
 def remove(
     data,
     model_name="u2net",
@@ -74,11 +80,7 @@ def remove(
     alpha_matting_background_threshold=10,
     alpha_matting_erode_structure_size=10,
 ):
-    model = model_u2net
-
-    if model == "u2netp":
-        model = model_u2netp
-
+    model = get_model(model_name)
     img = Image.open(io.BytesIO(data)).convert("RGB")
     mask = detect.predict(model, np.array(img)).convert("L")
 

+ 25 - 14
src/rembg/u2net/detect.py

@@ -17,16 +17,31 @@ from tqdm import tqdm
 from . import data_loader, u2net
 
 
-def download(url, fname, path):
-    if os.path.exists(path):
+def download_file_from_google_drive(id, fname, destination):
+    if os.path.exists(destination):
         return
 
-    resp = requests.get(url, stream=True)
-    total = int(resp.headers.get("content-length", 0))
-    with open(path, "wb") as file, tqdm(
+    URL = "https://docs.google.com/uc?export=download"
+
+    session = requests.Session()
+    response = session.get(URL, params={"id": id}, stream=True)
+
+    token = None
+    for key, value in response.cookies.items():
+        if key.startswith("download_warning"):
+            token = value
+            break
+
+    if token:
+        params = {"id": id, "confirm": token}
+        response = session.get(URL, params=params, stream=True)
+
+    total = int(response.headers.get("content-length", 0))
+
+    with open(destination, "wb") as file, tqdm(
         desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024,
     ) as bar:
-        for data in resp.iter_content(chunk_size=1024):
+        for data in response.iter_content(chunk_size=1024):
             size = file.write(data)
             bar.update(size)
 
@@ -37,18 +52,14 @@ def load_model(model_name: str = "u2net"):
     if model_name == "u2netp":
         net = u2net.U2NETP(3, 1)
         path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
-        download(
-            "https://www.dropbox.com/s/usb1fyiuh8as5gi/u2netp.pth?dl=1",
-            "u2netp.pth",
-            path,
+        download_file_from_google_drive(
+            "1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy", "u2netp.pth", path,
         )
     elif model_name == "u2net":
         net = u2net.U2NET(3, 1)
         path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
-        download(
-            "https://www.dropbox.com/s/kdu5mhose1clds0/u2net.pth?dl=1",
-            "u2net.pth",
-            path,
+        download_file_from_google_drive(
+            "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", "u2net.pth", path,
         )
     else:
         print("Choose between u2net or u2netp", file=sys.stderr)