common.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from typing import List, Optional
  2. import pydantic
  3. import numpy as np
  4. # Global list mapping each dihedral transform id to its inverse.
  5. # Index corresponds to the original tid, and the value is its inverse.
  6. DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]
  7. class PuzzleDatasetMetadata(pydantic.BaseModel):
  8. pad_id: int
  9. ignore_label_id: Optional[int]
  10. blank_identifier_id: int
  11. vocab_size: int
  12. seq_len: int
  13. num_puzzle_identifiers: int
  14. total_groups: int
  15. mean_puzzle_examples: float
  16. total_puzzles: int
  17. sets: List[str]
  18. def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
  19. """8 dihedral symmetries by rotate, flip and mirror"""
  20. if tid == 0:
  21. return arr # identity
  22. elif tid == 1:
  23. return np.rot90(arr, k=1)
  24. elif tid == 2:
  25. return np.rot90(arr, k=2)
  26. elif tid == 3:
  27. return np.rot90(arr, k=3)
  28. elif tid == 4:
  29. return np.fliplr(arr) # horizontal flip
  30. elif tid == 5:
  31. return np.flipud(arr) # vertical flip
  32. elif tid == 6:
  33. return arr.T # transpose (reflection along main diagonal)
  34. elif tid == 7:
  35. return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection
  36. else:
  37. return arr
  38. def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
  39. return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])