pretrain.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. from typing import Optional, Any, Sequence, List
  2. from dataclasses import dataclass
  3. import os
  4. import math
  5. import yaml
  6. import shutil
  7. import copy
  8. import torch
  9. import torch.distributed as dist
  10. from torch import nn
  11. from torch.utils.data import DataLoader
  12. import tqdm
  13. import wandb
  14. import coolname
  15. import hydra
  16. import pydantic
  17. from omegaconf import DictConfig
  18. from adam_atan2 import AdamATan2
  19. from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata
  20. from utils.functions import load_model_class, get_model_source_path
  21. from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed
  22. from models.ema import EMAHelper
  23. class LossConfig(pydantic.BaseModel):
  24. model_config = pydantic.ConfigDict(extra='allow')
  25. name: str
  26. class ArchConfig(pydantic.BaseModel):
  27. model_config = pydantic.ConfigDict(extra='allow')
  28. name: str
  29. loss: LossConfig
  30. class EvaluatorConfig(pydantic.BaseModel):
  31. model_config = pydantic.ConfigDict(extra="allow")
  32. name: str
  33. class PretrainConfig(pydantic.BaseModel):
  34. # Config
  35. arch: ArchConfig
  36. # Data
  37. data_paths: List[str]
  38. data_paths_test: List[str] = []
  39. # Evaluators
  40. evaluators: List[EvaluatorConfig] = []
  41. # Hyperparams
  42. global_batch_size: int
  43. epochs: int
  44. lr: float
  45. lr_min_ratio: float
  46. lr_warmup_steps: int
  47. weight_decay: float
  48. beta1: float
  49. beta2: float
  50. # Puzzle embedding
  51. puzzle_emb_lr: float
  52. puzzle_emb_weight_decay: float
  53. # Names
  54. project_name: Optional[str] = None
  55. run_name: Optional[str] = None
  56. load_checkpoint: Optional[str] = None
  57. checkpoint_path: Optional[str] = None
  58. # Extras
  59. seed: int = 0
  60. checkpoint_every_eval: bool = False
  61. eval_interval: Optional[int] = None
  62. min_eval_interval: Optional[int] = 0 # when to start eval
  63. eval_save_outputs: List[str] = []
  64. ema: bool = False # use Exponential-Moving-Average
  65. ema_rate: float = 0.999 # EMA-rate
  66. freeze_weights: bool = False # If True, freeze weights and only learn the embeddings
  67. @dataclass
  68. class TrainState:
  69. model: nn.Module
  70. optimizers: Sequence[torch.optim.Optimizer]
  71. optimizer_lrs: Sequence[float]
  72. carry: Any
  73. step: int
  74. total_steps: int
  75. def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs):
  76. dataset = PuzzleDataset(PuzzleDatasetConfig(
  77. seed=config.seed,
  78. dataset_paths=config.data_paths_test if len(config.data_paths_test)>0 and split=="test" else config.data_paths,
  79. rank=rank,
  80. num_replicas=world_size,
  81. **kwargs
  82. ), split=split)
  83. dataloader = DataLoader(
  84. dataset,
  85. batch_size=None,
  86. num_workers=1,
  87. prefetch_factor=8,
  88. pin_memory=True,
  89. persistent_workers=True
  90. )
  91. return dataloader, dataset.metadata
  92. def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
  93. model_cfg = dict(
  94. **config.arch.__pydantic_extra__, # type: ignore
  95. batch_size=config.global_batch_size // world_size,
  96. vocab_size=train_metadata.vocab_size,
  97. seq_len=train_metadata.seq_len,
  98. num_puzzle_identifiers=train_metadata.num_puzzle_identifiers,
  99. causal=False # Non-autoregressive
  100. )
  101. # Instantiate model with loss head
  102. model_cls = load_model_class(config.arch.name)
  103. loss_head_cls = load_model_class(config.arch.loss.name)
  104. with torch.device("cuda"):
  105. model: nn.Module = model_cls(model_cfg)
  106. print(model)
  107. model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore
  108. if "DISABLE_COMPILE" not in os.environ:
  109. model = torch.compile(model) # type: ignore
  110. # Load checkpoint
  111. if rank == 0:
  112. load_checkpoint(model, config)
  113. # Broadcast parameters from rank 0
  114. if world_size > 1:
  115. with torch.no_grad():
  116. for param in list(model.parameters()) + list(model.buffers()):
  117. dist.broadcast(param, src=0)
  118. # Optimizers and lr
  119. if config.arch.puzzle_emb_ndim == 0:
  120. optimizers = [
  121. AdamATan2(
  122. model.parameters(),
  123. lr=0, # Needs to be set by scheduler
  124. weight_decay=config.weight_decay,
  125. betas=(config.beta1, config.beta2)
  126. )
  127. ]
  128. optimizer_lrs = [
  129. config.lr
  130. ]
  131. elif config.freeze_weights:
  132. optimizers = [
  133. CastedSparseEmbeddingSignSGD_Distributed(
  134. model.model.puzzle_emb.buffers(), # type: ignore
  135. lr=0, # Needs to be set by scheduler
  136. weight_decay=config.puzzle_emb_weight_decay,
  137. world_size=world_size
  138. )
  139. ]
  140. optimizer_lrs = [
  141. config.puzzle_emb_lr
  142. ]
  143. else:
  144. optimizers = [
  145. CastedSparseEmbeddingSignSGD_Distributed(
  146. model.model.puzzle_emb.buffers(), # type: ignore
  147. lr=0, # Needs to be set by scheduler
  148. weight_decay=config.puzzle_emb_weight_decay,
  149. world_size=world_size
  150. ),
  151. AdamATan2(
  152. model.parameters(),
  153. lr=0, # Needs to be set by scheduler
  154. weight_decay=config.weight_decay,
  155. betas=(config.beta1, config.beta2)
  156. )
  157. ]
  158. optimizer_lrs = [
  159. config.puzzle_emb_lr,
  160. config.lr
  161. ]
  162. return model, optimizers, optimizer_lrs
  163. def mix_weights_direct(device, alpha, net, nets):
  164. sd = []
  165. for i in range(len(nets)):
  166. sd += [nets[i].state_dict()]
  167. sd_alpha = {}
  168. for k in sd[0].keys():
  169. comb_net = alpha[0]*sd[0][k].to(device)
  170. for i in range(1,len(nets)):
  171. comb_net += alpha[i]*sd[i][k].to(device)
  172. sd_alpha[k] = comb_net
  173. net.load_state_dict(sd_alpha)
  174. return net
  175. def cosine_schedule_with_warmup_lr_lambda(
  176. current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5
  177. ):
  178. if current_step < num_warmup_steps:
  179. return base_lr * float(current_step) / float(max(1, num_warmup_steps))
  180. progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
  181. return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))))
  182. def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
  183. # Estimated total training steps
  184. total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size)
  185. # Model
  186. model, optimizers, optimizer_lrs = create_model(config, train_metadata, rank=rank, world_size=world_size)
  187. return TrainState(
  188. step=0,
  189. total_steps=total_steps,
  190. model=model,
  191. optimizers=optimizers,
  192. optimizer_lrs=optimizer_lrs,
  193. carry=None
  194. )
  195. def save_train_state(config: PretrainConfig, train_state: TrainState):
  196. # FIXME: Only saved model.
  197. if config.checkpoint_path is None:
  198. return
  199. os.makedirs(config.checkpoint_path, exist_ok=True)
  200. torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}"))
  201. def load_checkpoint(model: nn.Module, config: PretrainConfig):
  202. if config.load_checkpoint is not None:
  203. print(f"Loading checkpoint {config.load_checkpoint}")
  204. # Load state dict
  205. state_dict = torch.load(config.load_checkpoint, map_location="cuda")
  206. # Resize and reset puzzle emb if needed
  207. puzzle_emb_name = "_orig_mod.model.inner.puzzle_emb.weights"
  208. expected_shape: torch.Size = model.model.puzzle_emb.weights.shape # type: ignore
  209. if puzzle_emb_name in state_dict:
  210. puzzle_emb = state_dict[puzzle_emb_name]
  211. if puzzle_emb.shape != expected_shape:
  212. print(f"Resetting puzzle embedding as shape is different. Found {puzzle_emb.shape}, Expected {expected_shape}")
  213. # Re-initialize using mean
  214. state_dict[puzzle_emb_name] = (
  215. torch.mean(puzzle_emb, dim=0, keepdim=True).expand(expected_shape).contiguous()
  216. )
  217. model.load_state_dict(state_dict, assign=True)
  218. def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):
  219. return cosine_schedule_with_warmup_lr_lambda(
  220. current_step=train_state.step,
  221. base_lr=base_lr,
  222. num_warmup_steps=round(config.lr_warmup_steps),
  223. num_training_steps=train_state.total_steps,
  224. min_ratio=config.lr_min_ratio
  225. )
  226. def create_evaluators(config: PretrainConfig, eval_metadata: PuzzleDatasetMetadata) -> List[Any]:
  227. data_paths =config.data_paths_test if len(config.data_paths_test)>0 else config.data_paths
  228. # Initialize evaluators
  229. evaluators = []
  230. for cfg in config.evaluators:
  231. for data_path in data_paths:
  232. cls = load_model_class(cfg.name, "evaluators.")(
  233. data_path=data_path, eval_metadata=eval_metadata, **cfg.__pydantic_extra__
  234. ) # type: ignore
  235. evaluators.append(cls)
  236. return evaluators
  237. def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):
  238. train_state.step += 1
  239. if train_state.step > train_state.total_steps: # At most train_total_steps
  240. return
  241. # To device
  242. batch = {k: v.cuda() for k, v in batch.items()}
  243. # Init carry if it is None
  244. if train_state.carry is None:
  245. with torch.device("cuda"):
  246. train_state.carry = train_state.model.initial_carry(batch) # type: ignore
  247. # Forward
  248. train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[])
  249. ((1 / global_batch_size) * loss).backward()
  250. # Allreduce
  251. if world_size > 1:
  252. for param in train_state.model.parameters():
  253. if param.grad is not None:
  254. dist.all_reduce(param.grad)
  255. # Apply optimizer
  256. lr_this_step = None
  257. for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
  258. lr_this_step = compute_lr(base_lr, config, train_state)
  259. for param_group in optim.param_groups:
  260. param_group['lr'] = lr_this_step
  261. optim.step()
  262. optim.zero_grad()
  263. # Reduce metrics
  264. if len(metrics):
  265. assert not any(v.requires_grad for v in metrics.values())
  266. metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
  267. # Reduce and reconstruct
  268. metric_values = torch.stack([metrics[k] for k in metric_keys])
  269. if world_size > 1:
  270. dist.reduce(metric_values, dst=0)
  271. if rank == 0:
  272. metric_values = metric_values.cpu().numpy()
  273. reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)}
  274. # Postprocess
  275. count = max(reduced_metrics["count"], 1) # Avoid NaNs
  276. reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()}
  277. reduced_metrics["train/lr"] = lr_this_step
  278. return reduced_metrics
  279. def evaluate(
  280. config: PretrainConfig,
  281. train_state: TrainState,
  282. eval_loader: torch.utils.data.DataLoader,
  283. eval_metadata: PuzzleDatasetMetadata,
  284. evaluators: List[Any],
  285. rank: int,
  286. world_size: int,
  287. cpu_group: Optional[dist.ProcessGroup],
  288. ):
  289. reduced_metrics = None
  290. with torch.inference_mode():
  291. return_keys = set(config.eval_save_outputs)
  292. for evaluator in evaluators:
  293. evaluator.begin_eval()
  294. return_keys.update(evaluator.required_outputs)
  295. # Run evaluation
  296. set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)}
  297. save_preds = {}
  298. metric_keys = []
  299. metric_values = None
  300. carry = None
  301. processed_batches = 0
  302. for set_name, batch, global_batch_size in eval_loader:
  303. processed_batches += 1
  304. if rank == 0:
  305. print(f"Processing batch {processed_batches}: {set_name}")
  306. # To device
  307. batch = {k: v.cuda() for k, v in batch.items()}
  308. with torch.device("cuda"):
  309. carry = train_state.model.initial_carry(batch) # type: ignore
  310. # Forward
  311. inference_steps = 0
  312. while True:
  313. carry, loss, metrics, preds, all_finish = train_state.model(
  314. carry=carry, batch=batch, return_keys=return_keys
  315. )
  316. inference_steps += 1
  317. if all_finish:
  318. break
  319. if rank == 0:
  320. print(f" Completed inference in {inference_steps} steps")
  321. for collection in (batch, preds):
  322. for k, v in collection.items():
  323. if k in config.eval_save_outputs:
  324. save_preds.setdefault(k, [])
  325. save_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory
  326. for evaluator in evaluators:
  327. evaluator.update_batch(batch, preds)
  328. del carry, loss, preds, batch, all_finish
  329. # Aggregate metrics
  330. set_id = set_ids[set_name]
  331. if metric_values is None:
  332. metric_keys = list(
  333. sorted(metrics.keys())
  334. ) # Sort keys to guarantee all processes use the same order.
  335. metric_values = torch.zeros(
  336. (len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda"
  337. )
  338. metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys])
  339. del metrics
  340. # concatenate save preds
  341. save_preds = {k: torch.cat(v, dim=0) for k, v in save_preds.items()}
  342. # Save preds
  343. if config.checkpoint_path is not None and len(save_preds):
  344. # Each rank save predictions independently
  345. os.makedirs(os.path.dirname(config.checkpoint_path), exist_ok=True)
  346. torch.save(
  347. save_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}")
  348. )
  349. del save_preds
  350. # Reduce to rank 0
  351. if metric_values is not None:
  352. if world_size > 1:
  353. dist.reduce(metric_values, dst=0)
  354. if rank == 0:
  355. reduced_metrics = metric_values.cpu().numpy()
  356. reduced_metrics = {
  357. set_name: {
  358. metric_name: reduced_metrics[set_id, metric_id]
  359. for metric_id, metric_name in enumerate(metric_keys)
  360. }
  361. for set_id, set_name in enumerate(set_ids)
  362. }
  363. # Postprocess
  364. for set_name, m in reduced_metrics.items():
  365. count = m.pop("count")
  366. reduced_metrics[set_name] = {k: v / count for k, v in m.items()}
  367. # Run evaluators
  368. if rank == 0:
  369. print(f"\nRunning {len(evaluators)} evaluator(s)...")
  370. for i, evaluator in enumerate(evaluators):
  371. if rank == 0:
  372. print(f"Running evaluator {i+1}/{len(evaluators)}: {evaluator.__class__.__name__}")
  373. # Path for saving
  374. evaluator_save_path = None
  375. if config.checkpoint_path is not None:
  376. evaluator_save_path = os.path.join(
  377. config.checkpoint_path,
  378. f"evaluator_{evaluator.__class__.__name__}_step_{train_state.step}",
  379. )
  380. os.makedirs(evaluator_save_path, exist_ok=True)
  381. # Run and log
  382. metrics = evaluator.result(evaluator_save_path, rank=rank, world_size=world_size, group=cpu_group)
  383. if rank == 0 and metrics is not None:
  384. if reduced_metrics is None:
  385. reduced_metrics = {}
  386. reduced_metrics.update(metrics)
  387. print(f" Completed {evaluator.__class__.__name__}")
  388. if rank == 0:
  389. print("All evaluators completed!")
  390. return reduced_metrics
  391. def save_code_and_config(config: PretrainConfig):
  392. if config.checkpoint_path is None or wandb.run is None:
  393. return
  394. os.makedirs(config.checkpoint_path, exist_ok=True)
  395. # Copy code
  396. code_list = [
  397. get_model_source_path(config.arch.name),
  398. get_model_source_path(config.arch.loss.name)
  399. ]
  400. for code_file in code_list:
  401. if code_file is not None:
  402. code_name = os.path.basename(code_file)
  403. shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name))
  404. # Dump config as yaml
  405. config_file = os.path.join(config.checkpoint_path, "all_config.yaml")
  406. with open(config_file, "wt") as f:
  407. yaml.dump(config.model_dump(), f)
  408. # Log code
  409. wandb.run.log_code(config.checkpoint_path)
  410. def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig:
  411. objects = [None]
  412. if rank == 0:
  413. config = PretrainConfig(**hydra_config) # type: ignore
  414. # Naming
  415. if config.project_name is None:
  416. config.project_name = f"{os.path.basename(config.data_paths[0]).capitalize()}-ACT-torch"
  417. if config.run_name is None:
  418. config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}"
  419. if config.checkpoint_path is None:
  420. config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name)
  421. objects = [config]
  422. if world_size > 1:
  423. dist.broadcast_object_list(objects, src=0)
  424. return objects[0] # type: ignore
  425. @hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None)
  426. def launch(hydra_config: DictConfig):
  427. RANK = 0
  428. WORLD_SIZE = 1
  429. CPU_PROCESS_GROUP = None
  430. # Initialize distributed training if in distributed environment (e.g. torchrun)
  431. if "LOCAL_RANK" in os.environ:
  432. # Initialize distributed, default device and dtype
  433. dist.init_process_group(backend="nccl")
  434. RANK = dist.get_rank()
  435. WORLD_SIZE = dist.get_world_size()
  436. torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
  437. # CPU GLOO process group
  438. CPU_PROCESS_GROUP = dist.new_group(backend="gloo")
  439. assert (
  440. dist.get_rank(CPU_PROCESS_GROUP) == RANK and dist.get_world_size(CPU_PROCESS_GROUP) == WORLD_SIZE
  441. )
  442. # Load sync'ed config
  443. config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE)
  444. # Seed RNGs to ensure consistency
  445. torch.random.manual_seed(config.seed + RANK)
  446. # Dataset
  447. train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs
  448. total_iters = config.epochs // train_epochs_per_iter
  449. assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs."
  450. train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
  451. try:
  452. eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
  453. except:
  454. print("NO EVAL DATA FOUND")
  455. eval_loader = eval_metadata = None
  456. try:
  457. evaluators = create_evaluators(config, eval_metadata)
  458. except:
  459. print("No evaluator found")
  460. evaluators = []
  461. # Train state
  462. train_state = init_train_state(config, train_metadata, rank=RANK, world_size=WORLD_SIZE)
  463. # Progress bar and logger
  464. progress_bar = None
  465. ema_helper = None
  466. if RANK == 0:
  467. progress_bar = tqdm.tqdm(total=train_state.total_steps)
  468. wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore
  469. wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0)
  470. save_code_and_config(config)
  471. if config.ema:
  472. print('Setup EMA')
  473. ema_helper = EMAHelper(mu=config.ema_rate)
  474. ema_helper.register(train_state.model)
  475. # Training Loop
  476. for _iter_id in range(total_iters):
  477. print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}")
  478. ############ Train Iter
  479. if RANK == 0:
  480. print("TRAIN")
  481. train_state.model.train()
  482. for set_name, batch, global_batch_size in train_loader:
  483. metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE)
  484. if RANK == 0 and metrics is not None:
  485. wandb.log(metrics, step=train_state.step)
  486. progress_bar.update(train_state.step - progress_bar.n) # type: ignore
  487. if config.ema:
  488. ema_helper.update(train_state.model)
  489. if _iter_id >= config.min_eval_interval:
  490. ############ Evaluation
  491. if RANK == 0:
  492. print("EVALUATE")
  493. if config.ema:
  494. print("SWITCH TO EMA")
  495. train_state_eval = copy.deepcopy(train_state)
  496. train_state_eval.model = ema_helper.ema_copy(train_state_eval.model)
  497. else:
  498. train_state_eval = train_state
  499. train_state_eval.model.eval()
  500. metrics = evaluate(config,
  501. train_state_eval,
  502. eval_loader,
  503. eval_metadata,
  504. evaluators,
  505. rank=RANK,
  506. world_size=WORLD_SIZE,
  507. cpu_group=CPU_PROCESS_GROUP)
  508. if RANK == 0 and metrics is not None:
  509. wandb.log(metrics, step=train_state.step)
  510. ############ Checkpointing
  511. if RANK == 0:
  512. print("SAVE CHECKPOINT")
  513. if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)):
  514. save_train_state(config, train_state_eval)
  515. if config.ema:
  516. del train_state_eval
  517. # finalize
  518. if dist.is_initialized():
  519. dist.destroy_process_group()
  520. wandb.finish()
  521. if __name__ == "__main__":
  522. launch()