test_remove.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from io import BytesIO
  2. from pathlib import Path
  3. from imagehash import phash as hash_img
  4. from PIL import Image
  5. from rembg import new_session, remove
  6. here = Path(__file__).parent.resolve()
  7. def test_remove():
  8. kwargs = {
  9. "sam": {
  10. "anime-girl-1" : {
  11. "sam_prompt" :[{"type": "point", "data": [400, 165], "label": 1}],
  12. },
  13. "car-1" : {
  14. "sam_prompt" :[{"type": "point", "data": [250, 200], "label": 1}],
  15. },
  16. "cloth-1" : {
  17. "sam_prompt" :[{"type": "point", "data": [370, 495], "label": 1}],
  18. },
  19. "plants-1" : {
  20. "sam_prompt" :[{"type": "point", "data": [724, 740], "label": 1}],
  21. },
  22. }
  23. }
  24. for model in [
  25. "u2net",
  26. "u2netp",
  27. "u2net_human_seg",
  28. "u2net_cloth_seg",
  29. "silueta",
  30. "isnet-general-use",
  31. "isnet-anime",
  32. "sam",
  33. "birefnet-general",
  34. "birefnet-general-lite",
  35. "birefnet-portrait",
  36. "birefnet-dis",
  37. "birefnet-hrsod",
  38. "birefnet-cod",
  39. "birefnet-massive"
  40. ]:
  41. for picture in ["anime-girl-1", "car-1", "cloth-1", "plants-1"]:
  42. image_path = Path(here / "fixtures" / f"{picture}.jpg")
  43. image = image_path.read_bytes()
  44. actual = remove(image, session=new_session(model), **kwargs.get(model, {}).get(picture, {}))
  45. actual_hash = hash_img(Image.open(BytesIO(actual)))
  46. expected_path = Path(here / "results" / f"{picture}.{model}.png")
  47. # Uncomment to update the expected results
  48. # f = open(expected_path, "wb")
  49. # f.write(actual)
  50. # f.close()
  51. expected = expected_path.read_bytes()
  52. expected_hash = hash_img(Image.open(BytesIO(expected)))
  53. print(f"image_path: {image_path}")
  54. print(f"expected_path: {expected_path}")
  55. print(f"actual_hash: {actual_hash}")
  56. print(f"expected_hash: {expected_hash}")
  57. print(f"actual_hash == expected_hash: {actual_hash == expected_hash}")
  58. print("---\n")
  59. assert actual_hash == expected_hash