| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- from typing import Union
- import torch
- from torch import nn
- import torch.distributed as dist
- from torch.optim.optimizer import Optimizer, ParamsT
- from models.common import trunc_normal_init_
- class CastedSparseEmbedding(nn.Module):
- def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
- super().__init__()
- self.cast_to = cast_to
- # Real Weights
- # Truncated LeCun normal init
- self.weights = nn.Buffer(
- trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
- )
- # Local weights and IDs
- # Local embeddings, with gradient, not persistent
- self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
- # Local embedding IDs, not persistent
- self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
- if not self.training:
- # Test mode, no gradient
- return self.weights[inputs].to(self.cast_to)
-
- # Training mode, fill puzzle embedding from weights
- with torch.no_grad():
- self.local_weights.copy_(self.weights[inputs])
- self.local_ids.copy_(inputs)
- return self.local_weights.to(self.cast_to)
- class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
- def __init__(
- self,
- params: ParamsT,
- world_size: int,
- lr: Union[float, torch.Tensor] = 1e-3,
- weight_decay: float = 1e-2,
- ):
- if not 0.0 <= lr:
- raise ValueError(f"Invalid learning rate: {lr}")
- if not 0.0 <= weight_decay:
- raise ValueError(f"Invalid weight_decay value: {weight_decay}")
- defaults = dict(
- lr=lr,
- weight_decay=weight_decay,
- world_size=world_size
- )
- super().__init__(params, defaults)
- @torch.no_grad
- def step(self, closure=None): # type: ignore
- for group in self.param_groups:
- # Find the sparse embedding weights
- local_weights_grad = None
- local_ids = None
- weights = None
-
- assert len(group["params"]) == 3
- for p in group["params"]:
- if p.requires_grad:
- local_weights_grad = p.grad
- elif p.ndim == 1:
- local_ids = p
- elif p.ndim == 2:
- weights = p
- else:
- assert False
-
- assert local_ids is not None
- assert weights is not None
-
- # Apply SignSGD
- # Adam ≈ SignSGD if gradient is very sparse
- if local_weights_grad is not None:
- _sparse_emb_signsgd_dist(
- local_weights_grad,
- local_ids,
- weights,
-
- lr=group["lr"],
- weight_decay=group["weight_decay"],
- world_size=group["world_size"]
- )
- def _sparse_emb_signsgd_dist(
- local_weights_grad: torch.Tensor,
- local_ids: torch.Tensor,
- weights: torch.Tensor,
-
- lr: float,
- weight_decay: float,
- world_size: int
- ) -> None:
- N, D = local_weights_grad.shape
-
- # All-gather
- all_weights_grad = local_weights_grad
- all_ids = local_ids
- if world_size > 1:
- all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
- all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
-
- dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
- dist.all_gather_into_tensor(all_ids, local_ids)
- # Unique
- grad_ids, inv = all_ids.unique(return_inverse=True)
- grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
- grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
- # SignSGD with decoupled weight decay
- p = weights[grad_ids]
- p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
- # Write updated slices back
- weights[grad_ids] = p
|