test_remove.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from io import BytesIO
  2. from pathlib import Path
  3. from imagehash import phash as hash_img
  4. from PIL import Image
  5. from rembg import new_session, remove
  6. here = Path(__file__).parent.resolve()
  7. def test_remove():
  8. kwargs = {
  9. "sam": {
  10. "anime-girl-1" : {
  11. "input_points": [[400, 165]],
  12. "input_labels": [1],
  13. },
  14. "car-1" : {
  15. "input_points": [[250, 200]],
  16. "input_labels": [1],
  17. },
  18. "cloth-1" : {
  19. "input_points": [[370, 495]],
  20. "input_labels": [1],
  21. },
  22. }
  23. }
  24. for model in [
  25. "u2net",
  26. "u2netp",
  27. "u2net_human_seg",
  28. "u2net_cloth_seg",
  29. "silueta",
  30. "isnet-general-use",
  31. "isnet-anime",
  32. "sam"
  33. ]:
  34. for picture in ["anime-girl-1", "car-1", "cloth-1"]:
  35. image_path = Path(here / "fixtures" / f"{picture}.jpg")
  36. image = image_path.read_bytes()
  37. actual = remove(image, session=new_session(model), **kwargs.get(model, {}).get(picture, {}))
  38. actual_hash = hash_img(Image.open(BytesIO(actual)))
  39. expected_path = Path(here / "results" / f"{picture}.{model}.png")
  40. # Uncomment to update the expected results
  41. # f = open(expected_path, "ab")
  42. # f.write(actual)
  43. # f.close()
  44. expected = expected_path.read_bytes()
  45. expected_hash = hash_img(Image.open(BytesIO(expected)))
  46. print(f"image_path: {image_path}")
  47. print(f"expected_path: {expected_path}")
  48. print(f"actual_hash: {actual_hash}")
  49. print(f"expected_hash: {expected_hash}")
  50. print(f"actual_hash == expected_hash: {actual_hash == expected_hash}")
  51. print("---\n")
  52. assert actual_hash == expected_hash