transformers_baseline.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. """
  2. HRM ACT V2: Transformer Baseline for Architecture Ablation
  3. This is an architecture ablation of the Hierarchical Reasoning Model (HRM).
  4. Key changes from V1:
  5. 1. REMOVED hierarchical split (no separate H and L levels)
  6. 2. REMOVED inner cycles (no H_cycles/L_cycles loops within reasoning)
  7. 3. KEPT ACT outer loop structure intact
  8. 4. KEPT all data preprocessing, embeddings, and evaluation infrastructure
  9. Architecture: Single-level transformer that processes the full 30x30 grid as a
  10. 900-token sequence, with the same positional encodings and sparse embeddings as V1.
  11. """
  12. from typing import Tuple, List, Dict, Optional
  13. from dataclasses import dataclass
  14. import math
  15. import torch
  16. import torch.nn.functional as F
  17. from torch import nn
  18. from pydantic import BaseModel
  19. from models.common import trunc_normal_init_
  20. from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
  21. from models.sparse_embedding import CastedSparseEmbedding
  22. @dataclass
  23. class Model_ACTV2InnerCarry:
  24. z_H: torch.Tensor
  25. @dataclass
  26. class Model_ACTV2Carry:
  27. inner_carry: Model_ACTV2InnerCarry
  28. steps: torch.Tensor
  29. halted: torch.Tensor
  30. current_data: Dict[str, torch.Tensor]
  31. class Model_ACTV2Config(BaseModel):
  32. batch_size: int
  33. seq_len: int
  34. puzzle_emb_ndim: int = 0
  35. num_puzzle_identifiers: int
  36. vocab_size: int
  37. H_cycles: int
  38. H_layers: int
  39. # Transformer config
  40. hidden_size: int
  41. expansion: float
  42. num_heads: int
  43. pos_encodings: str
  44. rms_norm_eps: float = 1e-5
  45. rope_theta: float = 10000.0
  46. # Halting Q-learning config
  47. halt_max_steps: int
  48. halt_exploration_prob: float
  49. act_enabled: bool = True # If False, always run halt_max_steps (no early stopping during training)
  50. act_inference: bool = False # If True, use adaptive computation during inference
  51. forward_dtype: str = "bfloat16"
  52. class Model_ACTV2Block(nn.Module):
  53. def __init__(self, config: Model_ACTV2Config) -> None:
  54. super().__init__()
  55. self.self_attn = Attention(
  56. hidden_size=config.hidden_size,
  57. head_dim=config.hidden_size // config.num_heads,
  58. num_heads=config.num_heads,
  59. num_key_value_heads=config.num_heads,
  60. causal=False,
  61. )
  62. self.mlp = SwiGLU(
  63. hidden_size=config.hidden_size,
  64. expansion=config.expansion,
  65. )
  66. self.norm_eps = config.rms_norm_eps
  67. def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
  68. # Post Norm
  69. # Self Attention
  70. hidden_states = rms_norm(
  71. hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
  72. variance_epsilon=self.norm_eps,
  73. )
  74. # Fully Connected
  75. hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
  76. return hidden_states
  77. class Model_ACTV2ReasoningModule(nn.Module):
  78. def __init__(self, layers: List[Model_ACTV2Block]):
  79. super().__init__()
  80. self.layers = torch.nn.ModuleList(layers)
  81. def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
  82. # Input injection (add)
  83. hidden_states = hidden_states + input_injection
  84. # Layers
  85. for layer in self.layers:
  86. hidden_states = layer(hidden_states=hidden_states, **kwargs)
  87. return hidden_states
  88. class Model_ACTV2_Inner(nn.Module):
  89. def __init__(self, config: Model_ACTV2Config) -> None:
  90. super().__init__()
  91. self.config = config
  92. self.forward_dtype = getattr(torch, self.config.forward_dtype)
  93. # I/O
  94. self.embed_scale = math.sqrt(self.config.hidden_size)
  95. embed_init_std = 1.0 / self.embed_scale
  96. self.embed_tokens = CastedEmbedding(
  97. self.config.vocab_size,
  98. self.config.hidden_size,
  99. init_std=embed_init_std,
  100. cast_to=self.forward_dtype,
  101. )
  102. self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
  103. self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
  104. self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
  105. if self.config.puzzle_emb_ndim > 0:
  106. # Zero init puzzle embeddings
  107. self.puzzle_emb = CastedSparseEmbedding(
  108. self.config.num_puzzle_identifiers,
  109. self.config.puzzle_emb_ndim,
  110. batch_size=self.config.batch_size,
  111. init_std=0,
  112. cast_to=self.forward_dtype,
  113. )
  114. # LM Blocks
  115. if self.config.pos_encodings == "rope":
  116. self.rotary_emb = RotaryEmbedding(
  117. dim=self.config.hidden_size // self.config.num_heads,
  118. max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
  119. base=self.config.rope_theta,
  120. )
  121. elif self.config.pos_encodings == "learned":
  122. self.embed_pos = CastedEmbedding(
  123. self.config.seq_len + self.puzzle_emb_len,
  124. self.config.hidden_size,
  125. init_std=embed_init_std,
  126. cast_to=self.forward_dtype,
  127. )
  128. else:
  129. raise NotImplementedError()
  130. # Reasoning Layers
  131. self.H_level = Model_ACTV2ReasoningModule(
  132. layers=[Model_ACTV2Block(self.config) for _i in range(self.config.H_layers)]
  133. )
  134. # Initial states
  135. self.H_init = nn.Buffer(
  136. trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1),
  137. persistent=True,
  138. )
  139. # Q head special init
  140. # Init Q to (almost) zero for faster learning during bootstrapping
  141. with torch.no_grad():
  142. self.q_head.weight.zero_()
  143. self.q_head.bias.fill_(-5) # type: ignore
  144. def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
  145. # Token embedding
  146. embedding = self.embed_tokens(input.to(torch.int32))
  147. # Puzzle embeddings
  148. if self.config.puzzle_emb_ndim > 0:
  149. puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
  150. pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
  151. if pad_count > 0:
  152. puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
  153. embedding = torch.cat(
  154. (puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2
  155. )
  156. # Position embeddings
  157. if self.config.pos_encodings == "learned":
  158. # scale by 1/sqrt(2) to maintain forward variance
  159. embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
  160. # Scale
  161. return self.embed_scale * embedding
  162. def empty_carry(self, batch_size: int):
  163. return Model_ACTV2InnerCarry(
  164. z_H=torch.empty(
  165. batch_size,
  166. self.config.seq_len + self.puzzle_emb_len,
  167. self.config.hidden_size,
  168. dtype=self.forward_dtype,
  169. ),
  170. )
  171. def reset_carry(self, reset_flag: torch.Tensor, carry: Model_ACTV2InnerCarry):
  172. return Model_ACTV2InnerCarry(
  173. z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
  174. )
  175. def forward(
  176. self, carry: Model_ACTV2InnerCarry, batch: Dict[str, torch.Tensor]
  177. ) -> Tuple[Model_ACTV2InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  178. seq_info = dict(
  179. cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
  180. )
  181. # Input encoding
  182. input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
  183. # 1-step grad
  184. z_H = self.H_level(carry.z_H, input_embeddings, **seq_info)
  185. # LM Outputs
  186. new_carry = Model_ACTV2InnerCarry(
  187. z_H=z_H.detach(),
  188. ) # New carry no grad
  189. output = self.lm_head(z_H)[:, self.puzzle_emb_len :]
  190. # Q head
  191. q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
  192. return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
  193. class Model_ACTV2(nn.Module):
  194. """ACT wrapper."""
  195. def __init__(self, config_dict: dict):
  196. super().__init__()
  197. self.config = Model_ACTV2Config(**config_dict)
  198. self.inner = Model_ACTV2_Inner(self.config)
  199. @property
  200. def puzzle_emb(self):
  201. return self.inner.puzzle_emb
  202. def initial_carry(self, batch: Dict[str, torch.Tensor]):
  203. batch_size = batch["inputs"].shape[0]
  204. return Model_ACTV2Carry(
  205. inner_carry=self.inner.empty_carry(
  206. batch_size
  207. ), # Empty is expected, it will be reseted in first pass as all sequences are halted.
  208. steps=torch.zeros((batch_size,), dtype=torch.int32),
  209. halted=torch.ones((batch_size,), dtype=torch.bool), # Default to halted
  210. current_data={k: torch.empty_like(v) for k, v in batch.items()},
  211. )
  212. def forward(
  213. self,
  214. carry: Model_ACTV2Carry,
  215. batch: Dict[str, torch.Tensor],
  216. compute_target_q: bool = False,
  217. ) -> Tuple[Model_ACTV2Carry, Dict[str, torch.Tensor]]:
  218. # Update data, carry (removing halted sequences)
  219. new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
  220. new_steps = torch.where(carry.halted, 0, carry.steps)
  221. new_current_data = {
  222. k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v)
  223. for k, v in carry.current_data.items()
  224. }
  225. # Forward inner model
  226. new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(
  227. new_inner_carry, new_current_data
  228. )
  229. outputs = {"logits": logits, "q_halt_logits": q_halt_logits, "q_continue_logits": q_continue_logits}
  230. with torch.no_grad():
  231. # Step
  232. new_steps = new_steps + 1
  233. is_last_step = new_steps >= self.config.halt_max_steps
  234. halted = is_last_step
  235. # Check if adaptive computation should be used
  236. use_adaptive = (self.config.halt_max_steps > 1) and (
  237. (self.training and self.config.act_enabled)
  238. or (not self.training and self.config.act_inference)
  239. )
  240. if use_adaptive:
  241. # Halt signal based on Q-values (but always halt at max steps)
  242. q_halt_signal = q_halt_logits > q_continue_logits
  243. halted = halted | q_halt_signal
  244. # Store actual steps used for logging (only during inference)
  245. if not self.training:
  246. outputs["actual_steps"] = new_steps.float()
  247. # Exploration (only during training)
  248. if self.training:
  249. min_halt_steps = (
  250. torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob
  251. ) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
  252. halted = halted & (new_steps >= min_halt_steps)
  253. # Compute target Q (only during training)
  254. # NOTE: No replay buffer and target networks for computing target Q-value.
  255. # As batch_size is large, there're many parallel envs.
  256. # Similar concept as PQN https://arxiv.org/abs/2407.04811
  257. if self.training and compute_target_q:
  258. next_q_halt_logits, next_q_continue_logits = self.inner(
  259. new_inner_carry, new_current_data
  260. )[-1]
  261. outputs["target_q_continue"] = torch.sigmoid(
  262. torch.where(
  263. is_last_step,
  264. next_q_halt_logits,
  265. torch.maximum(next_q_halt_logits, next_q_continue_logits),
  266. )
  267. )
  268. return Model_ACTV2Carry(
  269. new_inner_carry, new_steps, halted, new_current_data
  270. ), outputs