layers.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. from typing import Tuple
  2. import einops
  3. import torch
  4. from torch import nn
  5. import torch.nn.functional as F
  6. #try:
  7. # from flash_attn_interface import flash_attn_func # type: ignore[import]
  8. #except ImportError:
  9. # # Fallback to FlashAttention 2
  10. # from flash_attn import flash_attn_func # type: ignore[import]
  11. from torch.nn.functional import scaled_dot_product_attention
  12. from models.common import trunc_normal_init_
  13. CosSin = Tuple[torch.Tensor, torch.Tensor]
  14. def _find_multiple(a, b):
  15. return (-(a // -b)) * b
  16. def rotate_half(x: torch.Tensor):
  17. """Rotates half the hidden dims of the input."""
  18. x1 = x[..., : x.shape[-1] // 2]
  19. x2 = x[..., x.shape[-1] // 2 :]
  20. return torch.cat((-x2, x1), dim=-1)
  21. def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
  22. # q, k: [bs, seq_len, num_heads, head_dim]
  23. # cos, sin: [seq_len, head_dim]
  24. orig_dtype = q.dtype
  25. q = q.to(cos.dtype)
  26. k = k.to(cos.dtype)
  27. q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
  28. k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
  29. return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
  30. class CastedLinear(nn.Module):
  31. def __init__(self,
  32. in_features: int,
  33. out_features: int,
  34. bias: bool):
  35. super().__init__()
  36. # Truncated LeCun normal init
  37. self.weight = nn.Parameter(
  38. trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
  39. )
  40. self.bias = None
  41. if bias:
  42. # Zero init bias
  43. self.bias = nn.Parameter(torch.zeros((out_features, )))
  44. def forward(self, input: torch.Tensor) -> torch.Tensor:
  45. return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
  46. class CastedEmbedding(nn.Module):
  47. def __init__(self,
  48. num_embeddings: int,
  49. embedding_dim: int,
  50. init_std: float,
  51. cast_to: torch.dtype):
  52. super().__init__()
  53. self.cast_to = cast_to
  54. # Truncated LeCun normal init
  55. self.embedding_weight = nn.Parameter(
  56. trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
  57. )
  58. def forward(self, input: torch.Tensor) -> torch.Tensor:
  59. return F.embedding(input, self.embedding_weight.to(self.cast_to))
  60. class RotaryEmbedding(nn.Module):
  61. def __init__(self, dim, max_position_embeddings, base, device=None):
  62. super().__init__()
  63. # RoPE
  64. inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
  65. t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
  66. freqs = torch.outer(t, inv_freq)
  67. # Different from paper, but it uses a different permutation in order to obtain the same calculation
  68. emb = torch.cat((freqs, freqs), dim=-1)
  69. self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
  70. self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
  71. def forward(self):
  72. return self.cos_cached, self.sin_cached
  73. class Attention(nn.Module):
  74. def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
  75. super().__init__()
  76. self.hidden_size = hidden_size
  77. self.head_dim = head_dim
  78. self.output_size = head_dim * num_heads
  79. self.num_heads = num_heads
  80. self.num_key_value_heads = num_key_value_heads
  81. self.causal = causal
  82. self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
  83. self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
  84. def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
  85. batch_size, seq_len, _ = hidden_states.shape
  86. # hidden_states: [bs, seq_len, num_heads, head_dim]
  87. qkv = self.qkv_proj(hidden_states)
  88. # Split head
  89. qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
  90. query = qkv[:, :, :self.num_heads]
  91. key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
  92. value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
  93. # RoPE
  94. if cos_sin is not None:
  95. cos, sin = cos_sin
  96. query, key = apply_rotary_pos_emb(query, key, cos, sin)
  97. # flash attn
  98. query, key, value = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (query, key, value)) # needed for scaled_dot_product_attention but not flash_attn_func
  99. attn_output = scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
  100. attn_output = einops.rearrange(attn_output, 'B H S D -> B S H D')
  101. attn_output = attn_output.reshape(batch_size, seq_len, self.output_size) # type: ignore
  102. return self.o_proj(attn_output)
  103. class LinearSwish(nn.Module):
  104. def __init__(self, hidden_size: int, reverse=False):
  105. super().__init__()
  106. self.linear = CastedLinear(hidden_size, hidden_size, bias=False)
  107. self.reverse = reverse
  108. def forward(self, x):
  109. if self.reverse:
  110. return F.silu(self.linear(x))
  111. else:
  112. return self.linear(F.silu(x))
  113. class SwiGLU(nn.Module):
  114. def __init__(self, hidden_size: int, expansion: float):
  115. super().__init__()
  116. inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
  117. self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
  118. self.down_proj = CastedLinear(inter, hidden_size, bias=False)
  119. def forward(self, x):
  120. gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
  121. return self.down_proj(F.silu(gate) * up)
  122. def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
  123. input_dtype = hidden_states.dtype
  124. hidden_states = hidden_states.to(torch.float32)
  125. variance = hidden_states.square().mean(-1, keepdim=True)
  126. hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
  127. return hidden_states.to(input_dtype)