浏览代码

Merge pull request #142 from cclauss/pyupgrade

Make black, isort, and pyupgrade mandatory tests
Daniel Gatis 3 年之前
父节点
当前提交
9d16c4354a
共有 6 个文件被更改,包括 40 次插入29 次删除
  1. 7 8
      .github/workflows/lint_python.yml
  2. 1 1
      src/rembg/bg.py
  3. 10 2
      src/rembg/cmd/cli.py
  4. 9 4
      src/rembg/cmd/server.py
  5. 5 6
      src/rembg/u2net/data_loader.py
  6. 8 8
      src/rembg/u2net/u2net.py

+ 7 - 8
.github/workflows/lint_python.yml

@@ -10,15 +10,14 @@ jobs:
       - run: pip install bandit black codespell flake8 flake8-bugbear
                          flake8-comprehensions isort mypy pytest pyupgrade safety
       - run: bandit --recursive --skip B101,B104,B310,B311 .
-      - run: black --check . || true
+      - run: black --check --diff .
       - run: codespell
-      - run: flake8 . --count --ignore=B001,E203,E266,E722,E731,F401,F811,F841,W503
-                      --max-complexity=10 --max-line-length=121 --show-source --statistics
-      - run: isort --check-only --profile black . || true
-      - run: pip install -r requirements.txt || pip install --editable . || true
+      - run: flake8 . --count --ignore=E203,E266,E731,F401,F811,F841,W503
+                      --max-complexity=10 --max-line-length=103 --show-source --statistics
+      - run: isort --check-only --profile black .
+      - run: pip install -r requirements.txt
       - run: mkdir --parents --verbose .mypy_cache
       - run: mypy --ignore-missing-imports --install-types --non-interactive . || true
-      - run: pytest . || true
-      - run: pytest --doctest-modules . || true
-      - run: shopt -s globstar && pyupgrade --py36-plus **/*.py || true
+      - run: pytest . || pytest --doctest-modules . || true
+      - run: shopt -s globstar && pyupgrade --py36-plus **/*.py
       - run: safety check

+ 1 - 1
src/rembg/bg.py

@@ -101,7 +101,7 @@ def remove(
                 alpha_matting_erode_structure_size,
                 alpha_matting_base_size,
             )
-        except:
+        except Exception:
             cutout = naive_cutout(img, mask)
     else:
         cutout = naive_cutout(img, mask)

+ 10 - 2
src/rembg/cmd/cli.py

@@ -14,7 +14,10 @@ def main():
         "U2NETP_PATH",
         os.path.expanduser(os.path.join("~", ".u2net")),
     )
-    model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))]
+    model_choices = [
+        os.path.splitext(os.path.basename(x))[0]
+        for x in set(glob.glob(model_path + "/*"))
+    ]
     if len(model_choices) == 0:
         model_choices = ["u2net", "u2netp", "u2net_human_seg"]
 
@@ -126,7 +129,12 @@ def main():
                 continue
 
             with open(fi, "rb") as input:
-                with open(os.path.join(output_path, os.path.splitext(os.path.basename(fi))[0] + ".png"), "wb") as output:
+                with open(
+                    os.path.join(
+                        output_path, os.path.splitext(os.path.basename(fi))[0] + ".png"
+                    ),
+                    "wb",
+                ) as output:
                     w(
                         output,
                         remove(

+ 9 - 4
src/rembg/cmd/server.py

@@ -1,6 +1,6 @@
-import os
-import glob
 import argparse
+import glob
+import os
 from io import BytesIO
 from urllib.parse import unquote_plus
 from urllib.request import urlopen
@@ -48,12 +48,17 @@ def index():
         "U2NETP_PATH",
         os.path.expanduser(os.path.join("~", ".u2net")),
     )
-    model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))]
+    model_choices = [
+        os.path.splitext(os.path.basename(x))[0]
+        for x in set(glob.glob(model_path + "/*"))
+    ]
     if len(model_choices) == 0:
         model_choices = ["u2net", "u2netp", "u2net_human_seg"]
 
     if model not in model_choices:
-        return {"error": f"invalid query param 'model'. Available options are {model_choices}"}, 400
+        return {
+            "error": f"invalid query param 'model'. Available options are {model_choices}"
+        }, 400
 
     try:
         return send_file(

+ 5 - 6
src/rembg/u2net/data_loader.py

@@ -1,5 +1,4 @@
 # data loader
-from __future__ import division, print_function
 
 import random
 
@@ -13,7 +12,7 @@ from torchvision import transforms, utils
 
 
 # ==========================dataset load==========================
-class RescaleT(object):
+class RescaleT:
     def __init__(self, output_size):
         assert isinstance(output_size, (int, tuple))
         self.output_size = output_size
@@ -51,7 +50,7 @@ class RescaleT(object):
         return {"imidx": imidx, "image": img, "label": lbl}
 
 
-class Rescale(object):
+class Rescale:
     def __init__(self, output_size):
         assert isinstance(output_size, (int, tuple))
         self.output_size = output_size
@@ -84,7 +83,7 @@ class Rescale(object):
         return {"imidx": imidx, "image": img, "label": lbl}
 
 
-class RandomCrop(object):
+class RandomCrop:
     def __init__(self, output_size):
         assert isinstance(output_size, (int, tuple))
         if isinstance(output_size, int):
@@ -112,7 +111,7 @@ class RandomCrop(object):
         return {"imidx": imidx, "image": image, "label": label}
 
 
-class ToTensor(object):
+class ToTensor:
     """Convert ndarrays in sample to Tensors."""
 
     def __call__(self, sample):
@@ -151,7 +150,7 @@ class ToTensor(object):
         }
 
 
-class ToTensorLab(object):
+class ToTensorLab:
     """Convert ndarrays in sample to Tensors."""
 
     def __init__(self, flag=0):

+ 8 - 8
src/rembg/u2net/u2net.py

@@ -6,7 +6,7 @@ from torchvision import models
 
 class REBNCONV(nn.Module):
     def __init__(self, in_ch=3, out_ch=3, dirate=1):
-        super(REBNCONV, self).__init__()
+        super().__init__()
 
         self.conv_s1 = nn.Conv2d(
             in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
@@ -33,7 +33,7 @@ def _upsample_like(src, tar):
 ### RSU-7 ###
 class RSU7(nn.Module):  # UNet07DRES(nn.Module):
     def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
-        super(RSU7, self).__init__()
+        super().__init__()
 
         self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
 
@@ -110,7 +110,7 @@ class RSU7(nn.Module):  # UNet07DRES(nn.Module):
 ### RSU-6 ###
 class RSU6(nn.Module):  # UNet06DRES(nn.Module):
     def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
-        super(RSU6, self).__init__()
+        super().__init__()
 
         self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
 
@@ -178,7 +178,7 @@ class RSU6(nn.Module):  # UNet06DRES(nn.Module):
 ### RSU-5 ###
 class RSU5(nn.Module):  # UNet05DRES(nn.Module):
     def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
-        super(RSU5, self).__init__()
+        super().__init__()
 
         self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
 
@@ -236,7 +236,7 @@ class RSU5(nn.Module):  # UNet05DRES(nn.Module):
 ### RSU-4 ###
 class RSU4(nn.Module):  # UNet04DRES(nn.Module):
     def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
-        super(RSU4, self).__init__()
+        super().__init__()
 
         self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
 
@@ -284,7 +284,7 @@ class RSU4(nn.Module):  # UNet04DRES(nn.Module):
 ### RSU-4F ###
 class RSU4F(nn.Module):  # UNet04FRES(nn.Module):
     def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
-        super(RSU4F, self).__init__()
+        super().__init__()
 
         self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
 
@@ -320,7 +320,7 @@ class RSU4F(nn.Module):  # UNet04FRES(nn.Module):
 ##### U^2-Net ####
 class U2NET(nn.Module):
     def __init__(self, in_ch=3, out_ch=1):
-        super(U2NET, self).__init__()
+        super().__init__()
 
         self.stage1 = RSU7(in_ch, 32, 64)
         self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
@@ -432,7 +432,7 @@ class U2NET(nn.Module):
 ### U^2-Net small ###
 class U2NETP(nn.Module):
     def __init__(self, in_ch=3, out_ch=1):
-        super(U2NETP, self).__init__()
+        super().__init__()
 
         self.stage1 = RSU7(in_ch, 16, 64)
         self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)