losses.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from typing import Any, Tuple, Dict, Sequence, Optional
  2. import torch
  3. import torch.nn.functional as F
  4. from torch import nn
  5. import math
  6. IGNORE_LABEL_ID = -100
  7. def s(x, epsilon=1e-30):
  8. return torch.where(
  9. x<0,
  10. 1/(1-x+ epsilon),
  11. x + 1
  12. )
  13. def log_stablemax(x, dim=-1):
  14. s_x = s(x)
  15. return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
  16. def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
  17. logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
  18. if valid_mask is None:
  19. valid_mask = (labels != ignore_index)
  20. transformed_labels = torch.where(valid_mask, labels, 0)
  21. prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
  22. return -torch.where(valid_mask, prediction_logprobs, 0)
  23. def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
  24. # Cast logits to f32
  25. # Flatten logits
  26. return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
  27. class ACTLossHead(nn.Module):
  28. def __init__(self, model: nn.Module, loss_type: str):
  29. super().__init__()
  30. self.model = model
  31. self.loss_fn = globals()[loss_type]
  32. def initial_carry(self, *args, **kwargs):
  33. return self.model.initial_carry(*args, **kwargs) # type: ignore
  34. def forward(
  35. self,
  36. return_keys: Sequence[str],
  37. # Model args
  38. **model_kwargs,
  39. ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
  40. # Model logits
  41. # B x SeqLen x D
  42. new_carry, outputs = self.model(**model_kwargs)
  43. labels = new_carry.current_data["labels"]
  44. with torch.no_grad():
  45. # Preds
  46. outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
  47. # Correctness
  48. mask = (labels != IGNORE_LABEL_ID)
  49. loss_counts = mask.sum(-1)
  50. loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
  51. is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
  52. seq_is_correct = is_correct.sum(-1) == loss_counts
  53. # Metrics (halted)
  54. valid_metrics = new_carry.halted & (loss_counts > 0)
  55. metrics = {
  56. "count": valid_metrics.sum(),
  57. "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
  58. "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
  59. "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
  60. "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
  61. }
  62. # Losses
  63. lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
  64. q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
  65. metrics.update({
  66. "lm_loss": lm_loss.detach(),
  67. "q_halt_loss": q_halt_loss.detach(),
  68. })
  69. # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
  70. q_continue_loss = 0
  71. if "target_q_continue" in outputs:
  72. q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
  73. metrics["q_continue_loss"] = q_continue_loss.detach()
  74. # Filter outputs for return
  75. detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
  76. return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()