Daniel Gatis 3 anos atrás
pai
commit
972f82ccbb
4 arquivos alterados com 15 adições e 18 exclusões
  1. 5 13
      .github/workflows/lint_python.yml
  2. 6 1
      .gitignore
  3. 2 2
      src/rembg/cmd/server.py
  4. 2 2
      src/rembg/u2net/detect.py

+ 5 - 13
.github/workflows/lint_python.yml

@@ -7,17 +7,9 @@ jobs:
       - uses: actions/checkout@v2
       - uses: actions/setup-python@v2
       - run: pip install --upgrade pip wheel
-      - run: pip install bandit black codespell flake8 flake8-bugbear
-                         flake8-comprehensions isort mypy pytest pyupgrade safety
-      - run: bandit --recursive --skip B101,B104,B310,B311 .
-      - run: black --check --diff .
-      - run: codespell
-      - run: flake8 . --count --ignore=E203,E266,E731,F401,F811,F841,W503
-                      --max-complexity=10 --max-line-length=103 --show-source --statistics
-      - run: isort --check-only --profile black .
-      - run: pip install -r requirements.txt
-      - run: mkdir --parents --verbose .mypy_cache
-      - run: mypy --ignore-missing-imports --install-types --non-interactive . || true
-      - run: pytest . || pytest --doctest-modules . || true
-      - run: shopt -s globstar && pyupgrade --py36-plus **/*.py
+      - run: pip install bandit black flake8 flake8-bugbear flake8-comprehensions isort safety
+      - run: bandit --recursive --skip B101,B104,B310,B311,B303 ./src
+      - run: black --check --diff ./src
+      - run: flake8 ./src --count --ignore=E203,E266,E731,F401,F811,F841,W503 --max-complexity=15 --max-line-length=120 --show-source --statistics
+      - run: isort --check-only --profile black ./src
       - run: safety check

+ 6 - 1
.gitignore

@@ -1,13 +1,18 @@
 # general things to ignore
 build/
 dist/
+.venv/
+.direnv/
 *.egg-info/
 *.egg
 *.py[cod]
 __pycache__/
 *.so
-*~
+*~≈
+.envrc
+.python-version
 
 # due to using tox and pytest
 .tox
 .cache
+.mypy_cache

+ 2 - 2
src/rembg/cmd/server.py

@@ -2,7 +2,7 @@ import argparse
 import glob
 import os
 from io import BytesIO
-from urllib.parse import unquote_plus
+from urllib.parse import quote, unquote_plus
 from urllib.request import urlopen
 
 from flask import Flask, request, send_file
@@ -27,7 +27,7 @@ def index():
         url = request.args.get("url", type=str)
         if url is None:
             return {"error": "missing query param 'url'"}, 400
-        
+
         url = unquote_plus(url)
         if " " in url:
             url = quote(url, safe="/:")

+ 2 - 2
src/rembg/u2net/detect.py

@@ -2,6 +2,7 @@ import errno
 import os
 import sys
 import urllib.request
+from hashlib import md5
 
 import numpy as np
 import requests
@@ -9,7 +10,6 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torchvision
-from hashlib import md5
 from PIL import Image
 from skimage import transform
 from torchvision import transforms
@@ -52,7 +52,7 @@ def download_file_from_google_drive(id, fname, destination):
 
 
 def load_model(model_name: str = "u2net"):
-    hashfile = lambda f: md5(open(f,"rb").read()).hexdigest()
+    hashfile = lambda f: md5(open(f, "rb").read()).hexdigest()
 
     if model_name == "u2netp":
         net = u2net.U2NETP(3, 1)