| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- from typing import Dict, Sequence, Optional
- import os
- import json
- import torch
- import numpy as np
- from numba import njit
- import torch.distributed as dist
- from dataset.build_arc_dataset import inverse_aug, grid_hash, arc_grid_to_np
- from dataset.common import PuzzleDatasetMetadata
- @njit
- def _crop(grid: np.ndarray):
- """Find maximum-sized rectangle without any EOS token inside. """
- grid = grid.reshape(30, 30)
- max_area = 0
- max_size = (0, 0)
- nr, nc = grid.shape
-
- num_c = nc
- for num_r in range(1, nr + 1):
- # Scan for maximum c
- for c in range(1, num_c + 1):
- x = grid[num_r - 1, c - 1]
- if (x < 2) | (x > 11):
- num_c = c - 1
- break
-
- area = num_r * num_c
- if area > max_area:
- max_area = area
- max_size = (num_r, num_c)
- return (grid[:max_size[0], :max_size[1]] - 2).astype(np.uint8)
- class ARC:
- required_outputs = {"inputs", "puzzle_identifiers", "q_halt_logits", "preds"}
-
- def __init__(self, data_path: str,
- eval_metadata: PuzzleDatasetMetadata,
- submission_K: int = 2,
- pass_Ks: Sequence[int] = (1, 2, 5, 10, 100, 1000),
- aggregated_voting: bool = True):
- super().__init__()
- self.pass_Ks = pass_Ks
- self.submission_K = submission_K
- self.aggregated_voting = aggregated_voting
- self.blank_identifier_id = eval_metadata.blank_identifier_id
- # Load identifiers and test puzzles
- with open(os.path.join(data_path, "identifiers.json"), "r") as f:
- self.identifier_map = json.load(f)
- with open(os.path.join(data_path, "test_puzzles.json"), "r") as f:
- self.test_puzzles = json.load(f)
-
- # States
- self._local_hmap = {}
- self._local_preds = {}
-
- def begin_eval(self):
- if not self.aggregated_voting:
- # Clear previous predictions
- self._local_hmap = {}
- self._local_preds = {}
-
- def update_batch(self, batch: Dict[str, torch.Tensor], preds: Dict[str, torch.Tensor]):
- # Collect required outputs to CPU
- outputs = {}
- q_values = None
- for collection in (batch, preds):
- for k, v in collection.items():
- if k in self.required_outputs:
- if k == "q_halt_logits":
- q_values = v.to(torch.float64).sigmoid().cpu()
- else:
- outputs[k] = v.cpu()
-
- assert q_values is not None
- # Remove padding from outputs
- mask = outputs["puzzle_identifiers"] != self.blank_identifier_id
- outputs = {k: v[mask] for k, v in outputs.items()}
- # Get predictions
- for identifier, input, pred, q in zip(outputs["puzzle_identifiers"].numpy(), outputs["inputs"].numpy(), outputs["preds"].numpy(), q_values.numpy()):
- name = self.identifier_map[identifier]
- orig_name, _inverse_fn = inverse_aug(name)
-
- input_hash = grid_hash(_inverse_fn(_crop(input)))
-
- pred = _inverse_fn(_crop(pred))
- assert np.all((pred >= 0) & (pred <= 9)), f"Puzzle {name}'s prediction out of 0-9 range." # Sanity check
- # Store into local state
- pred_hash = grid_hash(pred)
- self._local_hmap[pred_hash] = pred
-
- self._local_preds.setdefault(orig_name, {})
- self._local_preds[orig_name].setdefault(input_hash, [])
- self._local_preds[orig_name][input_hash].append((pred_hash, float(q)))
-
- def result(self, save_path: Optional[str], rank: int, world_size: int, group: Optional[torch.distributed.ProcessGroup] = None) -> Optional[Dict[str, float]]:
- # Gather predictions to rank 0 for voting
- global_hmap_preds = [None for _ in range(world_size)] if rank == 0 else None
- dist.gather_object((self._local_hmap, self._local_preds), global_hmap_preds, dst=0, group=group)
-
- # Rank 0 logic
- if rank != 0:
- return
- submission = {}
- correct = [0.0 for _ in range(len(self.pass_Ks))]
- for name, puzzle in self.test_puzzles.items():
- # Process test examples in this puzzle
- submission[name] = []
- num_test_correct = [0 for _ in range(len(self.pass_Ks))]
- for pair in puzzle["test"]:
- input_hash = grid_hash(arc_grid_to_np(pair["input"]))
- label_hash = grid_hash(arc_grid_to_np(pair["output"]))
-
- p_map = {}
- for hmap, preds in global_hmap_preds: # type: ignore
- for h, q in preds.get(name, {}).get(input_hash, {}):
- p_map.setdefault(h, [0, 0])
- p_map[h][0] += 1
- p_map[h][1] += q
-
- if not len(p_map):
- print (f"Puzzle {name} has no predictions.")
- continue
- for h, stats in p_map.items():
- stats[1] /= stats[0]
-
- p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)
- # vote for different Ks
- for i, k in enumerate(self.pass_Ks):
- ok = False
- for h, stats in p_map[:k]:
- ok |= h == label_hash
-
- num_test_correct[i] += ok
-
- # Query grids
- pred_grids = []
- for h, stats in p_map[:self.submission_K]:
- for hmap, preds in global_hmap_preds: # type: ignore
- if h in hmap:
- pred_grids.append(hmap[h])
- break
-
- # Pad to K
- while len(pred_grids) < self.submission_K:
- pred_grids.append(pred_grids[0])
-
- submission[name].append({f"attempt_{i + 1}": grid.tolist() for i, grid in enumerate(pred_grids)})
- # Total correctness
- for i in range(len(self.pass_Ks)):
- correct[i] += num_test_correct[i] / len(puzzle["test"])
- # Save submission
- if save_path is not None:
- with open(os.path.join(save_path, "submission.json"), "w") as f:
- json.dump(submission, f)
- # Final result
- all_results = {f"ARC/pass@{k}": correct[i] / len(self.test_puzzles) for i, k in enumerate(self.pass_Ks)}
- return all_results
|