puzzle_dataset.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. import os
  2. import json
  3. from typing import Tuple, List, Dict, Optional
  4. import numpy as np
  5. import pydantic
  6. import torch
  7. from torch.utils.data import IterableDataset, get_worker_info
  8. from models.losses import IGNORE_LABEL_ID
  9. from dataset.common import PuzzleDatasetMetadata
  10. from argdantic import ArgParser
  11. from pydantic import BaseModel
  12. def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):
  13. # Pack examples into a full batch
  14. batch = []
  15. batch_puzzle_indices = []
  16. current_size = 0
  17. while (start_index < group_order.size) and (current_size < global_batch_size):
  18. # Pick a group and a puzzle from that group
  19. group_id = group_order[start_index]
  20. puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1])
  21. start_index += 1
  22. # Get range of the puzzle
  23. puzzle_start = puzzle_indices[puzzle_id]
  24. puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start)
  25. append_size = min(puzzle_size, global_batch_size - current_size)
  26. # Put into batch
  27. batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32))
  28. batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False))
  29. current_size += append_size
  30. return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices)
  31. class PuzzleDatasetConfig(pydantic.BaseModel):
  32. seed: int
  33. dataset_paths: List[str]
  34. global_batch_size: int
  35. test_set_mode: bool
  36. epochs_per_iter: int # Batch X epochs in an iteration to reduce overhead.
  37. rank: int
  38. num_replicas: int
  39. class PuzzleDataset(IterableDataset):
  40. def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
  41. super().__init__()
  42. self.config = config
  43. self.split = split
  44. # Merge multiple metadata
  45. prev_seq_len = None
  46. prev_vocab_size = None
  47. prev_pad_id = None
  48. prev_ignore_label_id = None
  49. prev_blank_identifier_id = None
  50. prev_sets = None
  51. prev_num_identifiers = None
  52. mean_puzzle_examples = 0
  53. total_puzzles = 0
  54. total_groups = 0
  55. num_identifiers = 0
  56. for dataset_path in config.dataset_paths:
  57. current_metadata = self._load_metadata(dataset_path)
  58. if prev_seq_len is None:
  59. prev_seq_len = current_metadata.seq_len
  60. prev_vocab_size = current_metadata.vocab_size
  61. prev_pad_id = current_metadata.pad_id
  62. prev_ignore_label_id = current_metadata.ignore_label_id
  63. prev_blank_identifier_id = current_metadata.blank_identifier_id
  64. prev_sets = current_metadata.sets
  65. prev_num_identifiers = current_metadata.num_puzzle_identifiers
  66. else:
  67. assert prev_seq_len == current_metadata.seq_len
  68. assert prev_vocab_size == current_metadata.vocab_size
  69. assert prev_pad_id == current_metadata.pad_id
  70. assert prev_ignore_label_id == current_metadata.ignore_label_id
  71. assert prev_blank_identifier_id == current_metadata.blank_identifier_id
  72. assert prev_sets == current_metadata.sets
  73. assert prev_num_identifiers == current_metadata.num_puzzle_identifiers
  74. mean_puzzle_examples += current_metadata.mean_puzzle_examples*current_metadata.total_puzzles
  75. total_puzzles += current_metadata.total_puzzles
  76. total_groups += current_metadata.total_groups
  77. num_identifiers += current_metadata.num_puzzle_identifiers
  78. mean_puzzle_examples = mean_puzzle_examples / total_puzzles
  79. self.metadata = PuzzleDatasetMetadata(
  80. seq_len=prev_seq_len,
  81. vocab_size=prev_vocab_size,
  82. pad_id=prev_pad_id,
  83. ignore_label_id=prev_ignore_label_id,
  84. blank_identifier_id=prev_blank_identifier_id,
  85. num_puzzle_identifiers=num_identifiers,
  86. total_groups=total_groups,
  87. mean_puzzle_examples=mean_puzzle_examples,
  88. total_puzzles=total_puzzles,
  89. sets=prev_sets
  90. )
  91. # Checks
  92. assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}."
  93. self.local_batch_size = self.config.global_batch_size // self.config.num_replicas
  94. # State
  95. self._data = None
  96. self._iters = 0
  97. def _load_metadata(self, dataset_path) -> PuzzleDatasetMetadata:
  98. with open(os.path.join(dataset_path, self.split, "dataset.json"), "r") as f:
  99. return PuzzleDatasetMetadata(**json.load(f))
  100. def _lazy_load_dataset(self):
  101. if self._data is not None:
  102. return
  103. field_mmap_modes = {
  104. "inputs": "r",
  105. "labels": "r",
  106. # Keep indices in memory
  107. "puzzle_identifiers": None,
  108. "puzzle_indices": None,
  109. "group_indices": None
  110. }
  111. # Load data
  112. self._data = {}
  113. for set_name in self.metadata.sets: # Load subset
  114. for i, dataset_path in enumerate(self.config.dataset_paths):
  115. if i > 0:
  116. set_name_ = set_name + str(i)
  117. else:
  118. set_name_ = set_name
  119. self._data[set_name_] = {
  120. field_name: np.load(os.path.join(dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode)
  121. for field_name, mmap_mode in field_mmap_modes.items()
  122. }
  123. def _collate_batch(self, batch):
  124. # Convert dtype
  125. batch = {k: v.astype(np.int32) for k, v in batch.items()}
  126. # Convert ignore label IDs
  127. if self.metadata.ignore_label_id is not None:
  128. batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID
  129. # Pad
  130. if batch["puzzle_identifiers"].size < self.local_batch_size:
  131. pad_size = self.local_batch_size - batch["puzzle_identifiers"].size
  132. pad_values = {
  133. "inputs": self.metadata.pad_id,
  134. "labels": IGNORE_LABEL_ID,
  135. "puzzle_identifiers": self.metadata.blank_identifier_id
  136. }
  137. batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()}
  138. # To tensor
  139. return {k: torch.from_numpy(v) for k, v in batch.items()}
  140. def _iter_test(self):
  141. for set_i, (set_name, dataset) in enumerate(self._data.items()): # type: ignore
  142. total_examples = len(dataset["inputs"])
  143. # Load examples one by one
  144. start_index = 0
  145. while start_index < total_examples:
  146. # Compute indices
  147. end_index = min(total_examples, start_index + self.config.global_batch_size)
  148. local_start = start_index + self.config.rank * self.local_batch_size
  149. local_end = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index)
  150. # Get batch of examples, and also puzzle IDs
  151. puzzle_indices = []
  152. puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1
  153. for i in range(local_start, local_end):
  154. while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]:
  155. puzzle_index += 1
  156. puzzle_indices.append(puzzle_index)
  157. batch = self._collate_batch({
  158. "inputs": dataset["inputs"][local_start: local_end],
  159. "labels": dataset["labels"][local_start: local_end],
  160. "puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices]
  161. })
  162. yield set_name, batch, end_index - start_index
  163. # Advance to next batch
  164. start_index += self.config.global_batch_size
  165. def _iter_train(self):
  166. for set_name, dataset in self._data.items(): # type: ignore
  167. # Increase epoch count
  168. self._iters += 1
  169. # Randomly shuffle groups
  170. rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters))
  171. group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)])
  172. start_index = 0
  173. while start_index < group_order.size:
  174. start_index, batch_indices, batch_puzzle_indices = _sample_batch(
  175. rng,
  176. group_order=group_order,
  177. puzzle_indices=dataset["puzzle_indices"],
  178. group_indices=dataset["group_indices"],
  179. start_index=start_index,
  180. global_batch_size=self.config.global_batch_size,
  181. )
  182. # Select current rank and collate
  183. global_effective_batch_size = batch_puzzle_indices.size # Global effective batch size, excluding pads
  184. # Drop last batch
  185. if global_effective_batch_size < self.config.global_batch_size:
  186. break
  187. batch_indices = batch_indices [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
  188. batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
  189. batch = self._collate_batch({
  190. "inputs": dataset["inputs"][batch_indices],
  191. "labels": dataset["labels"][batch_indices],
  192. "puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices]
  193. })
  194. yield set_name, batch, global_effective_batch_size
  195. def __iter__(self):
  196. worker_info = get_worker_info()
  197. assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported."
  198. self._lazy_load_dataset()
  199. # Iterate using specified mode
  200. if self.config.test_set_mode:
  201. yield from self._iter_test()
  202. else:
  203. yield from self._iter_train()