Skip to content

TorchRunner

danling.runner.TorchRunner

Bases: BaseRunner

Set up everything for running a job with PyTorch.

PyTorch backend is the most basic and flexible distributed backend. If you wish to extend the Runner, this is the best choice.

Source code in danling/runner/torch_runner.py
Python
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
class TorchRunner(BaseRunner):
    r"""
    Set up everything for running a job with PyTorch.

    PyTorch backend is the most basic and flexible distributed backend.
    If you wish to extend the Runner, this is the best choice.
    """

    model: nn.Module
    ema: nn.Module | None = None
    criterion: nn.Module
    optimizer: optim.Optimizer
    scheduler: optim.lr_scheduler._LRScheduler

    def __post_init__(self):
        super().__post_init__()
        if self.datasets:
            self.build_dataloaders()
        if self.config.get("log_interval") is None:
            self.config.log_interval = max(ceil(max(len(d) for d in self.dataloaders.values()) / 10), 1)
        self.model = self.model.to(self.device)
        if self.ema is not None:
            self.ema = self.ema.to(self.device)
        if self.distributed and not isinstance(
            self.model, (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)
        ):
            self.model = nn.parallel.DistributedDataParallel(self.model)

    def train(self, train_splits: list[str] | None = None, evaluate_splits: list[str] | None = None) -> NestedDict:
        r"""
        Perform training on `split`.

        Args:
            train_splits: list of split to run train.
                Defaults to `["train"]`.
            evaluate_splits: list of split to run evaluate.
                Defaults to `self.dataloaders` except for those in `train_splits`.

        Return:
            NestedDict: train results
        """

        early_stop_counter = 0
        if train_splits is None:
            train_splits = ["train"] if "train" in self.dataloaders else []
        self.train_splits = sorted(train_splits)
        if not train_splits:
            warn("No training split is found. Will only evaluate for one epoch.", stacklevel=2)
            self.epoch_end = self.epoch_begin + 1
        if evaluate_splits is None:
            evaluate_splits = [s for s in self.dataloaders if s not in train_splits]
        self.evaluate_splits = sorted(evaluate_splits)
        print(f"Begin training from {self.epoch_begin} to {self.epoch_end}")
        print(f"Training splits: {self.train_splits}")
        print(f"Evaluation splits: {self.evaluate_splits}")
        patience = self.config.get("patience", float("inf"))
        for epoch in range(self.epoch_begin, self.epoch_end):  # type: ignore
            self.epochs = epoch
            result = NestedDict()
            result.setattr("convert_mapping", True)
            for split in self.train_splits:
                result[split] = self.train_epoch(split)
            for split in self.evaluate_splits:
                result[split] = self.evaluate_epoch(split)
            self.append_result(result)
            print(self.format_epoch_result(result))
            self.save_result()
            if self.config.save_interval is not None:
                self.save_checkpoint()
            """@nni.report_intermediate_result(self.latest_score)"""
            early_stop_counter = 0 if self.is_best else early_stop_counter + 1
            if early_stop_counter > patience:
                print("early stop")
                break
        """@nni.report_final_result(self.latest_score)"""
        return self.results

    def train_epoch(self, split: str = "train") -> NestedDict:
        r"""
        Train one epoch on `split`.

        Args:
            split (str): split to run train

        Return:
            NestedDict: train result
        """

        self.mode = "train"  # type: ignore
        self.split = split
        loader = self.dataloaders[split]
        length = len(loader) - 1
        last_print_iteration = -1
        log_interval = self.config.get("log_interval", -1)
        self.meters.reset()
        if self.train_metrics is not None:
            self.metrics = self.train_metrics
        if self.metrics is not None:
            self.metrics.reset()
        batch_time = time()
        if hasattr(loader.batch_sampler, "set_epoch"):
            loader.batch_sampler.set_epoch(self.epochs)
        if hasattr(loader.sampler, "set_epoch"):
            loader.sampler.set_epoch(self.epochs)

        for iteration, data in enumerate(loader):
            _, loss = self.train_step(data)

            if log_interval > 0 and (iteration > 0 and iteration % log_interval == 0 or iteration == length):
                interval = iteration - last_print_iteration
                if self.device == torch.device("cuda"):
                    torch.cuda.synchronize()
                if self.scheduler is not None:
                    self.meters.lr.update(self.scheduler.get_last_lr()[0])
                self.meters.time.update((time() - batch_time) / interval)
                batch_time = time()
                reduced_loss = self.reduce(loss).item()
                self.meters.loss.update(reduced_loss)
                self.step_log(split, iteration, length)
                last_print_iteration = iteration

        result = self.get_epoch_result()
        return result

    def train_step(self, data) -> Tuple[Any, torch.Tensor]:
        with self.autocast(), self.accumulate():
            input = data["input"] if isinstance(data, Mapping) else data[0]
            target = data["target"] if isinstance(data, Mapping) else data[1]
            pred = self.model(**input) if isinstance(input, Mapping) else self.model(input)
            loss = self.criterion(pred, target)
            if self.metrics is not None:
                self.metrics.update(pred.squeeze(-1), target)
            self.advance(loss)
        return pred, loss

    def advance(self, loss) -> None:
        r"""
        Backward loss and step optimizer & scheduler.

        Args:
            loss: The loss tensor from which to backpropagate.
        """

        self.backward(loss / self.accum_steps)
        if self.accum_steps <= 1 or self.steps % self.accum_steps == 0:
            if self.config.get("max_grad_value") is not None:
                clip_grad_value_(self.model.parameters(), self.config["max_grad_value"])
            if self.config.get("max_grad_norm") is not None:
                clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"])
            self.optimizer.step()
            self.optimizer.zero_grad()
            if self.ema is not None:
                self.ema.update()
            if self.scheduler is not None:
                self.scheduler.step()
            self.steps += 1

    def evaluate(self, evaluate_splits: list[str] | None = None) -> NestedDict:
        r"""
        Perform evaluation on `evaluate_splits`.

        Args:
            evaluate_splits: list of split to run evaluate.
                Defaults to `["evaluate"]`.

        Return:
            NestedDict: evaluation result
        """

        if evaluate_splits is None:
            evaluate_splits = ["evaluate"]
        self.evaluate_splits = sorted(evaluate_splits)

        print("Begin evaluation")
        print(f"Evaluation splits: {self.evaluate_splits}")
        result = NestedDict()
        result.setattr("convert_mapping", True)
        for split in self.evaluate_splits:
            result[split] = self.evaluate_epoch(split=split)
        print(self.format_epoch_result(result))
        return result

    # torch.inference_mode cause experiments to hang
    # @torch.inference_mode()
    def evaluate_epoch(self, split: str = "val") -> NestedDict:
        r"""
        Evaluate one epoch on `split`.

        Args:
            split (str): split to run evaluate

        Return:
            NestedDict: evaluation result
        """

        self.mode = RunnerMode.evaluate
        self.split = split
        loader = self.dataloaders[split]
        length = len(loader) - 1
        last_print_iteration = -1
        log_interval = self.config.get("log_interval", -1)
        self.meters.reset()
        if self.evaluate_metrics is not None:
            self.metrics = self.evaluate_metrics
        if self.metrics is not None:
            self.metrics.reset()
        batch_time = time()

        for iteration, data in enumerate(loader):
            _, loss = self.evaluate_step(data)

            if log_interval > 0 and (iteration > 0 and iteration % log_interval == 0 or iteration == length):
                interval = iteration - last_print_iteration
                if self.device == torch.device("cuda"):
                    torch.cuda.synchronize()
                self.meters.time.update((time() - batch_time) / interval)
                batch_time = time()
                reduced_loss = self.reduce(loss).item()
                self.meters.loss.update(reduced_loss)
                self.step_log(split, iteration, length)
                last_print_iteration = iteration

        result = self.get_epoch_result()
        self.write_result(result, split, self.epochs)
        return result

    def evaluate_step(self, data) -> Tuple[Any, torch.Tensor]:
        input = data["input"] if isinstance(data, Mapping) else data[0]
        target = data["target"] if isinstance(data, Mapping) else data[1]
        model = self.ema or self.model
        pred = model(**input) if isinstance(input, Mapping) else model(input)
        loss = self.criterion(pred, target)
        if self.metrics is not None:
            self.metrics.update(pred.squeeze(-1), target)
        return pred, loss

    @torch.inference_mode()
    def infer(self, split: str = "infer") -> list[float]:
        r"""
        Perform inference on `split`.

        Args:
            split (str): split to run inference

        Return:
            Tensor: inference outputs
        """

        self.mode = RunnerMode.infer
        loader = self.dataloaders[split]
        output: list[float] = []
        model = self.ema or self.model
        for _, data in tqdm(enumerate(loader), total=len(loader)):
            input = data["input"] if isinstance(data, Mapping) else data[0]
            pred = model(**input) if isinstance(input, Mapping) else model(input)
            output.extend(pred.squeeze(-1).tolist())

        if self.distributed:
            torch.cuda.synchronize()
            output = self.gather_for_metrics(output)
        return output

    def backward(self, loss: torch.Tensor) -> None:
        r"""
        Backward loss.

        Args:
            loss: Loss to backward.
        """

        loss.backward()

    def has_nan_inf_grad(self, model: nn.Module | None = None) -> bool:
        r"""
        Check if model has NaN or Inf gradients.

        Args:
            model: Model to check.
                Defaults to `self.model`.

        Return:
            bool: True if NaN or Inf is detected in gradients.
        """
        model = model or self.model
        for name, param in model.named_parameters():
            if param.grad is not None:
                if torch.isnan(param.grad).any():
                    print(f"NaN detected in gradients of parameter: {name}")
                    return True
                if torch.isinf(param.grad).any():
                    print(f"Inf detected in gradients of parameter: {name}")
                    return True
        return False

    def init_distributed(self) -> None:
        r"""
        Set up distributed training.

        Initialise process group and set up DDP variables.
        """

        backend = self.config.get("backend", os.getenv("BACKEND"))
        init_method = self.config.get("init_method", os.getenv("INIT_METHOD"))
        world_size = int(self.config.get("world_size", os.getenv("WORLD_SIZE", "1")))
        rank = int(self.config.get("rank", os.getenv("RANK", "0")))
        if world_size > 1:
            if torch.cuda.is_available():
                torch.cuda.set_device(self.local_rank)
            dist.init_process_group(backend, init_method, world_size=world_size, rank=rank)
            object_list = [self.id, self.timestamp]
            dist.broadcast_object_list(object_list)
            self.id, self.timestamp = object_list

    @on_main_process
    def init_tensorboard(self, *args, **kwargs) -> None:
        r"""
        Set up Tensoraoard SummaryWriter.
        """
        from torch.utils.tensorboard.writer import SummaryWriter  # pylint: disable=C0415

        if "log_dir" not in kwargs:
            kwargs["log_dir"] = self.dir

        self.writer = SummaryWriter(*args, **kwargs)
        self.writer.add_scalar = catch(OSError, verbose=False)(self.writer.add_scalar)

    def set_seed(self, seed: int = None, bias: int = None) -> int:  # type: ignore[assignment]
        r"""
        Set up random seed.

        Args:
            seed: Random seed to set.
                Defaults to `self.config.seed` (`config.seed`).

            bias: Make the seed different for each processes.
                This is used to ensure the data augmentation are applied differently on every processes.
                Defaults to `self.rank`.
                Set to `False` to disable this feature.
        Returns:
            Random seed set.
        """

        seed = seed or self.config.seed  # type: ignore[assignment]
        if seed is None:
            if self.inited:
                seed = random.randint(0, 2**32 - 1)
                if self.distributed:
                    object_list = [seed]
                    dist.broadcast_object_list(object_list)
                    seed = object_list[0]
                self.config.seed = seed
        else:
            seed = defaults.SEED
        bias = bias or self.rank
        if bias:
            seed += bias
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        if np_random is not None:
            np_random.seed(seed)
        random.seed(seed)
        return seed

    def set_deterministic(self) -> None:
        cudnn.benchmark = False
        cudnn.deterministic = True
        if torch.__version__ >= "1.8.0":
            torch.use_deterministic_algorithms(True)

    def state_dict(self, cls: Callable = dict) -> Mapping:
        if self.model is None:
            raise ValueError("Model must be defined when calling state_dict")
        return cls(
            runner=self.config.dict(),
            model=self.unwrap(self.model).state_dict(),
            optimizer=self.optimizer.state_dict() if self.optimizer else None,
            scheduler=self.scheduler.state_dict() if self.scheduler else None,
        )

    def unwrap(self, model: nn.Module) -> nn.Module:
        if isinstance(model, (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)):
            return model.module
        return model

    def build_dataloaders(self):
        datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders}
        default_kwargs = self.config.get("dataloader", NestedDict())
        dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs})
        for k, d in datasets.items():
            dataloader_kwargs.setdefault(k, NestedDict())
            dataloader_kwargs[k].merge(default_kwargs, overwrite=False)
            shuffle = dataloader_kwargs[k].pop("shuffle", getattr(d, "train", True))
            if self.distributed:
                sampler = utils.data.distributed.DistributedSampler(d, shuffle=shuffle)
            else:
                sampler = utils.data.RandomSampler(d) if shuffle else utils.data.SequentialSampler(d)
            dataloader_kwargs[k].setdefault("drop_last", not getattr(d, "train", True))
            self.dataloaders[k] = utils.data.DataLoader(
                d, sampler=sampler, collate_fn=self.collate_fn, **dataloader_kwargs[k]
            )

    def collate_fn(self, batch):
        return to_device(utils.data.dataloader.default_collate(batch), self.device)

    @contextmanager
    def autocast(self):
        if self.config.get("precision") is None:
            yield nullcontext()
        else:
            yield torch.autocast(self.device.type, dtype=get_precision(self.config.precision))

    @contextmanager
    def accumulate(self):
        if self.accum_steps <= 1 or self.steps % self.accum_steps == 0:
            yield nullcontext()
        else:
            yield self.model.no_sync()

    def get_optimizer(self, name: str):
        if name.lower() == "sgd":
            return optim.SGD
        if name.lower() == "asgd":
            return optim.ASGD
        if name.lower() in {"torch_adam", "torch_adamw"}:
            return optim.Adam
        if ds is not None:
            if name.lower() == "adagrad":
                return ds.ops.adagrad.DeepSpeedCPUAdagrad
            if name.lower() in {"adam", "adamw"}:
                if torch.cuda.device_count() > 0:
                    return ds.ops.adam.FusedAdam
                return ds.ops.adam.DeepSpeedCPUAdam
            if name.lower() in {"cpu", "cpu_adam", "cpuadam", "cpu_adamw", "cpuadamw"}:
                return ds.ops.adam.DeepSpeedCPUAdam
            if name.lower() == "lamb":
                if torch.cuda.device_count() > 0:
                    return ds.ops.lamb.FusedLamb
                return ds.ops.lamb.DeepSpeedCPULamb
            if name.lower() in {"cpulamb", "cpu_lamb"}:
                return ds.ops.lamb.DeepSpeedCPULamb
            if name.lower() == "lion":
                if torch.cuda.device_count() > 0:
                    return ds.ops.lion.FusedLion
                return ds.ops.lion.DeepSpeedCPULion
            if name.lower() in {"cpulion", "cpu_lion"}:
                return ds.ops.lion.DeepSpeedCPULion
        if name.lower() in {"adam", "adamw"}:
            return optim.AdamW
        if name.lower() == "adadelta":
            return optim.Adadelta
        if name.lower() == "adafactor":
            return optim.Adafactor
        if name.lower() == "adagrad":
            return optim.Adagrad
        if name.lower() == "adamax":
            return optim.Adamax
        if name.lower() == "lbfgs":
            return optim.LBFGS
        if name.lower() == "nadam":
            return optim.NAdam
        if name.lower() == "radam":
            return optim.RAdam
        if name.lower() == "rmsprop":
            return optim.RMSprop
        if name.lower() == "rprop":
            return optim.Rprop

    def get_deepspeed_config(self, config: NestedDict | str = None) -> NestedDict:  # pylint: disable=R0912, R0915
        r"""
        Preprocess DeepSpeed config.
        """

        if config is None and "deepspeed" in self.config:
            config = self.config.deepspeed
        if isinstance(config, str):
            config = NestedDict(config)
        if config is None:
            config = NestedDict()
        if config.get("steps_per_print", "auto") == "auto":
            config["steps_per_print"] = self.config.log_interval
        if config.get("train_micro_batch_size_per_gpu", "auto") == "auto":
            config["train_micro_batch_size_per_gpu"] = self.batch_size
        if config.get("gradient_accumulation_steps", "auto") == "auto":
            if self.accum_steps > 1:
                config["gradient_accumulation_steps"] = self.accum_steps
            else:
                config.pop("gradient_accumulation_steps", None)
        if "amp" in config:
            amp = config["amp"]
            if amp.get("enabled", "auto") == "auto":
                amp["enabled"] = "true"
            if amp.get("opt_level", "auto") == "auto":
                amp["opt_level"] = "O1"
        if "zero_optimization" in config:
            zero = config["zero_optimization"]
            if zero.get("allgather_bucket_size") == "auto":
                zero["allgather_bucket_size"] = 1e6
            if zero.get("reduce_bucket_size") == "auto":
                zero["reduce_bucket_size"] = 1e6
            if zero.get("stage3_max_live_parameters") == "auto":
                zero["stage3_max_live_parameters"] = 1e8
            if zero.get("stage3_max_live_gradients") == "auto":
                zero["stage3_max_live_gradients"] = 1e8
            if zero.get("stage3_max_reuse_distance") == "auto":
                zero["stage3_max_reuse_distance"] = 1e8
            if zero.get("stage3_prefetch_bucket_size") == "auto":
                zero["stage3_prefetch_bucket_size"] = 1e6
            if zero.get("stage3_param_persistence_threshold") == "auto":
                zero["stage3_param_persistence_threshold"] = 1e8
            if "amp" in config:
                if "fp16" not in config:
                    config["fp16"] = NestedDict()
                if config["fp16"].get("enabled", "auto"):
                    config["fp16"]["enabled"] = config["amp"]["enabled"]
                warn(
                    f"AMP is not compatible with ZeRO. Automatically set 'fp16' to {config['amp']['enabled']}",
                    stacklevel=2,
                )
                del config["amp"]
        if "optimizer" in config:
            if config["optimizer"].get("type", "auto") == "auto":
                config["optimizer"]["type"] = "Adam"
            if "params" not in config["optimizer"]:
                config["optimizer"]["params"] = NestedDict()
            optimizer = config["optimizer"]["params"]
            if optimizer.get("lr", "auto") == "auto":
                optimizer["lr"] = self.config.get("optim.lr", 1e-3)
            if optimizer.get("weight_decay", "auto") == "auto":
                optimizer["weight_decay"] = self.config.get("optim.weight_decay", 1e-2)
            if optimizer.get("betas") == "auto":
                optimizer["betas"] = (0.9, 0.999)
            if optimizer.get("eps") == "auto":
                optimizer["eps"] = 1e-8
        if "scheduler" in config:
            if config["scheduler"].get("type", "auto") == "auto":
                config["scheduler"]["type"] = "WarmupCosineLR"
            if "params" not in config["scheduler"]:
                config["scheduler"]["params"] = NestedDict()
            scheduler = config["scheduler"]["params"]
            if scheduler.get("total_num_steps", "auto") == "auto":
                scheduler["total_num_steps"] = self.total_steps
            if scheduler.get("warmup_num_steps", "auto") == "auto":
                scheduler["warmup_num_steps"] = scheduler["total_num_steps"] // 20
            if config["scheduler"]["type"] in ("WarmupLR", "WarmupDecayLR"):
                if scheduler.get("warmup_max_lr", "auto") == "auto":
                    if self.optimizer:
                        scheduler["warmup_max_lr"] = self.optimizer.param_groups[0]["lr"]
                    elif "optimizer" in config:
                        scheduler["warmup_max_lr"] = config["optimizer"]["params"]["lr"]
                    else:
                        scheduler["warmup_max_lr"] = self.config.get("optim.lr", 1e-3)
                if scheduler.get("warmup_min_lr", "auto") == "auto":
                    scheduler["warmup_min_lr"] = 1e-9
            else:
                scheduler.pop("warmup_max_lr", None)
                scheduler.pop("warmup_min_lr", None)
        if config.get("gradient_clipping", "auto") == "auto" and self.config.get("max_grad_norm") is not None:
            config["gradient_clipping"] = self.config["max_grad_norm"]
        return config

    @property
    def device(self):
        return torch.device("cuda" if torch.cuda.is_available() else "cpu", self.local_rank)

    @property
    def mode(self) -> RunnerMode:
        return self._mode

    @mode.setter
    def mode(self, mode: str | RunnerMode) -> None:
        if isinstance(mode, str):
            mode = RunnerMode(mode)
        self._mode = mode
        if self.model is not None:
            self.model.train(mode == RunnerMode.train)
        if self.ema is not None:
            self.ema.train(mode == RunnerMode.train)

    @property
    def rank(self) -> int:
        if self.distributed:
            return dist.get_rank()
        return 0

    @property
    def local_rank(self) -> int:
        if local_rank := os.getenv("LOCAL_RANK"):
            return int(local_rank)
        return 0

    @property
    def world_size(self) -> int:
        r"""
        Number of Processes.
        """

        if torch.distributed.is_available() and torch.distributed.is_initialized():
            return dist.get_world_size()
        return 1

    @property
    def distributed(self) -> bool:
        return self.world_size > 1

    @cached_property
    def accum_steps(self) -> int:
        return self.config.get("accum_steps", 1)

    @staticmethod
    def reduce(tensor: torch.Tensor) -> torch.Tensor:
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            dist.all_reduce(tensor)
        return tensor

world_size property

Python
world_size: int

Number of Processes.

train

Python
train(
    train_splits: list[str] | None = None,
    evaluate_splits: list[str] | None = None,
) -> NestedDict

Perform training on split.

Parameters:

Name Type Description Default

train_splits

list[str] | None

list of split to run train. Defaults to ["train"].

None

evaluate_splits

list[str] | None

list of split to run evaluate. Defaults to self.dataloaders except for those in train_splits.

None
Return
Source code in danling/runner/torch_runner.py
Python
def train(self, train_splits: list[str] | None = None, evaluate_splits: list[str] | None = None) -> NestedDict:
    r"""
    Perform training on `split`.

    Args:
        train_splits: list of split to run train.
            Defaults to `["train"]`.
        evaluate_splits: list of split to run evaluate.
            Defaults to `self.dataloaders` except for those in `train_splits`.

    Return:
        NestedDict: train results
    """

    early_stop_counter = 0
    if train_splits is None:
        train_splits = ["train"] if "train" in self.dataloaders else []
    self.train_splits = sorted(train_splits)
    if not train_splits:
        warn("No training split is found. Will only evaluate for one epoch.", stacklevel=2)
        self.epoch_end = self.epoch_begin + 1
    if evaluate_splits is None:
        evaluate_splits = [s for s in self.dataloaders if s not in train_splits]
    self.evaluate_splits = sorted(evaluate_splits)
    print(f"Begin training from {self.epoch_begin} to {self.epoch_end}")
    print(f"Training splits: {self.train_splits}")
    print(f"Evaluation splits: {self.evaluate_splits}")
    patience = self.config.get("patience", float("inf"))
    for epoch in range(self.epoch_begin, self.epoch_end):  # type: ignore
        self.epochs = epoch
        result = NestedDict()
        result.setattr("convert_mapping", True)
        for split in self.train_splits:
            result[split] = self.train_epoch(split)
        for split in self.evaluate_splits:
            result[split] = self.evaluate_epoch(split)
        self.append_result(result)
        print(self.format_epoch_result(result))
        self.save_result()
        if self.config.save_interval is not None:
            self.save_checkpoint()
        """@nni.report_intermediate_result(self.latest_score)"""
        early_stop_counter = 0 if self.is_best else early_stop_counter + 1
        if early_stop_counter > patience:
            print("early stop")
            break
    """@nni.report_final_result(self.latest_score)"""
    return self.results

train_epoch

Python
train_epoch(split: str = 'train') -> NestedDict

Train one epoch on split.

Parameters:

Name Type Description Default

split

str

split to run train

'train'
Return
Source code in danling/runner/torch_runner.py
Python
def train_epoch(self, split: str = "train") -> NestedDict:
    r"""
    Train one epoch on `split`.

    Args:
        split (str): split to run train

    Return:
        NestedDict: train result
    """

    self.mode = "train"  # type: ignore
    self.split = split
    loader = self.dataloaders[split]
    length = len(loader) - 1
    last_print_iteration = -1
    log_interval = self.config.get("log_interval", -1)
    self.meters.reset()
    if self.train_metrics is not None:
        self.metrics = self.train_metrics
    if self.metrics is not None:
        self.metrics.reset()
    batch_time = time()
    if hasattr(loader.batch_sampler, "set_epoch"):
        loader.batch_sampler.set_epoch(self.epochs)
    if hasattr(loader.sampler, "set_epoch"):
        loader.sampler.set_epoch(self.epochs)

    for iteration, data in enumerate(loader):
        _, loss = self.train_step(data)

        if log_interval > 0 and (iteration > 0 and iteration % log_interval == 0 or iteration == length):
            interval = iteration - last_print_iteration
            if self.device == torch.device("cuda"):
                torch.cuda.synchronize()
            if self.scheduler is not None:
                self.meters.lr.update(self.scheduler.get_last_lr()[0])
            self.meters.time.update((time() - batch_time) / interval)
            batch_time = time()
            reduced_loss = self.reduce(loss).item()
            self.meters.loss.update(reduced_loss)
            self.step_log(split, iteration, length)
            last_print_iteration = iteration

    result = self.get_epoch_result()
    return result

advance

Python
advance(loss) -> None

Backward loss and step optimizer & scheduler.

Parameters:

Name Type Description Default

loss

The loss tensor from which to backpropagate.

required
Source code in danling/runner/torch_runner.py
Python
def advance(self, loss) -> None:
    r"""
    Backward loss and step optimizer & scheduler.

    Args:
        loss: The loss tensor from which to backpropagate.
    """

    self.backward(loss / self.accum_steps)
    if self.accum_steps <= 1 or self.steps % self.accum_steps == 0:
        if self.config.get("max_grad_value") is not None:
            clip_grad_value_(self.model.parameters(), self.config["max_grad_value"])
        if self.config.get("max_grad_norm") is not None:
            clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"])
        self.optimizer.step()
        self.optimizer.zero_grad()
        if self.ema is not None:
            self.ema.update()
        if self.scheduler is not None:
            self.scheduler.step()
        self.steps += 1

evaluate

Python
evaluate(
    evaluate_splits: list[str] | None = None,
) -> NestedDict

Perform evaluation on evaluate_splits.

Parameters:

Name Type Description Default

evaluate_splits

list[str] | None

list of split to run evaluate. Defaults to ["evaluate"].

None
Return
Source code in danling/runner/torch_runner.py
Python
def evaluate(self, evaluate_splits: list[str] | None = None) -> NestedDict:
    r"""
    Perform evaluation on `evaluate_splits`.

    Args:
        evaluate_splits: list of split to run evaluate.
            Defaults to `["evaluate"]`.

    Return:
        NestedDict: evaluation result
    """

    if evaluate_splits is None:
        evaluate_splits = ["evaluate"]
    self.evaluate_splits = sorted(evaluate_splits)

    print("Begin evaluation")
    print(f"Evaluation splits: {self.evaluate_splits}")
    result = NestedDict()
    result.setattr("convert_mapping", True)
    for split in self.evaluate_splits:
        result[split] = self.evaluate_epoch(split=split)
    print(self.format_epoch_result(result))
    return result

evaluate_epoch

Python
evaluate_epoch(split: str = 'val') -> NestedDict

Evaluate one epoch on split.

Parameters:

Name Type Description Default

split

str

split to run evaluate

'val'
Return
Source code in danling/runner/torch_runner.py
Python
def evaluate_epoch(self, split: str = "val") -> NestedDict:
    r"""
    Evaluate one epoch on `split`.

    Args:
        split (str): split to run evaluate

    Return:
        NestedDict: evaluation result
    """

    self.mode = RunnerMode.evaluate
    self.split = split
    loader = self.dataloaders[split]
    length = len(loader) - 1
    last_print_iteration = -1
    log_interval = self.config.get("log_interval", -1)
    self.meters.reset()
    if self.evaluate_metrics is not None:
        self.metrics = self.evaluate_metrics
    if self.metrics is not None:
        self.metrics.reset()
    batch_time = time()

    for iteration, data in enumerate(loader):
        _, loss = self.evaluate_step(data)

        if log_interval > 0 and (iteration > 0 and iteration % log_interval == 0 or iteration == length):
            interval = iteration - last_print_iteration
            if self.device == torch.device("cuda"):
                torch.cuda.synchronize()
            self.meters.time.update((time() - batch_time) / interval)
            batch_time = time()
            reduced_loss = self.reduce(loss).item()
            self.meters.loss.update(reduced_loss)
            self.step_log(split, iteration, length)
            last_print_iteration = iteration

    result = self.get_epoch_result()
    self.write_result(result, split, self.epochs)
    return result

infer

Python
infer(split: str = 'infer') -> list[float]

Perform inference on split.

Parameters:

Name Type Description Default

split

str

split to run inference

'infer'
Return
Source code in danling/runner/torch_runner.py
Python
@torch.inference_mode()
def infer(self, split: str = "infer") -> list[float]:
    r"""
    Perform inference on `split`.

    Args:
        split (str): split to run inference

    Return:
        Tensor: inference outputs
    """

    self.mode = RunnerMode.infer
    loader = self.dataloaders[split]
    output: list[float] = []
    model = self.ema or self.model
    for _, data in tqdm(enumerate(loader), total=len(loader)):
        input = data["input"] if isinstance(data, Mapping) else data[0]
        pred = model(**input) if isinstance(input, Mapping) else model(input)
        output.extend(pred.squeeze(-1).tolist())

    if self.distributed:
        torch.cuda.synchronize()
        output = self.gather_for_metrics(output)
    return output

backward

Python
backward(loss: Tensor) -> None

Backward loss.

Parameters:

Name Type Description Default

loss

Tensor

Loss to backward.

required
Source code in danling/runner/torch_runner.py
Python
def backward(self, loss: torch.Tensor) -> None:
    r"""
    Backward loss.

    Args:
        loss: Loss to backward.
    """

    loss.backward()

has_nan_inf_grad

Python
has_nan_inf_grad(model: Module | None = None) -> bool

Check if model has NaN or Inf gradients.

Parameters:

Name Type Description Default

model

Module | None

Model to check. Defaults to self.model.

None
Return
Source code in danling/runner/torch_runner.py
Python
def has_nan_inf_grad(self, model: nn.Module | None = None) -> bool:
    r"""
    Check if model has NaN or Inf gradients.

    Args:
        model: Model to check.
            Defaults to `self.model`.

    Return:
        bool: True if NaN or Inf is detected in gradients.
    """
    model = model or self.model
    for name, param in model.named_parameters():
        if param.grad is not None:
            if torch.isnan(param.grad).any():
                print(f"NaN detected in gradients of parameter: {name}")
                return True
            if torch.isinf(param.grad).any():
                print(f"Inf detected in gradients of parameter: {name}")
                return True
    return False

init_distributed

Python
init_distributed() -> None

Set up distributed training.

Initialise process group and set up DDP variables.

Source code in danling/runner/torch_runner.py
Python
def init_distributed(self) -> None:
    r"""
    Set up distributed training.

    Initialise process group and set up DDP variables.
    """

    backend = self.config.get("backend", os.getenv("BACKEND"))
    init_method = self.config.get("init_method", os.getenv("INIT_METHOD"))
    world_size = int(self.config.get("world_size", os.getenv("WORLD_SIZE", "1")))
    rank = int(self.config.get("rank", os.getenv("RANK", "0")))
    if world_size > 1:
        if torch.cuda.is_available():
            torch.cuda.set_device(self.local_rank)
        dist.init_process_group(backend, init_method, world_size=world_size, rank=rank)
        object_list = [self.id, self.timestamp]
        dist.broadcast_object_list(object_list)
        self.id, self.timestamp = object_list

init_tensorboard

Python
init_tensorboard(*args, **kwargs) -> None

Set up Tensoraoard SummaryWriter.

Source code in danling/runner/torch_runner.py
Python
@on_main_process
def init_tensorboard(self, *args, **kwargs) -> None:
    r"""
    Set up Tensoraoard SummaryWriter.
    """
    from torch.utils.tensorboard.writer import SummaryWriter  # pylint: disable=C0415

    if "log_dir" not in kwargs:
        kwargs["log_dir"] = self.dir

    self.writer = SummaryWriter(*args, **kwargs)
    self.writer.add_scalar = catch(OSError, verbose=False)(self.writer.add_scalar)

set_seed

Python
set_seed(seed: int = None, bias: int = None) -> int

Set up random seed.

Parameters:

Name Type Description Default

seed

int

Random seed to set. Defaults to self.config.seed (config.seed).

None

bias

int

Make the seed different for each processes. This is used to ensure the data augmentation are applied differently on every processes. Defaults to self.rank. Set to False to disable this feature.

None
Source code in danling/runner/torch_runner.py
Python
def set_seed(self, seed: int = None, bias: int = None) -> int:  # type: ignore[assignment]
    r"""
    Set up random seed.

    Args:
        seed: Random seed to set.
            Defaults to `self.config.seed` (`config.seed`).

        bias: Make the seed different for each processes.
            This is used to ensure the data augmentation are applied differently on every processes.
            Defaults to `self.rank`.
            Set to `False` to disable this feature.
    Returns:
        Random seed set.
    """

    seed = seed or self.config.seed  # type: ignore[assignment]
    if seed is None:
        if self.inited:
            seed = random.randint(0, 2**32 - 1)
            if self.distributed:
                object_list = [seed]
                dist.broadcast_object_list(object_list)
                seed = object_list[0]
            self.config.seed = seed
    else:
        seed = defaults.SEED
    bias = bias or self.rank
    if bias:
        seed += bias
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    if np_random is not None:
        np_random.seed(seed)
    random.seed(seed)
    return seed

get_deepspeed_config

Python
get_deepspeed_config(
    config: NestedDict | str = None,
) -> NestedDict

Preprocess DeepSpeed config.

Source code in danling/runner/torch_runner.py
Python
def get_deepspeed_config(self, config: NestedDict | str = None) -> NestedDict:  # pylint: disable=R0912, R0915
    r"""
    Preprocess DeepSpeed config.
    """

    if config is None and "deepspeed" in self.config:
        config = self.config.deepspeed
    if isinstance(config, str):
        config = NestedDict(config)
    if config is None:
        config = NestedDict()
    if config.get("steps_per_print", "auto") == "auto":
        config["steps_per_print"] = self.config.log_interval
    if config.get("train_micro_batch_size_per_gpu", "auto") == "auto":
        config["train_micro_batch_size_per_gpu"] = self.batch_size
    if config.get("gradient_accumulation_steps", "auto") == "auto":
        if self.accum_steps > 1:
            config["gradient_accumulation_steps"] = self.accum_steps
        else:
            config.pop("gradient_accumulation_steps", None)
    if "amp" in config:
        amp = config["amp"]
        if amp.get("enabled", "auto") == "auto":
            amp["enabled"] = "true"
        if amp.get("opt_level", "auto") == "auto":
            amp["opt_level"] = "O1"
    if "zero_optimization" in config:
        zero = config["zero_optimization"]
        if zero.get("allgather_bucket_size") == "auto":
            zero["allgather_bucket_size"] = 1e6
        if zero.get("reduce_bucket_size") == "auto":
            zero["reduce_bucket_size"] = 1e6
        if zero.get("stage3_max_live_parameters") == "auto":
            zero["stage3_max_live_parameters"] = 1e8
        if zero.get("stage3_max_live_gradients") == "auto":
            zero["stage3_max_live_gradients"] = 1e8
        if zero.get("stage3_max_reuse_distance") == "auto":
            zero["stage3_max_reuse_distance"] = 1e8
        if zero.get("stage3_prefetch_bucket_size") == "auto":
            zero["stage3_prefetch_bucket_size"] = 1e6
        if zero.get("stage3_param_persistence_threshold") == "auto":
            zero["stage3_param_persistence_threshold"] = 1e8
        if "amp" in config:
            if "fp16" not in config:
                config["fp16"] = NestedDict()
            if config["fp16"].get("enabled", "auto"):
                config["fp16"]["enabled"] = config["amp"]["enabled"]
            warn(
                f"AMP is not compatible with ZeRO. Automatically set 'fp16' to {config['amp']['enabled']}",
                stacklevel=2,
            )
            del config["amp"]
    if "optimizer" in config:
        if config["optimizer"].get("type", "auto") == "auto":
            config["optimizer"]["type"] = "Adam"
        if "params" not in config["optimizer"]:
            config["optimizer"]["params"] = NestedDict()
        optimizer = config["optimizer"]["params"]
        if optimizer.get("lr", "auto") == "auto":
            optimizer["lr"] = self.config.get("optim.lr", 1e-3)
        if optimizer.get("weight_decay", "auto") == "auto":
            optimizer["weight_decay"] = self.config.get("optim.weight_decay", 1e-2)
        if optimizer.get("betas") == "auto":
            optimizer["betas"] = (0.9, 0.999)
        if optimizer.get("eps") == "auto":
            optimizer["eps"] = 1e-8
    if "scheduler" in config:
        if config["scheduler"].get("type", "auto") == "auto":
            config["scheduler"]["type"] = "WarmupCosineLR"
        if "params" not in config["scheduler"]:
            config["scheduler"]["params"] = NestedDict()
        scheduler = config["scheduler"]["params"]
        if scheduler.get("total_num_steps", "auto") == "auto":
            scheduler["total_num_steps"] = self.total_steps
        if scheduler.get("warmup_num_steps", "auto") == "auto":
            scheduler["warmup_num_steps"] = scheduler["total_num_steps"] // 20
        if config["scheduler"]["type"] in ("WarmupLR", "WarmupDecayLR"):
            if scheduler.get("warmup_max_lr", "auto") == "auto":
                if self.optimizer:
                    scheduler["warmup_max_lr"] = self.optimizer.param_groups[0]["lr"]
                elif "optimizer" in config:
                    scheduler["warmup_max_lr"] = config["optimizer"]["params"]["lr"]
                else:
                    scheduler["warmup_max_lr"] = self.config.get("optim.lr", 1e-3)
            if scheduler.get("warmup_min_lr", "auto") == "auto":
                scheduler["warmup_min_lr"] = 1e-9
        else:
            scheduler.pop("warmup_max_lr", None)
            scheduler.pop("warmup_min_lr", None)
    if config.get("gradient_clipping", "auto") == "auto" and self.config.get("max_grad_norm") is not None:
        config["gradient_clipping"] = self.config["max_grad_norm"]
    return config