arc.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. from typing import Dict, Sequence, Optional
  2. import os
  3. import json
  4. import torch
  5. import numpy as np
  6. from numba import njit
  7. import torch.distributed as dist
  8. from dataset.build_arc_dataset import inverse_aug, grid_hash, arc_grid_to_np
  9. from dataset.common import PuzzleDatasetMetadata
  10. @njit
  11. def _crop(grid: np.ndarray):
  12. """Find maximum-sized rectangle without any EOS token inside. """
  13. grid = grid.reshape(30, 30)
  14. max_area = 0
  15. max_size = (0, 0)
  16. nr, nc = grid.shape
  17. num_c = nc
  18. for num_r in range(1, nr + 1):
  19. # Scan for maximum c
  20. for c in range(1, num_c + 1):
  21. x = grid[num_r - 1, c - 1]
  22. if (x < 2) | (x > 11):
  23. num_c = c - 1
  24. break
  25. area = num_r * num_c
  26. if area > max_area:
  27. max_area = area
  28. max_size = (num_r, num_c)
  29. return (grid[:max_size[0], :max_size[1]] - 2).astype(np.uint8)
  30. class ARC:
  31. required_outputs = {"inputs", "puzzle_identifiers", "q_halt_logits", "preds"}
  32. def __init__(self, data_path: str,
  33. eval_metadata: PuzzleDatasetMetadata,
  34. submission_K: int = 2,
  35. pass_Ks: Sequence[int] = (1, 2, 5, 10, 100, 1000),
  36. aggregated_voting: bool = True):
  37. super().__init__()
  38. self.pass_Ks = pass_Ks
  39. self.submission_K = submission_K
  40. self.aggregated_voting = aggregated_voting
  41. self.blank_identifier_id = eval_metadata.blank_identifier_id
  42. # Load identifiers and test puzzles
  43. with open(os.path.join(data_path, "identifiers.json"), "r") as f:
  44. self.identifier_map = json.load(f)
  45. with open(os.path.join(data_path, "test_puzzles.json"), "r") as f:
  46. self.test_puzzles = json.load(f)
  47. # States
  48. self._local_hmap = {}
  49. self._local_preds = {}
  50. def begin_eval(self):
  51. if not self.aggregated_voting:
  52. # Clear previous predictions
  53. self._local_hmap = {}
  54. self._local_preds = {}
  55. def update_batch(self, batch: Dict[str, torch.Tensor], preds: Dict[str, torch.Tensor]):
  56. # Collect required outputs to CPU
  57. outputs = {}
  58. q_values = None
  59. for collection in (batch, preds):
  60. for k, v in collection.items():
  61. if k in self.required_outputs:
  62. if k == "q_halt_logits":
  63. q_values = v.to(torch.float64).sigmoid().cpu()
  64. else:
  65. outputs[k] = v.cpu()
  66. assert q_values is not None
  67. # Remove padding from outputs
  68. mask = outputs["puzzle_identifiers"] != self.blank_identifier_id
  69. outputs = {k: v[mask] for k, v in outputs.items()}
  70. # Get predictions
  71. for identifier, input, pred, q in zip(outputs["puzzle_identifiers"].numpy(), outputs["inputs"].numpy(), outputs["preds"].numpy(), q_values.numpy()):
  72. name = self.identifier_map[identifier]
  73. orig_name, _inverse_fn = inverse_aug(name)
  74. input_hash = grid_hash(_inverse_fn(_crop(input)))
  75. pred = _inverse_fn(_crop(pred))
  76. assert np.all((pred >= 0) & (pred <= 9)), f"Puzzle {name}'s prediction out of 0-9 range." # Sanity check
  77. # Store into local state
  78. pred_hash = grid_hash(pred)
  79. self._local_hmap[pred_hash] = pred
  80. self._local_preds.setdefault(orig_name, {})
  81. self._local_preds[orig_name].setdefault(input_hash, [])
  82. self._local_preds[orig_name][input_hash].append((pred_hash, float(q)))
  83. def result(self, save_path: Optional[str], rank: int, world_size: int, group: Optional[torch.distributed.ProcessGroup] = None) -> Optional[Dict[str, float]]:
  84. # Gather predictions to rank 0 for voting
  85. global_hmap_preds = [None for _ in range(world_size)] if rank == 0 else None
  86. dist.gather_object((self._local_hmap, self._local_preds), global_hmap_preds, dst=0, group=group)
  87. # Rank 0 logic
  88. if rank != 0:
  89. return
  90. submission = {}
  91. correct = [0.0 for _ in range(len(self.pass_Ks))]
  92. for name, puzzle in self.test_puzzles.items():
  93. # Process test examples in this puzzle
  94. submission[name] = []
  95. num_test_correct = [0 for _ in range(len(self.pass_Ks))]
  96. for pair in puzzle["test"]:
  97. input_hash = grid_hash(arc_grid_to_np(pair["input"]))
  98. label_hash = grid_hash(arc_grid_to_np(pair["output"]))
  99. p_map = {}
  100. for hmap, preds in global_hmap_preds: # type: ignore
  101. for h, q in preds.get(name, {}).get(input_hash, {}):
  102. p_map.setdefault(h, [0, 0])
  103. p_map[h][0] += 1
  104. p_map[h][1] += q
  105. if not len(p_map):
  106. print (f"Puzzle {name} has no predictions.")
  107. continue
  108. for h, stats in p_map.items():
  109. stats[1] /= stats[0]
  110. p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)
  111. # vote for different Ks
  112. for i, k in enumerate(self.pass_Ks):
  113. ok = False
  114. for h, stats in p_map[:k]:
  115. ok |= h == label_hash
  116. num_test_correct[i] += ok
  117. # Query grids
  118. pred_grids = []
  119. for h, stats in p_map[:self.submission_K]:
  120. for hmap, preds in global_hmap_preds: # type: ignore
  121. if h in hmap:
  122. pred_grids.append(hmap[h])
  123. break
  124. # Pad to K
  125. while len(pred_grids) < self.submission_K:
  126. pred_grids.append(pred_grids[0])
  127. submission[name].append({f"attempt_{i + 1}": grid.tolist() for i, grid in enumerate(pred_grids)})
  128. # Total correctness
  129. for i in range(len(self.pass_Ks)):
  130. correct[i] += num_test_correct[i] / len(puzzle["test"])
  131. # Save submission
  132. if save_path is not None:
  133. with open(os.path.join(save_path, "submission.json"), "w") as f:
  134. json.dump(submission, f)
  135. # Final result
  136. all_results = {f"ARC/pass@{k}": correct[i] / len(self.test_puzzles) for i, k in enumerate(self.pass_Ks)}
  137. return all_results