hrm.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. from typing import Tuple, List, Dict, Optional
  2. from dataclasses import dataclass
  3. import math
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import nn
  7. from pydantic import BaseModel
  8. from models.common import trunc_normal_init_
  9. from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
  10. from models.sparse_embedding import CastedSparseEmbedding
  11. @dataclass
  12. class HierarchicalReasoningModel_ACTV1InnerCarry:
  13. z_H: torch.Tensor
  14. z_L: torch.Tensor
  15. @dataclass
  16. class HierarchicalReasoningModel_ACTV1Carry:
  17. inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
  18. steps: torch.Tensor
  19. halted: torch.Tensor
  20. current_data: Dict[str, torch.Tensor]
  21. class HierarchicalReasoningModel_ACTV1Config(BaseModel):
  22. batch_size: int
  23. seq_len: int
  24. puzzle_emb_ndim: int = 0
  25. num_puzzle_identifiers: int
  26. vocab_size: int
  27. H_cycles: int
  28. L_cycles: int
  29. H_layers: int
  30. L_layers: int
  31. # Transformer config
  32. hidden_size: int
  33. expansion: float
  34. num_heads: int
  35. pos_encodings: str
  36. rms_norm_eps: float = 1e-5
  37. rope_theta: float = 10000.0
  38. # Halting Q-learning config
  39. halt_max_steps: int
  40. halt_exploration_prob: float
  41. forward_dtype: str = "bfloat16"
  42. # Alexia: added
  43. mlp_t: bool=False # use mlp on L instead of transformer
  44. class HierarchicalReasoningModel_ACTV1Block(nn.Module):
  45. def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
  46. super().__init__()
  47. self.config = config
  48. if self.config.mlp_t:
  49. self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
  50. self.mlp_t = SwiGLU(
  51. hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
  52. expansion=config.expansion,
  53. )
  54. else:
  55. self.self_attn = Attention(
  56. hidden_size=config.hidden_size,
  57. head_dim=config.hidden_size // config.num_heads,
  58. num_heads=config.num_heads,
  59. num_key_value_heads=config.num_heads,
  60. causal=False
  61. )
  62. self.mlp = SwiGLU(
  63. hidden_size=config.hidden_size,
  64. expansion=config.expansion,
  65. )
  66. self.norm_eps = config.rms_norm_eps
  67. def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
  68. # B, L, D = hidden_states.shape
  69. # Post Norm
  70. if self.config.mlp_t:
  71. hidden_states = hidden_states.transpose(1,2)
  72. out = self.mlp_t(hidden_states)
  73. hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
  74. hidden_states = hidden_states.transpose(1,2)
  75. else:
  76. # Self Attention
  77. hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
  78. # Fully Connected
  79. out = self.mlp(hidden_states)
  80. hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
  81. return hidden_states
  82. class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
  83. def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
  84. super().__init__()
  85. self.layers = torch.nn.ModuleList(layers)
  86. def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
  87. # Input injection (add)
  88. hidden_states = hidden_states + input_injection
  89. # Layers
  90. for layer in self.layers:
  91. hidden_states = layer(hidden_states=hidden_states, **kwargs)
  92. return hidden_states
  93. class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
  94. def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
  95. super().__init__()
  96. self.config = config
  97. self.forward_dtype = getattr(torch, self.config.forward_dtype)
  98. # I/O
  99. self.embed_scale = math.sqrt(self.config.hidden_size)
  100. embed_init_std = 1.0 / self.embed_scale
  101. self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
  102. self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
  103. self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
  104. self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
  105. if self.config.puzzle_emb_ndim > 0:
  106. # Zero init puzzle embeddings
  107. self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
  108. batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
  109. # LM Blocks
  110. if self.config.pos_encodings == "rope":
  111. self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
  112. max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
  113. base=self.config.rope_theta)
  114. elif self.config.pos_encodings == "learned":
  115. self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
  116. else:
  117. pass
  118. # Reasoning Layers
  119. self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
  120. self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
  121. # Initial states
  122. self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
  123. self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
  124. # Q head special init
  125. # Init Q to (almost) zero for faster learning during bootstrapping
  126. with torch.no_grad():
  127. self.q_head.weight.zero_()
  128. self.q_head.bias.fill_(-5) # type: ignore
  129. def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
  130. # Token embedding
  131. embedding = self.embed_tokens(input.to(torch.int32))
  132. # Puzzle embeddings
  133. if self.config.puzzle_emb_ndim > 0:
  134. puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
  135. pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
  136. if pad_count > 0:
  137. puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
  138. embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
  139. # Position embeddings
  140. if self.config.pos_encodings == "learned":
  141. # scale by 1/sqrt(2) to maintain forward variance
  142. embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
  143. # Scale
  144. return self.embed_scale * embedding
  145. def empty_carry(self, batch_size: int):
  146. return HierarchicalReasoningModel_ACTV1InnerCarry(
  147. z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
  148. z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
  149. )
  150. def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
  151. return HierarchicalReasoningModel_ACTV1InnerCarry(
  152. z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
  153. z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
  154. )
  155. def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  156. seq_info = dict(
  157. cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
  158. )
  159. # Input encoding
  160. input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
  161. # Forward iterations
  162. with torch.no_grad():
  163. z_H, z_L = carry.z_H, carry.z_L
  164. for _H_step in range(self.config.H_cycles):
  165. for _L_step in range(self.config.L_cycles):
  166. if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
  167. z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
  168. if not (_H_step == self.config.H_cycles - 1):
  169. z_H = self.H_level(z_H, z_L, **seq_info)
  170. assert not z_H.requires_grad and not z_L.requires_grad
  171. # 1-step grad
  172. z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
  173. z_H = self.H_level(z_H, z_L, **seq_info)
  174. # LM Outputs
  175. new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
  176. output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
  177. # Q head
  178. q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
  179. return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
  180. class HierarchicalReasoningModel_ACTV1(nn.Module):
  181. """ACT wrapper."""
  182. def __init__(self, config_dict: dict):
  183. super().__init__()
  184. self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
  185. self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
  186. @property
  187. def puzzle_emb(self):
  188. return self.inner.puzzle_emb
  189. def initial_carry(self, batch: Dict[str, torch.Tensor]):
  190. batch_size = batch["inputs"].shape[0]
  191. return HierarchicalReasoningModel_ACTV1Carry(
  192. inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
  193. steps=torch.zeros((batch_size, ), dtype=torch.int32),
  194. halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
  195. current_data={k: torch.empty_like(v) for k, v in batch.items()}
  196. )
  197. def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
  198. # Update data, carry (removing halted sequences)
  199. new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
  200. new_steps = torch.where(carry.halted, 0, carry.steps)
  201. new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
  202. # Forward inner model
  203. new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
  204. outputs = {
  205. "logits": logits,
  206. "q_halt_logits": q_halt_logits,
  207. "q_continue_logits": q_continue_logits
  208. }
  209. with torch.no_grad():
  210. # Step
  211. new_steps = new_steps + 1
  212. is_last_step = new_steps >= self.config.halt_max_steps
  213. halted = is_last_step
  214. # if training, and ACT is enabled
  215. if self.training and (self.config.halt_max_steps > 1):
  216. # Halt signal
  217. # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
  218. halted = halted | (q_halt_logits > q_continue_logits)
  219. # Exploration
  220. min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
  221. halted = halted & (new_steps >= min_halt_steps)
  222. # Compute target Q
  223. # NOTE: No replay buffer and target networks for computing target Q-value.
  224. # As batch_size is large, there're many parallel envs.
  225. # Similar concept as PQN https://arxiv.org/abs/2407.04811
  226. next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
  227. outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
  228. return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs