test_remove.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from io import BytesIO
  2. from pathlib import Path
  3. from imagehash import phash as hash_img
  4. from PIL import Image
  5. from rembg import new_session, remove
  6. here = Path(__file__).parent.resolve()
  7. def test_remove():
  8. kwargs = {
  9. "sam": {
  10. "car-1" : {
  11. "input_points": [[250, 200]],
  12. "input_labels": [1],
  13. },
  14. "cloth-1" : {
  15. "input_points": [[370, 495]],
  16. "input_labels": [1],
  17. }
  18. }
  19. }
  20. for model in [
  21. "u2net",
  22. "u2netp",
  23. "u2net_human_seg",
  24. "u2net_cloth_seg",
  25. "silueta",
  26. "isnet-general-use",
  27. "sam"
  28. ]:
  29. for picture in ["car-1", "cloth-1"]:
  30. image_path = Path(here / "fixtures" / f"{picture}.jpg")
  31. image = image_path.read_bytes()
  32. actual = remove(image, session=new_session(model), **kwargs.get(model, {}).get(picture, {}))
  33. actual_hash = hash_img(Image.open(BytesIO(actual)))
  34. expected_path = Path(here / "results" / f"{picture}.{model}.png")
  35. # Uncomment to update the expected results
  36. # f = open(expected_path, "ab")
  37. # f.write(actual)
  38. # f.close()
  39. expected = expected_path.read_bytes()
  40. expected_hash = hash_img(Image.open(BytesIO(expected)))
  41. print(f"image_path: {image_path}")
  42. print(f"expected_path: {expected_path}")
  43. print(f"actual_hash: {actual_hash}")
  44. print(f"expected_hash: {expected_hash}")
  45. print(f"actual_hash == expected_hash: {actual_hash == expected_hash}")
  46. print("---\n")
  47. assert actual_hash == expected_hash