sparse_embedding.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. from typing import Union
  2. import torch
  3. from torch import nn
  4. import torch.distributed as dist
  5. from torch.optim.optimizer import Optimizer, ParamsT
  6. from models.common import trunc_normal_init_
  7. class CastedSparseEmbedding(nn.Module):
  8. def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
  9. super().__init__()
  10. self.cast_to = cast_to
  11. # Real Weights
  12. # Truncated LeCun normal init
  13. self.weights = nn.Buffer(
  14. trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
  15. )
  16. # Local weights and IDs
  17. # Local embeddings, with gradient, not persistent
  18. self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
  19. # Local embedding IDs, not persistent
  20. self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
  21. def forward(self, inputs: torch.Tensor) -> torch.Tensor:
  22. if not self.training:
  23. # Test mode, no gradient
  24. return self.weights[inputs].to(self.cast_to)
  25. # Training mode, fill puzzle embedding from weights
  26. with torch.no_grad():
  27. self.local_weights.copy_(self.weights[inputs])
  28. self.local_ids.copy_(inputs)
  29. return self.local_weights.to(self.cast_to)
  30. class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
  31. def __init__(
  32. self,
  33. params: ParamsT,
  34. world_size: int,
  35. lr: Union[float, torch.Tensor] = 1e-3,
  36. weight_decay: float = 1e-2,
  37. ):
  38. if not 0.0 <= lr:
  39. raise ValueError(f"Invalid learning rate: {lr}")
  40. if not 0.0 <= weight_decay:
  41. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  42. defaults = dict(
  43. lr=lr,
  44. weight_decay=weight_decay,
  45. world_size=world_size
  46. )
  47. super().__init__(params, defaults)
  48. @torch.no_grad
  49. def step(self, closure=None): # type: ignore
  50. for group in self.param_groups:
  51. # Find the sparse embedding weights
  52. local_weights_grad = None
  53. local_ids = None
  54. weights = None
  55. assert len(group["params"]) == 3
  56. for p in group["params"]:
  57. if p.requires_grad:
  58. local_weights_grad = p.grad
  59. elif p.ndim == 1:
  60. local_ids = p
  61. elif p.ndim == 2:
  62. weights = p
  63. else:
  64. assert False
  65. assert local_ids is not None
  66. assert weights is not None
  67. # Apply SignSGD
  68. # Adam ≈ SignSGD if gradient is very sparse
  69. if local_weights_grad is not None:
  70. _sparse_emb_signsgd_dist(
  71. local_weights_grad,
  72. local_ids,
  73. weights,
  74. lr=group["lr"],
  75. weight_decay=group["weight_decay"],
  76. world_size=group["world_size"]
  77. )
  78. def _sparse_emb_signsgd_dist(
  79. local_weights_grad: torch.Tensor,
  80. local_ids: torch.Tensor,
  81. weights: torch.Tensor,
  82. lr: float,
  83. weight_decay: float,
  84. world_size: int
  85. ) -> None:
  86. N, D = local_weights_grad.shape
  87. # All-gather
  88. all_weights_grad = local_weights_grad
  89. all_ids = local_ids
  90. if world_size > 1:
  91. all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
  92. all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
  93. dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
  94. dist.all_gather_into_tensor(all_ids, local_ids)
  95. # Unique
  96. grad_ids, inv = all_ids.unique(return_inverse=True)
  97. grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
  98. grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
  99. # SignSGD with decoupled weight decay
  100. p = weights[grad_ids]
  101. p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
  102. # Write updated slices back
  103. weights[grad_ids] = p