Skip to content

DanLing

danling

AverageMeter

A lightweight utility to compute and store running averages of values.

AverageMeter provides an efficient way to track running statistics (current value, sum, count, average) with minimal memory overhead and optional distributed averaging. Scalar values stay scalar. Tensor values are preserved end to end as long as each update for the meter has the same shape.

Attributes:

Name Type Description
val float | Tensor

Most recent local value added to the meter

bat float | Tensor

Synchronized metric value for the current step

avg float | Tensor

Running average of all values, weighted by counts

sum float | Tensor

Sum of all values added to the meter

count int

Total count of values added (considering weights)

device

Device used when synchronising running averages across processes

Parameters:

Name Type Description Default

device

device | str | None

Optional device used for distributed reductions. When not provided, the device is detected automatically when synchronisation happens.

None

Examples:

Python Console Session
>>> meter = AverageMeter()
>>> meter.update(0.7)
>>> meter.val
0.7
>>> meter.bat  # Same as val in non-distributed mode
0.7
>>> meter.avg
0.7
>>> meter.update(0.9)
>>> meter.val
0.9
>>> meter.avg
0.8
>>> meter.sum
1.6
>>> meter.count
2
>>> # Weighted update
>>> meter.update(value=0.5, n=3)
>>> meter.avg
0.62
>>> meter.reset()
AverageMeter(val=nan, avg=nan)
See Also
  • MetricMeter: Memory-efficient metric tracker that averages metrics batch-by-batch.
Source code in danling/metrics/average_meter.py
Python
class AverageMeter:
    r"""
    A lightweight utility to compute and store running averages of values.

    AverageMeter provides an efficient way to track running statistics (current value, sum, count, average)
    with minimal memory overhead and optional distributed averaging.
    Scalar values stay scalar. Tensor values are preserved end to end as long as
    each update for the meter has the same shape.

    Attributes:
        val: Most recent local value added to the meter
        bat: Synchronized metric value for the current step
        avg: Running average of all values, weighted by counts
        sum: Sum of all values added to the meter
        count: Total count of values added (considering weights)
        device: Device used when synchronising running averages across processes

    Args:
        device: Optional device used for distributed reductions. When not provided,
            the device is detected automatically when synchronisation happens.

    Examples:
        >>> meter = AverageMeter()
        >>> meter.update(0.7)
        >>> meter.val
        0.7
        >>> meter.bat  # Same as val in non-distributed mode
        0.7
        >>> meter.avg
        0.7
        >>> meter.update(0.9)
        >>> meter.val
        0.9
        >>> meter.avg
        0.8
        >>> meter.sum
        1.6
        >>> meter.count
        2
        >>> # Weighted update
        >>> meter.update(value=0.5, n=3)
        >>> meter.avg
        0.62
        >>> meter.reset()
        AverageMeter(val=nan, avg=nan)

    See Also:
        - [`MetricMeter`][danling.metrics.stream_metrics.MetricMeter]:
            Memory-efficient metric tracker that averages metrics batch-by-batch.
    """

    _local_value: float | Tensor = 0.0
    _local_n: int = 0
    _local_sum: float | Tensor = 0.0
    _local_count: int = 0

    def __init__(self, *, device: torch.device | str | None = None, distributed: bool = True) -> None:
        self.distributed = distributed
        self.device = torch_device(device) if device is not None else None
        self.reset()

    # Lifecycle
    def reset(self, *, device: torch.device | str | None = None) -> Self:
        r"""
        Resets the meter.
        """

        if device is not None:
            self.device = torch_device(device)
        self._local_value = 0.0
        self._local_n = 0
        self._local_sum = 0.0
        self._local_count = 0
        self._tensor_template = None
        return self

    # Mutation
    def update(self, value: float | int | Tensor, n: int = 1) -> None:
        r"""
        Updates the average and current value in the meter.

        Args:
            value: Value to be added to the average.
            n: Number of values to be added.
        """

        if isinstance(value, Tensor):
            if value.numel() == 1:
                if self.device is None:
                    self.device = value.device
                value = float(value.detach().item())
            else:
                if self._local_count > 0 and not isinstance(self._local_sum, Tensor):
                    raise ValueError("AverageMeter cannot mix scalar and tensor values.")

                value = value.detach().to(dtype=torch.float64)
                if self.device is None:
                    self.device = value.device

                if isinstance(self._local_sum, Tensor):
                    if value.shape != self._local_sum.shape:
                        raise ValueError(
                            "AverageMeter requires consistent tensor shapes, "
                            f"but got {tuple(value.shape)} after {tuple(self._local_sum.shape)}."
                        )
                    if value.device != self._local_sum.device:
                        self._local_sum = self._local_sum.to(value.device)
                        if isinstance(self._local_value, Tensor):
                            self._local_value = self._local_value.to(value.device)
                        if self._tensor_template is not None:
                            self._tensor_template = self._tensor_template.to(value.device)
                else:
                    self._local_sum = torch.zeros_like(value, dtype=torch.float64, device=value.device)

                self._tensor_template = torch.empty_like(value, dtype=torch.float64, device=value.device)
                self._local_value = value
                if n > 0:
                    self._local_sum.add_(value * n)
        if not isinstance(value, Tensor):
            if self._tensor_template is not None:
                if n == 0:
                    tensor_value = torch.full_like(self._tensor_template, float(value), dtype=torch.float64)
                    self._local_value = tensor_value
                    self._local_n = 0
                    return
                raise ValueError("AverageMeter cannot mix tensor and scalar values.")
            value = float(value)
            if isinstance(self._local_sum, Tensor):
                tensor_value = torch.tensor(value, dtype=self._local_sum.dtype, device=self._local_sum.device)
                self._local_value = tensor_value
                if n > 0:
                    self._local_sum.add_(tensor_value * n)
            else:
                self._local_value = value
                if n > 0:
                    self._local_sum += value * n
        self._local_n = n
        if n > 0:
            self._local_count += n

    # Public reductions
    def value(self) -> float | Tensor:
        if self._local_count == 0:
            empty_tensor = self._empty_tensor_value()
            if empty_tensor is not None:
                return empty_tensor
            return nan
        return self._local_value

    def batch(self) -> float | Tensor:
        world_size = self._current_world_size()
        if world_size == 1:
            return self.value()

        if self._tensor_template is not None:
            return self._tensor_batch()

        device = self._sync_device()
        synced_tensor = torch.tensor([0.0, float(self._local_n)], dtype=torch.float64, device=device)
        if self._local_n > 0:
            if isinstance(self._local_value, Tensor):
                synced_tensor[0] = self._local_value.to(device=device, dtype=torch.float64) * self._local_n
            else:
                synced_tensor[0] = float(self._local_value) * self._local_n
        dist.all_reduce(synced_tensor)
        total, count = synced_tensor.tolist()
        if count == 0:
            return nan
        return total / count

    def average(self) -> float | Tensor:
        world_size = self._current_world_size()
        if world_size == 1:
            return self._local_average()
        if self._tensor_template is not None:
            return self._tensor_average()
        device = self._sync_device()
        synced_tensor = torch.tensor([0.0, float(self._local_count)], dtype=torch.float64, device=device)
        if isinstance(self._local_sum, Tensor):
            synced_tensor[0] = self._local_sum.to(device=device, dtype=torch.float64)
        else:
            synced_tensor[0] = float(self._local_sum)
        dist.all_reduce(synced_tensor)
        val, count = synced_tensor.tolist()
        if count == 0:
            return nan
        return val / count

    # Public aliases
    @property
    def val(self) -> float | Tensor:
        return self.value()

    @property
    def bat(self) -> float | Tensor:
        return self.batch()

    @property
    def avg(self) -> float | Tensor:
        return self.average()

    # Public state accessors
    @property
    def n(self) -> int:
        return self._local_n

    @property
    def sum(self) -> float | Tensor:
        return self._local_sum

    @property
    def count(self) -> int:
        return self._local_count

    # Formatting helpers
    def __format__(self, format_spec: str) -> str:
        value = self.value()
        average = self._local_average()
        return f"{self._format_value(value, format_spec)} ({self._format_value(average, format_spec)})"

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(val={self.value()}, avg={self._local_average()})"

    # Internal helpers
    def _local_average(self) -> float | Tensor:
        if self._local_count == 0:
            empty_tensor = self._empty_tensor_value()
            if empty_tensor is not None:
                return empty_tensor
            return nan
        if isinstance(self._local_sum, Tensor) and self._tensor_template is not None:
            return self._local_sum / self._local_count
        if isinstance(self._local_sum, Tensor):
            return (self._local_sum / self._local_count).item()
        return self._local_sum / self._local_count

    def _current_world_size(self) -> int:
        if not self.distributed:
            return 1
        return get_world_size()

    def _sync_device(self) -> torch.device:
        if not (dist.is_available() and dist.is_initialized()) and self.device is not None:
            return self.device
        return infer_device()

    def _distributed_tensor_reduction(
        self,
        tensor: Tensor | None,
        count: int,
        *,
        scale_by_count: bool,
        template: Tensor | None = None,
    ) -> Tensor:
        world_size = self._current_world_size()
        device = self._sync_device()
        if template is None:
            template = self._resolve_tensor_template(world_size, device)
        else:
            template = template.to(device=device, dtype=torch.float64)
        if template is None:
            return torch.tensor(float("nan"), dtype=torch.float64, device=device)

        reduced = torch.zeros(template.numel() + 1, dtype=torch.float64, device=device)
        if count > 0 and tensor is not None:
            tensor = tensor.to(device=device, dtype=torch.float64)
            if tensor.shape != template.shape:
                raise ValueError(
                    "AverageMeter requires consistent tensor shapes across ranks, "
                    f"but got local shape {tuple(tensor.shape)} and expected {tuple(template.shape)}."
                )
            reduced[:-1] = ((tensor * count) if scale_by_count else tensor).reshape(-1)
        reduced[-1] = float(count)
        dist.all_reduce(reduced)

        total_count = int(round(reduced[-1].item()))
        if total_count == 0:
            return torch.full(template.shape, float("nan"), dtype=torch.float64, device=device)
        return (reduced[:-1] / total_count).reshape(template.shape)

    def _tensor_batch(self, template: Tensor | None = None) -> Tensor:
        return self._distributed_tensor_reduction(
            self._local_value if isinstance(self._local_value, Tensor) else None,
            self._local_n,
            scale_by_count=True,
            template=template,
        )

    def _tensor_average(self, template: Tensor | None = None) -> Tensor:
        local_sum = self._local_sum if isinstance(self._local_sum, Tensor) else None
        return self._distributed_tensor_reduction(
            local_sum,
            self._local_count,
            scale_by_count=False,
            template=template,
        )

    def _empty_tensor_value(self) -> Tensor | None:
        if self._tensor_template is None:
            return None
        return torch.full_like(self._tensor_template, float("nan"), dtype=torch.float64)

    def _resolve_tensor_template(self, world_size: int, device: torch.device) -> Tensor | None:
        if world_size == 1:
            if self._tensor_template is None:
                return None
            return self._tensor_template.to(device=device, dtype=torch.float64)

        if not (dist.is_available() and dist.is_initialized()):
            if self._tensor_template is None:
                return None
            return self._tensor_template.to(device=device, dtype=torch.float64)

        metadata = None
        if self._tensor_template is not None:
            metadata = (tuple(self._tensor_template.shape), str(self._tensor_template.dtype))
        metadata_list: list[tuple[tuple[int, ...], str] | None] = [None for _ in range(world_size)]
        dist.all_gather_object(metadata_list, metadata)
        references = [item for item in metadata_list if item is not None]
        if not references:
            return None
        if any(item != references[0] for item in references[1:]):
            raise ValueError(f"AverageMeter received inconsistent tensor metadata across ranks: {references!r}")

        shape, dtype_name = references[0]
        template = torch.empty(shape, dtype=getattr(torch, dtype_name.removeprefix("torch.")), device=device)
        template_device = self.device if self.device is not None else device
        self._tensor_template = template.to(device=template_device)
        return template

    @staticmethod
    def _format_value(value: float | Tensor, format_spec: str) -> str:
        if isinstance(value, Tensor):
            if value.numel() == 1:
                return value.item().__format__(format_spec)
            return str(value)
        return value.__format__(format_spec)

    def _tensor_spec(self) -> tuple[tuple[int, ...], str] | None:
        if self._tensor_template is None:
            return None
        return tuple(self._tensor_template.shape), str(self._tensor_template.dtype)

reset

Python
reset(*, device: device | str | None = None) -> Self

Resets the meter.

Source code in danling/metrics/average_meter.py
Python
def reset(self, *, device: torch.device | str | None = None) -> Self:
    r"""
    Resets the meter.
    """

    if device is not None:
        self.device = torch_device(device)
    self._local_value = 0.0
    self._local_n = 0
    self._local_sum = 0.0
    self._local_count = 0
    self._tensor_template = None
    return self

update

Python
update(value: float | int | Tensor, n: int = 1) -> None

Updates the average and current value in the meter.

Parameters:

Name Type Description Default
value
float | int | Tensor

Value to be added to the average.

required
n
int

Number of values to be added.

1
Source code in danling/metrics/average_meter.py
Python
def update(self, value: float | int | Tensor, n: int = 1) -> None:
    r"""
    Updates the average and current value in the meter.

    Args:
        value: Value to be added to the average.
        n: Number of values to be added.
    """

    if isinstance(value, Tensor):
        if value.numel() == 1:
            if self.device is None:
                self.device = value.device
            value = float(value.detach().item())
        else:
            if self._local_count > 0 and not isinstance(self._local_sum, Tensor):
                raise ValueError("AverageMeter cannot mix scalar and tensor values.")

            value = value.detach().to(dtype=torch.float64)
            if self.device is None:
                self.device = value.device

            if isinstance(self._local_sum, Tensor):
                if value.shape != self._local_sum.shape:
                    raise ValueError(
                        "AverageMeter requires consistent tensor shapes, "
                        f"but got {tuple(value.shape)} after {tuple(self._local_sum.shape)}."
                    )
                if value.device != self._local_sum.device:
                    self._local_sum = self._local_sum.to(value.device)
                    if isinstance(self._local_value, Tensor):
                        self._local_value = self._local_value.to(value.device)
                    if self._tensor_template is not None:
                        self._tensor_template = self._tensor_template.to(value.device)
            else:
                self._local_sum = torch.zeros_like(value, dtype=torch.float64, device=value.device)

            self._tensor_template = torch.empty_like(value, dtype=torch.float64, device=value.device)
            self._local_value = value
            if n > 0:
                self._local_sum.add_(value * n)
    if not isinstance(value, Tensor):
        if self._tensor_template is not None:
            if n == 0:
                tensor_value = torch.full_like(self._tensor_template, float(value), dtype=torch.float64)
                self._local_value = tensor_value
                self._local_n = 0
                return
            raise ValueError("AverageMeter cannot mix tensor and scalar values.")
        value = float(value)
        if isinstance(self._local_sum, Tensor):
            tensor_value = torch.tensor(value, dtype=self._local_sum.dtype, device=self._local_sum.device)
            self._local_value = tensor_value
            if n > 0:
                self._local_sum.add_(tensor_value * n)
        else:
            self._local_value = value
            if n > 0:
                self._local_sum += value * n
    self._local_n = n
    if n > 0:
        self._local_count += n

AverageMeters

Bases: MetersBase

Manages multiple average meters in one object.

Examples:

Python Console Session
>>> meters = AverageMeters()
>>> meters.update({"loss": 0.6, "auroc": 0.7, "r2": 0.8})
>>> f"{meters:.4f}"
'loss: 0.6000 (0.6000)\tauroc: 0.7000 (0.7000)\tr2: 0.8000 (0.8000)'
>>> meters['loss'].update(value=0.9, n=1)
>>> f"{meters:.4f}"
'loss: 0.9000 (0.7500)\tauroc: 0.7000 (0.7000)\tr2: 0.8000 (0.8000)'
>>> meters.sum.dict()
{'loss': 1.5, 'auroc': 0.7, 'r2': 0.8}
>>> meters.count.dict()
{'loss': 2, 'auroc': 1, 'r2': 1}
>>> meters.reset()
AverageMeters(...)
>>> f"{meters:.4f}"
'loss: nan (nan)\tauroc: nan (nan)\tr2: nan (nan)'
See Also
  • StreamMetrics: Memory-efficient metric tracker that averages multiple metrics batch-by-batch.
Source code in danling/metrics/average_meter.py
Python
class AverageMeters(MetersBase):
    r"""
    Manages multiple average meters in one object.

    Examples:
        >>> meters = AverageMeters()
        >>> meters.update({"loss": 0.6, "auroc": 0.7, "r2": 0.8})
        >>> f"{meters:.4f}"
        'loss: 0.6000 (0.6000)\tauroc: 0.7000 (0.7000)\tr2: 0.8000 (0.8000)'
        >>> meters['loss'].update(value=0.9, n=1)
        >>> f"{meters:.4f}"
        'loss: 0.9000 (0.7500)\tauroc: 0.7000 (0.7000)\tr2: 0.8000 (0.8000)'
        >>> meters.sum.dict()
        {'loss': 1.5, 'auroc': 0.7, 'r2': 0.8}
        >>> meters.count.dict()
        {'loss': 2, 'auroc': 1, 'r2': 1}
        >>> meters.reset()
        AverageMeters(...)
        >>> f"{meters:.4f}"
        'loss: nan (nan)\tauroc: nan (nan)\tr2: nan (nan)'

    See Also:
        - [`StreamMetrics`][danling.metrics.stream_metrics.StreamMetrics]:
            Memory-efficient metric tracker that averages multiple metrics batch-by-batch.
    """

    meter_cls = AverageMeter  # type: ignore[assignment]

    # Aggregate state accessors
    @property
    def n(self) -> RoundDict[str, int]:
        return RoundDict({key: meter.n for key, meter in self.all_items()})

    @property
    def sum(self) -> RoundDict[str, float | Tensor]:
        return RoundDict({key: meter.sum for key, meter in self.all_items()})

    @property
    def count(self) -> RoundDict[str, int]:
        return RoundDict({key: meter.count for key, meter in self.all_items()})

    # Public reductions
    def batch(self) -> RoundDict[str, float | Tensor]:
        items = list(self.all_items())
        sync_names = [name for name, meter in items if meter._current_world_size() > 1]
        if not sync_names:
            return super().batch()

        device = self[sync_names[0]]._sync_device()
        tensor_templates = self._resolved_tensor_templates(sync_names, device)
        tensor_sync_names = set(tensor_templates)
        scalar_sync_names = [name for name in sync_names if name not in tensor_sync_names]
        if not scalar_sync_names:
            return RoundDict(
                {
                    name: (
                        meter._tensor_batch(tensor_templates.get(name)) if name in tensor_sync_names else meter.batch()
                    )
                    for name, meter in items
                }
            )

        reduced = torch.zeros(len(scalar_sync_names) * 2, dtype=torch.float64, device=device)
        sync_indices = {name: idx for idx, name in enumerate(scalar_sync_names)}

        for name in scalar_sync_names:
            meter = self[name]
            offset = sync_indices[name] * 2
            if meter._local_n > 0:
                if isinstance(meter._local_value, Tensor):
                    reduced[offset] = meter._local_value.to(device=device, dtype=torch.float64) * meter._local_n
                else:
                    reduced[offset] = float(meter._local_value) * meter._local_n
            reduced[offset + 1] = float(meter._local_n)

        dist.all_reduce(reduced)

        batches: dict[str, float | Tensor] = {}
        for name, meter in items:
            sync_index = sync_indices.get(name)
            if sync_index is None:
                batches[name] = (
                    meter._tensor_batch(tensor_templates.get(name)) if name in tensor_sync_names else meter.batch()
                )
                continue

            total, count = reduced[sync_index * 2 : sync_index * 2 + 2].tolist()
            batches[name] = nan if count == 0 else total / count

        return RoundDict(batches)

    def average(self) -> RoundDict[str, float | Tensor]:
        items = list(self.all_items())
        sync_names = [name for name, meter in items if meter._current_world_size() > 1]
        if not sync_names:
            return super().average()

        device = self[sync_names[0]]._sync_device()
        tensor_templates = self._resolved_tensor_templates(sync_names, device)
        tensor_sync_names = set(tensor_templates)
        scalar_sync_names = [name for name in sync_names if name not in tensor_sync_names]
        if not scalar_sync_names:
            return RoundDict(
                {
                    name: (
                        meter._tensor_average(tensor_templates.get(name))
                        if name in tensor_sync_names
                        else meter.average()
                    )
                    for name, meter in items
                }
            )

        reduced = torch.zeros(len(scalar_sync_names) * 2, dtype=torch.float64, device=device)
        sync_indices = {name: idx for idx, name in enumerate(scalar_sync_names)}

        for name in scalar_sync_names:
            meter = self[name]
            offset = sync_indices[name] * 2
            if isinstance(meter._local_sum, Tensor):
                reduced[offset] = meter._local_sum.to(device=device, dtype=torch.float64)
            else:
                reduced[offset] = float(meter._local_sum)
            reduced[offset + 1] = float(meter._local_count)

        dist.all_reduce(reduced)

        averages: dict[str, float | Tensor] = {}
        for name, meter in items:
            sync_index = sync_indices.get(name)
            if sync_index is None:
                averages[name] = (
                    meter._tensor_average(tensor_templates.get(name)) if name in tensor_sync_names else meter.average()
                )
                continue

            total, count = reduced[sync_index * 2 : sync_index * 2 + 2].tolist()
            averages[name] = nan if count == 0 else total / count

        return RoundDict(averages)

    def _resolved_tensor_templates(self, sync_names: list[str], device: torch.device) -> dict[str, Tensor]:
        if not sync_names:
            return {}

        local_specs = {name: self[name]._tensor_spec() for name in sync_names}
        if not (dist.is_available() and dist.is_initialized()):
            return {
                name: self[name]._resolve_tensor_template(self[name]._current_world_size(), device)
                for name, spec in local_specs.items()
                if spec is not None
            }

        world_size = self[sync_names[0]]._current_world_size()
        gathered_specs: list[dict[str, tuple[tuple[int, ...], str] | None]] = [{} for _ in range(world_size)]
        dist.all_gather_object(gathered_specs, local_specs)

        templates: dict[str, Tensor] = {}
        for name in sync_names:
            references: list[tuple[tuple[int, ...], str]] = []
            for spec in gathered_specs:
                reference = spec.get(name)
                if reference is not None:
                    references.append(reference)
            if not references:
                continue
            if any(reference != references[0] for reference in references[1:]):
                raise ValueError(
                    f"AverageMeters received inconsistent tensor metadata for meter {name!r}: {references!r}"
                )

            shape, dtype_name = references[0]
            template = torch.empty(shape, dtype=getattr(torch, dtype_name.removeprefix("torch.")), device=device)
            meter = self[name]
            template_device = meter.device if meter.device is not None else device
            meter._tensor_template = template.to(device=template_device, dtype=torch.float64)
            templates[name] = template.to(dtype=torch.float64)
        return templates

    # Mutation
    def update(
        self, *args: Mapping[str, int | float | Tensor], **values: int | float | Tensor
    ) -> None:  # pylint: disable=W0237
        r"""
        Updates the average and current value in all meters.

        Args:
            values: Mapping or keyword values to be added to the corresponding meters.
        """  # noqa: E501

        if args:
            if len(args) > 1:
                raise ValueError("Expected only one positional argument, but got multiple.")
            values = dict(args[0]) | values

        for meter, value in values.items():
            self[meter].update(value)

update

Python
update(
    *args: Mapping[str, int | float | Tensor],
    **values: int | float | Tensor
) -> None

Updates the average and current value in all meters.

Parameters:

Name Type Description Default
values
int | float | Tensor

Mapping or keyword values to be added to the corresponding meters.

{}
Source code in danling/metrics/average_meter.py
Python
def update(
    self, *args: Mapping[str, int | float | Tensor], **values: int | float | Tensor
) -> None:  # pylint: disable=W0237
    r"""
    Updates the average and current value in all meters.

    Args:
        values: Mapping or keyword values to be added to the corresponding meters.
    """  # noqa: E501

    if args:
        if len(args) > 1:
            raise ValueError("Expected only one positional argument, but got multiple.")
        values = dict(args[0]) | values

    for meter, value in values.items():
        self[meter].update(value)

MetricMeter

Bases: AverageMeter

A memory-efficient metric tracker that computes and averages metrics across batches.

MetricMeter applies a metric function to each batch and maintains running averages without storing the complete history of predictions and labels. This makes it ideal for metrics that can be meaningfully averaged across batches (like accuracy or loss).

Attributes:

Name Type Description
metric Callable | MetricFunc

The metric function to compute on each batch

preprocess

Optional preprocessing function applied before the metric

val float | Tensor

Result from the most recent batch on the current rank

bat float | Tensor

Synchronized metric result for the current step

avg float | Tensor

Weighted average of all results so far

sum float | Tensor

Running sum of (metric × batch_size) values

count int

Running sum of batch sizes

Parameters:

Name Type Description Default

metric

Callable | MetricFunc

Function that computes a metric given input and target tensors

required

preprocess

Callable | None

Optional preprocessing function to apply before computing the metric

None

Examples:

Python Console Session
>>> import torch
>>> from danling.metrics.functional import accuracy
>>> meter = MetricMeter(accuracy)
>>> meter.update(torch.tensor([0.1, 0.8, 0.6, 0.2]), torch.tensor([0, 1, 0, 0]))
>>> meter.val
0.75
>>> meter.avg
0.75
>>> meter.update(torch.tensor([0.1, 0.7, 0.3, 0.2, 0.8, 0.4]), torch.tensor([0, 1, 1, 0, 0, 1]))
>>> meter.val
0.5
>>> meter.avg
0.6
>>> meter.sum
6.0
>>> meter.count
10
>>> meter.reset()
MetricMeter(accuracy)
>>> meter.val
nan
>>> meter.avg
nan
Notes
  • MetricMeter is more memory-efficient than GlobalMetrics because it only stores running statistics
  • Only suitable for metrics that can be meaningfully averaged batch-by-batch
  • Not suitable for metrics like AUROC that need the entire dataset
  • Metrics are evaluated once per update; batch-vs-sample semantics are determined by the metric itself
  • Stream metrics may return tensors; tensor outputs are averaged elementwise across batches
  • MetricFunc descriptors receive [MetricState][danling.metrics.MetricState]
  • Plain callables receive preprocessed input / target tensors
  • For multiple metrics, use StreamMetrics
See Also
  • AverageMeter: A lightweight utility to compute and store running averages of values.
Source code in danling/metrics/stream_metrics.py
Python
class MetricMeter(AverageMeter):
    r"""
    A memory-efficient metric tracker that computes and averages metrics across batches.

    MetricMeter applies a metric function to each batch and maintains running averages
    without storing the complete history of predictions and labels. This makes it ideal for
    metrics that can be meaningfully averaged across batches (like accuracy or loss).

    Attributes:
        metric: The metric function to compute on each batch
        preprocess: Optional preprocessing function applied before the metric
        val: Result from the most recent batch on the current rank
        bat: Synchronized metric result for the current step
        avg: Weighted average of all results so far
        sum: Running sum of (metric × batch_size) values
        count: Running sum of batch sizes

    Args:
        metric: Function that computes a metric given input and target tensors
        preprocess: Optional preprocessing function to apply before computing the metric

    Examples:
        >>> import torch
        >>> from danling.metrics.functional import accuracy
        >>> meter = MetricMeter(accuracy)
        >>> meter.update(torch.tensor([0.1, 0.8, 0.6, 0.2]), torch.tensor([0, 1, 0, 0]))
        >>> meter.val
        0.75
        >>> meter.avg
        0.75
        >>> meter.update(torch.tensor([0.1, 0.7, 0.3, 0.2, 0.8, 0.4]), torch.tensor([0, 1, 1, 0, 0, 1]))
        >>> meter.val
        0.5
        >>> meter.avg
        0.6
        >>> meter.sum
        6.0
        >>> meter.count
        10
        >>> meter.reset()
        MetricMeter(accuracy)
        >>> meter.val
        nan
        >>> meter.avg
        nan

    Notes:
        - MetricMeter is more memory-efficient than [`GlobalMetrics`][danling.metrics.global_metrics.GlobalMetrics]
          because it only stores running statistics
        - Only suitable for metrics that can be meaningfully averaged batch-by-batch
        - Not suitable for metrics like AUROC that need the entire dataset
        - Metrics are evaluated once per update; batch-vs-sample semantics are determined by the metric itself
        - Stream metrics may return tensors; tensor outputs are averaged elementwise across batches
        - `MetricFunc` descriptors receive [`MetricState`][danling.metrics.MetricState]
        - Plain callables receive preprocessed `input` / `target` tensors
        - For multiple metrics, use [`StreamMetrics`][danling.metrics.stream_metrics.StreamMetrics]

    See Also:
        - [`AverageMeter`][danling.metrics.average_meter.AverageMeter]:
            A lightweight utility to compute and store running averages of values.
    """

    metric: Callable | MetricFunc
    output_name: str | None = None

    # Construction
    def __init__(
        self,
        metric: Callable | MetricFunc,
        *,
        preprocess: Callable | None = None,
        device: torch.device | str | None = None,
        distributed: bool = True,
    ) -> None:
        super().__init__(device=device, distributed=distributed)
        if not callable(metric):
            raise ValueError(f"Expected metric to be callable, but got {type(metric)}")
        self.metric = metric
        self.preprocess = preprocess
        self._requirements = MetricState.collect_requirements((metric,)) if isinstance(metric, MetricFunc) else None

    # Mutation
    def update(  # type: ignore[override] # pylint: disable=W0237
        self,
        input: Tensor | NestedTensor,  # pylint: disable=W0622
        target: Tensor | NestedTensor,
        *,
        n: int | None = None,
    ) -> None:
        r"""
        Updates the average and current value in the meter.

        Args:
            input: Prediction tensor or nested tensor.
            target: Ground-truth tensor or nested tensor.
            n: Optional number of samples represented by this update. When omitted,
                the batch size is inferred from the inputs.
        """

        if self.preprocess is not None:
            input, target = self.preprocess(input, target)

        self._update_state(self._build_state(input, target), n=n)

    # Internal helpers
    def _update_state(self, state: MetricState, *, n: int | None = None) -> None:
        if n is None:
            try:
                n = len(state.preds)
            except TypeError:
                n = 1

        value = self._compute_metric(state)
        super().update(value=value, n=n)

    def _compute_metric(self, state: MetricState) -> Tensor | float:
        if isinstance(self.metric, MetricFunc):
            return self._normalize_value(self.metric(state))
        return self._normalize_value(self.metric(state.preds, state.targets))

    def _build_state(
        self,
        input: Tensor | NestedTensor | Sequence,
        target: Tensor | NestedTensor | Sequence,
    ) -> MetricState:
        return MetricState.from_requirements(input, target, self._requirements)

    @staticmethod
    def _normalize_value(value: Tensor | float | int) -> Tensor | float:
        if isinstance(value, Tensor):
            if value.numel() == 0:
                return torch.tensor(float("nan"))
            if value.numel() == 1:
                return value.item()
            return value.detach()
        return float(value)

    def __repr__(self):
        metric = self.metric
        if isinstance(metric, MetricFunc):
            return f"{self.__class__.__name__}({metric.name})"
        if isinstance(metric, partial):
            metric = metric.func
        return f"{self.__class__.__name__}({metric.__name__})"

update

Python
update(
    input: Tensor | NestedTensor,
    target: Tensor | NestedTensor,
    *,
    n: int | None = None
) -> None

Updates the average and current value in the meter.

Parameters:

Name Type Description Default
input
Tensor | NestedTensor

Prediction tensor or nested tensor.

required
target
Tensor | NestedTensor

Ground-truth tensor or nested tensor.

required
n
int | None

Optional number of samples represented by this update. When omitted, the batch size is inferred from the inputs.

None
Source code in danling/metrics/stream_metrics.py
Python
def update(  # type: ignore[override] # pylint: disable=W0237
    self,
    input: Tensor | NestedTensor,  # pylint: disable=W0622
    target: Tensor | NestedTensor,
    *,
    n: int | None = None,
) -> None:
    r"""
    Updates the average and current value in the meter.

    Args:
        input: Prediction tensor or nested tensor.
        target: Ground-truth tensor or nested tensor.
        n: Optional number of samples represented by this update. When omitted,
            the batch size is inferred from the inputs.
    """

    if self.preprocess is not None:
        input, target = self.preprocess(input, target)

    self._update_state(self._build_state(input, target), n=n)

StreamMetrics

Bases: AverageMeters

A container for managing multiple MetricMeter instances with shared preprocessing.

StreamMetrics allows you to organize and track multiple metrics in a unified interface, with consistent preprocessing applied to all inputs before computing each metric. This is particularly useful when you want to track several metrics that can be meaningfully averaged across batches.

Attributes:

Name Type Description
preprocess

Shared preprocessing function for all meters

val RoundDict[str, float | Tensor]

Dictionary of current local values from all meters

bat RoundDict[str, float | Tensor]

Dictionary of synchronized current-step values from all meters

avg RoundDict[str, float | Tensor]

Dictionary of running averages from all meters

sum RoundDict[str, float | Tensor]

Dictionary of sums from all meters

count RoundDict[str, int]

Dictionary of counts from all meters

Parameters:

Name Type Description Default

*args

Metric functions to register as meters

required

preprocess

Callable

Preprocessing function to apply to inputs before computing metrics

base_preprocess

**meters

Named MetricMeter instances or metric functions

{}

Examples:

Python Console Session
>>> import torch
>>> from danling.metrics.functional import accuracy
>>> meters = StreamMetrics(acc=accuracy)
>>> meters.update([0.1, 0.8, 0.6, 0.2], [0, 1, 0, 0])
>>> round(meters.val["acc"], 4)
0.75
>>> round(meters.avg["acc"], 4)
0.75
>>> meters["acc"].update(torch.tensor([0.2, 0.8]), torch.tensor([0, 1]))
>>> meters.count["acc"]
6
>>> meters.update(dict(loss=""))
Traceback (most recent call last):
TypeError: ...update() missing 1 required positional argument: 'target'
Notes
  • StreamMetrics manages multiple MetricMeter instances with shared preprocessing
  • Each metric is computed independently but uses the same inputs
  • All meters are updated simultaneously when you call update()
  • Individual meters can be accessed like dictionary items or attributes
  • Metrics are evaluated once per update; batch-vs-sample semantics are determined by the metric itself
  • Tensor-valued metrics are preserved and averaged elementwise across batches
  • Built-in MetricFunc stream values may be approximate rather than exact dataset-level metrics
See Also
  • AverageMeters: A container for managing multiple average meters in one object.
  • GlobalMetrics: Metric tracker that stores the complete prediction and target history.
Source code in danling/metrics/stream_metrics.py
Python
class StreamMetrics(AverageMeters):
    r"""
    A container for managing multiple MetricMeter instances with shared preprocessing.

    StreamMetrics allows you to organize and track multiple metrics in a unified interface,
    with consistent preprocessing applied to all inputs before computing each metric.
    This is particularly useful when you want to track several metrics that can be
    meaningfully averaged across batches.

    Attributes:
        preprocess: Shared preprocessing function for all meters
        val: Dictionary of current local values from all meters
        bat: Dictionary of synchronized current-step values from all meters
        avg: Dictionary of running averages from all meters
        sum: Dictionary of sums from all meters
        count: Dictionary of counts from all meters

    Args:
        *args: Metric functions to register as meters
        preprocess: Preprocessing function to apply to inputs before computing metrics
        **meters: Named MetricMeter instances or metric functions

    Examples:
        >>> import torch
        >>> from danling.metrics.functional import accuracy
        >>> meters = StreamMetrics(acc=accuracy)
        >>> meters.update([0.1, 0.8, 0.6, 0.2], [0, 1, 0, 0])
        >>> round(meters.val["acc"], 4)
        0.75
        >>> round(meters.avg["acc"], 4)
        0.75
        >>> meters["acc"].update(torch.tensor([0.2, 0.8]), torch.tensor([0, 1]))
        >>> meters.count["acc"]
        6
        >>> meters.update(dict(loss=""))  # doctest: +ELLIPSIS
        Traceback (most recent call last):
        TypeError: ...update() missing 1 required positional argument: 'target'

    Notes:
        - `StreamMetrics` manages multiple `MetricMeter` instances with shared preprocessing
        - Each metric is computed independently but uses the same inputs
        - All meters are updated simultaneously when you call `update()`
        - Individual meters can be accessed like dictionary items or attributes
        - Metrics are evaluated once per update; batch-vs-sample semantics are determined by the metric itself
        - Tensor-valued metrics are preserved and averaged elementwise across batches
        - Built-in `MetricFunc` stream values may be approximate rather than exact dataset-level metrics

    See Also:
        - [`AverageMeters`][danling.metrics.average_meter.AverageMeters]:
            A container for managing multiple average meters in one object.
        - [`GlobalMetrics`][danling.metrics.global_metrics.GlobalMetrics]:
            Metric tracker that stores the complete prediction and target history.
    """

    preprocess = base_preprocess
    meter_cls = MetricMeter  # type: ignore[assignment]

    # Construction
    def __init__(
        self,
        *metric_funcs,
        preprocess: Callable = base_preprocess,
        distributed: bool = True,
        device: torch.device | str | None = None,
        **meters,
    ) -> None:
        self.setattr("_initializing_meters", True)
        self.setattr("_requirements", None)
        positional: list[tuple[str, Callable | MetricMeter]] = []
        for metric in iter_metric_funcs(metric_funcs):
            if not callable(metric):
                raise ValueError(f"Expected metric to be callable, but got {type(metric)}")
            positional.append((infer_metric_name(metric), metric))

        named: dict[str, Callable | MetricMeter] = {}
        for name, metric in meters.items():
            if not isinstance(metric, MetricMeter) and not callable(metric):
                raise ValueError(f"Expected metric to be callable or MetricMeter, but got {type(metric)}")
            named[name] = metric

        meters = merge_metric_entries(positional, named)
        self.setattr("preprocess", preprocess)
        self.setattr("distributed", distributed)
        self.setattr("device", torch.device(device) if device is not None else None)
        try:
            super().__init__(**meters)
        finally:
            self.setattr("_initializing_meters", False)
        self._refresh_requirements()

    # Meter registration
    def _coerce_metric(self, value: Callable | MetricFunc | MetricMeter) -> MetricMeter:
        meter_cls: type[MetricMeter] = getattr(type(self), "meter_cls", MetricMeter)
        if isinstance(value, meter_cls):
            value.preprocess = None
            if self.device is not None:
                value.device = self.device
            value.distributed = self.distributed
            value._requirements = (
                MetricState.collect_requirements((value.metric,)) if isinstance(value.metric, MetricFunc) else None
            )
            return value
        if callable(value):
            return meter_cls(value, preprocess=None, device=self.device, distributed=self.distributed)
        raise ValueError(f"Expected meter to be an instance of {meter_cls.__name__}, but got {type(value)}")

    def _coerce_meter(self, value):  # type: ignore[override]
        return self._coerce_metric(value)

    def set(self, name, value) -> None:  # type: ignore[override]
        super().set(name, value)
        if not self.getattr("_initializing_meters", False):
            self._refresh_requirements()

    # Mutation
    def update(  # type: ignore[override] # pylint: disable=W0221
        self,
        input: Tensor | NestedTensor | Sequence,  # pylint: disable=W0622
        target: Tensor | NestedTensor | Sequence,
        *,
        n: int | None = None,
    ) -> None:
        r"""
        Updates the average and current value in all meters.

        Args:
            input: Input values to compute the metrics.
            target: Target values to compute the metrics.
            n: Optional number of samples represented by this update. Defaults to
                the inferred batch size.
        """

        input, target = self.preprocess(input, target)  # type: ignore[arg-type]
        if isinstance(input, (Tensor, NestedTensor)):
            input = input.detach()
        if isinstance(target, (Tensor, NestedTensor)):
            target = target.detach()
        state = self._build_state(input, target)
        for meter in self.values():
            if isinstance(meter, MetricMeter):
                meter._update_state(state, n=n)
            else:
                meter.update(input, target, n=n)

    # Internal helpers
    def _collect_requirements_from_meters(self):
        metric_funcs = []
        for meter in self.values():
            if isinstance(meter, MetricMeter) and isinstance(meter.metric, MetricFunc):
                metric_funcs.append(meter.metric)
        if not metric_funcs:
            return None
        return MetricState.collect_requirements(metric_funcs)

    def _refresh_requirements(self) -> None:
        self.setattr("_requirements", self._collect_requirements_from_meters())

    def _build_state(
        self,
        input: Tensor | NestedTensor | Sequence,
        target: Tensor | NestedTensor | Sequence,
    ) -> MetricState:
        return MetricState.from_requirements(input, target, self._requirements)

    # Formatting helpers
    def __repr__(self):
        keys = tuple(i for i in self.keys())
        return f"{self.__class__.__name__}{keys}"

update

Python
update(
    input: Tensor | NestedTensor | Sequence,
    target: Tensor | NestedTensor | Sequence,
    *,
    n: int | None = None
) -> None

Updates the average and current value in all meters.

Parameters:

Name Type Description Default
input
Tensor | NestedTensor | Sequence

Input values to compute the metrics.

required
target
Tensor | NestedTensor | Sequence

Target values to compute the metrics.

required
n
int | None

Optional number of samples represented by this update. Defaults to the inferred batch size.

None
Source code in danling/metrics/stream_metrics.py
Python
def update(  # type: ignore[override] # pylint: disable=W0221
    self,
    input: Tensor | NestedTensor | Sequence,  # pylint: disable=W0622
    target: Tensor | NestedTensor | Sequence,
    *,
    n: int | None = None,
) -> None:
    r"""
    Updates the average and current value in all meters.

    Args:
        input: Input values to compute the metrics.
        target: Target values to compute the metrics.
        n: Optional number of samples represented by this update. Defaults to
            the inferred batch size.
    """

    input, target = self.preprocess(input, target)  # type: ignore[arg-type]
    if isinstance(input, (Tensor, NestedTensor)):
        input = input.detach()
    if isinstance(target, (Tensor, NestedTensor)):
        target = target.detach()
    state = self._build_state(input, target)
    for meter in self.values():
        if isinstance(meter, MetricMeter):
            meter._update_state(state, n=n)
        else:
            meter.update(input, target, n=n)

LRScheduler

Bases: _LRScheduler

General learning rate scheduler.

PyTorch LRScheduler is hard to extend. This class is a wrapper of PyTorch LRScheduler, which provides a more general interface. You only needs to add a new scaling which calculates a learning rate ratio (range from 0 to 1) with total progress (range from 0 to 1), and everything else will be done automatically.

Moreover, this class has warmup and cooldown built-in. By default, the first 5% and last 20% of training steps will be warmup and cooldown respectively. You can alternate by passing warmup_steps and cooldown_steps, or disable them by setting them to 0.

Parameters:

Name Type Description Default

optimizer

Optimizer

Wrapped optimizer.

required

total_steps

int

Total number of trainable steps.

required

final_lr_ratio

Optional[float]

Final learning rate ratio to initial learning rate. Defaults to 1e-3.

None

final_lr

Optional[float]

Final learning rate.

None

min_lr

float

Minimal learning rate. Defaults to 1e-9.

1e-09

method

str

Scaling method. Defaults to “cosine”.

'cosine'

warmup_steps

Optional[int]

Number of warmup steps. Defaults to steps // 20.

None

cooldown_steps

Optional[int]

Number of cooldown steps. Defaults to steps // 5.

None

last_epoch

int

The index of last epoch. Defaults to -1.

-1

scaling

Optional[str]

Method to calculate learning rate given ratio, should be one of “percentile” or “numerical”. Defaults to “percentile” if final_lr_ratio is set, otherwise “numerical”.

None

Examples:

Python Console Session
>>> from danling.optim import LRScheduler
>>> import torch
>>> from torch import optim
>>> optimizer = optim.SGD([{'params': torch.tensor([0])}], lr=1, momentum=0.9)
>>> scheduler = LRScheduler(optimizer, total_steps=5, final_lr_ratio=1e-5, method='linear')
>>> lrs = []
>>> for epoch in range(5):
...     lrs.append(scheduler.get_lr()[0])
...     scheduler.step()
>>> [round(lr, 10) for lr in lrs]
[0.1, 0.01, 0.001, 0.0001, 1e-09]
>>> scheduler = LRScheduler(optimizer, total_steps=5, final_lr_ratio=1e-5, method='cosine')
>>> lrs = []
>>> for epoch in range(5):
...     lrs.append(scheduler.get_lr()[0])
...     scheduler.step()
>>> [round(lr, 10) for lr in lrs]
[0.3330753446, 0.0187302031, 0.000533897, 3.00232e-05, 1e-09]
>>> scheduler = LRScheduler(optimizer, total_steps=5, final_lr_ratio=1e-5, method='linear', scaling='numerical')
>>> lrs = []
>>> for epoch in range(5):
...     lrs.append(scheduler.get_lr()[0])
...     scheduler.step()
>>> [round(lr, 2) for lr in lrs]
[0.8, 0.6, 0.4, 0.2, 0.0]
Source code in danling/optim/lr_scheduler/lr_scheduler.py
Python
class LRScheduler(lr_scheduler._LRScheduler):  # pylint: disable=protected-access
    r"""
    General learning rate scheduler.

    PyTorch LRScheduler is hard to extend.
    This class is a wrapper of PyTorch LRScheduler, which provides a more general interface.
    You only needs to add a new scaling which calculates a learning rate ratio (range from 0 to 1)
    with total progress (range from 0 to 1), and everything else will be done automatically.

    Moreover, this class has warmup and cooldown built-in.
    By default, the first 5% and last 20% of training steps will be warmup and cooldown respectively.
    You can alternate by passing `warmup_steps` and `cooldown_steps`, or disable them by setting them to 0.

    Args:
        optimizer: Wrapped optimizer.
        total_steps: Total number of trainable steps.
        final_lr_ratio: Final learning rate ratio to initial learning rate.
            Defaults to 1e-3.
        final_lr: Final learning rate.
        min_lr: Minimal learning rate.
            Defaults to 1e-9.
        method: Scaling method.
            Defaults to "cosine".
        warmup_steps: Number of warmup steps.
            Defaults to `steps // 20`.
        cooldown_steps: Number of cooldown steps.
            Defaults to `steps // 5`.
        last_epoch: The index of last epoch.
            Defaults to -1.
        scaling: Method to calculate learning rate given ratio, should be one of "percentile" or "numerical".
            Defaults to "percentile" if `final_lr_ratio` is set, otherwise "numerical".

    Examples:
        >>> from danling.optim import LRScheduler
        >>> import torch
        >>> from torch import optim
        >>> optimizer = optim.SGD([{'params': torch.tensor([0])}], lr=1, momentum=0.9)
        >>> scheduler = LRScheduler(optimizer, total_steps=5, final_lr_ratio=1e-5, method='linear')
        >>> lrs = []
        >>> for epoch in range(5):
        ...     lrs.append(scheduler.get_lr()[0])
        ...     scheduler.step()
        >>> [round(lr, 10) for lr in lrs]
        [0.1, 0.01, 0.001, 0.0001, 1e-09]
        >>> scheduler = LRScheduler(optimizer, total_steps=5, final_lr_ratio=1e-5, method='cosine')
        >>> lrs = []
        >>> for epoch in range(5):
        ...     lrs.append(scheduler.get_lr()[0])
        ...     scheduler.step()
        >>> [round(lr, 10) for lr in lrs]
        [0.3330753446, 0.0187302031, 0.000533897, 3.00232e-05, 1e-09]
        >>> scheduler = LRScheduler(optimizer, total_steps=5, final_lr_ratio=1e-5, method='linear', scaling='numerical')
        >>> lrs = []
        >>> for epoch in range(5):
        ...     lrs.append(scheduler.get_lr()[0])
        ...     scheduler.step()
        >>> [round(lr, 2) for lr in lrs]
        [0.8, 0.6, 0.4, 0.2, 0.0]
    """  # noqa: E501

    def __init__(
        self,
        optimizer: Optimizer,
        total_steps: int,
        final_lr_ratio: Optional[float] = None,
        final_lr: Optional[float] = None,
        min_lr: float = 1e-9,
        method: str = "cosine",
        warmup_steps: Optional[int] = None,
        cooldown_steps: Optional[int] = None,
        last_epoch: int = -1,
        scaling: Optional[str] = None,
    ):
        if total_steps <= 0:
            raise ValueError(f"Total steps must be positive, but got {total_steps}")
        if warmup_steps is None:
            warmup_steps = total_steps // 20
        elif warmup_steps > total_steps:
            raise ValueError(f"Warmup steps must be less than total steps, but got {warmup_steps} > {total_steps}")
        elif warmup_steps < 0:
            raise ValueError(f"Warmup steps must be positive, but got {warmup_steps}")
        if cooldown_steps is None:
            cooldown_steps = total_steps // 5
        elif cooldown_steps > total_steps:
            raise ValueError(f"Cooldown steps must be less than total steps, but got {cooldown_steps} > {total_steps}")
        elif cooldown_steps < 0:
            raise ValueError(f"Cooldown steps must be positive, but got {cooldown_steps}")
        if warmup_steps + cooldown_steps > total_steps:
            raise ValueError(
                "Warmup steps + cooldown steps must be less than total steps, "
                f"but got {warmup_steps} + {cooldown_steps} > {total_steps}"
            )
        if final_lr_ratio is not None:
            if final_lr is not None:
                raise ValueError("Only one of `final_lr_ratio` and `final_lr` should be set, but not both")
            if final_lr_ratio < 0:
                raise ValueError(f"`final_lr_ratio` must be positive, but got {final_lr_ratio}")
            if scaling is None:
                scaling = "percentile"
        if final_lr is not None and final_lr < 0:
            raise ValueError(f"`final_lr` must be positive, but got {final_lr}")
        if min_lr < 0:
            raise ValueError(f"`min_lr` must be positive, but got {min_lr}")
        self.strategies = {
            k: v for k, v in self.__class__.__dict__.items() if callable(v) and (not k.startswith("_") or k in "get_lr")
        }
        if method not in self.strategies:
            raise ValueError(f"Scaling method must be one of {self.strategies.keys()}, but got {method}")

        if final_lr_ratio is None and final_lr is None:
            final_lr_ratio = 1e-3
            if scaling is None:
                scaling = "percentile"
        if final_lr is not None and min_lr > final_lr:
            min_lr = final_lr
        if scaling is None:
            scaling = "numerical"

        self.final_lr_ratio = final_lr_ratio
        self.final_lr = final_lr
        self.total_steps = total_steps
        self.min_lr = min_lr
        self.method = method
        self.scaling = scaling
        self.warmup_steps = warmup_steps
        self.cooldown_steps = cooldown_steps
        self.cooldown_steps_begin = self.total_steps - self.cooldown_steps
        super().__init__(optimizer, last_epoch)

    def get_lr(self) -> List[float]:
        step_count = self._step_count
        if step_count > self.total_steps + 1 or step_count < 1:
            warn(
                f"Step count {step_count} is out of range [1, {self.total_steps + 1}]",
                category=RuntimeWarning,
                stacklevel=2,
            )
        return [self._get_lr(lr, step_count) for lr in self.base_lrs]

    def _get_lr(
        self,
        lr: float,
        step_count: Optional[int] = None,
        progress: Optional[float] = None,
        warmup_ratio: Optional[float] = None,
        cooldown_ratio: Optional[float] = None,
        scaling: Optional[str] = None,
    ) -> float:
        scaling = scaling or self.scaling
        step_count = step_count or self._step_count
        progress = progress or min(max(step_count / self.total_steps, 0.0), 1.0)
        final_lr = self.final_lr if self.final_lr is not None else lr * self.final_lr_ratio  # type: ignore[operator]
        ratio = getattr(self, self.method)(progress)
        if scaling == "percentile":
            lr *= pow(final_lr / lr, ratio)
        elif scaling == "numerical":
            lr = (1 - ratio) * (lr - final_lr) + final_lr
        else:
            raise ValueError(f"Method must be one of ['percentile', 'numerical'], but got {scaling}")
        if self.warmup_steps > step_count > 0:
            warmup_ratio = warmup_ratio or step_count / self.warmup_steps
            lr = warmup_ratio * (lr - self.min_lr) + self.min_lr
        elif self.cooldown_steps > 0 and step_count > self.cooldown_steps_begin:
            cooldown_ratio = cooldown_ratio or 1 - (step_count - self.cooldown_steps_begin) / self.cooldown_steps
            lr = cooldown_ratio * (lr - self.min_lr) + self.min_lr
        return max(self.min_lr, lr)

    def linear(self, progress: float) -> float:
        return progress

    def cosine(self, progress: float) -> float:
        return 1 - ((1 + cos(pi * progress)) / 2)

    def constant(self, progress: float) -> float:  # pylint: disable=unused-argument
        return 0.0

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}({self.method}, scaling={self.scaling}, "
            f"final_lr_ratio={self.final_lr_ratio}, total_steps={self.total_steps}, "
            f"warmup_steps={self.warmup_steps}, cooldown_steps={self.cooldown_steps})"
        )

BaseRunner

Backend-agnostic runner state and orchestration utilities.

BaseRunner intentionally keeps only the shared runtime contract used by concrete runners such as TorchRunner:

  • configuration and process lifecycle bootstrap
  • datasets/dataloaders/result containers
  • checkpoint/result persistence helpers
  • progress and score bookkeeping

Concrete runners are expected to customize runtime behavior through the explicit training/checkpoint hooks below, not by overriding bootstrap internals.

Construction lifecycle:

  1. Normalize config and create RunnerState.
  2. Bind workspace, containers, default FileCheckpointManager, and supervisor.
  3. Call early service hooks in order: init_distributed, init_checkpoint_manager, init_fault_tolerance, init_garbage_collection.
  4. Apply seed/determinism policy.
  5. Initialize logging, TensorBoard/W&B, print routing, signal handlers, and heartbeat.
  6. MetaRunner calls __post_init__. Concrete runners such as TorchRunner materialize models, optimizers, schedulers, and resume checkpoints there before delegating back to BaseRunner.__post_init__ for metadata persistence.

Override rule: early hooks run while the runner is only partially constructed; model/runtime hooks run in concrete __post_init__; loop hooks (train_step, evaluate_step, infer_step) run after all runtime components are bound.

Attributes:

Name Type Description
state RunnerState

Checkpointable aggregate state object.

config RunnerConfig

Runner configuration.

train_state RunnerTrainState

Training progress counters.

elastic_state RunnerElasticState

Torchelastic restart metadata.

rng_state RunnerRNGState

Python/NumPy/Torch RNG snapshots.

datasets FlatDict

Dataset mapping keyed by split.

dataloaders FlatDict

Dataloader mapping keyed by split.

checkpoint_manager CheckpointManager

Active checkpoint backend manager.

workspace RunnerWorkspace

Workspace, logging, metadata, and print-routing helper.

supervisor RunnerSupervisor

Signal, heartbeat, and garbage-collection helper.

ft FaultTolerance | None

Optional fault-tolerance runtime handle.

Source code in danling/runners/base_runner.py
Python
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
class BaseRunner(metaclass=MetaRunner):
    """
    Backend-agnostic runner state and orchestration utilities.

    `BaseRunner` intentionally keeps only the shared runtime contract used by
    concrete runners such as `TorchRunner`:

    - configuration and process lifecycle bootstrap
    - datasets/dataloaders/result containers
    - checkpoint/result persistence helpers
    - progress and score bookkeeping

    Concrete runners are expected to customize runtime behavior through the
    explicit training/checkpoint hooks below, not by overriding bootstrap
    internals.

    **Construction lifecycle:**

    1. Normalize config and create `RunnerState`.
    2. Bind workspace, containers, default `FileCheckpointManager`, and
       supervisor.
    3. Call early service hooks in order: `init_distributed`,
       `init_checkpoint_manager`, `init_fault_tolerance`,
       `init_garbage_collection`.
    4. Apply seed/determinism policy.
    5. Initialize logging, TensorBoard/W&B, print routing, signal handlers, and
       heartbeat.
    6. `MetaRunner` calls `__post_init__`. Concrete runners such as
       `TorchRunner` materialize models, optimizers, schedulers, and resume
       checkpoints there before delegating back to `BaseRunner.__post_init__`
       for metadata persistence.

    **Override rule:** early hooks run while the runner is only partially
    constructed; model/runtime hooks run in concrete `__post_init__`; loop
    hooks (`train_step`, `evaluate_step`, `infer_step`) run after all runtime
    components are bound.

    Attributes:
        state: Checkpointable aggregate state object.
        config: Runner configuration.
        train_state: Training progress counters.
        elastic_state: Torchelastic restart metadata.
        rng_state: Python/NumPy/Torch RNG snapshots.
        datasets: Dataset mapping keyed by split.
        dataloaders: Dataloader mapping keyed by split.
        checkpoint_manager: Active checkpoint backend manager.
        workspace: Workspace, logging, metadata, and print-routing helper.
        supervisor: Signal, heartbeat, and garbage-collection helper.
        ft: Optional fault-tolerance runtime handle.
    """

    state: RunnerState
    config: RunnerConfig
    train_state: RunnerTrainState
    elastic_state: RunnerElasticState
    rng_state: RunnerRNGState

    model: Any | None = None
    ema: Any | None = None
    criterion: Callable | None = None
    optimizer: Any | None = None
    scheduler: Any | None = None

    datasets: FlatDict
    dataloaders: FlatDict
    split: str | None = None

    results: RoundDict
    meters: AverageMeters
    metrics: Any | None = None
    train_metrics: Any | None = None
    evaluate_metrics: Any | None = None

    logger: logging.Logger | None = None
    writer: Any | None = None
    wandb: Any | None = None

    checkpoint_manager: CheckpointManager
    workspace: RunnerWorkspace
    supervisor: RunnerSupervisor
    ft: FaultTolerance | None

    timestamp: str
    _print_process: int

    def __init__(self, config: RunnerConfig | Mapping[str, Any]) -> None:
        if not isinstance(config, RunnerConfig):
            config = RunnerConfig(config)

        state = RunnerState(config=config)
        self.state = state
        self.config = state.config
        self.train_state = state.train
        self.elastic_state = state.elastic
        self.rng_state = state.rng

        self.timestamp = get_time_str()
        self.workspace = RunnerWorkspace(self)
        self.name = str(self.config.get("name", f"{self.workspace.lineage}-{self.workspace.experiment}"))
        self.datasets = FlatDict()
        self.dataloaders = DataLoaderDict()
        self.results = RoundDict()
        self.meters = AverageMeters()
        self.mode = RunnerMode.train
        self.checkpoint_manager = FileCheckpointManager(self)
        self.supervisor = RunnerSupervisor(self)
        self.ft = None

        self.init_distributed()
        self.init_checkpoint_manager()
        self.init_fault_tolerance()
        self.init_garbage_collection()

        if self.config.seed is not None:
            self.set_seed()

        if self.config.deterministic:
            self.set_deterministic()

        if self.config.log:
            self.workspace.init_logging()

        if self.config.tensorboard:
            self.init_tensorboard()
        if self.config.get("wandb.enabled", False):
            self.init_wandb()

        self.workspace.init_print()
        self.init_signal_handlers()
        self.init_heartbeat()

    @property
    def world_size(self) -> int:
        """Distributed world size from environment."""

        return int(os.getenv("WORLD_SIZE", "1"))

    @property
    def rank(self) -> int:
        """Global rank from environment."""

        return int(os.getenv("RANK", "0"))

    @property
    def local_rank(self) -> int:
        """Local rank from environment."""

        return int(os.getenv("LOCAL_RANK", "0"))

    @property
    def distributed(self) -> bool:
        """Whether distributed mode is active."""

        return self.world_size > 1

    @property
    def is_main_process(self) -> bool:
        """Whether current rank is global main process."""

        return self.rank == 0

    @property
    def is_local_main_process(self) -> bool:
        """Whether current rank is local main process."""

        return self.local_rank == 0

    @cached_property
    def code_id(self) -> str | None:
        """Stable code identity for the current checkout."""

        return get_git_hash()

    @cached_property
    def config_id(self) -> str:
        """Stable semantic config identity for this runner."""

        return format(hash(self.config) & ((1 << 48) - 1), "012x")

    @property
    def id(self) -> str:
        """Stable run identity derived from code identity and semantic config."""

        if self.code_id is None:
            return self.config_id
        return f"{self.code_id}-{self.config_id}"

    def __post_init__(self) -> None:
        """Hook called after `__init__` by `MetaRunner`."""
        self.workspace.save_metadata()

    @cached_property
    def score_split(self) -> str | None:
        """Split used for best-score selection."""

        if "score_split" in self.config and self.config.score_split is not None:
            return self.config.score_split

        splits = self.evaluate_splits
        if not splits:
            return None
        for split in splits:
            if split.lower().startswith("val"):
                return split
        return splits[0]

    @property
    def scores(self) -> FlatDict | None:
        """Index-to-score mapping extracted from `score_split/score_name`."""

        if not self.results:
            return None

        score_split = self.score_split
        if score_split is None:
            return None

        scores = FlatDict()
        for index, result in self.results.items():
            if score_split not in result:
                continue
            split_result = result[score_split]
            if not isinstance(split_result, Mapping):
                continue
            if self.config.score_name not in split_result:
                continue
            scores[index] = split_result[self.config.score_name]

        return scores or None

    @property
    def best_index(self) -> int:
        """Best result index according to configured score metric."""

        if not self.scores:
            return 0

        scores = self.scores
        indices = list(scores.keys())
        reducer = min if self.config.score_name == "loss" else max
        return reducer(reversed(indices), key=scores.get)

    @property
    def latest_result(self) -> RoundDict | None:
        """Most recent appended result row."""

        if not self.results:
            return None

        latest_index = next(reversed(self.results))
        latest = self.results[latest_index]

        ret = RoundDict(latest)
        ret["index"] = latest_index
        return ret

    @property
    def best_result(self) -> RoundDict | None:
        """Best result row according to configured score metric."""

        if not self.results:
            return None

        best_index = self.best_index
        best = self.results[best_index]

        ret = RoundDict(best)
        ret["index"] = best_index
        return ret

    @property
    def latest_score(self) -> float | None:
        """Latest scalar score."""

        scores = self.scores
        if not scores:
            return None

        latest_index = next(reversed(scores))
        return scores[latest_index]

    @property
    def best_score(self) -> float | None:
        """Best scalar score."""

        if not self.scores:
            return None

        return self.scores[self.best_index]

    @property
    def is_best(self) -> bool:
        """Whether latest score matches current best score.

        Returns ``True`` only when comparable scalar scores are available and
        agree within tolerance. Returns ``True`` on the first iteration (no
        prior results), and ``False`` when scores cannot be resolved (e.g.,
        no `score_split`/`score_name` configured) — silently reporting best
        in that case would trigger phantom "best" checkpoint copies.
        """

        if not self.results:
            return True

        latest = self.latest_score
        best = self.best_score
        if latest is None or best is None:
            return False
        return abs(latest - best) < 1e-7

    def get_epoch_result(self) -> RoundDict:
        meter_result = self.meters.average()
        if self.metrics is None:
            return RoundDict(meter_result)
        merged = RoundDict(meter_result)
        for key, value in self.metrics.average().items():
            if isinstance(value, Mapping) and len(value) == 1:
                value = next(iter(value.values()))
            merged[key] = value
        return merged

    def get_step_result(self) -> RoundDict:
        meter_result = self.meters.value()
        if self.metrics is None:
            return RoundDict(meter_result)
        merged = RoundDict(meter_result)
        for key, value in self.metrics.value().items():
            if isinstance(value, Mapping) and len(value) == 1:
                value = next(iter(value.values()))
            merged[key] = value
        return merged

    def append_result(self, result: RoundDict | Mapping[str, Any], index: int | None = None) -> None:
        if index is None:
            index = self.train_state.epoch

        if not isinstance(result, RoundDict):
            result = RoundDict(result)

        if index in self.results:
            self.results[index].merge(result)
        else:
            self.results[index] = result

    def step_log(
        self,
        split: str,
        iteration: int,
        length: int | str | None = None,
        result: RoundDict[str, Any] | Mapping[str, Any] | None = None,
    ) -> RoundDict:
        if length is None:
            try:
                length = len(self.dataloaders[split]) - 1
            except (TypeError, NotImplementedError):
                length = "∞"

        if result is None:
            result = self.get_step_result()
        elif not isinstance(result, RoundDict):
            result = RoundDict(result)
        print(self.format_step_result(result, split, iteration, length))

        if self.mode == RunnerMode.train:
            self.write_result(result, split)

        return result

    def format_epoch_result(
        self,
        result: RoundDict[str, Any],
        epochs: int | None = None,
        total_epochs: int | None = None,
    ) -> str:
        epochs = self.train_state.epoch if epochs is None else epochs
        total_epochs = self.epochs if total_epochs is None else total_epochs

        prefix = ""
        if total_epochs is not None:
            prefix = f"epoch [{epochs + 1}/{total_epochs}]"

        return f"{prefix}{self.format_result(result)}"

    def format_step_result(self, result: RoundDict[str, Any], split: str, steps: int, length: int | str) -> str:
        if self.mode == RunnerMode.train:
            prefix = f"training on {split}"
        elif self.mode == RunnerMode.evaluate:
            prefix = f"evaluating on {split}"
        elif self.mode == RunnerMode.infer:
            prefix = f"inferring on {split}"
        else:
            prefix = f"running in {self.mode} on {split}"

        return f"{prefix} [{steps}/{length}]\t{self.format_result(result)}"

    def format_result(self, result: RoundDict[str, Any], format_spec: str = ".4f") -> str:
        return format_result(result, format_spec=format_spec)

    def flatten_result(self, result: Mapping[str, Any]) -> FlatDict[str, Any]:
        flat_result = FlatDict()

        def add_score(tag: str, score: Any) -> None:
            if isinstance(score, AverageMeter):
                score = score.avg

            if isinstance(score, Mapping):
                nested = RoundDict(score)
                nested.setattr("separator", "/")
                for nested_name, nested_score in nested.dict(flatten=True).items():
                    add_score(f"{tag}/{nested_name}", nested_score)
                return

            if isinstance(score, Sequence) and not isinstance(score, (str, bytes)):
                for idx, nested_score in enumerate(score):
                    add_score(f"{tag}/{idx}", nested_score)
                return

            flat_result[tag] = score

        flattened = RoundDict(result)
        flattened.setattr("separator", "/")
        for name, score in flattened.dict(flatten=True).items():
            add_score(str(name), score)

        return flat_result

    def write_result(self, result: RoundDict[str, Any], split: str, steps: int | None = None) -> None:
        if self.writer is None and self.wandb is None:
            return

        steps = self.train_state.global_step if steps is None else steps

        flat_result = self.flatten_result(result)

        for name, score in flat_result.items():
            self.write_score(name, score, split, steps)

        if self.wandb is not None:
            payload = {f"{split}/{name}": score for name, score in flat_result.items()}
            self.wandb.log(payload, step=steps)

    def write_score(self, name: str, score: float, split: str, steps: int) -> None:
        if self.writer is not None:
            self.writer.add_scalar(f"{split}/{name}", score, steps)

    @catch
    @on_main_process
    def save_result(self) -> None:
        if not self.latest_result:
            return
        payload = {
            "name": self.name,
            "id": self.id,
            "timestamp": self.timestamp,
            "results": round(self.results, 8),
        }
        self.save(payload, os.path.join(self.workspace.dir, "results.json"), indent=4)

        latest = round(self.latest_result, 8)
        latest_payload = {"name": self.name, "id": self.id, "timestamp": self.timestamp}
        latest_payload.update(dict(latest))

        latest_path = os.path.join(self.workspace.dir, "latest.json")
        self.save(latest_payload, latest_path, indent=4)

        if self.is_best:
            shutil.copy(latest_path, os.path.join(self.workspace.dir, "best.json"))

    def auto_restore(self) -> None:
        """Auto-load resume/pretrained sources declared in config.

        Precedence:
            `config.resume` > `config.auto_resume` > `config.pretrained`.
        """

        restore_target = self._resolve_auto_restore_target()
        if restore_target is None:
            return

        restore_kind, restore_source = restore_target
        if restore_kind == "checkpoint":
            self.load_checkpoint(restore_source)
            return
        self.load_pretrained(restore_source)

    def _resolve_auto_restore_target(self) -> tuple[str, Mapping[Any, Any] | PathStr] | None:
        resume_source = self.config.get("resume")
        auto_resume = bool(self.config.get("auto_resume", False))
        pretrained_source = self.config.get("pretrained")

        specified_count = int(bool(resume_source)) + int(auto_resume) + int(bool(pretrained_source))
        if specified_count > 2:
            warn(
                "`config.resume`, `config.auto_resume`, and `config.pretrained` are all set; "
                "precedence is `resume` > `auto_resume` > `pretrained`",
                RuntimeWarning,
                stacklevel=2,
            )

        if resume_source:
            return ("checkpoint", resume_source)

        if auto_resume:
            return ("checkpoint", self._auto_resume_source())

        if pretrained_source:
            return ("pretrained", pretrained_source)

        return None

    def _auto_resume_source(self) -> str:
        backend = str(self.config.get("checkpoint.backend", "auto")).strip().lower()
        if backend == "dcp":
            return os.path.join(self.workspace.checkpoint_dir, "latest")
        return os.path.join(self.workspace.checkpoint_dir, "latest.pth")

    def init_distributed(self) -> None:
        """
        Initialize the distributed environment.

        The default is a no-op (single-process). Concrete runners override
        this hook to initialize the torch.distributed process group; see
        [`TorchRunner.init_distributed`][danling.runners.TorchRunner.init_distributed]
        for the canonical specification.
        """

    def init_checkpoint_manager(self) -> None:
        """
        Bind the runner's checkpoint manager.

        The default is a no-op — `BaseRunner.__init__` already binds the
        `FileCheckpointManager`. Concrete runners override this hook to swap
        in the backend-appropriate manager via `set_checkpoint_manager(...)`;
        see
        [`TorchRunner.init_checkpoint_manager`][danling.runners.TorchRunner.init_checkpoint_manager]
        for the canonical specification.
        """

    def init_fault_tolerance(self) -> None:
        """Initialize optional fault-tolerance runtime support."""

        self.ft = FaultTolerance(self)

    def init_heartbeat(self) -> None:
        """Configure optional background heartbeat writer."""

        self.supervisor.init_heartbeat()

    def init_garbage_collection(self) -> None:
        """Configure optional runner-managed Python GC pacing."""

        self.supervisor.init_garbage_collection()

    def init_signal_handlers(self) -> None:
        """Install runner-owned signal handlers for graceful preemption."""

        self.supervisor.init_signal_handlers()

    def prepare_for_shutdown_checkpoint(self) -> None:
        """Finalize runner state before writing a forced shutdown checkpoint."""

    def set_checkpoint_manager(self, manager: CheckpointManager) -> None:
        current = getattr(self, "checkpoint_manager", None)
        if current is manager:
            return
        if current is not None:
            current.close(timeout=0.0)
        self.checkpoint_manager = manager

    @on_main_process
    def init_tensorboard(self, *args, **kwargs) -> None:
        """Initialize tensorboard writer."""

        warn(
            "tensorboard is enabled, but this runner does not initialize a tensorboard writer",
            RuntimeWarning,
            stacklevel=2,
        )

    @on_main_process
    def init_wandb(self, *args, **kwargs) -> None:
        """Initialize Weights & Biases run for scalar logging."""

        try:
            import wandb
        except ImportError as exc:
            raise RuntimeError("wandb is enabled, but the `wandb` package is not installed") from exc

        wandb_config = self.config.wandb
        if "project" not in kwargs:
            kwargs["project"] = wandb_config.get("project") or self.workspace.lineage
        if "entity" not in kwargs and wandb_config.get("entity") is not None:
            kwargs["entity"] = wandb_config.entity
        if "group" not in kwargs:
            kwargs["group"] = wandb_config.get("group") or self.workspace.experiment
        if "name" not in kwargs:
            kwargs["name"] = wandb_config.get("name") or self.id
        if "job_type" not in kwargs and wandb_config.get("job_type") is not None:
            kwargs["job_type"] = wandb_config.job_type
        tags = wandb_config.get("tags")
        if "tags" not in kwargs and tags is not None:
            kwargs["tags"] = [tags] if isinstance(tags, str) else list(tags)
        if "dir" not in kwargs:
            kwargs["dir"] = wandb_config.get("dir") or self.workspace.dir
        if "mode" not in kwargs and wandb_config.get("mode") is not None:
            kwargs["mode"] = wandb_config.mode
        if "config" not in kwargs:
            kwargs["config"] = self.config.dict()

        self.wandb = cast(Any, wandb).init(*args, **kwargs)

    def set_seed(self, seed: int | None = None, bias: int | bool | None = None) -> int:
        """Set python/numpy RNG seeds and snapshot RNG state.

        Args:
            seed: Base seed. Defaults to `self.config.seed`.
            bias: Optional per-process bias. `None` uses `self.rank`.

        Returns:
            The process-local seed after applying bias.
        """

        base_seed = self.config.seed if seed is None else seed
        if base_seed is None:
            raise ValueError("cannot set seed: no seed is configured and no seed argument was provided")
        base_seed = int(base_seed)

        self.config.seed = base_seed

        process_seed = base_seed
        if bias is None:
            bias = self.rank
        if bias:
            process_seed += int(bias)

        random.seed(process_seed)
        if np_random is not None:
            np_random.seed(process_seed)

        self.rng_state.python = random.getstate()
        self.rng_state.numpy = np_random.get_state() if np_random is not None else None
        return process_seed

    def set_deterministic(self) -> None:
        """Enable deterministic behavior in subclass-specific backends."""

    def train(self, *args, **kwargs):
        """Run top-level training workflow."""

        raise NotImplementedError

    def train_epochs(self, *args, **kwargs):
        """Run epoch-mode training workflow."""

        raise NotImplementedError

    def train_epoch(self, *args, **kwargs):
        """Run one training epoch on a split."""

        raise NotImplementedError

    def train_steps(self, *args, **kwargs):
        """Run step-mode training workflow."""

        raise NotImplementedError

    def train_step(self, *args, **kwargs):
        """
        Run one training micro-step.

        Concrete runners define the override contract; see
        [`TorchRunner.train_step`][danling.runners.TorchRunner.train_step] for
        the canonical specification.
        """

        raise NotImplementedError

    def backward(self, loss, *args, **kwargs) -> None:
        """Run backward pass for one micro-step loss."""

        raise NotImplementedError

    def step(self, *args, **kwargs) -> None:
        """Advance optimizer/scheduler state when accumulation is ready."""

        raise NotImplementedError

    def evaluate(self, *args, **kwargs):
        """Run top-level evaluation workflow."""

        raise NotImplementedError

    def evaluate_epoch(self, *args, **kwargs):
        """Run one full evaluation epoch on a split."""

        raise NotImplementedError

    def evaluate_steps(self, *args, **kwargs):
        """Run bounded evaluation steps on a split."""

        raise NotImplementedError

    def evaluate_step(self, *args, **kwargs):
        """
        Run one evaluation step.

        Concrete runners define the override contract; see
        [`TorchRunner.evaluate_step`][danling.runners.TorchRunner.evaluate_step]
        for the canonical specification.
        """

        raise NotImplementedError

    def infer(self, *args, **kwargs):
        """Run top-level inference workflow."""

        raise NotImplementedError

    def infer_step(self, *args, **kwargs):
        """
        Run one inference step.

        Concrete runners define the override contract; see
        [`TorchRunner.infer_step`][danling.runners.TorchRunner.infer_step] for
        the canonical specification.
        """

        raise NotImplementedError

    def unwrap(self, model: Any) -> Any:
        """Return an unwrapped model object."""

        return model

    def state_dict(self, cls: type = dict) -> Mapping:
        """
        Build the backend-neutral runner checkpoint payload.

        The base payload contains semantic runner config, mutable runner
        state, RNG snapshots, and dataloader resume state. Backend runners
        extend this payload with model/optimizer/scheduler state.

        **Called when:** checkpoint managers build a payload for
        `save_checkpoint`, and fault-tolerance callbacks need a runner state
        snapshot.

        Args:
            cls: Mapping factory used for nested payloads. Backends may pass
                `dict`-like containers to preserve their serialization format.

        Returns:
            Mapping with `runner`, `state`, and `dataloaders` keys.

        **Side effects:** snapshots Python and NumPy RNG state into
        `self.rng_state` before exporting.

        !!! danger "Do not"
            - Mutate model or optimizer state here.
            - Drop the `runner` config payload; resume validation depends on it.
            - Override without calling `super()` unless you fully replace the
              checkpoint format.
        """

        self.rng_state.python = random.getstate()
        self.rng_state.numpy = np_random.get_state() if np_random is not None else None

        state = self.state.state_dict()
        if cls is not dict:
            state = cls(state)

        dataloader_state = self.dataloaders.state_dict()
        if cls is not dict:
            dataloader_state = cls(dataloader_state)

        return cls(runner=self.config.dict(), state=state, dataloaders=dataloader_state)

    def load_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
        """
        Restore backend-neutral runner state from a checkpoint payload.

        This restores semantic runner state and Python/NumPy RNG state. Model,
        EMA, optimizer, scheduler, and dataloader component loading is owned by
        `load_checkpoint`.

        **Called when:** `load_checkpoint` restores a full checkpoint, and
        fault-tolerance callbacks receive a runner state payload.

        Args:
            checkpoint: Mapping produced by `state_dict` or a backend-specific
                superset of that payload.

        Raises:
            ValueError: checkpoint runner config differs semantically from the
                current runner config.

        **Side effects:** updates `self.state`, `self.train_state`,
        `self.elastic_state`, `self.rng_state`, and process RNG state.

        !!! danger "Do not"
            - Load model/optimizer/scheduler state here; use component loaders
              through `load_checkpoint`.
            - Suppress semantic config diffs unless you also update the resume
              policy deliberately.
        """

        runner_config = checkpoint.get("runner")
        if runner_config is not None:
            checkpoint_config = RunnerConfig(runner_config).canonical()
            current_config = self.config.canonical()
            semantic_diff = NestedDict(checkpoint_config).diff(current_config).dict()
            if semantic_diff:
                raise ValueError(
                    "cannot load checkpoint: runner config is semantically different from current config; "
                    f"start a new experiment or align config. diff={semantic_diff}"
                )

        state_dict = checkpoint.get("state") or {}
        self.state.load_state_dict(dict(state_dict))

        rng_state = state_dict.get("rng")
        if isinstance(rng_state, Mapping) and "python" in rng_state and self.rng_state.python is not None:
            random.setstate(self.rng_state.python)

        if (
            np_random is not None
            and isinstance(rng_state, Mapping)
            and "numpy" in rng_state
            and self.rng_state.numpy is not None
        ):
            np_random.set_state(self.rng_state.numpy)

    @staticmethod
    def _normalize_checkpoint_exclude_path(path: str) -> tuple[str, ...]:
        aliases = {
            "data_loader": "dataloaders",
            "dataloader": "dataloaders",
            "lr_scheduler": "scheduler",
        }
        parts = tuple(part for part in str(path).split(".") if part)
        if not parts:
            return ()
        return (aliases.get(parts[0], parts[0]), *parts[1:])

    def checkpoint_exclude_from_loading(self) -> tuple[tuple[str, ...], ...]:
        excluded = self.config.get("checkpoint.exclude_from_loading")
        if excluded is None:
            return ()
        if isinstance(excluded, str):
            excluded = (excluded,)
        return tuple(
            normalized for path in excluded if (normalized := self._normalize_checkpoint_exclude_path(str(path)))
        )

    @staticmethod
    def _drop_checkpoint_path(checkpoint: dict[str, Any], path: Sequence[str]) -> None:
        if not path:
            return
        key = path[0]
        if len(path) == 1:
            checkpoint.pop(key, None)
            return
        child = checkpoint.get(key)
        if isinstance(child, Mapping):
            child_copy = dict(child)
            checkpoint[key] = child_copy
            BaseRunner._drop_checkpoint_path(child_copy, path[1:])

    def _filter_checkpoint_for_loading(
        self,
        checkpoint: Mapping[str, Any],
        excluded_paths: Sequence[Sequence[str]],
    ) -> dict[str, Any]:
        filtered = dict(checkpoint)
        for path in excluded_paths:
            self._drop_checkpoint_path(filtered, path)
        return filtered

    @staticmethod
    def _is_top_level_checkpoint_excluded(excluded_paths: Sequence[Sequence[str]], *keys: str) -> bool:
        key_set = set(keys)
        return any(len(path) == 1 and path[0] in key_set for path in excluded_paths)

    def save_checkpoint(
        self,
        name: str = "latest",
        epochs: int | None = None,
        save_best: bool = True,
        last_step: bool = False,
        force: bool = False,
    ) -> None:
        """
        Persist runner state through the active checkpoint manager.

        Backend collective semantics are owned by
        `checkpoint_manager.is_collective`. File-style managers save on the
        main process only; collective managers require every rank to enter this
        method together.

        **Called when:** training loops hit checkpoint cadence, final
        `last_step` saves run, or the supervisor handles a shutdown signal.

        Args:
            name: Logical checkpoint alias, usually `"latest"` or `"best"`.
            epochs: Epoch index used for history checkpoint naming. Defaults
                to `self.train_state.epoch`.
            save_best: Whether to publish/update the best-checkpoint alias
                when `self.is_best` is true.
            last_step: Whether this save is the final save for the run.
            force: Bypass cadence checks inside the manager.

        **Side effects:** delegates to
        `self.checkpoint_manager.save_checkpoint(...)`.

        !!! danger "Do not"
            - Add a main-process guard around calls to this method; DCP-style
              managers need all ranks to participate.
            - Bypass the checkpoint manager for normal runner checkpoints.
        """

        if not (self.is_main_process or self.checkpoint_manager.is_collective):
            return
        epochs = self.train_state.epoch if epochs is None else epochs
        self.checkpoint_manager.save_checkpoint(
            name=name,
            epochs=epochs,
            save_best=save_best,
            last_step=last_step,
            force=force,
        )

    def save_seed_checkpoint(self, name: str = "seed") -> None:
        """
        Persist an initialization checkpoint for cross-topology experiments.

        Seed checkpoints are intended to be created before training advances,
        then loaded with `checkpoint.load_only=True` or `resume`/`pretrained`
        when comparing different parallel layouts from the same initial model
        state. They are saved through the final-checkpoint path, so
        `checkpoint.last_save_model_only=True` intentionally applies.
        """
        if self.train_state.global_step != 0 or self.train_state.epoch != 0:
            warn(
                "save_seed_checkpoint() is intended before training advances; "
                f"current epoch={self.train_state.epoch}, global_step={self.train_state.global_step}",
                RuntimeWarning,
                stacklevel=2,
            )
        self.save_checkpoint(name=name, epochs=0, save_best=False, last_step=True, force=True)

    def load_checkpoint(
        self,
        checkpoint: Mapping | bytes | str | os.PathLike,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """
        Restore a full runner checkpoint.

        This is the full-state restore path: runtime state, model/EMA,
        optimizer, scheduler, and dataloader progress are restored when present
        and applicable to the current runner.

        **Called when:** users resume a run explicitly, `auto_restore` selects
        a resume source, `from_checkpoint` constructs a runner, or
        fault-tolerance callbacks restore a full runner payload.

        Args:
            checkpoint: In-memory checkpoint mapping or backend-specific path.
            *args: Forwarded to `read_checkpoint` and component loaders.
            **kwargs: Forwarded to `read_checkpoint` and component loaders.

        Raises:
            ValueError: checkpoint is missing required component state for an
                initialized component, or config validation fails.

        **Side effects:** updates runner state, model/EMA weights, optimizer,
        scheduler, dataloader progress, and `config.resume` for path inputs.

        !!! danger "Do not"
            - Use this for model-only finetuning payloads; use
              `load_pretrained` instead.
            - Override just to support a new path type; prefer overriding
              `read_checkpoint`.
        """

        ckpt = self.read_checkpoint(checkpoint, *args, **kwargs)
        excluded_paths = self.checkpoint_exclude_from_loading()
        if excluded_paths:
            if self._is_top_level_checkpoint_excluded(excluded_paths, "runner"):
                warn(
                    "`checkpoint.exclude_from_loading` contains 'runner'; "
                    "semantic runner config validation will be skipped for this load.",
                    RuntimeWarning,
                    stacklevel=2,
                )
            ckpt = self._filter_checkpoint_for_loading(ckpt, excluded_paths)

        self.load_state_dict(ckpt)
        if not self._is_top_level_checkpoint_excluded(excluded_paths, "model", "model_parts", "module"):
            if "model" in ckpt:
                self.load_model(ckpt["model"], *args, **kwargs)
            elif "model_parts" in ckpt:
                self.load_model(ckpt["model_parts"], *args, **kwargs)
            elif self.model is not None:
                raise ValueError(
                    "cannot restore model: checkpoint has no model state\n"
                    "Use `load_pretrained` only for model-only checkpoints with model/ema payloads"
                )
        if not self._is_top_level_checkpoint_excluded(excluded_paths, "ema") and (
            self.ema is not None or "ema" in ckpt
        ):
            self.load_ema(ckpt.get("ema"), *args, **kwargs)
        if not self._is_top_level_checkpoint_excluded(excluded_paths, "optimizer") and (
            self.optimizer is not None or "optimizer" in ckpt
        ):
            self.load_optimizer(ckpt.get("optimizer"), *args, **kwargs)
        if not self._is_top_level_checkpoint_excluded(excluded_paths, "scheduler") and (
            self.scheduler is not None or "scheduler" in ckpt
        ):
            self.load_scheduler(ckpt.get("scheduler"), *args, **kwargs)
        if not self._is_top_level_checkpoint_excluded(excluded_paths, "dataloaders") and (
            self.dataloaders or "dataloaders" in ckpt
        ):
            self.load_dataloaders(ckpt.get("dataloaders"))
        if isinstance(checkpoint, (str, bytes, os.PathLike)):
            self.config.resume = os.fsdecode(checkpoint)

    @staticmethod
    def _require_checkpoint_component_state(component: str, state_dict: Any | None) -> Any:
        component_labels = {
            "ema": "EMA state",
            "optimizer": "optimizer state",
            "scheduler": "scheduler state",
        }
        if state_dict is None:
            component_label = component_labels.get(component, f"{component} state")
            raise ValueError(
                f"cannot restore {component}: checkpoint has no {component_label}\n"
                "Use `load_pretrained` for model-only checkpoints instead of `load_checkpoint`"
            )
        return state_dict

    def load_model(self, state_dict: Mapping[str, Any], *args, **kwargs) -> None:
        """Load model state."""
        if self.model is None:
            raise ValueError("cannot restore model: model is not initialized")
        self.unwrap(self.model).load_state_dict(state_dict, *args, **kwargs)

    def load_ema(self, state_dict: Mapping[str, Any] | None, *args, **kwargs) -> None:
        """Load EMA state."""
        if self.ema is None:
            return
        state_dict = self._require_checkpoint_component_state("ema", state_dict)
        self.ema.load_state_dict(state_dict, *args, **kwargs)

    def load_optimizer(self, state_dict: Mapping[str, Any] | None, *args, **kwargs) -> None:
        """Load optimizer state."""
        if self.optimizer is None:
            return
        state_dict = self._require_checkpoint_component_state("optimizer", state_dict)
        self.optimizer.load_state_dict(state_dict, *args, **kwargs)

    def load_scheduler(self, state_dict: Mapping[str, Any] | None, *args, **kwargs) -> None:
        """Load scheduler state."""
        if self.scheduler is None:
            return
        state_dict = self._require_checkpoint_component_state("scheduler", state_dict)
        self.scheduler.load_state_dict(state_dict, *args, **kwargs)

    def load_dataloaders(self, state_dict: Mapping[str, Any] | None) -> None:
        """Load dataloader progress state when the current runner has matching loaders."""
        if state_dict is None:
            return
        self.dataloaders.load_state_dict(state_dict)

    def load_pretrained(
        self,
        checkpoint: Mapping | bytes | str | os.PathLike,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """
        Load model weights only from a checkpoint payload or path.

        When checkpoint payload provides EMA weights (`ema`), EMA is preferred as
        the pretrained source. Otherwise `model` is used.

        **Called when:** users initialize from pretrained weights, or
        `auto_restore` selects `config.pretrained`.

        Args:
            checkpoint: In-memory payload or backend-specific path containing
                `ema`, `model`, or `model_parts`.
            *args: Forwarded to `read_checkpoint` and `load_model`.
            **kwargs: Forwarded to `read_checkpoint` and `load_model`.

        Raises:
            ValueError: model is not initialized, or the payload has no usable
                model/EMA state.

        **Side effects:** loads model weights and updates `config.pretrained`
        for path inputs. Optimizer, scheduler, runner state, and dataloaders
        are intentionally untouched.

        !!! danger "Do not"
            - Use this to resume training state; use `load_checkpoint` for
              full-state restore.
            - Load optimizer/scheduler state in this path.
        """

        if self.model is None:
            raise ValueError("cannot load pretrained weights: model is not initialized")

        ckpt = self.read_checkpoint(checkpoint, *args, **kwargs)
        if ckpt.get("ema") is not None:
            self.load_model(ckpt["ema"], *args, **kwargs)
        elif "model" in ckpt:
            self.load_model(ckpt["model"], *args, **kwargs)
        elif "model_parts" in ckpt:
            self.load_model(ckpt["model_parts"], *args, **kwargs)
        else:
            raise ValueError(
                "cannot load pretrained weights: checkpoint has no EMA or model state\n"
                "Use `load_checkpoint` for full checkpoint restore instead of `load_pretrained`"
            )
        if isinstance(checkpoint, (str, bytes, os.PathLike)):
            self.config.pretrained = os.fsdecode(checkpoint)
        else:
            self.config.pretrained = None

    @classmethod
    def from_checkpoint(cls, checkpoint: Mapping | bytes | str | os.PathLike, *args, **kwargs) -> BaseRunner:
        """Instantiate runner from checkpoint config and restore full state."""

        config = cls.read_config(checkpoint, *args, **kwargs)
        config.resume = None
        config.auto_resume = False
        config.pretrained = None
        runner = cls(config)
        runner.load_checkpoint(checkpoint, *args, **kwargs)
        return runner

    @classmethod
    def read_config(
        cls,
        checkpoint: Mapping | bytes | str | os.PathLike,
        *args,
        **kwargs,
    ) -> RunnerConfig:
        """
        Read runner config from checkpoint mapping or file path.

        Note:
            BaseRunner only accepts file checkpoints for path input.
            Backend-specific directory checkpoints must be handled in subclasses.
        """

        if isinstance(checkpoint, Mapping):
            ckpt = checkpoint
        elif isinstance(checkpoint, (bytes, str, os.PathLike)):
            checkpoint_id = os.fspath(checkpoint)
            if os.path.isfile(checkpoint_id):
                kwargs = dict(kwargs)
                kwargs["map_location"] = "cpu"
                kwargs["weights_only"] = False
                ckpt = load(checkpoint, *args, **kwargs)
            else:
                raise ValueError(
                    f"cannot read config from checkpoint path for {cls.__name__}: path must be a file; "
                    "use a backend-specific runner for directory-style checkpoints"
                )
        else:
            raise ValueError(
                "invalid checkpoint input: expected a mapping or path, "
                f"got {type(checkpoint).__name__}: {checkpoint!r}"
            )

        if "runner" not in ckpt:
            raise ValueError(
                "cannot read runner config: checkpoint is missing key 'runner'; "
                "use from_pretrained(...) for model-only checkpoints"
            )
        return RunnerConfig(ckpt["runner"])

    @classmethod
    def from_pretrained(
        cls,
        config: RunnerConfig | Mapping[str, Any],
        checkpoint: Mapping | bytes | str | os.PathLike,
        *args,
        **kwargs,
    ) -> BaseRunner:
        """Build a runner from config and load model weights only."""

        prepared = RunnerConfig(config)
        prepared.resume = None
        prepared.auto_resume = False
        prepared.pretrained = None
        runner = cls(prepared)
        runner.load_pretrained(checkpoint, *args, **kwargs)
        return runner

    def read_checkpoint(self, checkpoint: Mapping | bytes | str | os.PathLike, *args, **kwargs) -> Mapping[str, Any]:
        """Normalize checkpoint input into an in-memory mapping payload."""
        if isinstance(checkpoint, (bytes, str, os.PathLike)):
            kwargs = dict(kwargs)
            kwargs["map_location"] = "cpu"
            kwargs["weights_only"] = False
            return load(checkpoint, *args, **kwargs)
        if isinstance(checkpoint, Mapping):
            return checkpoint
        raise ValueError(
            f"invalid checkpoint input: expected a mapping or path, got {type(checkpoint).__name__}: {checkpoint!r}"
        )

    def save(self, obj: Any, file: PathStr, main_process_only: bool = True, *args, **kwargs) -> File:
        """Save an object with optional main-process guard."""

        if (main_process_only and self.is_main_process) or not main_process_only:
            return save(obj, file, *args, **kwargs)
        return file

    def close(self, timeout: float | None = None) -> bool:
        """Finalize checkpoint/log/writer resources before shutdown."""

        if timeout is None:
            timeout = self.config.get("checkpoint.wait_timeout")

        drained = True
        close_error: Exception | None = None
        try:
            drained = self.checkpoint_manager.close(timeout=timeout)
        except Exception as exc:
            close_error = exc

        if close_error is None and not drained:
            warn("runner close: timed out while draining async checkpoints", RuntimeWarning, stacklevel=2)
            return False

        self.supervisor.restore_signal_handlers()
        writer = self.writer
        if writer is not None:
            writer.flush()
            writer.close()
            self.writer = None

        if self.wandb is not None:
            self.wandb.finish()

        self.workspace.close()
        self.supervisor.close()
        if self.ft is not None:
            self.ft.close()

        if close_error is not None:
            raise close_error
        return drained

    @property
    def mode(self) -> RunnerMode:
        return self._mode

    @mode.setter
    def mode(self, mode: str | RunnerMode) -> None:
        if isinstance(mode, str):
            mode = RunnerMode(mode)
        if getattr(self, "_mode", None) == mode:
            return
        self._mode = mode

    @property
    def batch_size(self) -> int:
        """Infer batch size from config or first dataloader."""
        batch_size = self.config.get("dataloader.batch_size")
        if batch_size is not None:
            return batch_size

        if self.dataloaders:
            loader = next(iter(self.dataloaders.values()))
            batch_size = getattr(loader, "batch_size", None)
            if batch_size is not None:
                return batch_size

        raise AttributeError("batch_size could not be inferred and is not in config")

    @staticmethod
    def _loader_length(loader: Any) -> int | None:
        try:
            return len(loader)
        except (TypeError, NotImplementedError):
            return None

    @property
    def epochs(self) -> int | None:
        """Configured epoch budget, if present."""
        epochs = self.config.get("epochs")
        if epochs is not None:
            return epochs
        return None

    @epochs.setter
    def epochs(self, epochs: int) -> None:
        self.config.epochs = epochs

    @property
    def steps(self) -> int | None:
        """Configured/derived optimizer-step budget."""
        steps = self.config.get("steps")
        if steps is not None:
            return steps
        if self.epochs is not None and self.dataloaders:
            steps_per_epoch = 0
            for split in self.train_splits:
                split_micro_steps = self._loader_length(self.dataloaders[split])
                if split_micro_steps is None:
                    return None
                steps_per_epoch += (split_micro_steps + self.accum_steps - 1) // self.accum_steps
            return steps_per_epoch * self.epochs
        return None

    @steps.setter
    def steps(self, steps: int) -> None:
        self.config.steps = steps

    @property
    def is_step_mode(self) -> bool:
        """Whether runner is in step mode (`epochs` is unset)."""
        return self.epochs is None

    @cached_property
    def accum_steps(self) -> int:
        """Gradient accumulation steps."""
        return self.config.get("accum_steps", 1)

    @cached_property
    def precision(self) -> str | None:
        """Autocast precision mode."""
        return self.config.get("precision")

    @cached_property
    def max_grad_value(self) -> float | None:
        """Gradient value clipping threshold."""
        return self.config.get("max_grad_value")

    @cached_property
    def max_grad_norm(self) -> float | None:
        """Gradient norm clipping threshold."""
        return self.config.get("max_grad_norm")

    @cached_property
    def skip_nonfinite_grad(self) -> bool:
        """Whether to skip optimizer updates when gradients are non-finite."""
        return self.config.get("skip_nonfinite_grad", False)

    @cached_property
    def patience(self) -> int | float:
        """Early-stop patience in epoch mode."""
        return self.config.get("patience", float("inf"))

    @property
    def progress(self) -> float:
        """Normalized training progress in `[0, 1]`."""
        if self.steps is not None:
            return self.train_state.global_step / self.steps
        if self.epochs is not None:
            return self.train_state.epoch / self.epochs
        raise ValueError("cannot compute progress: neither `steps` nor `epochs` is configured")

    @property
    def train_splits(self) -> list[str]:
        """Configured or inferred training split names."""
        if "train_splits" in self.config:
            return self._sorted_unique(self.config["train_splits"])
        if self.datasets:
            inferred = [
                split
                for split, dataset in self.datasets.items()
                if split == "train" or getattr(dataset, "train", False) or getattr(dataset, "split", None) == "train"
            ]
            return self._sorted_unique(inferred)
        return []

    @property
    def evaluate_splits(self) -> list[str]:
        """Configured or inferred evaluation split names."""
        if "evaluate_splits" in self.config:
            return self._sorted_unique(self.config["evaluate_splits"])
        if self.datasets:
            train_splits = set(self.train_splits)
            return sorted(split for split in self.datasets if split not in train_splits)
        return []

    @staticmethod
    def _sorted_unique(values: Sequence[str] | str) -> list[str]:
        if isinstance(values, str):
            return [values]
        return sorted(dict.fromkeys(str(value) for value in values))

    @property
    def checkpoint_interval(self) -> int:
        """Checkpoint cadence in optimizer steps (step mode) or epochs (epoch mode)."""
        configured = self.config.get("checkpoint.interval")
        if configured is not None:
            return configured
        if self.epochs is not None:
            return 1
        if self.steps is not None:
            return max(ceil(self.steps / 20), 1)
        return 8_192

    @property
    def log_interval(self) -> int:
        """Step logging cadence."""
        configured = self.config.get("log_interval")
        if configured is not None:
            return configured
        if self.steps is not None:
            return max(ceil(self.steps / 100), 1)
        return 1_024

world_size property

Python
world_size: int

Distributed world size from environment.

rank property

Python
rank: int

Global rank from environment.

local_rank property

Python
local_rank: int

Local rank from environment.

distributed property

Python
distributed: bool

Whether distributed mode is active.

is_main_process property

Python
is_main_process: bool

Whether current rank is global main process.

is_local_main_process property

Python
is_local_main_process: bool

Whether current rank is local main process.

code_id cached property

Python
code_id: str | None

Stable code identity for the current checkout.

config_id cached property

Python
config_id: str

Stable semantic config identity for this runner.

id property

Python
id: str

Stable run identity derived from code identity and semantic config.

score_split cached property

Python
score_split: str | None

Split used for best-score selection.

scores property

Python
scores: FlatDict | None

Index-to-score mapping extracted from score_split/score_name.

best_index property

Python
best_index: int

Best result index according to configured score metric.

latest_result property

Python
latest_result: RoundDict | None

Most recent appended result row.

best_result property

Python
best_result: RoundDict | None

Best result row according to configured score metric.

latest_score property

Python
latest_score: float | None

Latest scalar score.

best_score property

Python
best_score: float | None

Best scalar score.

is_best property

Python
is_best: bool

Whether latest score matches current best score.

Returns True only when comparable scalar scores are available and agree within tolerance. Returns True on the first iteration (no prior results), and False when scores cannot be resolved (e.g., no score_split/score_name configured) — silently reporting best in that case would trigger phantom “best” checkpoint copies.

batch_size property

Python
batch_size: int

Infer batch size from config or first dataloader.

epochs property writable

Python
epochs: int | None

Configured epoch budget, if present.

steps property writable

Python
steps: int | None

Configured/derived optimizer-step budget.

is_step_mode property

Python
is_step_mode: bool

Whether runner is in step mode (epochs is unset).

accum_steps cached property

Python
accum_steps: int

Gradient accumulation steps.

precision cached property

Python
precision: str | None

Autocast precision mode.

max_grad_value cached property

Python
max_grad_value: float | None

Gradient value clipping threshold.

max_grad_norm cached property

Python
max_grad_norm: float | None

Gradient norm clipping threshold.

skip_nonfinite_grad cached property

Python
skip_nonfinite_grad: bool

Whether to skip optimizer updates when gradients are non-finite.

patience cached property

Python
patience: int | float

Early-stop patience in epoch mode.

progress property

Python
progress: float

Normalized training progress in [0, 1].

train_splits property

Python
train_splits: list[str]

Configured or inferred training split names.

evaluate_splits property

Python
evaluate_splits: list[str]

Configured or inferred evaluation split names.

checkpoint_interval property

Python
checkpoint_interval: int

Checkpoint cadence in optimizer steps (step mode) or epochs (epoch mode).

log_interval property

Python
log_interval: int

Step logging cadence.

__post_init__

Python
__post_init__() -> None

Hook called after __init__ by MetaRunner.

Source code in danling/runners/base_runner.py
Python
def __post_init__(self) -> None:
    """Hook called after `__init__` by `MetaRunner`."""
    self.workspace.save_metadata()

auto_restore

Python
auto_restore() -> None

Auto-load resume/pretrained sources declared in config.

Precedence

config.resume > config.auto_resume > config.pretrained.

Source code in danling/runners/base_runner.py
Python
def auto_restore(self) -> None:
    """Auto-load resume/pretrained sources declared in config.

    Precedence:
        `config.resume` > `config.auto_resume` > `config.pretrained`.
    """

    restore_target = self._resolve_auto_restore_target()
    if restore_target is None:
        return

    restore_kind, restore_source = restore_target
    if restore_kind == "checkpoint":
        self.load_checkpoint(restore_source)
        return
    self.load_pretrained(restore_source)

init_distributed

Python
init_distributed() -> None

Initialize the distributed environment.

The default is a no-op (single-process). Concrete runners override this hook to initialize the torch.distributed process group; see TorchRunner.init_distributed for the canonical specification.

Source code in danling/runners/base_runner.py
Python
def init_distributed(self) -> None:
    """
    Initialize the distributed environment.

    The default is a no-op (single-process). Concrete runners override
    this hook to initialize the torch.distributed process group; see
    [`TorchRunner.init_distributed`][danling.runners.TorchRunner.init_distributed]
    for the canonical specification.
    """

init_checkpoint_manager

Python
init_checkpoint_manager() -> None

Bind the runner’s checkpoint manager.

The default is a no-op — BaseRunner.__init__ already binds the FileCheckpointManager. Concrete runners override this hook to swap in the backend-appropriate manager via set_checkpoint_manager(...); see TorchRunner.init_checkpoint_manager for the canonical specification.

Source code in danling/runners/base_runner.py
Python
def init_checkpoint_manager(self) -> None:
    """
    Bind the runner's checkpoint manager.

    The default is a no-op — `BaseRunner.__init__` already binds the
    `FileCheckpointManager`. Concrete runners override this hook to swap
    in the backend-appropriate manager via `set_checkpoint_manager(...)`;
    see
    [`TorchRunner.init_checkpoint_manager`][danling.runners.TorchRunner.init_checkpoint_manager]
    for the canonical specification.
    """

init_fault_tolerance

Python
init_fault_tolerance() -> None

Initialize optional fault-tolerance runtime support.

Source code in danling/runners/base_runner.py
Python
def init_fault_tolerance(self) -> None:
    """Initialize optional fault-tolerance runtime support."""

    self.ft = FaultTolerance(self)

init_heartbeat

Python
init_heartbeat() -> None

Configure optional background heartbeat writer.

Source code in danling/runners/base_runner.py
Python
def init_heartbeat(self) -> None:
    """Configure optional background heartbeat writer."""

    self.supervisor.init_heartbeat()

init_garbage_collection

Python
init_garbage_collection() -> None

Configure optional runner-managed Python GC pacing.

Source code in danling/runners/base_runner.py
Python
def init_garbage_collection(self) -> None:
    """Configure optional runner-managed Python GC pacing."""

    self.supervisor.init_garbage_collection()

init_signal_handlers

Python
init_signal_handlers() -> None

Install runner-owned signal handlers for graceful preemption.

Source code in danling/runners/base_runner.py
Python
def init_signal_handlers(self) -> None:
    """Install runner-owned signal handlers for graceful preemption."""

    self.supervisor.init_signal_handlers()

prepare_for_shutdown_checkpoint

Python
prepare_for_shutdown_checkpoint() -> None

Finalize runner state before writing a forced shutdown checkpoint.

Source code in danling/runners/base_runner.py
Python
def prepare_for_shutdown_checkpoint(self) -> None:
    """Finalize runner state before writing a forced shutdown checkpoint."""

init_tensorboard

Python
init_tensorboard(*args, **kwargs) -> None

Initialize tensorboard writer.

Source code in danling/runners/base_runner.py
Python
@on_main_process
def init_tensorboard(self, *args, **kwargs) -> None:
    """Initialize tensorboard writer."""

    warn(
        "tensorboard is enabled, but this runner does not initialize a tensorboard writer",
        RuntimeWarning,
        stacklevel=2,
    )

init_wandb

Python
init_wandb(*args, **kwargs) -> None

Initialize Weights & Biases run for scalar logging.

Source code in danling/runners/base_runner.py
Python
@on_main_process
def init_wandb(self, *args, **kwargs) -> None:
    """Initialize Weights & Biases run for scalar logging."""

    try:
        import wandb
    except ImportError as exc:
        raise RuntimeError("wandb is enabled, but the `wandb` package is not installed") from exc

    wandb_config = self.config.wandb
    if "project" not in kwargs:
        kwargs["project"] = wandb_config.get("project") or self.workspace.lineage
    if "entity" not in kwargs and wandb_config.get("entity") is not None:
        kwargs["entity"] = wandb_config.entity
    if "group" not in kwargs:
        kwargs["group"] = wandb_config.get("group") or self.workspace.experiment
    if "name" not in kwargs:
        kwargs["name"] = wandb_config.get("name") or self.id
    if "job_type" not in kwargs and wandb_config.get("job_type") is not None:
        kwargs["job_type"] = wandb_config.job_type
    tags = wandb_config.get("tags")
    if "tags" not in kwargs and tags is not None:
        kwargs["tags"] = [tags] if isinstance(tags, str) else list(tags)
    if "dir" not in kwargs:
        kwargs["dir"] = wandb_config.get("dir") or self.workspace.dir
    if "mode" not in kwargs and wandb_config.get("mode") is not None:
        kwargs["mode"] = wandb_config.mode
    if "config" not in kwargs:
        kwargs["config"] = self.config.dict()

    self.wandb = cast(Any, wandb).init(*args, **kwargs)

set_seed

Python
set_seed(
    seed: int | None = None, bias: int | bool | None = None
) -> int

Set python/numpy RNG seeds and snapshot RNG state.

Parameters:

Name Type Description Default
seed
int | None

Base seed. Defaults to self.config.seed.

None
bias
int | bool | None

Optional per-process bias. None uses self.rank.

None

Returns:

Type Description
int

The process-local seed after applying bias.

Source code in danling/runners/base_runner.py
Python
def set_seed(self, seed: int | None = None, bias: int | bool | None = None) -> int:
    """Set python/numpy RNG seeds and snapshot RNG state.

    Args:
        seed: Base seed. Defaults to `self.config.seed`.
        bias: Optional per-process bias. `None` uses `self.rank`.

    Returns:
        The process-local seed after applying bias.
    """

    base_seed = self.config.seed if seed is None else seed
    if base_seed is None:
        raise ValueError("cannot set seed: no seed is configured and no seed argument was provided")
    base_seed = int(base_seed)

    self.config.seed = base_seed

    process_seed = base_seed
    if bias is None:
        bias = self.rank
    if bias:
        process_seed += int(bias)

    random.seed(process_seed)
    if np_random is not None:
        np_random.seed(process_seed)

    self.rng_state.python = random.getstate()
    self.rng_state.numpy = np_random.get_state() if np_random is not None else None
    return process_seed

set_deterministic

Python
set_deterministic() -> None

Enable deterministic behavior in subclass-specific backends.

Source code in danling/runners/base_runner.py
Python
def set_deterministic(self) -> None:
    """Enable deterministic behavior in subclass-specific backends."""

train

Python
train(*args, **kwargs)

Run top-level training workflow.

Source code in danling/runners/base_runner.py
Python
def train(self, *args, **kwargs):
    """Run top-level training workflow."""

    raise NotImplementedError

train_epochs

Python
train_epochs(*args, **kwargs)

Run epoch-mode training workflow.

Source code in danling/runners/base_runner.py
Python
def train_epochs(self, *args, **kwargs):
    """Run epoch-mode training workflow."""

    raise NotImplementedError

train_epoch

Python
train_epoch(*args, **kwargs)

Run one training epoch on a split.

Source code in danling/runners/base_runner.py
Python
def train_epoch(self, *args, **kwargs):
    """Run one training epoch on a split."""

    raise NotImplementedError

train_steps

Python
train_steps(*args, **kwargs)

Run step-mode training workflow.

Source code in danling/runners/base_runner.py
Python
def train_steps(self, *args, **kwargs):
    """Run step-mode training workflow."""

    raise NotImplementedError

train_step

Python
train_step(*args, **kwargs)

Run one training micro-step.

Concrete runners define the override contract; see TorchRunner.train_step for the canonical specification.

Source code in danling/runners/base_runner.py
Python
def train_step(self, *args, **kwargs):
    """
    Run one training micro-step.

    Concrete runners define the override contract; see
    [`TorchRunner.train_step`][danling.runners.TorchRunner.train_step] for
    the canonical specification.
    """

    raise NotImplementedError

backward

Python
backward(loss, *args, **kwargs) -> None

Run backward pass for one micro-step loss.

Source code in danling/runners/base_runner.py
Python
def backward(self, loss, *args, **kwargs) -> None:
    """Run backward pass for one micro-step loss."""

    raise NotImplementedError

step

Python
step(*args, **kwargs) -> None

Advance optimizer/scheduler state when accumulation is ready.

Source code in danling/runners/base_runner.py
Python
def step(self, *args, **kwargs) -> None:
    """Advance optimizer/scheduler state when accumulation is ready."""

    raise NotImplementedError

evaluate

Python
evaluate(*args, **kwargs)

Run top-level evaluation workflow.

Source code in danling/runners/base_runner.py
Python
def evaluate(self, *args, **kwargs):
    """Run top-level evaluation workflow."""

    raise NotImplementedError

evaluate_epoch

Python
evaluate_epoch(*args, **kwargs)

Run one full evaluation epoch on a split.

Source code in danling/runners/base_runner.py
Python
def evaluate_epoch(self, *args, **kwargs):
    """Run one full evaluation epoch on a split."""

    raise NotImplementedError

evaluate_steps

Python
evaluate_steps(*args, **kwargs)

Run bounded evaluation steps on a split.

Source code in danling/runners/base_runner.py
Python
def evaluate_steps(self, *args, **kwargs):
    """Run bounded evaluation steps on a split."""

    raise NotImplementedError

evaluate_step

Python
evaluate_step(*args, **kwargs)

Run one evaluation step.

Concrete runners define the override contract; see TorchRunner.evaluate_step for the canonical specification.

Source code in danling/runners/base_runner.py
Python
def evaluate_step(self, *args, **kwargs):
    """
    Run one evaluation step.

    Concrete runners define the override contract; see
    [`TorchRunner.evaluate_step`][danling.runners.TorchRunner.evaluate_step]
    for the canonical specification.
    """

    raise NotImplementedError

infer

Python
infer(*args, **kwargs)

Run top-level inference workflow.

Source code in danling/runners/base_runner.py
Python
def infer(self, *args, **kwargs):
    """Run top-level inference workflow."""

    raise NotImplementedError

infer_step

Python
infer_step(*args, **kwargs)

Run one inference step.

Concrete runners define the override contract; see TorchRunner.infer_step for the canonical specification.

Source code in danling/runners/base_runner.py
Python
def infer_step(self, *args, **kwargs):
    """
    Run one inference step.

    Concrete runners define the override contract; see
    [`TorchRunner.infer_step`][danling.runners.TorchRunner.infer_step] for
    the canonical specification.
    """

    raise NotImplementedError

unwrap

Python
unwrap(model: Any) -> Any

Return an unwrapped model object.

Source code in danling/runners/base_runner.py
Python
def unwrap(self, model: Any) -> Any:
    """Return an unwrapped model object."""

    return model

state_dict

Python
state_dict(cls: type = dict) -> Mapping

Build the backend-neutral runner checkpoint payload.

The base payload contains semantic runner config, mutable runner state, RNG snapshots, and dataloader resume state. Backend runners extend this payload with model/optimizer/scheduler state.

Called when: checkpoint managers build a payload for save_checkpoint, and fault-tolerance callbacks need a runner state snapshot.

Parameters:

Name Type Description Default
cls
type

Mapping factory used for nested payloads. Backends may pass dict-like containers to preserve their serialization format.

dict

Returns:

Type Description
Mapping

Mapping with runner, state, and dataloaders keys.

Side effects: snapshots Python and NumPy RNG state into self.rng_state before exporting.

Do not

  • Mutate model or optimizer state here.
  • Drop the runner config payload; resume validation depends on it.
  • Override without calling super() unless you fully replace the checkpoint format.
Source code in danling/runners/base_runner.py
Python
def state_dict(self, cls: type = dict) -> Mapping:
    """
    Build the backend-neutral runner checkpoint payload.

    The base payload contains semantic runner config, mutable runner
    state, RNG snapshots, and dataloader resume state. Backend runners
    extend this payload with model/optimizer/scheduler state.

    **Called when:** checkpoint managers build a payload for
    `save_checkpoint`, and fault-tolerance callbacks need a runner state
    snapshot.

    Args:
        cls: Mapping factory used for nested payloads. Backends may pass
            `dict`-like containers to preserve their serialization format.

    Returns:
        Mapping with `runner`, `state`, and `dataloaders` keys.

    **Side effects:** snapshots Python and NumPy RNG state into
    `self.rng_state` before exporting.

    !!! danger "Do not"
        - Mutate model or optimizer state here.
        - Drop the `runner` config payload; resume validation depends on it.
        - Override without calling `super()` unless you fully replace the
          checkpoint format.
    """

    self.rng_state.python = random.getstate()
    self.rng_state.numpy = np_random.get_state() if np_random is not None else None

    state = self.state.state_dict()
    if cls is not dict:
        state = cls(state)

    dataloader_state = self.dataloaders.state_dict()
    if cls is not dict:
        dataloader_state = cls(dataloader_state)

    return cls(runner=self.config.dict(), state=state, dataloaders=dataloader_state)

load_state_dict

Python
load_state_dict(checkpoint: Mapping[str, Any]) -> None

Restore backend-neutral runner state from a checkpoint payload.

This restores semantic runner state and Python/NumPy RNG state. Model, EMA, optimizer, scheduler, and dataloader component loading is owned by load_checkpoint.

Called when: load_checkpoint restores a full checkpoint, and fault-tolerance callbacks receive a runner state payload.

Parameters:

Name Type Description Default
checkpoint
Mapping[str, Any]

Mapping produced by state_dict or a backend-specific superset of that payload.

required

Raises:

Type Description
ValueError

checkpoint runner config differs semantically from the current runner config.

Side effects: updates self.state, self.train_state, self.elastic_state, self.rng_state, and process RNG state.

Do not

  • Load model/optimizer/scheduler state here; use component loaders through load_checkpoint.
  • Suppress semantic config diffs unless you also update the resume policy deliberately.
Source code in danling/runners/base_runner.py
Python
def load_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
    """
    Restore backend-neutral runner state from a checkpoint payload.

    This restores semantic runner state and Python/NumPy RNG state. Model,
    EMA, optimizer, scheduler, and dataloader component loading is owned by
    `load_checkpoint`.

    **Called when:** `load_checkpoint` restores a full checkpoint, and
    fault-tolerance callbacks receive a runner state payload.

    Args:
        checkpoint: Mapping produced by `state_dict` or a backend-specific
            superset of that payload.

    Raises:
        ValueError: checkpoint runner config differs semantically from the
            current runner config.

    **Side effects:** updates `self.state`, `self.train_state`,
    `self.elastic_state`, `self.rng_state`, and process RNG state.

    !!! danger "Do not"
        - Load model/optimizer/scheduler state here; use component loaders
          through `load_checkpoint`.
        - Suppress semantic config diffs unless you also update the resume
          policy deliberately.
    """

    runner_config = checkpoint.get("runner")
    if runner_config is not None:
        checkpoint_config = RunnerConfig(runner_config).canonical()
        current_config = self.config.canonical()
        semantic_diff = NestedDict(checkpoint_config).diff(current_config).dict()
        if semantic_diff:
            raise ValueError(
                "cannot load checkpoint: runner config is semantically different from current config; "
                f"start a new experiment or align config. diff={semantic_diff}"
            )

    state_dict = checkpoint.get("state") or {}
    self.state.load_state_dict(dict(state_dict))

    rng_state = state_dict.get("rng")
    if isinstance(rng_state, Mapping) and "python" in rng_state and self.rng_state.python is not None:
        random.setstate(self.rng_state.python)

    if (
        np_random is not None
        and isinstance(rng_state, Mapping)
        and "numpy" in rng_state
        and self.rng_state.numpy is not None
    ):
        np_random.set_state(self.rng_state.numpy)

save_checkpoint

Python
save_checkpoint(
    name: str = "latest",
    epochs: int | None = None,
    save_best: bool = True,
    last_step: bool = False,
    force: bool = False,
) -> None

Persist runner state through the active checkpoint manager.

Backend collective semantics are owned by checkpoint_manager.is_collective. File-style managers save on the main process only; collective managers require every rank to enter this method together.

Called when: training loops hit checkpoint cadence, final last_step saves run, or the supervisor handles a shutdown signal.

Parameters:

Name Type Description Default
name
str

Logical checkpoint alias, usually "latest" or "best".

'latest'
epochs
int | None

Epoch index used for history checkpoint naming. Defaults to self.train_state.epoch.

None
save_best
bool

Whether to publish/update the best-checkpoint alias when self.is_best is true.

True
last_step
bool

Whether this save is the final save for the run.

False
force
bool

Bypass cadence checks inside the manager.

False

Side effects: delegates to self.checkpoint_manager.save_checkpoint(...).

Do not

  • Add a main-process guard around calls to this method; DCP-style managers need all ranks to participate.
  • Bypass the checkpoint manager for normal runner checkpoints.
Source code in danling/runners/base_runner.py
Python
def save_checkpoint(
    self,
    name: str = "latest",
    epochs: int | None = None,
    save_best: bool = True,
    last_step: bool = False,
    force: bool = False,
) -> None:
    """
    Persist runner state through the active checkpoint manager.

    Backend collective semantics are owned by
    `checkpoint_manager.is_collective`. File-style managers save on the
    main process only; collective managers require every rank to enter this
    method together.

    **Called when:** training loops hit checkpoint cadence, final
    `last_step` saves run, or the supervisor handles a shutdown signal.

    Args:
        name: Logical checkpoint alias, usually `"latest"` or `"best"`.
        epochs: Epoch index used for history checkpoint naming. Defaults
            to `self.train_state.epoch`.
        save_best: Whether to publish/update the best-checkpoint alias
            when `self.is_best` is true.
        last_step: Whether this save is the final save for the run.
        force: Bypass cadence checks inside the manager.

    **Side effects:** delegates to
    `self.checkpoint_manager.save_checkpoint(...)`.

    !!! danger "Do not"
        - Add a main-process guard around calls to this method; DCP-style
          managers need all ranks to participate.
        - Bypass the checkpoint manager for normal runner checkpoints.
    """

    if not (self.is_main_process or self.checkpoint_manager.is_collective):
        return
    epochs = self.train_state.epoch if epochs is None else epochs
    self.checkpoint_manager.save_checkpoint(
        name=name,
        epochs=epochs,
        save_best=save_best,
        last_step=last_step,
        force=force,
    )

save_seed_checkpoint

Python
save_seed_checkpoint(name: str = 'seed') -> None

Persist an initialization checkpoint for cross-topology experiments.

Seed checkpoints are intended to be created before training advances, then loaded with checkpoint.load_only=True or resume/pretrained when comparing different parallel layouts from the same initial model state. They are saved through the final-checkpoint path, so checkpoint.last_save_model_only=True intentionally applies.

Source code in danling/runners/base_runner.py
Python
def save_seed_checkpoint(self, name: str = "seed") -> None:
    """
    Persist an initialization checkpoint for cross-topology experiments.

    Seed checkpoints are intended to be created before training advances,
    then loaded with `checkpoint.load_only=True` or `resume`/`pretrained`
    when comparing different parallel layouts from the same initial model
    state. They are saved through the final-checkpoint path, so
    `checkpoint.last_save_model_only=True` intentionally applies.
    """
    if self.train_state.global_step != 0 or self.train_state.epoch != 0:
        warn(
            "save_seed_checkpoint() is intended before training advances; "
            f"current epoch={self.train_state.epoch}, global_step={self.train_state.global_step}",
            RuntimeWarning,
            stacklevel=2,
        )
    self.save_checkpoint(name=name, epochs=0, save_best=False, last_step=True, force=True)

load_checkpoint

Python
load_checkpoint(
    checkpoint: Mapping | bytes | str | PathLike,
    *args: Any,
    **kwargs: Any
) -> None

Restore a full runner checkpoint.

This is the full-state restore path: runtime state, model/EMA, optimizer, scheduler, and dataloader progress are restored when present and applicable to the current runner.

Called when: users resume a run explicitly, auto_restore selects a resume source, from_checkpoint constructs a runner, or fault-tolerance callbacks restore a full runner payload.

Parameters:

Name Type Description Default
checkpoint
Mapping | bytes | str | PathLike

In-memory checkpoint mapping or backend-specific path.

required
*args
Any

Forwarded to read_checkpoint and component loaders.

()
**kwargs
Any

Forwarded to read_checkpoint and component loaders.

{}

Raises:

Type Description
ValueError

checkpoint is missing required component state for an initialized component, or config validation fails.

Side effects: updates runner state, model/EMA weights, optimizer, scheduler, dataloader progress, and config.resume for path inputs.

Do not

  • Use this for model-only finetuning payloads; use load_pretrained instead.
  • Override just to support a new path type; prefer overriding read_checkpoint.
Source code in danling/runners/base_runner.py
Python
def load_checkpoint(
    self,
    checkpoint: Mapping | bytes | str | os.PathLike,
    *args: Any,
    **kwargs: Any,
) -> None:
    """
    Restore a full runner checkpoint.

    This is the full-state restore path: runtime state, model/EMA,
    optimizer, scheduler, and dataloader progress are restored when present
    and applicable to the current runner.

    **Called when:** users resume a run explicitly, `auto_restore` selects
    a resume source, `from_checkpoint` constructs a runner, or
    fault-tolerance callbacks restore a full runner payload.

    Args:
        checkpoint: In-memory checkpoint mapping or backend-specific path.
        *args: Forwarded to `read_checkpoint` and component loaders.
        **kwargs: Forwarded to `read_checkpoint` and component loaders.

    Raises:
        ValueError: checkpoint is missing required component state for an
            initialized component, or config validation fails.

    **Side effects:** updates runner state, model/EMA weights, optimizer,
    scheduler, dataloader progress, and `config.resume` for path inputs.

    !!! danger "Do not"
        - Use this for model-only finetuning payloads; use
          `load_pretrained` instead.
        - Override just to support a new path type; prefer overriding
          `read_checkpoint`.
    """

    ckpt = self.read_checkpoint(checkpoint, *args, **kwargs)
    excluded_paths = self.checkpoint_exclude_from_loading()
    if excluded_paths:
        if self._is_top_level_checkpoint_excluded(excluded_paths, "runner"):
            warn(
                "`checkpoint.exclude_from_loading` contains 'runner'; "
                "semantic runner config validation will be skipped for this load.",
                RuntimeWarning,
                stacklevel=2,
            )
        ckpt = self._filter_checkpoint_for_loading(ckpt, excluded_paths)

    self.load_state_dict(ckpt)
    if not self._is_top_level_checkpoint_excluded(excluded_paths, "model", "model_parts", "module"):
        if "model" in ckpt:
            self.load_model(ckpt["model"], *args, **kwargs)
        elif "model_parts" in ckpt:
            self.load_model(ckpt["model_parts"], *args, **kwargs)
        elif self.model is not None:
            raise ValueError(
                "cannot restore model: checkpoint has no model state\n"
                "Use `load_pretrained` only for model-only checkpoints with model/ema payloads"
            )
    if not self._is_top_level_checkpoint_excluded(excluded_paths, "ema") and (
        self.ema is not None or "ema" in ckpt
    ):
        self.load_ema(ckpt.get("ema"), *args, **kwargs)
    if not self._is_top_level_checkpoint_excluded(excluded_paths, "optimizer") and (
        self.optimizer is not None or "optimizer" in ckpt
    ):
        self.load_optimizer(ckpt.get("optimizer"), *args, **kwargs)
    if not self._is_top_level_checkpoint_excluded(excluded_paths, "scheduler") and (
        self.scheduler is not None or "scheduler" in ckpt
    ):
        self.load_scheduler(ckpt.get("scheduler"), *args, **kwargs)
    if not self._is_top_level_checkpoint_excluded(excluded_paths, "dataloaders") and (
        self.dataloaders or "dataloaders" in ckpt
    ):
        self.load_dataloaders(ckpt.get("dataloaders"))
    if isinstance(checkpoint, (str, bytes, os.PathLike)):
        self.config.resume = os.fsdecode(checkpoint)

load_model

Python
load_model(
    state_dict: Mapping[str, Any], *args, **kwargs
) -> None

Load model state.

Source code in danling/runners/base_runner.py
Python
def load_model(self, state_dict: Mapping[str, Any], *args, **kwargs) -> None:
    """Load model state."""
    if self.model is None:
        raise ValueError("cannot restore model: model is not initialized")
    self.unwrap(self.model).load_state_dict(state_dict, *args, **kwargs)

load_ema

Python
load_ema(
    state_dict: Mapping[str, Any] | None, *args, **kwargs
) -> None

Load EMA state.

Source code in danling/runners/base_runner.py
Python
def load_ema(self, state_dict: Mapping[str, Any] | None, *args, **kwargs) -> None:
    """Load EMA state."""
    if self.ema is None:
        return
    state_dict = self._require_checkpoint_component_state("ema", state_dict)
    self.ema.load_state_dict(state_dict, *args, **kwargs)

load_optimizer

Python
load_optimizer(
    state_dict: Mapping[str, Any] | None, *args, **kwargs
) -> None

Load optimizer state.

Source code in danling/runners/base_runner.py
Python
def load_optimizer(self, state_dict: Mapping[str, Any] | None, *args, **kwargs) -> None:
    """Load optimizer state."""
    if self.optimizer is None:
        return
    state_dict = self._require_checkpoint_component_state("optimizer", state_dict)
    self.optimizer.load_state_dict(state_dict, *args, **kwargs)

load_scheduler

Python
load_scheduler(
    state_dict: Mapping[str, Any] | None, *args, **kwargs
) -> None

Load scheduler state.

Source code in danling/runners/base_runner.py
Python
def load_scheduler(self, state_dict: Mapping[str, Any] | None, *args, **kwargs) -> None:
    """Load scheduler state."""
    if self.scheduler is None:
        return
    state_dict = self._require_checkpoint_component_state("scheduler", state_dict)
    self.scheduler.load_state_dict(state_dict, *args, **kwargs)

load_dataloaders

Python
load_dataloaders(
    state_dict: Mapping[str, Any] | None,
) -> None

Load dataloader progress state when the current runner has matching loaders.

Source code in danling/runners/base_runner.py
Python
def load_dataloaders(self, state_dict: Mapping[str, Any] | None) -> None:
    """Load dataloader progress state when the current runner has matching loaders."""
    if state_dict is None:
        return
    self.dataloaders.load_state_dict(state_dict)

load_pretrained

Python
load_pretrained(
    checkpoint: Mapping | bytes | str | PathLike,
    *args: Any,
    **kwargs: Any
) -> None

Load model weights only from a checkpoint payload or path.

When checkpoint payload provides EMA weights (ema), EMA is preferred as the pretrained source. Otherwise model is used.

Called when: users initialize from pretrained weights, or auto_restore selects config.pretrained.

Parameters:

Name Type Description Default
checkpoint
Mapping | bytes | str | PathLike

In-memory payload or backend-specific path containing ema, model, or model_parts.

required
*args
Any

Forwarded to read_checkpoint and load_model.

()
**kwargs
Any

Forwarded to read_checkpoint and load_model.

{}

Raises:

Type Description
ValueError

model is not initialized, or the payload has no usable model/EMA state.

Side effects: loads model weights and updates config.pretrained for path inputs. Optimizer, scheduler, runner state, and dataloaders are intentionally untouched.

Do not

  • Use this to resume training state; use load_checkpoint for full-state restore.
  • Load optimizer/scheduler state in this path.
Source code in danling/runners/base_runner.py
Python
def load_pretrained(
    self,
    checkpoint: Mapping | bytes | str | os.PathLike,
    *args: Any,
    **kwargs: Any,
) -> None:
    """
    Load model weights only from a checkpoint payload or path.

    When checkpoint payload provides EMA weights (`ema`), EMA is preferred as
    the pretrained source. Otherwise `model` is used.

    **Called when:** users initialize from pretrained weights, or
    `auto_restore` selects `config.pretrained`.

    Args:
        checkpoint: In-memory payload or backend-specific path containing
            `ema`, `model`, or `model_parts`.
        *args: Forwarded to `read_checkpoint` and `load_model`.
        **kwargs: Forwarded to `read_checkpoint` and `load_model`.

    Raises:
        ValueError: model is not initialized, or the payload has no usable
            model/EMA state.

    **Side effects:** loads model weights and updates `config.pretrained`
    for path inputs. Optimizer, scheduler, runner state, and dataloaders
    are intentionally untouched.

    !!! danger "Do not"
        - Use this to resume training state; use `load_checkpoint` for
          full-state restore.
        - Load optimizer/scheduler state in this path.
    """

    if self.model is None:
        raise ValueError("cannot load pretrained weights: model is not initialized")

    ckpt = self.read_checkpoint(checkpoint, *args, **kwargs)
    if ckpt.get("ema") is not None:
        self.load_model(ckpt["ema"], *args, **kwargs)
    elif "model" in ckpt:
        self.load_model(ckpt["model"], *args, **kwargs)
    elif "model_parts" in ckpt:
        self.load_model(ckpt["model_parts"], *args, **kwargs)
    else:
        raise ValueError(
            "cannot load pretrained weights: checkpoint has no EMA or model state\n"
            "Use `load_checkpoint` for full checkpoint restore instead of `load_pretrained`"
        )
    if isinstance(checkpoint, (str, bytes, os.PathLike)):
        self.config.pretrained = os.fsdecode(checkpoint)
    else:
        self.config.pretrained = None

from_checkpoint classmethod

Python
from_checkpoint(
    checkpoint: Mapping | bytes | str | PathLike,
    *args,
    **kwargs
) -> BaseRunner

Instantiate runner from checkpoint config and restore full state.

Source code in danling/runners/base_runner.py
Python
@classmethod
def from_checkpoint(cls, checkpoint: Mapping | bytes | str | os.PathLike, *args, **kwargs) -> BaseRunner:
    """Instantiate runner from checkpoint config and restore full state."""

    config = cls.read_config(checkpoint, *args, **kwargs)
    config.resume = None
    config.auto_resume = False
    config.pretrained = None
    runner = cls(config)
    runner.load_checkpoint(checkpoint, *args, **kwargs)
    return runner

read_config classmethod

Python
read_config(
    checkpoint: Mapping | bytes | str | PathLike,
    *args,
    **kwargs
) -> RunnerConfig

Read runner config from checkpoint mapping or file path.

Note

BaseRunner only accepts file checkpoints for path input. Backend-specific directory checkpoints must be handled in subclasses.

Source code in danling/runners/base_runner.py
Python
@classmethod
def read_config(
    cls,
    checkpoint: Mapping | bytes | str | os.PathLike,
    *args,
    **kwargs,
) -> RunnerConfig:
    """
    Read runner config from checkpoint mapping or file path.

    Note:
        BaseRunner only accepts file checkpoints for path input.
        Backend-specific directory checkpoints must be handled in subclasses.
    """

    if isinstance(checkpoint, Mapping):
        ckpt = checkpoint
    elif isinstance(checkpoint, (bytes, str, os.PathLike)):
        checkpoint_id = os.fspath(checkpoint)
        if os.path.isfile(checkpoint_id):
            kwargs = dict(kwargs)
            kwargs["map_location"] = "cpu"
            kwargs["weights_only"] = False
            ckpt = load(checkpoint, *args, **kwargs)
        else:
            raise ValueError(
                f"cannot read config from checkpoint path for {cls.__name__}: path must be a file; "
                "use a backend-specific runner for directory-style checkpoints"
            )
    else:
        raise ValueError(
            "invalid checkpoint input: expected a mapping or path, "
            f"got {type(checkpoint).__name__}: {checkpoint!r}"
        )

    if "runner" not in ckpt:
        raise ValueError(
            "cannot read runner config: checkpoint is missing key 'runner'; "
            "use from_pretrained(...) for model-only checkpoints"
        )
    return RunnerConfig(ckpt["runner"])

from_pretrained classmethod

Python
from_pretrained(
    config: RunnerConfig | Mapping[str, Any],
    checkpoint: Mapping | bytes | str | PathLike,
    *args,
    **kwargs
) -> BaseRunner

Build a runner from config and load model weights only.

Source code in danling/runners/base_runner.py
Python
@classmethod
def from_pretrained(
    cls,
    config: RunnerConfig | Mapping[str, Any],
    checkpoint: Mapping | bytes | str | os.PathLike,
    *args,
    **kwargs,
) -> BaseRunner:
    """Build a runner from config and load model weights only."""

    prepared = RunnerConfig(config)
    prepared.resume = None
    prepared.auto_resume = False
    prepared.pretrained = None
    runner = cls(prepared)
    runner.load_pretrained(checkpoint, *args, **kwargs)
    return runner

read_checkpoint

Python
read_checkpoint(
    checkpoint: Mapping | bytes | str | PathLike,
    *args,
    **kwargs
) -> Mapping[str, Any]

Normalize checkpoint input into an in-memory mapping payload.

Source code in danling/runners/base_runner.py
Python
def read_checkpoint(self, checkpoint: Mapping | bytes | str | os.PathLike, *args, **kwargs) -> Mapping[str, Any]:
    """Normalize checkpoint input into an in-memory mapping payload."""
    if isinstance(checkpoint, (bytes, str, os.PathLike)):
        kwargs = dict(kwargs)
        kwargs["map_location"] = "cpu"
        kwargs["weights_only"] = False
        return load(checkpoint, *args, **kwargs)
    if isinstance(checkpoint, Mapping):
        return checkpoint
    raise ValueError(
        f"invalid checkpoint input: expected a mapping or path, got {type(checkpoint).__name__}: {checkpoint!r}"
    )

save

Python
save(
    obj: Any,
    file: PathStr,
    main_process_only: bool = True,
    *args,
    **kwargs
) -> File

Save an object with optional main-process guard.

Source code in danling/runners/base_runner.py
Python
def save(self, obj: Any, file: PathStr, main_process_only: bool = True, *args, **kwargs) -> File:
    """Save an object with optional main-process guard."""

    if (main_process_only and self.is_main_process) or not main_process_only:
        return save(obj, file, *args, **kwargs)
    return file

close

Python
close(timeout: float | None = None) -> bool

Finalize checkpoint/log/writer resources before shutdown.

Source code in danling/runners/base_runner.py
Python
def close(self, timeout: float | None = None) -> bool:
    """Finalize checkpoint/log/writer resources before shutdown."""

    if timeout is None:
        timeout = self.config.get("checkpoint.wait_timeout")

    drained = True
    close_error: Exception | None = None
    try:
        drained = self.checkpoint_manager.close(timeout=timeout)
    except Exception as exc:
        close_error = exc

    if close_error is None and not drained:
        warn("runner close: timed out while draining async checkpoints", RuntimeWarning, stacklevel=2)
        return False

    self.supervisor.restore_signal_handlers()
    writer = self.writer
    if writer is not None:
        writer.flush()
        writer.close()
        self.writer = None

    if self.wandb is not None:
        self.wandb.finish()

    self.workspace.close()
    self.supervisor.close()
    if self.ft is not None:
        self.ft.close()

    if close_error is not None:
        raise close_error
    return drained

DeepSpeedRunner

Bases: TorchRunner

DeepSpeed-backed runner focused on ZeRO-½ training flows.

Use this runner when DeepSpeed should own the training engine and optimizer update while DanLing still owns the outer lifecycle: dataloaders, metrics, accumulation normalization, result writing, and checkpoint alias policy.

DeepSpeed checkpoints are directory/tag based. DanLing writes lightweight pointer files (latest.pointer, best.pointer, and named aliases) so the public checkpoint API can keep using logical names.

Attributes:

Name Type Description
model DeepSpeedEngine

DeepSpeed engine after _finalize_runtime_components.

deepspeed_config dict[str, Any]

Effective DeepSpeed config passed to deepspeed.initialize.

Source code in danling/runners/deepspeed_runner.py
Python
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
class DeepSpeedRunner(TorchRunner):
    """
    DeepSpeed-backed runner focused on ZeRO-1/2 training flows.

    Use this runner when DeepSpeed should own the training engine and
    optimizer update while DanLing still owns the outer lifecycle: dataloaders,
    metrics, accumulation normalization, result writing, and checkpoint alias
    policy.

    DeepSpeed checkpoints are directory/tag based. DanLing writes lightweight
    pointer files (`latest.pointer`, `best.pointer`, and named aliases) so the
    public checkpoint API can keep using logical names.

    Attributes:
        model: DeepSpeed engine after `_finalize_runtime_components`.
        deepspeed_config: Effective DeepSpeed config passed to
            `deepspeed.initialize`.
    """

    model: deepspeed.DeepSpeedEngine
    deepspeed_config: dict[str, Any]
    _supports_torchft_runtime: bool = False

    def __init__(self, config) -> None:
        ds.check()
        if not isinstance(config, RunnerConfig):
            config = RunnerConfig(config)
        config.stack = "deepspeed"
        requested_backend = str(config.checkpoint.backend).strip().lower()
        if requested_backend == "dcp":
            warn(
                "DeepSpeedRunner overrides checkpoint.backend to 'file'",
                RuntimeWarning,
                stacklevel=2,
            )
        # DeepSpeed always uses the file backend; "auto" and "dcp" both fold to "file".
        coerced = "file" if requested_backend in {"auto", "dcp"} else requested_backend
        config.checkpoint.backend = self._validate_checkpoint_backend(coerced)
        super().__init__(config)

    def materialize_model(self) -> None:
        """
        Move and compile the local model before DeepSpeed engine creation.

        **Called when:** `TorchRunner.__post_init__` reaches
        `materialize_model`, before `build_optimizer`, `build_scheduler`, and
        `_finalize_runtime_components`.

        **Precondition:** `self.model` is the user-provided `nn.Module`, not
        yet a DeepSpeed engine.

        Raises:
            ValueError: `self.model` is not initialized.

        **Side effects:** moves the model and optional EMA module to
        `self.device`, applies FP8 policy when enabled, and compiles the model.
        DeepSpeed wrapping happens later in the engine-finalization step.

        !!! danger "Do not"
            - Call `deepspeed.initialize` here; optimizer and scheduler build
              happen after this hook.
            - DDP-wrap the model; DeepSpeed owns distributed wrapping.
        """
        if self.model is None:
            raise ValueError("cannot materialize DeepSpeed model: model is not initialized")

        model = self.model.to(self.device)
        self.model = model
        if self.fp8_enabled:
            self.apply_fp8_module_policy_to_model_parts()
            model = self.model
        model = self.compiler.compile(model)
        self.model = model

        if self.ema is not None:
            self.ema = self.ema.to(self.device)

    def get_deepspeed_config(self) -> dict[str, Any]:
        """
        Build the effective DeepSpeed config.

        **Called when:** `_finalize_runtime_components` initializes the
        DeepSpeed engine.

        Returns:
            A mutable config dict suitable for `deepspeed.initialize`.

        Raises:
            ValueError: `config.deepspeed` is present but not a mapping.

        **Side effects:** none. The returned config forces
        `gradient_accumulation_steps=1` because DanLing owns accumulation
        boundaries, fills `train_micro_batch_size_per_gpu` from the dataloader
        batch size when absent, and mirrors runner precision into DeepSpeed
        precision sections when possible.
        """
        runtime_config = getattr(self, "deepspeed_config", None)
        if runtime_config is not None:
            return dict(runtime_config)

        cfg = self.config.get("deepspeed")
        if cfg is None:
            ds_config: dict[str, Any] = {}
        elif isinstance(cfg, Mapping):
            ds_config = dict(cfg)
        else:
            raise ValueError(f"invalid deepspeed config: expected mapping, got {type(cfg).__name__}")

        grad_accum = ds_config.get("gradient_accumulation_steps")
        if grad_accum is not None and grad_accum != 1:
            warn(
                "DeepSpeedRunner manages accumulation via config.accum_steps; overriding "
                "deepspeed.gradient_accumulation_steps to 1",
                RuntimeWarning,
                stacklevel=2,
            )
        ds_config["gradient_accumulation_steps"] = 1

        if "train_micro_batch_size_per_gpu" not in ds_config:
            batch_size = self.config.get("dataloader.batch_size")
            if batch_size is not None:
                ds_config["train_micro_batch_size_per_gpu"] = batch_size

        precision = self.precision
        if precision is not None:
            normalized_precision = str(precision).lower().replace("-", "_")
            if normalized_precision in {"fp16", "float16", "half"} and "fp16" not in ds_config:
                ds_config["fp16"] = {"enabled": True}
            if normalized_precision in {"bf16", "bfloat16"} and "bf16" not in ds_config:
                ds_config["bf16"] = {"enabled": True}

        return ds_config

    def _resolve_deepspeed_scheduler(self, scheduler: object | None) -> object | None:
        if scheduler is None:
            return None
        sched_cfg = self._get_scheduler_config()
        interval = sched_cfg.get("interval") if sched_cfg is not None else None
        if normalize_scheduler_interval(interval, scheduler) != "step":
            return None
        return scheduler

    def _finalize_runtime_components(self) -> None:
        """
        Create the DeepSpeed engine after model/optimizer/scheduler build.

        **Called when:** `TorchRunner.__post_init__` has already run
        `materialize_model`, `build_optimizer`, and `build_scheduler`.

        **Side effects:** calls `deepspeed.initialize`, replaces `self.model`
        with the engine, replaces `self.optimizer` with the engine optimizer,
        and hands step schedulers to DeepSpeed while keeping epoch/metric
        schedulers under runner control.
        """
        ds_config = self.get_deepspeed_config()
        self.deepspeed_config = ds_config
        runner_scheduler = self.scheduler
        # DeepSpeed should own only per-step schedulers. Epoch and metric schedulers
        # still need the runner's explicit step boundary and metric resolution path.
        deepspeed_scheduler = self._resolve_deepspeed_scheduler(runner_scheduler)
        self._runner_owns_scheduler = runner_scheduler is not None and deepspeed_scheduler is None
        model_engine, optimizer, _, scheduler = deepspeed.initialize(
            model=self.model,
            optimizer=self.optimizer,
            lr_scheduler=deepspeed_scheduler,
            config=ds_config,
        )
        self.model = model_engine
        self.optimizer = optimizer
        self.scheduler = (
            scheduler
            if deepspeed_scheduler is not None and scheduler is not None
            else (deepspeed_scheduler if deepspeed_scheduler is not None else runner_scheduler)
        )

    def _bind_optimizer_container(self) -> None:
        self.optimizer_container = None

    def backward(self, loss: torch.Tensor) -> None:
        """
        Route one micro-step backward pass through the DeepSpeed engine.

        Args:
            loss: Raw micro-step loss from `train_step`.

        **Side effects:** accumulates gradients inside the DeepSpeed engine
        after DanLing's loss-scaling/normalization policy is applied.
        """
        self.model.backward(self._scaled_loss_for_backward(loss))

    def optimizer_step(self) -> bool:
        """
        Perform one DeepSpeed engine optimizer update.

        DeepSpeed owns the concrete optimizer step; DanLing keeps accumulation
        normalization, runner state, profiler, timeout, and supervisor state in sync.
        """
        self.checkpoint_manager.maybe_wait_for_staging()
        grad_scale = self._gradient_scale_for_step()
        if grad_scale is not None:
            self._scale_optimizer_gradients(grad_scale)
        self.model.step()
        self._reset_accumulation_normalization()
        global_steps = getattr(self.model, "global_steps", None)
        if global_steps is None:
            self.train_state.global_step += 1
        else:
            self.train_state.global_step = int(global_steps)
        self._step_profiler()
        self._maybe_reduce_train_process_group_timeout()
        self.supervisor.maybe_collect_garbage(self.train_state.global_step, scope="train")
        return True

    def _auto_resume_source(self) -> str:
        return self.workspace.checkpoint_dir

    def _checkpoint_pointer_path(self, name: str) -> str:
        return os.path.join(self.workspace.checkpoint_dir, f"{name}.pointer")

    def _write_checkpoint_pointer(self, name: str, target_tag: str) -> None:
        pointer_path = self._checkpoint_pointer_path(name)
        pointer_tmp_path = f"{pointer_path}.tmp-{self.id}"
        with open(pointer_tmp_path, "w", encoding="utf-8") as fp:
            fp.write(target_tag)
        os.replace(pointer_tmp_path, pointer_path)

    def _record_deepspeed_checkpoint_failure(
        self,
        exc: Exception,
        *,
        target: str,
        alias: str | None = None,
    ) -> None:
        self.checkpoint_manager.record_checkpoint_failure(exc, target=target, alias=alias)
        warn(f"deepspeed checkpoint save failed: {exc}", RuntimeWarning, stacklevel=2)
        self.checkpoint_manager.raise_checkpoint_error_if_requested()

    @staticmethod
    def _read_checkpoint_pointer(checkpoint_path: bytes | str | os.PathLike) -> str:
        pointer_path = os.fsdecode(checkpoint_path)
        with open(pointer_path, encoding="utf-8") as fp:
            tag = fp.read().strip()
        if not tag:
            raise ValueError(f"invalid DeepSpeed checkpoint pointer: {pointer_path!r} is empty")
        return tag

    def _resolve_physical_checkpoint_tag(self, *, name: str, epochs: int, should_update_best: bool) -> str:
        history_name = self.checkpoint_manager.resolve_history_name(epochs)
        if history_name is not None:
            return history_name
        if should_update_best:
            return f"ckpt-g{self.train_state.global_step:012d}"
        return name

    def save_checkpoint(
        self,
        name: str = "latest",
        epochs: int | None = None,
        save_best: bool = True,
        last_step: bool = False,
        force: bool = False,
    ) -> None:
        """
        Save a DeepSpeed checkpoint and publish DanLing pointer aliases.

        **Called when:** the training loop or shutdown supervisor requests a
        checkpoint save.

        Args:
            name: Logical alias to publish in addition to `latest`.
            epochs: Epoch index used for retention/history naming.
            save_best: Whether to publish `best.pointer` when the current
                result is best.
            last_step: Whether this is the final checkpoint save.
            force: Bypass checkpoint manager cadence checks.

        **Side effects:** all ranks enter `DeepSpeedEngine.save_checkpoint`.
        The main process writes `runner.yaml` and pointer files for logical
        aliases. Success/failure is reported through the checkpoint manager.

        !!! danger "Do not"
            - Guard the whole method with `is_main_process`; DeepSpeed saves
              are collective.
            - Write aliases before `save_checkpoint` succeeds.
            - Use the generic file checkpoint payload here; DeepSpeed owns the
              physical checkpoint layout.
        """
        epochs = self.train_state.epoch if epochs is None else epochs
        if not self.checkpoint_manager.should_persist_checkpoint(epochs=epochs, last_step=last_step, force=force):
            return

        client_state: dict = BaseRunner.state_dict(self, dict)  # type: ignore[assignment]
        client_state["ema"] = self.ema.state_dict() if self.ema else None
        client_state["scheduler"] = (
            self.scheduler.state_dict() if getattr(self, "_runner_owns_scheduler", False) and self.scheduler else None
        )
        should_update_best = bool(save_best and self.is_best)
        physical_tag = self._resolve_physical_checkpoint_tag(
            name=name, epochs=epochs, should_update_best=should_update_best
        )
        try:
            self.model.save_checkpoint(
                self.workspace.checkpoint_dir,
                tag=physical_tag,
                client_state=client_state,
                save_latest=False,
            )
        except Exception as exc:
            self._record_deepspeed_checkpoint_failure(exc, target=physical_tag)
            return

        if self.distributed and not self.is_main_process:
            return

        tag_dir = os.path.join(self.workspace.checkpoint_dir, physical_tag)
        try:
            if os.path.isdir(tag_dir):
                self.config.yaml(os.path.join(tag_dir, "runner.yaml"))
        except Exception as exc:
            self._record_deepspeed_checkpoint_failure(exc, target=physical_tag)
            return

        published_aliases: list[str] = []
        try:
            self._write_checkpoint_pointer("latest", physical_tag)
        except Exception as exc:
            self._record_deepspeed_checkpoint_failure(exc, target=physical_tag, alias="latest")
            return
        published_aliases.append("latest")

        if name not in {"latest", physical_tag}:
            try:
                self._write_checkpoint_pointer(name, physical_tag)
            except Exception as exc:
                self.checkpoint_manager.record_checkpoint_success(target=physical_tag, aliases=tuple(published_aliases))
                self._record_deepspeed_checkpoint_failure(exc, target=physical_tag, alias=name)
                return
            published_aliases.append(name)

        if should_update_best:
            try:
                self._write_checkpoint_pointer("best", physical_tag)
            except Exception as exc:
                self.checkpoint_manager.record_checkpoint_success(target=physical_tag, aliases=tuple(published_aliases))
                self._record_deepspeed_checkpoint_failure(exc, target=physical_tag, alias="best")
                return
            published_aliases.append("best")

        self.checkpoint_manager.record_checkpoint_success(target=physical_tag, aliases=tuple(published_aliases))

    @staticmethod
    def _resolve_deepspeed_checkpoint(checkpoint: bytes | str | os.PathLike) -> tuple[str, str]:
        checkpoint_path = os.fsdecode(checkpoint)

        if os.path.isfile(checkpoint_path):
            return os.path.dirname(checkpoint_path), DeepSpeedRunner._read_checkpoint_pointer(checkpoint_path)

        if os.path.isdir(checkpoint_path):
            latest_pointer = os.path.join(checkpoint_path, "latest.pointer")
            if os.path.isfile(latest_pointer):
                return checkpoint_path, DeepSpeedRunner._read_checkpoint_pointer(latest_pointer)
            latest_file = os.path.join(checkpoint_path, "latest")
            if os.path.isfile(latest_file):
                return checkpoint_path, DeepSpeedRunner._read_checkpoint_pointer(latest_file)
            if os.path.isdir(latest_file):
                return checkpoint_path, "latest"
            return os.path.dirname(checkpoint_path), os.path.basename(checkpoint_path)

        raise FileNotFoundError(f"checkpoint path does not exist: {checkpoint_path!r}")

    def load_checkpoint(
        self,
        checkpoint: Mapping | bytes | str | os.PathLike,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """
        Restore a full DeepSpeed checkpoint.

        Mapping checkpoints delegate to `TorchRunner.load_checkpoint`. Path
        checkpoints resolve pointer files/directories to a DeepSpeed
        `(checkpoint_dir, tag)` pair, then load engine state and DanLing client
        state.

        Args:
            checkpoint: In-memory payload, pointer file, checkpoint directory,
                or tagged checkpoint directory.
            *args: Forwarded to component loaders for client state.
            **kwargs: Forwarded to component loaders for client state.

        **Side effects:** restores DeepSpeed engine state, runner state,
        optional EMA, runner-owned scheduler state, dataloader state, and
        `config.resume`.

        !!! danger "Do not"
            - Treat DeepSpeed pointer files as torch `load` payloads; resolve
              them to a tag first.
            - Rebind an `OptimizerContainer`; DeepSpeed owns optimizer
              stepping.
        """
        if isinstance(checkpoint, Mapping):
            super().load_checkpoint(checkpoint, *args, **kwargs)
            return

        checkpoint_dir, checkpoint_tag = self._resolve_deepspeed_checkpoint(checkpoint)
        _, client_state = self.model.load_checkpoint(checkpoint_dir, tag=checkpoint_tag)

        if client_state is not None:
            BaseRunner.load_state_dict(self, client_state)
            if self.ema is not None and client_state.get("ema") is not None:
                self.load_ema(client_state["ema"], *args, **kwargs)
            if getattr(self, "_runner_owns_scheduler", False) and client_state.get("scheduler") is not None:
                self.load_scheduler(client_state["scheduler"], *args, **kwargs)
            if self.dataloaders or "dataloaders" in client_state:
                self.load_dataloaders(client_state.get("dataloaders"))

        self.config.resume = os.fsdecode(checkpoint)
        self.optimizer_container = None

    def load_pretrained(
        self,
        checkpoint: Mapping | bytes | str | os.PathLike,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """
        Load DeepSpeed model weights without restoring training state.

        Mapping checkpoints delegate to the generic pretrained path. Path
        checkpoints use `DeepSpeedEngine.load_checkpoint(..., load_module_only=True)`.
        If DanLing client state contains EMA weights, EMA is used as the
        pretrained source.

        Args:
            checkpoint: In-memory payload, pointer file, checkpoint directory,
                or tagged checkpoint directory.
            *args: Forwarded to model loading for client-state EMA payloads.
            **kwargs: Forwarded to model loading for client-state EMA payloads.

        **Side effects:** loads model weights through the DeepSpeed engine and
        updates `config.pretrained`. Optimizer, scheduler, dataloaders, and
        runner progress are untouched.
        """
        if isinstance(checkpoint, Mapping):
            return super().load_pretrained(checkpoint, *args, **kwargs)

        checkpoint_dir, checkpoint_tag = self._resolve_deepspeed_checkpoint(checkpoint)
        _, client_state = self.model.load_checkpoint(
            checkpoint_dir,
            tag=checkpoint_tag,
            load_module_only=True,
        )

        if client_state is not None and client_state.get("ema") is not None:
            self.load_model(client_state["ema"], *args, **kwargs)

        self.config.pretrained = os.fsdecode(checkpoint)

    @classmethod
    def read_config(
        cls,
        checkpoint: Mapping | bytes | str | os.PathLike,
        *args,
        **kwargs,
    ) -> RunnerConfig:
        if isinstance(checkpoint, Mapping):
            return super().read_config(checkpoint, *args, **kwargs)

        if isinstance(checkpoint, (bytes, str, os.PathLike)):
            checkpoint_path = os.fsdecode(checkpoint)

            if os.path.isdir(checkpoint_path):
                runner_yaml = os.path.join(checkpoint_path, "runner.yaml")
                if os.path.isfile(runner_yaml):
                    return RunnerConfig.from_yaml(runner_yaml, *args, **kwargs)

                latest_pointer = os.path.join(checkpoint_path, "latest.pointer")
                if os.path.isfile(latest_pointer):
                    tag = cls._read_checkpoint_pointer(latest_pointer)
                    tagged_runner_yaml = os.path.join(checkpoint_path, tag, "runner.yaml")
                    if os.path.isfile(tagged_runner_yaml):
                        return load(tagged_runner_yaml, *args, **kwargs)

                latest_file = os.path.join(checkpoint_path, "latest")
                if os.path.isfile(latest_file):
                    tag = cls._read_checkpoint_pointer(latest_file)
                    if tag:
                        tagged_runner_yaml = os.path.join(checkpoint_path, tag, "runner.yaml")
                        if os.path.isfile(tagged_runner_yaml):
                            return load(tagged_runner_yaml, *args, **kwargs)
                elif os.path.isdir(latest_file):
                    tagged_runner_yaml = os.path.join(latest_file, "runner.yaml")
                    if os.path.isfile(tagged_runner_yaml):
                        return load(tagged_runner_yaml, *args, **kwargs)

            if os.path.isfile(checkpoint_path):
                tag = cls._read_checkpoint_pointer(checkpoint_path)
                if tag:
                    tagged_runner_yaml = os.path.join(os.path.dirname(checkpoint_path), tag, "runner.yaml")
                    if os.path.isfile(tagged_runner_yaml):
                        return load(tagged_runner_yaml, *args, **kwargs)

        return super().read_config(checkpoint, *args, **kwargs)

materialize_model

Python
materialize_model() -> None

Move and compile the local model before DeepSpeed engine creation.

Called when: TorchRunner.__post_init__ reaches materialize_model, before build_optimizer, build_scheduler, and _finalize_runtime_components.

Precondition: self.model is the user-provided nn.Module, not yet a DeepSpeed engine.

Raises:

Type Description
ValueError

self.model is not initialized.

Side effects: moves the model and optional EMA module to self.device, applies FP8 policy when enabled, and compiles the model. DeepSpeed wrapping happens later in the engine-finalization step.

Do not

  • Call deepspeed.initialize here; optimizer and scheduler build happen after this hook.
  • DDP-wrap the model; DeepSpeed owns distributed wrapping.
Source code in danling/runners/deepspeed_runner.py
Python
def materialize_model(self) -> None:
    """
    Move and compile the local model before DeepSpeed engine creation.

    **Called when:** `TorchRunner.__post_init__` reaches
    `materialize_model`, before `build_optimizer`, `build_scheduler`, and
    `_finalize_runtime_components`.

    **Precondition:** `self.model` is the user-provided `nn.Module`, not
    yet a DeepSpeed engine.

    Raises:
        ValueError: `self.model` is not initialized.

    **Side effects:** moves the model and optional EMA module to
    `self.device`, applies FP8 policy when enabled, and compiles the model.
    DeepSpeed wrapping happens later in the engine-finalization step.

    !!! danger "Do not"
        - Call `deepspeed.initialize` here; optimizer and scheduler build
          happen after this hook.
        - DDP-wrap the model; DeepSpeed owns distributed wrapping.
    """
    if self.model is None:
        raise ValueError("cannot materialize DeepSpeed model: model is not initialized")

    model = self.model.to(self.device)
    self.model = model
    if self.fp8_enabled:
        self.apply_fp8_module_policy_to_model_parts()
        model = self.model
    model = self.compiler.compile(model)
    self.model = model

    if self.ema is not None:
        self.ema = self.ema.to(self.device)

get_deepspeed_config

Python
get_deepspeed_config() -> dict[str, Any]

Build the effective DeepSpeed config.

Called when: _finalize_runtime_components initializes the DeepSpeed engine.

Returns:

Type Description
dict[str, Any]

A mutable config dict suitable for deepspeed.initialize.

Raises:

Type Description
ValueError

config.deepspeed is present but not a mapping.

Side effects: none. The returned config forces gradient_accumulation_steps=1 because DanLing owns accumulation boundaries, fills train_micro_batch_size_per_gpu from the dataloader batch size when absent, and mirrors runner precision into DeepSpeed precision sections when possible.

Source code in danling/runners/deepspeed_runner.py
Python
def get_deepspeed_config(self) -> dict[str, Any]:
    """
    Build the effective DeepSpeed config.

    **Called when:** `_finalize_runtime_components` initializes the
    DeepSpeed engine.

    Returns:
        A mutable config dict suitable for `deepspeed.initialize`.

    Raises:
        ValueError: `config.deepspeed` is present but not a mapping.

    **Side effects:** none. The returned config forces
    `gradient_accumulation_steps=1` because DanLing owns accumulation
    boundaries, fills `train_micro_batch_size_per_gpu` from the dataloader
    batch size when absent, and mirrors runner precision into DeepSpeed
    precision sections when possible.
    """
    runtime_config = getattr(self, "deepspeed_config", None)
    if runtime_config is not None:
        return dict(runtime_config)

    cfg = self.config.get("deepspeed")
    if cfg is None:
        ds_config: dict[str, Any] = {}
    elif isinstance(cfg, Mapping):
        ds_config = dict(cfg)
    else:
        raise ValueError(f"invalid deepspeed config: expected mapping, got {type(cfg).__name__}")

    grad_accum = ds_config.get("gradient_accumulation_steps")
    if grad_accum is not None and grad_accum != 1:
        warn(
            "DeepSpeedRunner manages accumulation via config.accum_steps; overriding "
            "deepspeed.gradient_accumulation_steps to 1",
            RuntimeWarning,
            stacklevel=2,
        )
    ds_config["gradient_accumulation_steps"] = 1

    if "train_micro_batch_size_per_gpu" not in ds_config:
        batch_size = self.config.get("dataloader.batch_size")
        if batch_size is not None:
            ds_config["train_micro_batch_size_per_gpu"] = batch_size

    precision = self.precision
    if precision is not None:
        normalized_precision = str(precision).lower().replace("-", "_")
        if normalized_precision in {"fp16", "float16", "half"} and "fp16" not in ds_config:
            ds_config["fp16"] = {"enabled": True}
        if normalized_precision in {"bf16", "bfloat16"} and "bf16" not in ds_config:
            ds_config["bf16"] = {"enabled": True}

    return ds_config

backward

Python
backward(loss: Tensor) -> None

Route one micro-step backward pass through the DeepSpeed engine.

Parameters:

Name Type Description Default
loss
Tensor

Raw micro-step loss from train_step.

required

Side effects: accumulates gradients inside the DeepSpeed engine after DanLing’s loss-scaling/normalization policy is applied.

Source code in danling/runners/deepspeed_runner.py
Python
def backward(self, loss: torch.Tensor) -> None:
    """
    Route one micro-step backward pass through the DeepSpeed engine.

    Args:
        loss: Raw micro-step loss from `train_step`.

    **Side effects:** accumulates gradients inside the DeepSpeed engine
    after DanLing's loss-scaling/normalization policy is applied.
    """
    self.model.backward(self._scaled_loss_for_backward(loss))

optimizer_step

Python
optimizer_step() -> bool

Perform one DeepSpeed engine optimizer update.

DeepSpeed owns the concrete optimizer step; DanLing keeps accumulation normalization, runner state, profiler, timeout, and supervisor state in sync.

Source code in danling/runners/deepspeed_runner.py
Python
def optimizer_step(self) -> bool:
    """
    Perform one DeepSpeed engine optimizer update.

    DeepSpeed owns the concrete optimizer step; DanLing keeps accumulation
    normalization, runner state, profiler, timeout, and supervisor state in sync.
    """
    self.checkpoint_manager.maybe_wait_for_staging()
    grad_scale = self._gradient_scale_for_step()
    if grad_scale is not None:
        self._scale_optimizer_gradients(grad_scale)
    self.model.step()
    self._reset_accumulation_normalization()
    global_steps = getattr(self.model, "global_steps", None)
    if global_steps is None:
        self.train_state.global_step += 1
    else:
        self.train_state.global_step = int(global_steps)
    self._step_profiler()
    self._maybe_reduce_train_process_group_timeout()
    self.supervisor.maybe_collect_garbage(self.train_state.global_step, scope="train")
    return True

save_checkpoint

Python
save_checkpoint(
    name: str = "latest",
    epochs: int | None = None,
    save_best: bool = True,
    last_step: bool = False,
    force: bool = False,
) -> None

Save a DeepSpeed checkpoint and publish DanLing pointer aliases.

Called when: the training loop or shutdown supervisor requests a checkpoint save.

Parameters:

Name Type Description Default
name
str

Logical alias to publish in addition to latest.

'latest'
epochs
int | None

Epoch index used for retention/history naming.

None
save_best
bool

Whether to publish best.pointer when the current result is best.

True
last_step
bool

Whether this is the final checkpoint save.

False
force
bool

Bypass checkpoint manager cadence checks.

False

Side effects: all ranks enter DeepSpeedEngine.save_checkpoint. The main process writes runner.yaml and pointer files for logical aliases. Success/failure is reported through the checkpoint manager.

Do not

  • Guard the whole method with is_main_process; DeepSpeed saves are collective.
  • Write aliases before save_checkpoint succeeds.
  • Use the generic file checkpoint payload here; DeepSpeed owns the physical checkpoint layout.
Source code in danling/runners/deepspeed_runner.py
Python
def save_checkpoint(
    self,
    name: str = "latest",
    epochs: int | None = None,
    save_best: bool = True,
    last_step: bool = False,
    force: bool = False,
) -> None:
    """
    Save a DeepSpeed checkpoint and publish DanLing pointer aliases.

    **Called when:** the training loop or shutdown supervisor requests a
    checkpoint save.

    Args:
        name: Logical alias to publish in addition to `latest`.
        epochs: Epoch index used for retention/history naming.
        save_best: Whether to publish `best.pointer` when the current
            result is best.
        last_step: Whether this is the final checkpoint save.
        force: Bypass checkpoint manager cadence checks.

    **Side effects:** all ranks enter `DeepSpeedEngine.save_checkpoint`.
    The main process writes `runner.yaml` and pointer files for logical
    aliases. Success/failure is reported through the checkpoint manager.

    !!! danger "Do not"
        - Guard the whole method with `is_main_process`; DeepSpeed saves
          are collective.
        - Write aliases before `save_checkpoint` succeeds.
        - Use the generic file checkpoint payload here; DeepSpeed owns the
          physical checkpoint layout.
    """
    epochs = self.train_state.epoch if epochs is None else epochs
    if not self.checkpoint_manager.should_persist_checkpoint(epochs=epochs, last_step=last_step, force=force):
        return

    client_state: dict = BaseRunner.state_dict(self, dict)  # type: ignore[assignment]
    client_state["ema"] = self.ema.state_dict() if self.ema else None
    client_state["scheduler"] = (
        self.scheduler.state_dict() if getattr(self, "_runner_owns_scheduler", False) and self.scheduler else None
    )
    should_update_best = bool(save_best and self.is_best)
    physical_tag = self._resolve_physical_checkpoint_tag(
        name=name, epochs=epochs, should_update_best=should_update_best
    )
    try:
        self.model.save_checkpoint(
            self.workspace.checkpoint_dir,
            tag=physical_tag,
            client_state=client_state,
            save_latest=False,
        )
    except Exception as exc:
        self._record_deepspeed_checkpoint_failure(exc, target=physical_tag)
        return

    if self.distributed and not self.is_main_process:
        return

    tag_dir = os.path.join(self.workspace.checkpoint_dir, physical_tag)
    try:
        if os.path.isdir(tag_dir):
            self.config.yaml(os.path.join(tag_dir, "runner.yaml"))
    except Exception as exc:
        self._record_deepspeed_checkpoint_failure(exc, target=physical_tag)
        return

    published_aliases: list[str] = []
    try:
        self._write_checkpoint_pointer("latest", physical_tag)
    except Exception as exc:
        self._record_deepspeed_checkpoint_failure(exc, target=physical_tag, alias="latest")
        return
    published_aliases.append("latest")

    if name not in {"latest", physical_tag}:
        try:
            self._write_checkpoint_pointer(name, physical_tag)
        except Exception as exc:
            self.checkpoint_manager.record_checkpoint_success(target=physical_tag, aliases=tuple(published_aliases))
            self._record_deepspeed_checkpoint_failure(exc, target=physical_tag, alias=name)
            return
        published_aliases.append(name)

    if should_update_best:
        try:
            self._write_checkpoint_pointer("best", physical_tag)
        except Exception as exc:
            self.checkpoint_manager.record_checkpoint_success(target=physical_tag, aliases=tuple(published_aliases))
            self._record_deepspeed_checkpoint_failure(exc, target=physical_tag, alias="best")
            return
        published_aliases.append("best")

    self.checkpoint_manager.record_checkpoint_success(target=physical_tag, aliases=tuple(published_aliases))

load_checkpoint

Python
load_checkpoint(
    checkpoint: Mapping | bytes | str | PathLike,
    *args: Any,
    **kwargs: Any
) -> None

Restore a full DeepSpeed checkpoint.

Mapping checkpoints delegate to TorchRunner.load_checkpoint. Path checkpoints resolve pointer files/directories to a DeepSpeed (checkpoint_dir, tag) pair, then load engine state and DanLing client state.

Parameters:

Name Type Description Default
checkpoint
Mapping | bytes | str | PathLike

In-memory payload, pointer file, checkpoint directory, or tagged checkpoint directory.

required
*args
Any

Forwarded to component loaders for client state.

()
**kwargs
Any

Forwarded to component loaders for client state.

{}

Side effects: restores DeepSpeed engine state, runner state, optional EMA, runner-owned scheduler state, dataloader state, and config.resume.

Do not

  • Treat DeepSpeed pointer files as torch load payloads; resolve them to a tag first.
  • Rebind an OptimizerContainer; DeepSpeed owns optimizer stepping.
Source code in danling/runners/deepspeed_runner.py
Python
def load_checkpoint(
    self,
    checkpoint: Mapping | bytes | str | os.PathLike,
    *args: Any,
    **kwargs: Any,
) -> None:
    """
    Restore a full DeepSpeed checkpoint.

    Mapping checkpoints delegate to `TorchRunner.load_checkpoint`. Path
    checkpoints resolve pointer files/directories to a DeepSpeed
    `(checkpoint_dir, tag)` pair, then load engine state and DanLing client
    state.

    Args:
        checkpoint: In-memory payload, pointer file, checkpoint directory,
            or tagged checkpoint directory.
        *args: Forwarded to component loaders for client state.
        **kwargs: Forwarded to component loaders for client state.

    **Side effects:** restores DeepSpeed engine state, runner state,
    optional EMA, runner-owned scheduler state, dataloader state, and
    `config.resume`.

    !!! danger "Do not"
        - Treat DeepSpeed pointer files as torch `load` payloads; resolve
          them to a tag first.
        - Rebind an `OptimizerContainer`; DeepSpeed owns optimizer
          stepping.
    """
    if isinstance(checkpoint, Mapping):
        super().load_checkpoint(checkpoint, *args, **kwargs)
        return

    checkpoint_dir, checkpoint_tag = self._resolve_deepspeed_checkpoint(checkpoint)
    _, client_state = self.model.load_checkpoint(checkpoint_dir, tag=checkpoint_tag)

    if client_state is not None:
        BaseRunner.load_state_dict(self, client_state)
        if self.ema is not None and client_state.get("ema") is not None:
            self.load_ema(client_state["ema"], *args, **kwargs)
        if getattr(self, "_runner_owns_scheduler", False) and client_state.get("scheduler") is not None:
            self.load_scheduler(client_state["scheduler"], *args, **kwargs)
        if self.dataloaders or "dataloaders" in client_state:
            self.load_dataloaders(client_state.get("dataloaders"))

    self.config.resume = os.fsdecode(checkpoint)
    self.optimizer_container = None

load_pretrained

Python
load_pretrained(
    checkpoint: Mapping | bytes | str | PathLike,
    *args: Any,
    **kwargs: Any
) -> None

Load DeepSpeed model weights without restoring training state.

Mapping checkpoints delegate to the generic pretrained path. Path checkpoints use DeepSpeedEngine.load_checkpoint(..., load_module_only=True). If DanLing client state contains EMA weights, EMA is used as the pretrained source.

Parameters:

Name Type Description Default
checkpoint
Mapping | bytes | str | PathLike

In-memory payload, pointer file, checkpoint directory, or tagged checkpoint directory.

required
*args
Any

Forwarded to model loading for client-state EMA payloads.

()
**kwargs
Any

Forwarded to model loading for client-state EMA payloads.

{}

Side effects: loads model weights through the DeepSpeed engine and updates config.pretrained. Optimizer, scheduler, dataloaders, and runner progress are untouched.

Source code in danling/runners/deepspeed_runner.py
Python
def load_pretrained(
    self,
    checkpoint: Mapping | bytes | str | os.PathLike,
    *args: Any,
    **kwargs: Any,
) -> None:
    """
    Load DeepSpeed model weights without restoring training state.

    Mapping checkpoints delegate to the generic pretrained path. Path
    checkpoints use `DeepSpeedEngine.load_checkpoint(..., load_module_only=True)`.
    If DanLing client state contains EMA weights, EMA is used as the
    pretrained source.

    Args:
        checkpoint: In-memory payload, pointer file, checkpoint directory,
            or tagged checkpoint directory.
        *args: Forwarded to model loading for client-state EMA payloads.
        **kwargs: Forwarded to model loading for client-state EMA payloads.

    **Side effects:** loads model weights through the DeepSpeed engine and
    updates `config.pretrained`. Optimizer, scheduler, dataloaders, and
    runner progress are untouched.
    """
    if isinstance(checkpoint, Mapping):
        return super().load_pretrained(checkpoint, *args, **kwargs)

    checkpoint_dir, checkpoint_tag = self._resolve_deepspeed_checkpoint(checkpoint)
    _, client_state = self.model.load_checkpoint(
        checkpoint_dir,
        tag=checkpoint_tag,
        load_module_only=True,
    )

    if client_state is not None and client_state.get("ema") is not None:
        self.load_model(client_state["ema"], *args, **kwargs)

    self.config.pretrained = os.fsdecode(checkpoint)

ParallelRunner

Bases: TorchRunner

Torch runner for data, FSDP, pipeline, and model-parallel stacks.

Use this runner when training spans explicit parallel axes (replicate, shard, pipeline, tensor, context, expert, expert_tensor) rather than plain DDP. It keeps the TorchRunner outer lifecycle and replaces the distributed topology, sampler, model materialization, collective reduction, pipeline step, and checkpoint semantics.

Checkpoint invariants
  • Distributed parallel runs use checkpoint.backend="dcp" only.
  • Single-local-part checkpoints use torch.distributed.checkpoint state-dict APIs when available.
  • Restore order is model first, then optimizer, then scheduler.

Attributes:

Name Type Description
topology ParallelTopology

Rank/axis layout for the current world.

parallel ParallelContext

Process-group/device-mesh context built from topology.

model_parts list[Module]

Local pipeline/FSDP model parts. self.model is the first local part for compatibility with TorchRunner helpers.

pipeline_schedule Any | None

Optional PyTorch pipeline schedule.

pipeline_has_first_stage bool

Whether this rank owns pipeline input.

pipeline_has_last_stage bool

Whether this rank owns pipeline target/loss.

Source code in danling/runners/parallel_runner.py
Python
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
class ParallelRunner(TorchRunner):
    """
    Torch runner for data, FSDP, pipeline, and model-parallel stacks.

    Use this runner when training spans explicit parallel axes (`replicate`,
    `shard`, `pipeline`, `tensor`, `context`, `expert`, `expert_tensor`) rather
    than plain DDP. It keeps the TorchRunner outer lifecycle and replaces the
    distributed topology, sampler, model materialization, collective reduction,
    pipeline step, and checkpoint semantics.

    Checkpoint invariants:
        - Distributed parallel runs use `checkpoint.backend="dcp"` only.
        - Single-local-part checkpoints use torch.distributed.checkpoint
          state-dict APIs when available.
        - Restore order is model first, then optimizer, then scheduler.

    Attributes:
        topology: Rank/axis layout for the current world.
        parallel: Process-group/device-mesh context built from `topology`.
        model_parts: Local pipeline/FSDP model parts. `self.model` is the
            first local part for compatibility with TorchRunner helpers.
        pipeline_schedule: Optional PyTorch pipeline schedule.
        pipeline_has_first_stage: Whether this rank owns pipeline input.
        pipeline_has_last_stage: Whether this rank owns pipeline target/loss.
    """

    topology: ParallelTopology
    parallel: ParallelContext
    pipeline_schedule: Any | None = None
    pipeline_has_first_stage: bool = True
    pipeline_has_last_stage: bool = True

    tensor_group = None
    pipeline_group = None
    replicate_group = None
    shard_group = None
    context_group = None
    expert_group = None
    expert_tensor_group = None
    device_mesh = None
    _parallel_groups_initialized: bool = False
    _supports_torchft_runtime: bool = True
    _ft_reduced_domains = frozenset({"data", "batch", "loss", "optimizer", "fsdp"})
    _pipeline_loss_divisor_local: float = 0.0
    _pipeline_loss_weighting: str | None = None

    model_parts: list[nn.Module]

    checkpoint_manager: TorchDistributedCheckpointManager

    def __init__(self, config: Mapping[str, Any]) -> None:
        dcp.check()
        if not isinstance(config, RunnerConfig):
            config = RunnerConfig(config)
        requested_backend = str(config.checkpoint.backend).strip().lower()
        config.stack = "parallel"
        if requested_backend != "dcp":
            if requested_backend != "auto":
                warn(
                    f"{self.__class__.__name__} overrides checkpoint.backend to 'dcp'",
                    RuntimeWarning,
                    stacklevel=2,
                )
            config.checkpoint.backend = "dcp"
        super().__init__(config)
        self.dataloaders = _ParallelDataLoaderDict(self)

    @property
    def fsdp_enabled(self) -> bool:
        return bool(self.config.fsdp.get("enabled", False))

    def init_distributed(self) -> None:
        """
        Initialize default distributed state and parallel process groups.

        **Called when:** `BaseRunner.__init__` invokes `init_distributed`,
        before checkpoint manager/fault-tolerance setup and before model
        materialization.

        **Precondition:** `WORLD_SIZE > 1` and the configured parallel axis
        product equals `WORLD_SIZE`.

        Raises:
            RuntimeError: distributed mode is not active, or device-mesh process
                groups cannot be initialized.
            ValueError: `build_topology` rejects the configured axis product.

        **Side effects:** calls `TorchRunner.init_distributed`, builds
        `self.topology`, initializes the device mesh, binds per-axis process
        groups, and stores `self.parallel`.

        !!! danger "Do not"
            - Initialize model/pipeline/FSDP objects here; materialization
              happens in `materialize_model`.
            - Override this just to change axis degrees; set
              `config.parallel.axes` or override `build_topology`.
        """
        super().init_distributed()
        if self.world_size <= 1:
            raise RuntimeError("ParallelRunner requires distributed mode (WORLD_SIZE > 1)")
        self.topology = self.build_topology()
        if not self._parallel_groups_initialized:
            self._reset_model_parallel_groups()
            self._init_model_parallel_groups()
            self._parallel_groups_initialized = True

    def build_topology(self) -> ParallelTopology:
        """
        Build the rank-to-axis topology for this parallel run.

        **Called when:** `init_distributed` has initialized the default process
        group and needs per-axis domains.

        Returns:
            `ParallelTopology` with axis degrees, current-rank coordinates, and
            named reduction domains.

        Raises:
            ValueError: any axis degree is less than one, or the product of axis
                degrees does not equal `WORLD_SIZE`.

        **Side effects:** none. Override this only for non-standard axis/domain
        layouts; normal users should configure `config.parallel.axes`.
        """
        axes = {
            "replicate": int(self.config.parallel.axes.replicate),
            "shard": int(self.config.parallel.axes.shard),
            "context": int(self.config.parallel.axes.context),
            "pipeline": int(self.config.parallel.axes.pipeline),
            "tensor": int(self.config.parallel.axes.tensor),
            "expert": int(self.config.parallel.axes.expert),
            "expert_tensor": int(self.config.parallel.axes.expert_tensor),
        }
        return ParallelTopology(
            world_size=self.world_size,
            rank=self.rank,
            axes=axes,
            domains={
                "data": ("replicate", "shard"),
                "batch": ("replicate", "shard"),
                "loss": ("replicate", "shard", "context"),
                "optimizer": tuple(axes),
                "fsdp": ("replicate", "shard", "context"),
                "context": ("context",),
                "pipeline": ("pipeline",),
                "tensor": ("tensor",),
                "expert": ("expert",),
                "expert_tensor": ("expert_tensor",),
            },
            label="parallel topology",
        )

    def _reset_model_parallel_groups(self) -> None:
        self.tensor_group = None
        self.pipeline_group = None
        self.replicate_group = None
        self.shard_group = None
        self.context_group = None
        self.expert_group = None
        self.expert_tensor_group = None
        self.device_mesh = None
        if hasattr(self, "topology"):
            self.parallel = ParallelContext(self.topology)

    def _init_model_parallel_groups(self) -> None:
        use_device_mesh = self.config.parallel.use_device_mesh
        if not use_device_mesh:
            raise RuntimeError("cannot initialize parallel process groups: set `parallel.use_device_mesh=True`.")

        mesh_device_type = self.config.parallel.mesh_device_type
        if mesh_device_type is None:
            mesh_device_type = "cuda" if torch.cuda.is_available() else "cpu"
        self.device_mesh = init_device_mesh(
            mesh_device_type,
            mesh_shape=self.topology.mesh_shape,
            mesh_dim_names=self.topology.axis_names,
        )
        self.parallel = ParallelContext(
            self.topology,
            device_mesh=self.device_mesh,
            groups={axis: self.device_mesh.get_group(axis) for axis in self.topology.axis_names},
        )
        self.shard_group = self.parallel.group("shard")
        self.replicate_group = self.parallel.group("replicate")
        self.context_group = self.parallel.group("context")
        self.pipeline_group = self.parallel.group("pipeline")
        self.tensor_group = self.parallel.group("tensor")
        self.expert_group = self.parallel.group("expert")
        self.expert_tensor_group = self.parallel.group("expert_tensor")

    def _timeout_process_groups(self) -> tuple[Any | None, ...]:
        groups = list(super()._timeout_process_groups())
        if hasattr(self, "parallel"):
            groups.extend(group for group in self.parallel.groups.values() if group is not None)
        return tuple(groups)

    def __post_init__(self):
        self._pipeline_loss_divisor_local = 0.0
        self._pipeline_loss_weighting = None
        if self.fsdp_enabled:
            parallel_fsdp.check()
        torchft_config_supported = (
            self.fsdp_enabled
            and int(self.config.parallel.axes.pipeline) == 1
            and int(self.config.parallel.axes.tensor) == 1
            and int(self.config.parallel.axes.context) == 1
            and int(self.config.parallel.axes.expert) == 1
            and int(self.config.parallel.axes.expert_tensor) == 1
        )
        if self.ft is not None and self.ft.enabled and not torchft_config_supported:
            raise NotImplementedError(
                "ParallelRunner TorchFT integration currently requires FSDP with "
                "pipeline/tensor/context/expert axes set to 1"
            )
        if not self.model_parts:
            if self.model is None:
                raise ValueError("cannot initialize model_parts: model is not initialized")
            self.model_parts = [self.model]
        super().__post_init__()

    def materialize_model(self) -> None:
        """
        Materialize local model parts for FSDP/pipeline/model-parallel training.

        **Called when:** `TorchRunner.__post_init__` reaches
        `materialize_model`, after FP8 setup and before optimizer build.

        **Precondition:** either `self.model` or `self.model_parts` is bound.
        Pipeline runs may also provide `self.pipeline_schedule`; otherwise a
        single local model is converted to a pipeline stage when
        `pipeline_degree > 1`.

        Raises:
            RuntimeError: FSDP prerequisites are unavailable.
            ValueError: model/model_parts are missing or an unsupported
                auto-pipeline shape is requested.

        **Side effects:** moves local parts to `self.device`, calls
        `parallelize_model`, applies FP8 policy, compiles each part, optionally
        wraps parts with FSDP2 after `apply_activation_checkpointing`, binds
        pipeline schedule modules, installs TorchFT all-reduce hooks for FSDP,
        and moves EMA to device.

        !!! danger "Do not"
            - Build the optimizer before this hook; optimizer parameters must
              come from materialized/wrapped parts.
            - FSDP-wrap before `apply_activation_checkpointing`.
            - Replace `self.model_parts` without keeping `self.model` aligned
              to the first local part.
        """
        if self.fsdp_enabled:
            self._check_fsdp_prerequisites()
        self._maybe_init_pipeline_schedule_from_single_part()
        parts = self._prepare_local_model_parts()
        if self.fp8_enabled:
            self.apply_fp8_module_policy_to_model_parts()
            parts = list(self.model_parts)

        if self.fsdp_enabled:
            fsdp_kwargs = self.fsdp_kwargs()
            wrapped = [
                fully_shard(self.compiler.compile(self.apply_activation_checkpointing(part)), **fsdp_kwargs)
                for part in parts
            ]
        else:
            wrapped = [self.compiler.compile(part) for part in parts]

        self.model_parts = wrapped
        self.model = wrapped[0]
        self.bind_pipeline_modules(self.model_parts)

        if self.fsdp_enabled:
            self._apply_ft_all_reduce_hook()
        if self.ema is not None:
            self.ema = self.ema.to(self.device)

    def _check_fsdp_prerequisites(self) -> None:
        if fully_shard is None or FSDPModule is None:
            raise RuntimeError("cannot initialize ParallelRunner FSDP: torch.distributed.fsdp.fully_shard is required")
        if not torch.cuda.is_available():
            raise RuntimeError("ParallelRunner FSDP requires CUDA when WORLD_SIZE > 1")

    def _maybe_init_pipeline_schedule_from_single_part(self) -> None:
        if self.pipeline_schedule is not None or self.pipeline_degree <= 1 or self.pipeline_group is None:
            return
        if self.model_parts and len(self.model_parts) != 1:
            raise ValueError(
                "cannot auto-materialize pipeline from multiple local model_parts; "
                "provide `pipeline_schedule` explicitly when pre-partitioning local stages"
            )
        stage_model = self.model_parts[0] if self.model_parts else self.model
        if stage_model is None:
            raise ValueError("cannot materialize pipeline: model is not initialized")
        stage_models = self.build_pipeline_model_parts(stage_model)
        schedule_input: nn.Module | Sequence[nn.Module] = stage_models[0] if len(stage_models) == 1 else stage_models
        self.pipeline_schedule = self.build_pipeline_schedule(schedule_input)
        self.model_parts = stage_models
        self.model = stage_models[0]
        stage_indices = self.pipeline_stage_indices()
        self.pipeline_has_first_stage = 0 in stage_indices
        self.pipeline_has_last_stage = self._pipeline_num_stages() - 1 in stage_indices

    def _pipeline_num_stages(self) -> int:
        module_fqns_per_model_part = self.config.parallel.get("module_fqns_per_model_part")
        if module_fqns_per_model_part is not None:
            return len(module_fqns_per_model_part)
        return self.pipeline_degree

    def pipeline_stage_indices(self, num_stages: int | None = None) -> tuple[int, ...]:
        """
        Return the pipeline stage indices owned by this rank.

        The default supports the common looped virtual-stage mapping used by
        interleaved schedules: rank `r` owns `r`, `r + pp_degree`, ...
        Override this method for mirrored, zero-bubble, or other custom local
        stage placement.
        """
        if num_stages is None:
            num_stages = self._pipeline_num_stages()
        if num_stages < self.pipeline_degree:
            raise ValueError(
                "pipeline num_stages must be at least pipeline_degree " f"({self.pipeline_degree}), got {num_stages}"
            )
        if num_stages % self.pipeline_degree != 0:
            raise ValueError(
                "pipeline num_stages must be divisible by pipeline_degree "
                f"({self.pipeline_degree}), got {num_stages}"
            )

        stages_per_rank = num_stages // self.pipeline_degree
        if stages_per_rank == 1:
            return (self.pipeline_rank,)

        return tuple(self.pipeline_rank + offset * self.pipeline_degree for offset in range(stages_per_rank))

    def build_pipeline_model_part(self, model: nn.Module) -> nn.Module:
        """
        Return the local pipeline model part for this pipeline rank.

        The default supports two user-facing contracts:

        - If the model defines `build_pipeline_model_part(...)`, delegate to it.
        - If `parallel.module_fqns_per_model_part` is configured, extract those
          named modules for the current pipeline rank. Multiple FQNs become a
          simple `nn.Sequential` in the provided order.

        Complex graph partitioning should be implemented in the model hook or
        by overriding this method.
        """
        stage_index = self.pipeline_stage_indices()[0]
        module_fqns = self._pipeline_module_fqns_for_stage(stage_index)
        return self._build_pipeline_model_part(model, stage_index, self._pipeline_num_stages(), module_fqns)

    def build_pipeline_model_parts(self, model: nn.Module) -> list[nn.Module]:
        """
        Return all local pipeline model parts for this pipeline rank.

        Override this when a schedule maps multiple stages to each local rank
        and the default FQN/model-owned partitioning is not expressive enough.
        """
        stage_indices = self.pipeline_stage_indices()
        if len(stage_indices) == 1:
            return [self.build_pipeline_model_part(model)]

        build_part = getattr(model, "build_pipeline_model_part", None)
        has_fqn_partitions = self.config.parallel.get("module_fqns_per_model_part") is not None
        if not callable(build_part) and not has_fqn_partitions:
            raise ValueError(
                "multiple local pipeline stages require `parallel.module_fqns_per_model_part`, "
                "`model.build_pipeline_model_part(...)`, or an override of "
                "`ParallelRunner.build_pipeline_model_parts`"
            )

        num_stages = self._pipeline_num_stages()
        return [
            self._build_pipeline_model_part(
                model,
                stage_index,
                num_stages,
                self._pipeline_module_fqns_for_stage(stage_index),
            )
            for stage_index in stage_indices
        ]

    def _build_pipeline_model_part(
        self,
        model: nn.Module,
        stage_index: int,
        num_stages: int,
        module_fqns: tuple[str, ...] | None,
    ) -> nn.Module:
        build_part = getattr(model, "build_pipeline_model_part", None)
        if callable(build_part):
            part = build_part(
                stage_index=stage_index,
                num_stages=num_stages,
                module_fqns=module_fqns,
                parallel=self.parallel,
            )
            if part is None:
                return model
            if not isinstance(part, nn.Module):
                raise TypeError(
                    "model.build_pipeline_model_part(...) must return an nn.Module or None, "
                    f"got {type(part).__name__}"
                )
            return part

        if module_fqns is None:
            return model
        return self._build_pipeline_model_part_from_fqns(model, module_fqns)

    def _pipeline_module_fqns_for_stage(self, stage_index: int) -> tuple[str, ...] | None:
        module_fqns_per_model_part = self.config.parallel.get("module_fqns_per_model_part")
        if module_fqns_per_model_part is None:
            return None
        if stage_index < 0 or stage_index >= len(module_fqns_per_model_part):
            raise ValueError(
                "pipeline stage index is outside parallel.module_fqns_per_model_part: "
                f"stage_index={stage_index}, num_stages={len(module_fqns_per_model_part)}"
            )
        module_fqns = module_fqns_per_model_part[stage_index]
        if isinstance(module_fqns, str):
            module_fqns = (module_fqns,)
        else:
            module_fqns = tuple(str(module_fqn) for module_fqn in module_fqns)
        if not module_fqns:
            raise ValueError(
                "parallel.module_fqns_per_model_part entries must not be empty; "
                f"pipeline stage {stage_index} has no modules"
            )
        return module_fqns

    def _pipeline_module_fqns_for_rank(self) -> tuple[str, ...] | None:
        return self._pipeline_module_fqns_for_stage(self.pipeline_stage_indices()[0])

    def _build_pipeline_model_part_from_fqns(self, model: nn.Module, module_fqns: Sequence[str]) -> nn.Module:
        modules = dict(model.named_modules())
        if "" in module_fqns:
            raise ValueError("parallel.module_fqns_per_model_part may not select the root module")
        missing = [module_fqn for module_fqn in module_fqns if module_fqn not in modules]
        if missing:
            raise ValueError(f"unknown pipeline module FQN(s): {missing}")
        if len(set(module_fqns)) != len(module_fqns):
            raise ValueError(f"duplicate pipeline module FQN(s): {list(module_fqns)}")
        if len(module_fqns) == 1:
            return modules[module_fqns[0]]
        return nn.Sequential(
            OrderedDict((module_fqn.replace(".", "_"), modules[module_fqn]) for module_fqn in module_fqns)
        )

    def _prepare_local_model_parts(self) -> list[nn.Module]:
        if self.pipeline_schedule is None:
            if self.model is None:
                if self.model_parts:
                    self.model = self.model_parts[0]
                else:
                    raise ValueError("cannot materialize parallel model: model is not initialized")
            parts: list[nn.Module] = [self.model]
        else:
            if not self.model_parts:
                if self.model is None:
                    raise ValueError("cannot materialize pipeline: model_parts are not initialized")
                self.model_parts = [self.model]
            parts = list(self.model_parts)
        parts = [part.to(self.device) for part in parts]
        parts = [self.parallelize_model(part) for part in parts]
        self.model_parts = parts
        self.model = parts[0]
        return parts

    def _apply_ft_all_reduce_hook(self) -> None:
        if self.ft is None:
            return
        group = self.ft.replicate_process_group
        if group is None:
            return

        def all_reduce_hook(output):
            dist.all_reduce(output, group=group, op=dist.ReduceOp.AVG)

        def apply_hook(module: nn.Module) -> None:
            set_all_reduce_hook = getattr(module, "set_all_reduce_hook", None)
            if callable(set_all_reduce_hook):
                set_all_reduce_hook(all_reduce_hook)

        for model in self.model_parts:
            model.apply(apply_hook)

    def parallelize_model(self, model: nn.Module) -> nn.Module:
        """
        Apply model-specific tensor/context/expert parallel transforms.

        **Called when:** `_prepare_local_model_parts` materializes each local
        part, before compile and FSDP wrapping.

        Args:
            model: Local model part to transform.

        Returns:
            The transformed model. If the model defines
            `model.parallelize(parallel)`, that method may mutate in place and
            return `None`.

        Raises:
            TypeError: `model.parallelize` returns a non-module value.
            NotImplementedError: model-parallel axes are enabled but no
                transform hook is available.

        !!! danger "Do not"
            - Move the model to device here; the surrounding `materialize_model`
              flow handles device placement before this hook runs.
            - Compile or FSDP-wrap here; those happen after this hook.
        """
        parallelize = getattr(model, "parallelize", None)
        if callable(parallelize):
            parallelized = parallelize(self.parallel)
            if parallelized is None:
                return model
            if not isinstance(parallelized, nn.Module):
                raise TypeError(
                    "model.parallelize(parallel) must return an nn.Module or None, "
                    f"got {type(parallelized).__name__}"
                )
            return parallelized

        if self.model_parallel_degree > 1:
            axes = ", ".join(self.model_parallel_axes)
            raise NotImplementedError(
                f"parallel axes {axes} require model-specific parallelization. "
                "Implement `model.parallelize(parallel)` or override "
                "`ParallelRunner.parallelize_model`."
            )
        return model

    def fsdp_mesh(self):
        mesh = self.config.fsdp.get("mesh")
        if mesh is not None:
            return mesh
        if self.device_mesh is None:
            raise RuntimeError("cannot initialize ParallelRunner FSDP: device mesh is not initialized")

        if self.context_degree > 1:
            raise NotImplementedError(
                "ParallelRunner FSDP with context parallelism requires a flattened FSDP mesh; "
                "set fsdp.mesh explicitly or keep parallel.axes.context=1."
            )
        if self.replicate_degree > 1:
            return self.device_mesh["replicate", "shard"]
        return self.device_mesh["shard"]

    def build_mixed_precision_policy(self) -> object | None:
        return build_mixed_precision_policy(
            policy=self.config.fsdp.get("mp_policy"),
            mixed_precision_policy_cls=MixedPrecisionPolicy,
            label="fsdp.mp_policy",
        )

    def build_offload_policy(self) -> object | None:
        return build_offload_policy(
            policy=self.config.fsdp.get("offload_policy"),
            cpu_offload_policy_cls=CPUOffloadPolicy,
            label="fsdp.offload_policy",
        )

    def fsdp_kwargs(self) -> dict[str, Any]:
        return build_fsdp2_kwargs(
            config=self.config.fsdp,
            mesh=self.fsdp_mesh(),
            mixed_precision_policy=self.build_mixed_precision_policy(),
            offload_policy=self.build_offload_policy(),
            config_name="fsdp",
            supported_keys={
                "enabled",
                "mesh",
                "reshard_after_forward",
                "shard_placement_fn",
                "mp_policy",
                "offload_policy",
                "ignored_params",
            },
            support_hint="mesh/reshard_after_forward/mp_policy/offload_policy",
        )

    def apply_activation_checkpointing(self, model: nn.Module) -> nn.Module:
        """
        Apply activation checkpointing to one local model part.

        **Called when:** `materialize_model` wraps FSDP-enabled parts, before
        compile/FSDP wrapping.

        Args:
            model: Local model part.

        Returns:
            Model part with activation checkpointing wrappers applied.

        **Side effects:** default is a no-op. Overrides may mutate the module
        in place or return a wrapped module.

        !!! danger "Do not"
            - Change parameter ownership or shard layout here; FSDP has not
              wrapped the model yet.
            - Return a non-module value.
        """
        return model

    def bind_pipeline_modules(self, modules: Sequence[nn.Module]) -> None:
        if self.pipeline_schedule is None:
            return

        stages = getattr(self.pipeline_schedule, "stages", None)
        if stages is None:
            stage = getattr(self.pipeline_schedule, "stage", None)
            if stage is not None and modules:
                stage.module = modules[0]
                return
            if hasattr(self.pipeline_schedule, "module") and modules:
                self.pipeline_schedule.module = modules[0]
            return

        for stage, module in zip(stages, modules):
            if hasattr(stage, "module"):
                stage.module = module

    def iter_optimizer_parameters(self) -> Iterator[nn.Parameter]:
        parts: list[nn.Module] = list(self.model_parts or [])
        if not parts and self.model is not None:
            parts = [self.model]
        if not parts:
            return
        yield from self._iter_unique_parameters(parts)

    def iter_optimizer_named_parameters(self) -> Iterator[tuple[str, nn.Parameter]]:
        parts: list[nn.Module] = list(self.model_parts or [])
        if not parts and self.model is not None:
            parts = [self.model]
        if not parts:
            return
        prefixes = ("",) if len(parts) == 1 else tuple(f"part{index}." for index in range(len(parts)))
        yield from self._iter_unique_named_parameters(parts, prefixes)

    def unwrap(self, model: nn.Module) -> nn.Module:
        if FSDPModule is not None and isinstance(model, FSDPModule):
            return getattr(model, "module", model)
        return super().unwrap(model)

    def _train_no_sync_targets(self) -> tuple[nn.Module, ...]:
        fsdp_parts = [
            module for module in (self.model_parts or []) if FSDPModule is not None and isinstance(module, FSDPModule)
        ]
        if self.model is not None and not fsdp_parts and FSDPModule is not None and isinstance(self.model, FSDPModule):
            fsdp_parts = [self.model]
        if fsdp_parts:
            return tuple(fsdp_parts)
        return super()._train_no_sync_targets()

    def _resolve_pipeline_n_microbatches(self) -> int:
        configured = self.config.parallel.get("pipeline_n_microbatches")
        if configured is not None:
            n_microbatches = int(configured)
            if n_microbatches <= 0:
                raise ValueError(
                    f"invalid parallel.pipeline_n_microbatches: expected a positive integer, got {configured}"
                )
            return n_microbatches

        microbatch_size = int(self.config.parallel.get("pipeline_microbatch_size", 1))
        if microbatch_size <= 0:
            raise ValueError(
                f"invalid parallel.pipeline_microbatch_size: expected a positive integer, got {microbatch_size}"
            )

        try:
            batch_size = int(self.batch_size)
        except (AttributeError, TypeError, ValueError) as exc:
            raise ValueError(
                "cannot infer pipeline microbatch count: set `parallel.pipeline_n_microbatches` "
                "or provide `dataloader.batch_size`."
            ) from exc

        if batch_size <= 0:
            raise ValueError(f"invalid batch size: expected a positive integer, got {batch_size}")
        if batch_size % microbatch_size != 0:
            raise ValueError(
                f"batch size ({batch_size}) must be divisible by parallel.pipeline_microbatch_size ({microbatch_size})"
            )

        n_microbatches = batch_size // microbatch_size
        if n_microbatches < self.pipeline_degree:
            warn(
                f"n_microbatches ({n_microbatches}) is less than pipeline_degree ({self.pipeline_degree}); "
                "pipeline utilization may be suboptimal.",
                RuntimeWarning,
                stacklevel=2,
            )
        return n_microbatches

    def _pipeline_loss(self, pred: Any, target: Any) -> torch.Tensor:
        if self.criterion is None:
            raise ValueError("cannot compute pipeline loss: criterion is not initialized")

        loss = self.criterion(pred, target)
        if loss is None:
            raise ValueError("cannot compute pipeline loss: criterion did not produce a loss")
        if loss.ndim > 0:
            loss = loss.mean()

        normalizer = None
        if isinstance(target, Mapping):
            normalizer = self._mapping_loss_normalizer(target)
        if normalizer is None:
            normalizer = self._tensor_loss_normalizer(target)
        divisor = float(max(int(normalizer), 1)) if normalizer is not None else 1.0
        if self._pipeline_loss_weighting is not None:
            self._pipeline_loss_divisor_local += divisor
            if self._pipeline_loss_weighting == "train":
                self._accumulation_divisor_local += divisor
            return loss * divisor
        return loss

    def build_pipeline_schedule(self, stage_model: nn.Module | Sequence[nn.Module]) -> Any:
        """
        Build the PyTorch pipeline schedule for this rank.

        **Called when:** `materialize_model` sees `pipeline_degree > 1` and no
        explicit `pipeline_schedule` is already bound.

        Args:
            stage_model: Local stage module for this pipeline rank, or all
                local stage modules for an interleaved/multi-stage schedule.

        Returns:
            A PyTorch pipeline schedule instance.

        Raises:
            ValueError: pipeline microbatch count cannot be inferred or is
                inconsistent with batch size.

        **Side effects:** none beyond schedule construction. The caller binds
        the schedule modules after compile/FSDP wrapping.

        !!! danger "Do not"
            - Set `scale_grads=True`; DanLing owns gradient/loss scaling.
            - Build the optimizer here.
        """
        pipeline.check()
        schedule_name = str(self.config.parallel.get("pipeline_schedule", "1F1B")).strip() or "1F1B"
        n_microbatches = self._resolve_pipeline_n_microbatches()
        schedule_class = get_schedule_class(schedule_name)
        loss_fn = self._pipeline_loss if self.criterion is not None else None
        stage_models = [stage_model] if isinstance(stage_model, nn.Module) else list(stage_model)
        num_stages = self._pipeline_num_stages()
        stage_indices = self.pipeline_stage_indices(num_stages)
        if len(stage_models) != len(stage_indices):
            raise ValueError(
                "pipeline stage model count must match local pipeline stage indices: "
                f"{len(stage_models)} != {len(stage_indices)}"
            )
        stages = [
            PipelineStage(
                module,
                stage_index=stage_index,
                num_stages=num_stages,
                device=self.device,
                group=self.pipeline_group,
            )
            for module, stage_index in zip(stage_models, stage_indices)
        ]

        # Default to non-interleaved 1F1B for pipeline schedules until
        # pytorch/pytorch#164756 is addressed upstream, then we can migrate the
        # default to Interleaved1F1B.
        if issubclass(schedule_class, PipelineScheduleMulti):
            return schedule_class(
                stages,
                n_microbatches=n_microbatches,
                loss_fn=loss_fn,
                scale_grads=False,
            )
        if len(stages) != 1:
            raise ValueError(
                f"pipeline schedule {schedule_name!r} accepts one local stage, got {len(stages)}; "
                "choose an interleaved/multi-stage schedule or override `build_pipeline_schedule`."
            )
        return schedule_class(
            stages[0],
            n_microbatches=n_microbatches,
            loss_fn=loss_fn,
            scale_grads=False,
        )

    def build_datasampler(self, dataset: Any, *, split: str, shuffle: bool) -> Any:
        """
        Build a data-parallel sampler for one split.

        **Called when:** inherited `build_dataloaders` materializes a dataset
        split.

        Args:
            dataset: Dataset object for the split.
            split: Split name being materialized.
            shuffle: Whether to shuffle the split.

        Returns:
            `DistributedSampler` using topology data-parallel degree/rank,
            adjusted by TorchFT when active.
        """
        num_replicas = self.data_degree
        rank = self.data_rank
        if self.ft is not None:
            num_replicas, rank = self.ft.data_parallel_info(num_replicas, rank)
        return utils.data.distributed.DistributedSampler(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)

    def set_seed(self, seed: int | None = None, bias: int | bool | None = None) -> int:
        if bias is None:
            if self.ft is not None:
                _, bias = self.ft.data_parallel_info(self.data_degree, self.data_rank)
            else:
                bias = self.data_rank
        return super().set_seed(seed=seed, bias=bias)

    def _reduce_degree(self, domain: str = "data") -> int:
        degree = max(self.topology.domain_degree(domain), 1)
        if domain in self._ft_reduced_domains and self.ft is not None:
            group = self.ft.replicate_process_group
            if group is not None and dist.is_available() and dist.is_initialized():
                degree *= max(int(dist.get_world_size(group=group)), 1)
        return degree

    def all_reduce(self, tensor: torch.Tensor, *, domain: str = "data", op=dist.ReduceOp.SUM) -> torch.Tensor:
        if not (dist.is_available() and dist.is_initialized()):
            return tensor
        if self.topology.domain_degree(domain) > 1:
            self.parallel.all_reduce(tensor, domain=domain, op=op)
        group = self.ft.replicate_process_group if domain in self._ft_reduced_domains and self.ft is not None else None
        if group is not None:
            dist.all_reduce(tensor, op=op, group=group)
        return tensor

    def _sync_optimizer_skip_decision(self, should_skip: bool) -> bool:
        if not (self.distributed and dist.is_available() and dist.is_initialized()):
            return should_skip
        payload = torch.tensor(float(should_skip), device=self.all_reduce_device())
        self.all_reduce(payload, domain="optimizer", op=dist.ReduceOp.MAX)
        return payload.item() > 0

    def reduce(self, tensor):
        degree = self._reduce_degree("data")
        if degree <= 1 or not (dist.is_available() and dist.is_initialized()):
            return tensor
        original_device = tensor.device
        payload_device = self.all_reduce_device()
        payload = tensor if original_device == payload_device else tensor.to(payload_device)
        self.all_reduce(payload)
        payload = payload / degree
        if payload.device != original_device:
            payload = payload.to(original_device)
        return payload

    def reduce_loss_for_logging(self, loss: torch.Tensor | None, loss_n: int | None) -> torch.Tensor | None:
        if self.pipeline_schedule is None:
            if loss is None:
                return None
            loss_value = loss.detach().to(dtype=torch.float64)
            if loss_value.ndim > 0:
                loss_value = loss_value.mean()
            normalizer = float(max(int(loss_n or 1), 1))
            payload_device = self.all_reduce_device()
            payload = torch.stack(
                (
                    loss_value.to(device=payload_device) * normalizer,
                    torch.tensor(normalizer, dtype=torch.float64, device=payload_device),
                )
            )
            self.all_reduce(payload, domain="loss", op=dist.ReduceOp.SUM)
            if payload[1].item() <= 0:
                return None
            return payload[0] / payload[1]
        if not (dist.is_available() and dist.is_initialized()):
            return super().reduce_loss_for_logging(loss, loss_n)
        payload = torch.zeros((3,), dtype=torch.float64, device=self.device)
        is_reporter = self.pipeline_has_last_stage and self.tensor_rank == 0
        if is_reporter:
            if loss is not None:
                normalizer = float(max(int(loss_n or 1), 1))
                loss_value = loss.detach().to(dtype=torch.float64)
                if loss_value.ndim > 0:
                    loss_value = loss_value.mean()
                payload[0] = loss_value * normalizer
                payload[1] = normalizer
                payload[2] = 1.0
            self.all_reduce(payload, domain="loss")

        source_rank = self.topology.rank_from_coordinates({"pipeline": self.pipeline_degree - 1, "tensor": 0})
        dist.broadcast(payload, src=source_rank)
        if payload[2].item() <= 0 or payload[1].item() <= 0:
            return None
        return payload[0] / payload[1]

    @property
    def reports_batch_telemetry(self) -> bool:
        return self.pipeline_has_first_stage and self.tensor_rank == 0

    def _loss_normalizer_sync_divisor(self) -> int:
        if dist.is_available() and dist.is_initialized():
            return max(self._reduce_degree("loss"), 1)
        return 1

    def _reduce_loss_normalizer_total(self, local_total: float) -> float:
        if local_total <= 0:
            return local_total
        if self._loss_normalizer_sync_divisor() <= 1:
            return local_total
        if not (dist.is_available() and dist.is_initialized()):
            return local_total

        device = self.all_reduce_device()
        total_tensor = torch.tensor(local_total, dtype=torch.float64, device=device)
        self.all_reduce(total_tensor, domain="loss", op=dist.ReduceOp.SUM)
        return float(total_tensor.item())

    def _use_step_only_loader(self) -> bool:
        return (
            self.pipeline_schedule is not None
            and not self.pipeline_has_first_stage
            and not self.pipeline_has_last_stage
        )

    def _prepare_pipeline_batch(self, data: Any) -> tuple[Any | None, Any | None]:
        if self.pipeline_has_first_stage:
            if data is None:
                raise ValueError("cannot run pipeline stage: first stage requires dataloader inputs")
            data = self.to_device(data)
            if isinstance(data, Mapping):
                inputs = data["input"]
                target = data.get("target")
            elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
                inputs = data[0]
                target = data[1] if len(data) > 1 else None
            else:
                inputs = data
                target = None
            if not self.pipeline_has_last_stage:
                target = None
            return inputs, target

        if not self.pipeline_has_last_stage or data is None:
            return None, None
        data = self.to_device(data)
        if isinstance(data, Mapping):
            if "target" not in data:
                return None, None
            return None, data["target"]
        if isinstance(data, Sequence) and not isinstance(data, (str, bytes)) and len(data) > 1:
            return None, data[1]
        target = None
        return None, target

    def _pipeline_loss_value(self, losses: list[torch.Tensor]) -> torch.Tensor | None:
        if not (self.pipeline_has_last_stage and losses):
            return None
        loss = torch.stack(losses).sum()
        if self._pipeline_loss_divisor_local > 0:
            loss = loss / self._pipeline_loss_divisor_local
        else:
            loss = loss / len(losses)
        return loss

    def _sync_pipeline_accumulation_divisor(self) -> None:
        if self.pipeline_degree <= 1:
            return
        if self.pipeline_group is None or not (dist.is_available() and dist.is_initialized()):
            return

        value = self._pipeline_loss_divisor_local if self.pipeline_has_last_stage else 0.0
        device = self.all_reduce_device()
        payload = torch.tensor(value, dtype=torch.float64, device=device)
        coordinates = dict(self.topology.ranks)
        coordinates["pipeline"] = self.pipeline_degree - 1
        source_rank = self.topology.rank_from_coordinates(coordinates)
        dist.broadcast(payload, src=source_rank, group=self.pipeline_group)
        if not self.pipeline_has_last_stage:
            self._accumulation_divisor_local += float(payload.item())

    @contextmanager
    def train_context(self):
        if self.pipeline_schedule is None:
            with super().train_context():
                yield
            return

        with self._train_step_context(no_sync_targets=self._train_no_sync_targets()):
            yield

    def train_step(self, data: Any) -> tuple[Any, torch.Tensor | None]:
        """
        Run one training micro-step for plain or pipeline-parallel execution.

        Non-pipeline configurations delegate to `TorchRunner.train_step`.
        Pipeline configurations call the schedule, compute loss only on last
        stages, synchronize accumulation normalization across the pipeline, and
        then delegate optimizer-boundary handling to `step()`.

        **Called when:** `train_epoch`/`train_steps` consume one micro-batch.

        Args:
            data: Micro-batch from the local loader. Non-first/non-last
                pipeline stages may receive `None` through `StepProxyLoader`.

        Returns:
            `(None, loss)` for pipeline mode, where `loss` is present only on
            ranks that can report last-stage loss. Non-pipeline mode returns
            the TorchRunner result.

        !!! danger "Do not"
            - Call the optimizer directly; use `step()`.
            - Update metrics from pipeline mode here; pipeline schedule outputs
              are not a normal full-batch prediction.
            - Manually divide gradients by pipeline microbatch count.
        """
        if self.pipeline_schedule is None:
            return super().train_step(data)

        with self.train_context():
            self._pipeline_loss_divisor_local = 0.0
            self._pipeline_loss_weighting = "train"
            inputs, target = self._prepare_pipeline_batch(data)
            losses: list[torch.Tensor] = []
            targets = target if self.pipeline_has_last_stage else None

            try:
                if self.pipeline_has_first_stage:
                    self.pipeline_schedule.step(
                        inputs,
                        target=targets,
                        losses=losses,
                    )
                else:
                    self.pipeline_schedule.step(
                        target=targets,
                        losses=losses,
                    )
            finally:
                self._pipeline_loss_weighting = None

            loss = self._pipeline_loss_value(losses)

            pred = None
            self._sync_pipeline_accumulation_divisor()
            self.step()
        return pred, loss

    def evaluate_step(self, data: Any) -> tuple[Any, torch.Tensor | None]:
        """
        Run one evaluation micro-step for plain or pipeline execution.

        Non-pipeline configurations delegate to `TorchRunner.evaluate_step`.
        Pipeline configurations call the schedule in eval mode and report
        normalized loss from last-stage ranks.

        **Called when:** `evaluate_epoch`/`evaluate_steps` consume one
        micro-batch under inference mode.

        Args:
            data: Micro-batch from the local loader. Non-first/non-last
                pipeline stages may receive `None`.

        Returns:
            `(None, loss)` for pipeline mode. Non-pipeline mode returns the
            TorchRunner result.

        !!! danger "Do not"
            - Call backward or step.
            - Assume every rank has targets; only last-stage ranks need them.
        """
        if self.pipeline_schedule is None:
            return super().evaluate_step(data)

        with self.forward_context():
            self._pipeline_loss_divisor_local = 0.0
            self._pipeline_loss_weighting = "eval"
            inputs, target = self._prepare_pipeline_batch(data)
            losses: list[torch.Tensor] = []
            targets = target if self.pipeline_has_last_stage else None

            try:
                if self.pipeline_has_first_stage:
                    self.pipeline_schedule.eval(
                        inputs,
                        target=targets,
                        losses=losses,
                    )
                else:
                    self.pipeline_schedule.eval(
                        target=targets,
                        losses=losses,
                    )
            finally:
                self._pipeline_loss_weighting = None

            loss = self._pipeline_loss_value(losses)

        return None, loss

    @staticmethod
    def _normalize_infer_output(pred: Any) -> list[float]:
        if pred is None:
            return []
        if torch.is_tensor(pred):
            values = pred.detach().reshape(-1).cpu().tolist()
            if isinstance(values, list):
                return [float(value) for value in values]
            return [float(values)]
        if isinstance(pred, Mapping):
            mapped_values: list[float] = []
            for value in pred.values():
                mapped_values.extend(ParallelRunner._normalize_infer_output(value))
            return mapped_values
        if isinstance(pred, Sequence) and not isinstance(pred, (str, bytes)):
            seq_values: list[float] = []
            for value in pred:
                seq_values.extend(ParallelRunner._normalize_infer_output(value))
            return seq_values
        if isinstance(pred, (bool, int, float)):
            return [float(pred)]
        raise ValueError(
            "cannot normalize pipeline infer output: unsupported type "
            f"{type(pred).__name__}; override ParallelRunner.infer_step for custom formats"
        )

    @torch.inference_mode()
    def infer_step(self, data: Any) -> list[float]:
        """
        Run one inference micro-step for plain or pipeline execution.

        Non-pipeline configurations delegate to `TorchRunner.infer_step`.
        Pipeline configurations call the schedule in eval mode and normalize
        whatever the schedule returns into a flat list of floats.

        Args:
            data: Micro-batch on first-stage ranks; `None` on non-first stages
                that only participate in pipeline communication.

        Returns:
            Flat list of numeric predictions. Non-output ranks may return an
            empty list.

        Raises:
            ValueError: pipeline output cannot be normalized into floats.
        """
        if self.pipeline_schedule is None:
            return super().infer_step(data)

        with self.forward_context():
            inputs, _ = self._prepare_pipeline_batch(data)
            if self.pipeline_has_first_stage:
                pred = self.pipeline_schedule.eval(inputs)
            else:
                pred = self.pipeline_schedule.eval()
        return self._normalize_infer_output(pred)

    def infer(
        self,
        split: str = "infer",
        *,
        steps: int | None = None,
        stream: bool | None = None,
    ) -> list[float] | Iterator[list[float]]:
        """
        Run inference across a pipeline-aware loader.

        Non-pipeline configurations delegate to `TorchRunner.infer`. Pipeline
        configurations consume real dataloader batches only on first-stage
        ranks; other stages run `infer_step(None)` for the same number of
        steps.

        Args:
            split: Inference split name.
            steps: Optional maximum number of batches/stage ticks.
            stream: Whether to return a per-batch iterator instead of a
                flattened list.

        Returns:
            Flattened predictions or a streaming iterator.

        Raises:
            ValueError: `steps` is negative, or a non-first pipeline stage has
                an unsized loader and no explicit step count.
        """
        if self.pipeline_schedule is None:
            return super().infer(split=split, steps=steps, stream=stream)

        self.mode = RunnerMode.infer
        self.split = split
        loader = self.dataloaders[split]

        if steps is not None and steps < 0:
            raise ValueError(f"invalid steps: expected a non-negative value, got {steps}")

        loader_length = self._loader_length(loader)
        if stream is None:
            stream = steps is None and loader_length is None

        if self.pipeline_has_first_stage:
            if not stream and loader_length is None and steps is None:
                raise ValueError("infer with stream=False requires `steps` for unsized loaders")
            if steps is not None:
                iterator = (self.infer_step(data) for iteration, data in enumerate(loader) if iteration < steps)
            else:
                iterator = (self.infer_step(data) for data in loader)
            total = steps if steps is not None else loader_length
        else:
            if steps is None:
                if loader_length is None:
                    raise ValueError("infer for non-first pipeline stages requires `steps` for unsized loaders")
                steps = loader_length
            iterator = (self.infer_step(None) for _ in range(steps))
            total = steps

        if stream:
            return iterator

        output: list[float] = []
        for values in tqdm(iterator, total=total, disable=self.distributed and not self.is_main_process):
            output.extend(values)
        return output

    def _export_checkpoint_metadata(self, cls: type = dict) -> Mapping[str, Any]:
        state = cls({"parallel": cls({"axes": cls(self.parallel_axes_state(dict))})})
        if self.fsdp_enabled:
            state["fsdp"] = cls(
                {
                    "mode": self.fsdp_mode,
                    "data_degree": self.data_degree,
                    "shard_degree": self.shard_degree,
                    "replicate_degree": self.replicate_degree,
                    "context_degree": self.context_degree,
                }
            )
        return state

    def _export_checkpoint_components(self, cls: type = dict) -> Mapping[str, Any]:
        state = cls()
        state["ema"] = self.ema.state_dict() if self.ema else None
        state["scheduler"] = self.scheduler.state_dict() if self.scheduler else None
        if len(self.model_parts) != 1:
            state["optimizer"] = self.optimizer.state_dict() if self.optimizer else None
            state["model_parts"] = [self.unwrap(model).state_dict() for model in self.model_parts]
            return state

        model_state_dict, optim_state_dict = self.checkpoint_manager.export_model_optimizer_state(
            model=self.model_parts[0],
            optimizer=self.optimizer,
            options_cls=StateDictOptions,
            strict=True,
        )
        state["model"] = model_state_dict
        state["optimizer"] = optim_state_dict if self.optimizer is not None else None
        return state

    def _restore_model_checkpoint(
        self, state_dict: Mapping[str, Any] | list[Mapping[str, Any]], *args, **kwargs
    ) -> None:
        if isinstance(state_dict, list):
            state_dicts = state_dict
            if len(state_dicts) != len(self.model_parts):
                raise ValueError(
                    "cannot load parallel checkpoint: model_parts count mismatch: "
                    f"expected {len(self.model_parts)}, got {len(state_dicts)}"
                )
            for model, model_state_dict in zip(self.model_parts, state_dicts):
                self.unwrap(model).load_state_dict(model_state_dict, *args, **kwargs)
            return

        if len(self.model_parts) == 1:
            self.checkpoint_manager.load_model_state(
                model=self.model_parts[0],
                model_state_dict=state_dict,
                options_cls=StateDictOptions,
                strict=True,
            )
            return

        super()._restore_model_checkpoint(state_dict, *args, **kwargs)

    def _restore_optimizer_checkpoint(self, state_dict: Mapping[str, Any], *args, **kwargs) -> None:
        if len(self.model_parts) != 1:
            super()._restore_optimizer_checkpoint(state_dict, *args, **kwargs)
            return

        self.checkpoint_manager.load_optimizer_state(
            model=self.model_parts[0],
            optimizer=self.optimizer,
            optimizer_state_dict=state_dict,
            options_cls=StateDictOptions,
            strict=True,
        )

    def load_checkpoint(
        self,
        checkpoint: Mapping | bytes | str | os.PathLike,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """
        Restore a parallel checkpoint with topology validation.

        The checkpoint is read through the active DCP manager, validated against
        current parallel axes, optionally remapped for allowed non-FSDP degree
        changes, and then restored through the TorchRunner component loaders.

        Args:
            checkpoint: In-memory checkpoint mapping or DCP checkpoint path.
            *args: Forwarded to checkpoint reading and component loaders.
            **kwargs: Forwarded to checkpoint reading and component loaders.

        Raises:
            ValueError: saved topology is incompatible with the current run, or
                FSDP topology metadata is missing/changed.

        **Side effects:** restores model/optimizer/scheduler/runner state and
        updates `config.resume` for path inputs.

        !!! danger "Do not"
            - Suppress topology validation for FSDP restores; shard metadata is
              part of the checkpoint contract.
            - Attempt degree-change restore with multiple local model parts.
        """
        ckpt = self.read_checkpoint(checkpoint, *args, **kwargs)
        saved_topology = self._validate_checkpoint_topology(ckpt)
        if self.fsdp_enabled:
            self._validate_fsdp_checkpoint_topology(ckpt)
        current_topology = self.parallel_axes_state(dict)
        if saved_topology != current_topology:
            if len(self.model_parts) != 1:
                raise ValueError(
                    "cannot restore parallel degree change: degree change restore requires DCP state-dict API "
                    "with a single local model part. "
                    "Either keep parallel axes unchanged, or restore with a single local model part."
                )

            ckpt = dict(ckpt)
            ckpt["parallel"] = {"axes": current_topology}
            runner_config = ckpt.get("runner")
            if isinstance(runner_config, Mapping):
                runner_payload = dict(runner_config)
                parallel_config = runner_payload.get("parallel")
                if isinstance(parallel_config, Mapping):
                    updated_parallel_config = dict(parallel_config)
                    axes = dict(updated_parallel_config.get("axes", {}))
                    axes.update(current_topology)
                    updated_parallel_config["axes"] = axes
                    runner_payload["parallel"] = updated_parallel_config
                    ckpt["runner"] = runner_payload

        super().load_checkpoint(ckpt, *args, **kwargs)
        if isinstance(checkpoint, (str, bytes, os.PathLike)):
            self.config.resume = os.fsdecode(checkpoint)

    def _validate_checkpoint_topology(self, checkpoint: Mapping[str, Any]) -> dict[str, int]:
        ckpt_topology = checkpoint.get("parallel")
        current = self.parallel_axes_state(dict)
        if not isinstance(ckpt_topology, Mapping):
            return dict(current)

        axes = ckpt_topology.get("axes", {})
        saved = dict(current)
        if isinstance(axes, Mapping):
            for axis in current:
                if axis in axes:
                    saved[axis] = int(axes[axis])
        if saved == current:
            return saved

        allow_degree_change = self.config.parallel.allow_degree_change
        if allow_degree_change:
            warn(
                "parallel degree changed across restart "
                f"(saved axes={saved}, current axes={current}). "
                "Attempting to restore with current runtime mapping.",
                RuntimeWarning,
                stacklevel=2,
            )
            return saved

        raise ValueError(
            "cannot restore checkpoint: parallel degree changed across restart "
            f"(saved axes={saved}, current axes={current}). "
            "Set `config.parallel.allow_degree_change=True` to proceed explicitly."
        )

    def _validate_fsdp_checkpoint_topology(self, checkpoint: Mapping[str, Any]) -> tuple[str, int, int, int]:
        ckpt_topology = checkpoint.get("fsdp")
        current = (
            self.fsdp_mode,
            self.replicate_degree,
            self.shard_degree,
            self.context_degree,
        )
        if not isinstance(ckpt_topology, Mapping):
            raise ValueError(
                "cannot restore parallel FSDP checkpoint: checkpoint is missing 'fsdp' topology metadata. "
                "Start a new run or use a checkpoint written by the current parallel FSDP runner."
            )

        saved = (
            str(ckpt_topology.get("mode", current[0])),
            int(ckpt_topology.get("replicate_degree", current[1])),
            int(ckpt_topology.get("shard_degree", current[2])),
            int(ckpt_topology.get("context_degree", current[3])),
        )
        if saved != current:
            raise ValueError(
                "cannot restore checkpoint: parallel FSDP topology changed across restart "
                f"(saved mode/replicate/shard/context={saved}, current mode/replicate/shard/context={current})."
            )
        return saved

    def close(self, timeout: float | None = None) -> bool:
        try:
            drained = super().close(timeout=timeout)
        except Exception:
            self._reset_model_parallel_groups()
            self._parallel_groups_initialized = False
            raise
        if not drained:
            return False
        self._reset_model_parallel_groups()
        self._parallel_groups_initialized = False
        return True

    @property
    def tensor_degree(self) -> int:
        return self.topology.axis_degree("tensor")

    @property
    def pipeline_degree(self) -> int:
        return self.topology.axis_degree("pipeline")

    @property
    def data_degree(self) -> int:
        return self.topology.domain_degree("data")

    @property
    def tensor_rank(self) -> int:
        return self.topology.axis_rank("tensor")

    @property
    def pipeline_rank(self) -> int:
        return self.topology.axis_rank("pipeline")

    @property
    def data_rank(self) -> int:
        return self.topology.domain_rank("data")

    @property
    def model_parallel_axes(self) -> tuple[str, ...]:
        return tuple(
            axis for axis in ("tensor", "context", "expert", "expert_tensor") if self.topology.axis_degree(axis) > 1
        )

    @property
    def model_parallel_degree(self) -> int:
        degree = 1
        for axis in self.model_parallel_axes:
            degree *= self.topology.axis_degree(axis)
        return degree

    def parallel_axes_state(self, cls: type = dict) -> Mapping[str, int]:
        return cls({axis: self.topology.axis_degree(axis) for axis in self.topology.axis_names})

    @property
    def fsdp_mode(self) -> str:
        if self.replicate_degree > 1:
            return "hybrid_shard"
        return "full_shard"

    @property
    def shard_degree(self) -> int:
        return self.topology.axis_degree("shard")

    @property
    def replicate_degree(self) -> int:
        return self.topology.axis_degree("replicate", default=1)

    @property
    def shard_rank(self) -> int:
        return self.topology.axis_rank("shard")

    @property
    def replicate_rank(self) -> int:
        return self.topology.axis_rank("replicate", default=0)

    @property
    def context_degree(self) -> int:
        return self.topology.axis_degree("context")

    @property
    def context_rank(self) -> int:
        return self.topology.axis_rank("context")

    @property
    def expert_degree(self) -> int:
        return self.topology.axis_degree("expert")

    @property
    def expert_rank(self) -> int:
        return self.topology.axis_rank("expert")

    @property
    def expert_tensor_degree(self) -> int:
        return self.topology.axis_degree("expert_tensor")

    @property
    def expert_tensor_rank(self) -> int:
        return self.topology.axis_rank("expert_tensor")

init_distributed

Python
init_distributed() -> None

Initialize default distributed state and parallel process groups.

Called when: BaseRunner.__init__ invokes init_distributed, before checkpoint manager/fault-tolerance setup and before model materialization.

Precondition: WORLD_SIZE > 1 and the configured parallel axis product equals WORLD_SIZE.

Raises:

Type Description
RuntimeError

distributed mode is not active, or device-mesh process groups cannot be initialized.

ValueError

build_topology rejects the configured axis product.

Side effects: calls TorchRunner.init_distributed, builds self.topology, initializes the device mesh, binds per-axis process groups, and stores self.parallel.

Do not

  • Initialize model/pipeline/FSDP objects here; materialization happens in materialize_model.
  • Override this just to change axis degrees; set config.parallel.axes or override build_topology.
Source code in danling/runners/parallel_runner.py
Python
def init_distributed(self) -> None:
    """
    Initialize default distributed state and parallel process groups.

    **Called when:** `BaseRunner.__init__` invokes `init_distributed`,
    before checkpoint manager/fault-tolerance setup and before model
    materialization.

    **Precondition:** `WORLD_SIZE > 1` and the configured parallel axis
    product equals `WORLD_SIZE`.

    Raises:
        RuntimeError: distributed mode is not active, or device-mesh process
            groups cannot be initialized.
        ValueError: `build_topology` rejects the configured axis product.

    **Side effects:** calls `TorchRunner.init_distributed`, builds
    `self.topology`, initializes the device mesh, binds per-axis process
    groups, and stores `self.parallel`.

    !!! danger "Do not"
        - Initialize model/pipeline/FSDP objects here; materialization
          happens in `materialize_model`.
        - Override this just to change axis degrees; set
          `config.parallel.axes` or override `build_topology`.
    """
    super().init_distributed()
    if self.world_size <= 1:
        raise RuntimeError("ParallelRunner requires distributed mode (WORLD_SIZE > 1)")
    self.topology = self.build_topology()
    if not self._parallel_groups_initialized:
        self._reset_model_parallel_groups()
        self._init_model_parallel_groups()
        self._parallel_groups_initialized = True

build_topology

Python
build_topology() -> ParallelTopology

Build the rank-to-axis topology for this parallel run.

Called when: init_distributed has initialized the default process group and needs per-axis domains.

Returns:

Type Description
ParallelTopology

ParallelTopology with axis degrees, current-rank coordinates, and

ParallelTopology

named reduction domains.

Raises:

Type Description
ValueError

any axis degree is less than one, or the product of axis degrees does not equal WORLD_SIZE.

Side effects: none. Override this only for non-standard axis/domain layouts; normal users should configure config.parallel.axes.

Source code in danling/runners/parallel_runner.py
Python
def build_topology(self) -> ParallelTopology:
    """
    Build the rank-to-axis topology for this parallel run.

    **Called when:** `init_distributed` has initialized the default process
    group and needs per-axis domains.

    Returns:
        `ParallelTopology` with axis degrees, current-rank coordinates, and
        named reduction domains.

    Raises:
        ValueError: any axis degree is less than one, or the product of axis
            degrees does not equal `WORLD_SIZE`.

    **Side effects:** none. Override this only for non-standard axis/domain
    layouts; normal users should configure `config.parallel.axes`.
    """
    axes = {
        "replicate": int(self.config.parallel.axes.replicate),
        "shard": int(self.config.parallel.axes.shard),
        "context": int(self.config.parallel.axes.context),
        "pipeline": int(self.config.parallel.axes.pipeline),
        "tensor": int(self.config.parallel.axes.tensor),
        "expert": int(self.config.parallel.axes.expert),
        "expert_tensor": int(self.config.parallel.axes.expert_tensor),
    }
    return ParallelTopology(
        world_size=self.world_size,
        rank=self.rank,
        axes=axes,
        domains={
            "data": ("replicate", "shard"),
            "batch": ("replicate", "shard"),
            "loss": ("replicate", "shard", "context"),
            "optimizer": tuple(axes),
            "fsdp": ("replicate", "shard", "context"),
            "context": ("context",),
            "pipeline": ("pipeline",),
            "tensor": ("tensor",),
            "expert": ("expert",),
            "expert_tensor": ("expert_tensor",),
        },
        label="parallel topology",
    )

materialize_model

Python
materialize_model() -> None

Materialize local model parts for FSDP/pipeline/model-parallel training.

Called when: TorchRunner.__post_init__ reaches materialize_model, after FP8 setup and before optimizer build.

Precondition: either self.model or self.model_parts is bound. Pipeline runs may also provide self.pipeline_schedule; otherwise a single local model is converted to a pipeline stage when pipeline_degree > 1.

Raises:

Type Description
RuntimeError

FSDP prerequisites are unavailable.

ValueError

model/model_parts are missing or an unsupported auto-pipeline shape is requested.

Side effects: moves local parts to self.device, calls parallelize_model, applies FP8 policy, compiles each part, optionally wraps parts with FSDP2 after apply_activation_checkpointing, binds pipeline schedule modules, installs TorchFT all-reduce hooks for FSDP, and moves EMA to device.

Do not

  • Build the optimizer before this hook; optimizer parameters must come from materialized/wrapped parts.
  • FSDP-wrap before apply_activation_checkpointing.
  • Replace self.model_parts without keeping self.model aligned to the first local part.
Source code in danling/runners/parallel_runner.py
Python
def materialize_model(self) -> None:
    """
    Materialize local model parts for FSDP/pipeline/model-parallel training.

    **Called when:** `TorchRunner.__post_init__` reaches
    `materialize_model`, after FP8 setup and before optimizer build.

    **Precondition:** either `self.model` or `self.model_parts` is bound.
    Pipeline runs may also provide `self.pipeline_schedule`; otherwise a
    single local model is converted to a pipeline stage when
    `pipeline_degree > 1`.

    Raises:
        RuntimeError: FSDP prerequisites are unavailable.
        ValueError: model/model_parts are missing or an unsupported
            auto-pipeline shape is requested.

    **Side effects:** moves local parts to `self.device`, calls
    `parallelize_model`, applies FP8 policy, compiles each part, optionally
    wraps parts with FSDP2 after `apply_activation_checkpointing`, binds
    pipeline schedule modules, installs TorchFT all-reduce hooks for FSDP,
    and moves EMA to device.

    !!! danger "Do not"
        - Build the optimizer before this hook; optimizer parameters must
          come from materialized/wrapped parts.
        - FSDP-wrap before `apply_activation_checkpointing`.
        - Replace `self.model_parts` without keeping `self.model` aligned
          to the first local part.
    """
    if self.fsdp_enabled:
        self._check_fsdp_prerequisites()
    self._maybe_init_pipeline_schedule_from_single_part()
    parts = self._prepare_local_model_parts()
    if self.fp8_enabled:
        self.apply_fp8_module_policy_to_model_parts()
        parts = list(self.model_parts)

    if self.fsdp_enabled:
        fsdp_kwargs = self.fsdp_kwargs()
        wrapped = [
            fully_shard(self.compiler.compile(self.apply_activation_checkpointing(part)), **fsdp_kwargs)
            for part in parts
        ]
    else:
        wrapped = [self.compiler.compile(part) for part in parts]

    self.model_parts = wrapped
    self.model = wrapped[0]
    self.bind_pipeline_modules(self.model_parts)

    if self.fsdp_enabled:
        self._apply_ft_all_reduce_hook()
    if self.ema is not None:
        self.ema = self.ema.to(self.device)

pipeline_stage_indices

Python
pipeline_stage_indices(
    num_stages: int | None = None,
) -> tuple[int, ...]

Return the pipeline stage indices owned by this rank.

The default supports the common looped virtual-stage mapping used by interleaved schedules: rank r owns r, r + pp_degree, … Override this method for mirrored, zero-bubble, or other custom local stage placement.

Source code in danling/runners/parallel_runner.py
Python
def pipeline_stage_indices(self, num_stages: int | None = None) -> tuple[int, ...]:
    """
    Return the pipeline stage indices owned by this rank.

    The default supports the common looped virtual-stage mapping used by
    interleaved schedules: rank `r` owns `r`, `r + pp_degree`, ...
    Override this method for mirrored, zero-bubble, or other custom local
    stage placement.
    """
    if num_stages is None:
        num_stages = self._pipeline_num_stages()
    if num_stages < self.pipeline_degree:
        raise ValueError(
            "pipeline num_stages must be at least pipeline_degree " f"({self.pipeline_degree}), got {num_stages}"
        )
    if num_stages % self.pipeline_degree != 0:
        raise ValueError(
            "pipeline num_stages must be divisible by pipeline_degree "
            f"({self.pipeline_degree}), got {num_stages}"
        )

    stages_per_rank = num_stages // self.pipeline_degree
    if stages_per_rank == 1:
        return (self.pipeline_rank,)

    return tuple(self.pipeline_rank + offset * self.pipeline_degree for offset in range(stages_per_rank))

build_pipeline_model_part

Python
build_pipeline_model_part(model: Module) -> Module

Return the local pipeline model part for this pipeline rank.

The default supports two user-facing contracts:

  • If the model defines build_pipeline_model_part(...), delegate to it.
  • If parallel.module_fqns_per_model_part is configured, extract those named modules for the current pipeline rank. Multiple FQNs become a simple nn.Sequential in the provided order.

Complex graph partitioning should be implemented in the model hook or by overriding this method.

Source code in danling/runners/parallel_runner.py
Python
def build_pipeline_model_part(self, model: nn.Module) -> nn.Module:
    """
    Return the local pipeline model part for this pipeline rank.

    The default supports two user-facing contracts:

    - If the model defines `build_pipeline_model_part(...)`, delegate to it.
    - If `parallel.module_fqns_per_model_part` is configured, extract those
      named modules for the current pipeline rank. Multiple FQNs become a
      simple `nn.Sequential` in the provided order.

    Complex graph partitioning should be implemented in the model hook or
    by overriding this method.
    """
    stage_index = self.pipeline_stage_indices()[0]
    module_fqns = self._pipeline_module_fqns_for_stage(stage_index)
    return self._build_pipeline_model_part(model, stage_index, self._pipeline_num_stages(), module_fqns)

build_pipeline_model_parts

Python
build_pipeline_model_parts(model: Module) -> list[Module]

Return all local pipeline model parts for this pipeline rank.

Override this when a schedule maps multiple stages to each local rank and the default FQN/model-owned partitioning is not expressive enough.

Source code in danling/runners/parallel_runner.py
Python
def build_pipeline_model_parts(self, model: nn.Module) -> list[nn.Module]:
    """
    Return all local pipeline model parts for this pipeline rank.

    Override this when a schedule maps multiple stages to each local rank
    and the default FQN/model-owned partitioning is not expressive enough.
    """
    stage_indices = self.pipeline_stage_indices()
    if len(stage_indices) == 1:
        return [self.build_pipeline_model_part(model)]

    build_part = getattr(model, "build_pipeline_model_part", None)
    has_fqn_partitions = self.config.parallel.get("module_fqns_per_model_part") is not None
    if not callable(build_part) and not has_fqn_partitions:
        raise ValueError(
            "multiple local pipeline stages require `parallel.module_fqns_per_model_part`, "
            "`model.build_pipeline_model_part(...)`, or an override of "
            "`ParallelRunner.build_pipeline_model_parts`"
        )

    num_stages = self._pipeline_num_stages()
    return [
        self._build_pipeline_model_part(
            model,
            stage_index,
            num_stages,
            self._pipeline_module_fqns_for_stage(stage_index),
        )
        for stage_index in stage_indices
    ]

parallelize_model

Python
parallelize_model(model: Module) -> Module

Apply model-specific tensor/context/expert parallel transforms.

Called when: _prepare_local_model_parts materializes each local part, before compile and FSDP wrapping.

Parameters:

Name Type Description Default
model
Module

Local model part to transform.

required

Returns:

Type Description
Module

The transformed model. If the model defines

Module

model.parallelize(parallel), that method may mutate in place and

Module

return None.

Raises:

Type Description
TypeError

model.parallelize returns a non-module value.

NotImplementedError

model-parallel axes are enabled but no transform hook is available.

Do not

  • Move the model to device here; the surrounding materialize_model flow handles device placement before this hook runs.
  • Compile or FSDP-wrap here; those happen after this hook.
Source code in danling/runners/parallel_runner.py
Python
def parallelize_model(self, model: nn.Module) -> nn.Module:
    """
    Apply model-specific tensor/context/expert parallel transforms.

    **Called when:** `_prepare_local_model_parts` materializes each local
    part, before compile and FSDP wrapping.

    Args:
        model: Local model part to transform.

    Returns:
        The transformed model. If the model defines
        `model.parallelize(parallel)`, that method may mutate in place and
        return `None`.

    Raises:
        TypeError: `model.parallelize` returns a non-module value.
        NotImplementedError: model-parallel axes are enabled but no
            transform hook is available.

    !!! danger "Do not"
        - Move the model to device here; the surrounding `materialize_model`
          flow handles device placement before this hook runs.
        - Compile or FSDP-wrap here; those happen after this hook.
    """
    parallelize = getattr(model, "parallelize", None)
    if callable(parallelize):
        parallelized = parallelize(self.parallel)
        if parallelized is None:
            return model
        if not isinstance(parallelized, nn.Module):
            raise TypeError(
                "model.parallelize(parallel) must return an nn.Module or None, "
                f"got {type(parallelized).__name__}"
            )
        return parallelized

    if self.model_parallel_degree > 1:
        axes = ", ".join(self.model_parallel_axes)
        raise NotImplementedError(
            f"parallel axes {axes} require model-specific parallelization. "
            "Implement `model.parallelize(parallel)` or override "
            "`ParallelRunner.parallelize_model`."
        )
    return model

apply_activation_checkpointing

Python
apply_activation_checkpointing(model: Module) -> Module

Apply activation checkpointing to one local model part.

Called when: materialize_model wraps FSDP-enabled parts, before compile/FSDP wrapping.

Parameters:

Name Type Description Default
model
Module

Local model part.

required

Returns:

Type Description
Module

Model part with activation checkpointing wrappers applied.

Side effects: default is a no-op. Overrides may mutate the module in place or return a wrapped module.

Do not

  • Change parameter ownership or shard layout here; FSDP has not wrapped the model yet.
  • Return a non-module value.
Source code in danling/runners/parallel_runner.py
Python
def apply_activation_checkpointing(self, model: nn.Module) -> nn.Module:
    """
    Apply activation checkpointing to one local model part.

    **Called when:** `materialize_model` wraps FSDP-enabled parts, before
    compile/FSDP wrapping.

    Args:
        model: Local model part.

    Returns:
        Model part with activation checkpointing wrappers applied.

    **Side effects:** default is a no-op. Overrides may mutate the module
    in place or return a wrapped module.

    !!! danger "Do not"
        - Change parameter ownership or shard layout here; FSDP has not
          wrapped the model yet.
        - Return a non-module value.
    """
    return model

build_pipeline_schedule

Python
build_pipeline_schedule(
    stage_model: Module | Sequence[Module],
) -> Any

Build the PyTorch pipeline schedule for this rank.

Called when: materialize_model sees pipeline_degree > 1 and no explicit pipeline_schedule is already bound.

Parameters:

Name Type Description Default
stage_model
Module | Sequence[Module]

Local stage module for this pipeline rank, or all local stage modules for an interleaved/multi-stage schedule.

required

Returns:

Type Description
Any

A PyTorch pipeline schedule instance.

Raises:

Type Description
ValueError

pipeline microbatch count cannot be inferred or is inconsistent with batch size.

Side effects: none beyond schedule construction. The caller binds the schedule modules after compile/FSDP wrapping.

Do not

  • Set scale_grads=True; DanLing owns gradient/loss scaling.
  • Build the optimizer here.
Source code in danling/runners/parallel_runner.py
Python
def build_pipeline_schedule(self, stage_model: nn.Module | Sequence[nn.Module]) -> Any:
    """
    Build the PyTorch pipeline schedule for this rank.

    **Called when:** `materialize_model` sees `pipeline_degree > 1` and no
    explicit `pipeline_schedule` is already bound.

    Args:
        stage_model: Local stage module for this pipeline rank, or all
            local stage modules for an interleaved/multi-stage schedule.

    Returns:
        A PyTorch pipeline schedule instance.

    Raises:
        ValueError: pipeline microbatch count cannot be inferred or is
            inconsistent with batch size.

    **Side effects:** none beyond schedule construction. The caller binds
    the schedule modules after compile/FSDP wrapping.

    !!! danger "Do not"
        - Set `scale_grads=True`; DanLing owns gradient/loss scaling.
        - Build the optimizer here.
    """
    pipeline.check()
    schedule_name = str(self.config.parallel.get("pipeline_schedule", "1F1B")).strip() or "1F1B"
    n_microbatches = self._resolve_pipeline_n_microbatches()
    schedule_class = get_schedule_class(schedule_name)
    loss_fn = self._pipeline_loss if self.criterion is not None else None
    stage_models = [stage_model] if isinstance(stage_model, nn.Module) else list(stage_model)
    num_stages = self._pipeline_num_stages()
    stage_indices = self.pipeline_stage_indices(num_stages)
    if len(stage_models) != len(stage_indices):
        raise ValueError(
            "pipeline stage model count must match local pipeline stage indices: "
            f"{len(stage_models)} != {len(stage_indices)}"
        )
    stages = [
        PipelineStage(
            module,
            stage_index=stage_index,
            num_stages=num_stages,
            device=self.device,
            group=self.pipeline_group,
        )
        for module, stage_index in zip(stage_models, stage_indices)
    ]

    # Default to non-interleaved 1F1B for pipeline schedules until
    # pytorch/pytorch#164756 is addressed upstream, then we can migrate the
    # default to Interleaved1F1B.
    if issubclass(schedule_class, PipelineScheduleMulti):
        return schedule_class(
            stages,
            n_microbatches=n_microbatches,
            loss_fn=loss_fn,
            scale_grads=False,
        )
    if len(stages) != 1:
        raise ValueError(
            f"pipeline schedule {schedule_name!r} accepts one local stage, got {len(stages)}; "
            "choose an interleaved/multi-stage schedule or override `build_pipeline_schedule`."
        )
    return schedule_class(
        stages[0],
        n_microbatches=n_microbatches,
        loss_fn=loss_fn,
        scale_grads=False,
    )

build_datasampler

Python
build_datasampler(
    dataset: Any, *, split: str, shuffle: bool
) -> Any

Build a data-parallel sampler for one split.

Called when: inherited build_dataloaders materializes a dataset split.

Parameters:

Name Type Description Default
dataset
Any

Dataset object for the split.

required
split
str

Split name being materialized.

required
shuffle
bool

Whether to shuffle the split.

required

Returns:

Type Description
Any

DistributedSampler using topology data-parallel degree/rank,

Any

adjusted by TorchFT when active.

Source code in danling/runners/parallel_runner.py
Python
def build_datasampler(self, dataset: Any, *, split: str, shuffle: bool) -> Any:
    """
    Build a data-parallel sampler for one split.

    **Called when:** inherited `build_dataloaders` materializes a dataset
    split.

    Args:
        dataset: Dataset object for the split.
        split: Split name being materialized.
        shuffle: Whether to shuffle the split.

    Returns:
        `DistributedSampler` using topology data-parallel degree/rank,
        adjusted by TorchFT when active.
    """
    num_replicas = self.data_degree
    rank = self.data_rank
    if self.ft is not None:
        num_replicas, rank = self.ft.data_parallel_info(num_replicas, rank)
    return utils.data.distributed.DistributedSampler(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)

train_step

Python
train_step(data: Any) -> tuple[Any, Tensor | None]

Run one training micro-step for plain or pipeline-parallel execution.

Non-pipeline configurations delegate to TorchRunner.train_step. Pipeline configurations call the schedule, compute loss only on last stages, synchronize accumulation normalization across the pipeline, and then delegate optimizer-boundary handling to step().

Called when: train_epoch/train_steps consume one micro-batch.

Parameters:

Name Type Description Default
data
Any

Micro-batch from the local loader. Non-first/non-last pipeline stages may receive None through StepProxyLoader.

required

Returns:

Type Description
Any

(None, loss) for pipeline mode, where loss is present only on

Tensor | None

ranks that can report last-stage loss. Non-pipeline mode returns

tuple[Any, Tensor | None]

the TorchRunner result.

Do not

  • Call the optimizer directly; use step().
  • Update metrics from pipeline mode here; pipeline schedule outputs are not a normal full-batch prediction.
  • Manually divide gradients by pipeline microbatch count.
Source code in danling/runners/parallel_runner.py
Python
def train_step(self, data: Any) -> tuple[Any, torch.Tensor | None]:
    """
    Run one training micro-step for plain or pipeline-parallel execution.

    Non-pipeline configurations delegate to `TorchRunner.train_step`.
    Pipeline configurations call the schedule, compute loss only on last
    stages, synchronize accumulation normalization across the pipeline, and
    then delegate optimizer-boundary handling to `step()`.

    **Called when:** `train_epoch`/`train_steps` consume one micro-batch.

    Args:
        data: Micro-batch from the local loader. Non-first/non-last
            pipeline stages may receive `None` through `StepProxyLoader`.

    Returns:
        `(None, loss)` for pipeline mode, where `loss` is present only on
        ranks that can report last-stage loss. Non-pipeline mode returns
        the TorchRunner result.

    !!! danger "Do not"
        - Call the optimizer directly; use `step()`.
        - Update metrics from pipeline mode here; pipeline schedule outputs
          are not a normal full-batch prediction.
        - Manually divide gradients by pipeline microbatch count.
    """
    if self.pipeline_schedule is None:
        return super().train_step(data)

    with self.train_context():
        self._pipeline_loss_divisor_local = 0.0
        self._pipeline_loss_weighting = "train"
        inputs, target = self._prepare_pipeline_batch(data)
        losses: list[torch.Tensor] = []
        targets = target if self.pipeline_has_last_stage else None

        try:
            if self.pipeline_has_first_stage:
                self.pipeline_schedule.step(
                    inputs,
                    target=targets,
                    losses=losses,
                )
            else:
                self.pipeline_schedule.step(
                    target=targets,
                    losses=losses,
                )
        finally:
            self._pipeline_loss_weighting = None

        loss = self._pipeline_loss_value(losses)

        pred = None
        self._sync_pipeline_accumulation_divisor()
        self.step()
    return pred, loss

evaluate_step

Python
evaluate_step(data: Any) -> tuple[Any, Tensor | None]

Run one evaluation micro-step for plain or pipeline execution.

Non-pipeline configurations delegate to TorchRunner.evaluate_step. Pipeline configurations call the schedule in eval mode and report normalized loss from last-stage ranks.

Called when: evaluate_epoch/evaluate_steps consume one micro-batch under inference mode.

Parameters:

Name Type Description Default
data
Any

Micro-batch from the local loader. Non-first/non-last pipeline stages may receive None.

required

Returns:

Type Description
Any

(None, loss) for pipeline mode. Non-pipeline mode returns the

Tensor | None

TorchRunner result.

Do not

  • Call backward or step.
  • Assume every rank has targets; only last-stage ranks need them.
Source code in danling/runners/parallel_runner.py
Python
def evaluate_step(self, data: Any) -> tuple[Any, torch.Tensor | None]:
    """
    Run one evaluation micro-step for plain or pipeline execution.

    Non-pipeline configurations delegate to `TorchRunner.evaluate_step`.
    Pipeline configurations call the schedule in eval mode and report
    normalized loss from last-stage ranks.

    **Called when:** `evaluate_epoch`/`evaluate_steps` consume one
    micro-batch under inference mode.

    Args:
        data: Micro-batch from the local loader. Non-first/non-last
            pipeline stages may receive `None`.

    Returns:
        `(None, loss)` for pipeline mode. Non-pipeline mode returns the
        TorchRunner result.

    !!! danger "Do not"
        - Call backward or step.
        - Assume every rank has targets; only last-stage ranks need them.
    """
    if self.pipeline_schedule is None:
        return super().evaluate_step(data)

    with self.forward_context():
        self._pipeline_loss_divisor_local = 0.0
        self._pipeline_loss_weighting = "eval"
        inputs, target = self._prepare_pipeline_batch(data)
        losses: list[torch.Tensor] = []
        targets = target if self.pipeline_has_last_stage else None

        try:
            if self.pipeline_has_first_stage:
                self.pipeline_schedule.eval(
                    inputs,
                    target=targets,
                    losses=losses,
                )
            else:
                self.pipeline_schedule.eval(
                    target=targets,
                    losses=losses,
                )
        finally:
            self._pipeline_loss_weighting = None

        loss = self._pipeline_loss_value(losses)

    return None, loss

infer_step

Python
infer_step(data: Any) -> list[float]

Run one inference micro-step for plain or pipeline execution.

Non-pipeline configurations delegate to TorchRunner.infer_step. Pipeline configurations call the schedule in eval mode and normalize whatever the schedule returns into a flat list of floats.

Parameters:

Name Type Description Default
data
Any

Micro-batch on first-stage ranks; None on non-first stages that only participate in pipeline communication.

required

Returns:

Type Description
list[float]

Flat list of numeric predictions. Non-output ranks may return an

list[float]

empty list.

Raises:

Type Description
ValueError

pipeline output cannot be normalized into floats.

Source code in danling/runners/parallel_runner.py
Python
@torch.inference_mode()
def infer_step(self, data: Any) -> list[float]:
    """
    Run one inference micro-step for plain or pipeline execution.

    Non-pipeline configurations delegate to `TorchRunner.infer_step`.
    Pipeline configurations call the schedule in eval mode and normalize
    whatever the schedule returns into a flat list of floats.

    Args:
        data: Micro-batch on first-stage ranks; `None` on non-first stages
            that only participate in pipeline communication.

    Returns:
        Flat list of numeric predictions. Non-output ranks may return an
        empty list.

    Raises:
        ValueError: pipeline output cannot be normalized into floats.
    """
    if self.pipeline_schedule is None:
        return super().infer_step(data)

    with self.forward_context():
        inputs, _ = self._prepare_pipeline_batch(data)
        if self.pipeline_has_first_stage:
            pred = self.pipeline_schedule.eval(inputs)
        else:
            pred = self.pipeline_schedule.eval()
    return self._normalize_infer_output(pred)

infer

Python
infer(
    split: str = "infer",
    *,
    steps: int | None = None,
    stream: bool | None = None
) -> list[float] | Iterator[list[float]]

Run inference across a pipeline-aware loader.

Non-pipeline configurations delegate to TorchRunner.infer. Pipeline configurations consume real dataloader batches only on first-stage ranks; other stages run infer_step(None) for the same number of steps.

Parameters:

Name Type Description Default
split
str

Inference split name.

'infer'
steps
int | None

Optional maximum number of batches/stage ticks.

None
stream
bool | None

Whether to return a per-batch iterator instead of a flattened list.

None

Returns:

Type Description
list[float] | Iterator[list[float]]

Flattened predictions or a streaming iterator.

Raises:

Type Description
ValueError

steps is negative, or a non-first pipeline stage has an unsized loader and no explicit step count.

Source code in danling/runners/parallel_runner.py
Python
def infer(
    self,
    split: str = "infer",
    *,
    steps: int | None = None,
    stream: bool | None = None,
) -> list[float] | Iterator[list[float]]:
    """
    Run inference across a pipeline-aware loader.

    Non-pipeline configurations delegate to `TorchRunner.infer`. Pipeline
    configurations consume real dataloader batches only on first-stage
    ranks; other stages run `infer_step(None)` for the same number of
    steps.

    Args:
        split: Inference split name.
        steps: Optional maximum number of batches/stage ticks.
        stream: Whether to return a per-batch iterator instead of a
            flattened list.

    Returns:
        Flattened predictions or a streaming iterator.

    Raises:
        ValueError: `steps` is negative, or a non-first pipeline stage has
            an unsized loader and no explicit step count.
    """
    if self.pipeline_schedule is None:
        return super().infer(split=split, steps=steps, stream=stream)

    self.mode = RunnerMode.infer
    self.split = split
    loader = self.dataloaders[split]

    if steps is not None and steps < 0:
        raise ValueError(f"invalid steps: expected a non-negative value, got {steps}")

    loader_length = self._loader_length(loader)
    if stream is None:
        stream = steps is None and loader_length is None

    if self.pipeline_has_first_stage:
        if not stream and loader_length is None and steps is None:
            raise ValueError("infer with stream=False requires `steps` for unsized loaders")
        if steps is not None:
            iterator = (self.infer_step(data) for iteration, data in enumerate(loader) if iteration < steps)
        else:
            iterator = (self.infer_step(data) for data in loader)
        total = steps if steps is not None else loader_length
    else:
        if steps is None:
            if loader_length is None:
                raise ValueError("infer for non-first pipeline stages requires `steps` for unsized loaders")
            steps = loader_length
        iterator = (self.infer_step(None) for _ in range(steps))
        total = steps

    if stream:
        return iterator

    output: list[float] = []
    for values in tqdm(iterator, total=total, disable=self.distributed and not self.is_main_process):
        output.extend(values)
    return output

load_checkpoint

Python
load_checkpoint(
    checkpoint: Mapping | bytes | str | PathLike,
    *args: Any,
    **kwargs: Any
) -> None

Restore a parallel checkpoint with topology validation.

The checkpoint is read through the active DCP manager, validated against current parallel axes, optionally remapped for allowed non-FSDP degree changes, and then restored through the TorchRunner component loaders.

Parameters:

Name Type Description Default
checkpoint
Mapping | bytes | str | PathLike

In-memory checkpoint mapping or DCP checkpoint path.

required
*args
Any

Forwarded to checkpoint reading and component loaders.

()
**kwargs
Any

Forwarded to checkpoint reading and component loaders.

{}

Raises:

Type Description
ValueError

saved topology is incompatible with the current run, or FSDP topology metadata is missing/changed.

Side effects: restores model/optimizer/scheduler/runner state and updates config.resume for path inputs.

Do not

  • Suppress topology validation for FSDP restores; shard metadata is part of the checkpoint contract.
  • Attempt degree-change restore with multiple local model parts.
Source code in danling/runners/parallel_runner.py
Python
def load_checkpoint(
    self,
    checkpoint: Mapping | bytes | str | os.PathLike,
    *args: Any,
    **kwargs: Any,
) -> None:
    """
    Restore a parallel checkpoint with topology validation.

    The checkpoint is read through the active DCP manager, validated against
    current parallel axes, optionally remapped for allowed non-FSDP degree
    changes, and then restored through the TorchRunner component loaders.

    Args:
        checkpoint: In-memory checkpoint mapping or DCP checkpoint path.
        *args: Forwarded to checkpoint reading and component loaders.
        **kwargs: Forwarded to checkpoint reading and component loaders.

    Raises:
        ValueError: saved topology is incompatible with the current run, or
            FSDP topology metadata is missing/changed.

    **Side effects:** restores model/optimizer/scheduler/runner state and
    updates `config.resume` for path inputs.

    !!! danger "Do not"
        - Suppress topology validation for FSDP restores; shard metadata is
          part of the checkpoint contract.
        - Attempt degree-change restore with multiple local model parts.
    """
    ckpt = self.read_checkpoint(checkpoint, *args, **kwargs)
    saved_topology = self._validate_checkpoint_topology(ckpt)
    if self.fsdp_enabled:
        self._validate_fsdp_checkpoint_topology(ckpt)
    current_topology = self.parallel_axes_state(dict)
    if saved_topology != current_topology:
        if len(self.model_parts) != 1:
            raise ValueError(
                "cannot restore parallel degree change: degree change restore requires DCP state-dict API "
                "with a single local model part. "
                "Either keep parallel axes unchanged, or restore with a single local model part."
            )

        ckpt = dict(ckpt)
        ckpt["parallel"] = {"axes": current_topology}
        runner_config = ckpt.get("runner")
        if isinstance(runner_config, Mapping):
            runner_payload = dict(runner_config)
            parallel_config = runner_payload.get("parallel")
            if isinstance(parallel_config, Mapping):
                updated_parallel_config = dict(parallel_config)
                axes = dict(updated_parallel_config.get("axes", {}))
                axes.update(current_topology)
                updated_parallel_config["axes"] = axes
                runner_payload["parallel"] = updated_parallel_config
                ckpt["runner"] = runner_payload

    super().load_checkpoint(ckpt, *args, **kwargs)
    if isinstance(checkpoint, (str, bytes, os.PathLike)):
        self.config.resume = os.fsdecode(checkpoint)

Runner

Bases: BaseRunner

Dynamic runner entrypoint that selects stack-specific runner classes.

Source code in danling/runners/runner.py
Python
class Runner(BaseRunner):
    """Dynamic runner entrypoint that selects stack-specific runner classes."""

    @staticmethod
    def resolve_stack(config: Mapping[str, Any]) -> str:
        return normalize_stack_name(config.get("stack", "auto"))

    @classmethod
    def resolve_runner_class(cls, config: Mapping[str, Any]) -> type[TorchRunner]:
        stack = cls.resolve_stack(config)
        if stack in RUNNER_REGISTRY:
            return RUNNER_REGISTRY[stack]
        valid = ", ".join(sorted(RUNNER_REGISTRY))
        raise ValueError(f"Unknown stack: {stack!r}. Valid options are: {valid}")

    def __new__(cls, config):
        runner_cls = cls.resolve_runner_class(config)
        if cls is Runner:
            return runner_cls(config)
        dynamic_cls = type(cls.__name__, (cls, runner_cls), {})
        return super().__new__(dynamic_cls)

RunnerConfig

Bases: Config

Configuration class for managing and persisting all states of a DanLing Runner.

The RunnerConfig class provides a hierarchical configuration system that handles:

  1. Parameter management: Hyperparameters, model settings, training options
  2. Experiment tracking: IDs, names, and other metadata for runs and experiments
  3. Serialization: Save/load configurations from files or command line
  4. Reproducibility: Tracking seeds and settings for reproducible runs

RunnerConfig inherits from Config and provides attribute-style access to nested values:

Python
1
2
3
4
5
6
7
8
9
config = RunnerConfig()

# Attribute-style access (recommended)
config.optim.lr = 1e-3
config.network.type = "resnet50"

# Dictionary-style access (alternative)
config["optim"]["lr"] = 1e-3
config["network"]["type"] = "resnet50"

RunnerConfig objects support three types of hierarchical attribute access patterns:

  1. Direct assignment for simple values:

    Python
    config.epochs = 10
    

  2. Auto-created nested objects for hierarchical settings:

    Python
    1
    2
    3
    # Auto-creates the nested objects
    config.optim.lr = 0.01
    config.optim.weight_decay = 1e-4
    

  3. Class-level annotations for typed properties with defaults:

    Python
    1
    2
    3
    class MyConfig(RunnerConfig):
        epochs: int = 10
        learning_rate: float = 0.001
    

Command-line integration is built-in. You can define a configuration and then override values via command line arguments:

Python
config = MyConfig()
config.parse()  # Parse CLI args, e.g., --epochs 20 --optim.lr 0.01

General:

Name Type Description
stack str

Runner stack selector used by danling.runners.Runner. Supported values: "auto", "ddp"/"torch", "graph", "deepspeed"/"ds", "parallel". Defaults to "auto" (resolved to "ddp" at runtime).

Reproducibility:

Name Type Description
seed int

Random seed for reproducibility. If not set, a random value is generated.

deterministic bool

Whether to enforce deterministic operations in PyTorch. Defaults to False for better performance. Set to True for exact reproducibility.

Progress:

Name Type Description
steps int | None

Final global step target for training. In step mode, training stops when global_step >= steps.

epochs int | None

Final epoch index boundary for training. In epoch mode, training iterates epochs until epoch == epochs.

Model Evaluation:

Name Type Description
score_split str

Dataset split to use for model selection. Defaults to None. If unset, runner infers once (val -> validate -> first available) and reuses it unless that split disappears from results.

score_name str

Metric name to use for model selection. Defaults to “loss”.

scheduler.interval str

Scheduler advancement policy. Supported values: "step" and "epoch"/"validation". Non-metric schedulers default to "step". Metric schedulers such as ReduceLROnPlateau default to "epoch" and advance after the aggregated round result is available.

scheduler.monitor str

Optional metric selector for metric schedulers. Supports dotted paths such as "val.loss". When unset, the runner prefers score_split/score_name when available and otherwise resolves score_name from the aggregated result.

Optimization:

Name Type Description
optim.param_groups list[dict] | None

Optional regex-based optimizer parameter groups. Each entry requires pattern, matched against TorchRunner.iter_optimizer_named_parameters() with re.search semantics, and may provide optimizer group options directly. Anchor patterns with ^/$ when a full FQN position matters. lr_multiplier, weight_decay_multiplier, beta1, and beta2 derive group values from top-level optim.lr, optim.weight_decay, and optim.betas. Unmatched parameters keep the optimizer-level defaults.

I/O:

Name Type Description
workspace_root str

Root directory for experiments. Defaults to "experiments".

auto_resume bool

Auto-resume from backend latest checkpoint alias/path. When True, runner resolves the backend-native latest checkpoint source. Priority is resume > auto_resume > pretrained.

resume str | None

Optional full-state checkpoint source for resume workflows. This is a path-like identifier consumed by runner load_checkpoint(...).

pretrained str | None

Optional model-only checkpoint source for finetune workflows. This is a path-like identifier consumed by runner load_pretrained(...).

lineage str

Top-level lineage namespace. Defaults to "lin" when unset. RunnerWorkspace.dir appends code identity (-<git_hash>) when available.

experiment str

Experiment namespace. Defaults to "exp".

checkpoint.dir_name str

Subdirectory name for checkpoints. Defaults to "checkpoints".

checkpoint.async_enabled bool

Whether to persist checkpoints asynchronously. Defaults to True.

checkpoint.async_mode str | None

Checkpoint async behavior. Supported values: "disabled", "async", "async_with_pinned_mem". When unset (None), the runner derives the mode from checkpoint.async_enabled.

checkpoint.dedicated_async_process_group bool

Use a dedicated process group for async DCP checkpoint I/O to reduce interference with training collectives. Defaults to True.

checkpoint.async_process_group_backend str

Backend for the dedicated async checkpoint process group. Defaults to "gloo".

checkpoint.backend str

Checkpoint backend selected at runtime by the runner ("dcp" for distributed runs, "file" otherwise when set to "auto").

checkpoint.wait_timeout float

Timeout in seconds when draining async checkpoint writes during runner shutdown (None waits indefinitely).

parallel.axes.replicate int

Data-replication degree for DDP/HSDP-style replication. Defaults to 1.

parallel.axes.shard int

Data-sharding degree for FSDP-style sharding. Defaults to 1. Set one parallel axis, commonly shard, to -1 to auto-fill it from WORLD_SIZE and the other configured axes.

parallel.axes.context int

Context/sequence parallel degree. Defaults to 1.

parallel.axes.pipeline int

Pipeline-parallel degree. Defaults to 1.

parallel.axes.tensor int

Tensor-parallel degree. Defaults to 1.

parallel.axes.expert int

Expert-parallel degree for MoE models. Defaults to 1.

parallel.axes.expert_tensor int

Expert tensor-parallel degree for MoE models. Defaults to 1.

parallel.pipeline_schedule str

Pipeline schedule class name resolved by torch.distributed.pipelining.schedules.get_schedule_class. Defaults to "1F1B".

parallel.pipeline_microbatch_size int

Local microbatch size used to infer schedule microbatch count as dataloader.batch_size // pipeline_microbatch_size. Defaults to 1.

parallel.pipeline_n_microbatches int

Explicit schedule microbatch count. When set, overrides pipeline_microbatch_size-based inference.

parallel.module_fqns_per_model_part list[list[str]] | None

Optional module FQNs for simple pipeline stage extraction. The outer list length is the total pipeline stage count and must be divisible by parallel.axes.pipeline; complex partitioning should use model.build_pipeline_model_part(...) or override ParallelRunner.build_pipeline_model_part / ParallelRunner.build_pipeline_model_parts.

log bool

Whether to enable file logging. Defaults to True. Logging is initialized on the main process only.

tensorboard bool

Whether to use TensorBoard for visualization. Defaults to False.

wandb.enabled bool

Whether to enable Weights & Biases scalar logging. Defaults to False.

wandb.project str | None

Optional W&B project name. Defaults to lineage.

wandb.entity str | None

Optional W&B entity/team override.

wandb.group str | None

Optional W&B group name. Defaults to experiment.

wandb.name str | None

Optional W&B display name. Defaults to stable runner id.

wandb.job_type str | None

Optional W&B job type.

wandb.tags list[str] | str | None

Optional W&B run tags.

wandb.dir str | None

Optional local W&B run directory. Defaults to run dir.

wandb.mode str | None

Optional W&B mode such as "online" or "offline".

ft.enabled bool

Enable TorchFT-managed fault tolerance. Defaults to False.

ft.process_group str

TorchFT coordination backend. Supported values: "gloo" and "nccl". Defaults to "gloo".

ft.process_group_timeout_ms int

TorchFT process-group timeout in milliseconds. Defaults to 10000.

ft.replica_id int

Replica-group identifier for this run. Defaults to 0.

ft.group_size int

Number of replica groups participating in TorchFT. Defaults to 1.

ft.min_replica_size int

Minimum healthy replicas required by TorchFT per step. Defaults to 1.

log_interval int

Iterations between log outputs. If None, auto-calculated.

checkpoint.interval int

Interval between checkpoint save attempts for latest/best. The same cadence is used for history checkpoints. Uses epochs in epoch mode and global steps in step mode. If unset, runner defaults are used by mode.

checkpoint.keep_latest_k int

Number of framework-generated history checkpoints to retain. 0 disables retention pruning.

checkpoint.load_only bool

Disable checkpoint persistence entirely while still allowing checkpoint loading.

checkpoint.enable_ft_dataloader_checkpoints bool

Enable per-replica dataloader checkpoints for FT recovery. Uses DCP and stores checkpoints under checkpoint.ft_dataloader_checkpoint_prefix-{checkpoint.ft_replica_id}.

checkpoint.ft_replica_id str | None

Replica identifier used for FT dataloader checkpoint directory naming. Defaults to FT_REPLICA_ID environment variable, then process rank.

checkpoint.ft_dataloader_checkpoint_prefix str

Prefix used for FT per-replica checkpoint directories. Defaults to "ft-replica".

checkpoint.exclude_from_loading list[str] | str | None

Checkpoint keys to skip during load_checkpoint, such as "optimizer", "scheduler", "dataloaders", or dotted nested keys. The aliases "data_loader", "dataloader", and "lr_scheduler" are accepted.

checkpoint.last_save_model_only bool

Save model-only payload on final last_step checkpoint.

checkpoint.export_dtype str

Optional dtype cast for final model-only export (fp32/fp16/bf16/fp64 aliases supported).

dataloader.batch_size int | None

Local dataloader batch size passed to StatefulDataLoader.

dataloader.shuffle bool | None

Optional shuffle override. When unset, train splits shuffle and non-train splits do not.

dataloader.drop_last bool | None

Optional drop-last override. When unset, train splits drop incomplete batches and non-train splits keep them.

dataloader.num_workers / persistent_workers / prefetch_factor / pin_memory

Standard PyTorch DataLoader kwargs forwarded to StatefulDataLoader.

dataloader.in_order bool

PyTorch DataLoader ordering flag.

dataloader.snapshot_every_n_steps int | None

StatefulDataLoader snapshot cadence.

dataloader.<split> dict

Split-specific overrides merged on top of default dataloader kwargs, for example dataloader.train.shuffle=False.

fsdp.enabled bool

Enable FSDP2 wrapping in ParallelRunner. The FSDP mesh is derived from parallel.axes.replicate, parallel.axes.shard, and later parallel.axes.context.

fsdp.reshard_after_forward bool | int | None

Optional FSDP2 reshard policy.

fsdp.mp_policy bool | int | None

Optional FSDP2 mixed precision policy.

fsdp.offload_policy bool | int | None

Optional FSDP2 CPU offload policy.

compile.enable bool

Whether to enable torch.compile for runner-selected model compilation points.

compile.backend str

Optional backend passed to torch.compile.

compile.fullgraph bool

Optional fullgraph flag for torch.compile.

compile.dynamic bool

Optional dynamic flag for torch.compile.

compile.mode str

Optional mode passed to torch.compile.

compile.options dict

Optional options passed to torch.compile.

compile.optimize_ddp str | None

Optional torch._dynamo.config.optimize_ddp value. Defaults to "ddp_optimizer" when model compile is enabled.

compile.precompile_artifact_dir str | None

Optional directory for GraphRunner torch compiler cache artifacts. Current eager runners ignore this setting.

compile.memory_policy str | None

Optional graph-memory policy label for experimental graph paths. GraphRunner currently accepts None/"default"; activation remat/offload policies require a dedicated graph pass pipeline.

comm.init_timeout_seconds int | None

Optional distributed process-group timeout used during initialization and early startup.

comm.train_timeout_seconds int | None

Optional tighter distributed process-group timeout applied once after the first successful optimizer step.

gc.interval int | None

Optional periodic Python GC cadence. When unset, runner-managed GC pacing is disabled.

gc.generation int

Python GC generation passed to gc.collect(...) when pacing is enabled. Defaults to 1.

gc.disable_automatic bool

Disable CPython automatic GC while runner-managed pacing is enabled. Defaults to True.

profiling.enabled bool

Enable bounded-step torch.profiler tracing. Defaults to False.

profiling.wait int

Profiler schedule wait steps before warmup. Defaults to 1.

profiling.warmup int

Profiler schedule warmup steps. Defaults to 1.

profiling.active int

Profiler schedule active trace steps. Defaults to 3.

profiling.repeat int | None

Optional profiler schedule repeat count.

profiling.record_shapes bool

Enable shape recording in traces. Defaults to False.

profiling.profile_memory bool

Enable profiler-side memory recording. Defaults to False.

profiling.with_stack bool

Include Python stack traces in profiler output. Defaults to False.

profiling.with_flops bool

Enable profiler FLOPs estimation when available. Defaults to False.

profiling.trace_dir str

Relative or absolute trace output directory. Defaults to "profiles".

heartbeat.enabled bool

Enable a machine-readable per-rank heartbeat/progress file. Defaults to False.

heartbeat.interval_seconds float

Heartbeat write interval in seconds. Defaults to 60.0.

heartbeat.dir_name str

Subdirectory under the run dir for heartbeat files. Defaults to "heartbeats".

Text Only
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Use in a runner
runner = Runner(config)
```

Custom config class with typed attributes:
```python
class TrainingConfig(RunnerConfig):
    # Type annotations provide auto-completion and validation
    epochs: int = 100
    batch_size: int = 32
    precision: str = "fp16"

    def __init__(self):
        super().__init__()
        # Initialize nested settings
        self.optim.type = "adamw"
        self.optim.lr = 1e-3

    def post(self):
        # Called after parsing CLI args
        super().post()
        # Create derived settings
        self.experiment = f"{self.network.type}_{self.optim.lr}"
```

Command-line integration:
```bash
# Override config settings via CLI
python train.py --epochs 50 --optim.lr 0.0005 --network.type resnet50
```
Note

Always store all parameters needed to reproduce a run in the RunnerConfig. The RunnerConfig is automatically saved with checkpoints, enabling exact resumption.

See Also
Source code in danling/runners/config.py
Python
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
class RunnerConfig(chanfig.Config):  # pylint: disable=too-many-instance-attributes
    r"""
    Configuration class for managing and persisting all states of a DanLing Runner.

    The RunnerConfig class provides a hierarchical configuration system that handles:

    1. **Parameter management**: Hyperparameters, model settings, training options
    2. **Experiment tracking**: IDs, names, and other metadata for runs and experiments
    3. **Serialization**: Save/load configurations from files or command line
    4. **Reproducibility**: Tracking seeds and settings for reproducible runs

    RunnerConfig inherits from [`Config`][chanfig.Config] and provides attribute-style access to nested values:

    ```python
    config = RunnerConfig()

    # Attribute-style access (recommended)
    config.optim.lr = 1e-3
    config.network.type = "resnet50"

    # Dictionary-style access (alternative)
    config["optim"]["lr"] = 1e-3
    config["network"]["type"] = "resnet50"
    ```

    RunnerConfig objects support three types of hierarchical attribute access patterns:

    1. **Direct assignment** for simple values:
       ```python
       config.epochs = 10
       ```

    2. **Auto-created nested objects** for hierarchical settings:
       ```python
       # Auto-creates the nested objects
       config.optim.lr = 0.01
       config.optim.weight_decay = 1e-4
       ```

    3. **Class-level annotations** for typed properties with defaults:
       ```python
       class MyConfig(RunnerConfig):
           epochs: int = 10
           learning_rate: float = 0.001
       ```

    Command-line integration is built-in. You can define a configuration and
    then override values via command line arguments:

    ```python
    config = MyConfig()
    config.parse()  # Parse CLI args, e.g., --epochs 20 --optim.lr 0.01
    ```

    Attributes: General:
        stack (str): Runner stack selector used by `danling.runners.Runner`.
            Supported values: `"auto"`, `"ddp"`/`"torch"`, `"graph"`,
            `"deepspeed"`/`"ds"`, `"parallel"`.
            Defaults to `"auto"` (resolved to `"ddp"` at runtime).

    Attributes: Reproducibility:
        seed (int): Random seed for reproducibility. If not set, a random value is generated.
        deterministic (bool): Whether to enforce deterministic operations in PyTorch.
            Defaults to `False` for better performance. Set to `True` for exact reproducibility.

    Attributes: Progress:
        steps (int | None): Final global step target for training.
            In step mode, training stops when `global_step >= steps`.
        epochs (int | None): Final epoch index boundary for training.
            In epoch mode, training iterates epochs until `epoch == epochs`.

    Attributes: Model Evaluation:
        score_split (str): Dataset split to use for model selection. Defaults to None.
            If unset, runner infers once (`val` -> `validate` -> first available) and reuses it
            unless that split disappears from results.
        score_name (str): Metric name to use for model selection. Defaults to "loss".
        scheduler.interval (str): Scheduler advancement policy.
            Supported values: `"step"` and `"epoch"`/`"validation"`.
            Non-metric schedulers default to `"step"`. Metric schedulers such as
            `ReduceLROnPlateau` default to `"epoch"` and advance after the aggregated
            round result is available.
        scheduler.monitor (str): Optional metric selector for metric schedulers.
            Supports dotted paths such as `"val.loss"`.
            When unset, the runner prefers `score_split/score_name` when available and
            otherwise resolves `score_name` from the aggregated result.

    Attributes: Optimization:
        optim.param_groups (list[dict] | None): Optional regex-based optimizer
            parameter groups. Each entry requires `pattern`, matched against
            `TorchRunner.iter_optimizer_named_parameters()` with `re.search`
            semantics, and may provide optimizer group options directly. Anchor
            patterns with `^`/`$` when a full FQN position matters.
            `lr_multiplier`,
            `weight_decay_multiplier`, `beta1`, and `beta2` derive group values
            from top-level `optim.lr`, `optim.weight_decay`, and `optim.betas`.
            Unmatched parameters keep the optimizer-level defaults.

    Attributes: I/O:
        workspace_root (str): Root directory for experiments. Defaults to `"experiments"`.
        auto_resume (bool): Auto-resume from backend latest checkpoint alias/path.
            When `True`, runner resolves the backend-native latest checkpoint source.
            Priority is `resume` > `auto_resume` > `pretrained`.
        resume (str | None): Optional full-state checkpoint source for resume workflows.
            This is a path-like identifier consumed by runner `load_checkpoint(...)`.
        pretrained (str | None): Optional model-only checkpoint source for finetune workflows.
            This is a path-like identifier consumed by runner `load_pretrained(...)`.
        lineage (str): Top-level lineage namespace.
            Defaults to `"lin"` when unset.
            `RunnerWorkspace.dir` appends code identity (`-<git_hash>`) when available.
        experiment (str): Experiment namespace. Defaults to `"exp"`.
        checkpoint.dir_name (str): Subdirectory name for checkpoints. Defaults to `"checkpoints"`.
        checkpoint.async_enabled (bool): Whether to persist checkpoints asynchronously.
            Defaults to `True`.
        checkpoint.async_mode (str | None): Checkpoint async behavior.
            Supported values: `"disabled"`, `"async"`, `"async_with_pinned_mem"`.
            When unset (`None`), the runner derives the mode from `checkpoint.async_enabled`.
        checkpoint.dedicated_async_process_group (bool): Use a dedicated process group for async DCP
            checkpoint I/O to reduce interference with training collectives. Defaults to `True`.
        checkpoint.async_process_group_backend (str): Backend for the dedicated async checkpoint process
            group. Defaults to `"gloo"`.
        checkpoint.backend (str): Checkpoint backend selected at runtime by the runner
            (`"dcp"` for distributed runs, `"file"` otherwise when set to `"auto"`).
        checkpoint.wait_timeout (float): Timeout in seconds when draining async checkpoint writes
            during runner shutdown (`None` waits indefinitely).
        parallel.axes.replicate (int): Data-replication degree for DDP/HSDP-style replication.
            Defaults to `1`.
        parallel.axes.shard (int): Data-sharding degree for FSDP-style sharding.
            Defaults to `1`. Set one parallel axis, commonly `shard`, to `-1`
            to auto-fill it from `WORLD_SIZE` and the other configured axes.
        parallel.axes.context (int): Context/sequence parallel degree. Defaults to `1`.
        parallel.axes.pipeline (int): Pipeline-parallel degree. Defaults to `1`.
        parallel.axes.tensor (int): Tensor-parallel degree. Defaults to `1`.
        parallel.axes.expert (int): Expert-parallel degree for MoE models. Defaults to `1`.
        parallel.axes.expert_tensor (int): Expert tensor-parallel degree for MoE models. Defaults to `1`.
        parallel.pipeline_schedule (str): Pipeline schedule class name resolved by
            `torch.distributed.pipelining.schedules.get_schedule_class`.
            Defaults to `"1F1B"`.
        parallel.pipeline_microbatch_size (int): Local microbatch size used to infer
            schedule microbatch count as `dataloader.batch_size // pipeline_microbatch_size`.
            Defaults to `1`.
        parallel.pipeline_n_microbatches (int): Explicit schedule microbatch count.
            When set, overrides `pipeline_microbatch_size`-based inference.
        parallel.module_fqns_per_model_part (list[list[str]] | None): Optional
            module FQNs for simple pipeline stage extraction. The outer list
            length is the total pipeline stage count and must be divisible by
            `parallel.axes.pipeline`; complex partitioning should use
            `model.build_pipeline_model_part(...)` or override
            `ParallelRunner.build_pipeline_model_part` /
            `ParallelRunner.build_pipeline_model_parts`.
        log (bool): Whether to enable file logging. Defaults to `True`.
            Logging is initialized on the main process only.
        tensorboard (bool): Whether to use TensorBoard for visualization. Defaults to `False`.
        wandb.enabled (bool): Whether to enable Weights & Biases scalar logging. Defaults to `False`.
        wandb.project (str | None): Optional W&B project name. Defaults to `lineage`.
        wandb.entity (str | None): Optional W&B entity/team override.
        wandb.group (str | None): Optional W&B group name. Defaults to `experiment`.
        wandb.name (str | None): Optional W&B display name. Defaults to stable runner `id`.
        wandb.job_type (str | None): Optional W&B job type.
        wandb.tags (list[str] | str | None): Optional W&B run tags.
        wandb.dir (str | None): Optional local W&B run directory. Defaults to run dir.
        wandb.mode (str | None): Optional W&B mode such as `"online"` or `"offline"`.
        ft.enabled (bool): Enable TorchFT-managed fault tolerance. Defaults to `False`.
        ft.process_group (str): TorchFT coordination backend. Supported values: `"gloo"` and `"nccl"`.
            Defaults to `"gloo"`.
        ft.process_group_timeout_ms (int): TorchFT process-group timeout in milliseconds.
            Defaults to `10000`.
        ft.replica_id (int): Replica-group identifier for this run. Defaults to `0`.
        ft.group_size (int): Number of replica groups participating in TorchFT. Defaults to `1`.
        ft.min_replica_size (int): Minimum healthy replicas required by TorchFT per step.
            Defaults to `1`.
        log_interval (int): Iterations between log outputs. If None, auto-calculated.
        checkpoint.interval (int): Interval between checkpoint save attempts for `latest`/`best`.
            The same cadence is used for history checkpoints.
            Uses epochs in epoch mode and global steps in step mode.
            If unset, runner defaults are used by mode.
        checkpoint.keep_latest_k (int): Number of framework-generated history checkpoints to retain.
            `0` disables retention pruning.
        checkpoint.load_only (bool): Disable checkpoint persistence entirely while still allowing checkpoint loading.
        checkpoint.enable_ft_dataloader_checkpoints (bool): Enable per-replica dataloader checkpoints for FT recovery.
            Uses DCP and stores checkpoints under
            `checkpoint.ft_dataloader_checkpoint_prefix-{checkpoint.ft_replica_id}`.
        checkpoint.ft_replica_id (str | None): Replica identifier used for FT dataloader checkpoint directory naming.
            Defaults to `FT_REPLICA_ID` environment variable, then process rank.
        checkpoint.ft_dataloader_checkpoint_prefix (str): Prefix used for FT per-replica checkpoint directories.
            Defaults to `"ft-replica"`.
        checkpoint.exclude_from_loading (list[str] | str | None): Checkpoint keys to skip during
            `load_checkpoint`, such as `"optimizer"`, `"scheduler"`, `"dataloaders"`, or dotted nested keys.
            The aliases `"data_loader"`, `"dataloader"`, and `"lr_scheduler"` are accepted.
        checkpoint.last_save_model_only (bool): Save model-only payload on final `last_step` checkpoint.
        checkpoint.export_dtype (str): Optional dtype cast for final model-only export
            (`fp32`/`fp16`/`bf16`/`fp64` aliases supported).
        dataloader.batch_size (int | None): Local dataloader batch size passed to
            `StatefulDataLoader`.
        dataloader.shuffle (bool | None): Optional shuffle override. When unset, train
            splits shuffle and non-train splits do not.
        dataloader.drop_last (bool | None): Optional drop-last override. When unset,
            train splits drop incomplete batches and non-train splits keep them.
        dataloader.num_workers / persistent_workers / prefetch_factor / pin_memory:
            Standard PyTorch DataLoader kwargs forwarded to `StatefulDataLoader`.
        dataloader.in_order (bool): PyTorch DataLoader ordering flag.
        dataloader.snapshot_every_n_steps (int | None): StatefulDataLoader snapshot cadence.
        dataloader.<split> (dict): Split-specific overrides merged on top of default
            dataloader kwargs, for example `dataloader.train.shuffle=False`.
        fsdp.enabled (bool): Enable FSDP2 wrapping in `ParallelRunner`.
            The FSDP mesh is derived from `parallel.axes.replicate`,
            `parallel.axes.shard`, and later `parallel.axes.context`.
        fsdp.reshard_after_forward (bool | int | None): Optional FSDP2 reshard policy.
        fsdp.mp_policy: Optional FSDP2 mixed precision policy.
        fsdp.offload_policy: Optional FSDP2 CPU offload policy.
        compile.enable (bool): Whether to enable `torch.compile` for runner-selected model compilation points.
        compile.backend (str): Optional backend passed to `torch.compile`.
        compile.fullgraph (bool): Optional `fullgraph` flag for `torch.compile`.
        compile.dynamic (bool): Optional `dynamic` flag for `torch.compile`.
        compile.mode (str): Optional mode passed to `torch.compile`.
        compile.options (dict): Optional options passed to `torch.compile`.
        compile.optimize_ddp (str | None): Optional `torch._dynamo.config.optimize_ddp` value.
            Defaults to `"ddp_optimizer"` when model compile is enabled.
        compile.precompile_artifact_dir (str | None): Optional directory for GraphRunner torch compiler
            cache artifacts. Current eager runners ignore this setting.
        compile.memory_policy (str | None): Optional graph-memory policy label for experimental graph paths.
            GraphRunner currently accepts `None`/`"default"`; activation remat/offload policies require a
            dedicated graph pass pipeline.
        comm.init_timeout_seconds (int | None): Optional distributed process-group timeout used during
            initialization and early startup.
        comm.train_timeout_seconds (int | None): Optional tighter distributed process-group timeout applied
            once after the first successful optimizer step.
        gc.interval (int | None): Optional periodic Python GC cadence.
            When unset, runner-managed GC pacing is disabled.
        gc.generation (int): Python GC generation passed to `gc.collect(...)` when pacing is enabled.
            Defaults to `1`.
        gc.disable_automatic (bool): Disable CPython automatic GC while runner-managed pacing is enabled.
            Defaults to `True`.
        profiling.enabled (bool): Enable bounded-step `torch.profiler` tracing. Defaults to `False`.
        profiling.wait (int): Profiler schedule wait steps before warmup. Defaults to `1`.
        profiling.warmup (int): Profiler schedule warmup steps. Defaults to `1`.
        profiling.active (int): Profiler schedule active trace steps. Defaults to `3`.
        profiling.repeat (int | None): Optional profiler schedule repeat count.
        profiling.record_shapes (bool): Enable shape recording in traces. Defaults to `False`.
        profiling.profile_memory (bool): Enable profiler-side memory recording. Defaults to `False`.
        profiling.with_stack (bool): Include Python stack traces in profiler output. Defaults to `False`.
        profiling.with_flops (bool): Enable profiler FLOPs estimation when available. Defaults to `False`.
        profiling.trace_dir (str): Relative or absolute trace output directory. Defaults to `"profiles"`.
        heartbeat.enabled (bool): Enable a machine-readable per-rank heartbeat/progress file. Defaults to `False`.
        heartbeat.interval_seconds (float): Heartbeat write interval in seconds. Defaults to `60.0`.
        heartbeat.dir_name (str): Subdirectory under the run dir for heartbeat files. Defaults to `"heartbeats"`.
    Examples:
        Basic usage:
        ```python
        # Create a config
        config = RunnerConfig()
        config.network.type = "resnet18"
        config.optim.lr = 0.001
        config.epochs = 10

        # Use in a runner
        runner = Runner(config)
        ```

        Custom config class with typed attributes:
        ```python
        class TrainingConfig(RunnerConfig):
            # Type annotations provide auto-completion and validation
            epochs: int = 100
            batch_size: int = 32
            precision: str = "fp16"

            def __init__(self):
                super().__init__()
                # Initialize nested settings
                self.optim.type = "adamw"
                self.optim.lr = 1e-3

            def post(self):
                # Called after parsing CLI args
                super().post()
                # Create derived settings
                self.experiment = f"{self.network.type}_{self.optim.lr}"
        ```

        Command-line integration:
        ```bash
        # Override config settings via CLI
        python train.py --epochs 50 --optim.lr 0.0005 --network.type resnet50
        ```

    Note:
        Always store all parameters needed to reproduce a run in the RunnerConfig.
        The RunnerConfig is automatically saved with checkpoints, enabling exact resumption.

    See Also:
        - [`Runner`][danling.runners.Runner]: Main runner class that uses this config.
        - [`chanfig.Config`](https://github.com/ultmaster/chanfig): Base config implementation.
    """

    # DO NOT set default value in class, as they won't be stored in `__dict__`.

    stack: str = "auto"

    seed: Optional[int] = None
    deterministic: bool = False

    steps: Optional[int] = None
    epochs: Optional[int] = None
    accum_steps: int = 1

    score_split: Optional[str] = None
    score_name: str = "loss"

    workspace_root: str = "experiments"
    auto_resume: bool = False
    resume: Optional[str] = None
    pretrained: Optional[str] = None
    log: bool = True
    tensorboard: bool = False
    wandb: WandbConfig
    ft: FaultToleranceConfig
    log_interval: Optional[int] = None

    compile: CompileConfig
    comm: CommConfig
    gc: GcConfig
    profiling: ProfilingConfig
    heartbeat: HeartbeatConfig
    checkpoint: CheckpointConfig
    dataloader: DataloaderConfig
    fsdp: FsdpConfig
    parallel: ParallelConfig

    def __post_init__(self, *args, **kwargs) -> None:
        super().__post_init__(*args, **kwargs)
        self.validate()
        if "compile" not in self:
            self.compile = CompileConfig()
        if "comm" not in self:
            self.comm = CommConfig()
        if "gc" not in self:
            self.gc = GcConfig()
        if "profiling" not in self:
            self.profiling = ProfilingConfig()
        if "heartbeat" not in self:
            self.heartbeat = HeartbeatConfig()
        if "wandb" not in self:
            self.wandb = WandbConfig()
        if "ft" not in self:
            self.ft = FaultToleranceConfig()
        if "checkpoint" not in self:
            self.checkpoint = CheckpointConfig()
        if "parallel" not in self:
            self.parallel = ParallelConfig()
        elif not isinstance(self.parallel, ParallelConfig):
            self.parallel = ParallelConfig(self.parallel)

    def post(self) -> None:
        super().post()
        self.validate()

    def validate(self) -> None:
        if self.steps is not None and self.epochs is not None:
            raise ValueError("`steps` and `epochs` are mutually exclusive; set only one training boundary")

    def canonical(self) -> chanfig.NestedDict:
        canonical = chanfig.NestedDict(self.dict())
        stack = normalize_stack_name(canonical.get("stack", "auto"))
        canonical["stack"] = stack
        for key in NON_SEMANTIC_CONFIG_KEYS:
            canonical.pop(key, None)

        checkpoint = canonical.get("checkpoint")
        if isinstance(checkpoint, Mapping):
            semantic_checkpoint = chanfig.NestedDict(checkpoint)
            backend = semantic_checkpoint.get("backend")
            if backend is not None:
                semantic_checkpoint["backend"] = str(backend).strip().lower()
            for key in NON_SEMANTIC_CHECKPOINT_KEYS:
                semantic_checkpoint.pop(key, None)
            if semantic_checkpoint:
                canonical["checkpoint"] = semantic_checkpoint
            else:
                canonical.pop("checkpoint", None)

        if stack != "parallel":
            canonical.pop("fsdp", None)
            canonical.pop("parallel", None)
        return canonical

    def __hash__(self) -> int:
        digest = hashlib.sha1(self.canonical().yamls().encode("utf-8")).digest()
        return int.from_bytes(digest[:8], byteorder="big", signed=False)

RunnerState dataclass

Bases: _StatefulBase

Checkpointable state container for a runner instance.

Attributes:

Name Type Description
config RunnerConfig

Runner configuration associated with this state object.

train RunnerTrainState

Training progress counters.

elastic RunnerElasticState

Torchelastic restart metadata.

rng RunnerRNGState

Python/NumPy/Torch RNG snapshots.

Source code in danling/runners/state.py
Python
@dataclass
class RunnerState(_StatefulBase):
    """
    Checkpointable state container for a runner instance.

    Attributes:
        config: Runner configuration associated with this state object.
        train: Training progress counters.
        elastic: Torchelastic restart metadata.
        rng: Python/NumPy/Torch RNG snapshots.
    """

    config: RunnerConfig
    train: RunnerTrainState = field(default_factory=RunnerTrainState)
    elastic: RunnerElasticState = field(default_factory=RunnerElasticState)
    rng: RunnerRNGState = field(default_factory=RunnerRNGState)

    def __post_init__(self) -> None:
        if not isinstance(self.config, RunnerConfig):
            self.config = RunnerConfig(self.config)

    def state_dict(self) -> dict[str, Any]:
        return {
            "train": self.train.state_dict(),
            "elastic": self.elastic.state_dict(),
            "rng": self.rng.state_dict(),
        }

    def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
        for name in ("train", "elastic", "rng"):
            value = state_dict.get(name)
            if isinstance(value, Mapping):
                getattr(self, name).load_state_dict(value)

TorchRunner

Bases: Fp8Mixin, BaseRunner

PyTorch-native runner for training, evaluation, and inference.

Use this runner for single-model PyTorch training with optional DDP, autocast/FP8, torch.compile, stateful dataloaders, metric logging, and file or torch.distributed.checkpoint persistence.

Users must provide self.model before construction completes. Most training tasks also provide self.criterion, and either self.optimizer or config.optim. Datasets may be supplied through self.datasets and will be materialized into StatefulDataLoader instances during __post_init__.

The default batch contract is intentionally simple: mappings use input/target, sequences use index 0/1, and any other value is treated as model input with no target. Override train_step, evaluate_step, or infer_step when a task needs a different contract.

Attributes:

Name Type Description
model Module

Local model module after materialization (possibly DDP-wrapped).

ema Module | None

Optional EMA/evaluation model.

criterion Callable | None

Loss callable used by default train/evaluate steps.

optimizer Optimizer | None

Optimizer used by the runner or backend engine.

scheduler Any | None

Optional LR scheduler.

optimizer_container OptimizerContainer | None

Helper that owns optimizer step, clipping, non-finite checks, and step-scheduler dispatch.

compiler Compiler

torch.compile policy object.

scheduler_interval str

Effective scheduler interval ("step" or epoch/metric-style interval).

scheduler_monitor str | None

Optional metric path used for metric schedulers.

Source code in danling/runners/torch_runner.py
Python
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
class TorchRunner(Fp8Mixin, BaseRunner):
    r"""
    PyTorch-native runner for training, evaluation, and inference.

    Use this runner for single-model PyTorch training with optional DDP,
    autocast/FP8, `torch.compile`, stateful dataloaders, metric logging, and
    file or torch.distributed.checkpoint persistence.

    Users must provide `self.model` before construction completes. Most
    training tasks also provide `self.criterion`, and either `self.optimizer`
    or `config.optim`. Datasets may be supplied through `self.datasets` and
    will be materialized into `StatefulDataLoader` instances during
    `__post_init__`.

    The default batch contract is intentionally simple:
    mappings use `input`/`target`, sequences use index 0/1, and any other value
    is treated as model input with no target. Override `train_step`,
    `evaluate_step`, or `infer_step` when a task needs a different contract.

    Attributes:
        model: Local model module after materialization (possibly DDP-wrapped).
        ema: Optional EMA/evaluation model.
        criterion: Loss callable used by default train/evaluate steps.
        optimizer: Optimizer used by the runner or backend engine.
        scheduler: Optional LR scheduler.
        optimizer_container: Helper that owns optimizer step, clipping,
            non-finite checks, and step-scheduler dispatch.
        compiler: `torch.compile` policy object.
        scheduler_interval: Effective scheduler interval (`"step"` or
            epoch/metric-style interval).
        scheduler_monitor: Optional metric path used for metric schedulers.
    """

    model: nn.Module
    ema: nn.Module | None = None
    criterion: Callable | None = None
    optimizer: optim.Optimizer | None = None
    scheduler: Any | None = None
    optimizer_container: OptimizerContainer | None = None
    compiler: Compiler
    scheduler_interval: str = "step"
    scheduler_monitor: str | None = None
    _train_pg_timeout_reduced: bool = False
    _profiler_context: Any | None = None
    _profiler: Any | None = None
    _pending_loss_normalizer: int | None = None
    _accumulation_divisor_local: float = 0.0
    _accumulation_mode: str | None = None
    _train_window_will_flush: bool = False
    _optimizer_parameter_cache: OptimizerParameterCache | None = None
    _supports_torchft_runtime: bool = True

    _VALID_CHECKPOINT_BACKENDS = frozenset({"file", "dcp"})

    @classmethod
    def _validate_checkpoint_backend(cls, backend: str) -> str:
        """Normalize and validate a resolved checkpoint backend value."""
        backend = str(backend).strip().lower()
        if backend not in cls._VALID_CHECKPOINT_BACKENDS:
            raise ValueError(f"invalid checkpoint backend: {backend!r}. Expected one of: 'auto', 'file', 'dcp'.")
        return backend

    def __init__(self, config) -> None:
        if not isinstance(config, RunnerConfig):
            config = RunnerConfig(config)
        config.stack = normalize_stack_name(config.get("stack", "ddp"))
        checkpoint_backend = str(config.checkpoint.backend).strip().lower()
        if checkpoint_backend == "auto":
            checkpoint_backend = "dcp" if self.world_size > 1 else "file"
        config.checkpoint.backend = self._validate_checkpoint_backend(checkpoint_backend)
        super().__init__(config)

    def __post_init__(self):
        self._pending_loss_normalizer = None
        self._accumulation_divisor_local = 0.0
        self._accumulation_mode = None
        self._train_window_will_flush = False
        self._optimizer_parameter_cache = None
        if self.model is None:
            raise ValueError("cannot initialize TorchRunner: model is not initialized")
        if self.datasets:
            self.build_dataloaders()
        if self.ft is not None and self.ft.enabled and not self._supports_torchft_runtime:
            raise NotImplementedError(
                "TorchFT integration is currently supported by TorchRunner/DDP and ParallelRunner FSDP only"
            )
        self.compiler = Compiler(self.config.compile)
        self.setup_fp8()
        self.materialize_model()
        self.build_optimizer()
        self.build_scheduler()
        self._finalize_runtime_components()
        sched_cfg = self._get_scheduler_config()
        interval = sched_cfg.get("interval") if sched_cfg is not None else None
        monitor = sched_cfg.get("monitor") if sched_cfg is not None else None
        self.scheduler_interval = normalize_scheduler_interval(interval, self.scheduler)
        self.scheduler_monitor = None if monitor is None else str(monitor)
        self._bind_optimizer_container()
        self.auto_restore()
        self._init_profiler()
        super().__post_init__()

    def _finalize_runtime_components(self) -> None:
        """Hook for backend-specific engine/materialization after optimizer and scheduler build."""

    def init_distributed(self) -> None:
        r"""
        Initialize the distributed environment.

        The default implementation initializes the default torch.distributed
        process group from `WORLD_SIZE`/`RANK`/`LOCAL_RANK` environment
        variables when `WORLD_SIZE > 1`, sets the active CUDA device,
        broadcasts `self.timestamp` from rank 0, and seeds
        `elastic_state.restart_count` from `TORCHELASTIC_RESTART_COUNT`.

        **Called when:** once during `BaseRunner.__init__`, before
        `init_checkpoint_manager`, `init_fault_tolerance`, and
        `init_garbage_collection`. The runner is partially constructed at
        this point — `self.config`, `self.workspace`, `self.timestamp`, the
        dataloader container, and the default `FileCheckpointManager` are
        bound, but the model is not materialized and optimizers/dataloaders
        are not built.

        **Precondition:** environment variables `WORLD_SIZE`, `RANK`,
        `LOCAL_RANK` are set when running distributed. The default
        torch.distributed process group is **not** already initialized when
        `WORLD_SIZE > 1` — the runner owns process-group lifecycle.

        Raises:
            RuntimeError: the default process group is already initialized
                when `WORLD_SIZE > 1`.
            ValueError: `comm.init_timeout_seconds` is non-positive.

        **Side effects:** when `WORLD_SIZE > 1`, calls
        `dist.init_process_group(...)`, sets the active CUDA device when
        CUDA is available, and broadcasts `self.timestamp` from rank 0.
        Reads `TORCHELASTIC_RESTART_COUNT` into `elastic_state.restart_count`.

        !!! danger "Do not"
            - Initialize a process group via `dist.init_process_group(...)`
              outside the runner; the runner owns its lifecycle.
            - Build the model or dataloaders here; those happen in
              `__post_init__`.
            - Bind the checkpoint manager here; `init_checkpoint_manager`
              runs next.

        **Backend notes:**

        - `ParallelRunner` extends this hook: after calling `super()`, it
          builds the parallel topology (`build_topology`) and initializes
          per-axis process groups via `init_device_mesh`.
        - `DeepSpeedRunner` inherits the default; DeepSpeed reuses the
          default process group initialized here.
        """

        backend = self.config.get("backend", os.getenv("BACKEND"))
        init_method = self.config.get("init_method", os.getenv("INIT_METHOD"))
        init_timeout = self._comm_timeout("comm.init_timeout_seconds")
        world_size = int(os.getenv("WORLD_SIZE", "1"))
        rank = int(os.getenv("RANK", "0"))
        runtime_device = self.device
        use_cuda_runtime = torch.cuda.is_available() and runtime_device.type == "cuda"
        runtime_device_index = runtime_device.index if runtime_device.index is not None else self.local_rank
        dist_ready = dist.is_available() and dist.is_initialized()
        if world_size > 1 and dist_ready:
            raise RuntimeError(
                "default process group is already initialized; Runner requires owning process-group lifecycle"
            )
        if world_size > 1:
            if use_cuda_runtime:
                torch.cuda.set_device(runtime_device_index)
            init_kwargs: dict[str, Any] = {
                "backend": backend,
                "init_method": init_method,
                "world_size": world_size,
                "rank": rank,
            }
            if init_timeout is not None:
                init_kwargs["timeout"] = init_timeout
            dist.init_process_group(**init_kwargs)
            dist_ready = bool(dist.is_available() and dist.is_initialized())

        if dist_ready and use_cuda_runtime:
            torch.cuda.set_device(runtime_device_index)

        if dist_ready and self.world_size > 1:
            object_list = [self.timestamp]
            dist.broadcast_object_list(object_list)
            self.timestamp = str(object_list[0])

        restart_count = os.getenv("TORCHELASTIC_RESTART_COUNT")
        if restart_count is not None:
            self.elastic_state.restart_count = int(restart_count)

        self._train_pg_timeout_reduced = False

    def init_checkpoint_manager(self) -> None:
        """
        Bind the checkpoint manager corresponding to `config.checkpoint.backend`.

        The default dispatches by backend: when the backend is `"dcp"`, it
        binds a `TorchDistributedCheckpointManager` (or
        `TorchFTCheckpointManager` when FT dataloader checkpoints are
        enabled). For `"file"` it leaves the `FileCheckpointManager` already
        bound by `BaseRunner.__init__` in place.

        **Called when:** once during `BaseRunner.__init__`, after
        `init_distributed` and before `init_fault_tolerance`. The default
        `FileCheckpointManager` is already bound at this point — overrides
        should swap it via `set_checkpoint_manager(...)`, not by direct
        attribute assignment.

        **Precondition:** `config.checkpoint.backend` is normalized to one
        of `{"file", "dcp"}` (TorchRunner does this in `__init__`). When
        the backend is `"dcp"`, the default process group is initialized
        for distributed runs.

        **Side effects:** swaps `self.checkpoint_manager` via
        `set_checkpoint_manager(...)` when the backend differs from
        `"file"`. The prior manager is closed with a zero timeout.

        !!! danger "Do not"
            - Set `self.checkpoint_manager` directly; use
              `set_checkpoint_manager` so the prior manager is closed
              cleanly.
            - Initialize fault tolerance here; `init_fault_tolerance` runs
              next.
            - Bind the model or dataloaders here.

        **Backend notes:**

        - `DeepSpeedRunner` coerces `config.checkpoint.backend` to `"file"`
          in `__init__`, so this hook is a no-op for that backend.
        - `ParallelRunner` coerces the backend to `"dcp"`, so this hook
          always binds `TorchDistributedCheckpointManager` or
          `TorchFTCheckpointManager`.
        """
        checkpoint_backend = self.config.checkpoint.backend.lower()
        if checkpoint_backend == "dcp":
            ft_checkpoint_enabled = bool(
                self.config.get("ft.enabled", False)
                or self.config.get("checkpoint.enable_ft_dataloader_checkpoints", False)
            )
            manager_cls = TorchFTCheckpointManager if ft_checkpoint_enabled else TorchDistributedCheckpointManager
            self.set_checkpoint_manager(manager_cls(self))
            return
        # Backend is normalized to {"file", "dcp"} in `__init__`; "file" is the
        # remaining case and reuses the default `FileCheckpointManager` that
        # `BaseRunner.__init__` already bound.

    def _comm_timeout(self, key: str) -> timedelta | None:
        value = self.config.get(key)
        if value is None:
            return None
        seconds = int(value)
        if seconds <= 0:
            raise ValueError(f"{key} must be a positive integer, got {seconds}")
        return timedelta(seconds=seconds)

    def _timeout_process_groups(self) -> tuple[Any | None, ...]:
        groups: list[Any | None] = [None]
        if self.ft is not None and self.ft.replicate_process_group is not None:
            groups.append(self.ft.replicate_process_group)
        return tuple(groups)

    def _set_process_group_timeout(self, timeout: timedelta) -> None:
        if not (dist.is_available() and dist.is_initialized()):
            return
        set_pg_timeout = getattr(dist_c10d, "_set_pg_timeout", None)
        if not callable(set_pg_timeout):
            warn(
                "torch.distributed does not expose process-group timeout mutation; "
                "skipping comm.train_timeout_seconds update",
                RuntimeWarning,
                stacklevel=2,
            )
            return

        for group in self._timeout_process_groups():
            backend = str(dist.get_backend() if group is None else dist.get_backend(group)).lower()
            if backend != "nccl":
                continue

            barrier_kwargs = {} if group is None else {"group": group}
            if torch.cuda.is_available():
                dist.barrier(device_ids=[torch.cuda.current_device()], **barrier_kwargs)
                torch.cuda.synchronize()
            else:
                dist.barrier(**barrier_kwargs)

            try:
                set_pg_timeout(timeout, group)
            except TypeError:
                if group is not None:
                    warn(
                        "torch.distributed does not support subgroup timeout mutation; "
                        "skipping comm.train_timeout_seconds update for a non-default process group",
                        RuntimeWarning,
                        stacklevel=2,
                    )
                    continue
                set_pg_timeout(timeout)
            except Exception as exc:
                group_name = "default" if group is None else "subgroup"
                warn(f"failed to update {group_name} process-group timeout: {exc}", RuntimeWarning, stacklevel=2)

    def _maybe_reduce_train_process_group_timeout(self) -> None:
        if self._train_pg_timeout_reduced:
            return
        if self.train_state.global_step != 1:
            return
        timeout = self._comm_timeout("comm.train_timeout_seconds")
        if timeout is None:
            return
        self._set_process_group_timeout(timeout)
        self._train_pg_timeout_reduced = True

    def destroy_process_group(self) -> None:
        if not (dist.is_available() and dist.is_initialized()):
            return
        try:
            dist.destroy_process_group()
        except Exception as exc:
            warn(f"failed to destroy default process group: {exc}", RuntimeWarning, stacklevel=2)

    def _init_profiler(self) -> None:
        profiling = self.config.get("profiling")
        if not isinstance(profiling, Mapping) or not bool(profiling.get("enabled", False)):
            return

        wait = int(profiling.get("wait", 1))
        warmup = int(profiling.get("warmup", 1))
        active = int(profiling.get("active", 3))
        repeat = profiling.get("repeat")
        if wait < 0:
            raise ValueError(f"profiling.wait must be a non-negative integer, got {wait}")
        if warmup < 0:
            raise ValueError(f"profiling.warmup must be a non-negative integer, got {warmup}")
        if active <= 0:
            raise ValueError(f"profiling.active must be a positive integer, got {active}")
        if repeat is not None:
            repeat = int(repeat)
            if repeat <= 0:
                raise ValueError(f"profiling.repeat must be a positive integer, got {repeat}")

        activities = [torch.profiler.ProfilerActivity.CPU]
        if torch.cuda.is_available() and self.device.type == "cuda":
            activities.append(torch.profiler.ProfilerActivity.CUDA)

        schedule_kwargs: dict[str, Any] = {"wait": wait, "warmup": warmup, "active": active}
        if repeat is not None:
            schedule_kwargs["repeat"] = repeat

        trace_dir = os.fsdecode(str(profiling.get("trace_dir", "profiles")))
        if not os.path.isabs(trace_dir):
            trace_dir = os.path.join(self.workspace.dir, trace_dir)
        trace_dir = os.path.join(trace_dir, self.timestamp, f"rank-{self.rank:05d}")
        os.makedirs(trace_dir, exist_ok=True)
        profiler_context = torch.profiler.profile(
            activities=activities,
            schedule=torch.profiler.schedule(**schedule_kwargs),
            on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
            record_shapes=bool(profiling.get("record_shapes", False)),
            profile_memory=bool(profiling.get("profile_memory", False)),
            with_stack=bool(profiling.get("with_stack", False)),
            with_flops=bool(profiling.get("with_flops", False)),
        )
        profiler = profiler_context.__enter__()
        if hasattr(profiler, "step_num"):
            profiler.step_num = self.train_state.global_step
        self._profiler_context = profiler_context
        self._profiler = profiler

    def _step_profiler(self) -> None:
        if self._profiler is None:
            return
        self._profiler.step()

    def _close_profiler(self) -> None:
        profiler_context = self._profiler_context
        self._profiler_context = None
        self._profiler = None
        if profiler_context is None:
            return
        profiler_context.__exit__(None, None, None)

    @on_main_process
    def init_tensorboard(self, *args, **kwargs) -> None:
        r"""
        Set up TensorBoard SummaryWriter.
        """

        from torch.utils.tensorboard.writer import SummaryWriter  # pylint: disable=C0415

        if "log_dir" not in kwargs:
            kwargs["log_dir"] = os.path.join(self.workspace.dir, "tensorboard", self.timestamp)

        self.writer = SummaryWriter(*args, **kwargs)
        self.writer.add_scalar = catch(OSError, verbose=False)(self.writer.add_scalar)

    def set_seed(self, seed: int | None = None, bias: int | bool | None = None) -> int:
        r"""
        Set up random seed.

        Args:
            seed: Random seed to set.
                Defaults to `self.config.seed` (`config.seed`).

            bias: Make the seed different for each processes.
                This is used to ensure the data augmentation are applied differently on every processes.
                Defaults to `self.rank`.
                Set to `False` to disable this feature.
        Returns:
            Random seed set.
        """

        base_seed = seed if seed is not None else self.config.seed  # type: ignore[assignment]
        if base_seed is None:
            base_seed = random.randint(0, 2**32 - 1)
            if self.distributed and dist.is_initialized():
                object_list = [base_seed]
                dist.broadcast_object_list(object_list)
                base_seed = object_list[0]
        base_seed = int(base_seed)
        # Keep `config.seed` as the global/base seed (before per-rank bias).
        self.config.seed = base_seed

        process_seed = base_seed
        if bias is None:
            if self.ft is not None:
                _, bias = self.ft.data_parallel_info(self.world_size, self.rank)
            else:
                bias = self.rank
        if bias:
            process_seed += int(bias)

        torch.manual_seed(process_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(process_seed)
        if np_random is not None:
            np_random.seed(process_seed)
        random.seed(process_seed)
        self.rng_state.python = random.getstate()
        self.rng_state.numpy = np_random.get_state() if np_random is not None else None
        self.rng_state.torch_cpu = torch.get_rng_state()
        if torch.cuda.is_available():
            self.rng_state.torch_cuda = torch.cuda.get_rng_state_all()
        else:
            self.rng_state.torch_cuda = None
        return process_seed

    def set_deterministic(self) -> None:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.use_deterministic_algorithms(True)

    def materialize_model(self) -> None:
        """
        Move the model to the runtime device, optionally compile, and wrap
        with DDP when distributed.

        The default is a single-module DDP-style materialization: it moves
        `self.model` to `self.device`, applies any FP8 module policy when
        FP8 is enabled, runs `torch.compile` via `self.compiler` (under the
        DDP-optimizer context when wrapping is needed), and wraps the result
        with `nn.parallel.DistributedDataParallel` when world size > 1.

        **Called when:** once during `__post_init__`, after `setup_fp8()`
        and before `build_optimizer()`. The order matters — the optimizer
        must see post-wrap parameters.

        **Precondition:** `self.model` is set (typically by the user before
        constructing the runner). `self.device` resolves to the runtime
        device.

        Raises:
            ValueError: `self.model` is not initialized.

        **Side effects:** moves `self.model` to `self.device`; applies FP8
        module policy when `self.fp8_enabled`; compiles via
        `self.compiler.compile(...)` under the DDP-optimizer context when
        wrapping is needed; wraps with `DistributedDataParallel` for world
        size > 1. Moves `self.ema` to device when EMA is bound.

        !!! danger "Do not"
            - Build the optimizer or scheduler here; they run after this
              hook.
            - Skip the device move when overriding (tensors must live on
              `self.device` before the forward pass).
            - Re-wrap an already-wrapped model (e.g., DDP-wrap a DDP module).

        **Backend notes:**

        - `DeepSpeedRunner` overrides this hook to move the model to device
          and compile only; the DeepSpeed engine wraps the model later in
          `_finalize_runtime_components`.
        - `ParallelRunner` overrides this hook for FSDP2, pipeline-parallel
          schedules, and tensor/expert/context parallelism (via the
          `parallelize_model` and `apply_activation_checkpointing` hooks).
        """
        if self.model is None:
            raise ValueError("cannot materialize model: model is not initialized")

        model = self.model.to(self.device)
        self.model = model
        if self.fp8_enabled:
            self.apply_fp8_module_policy_to_model_parts()
            model = self.model
        should_wrap_ddp = self.distributed and not isinstance(
            model, (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)
        )
        with self.compiler.ddp_optimizer() if should_wrap_ddp else nullcontext():
            model = self.compiler.compile(model)
        if should_wrap_ddp:
            model = nn.parallel.DistributedDataParallel(model)
        self.model = model

        if self.ema is not None:
            self.ema = self.ema.to(self.device)

    def unwrap(self, model: nn.Module) -> nn.Module:
        if isinstance(model, (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)):
            return model.module
        return model

    def _iter_unique_parameters(self, modules: Sequence[nn.Module]) -> Iterator[nn.Parameter]:
        seen: set[int] = set()
        for module in modules:
            for parameter in module.parameters():
                parameter_id = id(parameter)
                if parameter_id in seen:
                    continue
                seen.add(parameter_id)
                yield parameter

    def _iter_unique_named_parameters(
        self, modules: Sequence[nn.Module], prefixes: Sequence[str] | None = None
    ) -> Iterator[tuple[str, nn.Parameter]]:
        seen: set[int] = set()
        if prefixes is None:
            prefixes = ("",) * len(modules)
        if len(prefixes) != len(modules):
            raise ValueError("prefix count must match module count")
        for module, prefix in zip(modules, prefixes):
            for name, parameter in module.named_parameters():
                parameter_id = id(parameter)
                if parameter_id in seen:
                    continue
                seen.add(parameter_id)
                yield f"{prefix}{name}", parameter

    def iter_optimizer_parameters(self) -> Iterator[nn.Parameter]:
        if self.model is None:
            return
        yield from self._iter_unique_parameters((self.unwrap(self.model),))

    def iter_optimizer_named_parameters(self) -> Iterator[tuple[str, nn.Parameter]]:
        if self.model is None:
            return
        yield from self._iter_unique_named_parameters((self.unwrap(self.model),))

    def _optimizer_param_group_options(
        self,
        group_cfg: Mapping[str, Any],
        optim_cfg: Mapping[str, Any],
        *,
        index: int,
    ) -> dict[str, Any]:
        options = {
            str(key): value
            for key, value in group_cfg.items()
            if key not in {"pattern", "params", "lr_multiplier", "weight_decay_multiplier", "beta1", "beta2"}
        }
        if "lr_multiplier" in group_cfg:
            if "lr" not in optim_cfg:
                raise ValueError(f"optim.param_groups[{index}].lr_multiplier requires optim.lr")
            options["lr"] = float(optim_cfg["lr"]) * float(group_cfg["lr_multiplier"])
        if "weight_decay_multiplier" in group_cfg:
            if "weight_decay" not in optim_cfg:
                raise ValueError(f"optim.param_groups[{index}].weight_decay_multiplier requires optim.weight_decay")
            options["weight_decay"] = float(optim_cfg["weight_decay"]) * float(group_cfg["weight_decay_multiplier"])

        beta1 = group_cfg.get("beta1")
        beta2 = group_cfg.get("beta2")
        if beta1 is not None or beta2 is not None:
            if "betas" not in optim_cfg:
                raise ValueError(f"optim.param_groups[{index}].beta1/beta2 requires optim.betas")
            beta1_default, beta2_default = optim_cfg["betas"]
            options["betas"] = (
                float(beta1_default if beta1 is None else beta1),
                float(beta2_default if beta2 is None else beta2),
            )
        return options

    def _build_optimizer_param_groups(self, optim_cfg: Mapping[str, Any]) -> list[nn.Parameter] | list[dict[str, Any]]:
        group_configs = optim_cfg.get("param_groups")
        if group_configs is None:
            return list(self.iter_optimizer_parameters())
        if isinstance(group_configs, (str, bytes, Mapping)) or not isinstance(group_configs, Sequence):
            raise ValueError("optim.param_groups must be a sequence of mappings")

        named_parameters = list(self.iter_optimizer_named_parameters())
        if not named_parameters:
            return []

        assigned: set[int] = set()
        param_groups: list[dict[str, Any]] = []
        for index, group_cfg in enumerate(group_configs):
            if not isinstance(group_cfg, Mapping):
                raise ValueError(f"optim.param_groups[{index}] must be a mapping")
            pattern = group_cfg.get("pattern")
            if pattern is None:
                raise ValueError(f"optim.param_groups[{index}] requires `pattern`")
            regex = re.compile(str(pattern))
            parameters = [
                parameter
                for name, parameter in named_parameters
                if id(parameter) not in assigned and regex.search(name) is not None
            ]
            if not parameters:
                warn(
                    f"optim.param_groups[{index}] pattern {pattern!r} matched no parameters",
                    RuntimeWarning,
                    stacklevel=2,
                )
                continue
            assigned.update(id(parameter) for parameter in parameters)
            param_groups.append(
                {
                    "params": parameters,
                    **self._optimizer_param_group_options(group_cfg, optim_cfg, index=index),
                }
            )

        unmatched = [parameter for _name, parameter in named_parameters if id(parameter) not in assigned]
        if unmatched:
            param_groups.append({"params": unmatched})
        return param_groups

    def build_optimizer(self) -> None:
        """
        Auto-build the optimizer from `config.optim` (or `config.optimizer`)
        when `self.optimizer` is absent.

        The default iterates parameters via `iter_optimizer_parameters` and
        dispatches to the `OPTIMIZERS` registry with the merged config. If
        `optim.param_groups` is configured, entries are matched by regex
        `search` against `iter_optimizer_named_parameters`; unmatched
        parameters keep the optimizer-level defaults.

        **Called when:** once during `TorchRunner.__post_init__`, after
        `materialize_model` (so parameters reflect DDP/FSDP wrapping) and
        before `build_scheduler`.

        **Precondition:** `self.model` is materialized and on `self.device`.
        `self.optimizer` is `None` (the auto-build is skipped when the user
        has already bound an optimizer).

        **Side effects:** sets `self.optimizer` to the registry-built
        instance.

        !!! danger "Do not"
            - Run before `materialize_model`; parameters won't reflect
              DDP/FSDP wrapping.
            - Build a scheduler here.
            - Override parameter enumeration here; override
              `iter_optimizer_parameters` / `iter_optimizer_named_parameters`
              instead so subclass topology (e.g., `ParallelRunner.model_parts`)
              is preserved.

        **Backend notes:**

        - `DeepSpeedRunner` inherits this hook; DeepSpeed may replace the
          optimizer with a DeepSpeed-managed instance during
          `_finalize_runtime_components`.
        - `ParallelRunner` inherits this hook but overrides
          `iter_optimizer_parameters` to enumerate `self.model_parts`.
        """
        if self.optimizer is not None or self.model is None:
            return
        optim_cfg = self.config.get("optim")
        if optim_cfg is None:
            optim_cfg = self.config.get("optimizer")
        if not isinstance(optim_cfg, Mapping) or not optim_cfg:
            return
        optimizer_kwargs = dict(optim_cfg)
        optimizer_kwargs.pop("param_groups", None)
        parameters = self._build_optimizer_param_groups(optim_cfg)
        if not parameters:
            return
        self.optimizer = OPTIMIZERS.build(params=parameters, **optimizer_kwargs)

    def _get_scheduler_config(self) -> Mapping[str, Any] | None:
        sched_cfg = self.config.get("sched")
        if sched_cfg is None:
            sched_cfg = self.config.get("scheduler")
        if not isinstance(sched_cfg, Mapping):
            return None
        return sched_cfg

    def build_scheduler(self) -> None:
        """
        Auto-build the LR scheduler from `config.sched` (or
        `config.scheduler`) when `self.scheduler` is absent.

        The default pops `interval` and `monitor` from the config (those
        drive runner-level dispatch, not scheduler construction), defaults
        `total_steps` to `self.steps` when computable, and dispatches to
        the `SCHEDULERS` registry with `self.optimizer` and the merged
        config.

        **Called when:** once during `TorchRunner.__post_init__`, after
        `build_optimizer`.

        **Precondition:** `self.optimizer` is bound. `self.scheduler` is
        `None` (the auto-build is skipped when the user has already bound a
        scheduler).

        **Side effects:** sets `self.scheduler` to the registry-built
        instance.

        !!! danger "Do not"
            - Run before `build_optimizer`; the scheduler must wrap an
              optimizer.
            - Set scheduler interval or monitor here; configure them via
              `config.sched.interval` / `config.sched.monitor`.

        **Backend notes:**

        - `DeepSpeedRunner` inherits this hook; the scheduler may be handed
          to the DeepSpeed engine in `_finalize_runtime_components` when
          its effective interval is `"step"`. Otherwise the runner retains
          it.
        """
        if self.scheduler is not None or self.optimizer is None:
            return
        sched_cfg = self._get_scheduler_config()
        if not isinstance(sched_cfg, Mapping) or not sched_cfg:
            return
        scheduler_kwargs = dict(sched_cfg)
        scheduler_kwargs.pop("interval", None)
        scheduler_kwargs.pop("monitor", None)
        if "total_steps" not in scheduler_kwargs:
            steps = self.steps
            if steps is not None:
                scheduler_kwargs["total_steps"] = steps
        self.scheduler = SCHEDULERS.build(self.optimizer, **scheduler_kwargs)

    def _bind_optimizer_container(self) -> None:
        if self.optimizer is None:
            self.optimizer_container = None
            return
        self.optimizer_container = OptimizerContainer(
            self.optimizer,
            scheduler=self.scheduler,
            scheduler_interval=self.scheduler_interval,
        )

    def _resolve_scheduler_metric(self, result: Mapping[str, Any]) -> Any:
        def scalarize(value: Any) -> Any:
            if isinstance(value, torch.Tensor):
                if value.numel() != 1:
                    raise ValueError(
                        "scheduler monitor must resolve to a scalar metric, "
                        f"but got tensor with shape {tuple(value.shape)}"
                    )
                return value.item()
            return value

        monitor = self.scheduler_monitor or self.config.score_name

        if "." in monitor:
            value: Any = result
            for key in monitor.split("."):
                if not isinstance(value, Mapping) or key not in value:
                    raise ValueError(
                        f"could not resolve scheduler.monitor={monitor!r} from aggregated result {dict(result)!r}"
                    )
                value = value[key]
            return scalarize(value)

        score_split = self.score_split
        if score_split is not None:
            split_result = result.get(score_split)
            if isinstance(split_result, Mapping) and monitor in split_result:
                return scalarize(split_result[monitor])

        if monitor in result and not isinstance(result[monitor], Mapping):
            return scalarize(result[monitor])

        matches: list[tuple[str, Any]] = []
        for split_name, split_result in result.items():
            if isinstance(split_result, Mapping) and monitor in split_result:
                matches.append((split_name, split_result[monitor]))

        if len(matches) == 1:
            return scalarize(matches[0][1])
        if len(matches) > 1:
            splits = ", ".join(split_name for split_name, _ in matches)
            raise ValueError(
                f"ambiguous scheduler.monitor={monitor!r}: matched multiple splits ({splits}). "
                "Use '<split>.<metric>' to disambiguate."
            )

        raise ValueError(f"could not resolve scheduler.monitor={monitor!r} from aggregated result {dict(result)!r}")

    def _step_epoch_scheduler(self, result: Mapping[str, Any]) -> bool:
        if self.scheduler is None or self.scheduler_interval != "epoch":
            return False

        scheduler_metric = SCHEDULER_METRIC_UNSET
        if scheduler_requires_metric(self.scheduler):
            scheduler_metric = self._resolve_scheduler_metric(result)

        if self.optimizer_container is not None:
            return self.optimizer_container.step_scheduler(scheduler_metric=scheduler_metric)
        return step_scheduler(self.scheduler, scheduler_metric=scheduler_metric)

    def build_dataloaders(self):
        """
        Build dataloaders for dataset splits not already materialized.

        The default iterates `self.datasets`, merges `config.dataloader`
        defaults with split-specific overrides (`config.dataloader.<split>`),
        constructs a sampler via `build_datasampler`, and wraps each dataset
        in a `StatefulDataLoader` using `self.collate_fn`. Train splits
        default to `shuffle=True` and `drop_last=True`; non-train splits
        default to the opposite.

        **Called when:** once during `TorchRunner.__post_init__` when
        `self.datasets` is non-empty.

        **Precondition:** `self.datasets` is populated (typically by the
        user before constructing the runner). `self.dataloaders` is bound
        to a default-constructed `DataLoaderDict`.

        **Side effects:** populates `self.dataloaders[split]` for each
        split in `self.datasets` not already materialized. Existing entries
        in `self.dataloaders` are left untouched.

        !!! danger "Do not"
            - Override sampler logic here; override `build_datasampler`
              instead.
            - Override collation; set `self.collate_fn` or override
              `collate_fn` (classmethod) instead.
            - Bind the optimizer or scheduler here.

        **Backend notes:**

        - `ParallelRunner` substitutes `self.dataloaders` with a proxying
          dict in `__init__` so non-first/last pipeline stages receive a
          `StepProxyLoader` view. The build logic itself is inherited.
        """
        datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders}
        dataloader_config = self.config.get("dataloader", NestedDict())
        default_kwargs = NestedDict({k: v for k, v in dataloader_config.items() if k not in self.datasets})
        split_kwargs = NestedDict({k: v for k, v in dataloader_config.items() if k in self.datasets})
        for k, dataset in datasets.items():
            kwargs = NestedDict(default_kwargs)
            if k in split_kwargs:
                kwargs.merge(split_kwargs[k], overwrite=True)
            is_train_split = k in self.train_splits
            shuffle = kwargs.pop("shuffle", is_train_split)
            kwargs.setdefault("drop_last", is_train_split)
            sampler = self.build_datasampler(dataset, split=k, shuffle=shuffle)
            self.dataloaders[k] = StatefulDataLoader(dataset, sampler=sampler, collate_fn=self.collate_fn, **kwargs)

    def build_datasampler(self, dataset: Any, *, split: str, shuffle: bool) -> Any:
        """
        Build the sampler for one dataset split.

        **Called when:** `build_dataloaders` materializes a split from
        `self.datasets`.

        Args:
            dataset: Dataset object for the split.
            split: Split name being materialized.
            shuffle: Whether this split should be sampled in shuffled order.

        Returns:
            A local random/sequential sampler in single-process mode, or a
            `DistributedSampler` in distributed mode.

        **Backend notes:**

        - `ParallelRunner` overrides replica/rank selection so data-parallel
          sampling follows its topology instead of raw global rank.
        """
        if self.distributed:
            num_replicas = self.world_size
            rank = self.rank
            if self.ft is not None:
                num_replicas, rank = self.ft.data_parallel_info(num_replicas, rank)
            return utils.data.distributed.DistributedSampler(
                dataset,
                num_replicas=num_replicas,
                rank=rank,
                shuffle=shuffle,
            )
        return utils.data.RandomSampler(dataset) if shuffle else utils.data.SequentialSampler(dataset)

    @staticmethod
    def collate_fn(batch):
        return utils.data.dataloader.default_collate(batch)

    def to_device(self, data: Any):
        """Move one batch to runtime device; override in subclasses for custom fast paths."""
        return to_device(data, self.device)

    def _step_mode_split_budget(
        self,
        *,
        remaining_steps: int,
        remaining_splits: int,
        loader: Any,
    ) -> int:
        if remaining_steps <= 0:
            return 0
        if remaining_splits <= 0:
            return remaining_steps

        fair_share = max((remaining_steps + remaining_splits - 1) // remaining_splits, 1)
        loader_length = self._loader_length(loader)
        if loader_length is None:
            return fair_share

        loader_step_budget = max((loader_length + self.accum_steps - 1) // self.accum_steps, 1)
        return min(fair_share, loader_step_budget, remaining_steps)

    @staticmethod
    def _set_loader_epoch(loader: Any, epoch: int) -> None:
        batch_sampler = getattr(loader, "batch_sampler", None)
        if hasattr(batch_sampler, "set_epoch"):
            batch_sampler.set_epoch(epoch)  # type: ignore[union-attr]
        sampler = getattr(loader, "sampler", None)
        if hasattr(sampler, "set_epoch"):
            sampler.set_epoch(epoch)  # type: ignore[union-attr]

    def loop_time(self, *, sync: bool = False) -> float:
        if sync and torch.cuda.is_available() and self.device.type == "cuda":
            torch.cuda.synchronize(self.device)
        return perf_counter()

    @property
    def reports_batch_telemetry(self) -> bool:
        return True

    @staticmethod
    def _as_int_or_none(value: Any) -> int | None:
        if isinstance(value, bool):
            return int(value)
        if isinstance(value, int):
            return int(value)
        if isinstance(value, float):
            return int(value)
        if torch.is_tensor(value) and value.numel() == 1:
            return int(value.detach().item())
        return None

    def _mapping_loss_normalizer(self, mapping: Mapping[str, Any] | None) -> int | None:
        if mapping is None:
            return None
        for key in ("loss_normalizer", "num_valid_tokens", "valid_tokens", "num_tokens", "token_count"):
            if key in mapping:
                normalizer = self._as_int_or_none(mapping[key])
                if normalizer is not None:
                    return normalizer
        return None

    def _tensor_loss_normalizer(self, target: Any) -> int | None:
        if not torch.is_tensor(target):
            return None
        ignore_index = getattr(self.criterion, "ignore_index", None)
        if ignore_index is not None:
            return int((target != int(ignore_index)).sum().item())
        if getattr(self.criterion, "reduction", None) == "mean":
            return int(target.numel())
        return None

    def _get_loss_normalizer(self, data: Any) -> int | None:
        if isinstance(data, Mapping):
            explicit = self._mapping_loss_normalizer(data)
            if explicit is not None:
                return explicit

            target = data.get("target")
            if isinstance(target, Mapping):
                explicit = self._mapping_loss_normalizer(target)
                if explicit is not None:
                    return explicit
            if target is not None:
                normalizer = self._tensor_loss_normalizer(target)
                if normalizer is not None:
                    return normalizer

            inputs = data.get("input")
            if isinstance(inputs, Mapping):
                attention_mask = inputs.get("attention_mask")
                if isinstance(attention_mask, torch.Tensor):
                    return int(attention_mask.detach().sum().item())
            return None

        if isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
            target = data[1] if len(data) > 1 else None
            if isinstance(target, Mapping):
                explicit = self._mapping_loss_normalizer(target)
                if explicit is not None:
                    return explicit
            if target is not None:
                normalizer = self._tensor_loss_normalizer(target)
                if normalizer is not None:
                    return normalizer
        return None

    def _loss_normalizer_sync_divisor(self) -> int:
        if self.ft is not None and self.ft.replicate_process_group is not None:
            return max(int(dist.get_world_size(group=self.ft.replicate_process_group)), 1)
        if dist.is_available() and dist.is_initialized():
            return max(self.world_size, 1)
        return 1

    def _reduce_loss_normalizer_total(self, local_total: float) -> float:
        if local_total <= 0:
            return local_total
        if self._loss_normalizer_sync_divisor() <= 1:
            return local_total
        if not (dist.is_available() and dist.is_initialized()):
            return local_total

        device = self.all_reduce_device()
        total_tensor = torch.tensor(local_total, dtype=torch.float64, device=device)
        self.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
        return float(total_tensor.item())

    def all_reduce_device(self) -> torch.device:
        if self.distributed and dist.is_available() and dist.is_initialized():
            group = self.all_reduce_group()
            if group is not None:
                try:
                    backend = str(dist.get_backend(group=group)).lower()
                except TypeError:
                    backend = str(dist.get_backend(group)).lower()
                except (RuntimeError, ValueError):
                    return torch.device("cpu")
            else:
                backend = str(dist.get_backend()).lower()
            if "nccl" in backend and torch.cuda.is_available():
                return self.device
        return torch.device("cpu")

    def all_reduce_group(self):
        if self.ft is not None and self.ft.replicate_process_group is not None:
            return self.ft.replicate_process_group
        return None

    def all_reduce(self, tensor: torch.Tensor, *, op=dist.ReduceOp.SUM) -> torch.Tensor:
        """Reduce tensor over the runner's replica/data-parallel collective domain."""
        if dist.is_available() and dist.is_initialized():
            dist.all_reduce(tensor, op=op, group=self.all_reduce_group())
        return tensor

    def _sync_optimizer_skip_decision(self, should_skip: bool) -> bool:
        if not (self.distributed and dist.is_available() and dist.is_initialized()):
            return should_skip
        payload = torch.tensor(float(should_skip), device=self.all_reduce_device())
        self.all_reduce(payload, op=dist.ReduceOp.MAX)
        return payload.item() > 0

    def reduce(self, tensor: torch.Tensor) -> torch.Tensor:
        """Average-reduce tensor over the runner's collective domain."""
        if not (dist.is_available() and dist.is_initialized()):
            return tensor
        group = self.all_reduce_group()
        group_size = max(self.world_size if group is None else dist.get_world_size(group=group), 1)
        if group_size <= 1:
            return tensor

        original_device = tensor.device
        payload_device = self.all_reduce_device()
        payload = tensor if original_device == payload_device else tensor.to(payload_device)
        self.all_reduce(payload, op=dist.ReduceOp.SUM)
        payload = payload / group_size
        if payload.device != original_device:
            payload = payload.to(original_device)
        return payload

    def reduce_loss_for_logging(self, loss: torch.Tensor | None, loss_n: int | None) -> torch.Tensor | None:
        """Detach and all-reduce weighted loss tensor for logging."""
        if loss is None:
            return None
        loss_value = loss.detach().to(dtype=torch.float64)
        if loss_value.ndim > 0:
            loss_value = loss_value.mean()
        normalizer = float(max(int(loss_n or 1), 1))
        payload_device = self.all_reduce_device()
        payload = torch.stack(
            (
                loss_value.to(device=payload_device) * normalizer,
                torch.tensor(normalizer, dtype=torch.float64, device=payload_device),
            )
        )
        self.all_reduce(payload, op=dist.ReduceOp.SUM)
        if payload[1].item() <= 0:
            return None
        return payload[0] / payload[1]

    def _reset_accumulation_normalization(self) -> None:
        """Clear the per-window accumulation state.

        Called at loop start and on every optimizer flush. After reset, the
        next call to `_scaled_loss_for_backward` re-classifies the window mode
        from its first batch's normalizer.
        """
        self._pending_loss_normalizer = None
        self._accumulation_divisor_local = 0.0
        self._accumulation_mode = None

    def _loss_scale_for_backward(self) -> float:
        """Consume the pending loss-normalizer signal and return this micro-step's loss scale."""
        loss_normalizer = self._pending_loss_normalizer
        if self._accumulation_mode is None:
            self._accumulation_mode = (
                "weighted" if loss_normalizer is not None and int(loss_normalizer) > 0 else "uniform"
            )

        if self._accumulation_mode == "weighted":
            if loss_normalizer is None or int(loss_normalizer) <= 0:
                raise ValueError(
                    "loss normalizer became unavailable within the current accumulation window. "
                    "Override `train_step()` or provide consistent batch metadata for weighted normalization."
                )
            normalizer = float(int(loss_normalizer))
            self._accumulation_divisor_local += normalizer
            self._pending_loss_normalizer = None
            return normalizer

        self._accumulation_divisor_local += 1.0
        self._pending_loss_normalizer = None
        return 1.0

    def _scaled_loss_for_backward(self, loss: torch.Tensor) -> torch.Tensor:
        """Scale and accumulate loss for one micro-step inside an accumulation window.

        Accumulation contract (window-local; reset on optimizer flush):

        1. **Mode detection.** First micro-step decides the window mode from
           ``self._pending_loss_normalizer``:
              - non-empty positive normalizer → ``"weighted"``
              - ``None`` or non-positive → ``"uniform"``
        2. **Mode is sticky.** Once the window picks ``"weighted"``, every
           subsequent micro-step in that window MUST also publish a positive
           normalizer; missing one raises with guidance to override
           `train_step()` or homogenize batch metadata. A ``"uniform"`` window
           ignores per-batch normalizers entirely.
        3. **Producer/consumer.** ``train_epoch`` / ``train_steps`` set
           ``self._pending_loss_normalizer`` from ``_get_loss_normalizer(data)``
           before each ``train_step`` call; this method consumes and clears it.
        4. **Override safety.** Subclasses overriding ``train_step`` are
           responsible for keeping the normalizer signal consistent across the
           window — either always present (weighted mode) or always absent
           (uniform mode). Mixing within one window is a programmer error.

        See `_reset_accumulation_normalization` for window boundaries and
        `_gradient_scale_for_step` for the optimizer-side rescale.
        """
        return loss * self._loss_scale_for_backward()

    def _gradient_scale_for_step(self) -> float | None:
        if self._accumulation_divisor_local <= 0:
            return None
        total = self._reduce_loss_normalizer_total(self._accumulation_divisor_local)
        if total <= 0:
            return None
        return float(max(self._loss_normalizer_sync_divisor(), 1)) / total

    def _optimizer_parameters_for_scaling(self) -> list[nn.Parameter]:
        if self.optimizer is None:
            return []
        if self.optimizer_container is not None:
            return self.optimizer_container.parameter_cache.get_parameters_for_clipping(self.optimizer)

        parameter_cache = self._optimizer_parameter_cache
        if parameter_cache is None:
            parameter_cache = OptimizerParameterCache(self.optimizer)
            self._optimizer_parameter_cache = parameter_cache
        else:
            parameter_cache.bind(self.optimizer)
        return parameter_cache.get_parameters_for_clipping()

    def _scale_optimizer_gradients(self, scale: float) -> None:
        if scale == 1.0 or self.optimizer is None:
            return
        parameters = self._optimizer_parameters_for_scaling()
        for parameter in parameters:
            grad = parameter.grad
            if grad is None:
                continue
            grad.mul_(float(scale))

    @contextmanager
    def train_context(self):
        """Context for one training micro-step (autocast + optional DDP no_sync)."""
        with self._train_step_context(no_sync_targets=self._train_no_sync_targets()):
            yield

    def _should_train_no_sync(self) -> bool:
        if self._train_window_will_flush:
            return False
        micro_steps = self.train_state.micro_step + 1
        return self.accum_steps > 1 and micro_steps % self.accum_steps != 0

    def forward_context(self):
        """Precision context used by train/eval/infer forward passes."""

        if self.fp8_enabled:
            return self.fp8_autocast()

        precision = self.precision
        if precision is None:
            return nullcontext()
        return torch.autocast(self.device.type, dtype=get_precision(precision))

    def _train_no_sync_targets(self) -> tuple[nn.Module, ...]:
        if isinstance(self.model, nn.parallel.DistributedDataParallel):
            return (self.model,)
        return ()

    @contextmanager
    def _train_step_context(self, *, no_sync_targets: tuple[nn.Module, ...] | list[nn.Module] = ()):
        autocast_context = self.forward_context()
        if self._should_train_no_sync() and no_sync_targets:
            with ExitStack() as stack:
                stack.enter_context(autocast_context)
                for module in no_sync_targets:
                    no_sync = getattr(module, "no_sync", None)
                    if callable(no_sync):
                        stack.enter_context(no_sync())
                        continue

                    set_requires_gradient_sync = getattr(module, "set_requires_gradient_sync", None)
                    if callable(set_requires_gradient_sync):
                        set_requires_gradient_sync(False)
                        stack.callback(set_requires_gradient_sync, True)
                        continue

                    raise TypeError(
                        "cannot disable gradient synchronization for "
                        f"{type(module).__name__}: expected `no_sync()` or `set_requires_gradient_sync(...)`"
                    )
                yield
            return

        with autocast_context:
            yield

    def train_step(self, data: Any) -> tuple[Any, torch.Tensor | None]:
        """
        Run one training micro-step.

        The default implementation runs forward → loss → metric update → backward
        → step for one micro-batch.

        **Called when:** once per micro-batch by `train_epoch`/`train_steps`. The
        caller seeds the loop's accumulation state before each invocation; this
        method consumes that state through `backward()` and `step()`.

        **Precondition:** `self.model`, `self.optimizer`, and `self.criterion`
        are bound; `self.mode == RunnerMode.train`.

        Args:
            data: One micro-batch. The default unpacks `data["input"]` /
                `data.get("target")` for mappings, `(data[0], data[1])` for
                non-string sequences, and `(data, None)` otherwise. Override
                `train_step` if your batch shape differs.

        Returns:
            `(pred, loss)`. `pred` is the model output (used by `metrics.update`).
            `loss` is the scalar loss returned to the caller for reduced logging.
            The default raises when `criterion` is missing or returns `None`;
            overrides may return `(pred, None)` to signal no loss available, in
            which case the caller skips loss bookkeeping.

        Raises:
            ValueError: `self.model` is not initialized, or `criterion` is missing
                or returned `None`.

        **Side effects:** moves `data` to `self.device`, runs forward under
        `train_context()` (autocast + optional DDP no-sync), updates
        `self.metrics` when bound, then calls `self.backward(loss)` and
        `self.step()` to scale gradients, advance accumulation state, and flush
        the optimizer on accumulation boundaries.

        !!! danger "Do not"
            - Zero gradients (`optimizer_step` does this on flush).
            - Call `self.optimizer.step()` directly (use `self.step()`).
            - Mutate `train_state.global_step` or `train_state.micro_step`.
            - Implement gradient scaling here (override `backward()` instead).
            - Call `save_checkpoint()` (cadence is owned by the loop method).

        **Backend notes:**

        - `DeepSpeedRunner` inherits the default; `backward`/`step` route
          through the DeepSpeed engine.
        - `ParallelRunner` overrides this method when a pipeline schedule is
          set; the schedule owns micro-batching and loss reduction.
        """
        data = self.to_device(data)
        with self.train_context():
            if isinstance(data, Mapping):
                inputs = data["input"]
                target = data.get("target")
            elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
                inputs = data[0]
                target = data[1] if len(data) > 1 else None
            else:
                inputs = data
                target = None

            if self.model is None:
                raise ValueError("cannot run train_step: model is not initialized")
            pred = self.model(**inputs) if isinstance(inputs, Mapping) else self.model(inputs)
            loss = self.criterion(pred, target) if self.criterion is not None else None
            if loss is None:
                raise ValueError("cannot run train_step: criterion did not produce a loss")
            if self.metrics is not None and pred is not None and target is not None:
                self.metrics.update(pred, target)
            self.backward(loss)
            self.step()
        return pred, loss

    def backward(self, loss: torch.Tensor) -> None:
        """
        Run backward pass on one micro-step loss.

        **Called when:** the default `train_step` has produced a loss tensor.
        The method receives the raw micro-step loss; accumulation scaling and
        loss-normalizer weighting are applied before `Tensor.backward()`.

        Args:
            loss: The loss tensor for this micro-step.

        **Side effects:** accumulates gradients on model parameters.

        !!! danger "Do not"
            - Advance the optimizer here; optimizer stepping belongs to
              `step()`/`optimizer_step()`.
            - Mutate `train_state` counters.

        **Backend notes:**

        - `DeepSpeedRunner` overrides this hook to call the DeepSpeed engine's
          backward method.
        """

        self._scaled_loss_for_backward(loss).backward()

    def step(self) -> None:
        """
        Advance the accumulation state machine after one training micro-step.

        **Called when:** `train_step` finishes backward for a micro-batch.

        **Side effects:** increments `train_state.micro_step` and calls
        `optimizer_step()` only when the accumulation boundary is reached or
        the surrounding loop marks the current batch as the final flush in a
        partial window.

        !!! danger "Do not"
            - Call this from evaluation/inference paths.
            - Call `optimizer_step()` in addition to this method from the same
              micro-step.
            - Adjust `train_state.micro_step` in `train_step` overrides.
        """
        micro_steps = self.train_state.micro_step + 1
        self.train_state.micro_step = micro_steps
        if self._train_window_will_flush:
            self.optimizer_step()
            remainder = micro_steps % self.accum_steps
            if self.accum_steps > 1 and remainder != 0:
                self.train_state.micro_step += self.accum_steps - remainder
            return
        if self.accum_steps <= 1 or micro_steps % self.accum_steps == 0:
            self.optimizer_step()

    def optimizer_step(self) -> bool:
        """
        Perform one backend optimizer update.

        The default Torch implementation waits for checkpoint staging, applies
        accumulated-loss gradient scaling, optional grad clipping, non-finite
        grad skip logic, optimizer/scheduler stepping through
        `OptimizerContainer`, gradient zeroing, profiler advancement, and
        garbage-collection cadence.

        **Called when:** `step()` reaches an accumulation boundary, or
        `_flush_pending_optimizer_step()` flushes a partial boundary before
        shutdown.

        Returns:
            `True` when an optimizer update is applied, otherwise `False`.

        **Side effects:** may update optimizer/scheduler state; increments
        `train_state.global_step` only when an update is actually applied.

        !!! danger "Do not"
            - Increment `global_step` on skipped updates.
            - Forget to zero gradients after a successful update or skipped
              non-finite update.
            - Bypass `checkpoint_manager.maybe_wait_for_staging()`.

        **Backend notes:**

        - `DeepSpeedRunner` overrides this hook because the DeepSpeed engine
          owns the concrete optimizer update.
        """
        if self.optimizer_container is None and self.optimizer is None:
            raise ValueError(
                "cannot perform optimizer step: no optimizer is configured; "
                "set `self.optimizer`, implement `build_optimizer()`, or override `optimizer_step()`"
            )

        self.checkpoint_manager.maybe_wait_for_staging()
        grad_scale = self._gradient_scale_for_step()
        if grad_scale is not None:
            self._scale_optimizer_gradients(grad_scale)
        max_grad_value = self.max_grad_value
        max_grad_norm = self.max_grad_norm
        skip_nonfinite_grad = self.skip_nonfinite_grad
        if self.optimizer_container is not None:
            if skip_nonfinite_grad:
                has_nonfinite_grad = self.optimizer_container.has_nan_inf_grad()
                has_nonfinite_grad = self._sync_optimizer_skip_decision(has_nonfinite_grad)
                if has_nonfinite_grad:
                    self.optimizer_container.zero_grad()
                    self._reset_accumulation_normalization()
                    return False

            stepped = self.optimizer_container.step(
                max_grad_value=max_grad_value,
                max_grad_norm=max_grad_norm,
                zero_grad=True,
                skip_nonfinite_grad=False,
            )
            if not stepped:
                self._reset_accumulation_normalization()
                return False
        elif self.optimizer is not None:
            self.optimizer.step()
            self.optimizer.zero_grad()

        self._reset_accumulation_normalization()
        self.train_state.global_step += 1
        self._step_profiler()
        self._maybe_reduce_train_process_group_timeout()
        self.supervisor.maybe_collect_garbage(self.train_state.global_step, scope="train")
        return True

    def _flush_pending_optimizer_step(self) -> bool:
        """
        Flush a partial accumulation window at loop boundaries.

        Returns:
            `True` when a boundary flush produced an optimizer update.
        """
        if self.accum_steps <= 1:
            return False
        remainder = self.train_state.micro_step % self.accum_steps
        if remainder == 0:
            return False
        if self.distributed and self._train_no_sync_targets():
            self._discard_pending_optimizer_step(remainder)
            return False
        stepped = self.optimizer_step()
        # Boundary flush clears current accumulation window; realign to the next
        # accumulation boundary so the next loop starts with a fresh full window.
        self.train_state.micro_step += self.accum_steps - remainder
        return stepped

    def _discard_pending_optimizer_step(self, remainder: int | None = None) -> None:
        if self.accum_steps <= 1:
            return
        if remainder is None:
            remainder = self.train_state.micro_step % self.accum_steps
        if remainder == 0:
            return
        if self.optimizer_container is not None:
            self.optimizer_container.zero_grad()
        elif self.optimizer is not None:
            self.optimizer.zero_grad()
        self._reset_accumulation_normalization()
        self.train_state.micro_step -= remainder

    def prepare_for_shutdown_checkpoint(self) -> None:
        self._flush_pending_optimizer_step()

    def _iter_train_batches(self, loader: Any) -> Iterator[tuple[int, Any, bool]]:
        iterator = iter(enumerate(loader))
        try:
            current = next(iterator)
        except StopIteration:
            return

        while True:
            try:
                next_item = next(iterator)
            except StopIteration:
                next_item = None

            iteration, data = current
            next_micro_step = self.train_state.micro_step + 1
            reaches_accum_boundary = self.accum_steps <= 1 or next_micro_step % self.accum_steps == 0
            will_flush = reaches_accum_boundary or next_item is None
            yield iteration, data, will_flush

            if next_item is None:
                break
            current = next_item

    def _resolve_requested_splits(
        self,
        requested_splits: list[str] | None,
        available_splits: list[str],
        *,
        kind: str,
    ) -> list[str]:
        if requested_splits is None:
            return available_splits

        splits = self._sorted_unique(requested_splits)
        unknown_splits = sorted(set(splits).difference(available_splits))
        if unknown_splits:
            raise ValueError(
                f"unknown {kind} split(s): {unknown_splits}; " f"available {kind} split(s): {available_splits}"
            )
        return splits

    def train(
        self,
        train_splits: list[str] | None = None,
        evaluate_splits: list[str] | None = None,
    ) -> RoundDict:
        """
        Run the full training workflow.

        Selects epoch mode or step mode from `self.is_step_mode`, validates
        explicit split lists against the runner's configured/inferred splits,
        and delegates to `train_epochs` or `train_steps`.

        **Called when:** user code starts training.

        Args:
            train_splits: Optional training splits. When `None`, use `self.train_splits`.
            evaluate_splits: Optional evaluation splits. When `None`, use `self.evaluate_splits`.

        Returns:
            Aggregated runner results (`self.results`).

        Raises:
            ValueError: no valid training split can be resolved.

        **Side effects:** prints selected splits and runs the selected training
        loop. Checkpointing, result writing, scheduler stepping, and early stop
        are owned by the delegated loop method.
        """

        train_splits = self._resolve_requested_splits(train_splits, self.train_splits, kind="training")
        if not train_splits:
            raise ValueError("cannot start training: no valid training split was resolved")

        evaluate_splits = self._resolve_requested_splits(evaluate_splits, self.evaluate_splits, kind="evaluation")

        print(f"train: splits={train_splits}")
        print(f"evaluate: splits={evaluate_splits}")
        if self.is_step_mode:
            return self.train_steps(train_splits=train_splits, evaluate_splits=evaluate_splits)
        return self.train_epochs(train_splits=train_splits, evaluate_splits=evaluate_splits)

    def train_epochs(
        self,
        train_splits: list[str] | None = None,
        evaluate_splits: list[str] | None = None,
    ) -> RoundDict:
        """
        Run epoch-mode training until `self.epochs` is reached.

        Each epoch runs all train splits, then all evaluation splits, advances
        epoch/metric schedulers, appends and writes results, and saves periodic
        checkpoints.

        **Called when:** `train` dispatches while `config.epochs` is set, or
        user code explicitly wants epoch-mode semantics.

        Args:
            train_splits: Training splits for each epoch.
            evaluate_splits: Evaluation splits after each epoch.

        Returns:
            Aggregated runner results (`self.results`).

        Raises:
            ValueError: `config.epochs` is not set.
        """
        if train_splits is None:
            train_splits = self.train_splits
        if evaluate_splits is None:
            evaluate_splits = self.evaluate_splits

        total_epochs = self.epochs
        if total_epochs is None:
            raise ValueError("cannot run epoch-mode training: config.epochs is not set")
        print(f"train: epoch mode start epoch={self.train_state.epoch} total_epochs={total_epochs}")
        checkpoint_cadence = self.checkpoint_interval
        early_stop_counter = 0
        patience = self.patience
        for epoch in range(self.train_state.epoch, total_epochs):
            self.supervisor.maybe_handle_termination_signal()
            self.train_state.epoch = epoch
            result = RoundDict()
            for split in train_splits:
                result[split] = self.train_epoch(split)
                self.supervisor.maybe_handle_termination_signal()
            for split in evaluate_splits:
                result[split] = self.evaluate_epoch(split)
                self.supervisor.maybe_handle_termination_signal()
            self._step_epoch_scheduler(result)
            self.append_result(result, index=epoch)
            print(self.format_epoch_result(result, epochs=epoch, total_epochs=total_epochs))
            self.save_result()
            self.train_state.epoch = epoch + 1
            if checkpoint_cadence > 0 and self.train_state.epoch % checkpoint_cadence == 0:
                self.save_checkpoint(epochs=epoch)
            early_stop_counter = 0 if self.is_best else early_stop_counter + 1
            if early_stop_counter > patience:
                print("train: early-stop triggered")
                break
        self.save_checkpoint(last_step=True)
        return self.results

    def train_epoch(self, split: str = "train") -> RoundDict:
        """
        Run one full dataloader pass for a training split.

        This is the per-split epoch loop. It sets train mode, resets meters and
        train metrics, manages accumulation-window normalization, invokes
        `train_step` for each micro-batch, emits step logs, and records
        interval/epoch telemetry.

        **Called when:** `train_epochs` processes one train split.

        Args:
            split: Training split name.

        Returns:
            Epoch-level metric mapping for this split.

        **Side effects:** updates optimizer state through `train_step`,
        advances `train_state.global_step` on optimizer flushes, writes step
        logs, and may save step-cadence checkpoints.

        !!! danger "Do not"
            - Call this for evaluation data; use `evaluate_epoch`.
            - Override this just to change one batch's forward/loss logic;
              override `train_step`.
            - Manually manage gradient zeroing inside `train_step`; this loop
              and `optimizer_step` own accumulation boundaries.
            - Increment `train_state.epoch`; the surrounding `train_epochs`
              loop owns epoch progress.
            - Save result or checkpoint aliases here; `train_epochs` owns
              epoch-level persistence.

        See Also:
            [`train_steps`][danling.runners.TorchRunner.train_steps]:
                Step-mode counterpart that consumes splits against a global
                step budget instead of one epoch per split.
        """
        loader = self.dataloaders[split]
        loader_length = self._loader_length(loader)
        length = loader_length - 1 if loader_length is not None else None
        last_loss: torch.Tensor | None = None
        last_loss_n: int | None = None
        self._set_loader_epoch(loader, self.train_state.epoch)
        self.mode = RunnerMode.train
        self.split = split
        self.meters.reset()
        self.metrics = self.train_metrics
        if self.metrics is not None:
            self.metrics.reset()
        telemetry = LoopTelemetry(self, start_time=self.loop_time())
        self._reset_accumulation_normalization()
        if self.optimizer_container is not None:
            self.optimizer_container.zero_grad()
        elif self.optimizer is not None:
            self.optimizer.zero_grad()
        checkpoint_cadence = self.checkpoint_interval

        for iteration, data, will_flush in self._iter_train_batches(loader):
            self.supervisor.maybe_handle_termination_signal()
            step_before = self.train_state.global_step
            # Positive int = weighted-loss signal; None = no signal (uniform window).
            # 0 or missing collapses to None so the accumulation state machine
            # picks "uniform" cleanly instead of being silently coerced to 1.
            loss_n = self._get_loss_normalizer(data)
            if loss_n is not None and loss_n <= 0:
                loss_n = None
            self._pending_loss_normalizer = loss_n
            self._train_window_will_flush = will_flush
            try:
                _, loss = self.train_step(data)
            finally:
                self._train_window_will_flush = False
                self._pending_loss_normalizer = None

            self.supervisor.mark_heartbeat_progress()
            self.supervisor.maybe_handle_termination_signal()
            current_time = self.loop_time()
            if self.scheduler is not None and hasattr(self.scheduler, "get_last_lr"):
                self.meters.lr.update(self.scheduler.get_last_lr()[0])
            if loss is not None:
                # `loss_n or 1` weights a missing normalizer as a single-sample meter update;
                # criteria that emit a real loss for zero-valid-token batches are not supported here.
                self.meters.loss.update(loss.detach(), n=loss_n or 1)
            telemetry.observe(iteration=iteration, data=data, current_time=current_time)

            step_after = self.train_state.global_step
            if checkpoint_cadence > 0 and step_after != step_before and step_after % checkpoint_cadence == 0:
                self.save_checkpoint()

            if self.log_interval > 0 and (
                (iteration > 0 and iteration % self.log_interval == 0) or iteration == length
            ):
                telemetry.emit_log(split=split, iteration=iteration, length=length, loss=loss, loss_n=loss_n)
            last_loss = loss
            last_loss_n = loss_n

        if (
            length is None
            and self.log_interval > 0
            and telemetry.last_iteration is not None
            and telemetry.last_iteration != telemetry.last_print_iteration
        ):
            assert telemetry.last_iteration is not None
            telemetry.emit_log(
                split=split,
                iteration=telemetry.last_iteration,
                length=length,
                loss=last_loss,
                loss_n=last_loss_n,
                reset_peak_stats=False,
            )
        result = self.get_epoch_result()
        telemetry.finalize_result(result, elapsed_seconds=self.loop_time(sync=True) - telemetry.start_time)
        return result

    def train_steps(
        self,
        train_splits: list[str] | None = None,
        evaluate_splits: list[str] | None = None,
    ) -> RoundDict:
        """
        Run step-mode training for the configured global step budget.

        Step mode consumes train splits in sorted split order until
        `train_state.global_step >= self.steps`, then optionally evaluates
        configured evaluation splits with `evaluate_steps`.

        **Called when:** `train` dispatches while `config.epochs` is unset, or
        user code explicitly wants a global-step budget.

        Args:
            train_splits: Training splits to consume in order.
            evaluate_splits: Evaluation splits to run after training steps finish.

        Returns:
            Aggregated runner results (`self.results`).

        Raises:
            ValueError: total step budget cannot be resolved.

        **Side effects:** updates epoch as an outer split-round counter,
        appends one result row indexed by `global_step`, writes result files,
        and saves the final checkpoint.

        !!! danger "Do not"
            - Assume a split is consumed exactly once; step mode can resume a
              split iterator across outer rounds.
            - Mutate `train_state.global_step` outside optimizer stepping.

        See Also:
            [`train_epoch`][danling.runners.TorchRunner.train_epoch]:
                Per-split epoch loop used by epoch-mode training.
        """
        if train_splits is None:
            train_splits = self.train_splits
        if evaluate_splits is None:
            evaluate_splits = self.evaluate_splits

        total_steps = self.steps
        if total_steps is None:
            raise ValueError("cannot run step-mode training: config.steps could not be resolved")
        print(f"train: step mode start global_step={self.train_state.global_step} steps={total_steps}")
        result = RoundDict()
        step_mode_iterators: dict[str, Iterator[tuple[int, Any, bool]] | None] = dict.fromkeys(train_splits)
        step_mode_sampler_epochs = {split: self.train_state.epoch for split in train_splits}
        while self.train_state.global_step < total_steps:
            self.supervisor.maybe_handle_termination_signal()
            round_start_step = self.train_state.global_step
            round_result = RoundDict()
            total_train_splits = len(train_splits)
            for split_index, split in enumerate(train_splits):
                self.supervisor.maybe_handle_termination_signal()
                self.mode = RunnerMode.train
                self.split = split
                remaining = total_steps - self.train_state.global_step
                if remaining <= 0:
                    break
                loader = self.dataloaders[split]
                remaining_splits = total_train_splits - split_index
                split_steps = self._step_mode_split_budget(
                    remaining_steps=remaining,
                    remaining_splits=remaining_splits,
                    loader=loader,
                )
                if split_steps <= 0:
                    break
                start_global_step = self.train_state.global_step
                target_global_step = start_global_step + split_steps
                length = max(target_global_step - self.train_state.global_step - 1, 0)
                self.meters.reset()
                self.metrics = self.train_metrics
                if self.metrics is not None:
                    self.metrics.reset()
                telemetry = LoopTelemetry(self, start_time=self.loop_time())
                self._reset_accumulation_normalization()
                if self.optimizer_container is not None:
                    self.optimizer_container.zero_grad()
                elif self.optimizer is not None:
                    self.optimizer.zero_grad()
                checkpoint_cadence = self.checkpoint_interval
                batch_iteration = -1

                while self.train_state.global_step < target_global_step:
                    batch: tuple[int, Any, bool] | None = None
                    iterator = step_mode_iterators[split]
                    recreated = False
                    while True:
                        if iterator is None:
                            if recreated:
                                break
                            self._set_loader_epoch(loader, step_mode_sampler_epochs[split])
                            iterator = self._iter_train_batches(loader)
                            step_mode_iterators[split] = iterator
                            recreated = True
                        try:
                            batch = next(iterator)
                            break
                        except StopIteration:
                            iterator = None
                            step_mode_iterators[split] = None
                            step_mode_sampler_epochs[split] += 1
                    if batch is None:
                        break
                    _, data, will_flush = batch
                    batch_iteration += 1
                    self.supervisor.maybe_handle_termination_signal()
                    step_before = self.train_state.global_step
                    # See `train_epoch` for normalizer semantics.
                    loss_n = self._get_loss_normalizer(data)
                    if loss_n is not None and loss_n <= 0:
                        loss_n = None
                    self._pending_loss_normalizer = loss_n
                    self._train_window_will_flush = will_flush
                    try:
                        _, loss = self.train_step(data)
                    finally:
                        self._train_window_will_flush = False
                        self._pending_loss_normalizer = None

                    self.supervisor.mark_heartbeat_progress()
                    self.supervisor.maybe_handle_termination_signal()
                    current_time = self.loop_time()
                    if self.scheduler is not None and hasattr(self.scheduler, "get_last_lr"):
                        self.meters.lr.update(self.scheduler.get_last_lr()[0])
                    if loss is not None:
                        self.meters.loss.update(loss.detach(), n=loss_n or 1)
                    telemetry.observe(iteration=batch_iteration, data=data, current_time=current_time)

                    step_after = self.train_state.global_step
                    if checkpoint_cadence > 0 and step_after != step_before and step_after % checkpoint_cadence == 0:
                        self.save_checkpoint()

                    step_iteration = step_after - start_global_step - 1 if step_after != step_before else None
                    if (
                        self.log_interval > 0
                        and step_iteration is not None
                        and (
                            (step_iteration > 0 and step_iteration % self.log_interval == 0) or step_iteration == length
                        )
                    ):
                        telemetry.emit_log(
                            split=split,
                            iteration=batch_iteration,
                            length=length,
                            loss=loss,
                            loss_n=loss_n,
                            display_iteration=step_iteration,
                        )

                round_result[split] = self.get_epoch_result()
                telemetry.finalize_result(
                    round_result[split], elapsed_seconds=self.loop_time(sync=True) - telemetry.start_time
                )
                self.supervisor.maybe_handle_termination_signal()

            if self.train_state.global_step == round_start_step:
                remaining_steps = total_steps - self.train_state.global_step
                warn(
                    f"step-mode training made no progress after one full split pass "
                    f"(target={total_steps}, reached={self.train_state.global_step}, remaining={remaining_steps})",
                    RuntimeWarning,
                    stacklevel=2,
                )
                break
            self._step_epoch_scheduler(round_result)
            result = round_result
            self.train_state.epoch += 1
        remaining_steps = total_steps - self.train_state.global_step
        if remaining_steps > 0:
            warn(
                f"step-mode training finished with {remaining_steps} step(s) remaining "
                f"(target={total_steps}, reached={self.train_state.global_step})",
                RuntimeWarning,
                stacklevel=2,
            )
        for split in evaluate_splits:
            result[split] = self.evaluate_steps(split=split)
        self.append_result(result, index=self.train_state.global_step)
        print(f"train: step mode result={result}")
        self.save_result()
        self.save_checkpoint(last_step=True)
        return self.results

    def evaluate_step(self, data: Any) -> tuple[Any, torch.Tensor | None]:
        """
        Run one evaluation micro-step.

        The default implementation runs forward → optional loss → optional
        metric update under `forward_context()`. No backward pass and no
        optimizer step.

        **Called when:** once per micro-batch by `evaluate_epoch`/`evaluate_steps`,
        which run under `torch.inference_mode()`.

        **Precondition:** at least one of `self.model` or `self.ema` is bound.
        `self.mode == RunnerMode.evaluate`. The default prefers `self.ema` over
        `self.model` when both are available.

        Args:
            data: One micro-batch. The default unpacks `data["input"]` /
                `data.get("target")` for mappings, `(data[0], data[1])` for
                non-string sequences, and `(data, None)` otherwise. Override
                `evaluate_step` if your batch shape differs.

        Returns:
            `(pred, loss)`. `pred` is the model output (used by `metrics.update`).
            `loss` is the scalar loss returned to the caller for reduced
            logging, or `None` when no `criterion` is set.

        Raises:
            ValueError: neither `self.model` nor `self.ema` is initialized.

        **Side effects:** moves `data` to `self.device`, runs forward through
        `self.ema or self.model` under `forward_context()`, computes loss when
        `criterion` is set, and updates `self.metrics` when bound.

        !!! danger "Do not"
            - Call `self.backward(...)` or `self.step()` (no optimizer here).
            - Mutate `train_state.global_step` or `train_state.micro_step`.
            - Switch the runner mode (the loop owns `self.mode`).
            - Call `save_checkpoint()` (cadence is owned by training loops only).

        **Backend notes:**

        - `ParallelRunner` overrides this method when a pipeline schedule is
          set; the schedule owns micro-batching and pipeline-stage loss
          reduction.
        """
        data = self.to_device(data)
        if isinstance(data, Mapping):
            inputs = data["input"]
            target = data.get("target")
        elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
            inputs = data[0]
            target = data[1] if len(data) > 1 else None
        else:
            inputs = data
            target = None

        if self.model is None and self.ema is None:
            raise ValueError("cannot run evaluate_step: model is not initialized")
        model = self.ema or self.model
        with self.forward_context():
            pred = model(**inputs) if isinstance(inputs, Mapping) else model(inputs)
            loss = self.criterion(pred, target) if self.criterion is not None else None

        if self.metrics is not None and pred is not None and target is not None:
            self.metrics.update(pred, target)

        return pred, loss

    def evaluate(self, evaluate_splits: list[str] | None = None) -> RoundDict:
        """
        Run evaluation across splits with epoch-mode semantics.

        **Called when:** user code explicitly evaluates a runner, or training
        code delegates to evaluation helpers.

        Args:
            evaluate_splits: Optional evaluation splits. When `None`, use `self.evaluate_splits`.

        Returns:
            Mapping of split -> evaluation result for this call.

        Raises:
            ValueError: no valid evaluation split can be resolved.

        **Side effects:** sets evaluation mode per split, prints a formatted
        aggregate result, and writes scalar outputs through `evaluate_epoch`.
        """

        evaluate_splits = self._resolve_requested_splits(evaluate_splits, self.evaluate_splits, kind="evaluation")
        if not evaluate_splits:
            raise ValueError("cannot start evaluation: no valid evaluation split was resolved")
        print("evaluate: start")
        print(f"evaluate: splits={evaluate_splits}")
        result = RoundDict()
        for split in evaluate_splits:
            result[split] = self.evaluate_epoch(split=split)
        display_epoch = self.train_state.epoch
        if self.epochs is not None and display_epoch > 0:
            display_epoch -= 1
        print(self.format_epoch_result(result, epochs=display_epoch))
        return result

    @torch.inference_mode()
    def evaluate_epoch(self, split: str = "val") -> RoundDict:
        """
        Run one full dataloader pass for an evaluation split.

        Sets evaluation mode, resets meters/evaluation metrics, runs
        `evaluate_step` for every batch under inference mode, emits step logs,
        and writes the split result at the current epoch index.

        **Called when:** `evaluate` or `train_epochs` evaluates a split.

        Args:
            split: Evaluation split name.

        Returns:
            Epoch-level metric mapping for this split.

        **Side effects:** updates evaluation meters/metrics, emits logs, writes
        scalar results, and records telemetry. It does not update optimizer or
        training progress counters.
        """
        loader = self.dataloaders[split]
        loader_length = self._loader_length(loader)
        length = loader_length - 1 if loader_length is not None else None

        last_loss: torch.Tensor | None = None
        last_loss_n: int | None = None
        self.mode = RunnerMode.evaluate
        self.split = split
        self.meters.reset()
        self.metrics = self.evaluate_metrics
        if self.metrics is not None:
            self.metrics.reset()
        telemetry = LoopTelemetry(self, start_time=self.loop_time())
        consumed = 0
        for iteration, data in enumerate(loader):
            consumed = iteration + 1
            self.supervisor.maybe_handle_termination_signal()
            loss_n = self._get_loss_normalizer(data)
            if loss_n is not None and loss_n <= 0:
                loss_n = None
            _, loss = self.evaluate_step(data)
            self.supervisor.mark_heartbeat_progress()
            self.supervisor.maybe_handle_termination_signal()
            current_time = self.loop_time()
            if loss is not None:
                self.meters.loss.update(loss.detach(), n=loss_n or 1)
            telemetry.observe(
                iteration=iteration,
                data=data,
                current_time=current_time,
            )
            self.supervisor.maybe_collect_garbage(iteration + 1, scope=f"evaluate:{split}")

            if self.log_interval > 0 and (
                (iteration > 0 and iteration % self.log_interval == 0) or iteration == length
            ):
                telemetry.emit_log(split=split, iteration=iteration, length=length, loss=loss, loss_n=loss_n)
            last_loss = loss
            last_loss_n = loss_n

        if (
            length is None
            and self.log_interval > 0
            and telemetry.last_iteration is not None
            and telemetry.last_iteration != telemetry.last_print_iteration
        ):
            assert telemetry.last_iteration is not None
            telemetry.emit_log(
                split=split,
                iteration=telemetry.last_iteration,
                length=length,
                loss=last_loss,
                loss_n=last_loss_n,
                reset_peak_stats=False,
            )
        result = self.get_epoch_result()
        telemetry.finalize_result(
            result, elapsed_seconds=self.loop_time(sync=True) - telemetry.start_time, steps=consumed
        )
        self.write_result(result, split, self.train_state.epoch)
        return result

    @torch.inference_mode()
    def evaluate_steps(self, split: str = "val", steps: int | None = None) -> RoundDict:
        """
        Run bounded evaluation steps on one split.

        Used by step-mode training to evaluate a small fixed number of batches
        without requiring a full evaluation pass.

        **Called when:** `train_steps` evaluates configured splits after the
        step budget finishes, or user code requests bounded evaluation.

        Args:
            split: Evaluation split name.
            steps: Number of batches to evaluate. When `None`, defaults to `max(self.steps // 20, 1)`.

        Returns:
            Step-bounded evaluation metrics.

        Raises:
            ValueError: step budget cannot be inferred, `steps` is negative, or
                the dataloader exhausts before the requested number of steps.

        **Side effects:** writes scalar results at `train_state.global_step`.
        """
        if steps is None:
            total_steps = self.steps
            if total_steps is None:
                raise ValueError("cannot infer evaluation steps: step budget is unavailable; pass `steps`")
            steps = max(total_steps // 20, 1)
        if steps < 0:
            raise ValueError(f"invalid steps: expected a non-negative value, got {steps}")
        loader = self.dataloaders[split]
        length = steps - 1

        self.mode = RunnerMode.evaluate
        self.split = split
        if steps == 0:
            self.meters.reset()
            self.metrics = self.evaluate_metrics
            if self.metrics is not None:
                self.metrics.reset()
            result = self.get_epoch_result()
            self.write_result(result, split, self.train_state.global_step)
            return result

        self.meters.reset()
        self.metrics = self.evaluate_metrics
        if self.metrics is not None:
            self.metrics.reset()
        telemetry = LoopTelemetry(self, start_time=self.loop_time())
        consumed = 0
        for iteration, data in enumerate(loader):
            if steps is not None and iteration >= steps:
                break
            consumed = iteration + 1
            self.supervisor.maybe_handle_termination_signal()
            loss_n = self._get_loss_normalizer(data)
            if loss_n is not None and loss_n <= 0:
                loss_n = None
            _, loss = self.evaluate_step(data)
            self.supervisor.mark_heartbeat_progress()
            self.supervisor.maybe_handle_termination_signal()
            current_time = self.loop_time()
            if loss is not None:
                self.meters.loss.update(loss.detach(), n=loss_n or 1)
            telemetry.observe(iteration=iteration, data=data, current_time=current_time)
            self.supervisor.maybe_collect_garbage(iteration + 1, scope=f"evaluate:{split}")

            if self.log_interval > 0 and (
                (iteration > 0 and iteration % self.log_interval == 0) or iteration == length
            ):
                telemetry.emit_log(split=split, iteration=iteration, length=length, loss=loss, loss_n=loss_n)

        if steps is not None and consumed < steps:
            raise ValueError(
                f"evaluate steps exhausted early on split '{split}': requested {steps} step(s), got {consumed}"
            )
        result = self.get_epoch_result()
        telemetry.finalize_result(
            result, elapsed_seconds=self.loop_time(sync=True) - telemetry.start_time, steps=consumed
        )
        self.write_result(result, split, self.train_state.global_step)
        return result

    @torch.inference_mode()
    def infer_step(self, data: Any) -> list[float]:
        """
        Run one inference micro-step.

        The default implementation runs forward through `self.ema or self.model`,
        detaches scalar-per-example predictions, squeezes the trailing
        dimension, moves them to CPU, and returns them as a Python list.

        **Called when:** once per micro-batch by `infer`/`_iter_infer_batches`.
        The method is decorated with `torch.inference_mode()`.

        **Precondition:** at least one of `self.model` or `self.ema` is bound.
        `self.mode == RunnerMode.infer`.

        Args:
            data: One micro-batch. The default unpacks `data["input"]` for
                mappings, `data[0]` for non-string sequences, and `data`
                itself otherwise. Override `infer_step` if your batch shape
                differs or you need to pass auxiliary tensors to the model.

        Returns:
            List of CPU floats for scalar-per-example predictions. The
            default converts with `pred.squeeze(-1).detach().cpu().tolist()`.
            Override if your model emits multi-dim tensors, mappings, or
            non-numeric outputs.

        Raises:
            ValueError: neither `self.model` nor `self.ema` is initialized.

        **Side effects:** moves `data` to `self.device`, runs forward through
        `self.ema or self.model` under `forward_context()`, then converts the
        output to a CPU list.

        !!! danger "Do not"
            - Compute or accumulate metrics (inference is metric-free).
            - Mutate runner state counters.
            - Return a `torch.Tensor` (callers expect `list[float]` for
              batched aggregation and streaming).
            - Call `self.backward(...)` or `self.step()`.

        **Backend notes:**

        - `ParallelRunner` overrides this method when a pipeline schedule is
          set; non-first-stage ranks pass `data=None` and the schedule routes
          activations through pipeline communication.
        """
        data = self.to_device(data)
        if isinstance(data, Mapping):
            inputs = data["input"]
        elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
            inputs = data[0]
        else:
            inputs = data

        if self.model is None and self.ema is None:
            raise ValueError("cannot run infer_step: model is not initialized")
        model = self.ema or self.model
        with self.forward_context():
            pred = model(**inputs) if isinstance(inputs, Mapping) else model(inputs)
        values = pred.squeeze(-1).detach().cpu().tolist()
        if isinstance(values, list):
            return values
        return [float(values)]

    def infer(
        self,
        split: str = "infer",
        *,
        steps: int | None = None,
        stream: bool | None = None,
    ) -> list[float] | Iterator[list[float]]:
        """
        Run inference on one split.

        In non-stream mode this consumes all requested batches and returns a
        flattened Python list. In stream mode it returns an iterator of
        per-batch outputs and leaves consumption to the caller.

        **Called when:** user code requests prediction-only execution.

        Args:
            split: Inference split name.
            steps: Optional max number of batches to consume.
            stream: `True` returns a generator of per-batch outputs, `False` returns a flattened list.
                When `None`, stream only for unsized loaders without explicit `steps`.

        Returns:
            Flattened predictions or a streaming iterator of batch predictions.

        Raises:
            ValueError: `steps` is negative, or non-stream inference is
                requested for an unsized loader without an explicit step count.

        **Side effects:** sets inference mode/split. It does not update metrics
        or optimizer state.
        """

        self.mode = RunnerMode.infer
        self.split = split
        loader = self.dataloaders[split]
        if steps is not None and steps < 0:
            raise ValueError(f"invalid steps: expected a non-negative value, got {steps}")

        loader_length = self._loader_length(loader)
        if stream is None:
            stream = steps is None and loader_length is None

        if not stream and loader_length is None and steps is None:
            raise ValueError("infer with stream=False requires `steps` for unsized loaders")

        iterator = self._iter_infer_batches(loader, steps=steps, split=split)
        if stream:
            return iterator

        total = steps if steps is not None else loader_length
        output: list[float] = []
        for values in tqdm(iterator, total=total, disable=self.distributed and not self.is_main_process):
            output.extend(values)
        return output

    def _iter_infer_batches(self, loader: Any, *, steps: int | None, split: str) -> Iterator[list[float]]:
        for iteration, data in enumerate(loader):
            if steps is not None and iteration >= steps:
                break
            values = self.infer_step(data)
            self.supervisor.mark_heartbeat_progress()
            yield values
            self.supervisor.maybe_collect_garbage(iteration + 1, scope=f"infer:{split}")

    def _export_checkpoint_metadata(self, cls: type = dict) -> Mapping[str, Any]:
        return cls()

    def _export_checkpoint_components(self, cls: type = dict) -> Mapping[str, Any]:
        if self.model is None:
            raise ValueError("cannot build checkpoint state: model is not initialized")
        state = cls()
        state["ema"] = self.ema.state_dict() if self.ema else None
        state["optimizer"] = self.optimizer.state_dict() if self.optimizer else None
        state["scheduler"] = self.scheduler.state_dict() if self.scheduler else None
        state["model"] = self.unwrap(self.model).state_dict()
        return state

    def state_dict(self, cls: type = dict) -> Mapping:
        """
        Return the TorchRunner checkpoint payload.

        Extends `BaseRunner.state_dict` with backend metadata plus EMA,
        optimizer, scheduler, and unwrapped model state.

        **Called when:** checkpoint managers persist a TorchRunner checkpoint.

        Args:
            cls: Mapping factory used for nested payloads.

        Returns:
            Mapping containing base runner state and torch component state.

        **Side effects:** snapshots Python/NumPy/Torch RNG state before
        exporting.
        """
        state = cls(super().state_dict(cls))
        state.update(self._export_checkpoint_metadata(cls))
        state.update(self._export_checkpoint_components(cls))
        return state

    def _restore_model_checkpoint(self, state_dict: Mapping[str, Any], *args, **kwargs) -> None:
        if self.model is None:
            raise ValueError("cannot load model weights: model is not initialized")
        self.unwrap(self.model).load_state_dict(state_dict, *args, **kwargs)

    def load_model(self, state_dict: Mapping[str, Any], *args, **kwargs) -> None:
        self._restore_model_checkpoint(state_dict, *args, **kwargs)

    def _restore_optimizer_checkpoint(self, state_dict: Mapping[str, Any], *args, **kwargs) -> None:
        if self.optimizer is None:
            return
        self.optimizer.load_state_dict(state_dict, *args, **kwargs)

    def load_optimizer(self, state_dict: Mapping[str, Any] | None, *args, **kwargs) -> None:
        if self.optimizer is None:
            return
        optimizer_state = self._require_checkpoint_component_state("optimizer", state_dict)
        self._restore_optimizer_checkpoint(optimizer_state, *args, **kwargs)

    def load_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
        """
        Restore base runner state plus Torch RNG state.

        Model, optimizer, scheduler, and dataloader components are restored by
        `load_checkpoint`; this method owns only runner/RNG state.
        """
        super().load_state_dict(checkpoint)
        state_dict = checkpoint.get("state") or {}
        rng_state = state_dict.get("rng")
        if isinstance(rng_state, Mapping) and "torch_cpu" in rng_state and self.rng_state.torch_cpu is not None:
            torch.set_rng_state(self.rng_state.torch_cpu)
        if (
            torch.cuda.is_available()
            and isinstance(rng_state, Mapping)
            and "torch_cuda" in rng_state
            and self.rng_state.torch_cuda is not None
        ):
            torch.cuda.set_rng_state_all(self.rng_state.torch_cuda)

    def load_checkpoint(self, checkpoint: Mapping | bytes | str | os.PathLike, *args, **kwargs) -> None:
        """
        Load a full checkpoint and rebind optimizer/scheduler helpers.

        This delegates component restore to `BaseRunner.load_checkpoint`, then
        rebuilds the `OptimizerContainer` so scheduler and optimizer state stay
        bound after restore.
        """
        super().load_checkpoint(checkpoint, *args, **kwargs)
        self._bind_optimizer_container()

    # `save_checkpoint` is inherited from `BaseRunner`; collective vs main-only
    # dispatch is owned by `checkpoint_manager.is_collective`.

    def read_checkpoint(self, checkpoint: Mapping | bytes | str | os.PathLike, *args, **kwargs) -> Mapping[str, Any]:
        """Read checkpoint payload from mapping/file/DCP directory input."""
        if isinstance(checkpoint, Mapping):
            return checkpoint

        if self.config.checkpoint.backend.lower() == "dcp":
            return self.checkpoint_manager.load_checkpoint(checkpoint)
        return super().read_checkpoint(checkpoint, *args, **kwargs)

    @classmethod
    def read_config(
        cls,
        checkpoint: Mapping | bytes | str | os.PathLike,
        *args,
        **kwargs,
    ) -> RunnerConfig:
        """Read runner config from checkpoint payload, including DCP directory inputs."""
        if isinstance(checkpoint, Mapping):
            return super().read_config(checkpoint, *args, **kwargs)

        if TorchDistributedCheckpointManager.is_checkpoint_path(checkpoint):
            return TorchDistributedCheckpointManager.read_config(checkpoint)

        return super().read_config(checkpoint, *args, **kwargs)

    @property
    def device(self):
        if torch.cuda.is_available():
            return torch.device("cuda", self.local_rank)
        return torch.device("cpu")

    @property
    def mode(self) -> RunnerMode:
        return self._mode

    @mode.setter
    def mode(self, mode: str | RunnerMode) -> None:
        if isinstance(mode, str):
            mode = RunnerMode(mode)
        if getattr(self, "_mode", None) == mode:
            return
        self._mode = mode

        is_train = mode == RunnerMode.train
        model_parts = getattr(self, "model_parts", None)
        if isinstance(model_parts, Sequence) and model_parts:
            for model_part in model_parts:
                if not isinstance(model_part, nn.Module):
                    continue
                model_part.train(is_train)
        elif self.model is not None:
            self.model.train(is_train)
        if self.ema is not None:
            self.ema.train(is_train)

    @property
    def rank(self) -> int:
        if dist.is_available() and dist.is_initialized():
            return dist.get_rank()
        return int(os.getenv("RANK", "0"))

    @property
    def world_size(self) -> int:
        r"""
        Number of Processes.
        """
        if dist.is_available() and dist.is_initialized():
            return dist.get_world_size()
        return int(os.getenv("WORLD_SIZE", "1"))

    @property
    def distributed(self) -> bool:
        return self.world_size > 1

    def close(self, timeout: float | None = None) -> bool:
        """Close runner resources."""
        try:
            drained = super().close(timeout=timeout)
        except Exception:
            self._close_profiler()
            self.destroy_process_group()
            raise
        if not drained:
            return False
        self._close_profiler()
        self.destroy_process_group()
        return drained

world_size property

Python
world_size: int

Number of Processes.

init_distributed

Python
init_distributed() -> None

Initialize the distributed environment.

The default implementation initializes the default torch.distributed process group from WORLD_SIZE/RANK/LOCAL_RANK environment variables when WORLD_SIZE > 1, sets the active CUDA device, broadcasts self.timestamp from rank 0, and seeds elastic_state.restart_count from TORCHELASTIC_RESTART_COUNT.

Called when: once during BaseRunner.__init__, before init_checkpoint_manager, init_fault_tolerance, and init_garbage_collection. The runner is partially constructed at this point — self.config, self.workspace, self.timestamp, the dataloader container, and the default FileCheckpointManager are bound, but the model is not materialized and optimizers/dataloaders are not built.

Precondition: environment variables WORLD_SIZE, RANK, LOCAL_RANK are set when running distributed. The default torch.distributed process group is not already initialized when WORLD_SIZE > 1 — the runner owns process-group lifecycle.

Raises:

Type Description
RuntimeError

the default process group is already initialized when WORLD_SIZE > 1.

ValueError

comm.init_timeout_seconds is non-positive.

Side effects: when WORLD_SIZE > 1, calls dist.init_process_group(...), sets the active CUDA device when CUDA is available, and broadcasts self.timestamp from rank 0. Reads TORCHELASTIC_RESTART_COUNT into elastic_state.restart_count.

Do not

  • Initialize a process group via dist.init_process_group(...) outside the runner; the runner owns its lifecycle.
  • Build the model or dataloaders here; those happen in __post_init__.
  • Bind the checkpoint manager here; init_checkpoint_manager runs next.

Backend notes:

  • ParallelRunner extends this hook: after calling super(), it builds the parallel topology (build_topology) and initializes per-axis process groups via init_device_mesh.
  • DeepSpeedRunner inherits the default; DeepSpeed reuses the default process group initialized here.
Source code in danling/runners/torch_runner.py
Python
def init_distributed(self) -> None:
    r"""
    Initialize the distributed environment.

    The default implementation initializes the default torch.distributed
    process group from `WORLD_SIZE`/`RANK`/`LOCAL_RANK` environment
    variables when `WORLD_SIZE > 1`, sets the active CUDA device,
    broadcasts `self.timestamp` from rank 0, and seeds
    `elastic_state.restart_count` from `TORCHELASTIC_RESTART_COUNT`.

    **Called when:** once during `BaseRunner.__init__`, before
    `init_checkpoint_manager`, `init_fault_tolerance`, and
    `init_garbage_collection`. The runner is partially constructed at
    this point — `self.config`, `self.workspace`, `self.timestamp`, the
    dataloader container, and the default `FileCheckpointManager` are
    bound, but the model is not materialized and optimizers/dataloaders
    are not built.

    **Precondition:** environment variables `WORLD_SIZE`, `RANK`,
    `LOCAL_RANK` are set when running distributed. The default
    torch.distributed process group is **not** already initialized when
    `WORLD_SIZE > 1` — the runner owns process-group lifecycle.

    Raises:
        RuntimeError: the default process group is already initialized
            when `WORLD_SIZE > 1`.
        ValueError: `comm.init_timeout_seconds` is non-positive.

    **Side effects:** when `WORLD_SIZE > 1`, calls
    `dist.init_process_group(...)`, sets the active CUDA device when
    CUDA is available, and broadcasts `self.timestamp` from rank 0.
    Reads `TORCHELASTIC_RESTART_COUNT` into `elastic_state.restart_count`.

    !!! danger "Do not"
        - Initialize a process group via `dist.init_process_group(...)`
          outside the runner; the runner owns its lifecycle.
        - Build the model or dataloaders here; those happen in
          `__post_init__`.
        - Bind the checkpoint manager here; `init_checkpoint_manager`
          runs next.

    **Backend notes:**

    - `ParallelRunner` extends this hook: after calling `super()`, it
      builds the parallel topology (`build_topology`) and initializes
      per-axis process groups via `init_device_mesh`.
    - `DeepSpeedRunner` inherits the default; DeepSpeed reuses the
      default process group initialized here.
    """

    backend = self.config.get("backend", os.getenv("BACKEND"))
    init_method = self.config.get("init_method", os.getenv("INIT_METHOD"))
    init_timeout = self._comm_timeout("comm.init_timeout_seconds")
    world_size = int(os.getenv("WORLD_SIZE", "1"))
    rank = int(os.getenv("RANK", "0"))
    runtime_device = self.device
    use_cuda_runtime = torch.cuda.is_available() and runtime_device.type == "cuda"
    runtime_device_index = runtime_device.index if runtime_device.index is not None else self.local_rank
    dist_ready = dist.is_available() and dist.is_initialized()
    if world_size > 1 and dist_ready:
        raise RuntimeError(
            "default process group is already initialized; Runner requires owning process-group lifecycle"
        )
    if world_size > 1:
        if use_cuda_runtime:
            torch.cuda.set_device(runtime_device_index)
        init_kwargs: dict[str, Any] = {
            "backend": backend,
            "init_method": init_method,
            "world_size": world_size,
            "rank": rank,
        }
        if init_timeout is not None:
            init_kwargs["timeout"] = init_timeout
        dist.init_process_group(**init_kwargs)
        dist_ready = bool(dist.is_available() and dist.is_initialized())

    if dist_ready and use_cuda_runtime:
        torch.cuda.set_device(runtime_device_index)

    if dist_ready and self.world_size > 1:
        object_list = [self.timestamp]
        dist.broadcast_object_list(object_list)
        self.timestamp = str(object_list[0])

    restart_count = os.getenv("TORCHELASTIC_RESTART_COUNT")
    if restart_count is not None:
        self.elastic_state.restart_count = int(restart_count)

    self._train_pg_timeout_reduced = False

init_checkpoint_manager

Python
init_checkpoint_manager() -> None

Bind the checkpoint manager corresponding to config.checkpoint.backend.

The default dispatches by backend: when the backend is "dcp", it binds a TorchDistributedCheckpointManager (or TorchFTCheckpointManager when FT dataloader checkpoints are enabled). For "file" it leaves the FileCheckpointManager already bound by BaseRunner.__init__ in place.

Called when: once during BaseRunner.__init__, after init_distributed and before init_fault_tolerance. The default FileCheckpointManager is already bound at this point — overrides should swap it via set_checkpoint_manager(...), not by direct attribute assignment.

Precondition: config.checkpoint.backend is normalized to one of {"file", "dcp"} (TorchRunner does this in __init__). When the backend is "dcp", the default process group is initialized for distributed runs.

Side effects: swaps self.checkpoint_manager via set_checkpoint_manager(...) when the backend differs from "file". The prior manager is closed with a zero timeout.

Do not

  • Set self.checkpoint_manager directly; use set_checkpoint_manager so the prior manager is closed cleanly.
  • Initialize fault tolerance here; init_fault_tolerance runs next.
  • Bind the model or dataloaders here.

Backend notes:

  • DeepSpeedRunner coerces config.checkpoint.backend to "file" in __init__, so this hook is a no-op for that backend.
  • ParallelRunner coerces the backend to "dcp", so this hook always binds TorchDistributedCheckpointManager or TorchFTCheckpointManager.
Source code in danling/runners/torch_runner.py
Python
def init_checkpoint_manager(self) -> None:
    """
    Bind the checkpoint manager corresponding to `config.checkpoint.backend`.

    The default dispatches by backend: when the backend is `"dcp"`, it
    binds a `TorchDistributedCheckpointManager` (or
    `TorchFTCheckpointManager` when FT dataloader checkpoints are
    enabled). For `"file"` it leaves the `FileCheckpointManager` already
    bound by `BaseRunner.__init__` in place.

    **Called when:** once during `BaseRunner.__init__`, after
    `init_distributed` and before `init_fault_tolerance`. The default
    `FileCheckpointManager` is already bound at this point — overrides
    should swap it via `set_checkpoint_manager(...)`, not by direct
    attribute assignment.

    **Precondition:** `config.checkpoint.backend` is normalized to one
    of `{"file", "dcp"}` (TorchRunner does this in `__init__`). When
    the backend is `"dcp"`, the default process group is initialized
    for distributed runs.

    **Side effects:** swaps `self.checkpoint_manager` via
    `set_checkpoint_manager(...)` when the backend differs from
    `"file"`. The prior manager is closed with a zero timeout.

    !!! danger "Do not"
        - Set `self.checkpoint_manager` directly; use
          `set_checkpoint_manager` so the prior manager is closed
          cleanly.
        - Initialize fault tolerance here; `init_fault_tolerance` runs
          next.
        - Bind the model or dataloaders here.

    **Backend notes:**

    - `DeepSpeedRunner` coerces `config.checkpoint.backend` to `"file"`
      in `__init__`, so this hook is a no-op for that backend.
    - `ParallelRunner` coerces the backend to `"dcp"`, so this hook
      always binds `TorchDistributedCheckpointManager` or
      `TorchFTCheckpointManager`.
    """
    checkpoint_backend = self.config.checkpoint.backend.lower()
    if checkpoint_backend == "dcp":
        ft_checkpoint_enabled = bool(
            self.config.get("ft.enabled", False)
            or self.config.get("checkpoint.enable_ft_dataloader_checkpoints", False)
        )
        manager_cls = TorchFTCheckpointManager if ft_checkpoint_enabled else TorchDistributedCheckpointManager
        self.set_checkpoint_manager(manager_cls(self))
        return

init_tensorboard

Python
init_tensorboard(*args, **kwargs) -> None

Set up TensorBoard SummaryWriter.

Source code in danling/runners/torch_runner.py
Python
@on_main_process
def init_tensorboard(self, *args, **kwargs) -> None:
    r"""
    Set up TensorBoard SummaryWriter.
    """

    from torch.utils.tensorboard.writer import SummaryWriter  # pylint: disable=C0415

    if "log_dir" not in kwargs:
        kwargs["log_dir"] = os.path.join(self.workspace.dir, "tensorboard", self.timestamp)

    self.writer = SummaryWriter(*args, **kwargs)
    self.writer.add_scalar = catch(OSError, verbose=False)(self.writer.add_scalar)

set_seed

Python
set_seed(
    seed: int | None = None, bias: int | bool | None = None
) -> int

Set up random seed.

Parameters:

Name Type Description Default
seed
int | None

Random seed to set. Defaults to self.config.seed (config.seed).

None
bias
int | bool | None

Make the seed different for each processes. This is used to ensure the data augmentation are applied differently on every processes. Defaults to self.rank. Set to False to disable this feature.

None
Source code in danling/runners/torch_runner.py
Python
def set_seed(self, seed: int | None = None, bias: int | bool | None = None) -> int:
    r"""
    Set up random seed.

    Args:
        seed: Random seed to set.
            Defaults to `self.config.seed` (`config.seed`).

        bias: Make the seed different for each processes.
            This is used to ensure the data augmentation are applied differently on every processes.
            Defaults to `self.rank`.
            Set to `False` to disable this feature.
    Returns:
        Random seed set.
    """

    base_seed = seed if seed is not None else self.config.seed  # type: ignore[assignment]
    if base_seed is None:
        base_seed = random.randint(0, 2**32 - 1)
        if self.distributed and dist.is_initialized():
            object_list = [base_seed]
            dist.broadcast_object_list(object_list)
            base_seed = object_list[0]
    base_seed = int(base_seed)
    # Keep `config.seed` as the global/base seed (before per-rank bias).
    self.config.seed = base_seed

    process_seed = base_seed
    if bias is None:
        if self.ft is not None:
            _, bias = self.ft.data_parallel_info(self.world_size, self.rank)
        else:
            bias = self.rank
    if bias:
        process_seed += int(bias)

    torch.manual_seed(process_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(process_seed)
    if np_random is not None:
        np_random.seed(process_seed)
    random.seed(process_seed)
    self.rng_state.python = random.getstate()
    self.rng_state.numpy = np_random.get_state() if np_random is not None else None
    self.rng_state.torch_cpu = torch.get_rng_state()
    if torch.cuda.is_available():
        self.rng_state.torch_cuda = torch.cuda.get_rng_state_all()
    else:
        self.rng_state.torch_cuda = None
    return process_seed

materialize_model

Python
materialize_model() -> None

Move the model to the runtime device, optionally compile, and wrap with DDP when distributed.

The default is a single-module DDP-style materialization: it moves self.model to self.device, applies any FP8 module policy when FP8 is enabled, runs torch.compile via self.compiler (under the DDP-optimizer context when wrapping is needed), and wraps the result with nn.parallel.DistributedDataParallel when world size > 1.

Called when: once during __post_init__, after setup_fp8() and before build_optimizer(). The order matters — the optimizer must see post-wrap parameters.

Precondition: self.model is set (typically by the user before constructing the runner). self.device resolves to the runtime device.

Raises:

Type Description
ValueError

self.model is not initialized.

Side effects: moves self.model to self.device; applies FP8 module policy when self.fp8_enabled; compiles via self.compiler.compile(...) under the DDP-optimizer context when wrapping is needed; wraps with DistributedDataParallel for world size > 1. Moves self.ema to device when EMA is bound.

Do not

  • Build the optimizer or scheduler here; they run after this hook.
  • Skip the device move when overriding (tensors must live on self.device before the forward pass).
  • Re-wrap an already-wrapped model (e.g., DDP-wrap a DDP module).

Backend notes:

  • DeepSpeedRunner overrides this hook to move the model to device and compile only; the DeepSpeed engine wraps the model later in _finalize_runtime_components.
  • ParallelRunner overrides this hook for FSDP2, pipeline-parallel schedules, and tensor/expert/context parallelism (via the parallelize_model and apply_activation_checkpointing hooks).
Source code in danling/runners/torch_runner.py
Python
def materialize_model(self) -> None:
    """
    Move the model to the runtime device, optionally compile, and wrap
    with DDP when distributed.

    The default is a single-module DDP-style materialization: it moves
    `self.model` to `self.device`, applies any FP8 module policy when
    FP8 is enabled, runs `torch.compile` via `self.compiler` (under the
    DDP-optimizer context when wrapping is needed), and wraps the result
    with `nn.parallel.DistributedDataParallel` when world size > 1.

    **Called when:** once during `__post_init__`, after `setup_fp8()`
    and before `build_optimizer()`. The order matters — the optimizer
    must see post-wrap parameters.

    **Precondition:** `self.model` is set (typically by the user before
    constructing the runner). `self.device` resolves to the runtime
    device.

    Raises:
        ValueError: `self.model` is not initialized.

    **Side effects:** moves `self.model` to `self.device`; applies FP8
    module policy when `self.fp8_enabled`; compiles via
    `self.compiler.compile(...)` under the DDP-optimizer context when
    wrapping is needed; wraps with `DistributedDataParallel` for world
    size > 1. Moves `self.ema` to device when EMA is bound.

    !!! danger "Do not"
        - Build the optimizer or scheduler here; they run after this
          hook.
        - Skip the device move when overriding (tensors must live on
          `self.device` before the forward pass).
        - Re-wrap an already-wrapped model (e.g., DDP-wrap a DDP module).

    **Backend notes:**

    - `DeepSpeedRunner` overrides this hook to move the model to device
      and compile only; the DeepSpeed engine wraps the model later in
      `_finalize_runtime_components`.
    - `ParallelRunner` overrides this hook for FSDP2, pipeline-parallel
      schedules, and tensor/expert/context parallelism (via the
      `parallelize_model` and `apply_activation_checkpointing` hooks).
    """
    if self.model is None:
        raise ValueError("cannot materialize model: model is not initialized")

    model = self.model.to(self.device)
    self.model = model
    if self.fp8_enabled:
        self.apply_fp8_module_policy_to_model_parts()
        model = self.model
    should_wrap_ddp = self.distributed and not isinstance(
        model, (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)
    )
    with self.compiler.ddp_optimizer() if should_wrap_ddp else nullcontext():
        model = self.compiler.compile(model)
    if should_wrap_ddp:
        model = nn.parallel.DistributedDataParallel(model)
    self.model = model

    if self.ema is not None:
        self.ema = self.ema.to(self.device)

build_optimizer

Python
build_optimizer() -> None

Auto-build the optimizer from config.optim (or config.optimizer) when self.optimizer is absent.

The default iterates parameters via iter_optimizer_parameters and dispatches to the OPTIMIZERS registry with the merged config. If optim.param_groups is configured, entries are matched by regex search against iter_optimizer_named_parameters; unmatched parameters keep the optimizer-level defaults.

Called when: once during TorchRunner.__post_init__, after materialize_model (so parameters reflect DDP/FSDP wrapping) and before build_scheduler.

Precondition: self.model is materialized and on self.device. self.optimizer is None (the auto-build is skipped when the user has already bound an optimizer).

Side effects: sets self.optimizer to the registry-built instance.

Do not

  • Run before materialize_model; parameters won’t reflect DDP/FSDP wrapping.
  • Build a scheduler here.
  • Override parameter enumeration here; override iter_optimizer_parameters / iter_optimizer_named_parameters instead so subclass topology (e.g., ParallelRunner.model_parts) is preserved.

Backend notes:

  • DeepSpeedRunner inherits this hook; DeepSpeed may replace the optimizer with a DeepSpeed-managed instance during _finalize_runtime_components.
  • ParallelRunner inherits this hook but overrides iter_optimizer_parameters to enumerate self.model_parts.
Source code in danling/runners/torch_runner.py
Python
def build_optimizer(self) -> None:
    """
    Auto-build the optimizer from `config.optim` (or `config.optimizer`)
    when `self.optimizer` is absent.

    The default iterates parameters via `iter_optimizer_parameters` and
    dispatches to the `OPTIMIZERS` registry with the merged config. If
    `optim.param_groups` is configured, entries are matched by regex
    `search` against `iter_optimizer_named_parameters`; unmatched
    parameters keep the optimizer-level defaults.

    **Called when:** once during `TorchRunner.__post_init__`, after
    `materialize_model` (so parameters reflect DDP/FSDP wrapping) and
    before `build_scheduler`.

    **Precondition:** `self.model` is materialized and on `self.device`.
    `self.optimizer` is `None` (the auto-build is skipped when the user
    has already bound an optimizer).

    **Side effects:** sets `self.optimizer` to the registry-built
    instance.

    !!! danger "Do not"
        - Run before `materialize_model`; parameters won't reflect
          DDP/FSDP wrapping.
        - Build a scheduler here.
        - Override parameter enumeration here; override
          `iter_optimizer_parameters` / `iter_optimizer_named_parameters`
          instead so subclass topology (e.g., `ParallelRunner.model_parts`)
          is preserved.

    **Backend notes:**

    - `DeepSpeedRunner` inherits this hook; DeepSpeed may replace the
      optimizer with a DeepSpeed-managed instance during
      `_finalize_runtime_components`.
    - `ParallelRunner` inherits this hook but overrides
      `iter_optimizer_parameters` to enumerate `self.model_parts`.
    """
    if self.optimizer is not None or self.model is None:
        return
    optim_cfg = self.config.get("optim")
    if optim_cfg is None:
        optim_cfg = self.config.get("optimizer")
    if not isinstance(optim_cfg, Mapping) or not optim_cfg:
        return
    optimizer_kwargs = dict(optim_cfg)
    optimizer_kwargs.pop("param_groups", None)
    parameters = self._build_optimizer_param_groups(optim_cfg)
    if not parameters:
        return
    self.optimizer = OPTIMIZERS.build(params=parameters, **optimizer_kwargs)

build_scheduler

Python
build_scheduler() -> None

Auto-build the LR scheduler from config.sched (or config.scheduler) when self.scheduler is absent.

The default pops interval and monitor from the config (those drive runner-level dispatch, not scheduler construction), defaults total_steps to self.steps when computable, and dispatches to the SCHEDULERS registry with self.optimizer and the merged config.

Called when: once during TorchRunner.__post_init__, after build_optimizer.

Precondition: self.optimizer is bound. self.scheduler is None (the auto-build is skipped when the user has already bound a scheduler).

Side effects: sets self.scheduler to the registry-built instance.

Do not

  • Run before build_optimizer; the scheduler must wrap an optimizer.
  • Set scheduler interval or monitor here; configure them via config.sched.interval / config.sched.monitor.

Backend notes:

  • DeepSpeedRunner inherits this hook; the scheduler may be handed to the DeepSpeed engine in _finalize_runtime_components when its effective interval is "step". Otherwise the runner retains it.
Source code in danling/runners/torch_runner.py
Python
def build_scheduler(self) -> None:
    """
    Auto-build the LR scheduler from `config.sched` (or
    `config.scheduler`) when `self.scheduler` is absent.

    The default pops `interval` and `monitor` from the config (those
    drive runner-level dispatch, not scheduler construction), defaults
    `total_steps` to `self.steps` when computable, and dispatches to
    the `SCHEDULERS` registry with `self.optimizer` and the merged
    config.

    **Called when:** once during `TorchRunner.__post_init__`, after
    `build_optimizer`.

    **Precondition:** `self.optimizer` is bound. `self.scheduler` is
    `None` (the auto-build is skipped when the user has already bound a
    scheduler).

    **Side effects:** sets `self.scheduler` to the registry-built
    instance.

    !!! danger "Do not"
        - Run before `build_optimizer`; the scheduler must wrap an
          optimizer.
        - Set scheduler interval or monitor here; configure them via
          `config.sched.interval` / `config.sched.monitor`.

    **Backend notes:**

    - `DeepSpeedRunner` inherits this hook; the scheduler may be handed
      to the DeepSpeed engine in `_finalize_runtime_components` when
      its effective interval is `"step"`. Otherwise the runner retains
      it.
    """
    if self.scheduler is not None or self.optimizer is None:
        return
    sched_cfg = self._get_scheduler_config()
    if not isinstance(sched_cfg, Mapping) or not sched_cfg:
        return
    scheduler_kwargs = dict(sched_cfg)
    scheduler_kwargs.pop("interval", None)
    scheduler_kwargs.pop("monitor", None)
    if "total_steps" not in scheduler_kwargs:
        steps = self.steps
        if steps is not None:
            scheduler_kwargs["total_steps"] = steps
    self.scheduler = SCHEDULERS.build(self.optimizer, **scheduler_kwargs)

build_dataloaders

Python
build_dataloaders()

Build dataloaders for dataset splits not already materialized.

The default iterates self.datasets, merges config.dataloader defaults with split-specific overrides (config.dataloader.<split>), constructs a sampler via build_datasampler, and wraps each dataset in a StatefulDataLoader using self.collate_fn. Train splits default to shuffle=True and drop_last=True; non-train splits default to the opposite.

Called when: once during TorchRunner.__post_init__ when self.datasets is non-empty.

Precondition: self.datasets is populated (typically by the user before constructing the runner). self.dataloaders is bound to a default-constructed DataLoaderDict.

Side effects: populates self.dataloaders[split] for each split in self.datasets not already materialized. Existing entries in self.dataloaders are left untouched.

Do not

  • Override sampler logic here; override build_datasampler instead.
  • Override collation; set self.collate_fn or override collate_fn (classmethod) instead.
  • Bind the optimizer or scheduler here.

Backend notes:

  • ParallelRunner substitutes self.dataloaders with a proxying dict in __init__ so non-first/last pipeline stages receive a StepProxyLoader view. The build logic itself is inherited.
Source code in danling/runners/torch_runner.py
Python
def build_dataloaders(self):
    """
    Build dataloaders for dataset splits not already materialized.

    The default iterates `self.datasets`, merges `config.dataloader`
    defaults with split-specific overrides (`config.dataloader.<split>`),
    constructs a sampler via `build_datasampler`, and wraps each dataset
    in a `StatefulDataLoader` using `self.collate_fn`. Train splits
    default to `shuffle=True` and `drop_last=True`; non-train splits
    default to the opposite.

    **Called when:** once during `TorchRunner.__post_init__` when
    `self.datasets` is non-empty.

    **Precondition:** `self.datasets` is populated (typically by the
    user before constructing the runner). `self.dataloaders` is bound
    to a default-constructed `DataLoaderDict`.

    **Side effects:** populates `self.dataloaders[split]` for each
    split in `self.datasets` not already materialized. Existing entries
    in `self.dataloaders` are left untouched.

    !!! danger "Do not"
        - Override sampler logic here; override `build_datasampler`
          instead.
        - Override collation; set `self.collate_fn` or override
          `collate_fn` (classmethod) instead.
        - Bind the optimizer or scheduler here.

    **Backend notes:**

    - `ParallelRunner` substitutes `self.dataloaders` with a proxying
      dict in `__init__` so non-first/last pipeline stages receive a
      `StepProxyLoader` view. The build logic itself is inherited.
    """
    datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders}
    dataloader_config = self.config.get("dataloader", NestedDict())
    default_kwargs = NestedDict({k: v for k, v in dataloader_config.items() if k not in self.datasets})
    split_kwargs = NestedDict({k: v for k, v in dataloader_config.items() if k in self.datasets})
    for k, dataset in datasets.items():
        kwargs = NestedDict(default_kwargs)
        if k in split_kwargs:
            kwargs.merge(split_kwargs[k], overwrite=True)
        is_train_split = k in self.train_splits
        shuffle = kwargs.pop("shuffle", is_train_split)
        kwargs.setdefault("drop_last", is_train_split)
        sampler = self.build_datasampler(dataset, split=k, shuffle=shuffle)
        self.dataloaders[k] = StatefulDataLoader(dataset, sampler=sampler, collate_fn=self.collate_fn, **kwargs)

build_datasampler

Python
build_datasampler(
    dataset: Any, *, split: str, shuffle: bool
) -> Any

Build the sampler for one dataset split.

Called when: build_dataloaders materializes a split from self.datasets.

Parameters:

Name Type Description Default
dataset
Any

Dataset object for the split.

required
split
str

Split name being materialized.

required
shuffle
bool

Whether this split should be sampled in shuffled order.

required

Returns:

Type Description
Any

A local random/sequential sampler in single-process mode, or a

Any

DistributedSampler in distributed mode.

Backend notes:

  • ParallelRunner overrides replica/rank selection so data-parallel sampling follows its topology instead of raw global rank.
Source code in danling/runners/torch_runner.py
Python
def build_datasampler(self, dataset: Any, *, split: str, shuffle: bool) -> Any:
    """
    Build the sampler for one dataset split.

    **Called when:** `build_dataloaders` materializes a split from
    `self.datasets`.

    Args:
        dataset: Dataset object for the split.
        split: Split name being materialized.
        shuffle: Whether this split should be sampled in shuffled order.

    Returns:
        A local random/sequential sampler in single-process mode, or a
        `DistributedSampler` in distributed mode.

    **Backend notes:**

    - `ParallelRunner` overrides replica/rank selection so data-parallel
      sampling follows its topology instead of raw global rank.
    """
    if self.distributed:
        num_replicas = self.world_size
        rank = self.rank
        if self.ft is not None:
            num_replicas, rank = self.ft.data_parallel_info(num_replicas, rank)
        return utils.data.distributed.DistributedSampler(
            dataset,
            num_replicas=num_replicas,
            rank=rank,
            shuffle=shuffle,
        )
    return utils.data.RandomSampler(dataset) if shuffle else utils.data.SequentialSampler(dataset)

to_device

Python
to_device(data: Any)

Move one batch to runtime device; override in subclasses for custom fast paths.

Source code in danling/runners/torch_runner.py
Python
def to_device(self, data: Any):
    """Move one batch to runtime device; override in subclasses for custom fast paths."""
    return to_device(data, self.device)

all_reduce

Python
all_reduce(tensor: Tensor, *, op=SUM) -> Tensor

Reduce tensor over the runner’s replica/data-parallel collective domain.

Source code in danling/runners/torch_runner.py
Python
def all_reduce(self, tensor: torch.Tensor, *, op=dist.ReduceOp.SUM) -> torch.Tensor:
    """Reduce tensor over the runner's replica/data-parallel collective domain."""
    if dist.is_available() and dist.is_initialized():
        dist.all_reduce(tensor, op=op, group=self.all_reduce_group())
    return tensor

reduce

Python
reduce(tensor: Tensor) -> Tensor

Average-reduce tensor over the runner’s collective domain.

Source code in danling/runners/torch_runner.py
Python
def reduce(self, tensor: torch.Tensor) -> torch.Tensor:
    """Average-reduce tensor over the runner's collective domain."""
    if not (dist.is_available() and dist.is_initialized()):
        return tensor
    group = self.all_reduce_group()
    group_size = max(self.world_size if group is None else dist.get_world_size(group=group), 1)
    if group_size <= 1:
        return tensor

    original_device = tensor.device
    payload_device = self.all_reduce_device()
    payload = tensor if original_device == payload_device else tensor.to(payload_device)
    self.all_reduce(payload, op=dist.ReduceOp.SUM)
    payload = payload / group_size
    if payload.device != original_device:
        payload = payload.to(original_device)
    return payload

reduce_loss_for_logging

Python
reduce_loss_for_logging(
    loss: Tensor | None, loss_n: int | None
) -> Tensor | None

Detach and all-reduce weighted loss tensor for logging.

Source code in danling/runners/torch_runner.py
Python
def reduce_loss_for_logging(self, loss: torch.Tensor | None, loss_n: int | None) -> torch.Tensor | None:
    """Detach and all-reduce weighted loss tensor for logging."""
    if loss is None:
        return None
    loss_value = loss.detach().to(dtype=torch.float64)
    if loss_value.ndim > 0:
        loss_value = loss_value.mean()
    normalizer = float(max(int(loss_n or 1), 1))
    payload_device = self.all_reduce_device()
    payload = torch.stack(
        (
            loss_value.to(device=payload_device) * normalizer,
            torch.tensor(normalizer, dtype=torch.float64, device=payload_device),
        )
    )
    self.all_reduce(payload, op=dist.ReduceOp.SUM)
    if payload[1].item() <= 0:
        return None
    return payload[0] / payload[1]

train_context

Python
train_context()

Context for one training micro-step (autocast + optional DDP no_sync).

Source code in danling/runners/torch_runner.py
Python
@contextmanager
def train_context(self):
    """Context for one training micro-step (autocast + optional DDP no_sync)."""
    with self._train_step_context(no_sync_targets=self._train_no_sync_targets()):
        yield

forward_context

Python
forward_context()

Precision context used by train/eval/infer forward passes.

Source code in danling/runners/torch_runner.py
Python
def forward_context(self):
    """Precision context used by train/eval/infer forward passes."""

    if self.fp8_enabled:
        return self.fp8_autocast()

    precision = self.precision
    if precision is None:
        return nullcontext()
    return torch.autocast(self.device.type, dtype=get_precision(precision))

train_step

Python
train_step(data: Any) -> tuple[Any, Tensor | None]

Run one training micro-step.

The default implementation runs forward → loss → metric update → backward → step for one micro-batch.

Called when: once per micro-batch by train_epoch/train_steps. The caller seeds the loop’s accumulation state before each invocation; this method consumes that state through backward() and step().

Precondition: self.model, self.optimizer, and self.criterion are bound; self.mode == RunnerMode.train.

Parameters:

Name Type Description Default
data
Any

One micro-batch. The default unpacks data["input"] / data.get("target") for mappings, (data[0], data[1]) for non-string sequences, and (data, None) otherwise. Override train_step if your batch shape differs.

required

Returns:

Type Description
Any

(pred, loss). pred is the model output (used by metrics.update).

Tensor | None

loss is the scalar loss returned to the caller for reduced logging.

tuple[Any, Tensor | None]

The default raises when criterion is missing or returns None;

tuple[Any, Tensor | None]

overrides may return (pred, None) to signal no loss available, in

tuple[Any, Tensor | None]

which case the caller skips loss bookkeeping.

Raises:

Type Description
ValueError

self.model is not initialized, or criterion is missing or returned None.

Side effects: moves data to self.device, runs forward under train_context() (autocast + optional DDP no-sync), updates self.metrics when bound, then calls self.backward(loss) and self.step() to scale gradients, advance accumulation state, and flush the optimizer on accumulation boundaries.

Do not

  • Zero gradients (optimizer_step does this on flush).
  • Call self.optimizer.step() directly (use self.step()).
  • Mutate train_state.global_step or train_state.micro_step.
  • Implement gradient scaling here (override backward() instead).
  • Call save_checkpoint() (cadence is owned by the loop method).

Backend notes:

  • DeepSpeedRunner inherits the default; backward/step route through the DeepSpeed engine.
  • ParallelRunner overrides this method when a pipeline schedule is set; the schedule owns micro-batching and loss reduction.
Source code in danling/runners/torch_runner.py
Python
def train_step(self, data: Any) -> tuple[Any, torch.Tensor | None]:
    """
    Run one training micro-step.

    The default implementation runs forward → loss → metric update → backward
    → step for one micro-batch.

    **Called when:** once per micro-batch by `train_epoch`/`train_steps`. The
    caller seeds the loop's accumulation state before each invocation; this
    method consumes that state through `backward()` and `step()`.

    **Precondition:** `self.model`, `self.optimizer`, and `self.criterion`
    are bound; `self.mode == RunnerMode.train`.

    Args:
        data: One micro-batch. The default unpacks `data["input"]` /
            `data.get("target")` for mappings, `(data[0], data[1])` for
            non-string sequences, and `(data, None)` otherwise. Override
            `train_step` if your batch shape differs.

    Returns:
        `(pred, loss)`. `pred` is the model output (used by `metrics.update`).
        `loss` is the scalar loss returned to the caller for reduced logging.
        The default raises when `criterion` is missing or returns `None`;
        overrides may return `(pred, None)` to signal no loss available, in
        which case the caller skips loss bookkeeping.

    Raises:
        ValueError: `self.model` is not initialized, or `criterion` is missing
            or returned `None`.

    **Side effects:** moves `data` to `self.device`, runs forward under
    `train_context()` (autocast + optional DDP no-sync), updates
    `self.metrics` when bound, then calls `self.backward(loss)` and
    `self.step()` to scale gradients, advance accumulation state, and flush
    the optimizer on accumulation boundaries.

    !!! danger "Do not"
        - Zero gradients (`optimizer_step` does this on flush).
        - Call `self.optimizer.step()` directly (use `self.step()`).
        - Mutate `train_state.global_step` or `train_state.micro_step`.
        - Implement gradient scaling here (override `backward()` instead).
        - Call `save_checkpoint()` (cadence is owned by the loop method).

    **Backend notes:**

    - `DeepSpeedRunner` inherits the default; `backward`/`step` route
      through the DeepSpeed engine.
    - `ParallelRunner` overrides this method when a pipeline schedule is
      set; the schedule owns micro-batching and loss reduction.
    """
    data = self.to_device(data)
    with self.train_context():
        if isinstance(data, Mapping):
            inputs = data["input"]
            target = data.get("target")
        elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
            inputs = data[0]
            target = data[1] if len(data) > 1 else None
        else:
            inputs = data
            target = None

        if self.model is None:
            raise ValueError("cannot run train_step: model is not initialized")
        pred = self.model(**inputs) if isinstance(inputs, Mapping) else self.model(inputs)
        loss = self.criterion(pred, target) if self.criterion is not None else None
        if loss is None:
            raise ValueError("cannot run train_step: criterion did not produce a loss")
        if self.metrics is not None and pred is not None and target is not None:
            self.metrics.update(pred, target)
        self.backward(loss)
        self.step()
    return pred, loss

backward

Python
backward(loss: Tensor) -> None

Run backward pass on one micro-step loss.

Called when: the default train_step has produced a loss tensor. The method receives the raw micro-step loss; accumulation scaling and loss-normalizer weighting are applied before Tensor.backward().

Parameters:

Name Type Description Default
loss
Tensor

The loss tensor for this micro-step.

required

Side effects: accumulates gradients on model parameters.

Do not

  • Advance the optimizer here; optimizer stepping belongs to step()/optimizer_step().
  • Mutate train_state counters.

Backend notes:

  • DeepSpeedRunner overrides this hook to call the DeepSpeed engine’s backward method.
Source code in danling/runners/torch_runner.py
Python
def backward(self, loss: torch.Tensor) -> None:
    """
    Run backward pass on one micro-step loss.

    **Called when:** the default `train_step` has produced a loss tensor.
    The method receives the raw micro-step loss; accumulation scaling and
    loss-normalizer weighting are applied before `Tensor.backward()`.

    Args:
        loss: The loss tensor for this micro-step.

    **Side effects:** accumulates gradients on model parameters.

    !!! danger "Do not"
        - Advance the optimizer here; optimizer stepping belongs to
          `step()`/`optimizer_step()`.
        - Mutate `train_state` counters.

    **Backend notes:**

    - `DeepSpeedRunner` overrides this hook to call the DeepSpeed engine's
      backward method.
    """

    self._scaled_loss_for_backward(loss).backward()

step

Python
step() -> None

Advance the accumulation state machine after one training micro-step.

Called when: train_step finishes backward for a micro-batch.

Side effects: increments train_state.micro_step and calls optimizer_step() only when the accumulation boundary is reached or the surrounding loop marks the current batch as the final flush in a partial window.

Do not

  • Call this from evaluation/inference paths.
  • Call optimizer_step() in addition to this method from the same micro-step.
  • Adjust train_state.micro_step in train_step overrides.
Source code in danling/runners/torch_runner.py
Python
def step(self) -> None:
    """
    Advance the accumulation state machine after one training micro-step.

    **Called when:** `train_step` finishes backward for a micro-batch.

    **Side effects:** increments `train_state.micro_step` and calls
    `optimizer_step()` only when the accumulation boundary is reached or
    the surrounding loop marks the current batch as the final flush in a
    partial window.

    !!! danger "Do not"
        - Call this from evaluation/inference paths.
        - Call `optimizer_step()` in addition to this method from the same
          micro-step.
        - Adjust `train_state.micro_step` in `train_step` overrides.
    """
    micro_steps = self.train_state.micro_step + 1
    self.train_state.micro_step = micro_steps
    if self._train_window_will_flush:
        self.optimizer_step()
        remainder = micro_steps % self.accum_steps
        if self.accum_steps > 1 and remainder != 0:
            self.train_state.micro_step += self.accum_steps - remainder
        return
    if self.accum_steps <= 1 or micro_steps % self.accum_steps == 0:
        self.optimizer_step()

optimizer_step

Python
optimizer_step() -> bool

Perform one backend optimizer update.

The default Torch implementation waits for checkpoint staging, applies accumulated-loss gradient scaling, optional grad clipping, non-finite grad skip logic, optimizer/scheduler stepping through OptimizerContainer, gradient zeroing, profiler advancement, and garbage-collection cadence.

Called when: step() reaches an accumulation boundary, or _flush_pending_optimizer_step() flushes a partial boundary before shutdown.

Returns:

Type Description
bool

True when an optimizer update is applied, otherwise False.

Side effects: may update optimizer/scheduler state; increments train_state.global_step only when an update is actually applied.

Do not

  • Increment global_step on skipped updates.
  • Forget to zero gradients after a successful update or skipped non-finite update.
  • Bypass checkpoint_manager.maybe_wait_for_staging().

Backend notes:

  • DeepSpeedRunner overrides this hook because the DeepSpeed engine owns the concrete optimizer update.
Source code in danling/runners/torch_runner.py
Python
def optimizer_step(self) -> bool:
    """
    Perform one backend optimizer update.

    The default Torch implementation waits for checkpoint staging, applies
    accumulated-loss gradient scaling, optional grad clipping, non-finite
    grad skip logic, optimizer/scheduler stepping through
    `OptimizerContainer`, gradient zeroing, profiler advancement, and
    garbage-collection cadence.

    **Called when:** `step()` reaches an accumulation boundary, or
    `_flush_pending_optimizer_step()` flushes a partial boundary before
    shutdown.

    Returns:
        `True` when an optimizer update is applied, otherwise `False`.

    **Side effects:** may update optimizer/scheduler state; increments
    `train_state.global_step` only when an update is actually applied.

    !!! danger "Do not"
        - Increment `global_step` on skipped updates.
        - Forget to zero gradients after a successful update or skipped
          non-finite update.
        - Bypass `checkpoint_manager.maybe_wait_for_staging()`.

    **Backend notes:**

    - `DeepSpeedRunner` overrides this hook because the DeepSpeed engine
      owns the concrete optimizer update.
    """
    if self.optimizer_container is None and self.optimizer is None:
        raise ValueError(
            "cannot perform optimizer step: no optimizer is configured; "
            "set `self.optimizer`, implement `build_optimizer()`, or override `optimizer_step()`"
        )

    self.checkpoint_manager.maybe_wait_for_staging()
    grad_scale = self._gradient_scale_for_step()
    if grad_scale is not None:
        self._scale_optimizer_gradients(grad_scale)
    max_grad_value = self.max_grad_value
    max_grad_norm = self.max_grad_norm
    skip_nonfinite_grad = self.skip_nonfinite_grad
    if self.optimizer_container is not None:
        if skip_nonfinite_grad:
            has_nonfinite_grad = self.optimizer_container.has_nan_inf_grad()
            has_nonfinite_grad = self._sync_optimizer_skip_decision(has_nonfinite_grad)
            if has_nonfinite_grad:
                self.optimizer_container.zero_grad()
                self._reset_accumulation_normalization()
                return False

        stepped = self.optimizer_container.step(
            max_grad_value=max_grad_value,
            max_grad_norm=max_grad_norm,
            zero_grad=True,
            skip_nonfinite_grad=False,
        )
        if not stepped:
            self._reset_accumulation_normalization()
            return False
    elif self.optimizer is not None:
        self.optimizer.step()
        self.optimizer.zero_grad()

    self._reset_accumulation_normalization()
    self.train_state.global_step += 1
    self._step_profiler()
    self._maybe_reduce_train_process_group_timeout()
    self.supervisor.maybe_collect_garbage(self.train_state.global_step, scope="train")
    return True

train

Python
train(
    train_splits: list[str] | None = None,
    evaluate_splits: list[str] | None = None,
) -> RoundDict

Run the full training workflow.

Selects epoch mode or step mode from self.is_step_mode, validates explicit split lists against the runner’s configured/inferred splits, and delegates to train_epochs or train_steps.

Called when: user code starts training.

Parameters:

Name Type Description Default
train_splits
list[str] | None

Optional training splits. When None, use self.train_splits.

None
evaluate_splits
list[str] | None

Optional evaluation splits. When None, use self.evaluate_splits.

None

Returns:

Type Description
RoundDict

Aggregated runner results (self.results).

Raises:

Type Description
ValueError

no valid training split can be resolved.

Side effects: prints selected splits and runs the selected training loop. Checkpointing, result writing, scheduler stepping, and early stop are owned by the delegated loop method.

Source code in danling/runners/torch_runner.py
Python
def train(
    self,
    train_splits: list[str] | None = None,
    evaluate_splits: list[str] | None = None,
) -> RoundDict:
    """
    Run the full training workflow.

    Selects epoch mode or step mode from `self.is_step_mode`, validates
    explicit split lists against the runner's configured/inferred splits,
    and delegates to `train_epochs` or `train_steps`.

    **Called when:** user code starts training.

    Args:
        train_splits: Optional training splits. When `None`, use `self.train_splits`.
        evaluate_splits: Optional evaluation splits. When `None`, use `self.evaluate_splits`.

    Returns:
        Aggregated runner results (`self.results`).

    Raises:
        ValueError: no valid training split can be resolved.

    **Side effects:** prints selected splits and runs the selected training
    loop. Checkpointing, result writing, scheduler stepping, and early stop
    are owned by the delegated loop method.
    """

    train_splits = self._resolve_requested_splits(train_splits, self.train_splits, kind="training")
    if not train_splits:
        raise ValueError("cannot start training: no valid training split was resolved")

    evaluate_splits = self._resolve_requested_splits(evaluate_splits, self.evaluate_splits, kind="evaluation")

    print(f"train: splits={train_splits}")
    print(f"evaluate: splits={evaluate_splits}")
    if self.is_step_mode:
        return self.train_steps(train_splits=train_splits, evaluate_splits=evaluate_splits)
    return self.train_epochs(train_splits=train_splits, evaluate_splits=evaluate_splits)

train_epochs

Python
train_epochs(
    train_splits: list[str] | None = None,
    evaluate_splits: list[str] | None = None,
) -> RoundDict

Run epoch-mode training until self.epochs is reached.

Each epoch runs all train splits, then all evaluation splits, advances epoch/metric schedulers, appends and writes results, and saves periodic checkpoints.

Called when: train dispatches while config.epochs is set, or user code explicitly wants epoch-mode semantics.

Parameters:

Name Type Description Default
train_splits
list[str] | None

Training splits for each epoch.

None
evaluate_splits
list[str] | None

Evaluation splits after each epoch.

None

Returns:

Type Description
RoundDict

Aggregated runner results (self.results).

Raises:

Type Description
ValueError

config.epochs is not set.

Source code in danling/runners/torch_runner.py
Python
def train_epochs(
    self,
    train_splits: list[str] | None = None,
    evaluate_splits: list[str] | None = None,
) -> RoundDict:
    """
    Run epoch-mode training until `self.epochs` is reached.

    Each epoch runs all train splits, then all evaluation splits, advances
    epoch/metric schedulers, appends and writes results, and saves periodic
    checkpoints.

    **Called when:** `train` dispatches while `config.epochs` is set, or
    user code explicitly wants epoch-mode semantics.

    Args:
        train_splits: Training splits for each epoch.
        evaluate_splits: Evaluation splits after each epoch.

    Returns:
        Aggregated runner results (`self.results`).

    Raises:
        ValueError: `config.epochs` is not set.
    """
    if train_splits is None:
        train_splits = self.train_splits
    if evaluate_splits is None:
        evaluate_splits = self.evaluate_splits

    total_epochs = self.epochs
    if total_epochs is None:
        raise ValueError("cannot run epoch-mode training: config.epochs is not set")
    print(f"train: epoch mode start epoch={self.train_state.epoch} total_epochs={total_epochs}")
    checkpoint_cadence = self.checkpoint_interval
    early_stop_counter = 0
    patience = self.patience
    for epoch in range(self.train_state.epoch, total_epochs):
        self.supervisor.maybe_handle_termination_signal()
        self.train_state.epoch = epoch
        result = RoundDict()
        for split in train_splits:
            result[split] = self.train_epoch(split)
            self.supervisor.maybe_handle_termination_signal()
        for split in evaluate_splits:
            result[split] = self.evaluate_epoch(split)
            self.supervisor.maybe_handle_termination_signal()
        self._step_epoch_scheduler(result)
        self.append_result(result, index=epoch)
        print(self.format_epoch_result(result, epochs=epoch, total_epochs=total_epochs))
        self.save_result()
        self.train_state.epoch = epoch + 1
        if checkpoint_cadence > 0 and self.train_state.epoch % checkpoint_cadence == 0:
            self.save_checkpoint(epochs=epoch)
        early_stop_counter = 0 if self.is_best else early_stop_counter + 1
        if early_stop_counter > patience:
            print("train: early-stop triggered")
            break
    self.save_checkpoint(last_step=True)
    return self.results

train_epoch

Python
train_epoch(split: str = 'train') -> RoundDict

Run one full dataloader pass for a training split.

This is the per-split epoch loop. It sets train mode, resets meters and train metrics, manages accumulation-window normalization, invokes train_step for each micro-batch, emits step logs, and records interval/epoch telemetry.

Called when: train_epochs processes one train split.

Parameters:

Name Type Description Default
split
str

Training split name.

'train'

Returns:

Type Description
RoundDict

Epoch-level metric mapping for this split.

Side effects: updates optimizer state through train_step, advances train_state.global_step on optimizer flushes, writes step logs, and may save step-cadence checkpoints.

Do not

  • Call this for evaluation data; use evaluate_epoch.
  • Override this just to change one batch’s forward/loss logic; override train_step.
  • Manually manage gradient zeroing inside train_step; this loop and optimizer_step own accumulation boundaries.
  • Increment train_state.epoch; the surrounding train_epochs loop owns epoch progress.
  • Save result or checkpoint aliases here; train_epochs owns epoch-level persistence.
See Also

train_steps: Step-mode counterpart that consumes splits against a global step budget instead of one epoch per split.

Source code in danling/runners/torch_runner.py
Python
def train_epoch(self, split: str = "train") -> RoundDict:
    """
    Run one full dataloader pass for a training split.

    This is the per-split epoch loop. It sets train mode, resets meters and
    train metrics, manages accumulation-window normalization, invokes
    `train_step` for each micro-batch, emits step logs, and records
    interval/epoch telemetry.

    **Called when:** `train_epochs` processes one train split.

    Args:
        split: Training split name.

    Returns:
        Epoch-level metric mapping for this split.

    **Side effects:** updates optimizer state through `train_step`,
    advances `train_state.global_step` on optimizer flushes, writes step
    logs, and may save step-cadence checkpoints.

    !!! danger "Do not"
        - Call this for evaluation data; use `evaluate_epoch`.
        - Override this just to change one batch's forward/loss logic;
          override `train_step`.
        - Manually manage gradient zeroing inside `train_step`; this loop
          and `optimizer_step` own accumulation boundaries.
        - Increment `train_state.epoch`; the surrounding `train_epochs`
          loop owns epoch progress.
        - Save result or checkpoint aliases here; `train_epochs` owns
          epoch-level persistence.

    See Also:
        [`train_steps`][danling.runners.TorchRunner.train_steps]:
            Step-mode counterpart that consumes splits against a global
            step budget instead of one epoch per split.
    """
    loader = self.dataloaders[split]
    loader_length = self._loader_length(loader)
    length = loader_length - 1 if loader_length is not None else None
    last_loss: torch.Tensor | None = None
    last_loss_n: int | None = None
    self._set_loader_epoch(loader, self.train_state.epoch)
    self.mode = RunnerMode.train
    self.split = split
    self.meters.reset()
    self.metrics = self.train_metrics
    if self.metrics is not None:
        self.metrics.reset()
    telemetry = LoopTelemetry(self, start_time=self.loop_time())
    self._reset_accumulation_normalization()
    if self.optimizer_container is not None:
        self.optimizer_container.zero_grad()
    elif self.optimizer is not None:
        self.optimizer.zero_grad()
    checkpoint_cadence = self.checkpoint_interval

    for iteration, data, will_flush in self._iter_train_batches(loader):
        self.supervisor.maybe_handle_termination_signal()
        step_before = self.train_state.global_step
        # Positive int = weighted-loss signal; None = no signal (uniform window).
        # 0 or missing collapses to None so the accumulation state machine
        # picks "uniform" cleanly instead of being silently coerced to 1.
        loss_n = self._get_loss_normalizer(data)
        if loss_n is not None and loss_n <= 0:
            loss_n = None
        self._pending_loss_normalizer = loss_n
        self._train_window_will_flush = will_flush
        try:
            _, loss = self.train_step(data)
        finally:
            self._train_window_will_flush = False
            self._pending_loss_normalizer = None

        self.supervisor.mark_heartbeat_progress()
        self.supervisor.maybe_handle_termination_signal()
        current_time = self.loop_time()
        if self.scheduler is not None and hasattr(self.scheduler, "get_last_lr"):
            self.meters.lr.update(self.scheduler.get_last_lr()[0])
        if loss is not None:
            # `loss_n or 1` weights a missing normalizer as a single-sample meter update;
            # criteria that emit a real loss for zero-valid-token batches are not supported here.
            self.meters.loss.update(loss.detach(), n=loss_n or 1)
        telemetry.observe(iteration=iteration, data=data, current_time=current_time)

        step_after = self.train_state.global_step
        if checkpoint_cadence > 0 and step_after != step_before and step_after % checkpoint_cadence == 0:
            self.save_checkpoint()

        if self.log_interval > 0 and (
            (iteration > 0 and iteration % self.log_interval == 0) or iteration == length
        ):
            telemetry.emit_log(split=split, iteration=iteration, length=length, loss=loss, loss_n=loss_n)
        last_loss = loss
        last_loss_n = loss_n

    if (
        length is None
        and self.log_interval > 0
        and telemetry.last_iteration is not None
        and telemetry.last_iteration != telemetry.last_print_iteration
    ):
        assert telemetry.last_iteration is not None
        telemetry.emit_log(
            split=split,
            iteration=telemetry.last_iteration,
            length=length,
            loss=last_loss,
            loss_n=last_loss_n,
            reset_peak_stats=False,
        )
    result = self.get_epoch_result()
    telemetry.finalize_result(result, elapsed_seconds=self.loop_time(sync=True) - telemetry.start_time)
    return result

train_steps

Python
train_steps(
    train_splits: list[str] | None = None,
    evaluate_splits: list[str] | None = None,
) -> RoundDict

Run step-mode training for the configured global step budget.

Step mode consumes train splits in sorted split order until train_state.global_step >= self.steps, then optionally evaluates configured evaluation splits with evaluate_steps.

Called when: train dispatches while config.epochs is unset, or user code explicitly wants a global-step budget.

Parameters:

Name Type Description Default
train_splits
list[str] | None

Training splits to consume in order.

None
evaluate_splits
list[str] | None

Evaluation splits to run after training steps finish.

None

Returns:

Type Description
RoundDict

Aggregated runner results (self.results).

Raises:

Type Description
ValueError

total step budget cannot be resolved.

Side effects: updates epoch as an outer split-round counter, appends one result row indexed by global_step, writes result files, and saves the final checkpoint.

Do not

  • Assume a split is consumed exactly once; step mode can resume a split iterator across outer rounds.
  • Mutate train_state.global_step outside optimizer stepping.
See Also

train_epoch: Per-split epoch loop used by epoch-mode training.

Source code in danling/runners/torch_runner.py
Python
def train_steps(
    self,
    train_splits: list[str] | None = None,
    evaluate_splits: list[str] | None = None,
) -> RoundDict:
    """
    Run step-mode training for the configured global step budget.

    Step mode consumes train splits in sorted split order until
    `train_state.global_step >= self.steps`, then optionally evaluates
    configured evaluation splits with `evaluate_steps`.

    **Called when:** `train` dispatches while `config.epochs` is unset, or
    user code explicitly wants a global-step budget.

    Args:
        train_splits: Training splits to consume in order.
        evaluate_splits: Evaluation splits to run after training steps finish.

    Returns:
        Aggregated runner results (`self.results`).

    Raises:
        ValueError: total step budget cannot be resolved.

    **Side effects:** updates epoch as an outer split-round counter,
    appends one result row indexed by `global_step`, writes result files,
    and saves the final checkpoint.

    !!! danger "Do not"
        - Assume a split is consumed exactly once; step mode can resume a
          split iterator across outer rounds.
        - Mutate `train_state.global_step` outside optimizer stepping.

    See Also:
        [`train_epoch`][danling.runners.TorchRunner.train_epoch]:
            Per-split epoch loop used by epoch-mode training.
    """
    if train_splits is None:
        train_splits = self.train_splits
    if evaluate_splits is None:
        evaluate_splits = self.evaluate_splits

    total_steps = self.steps
    if total_steps is None:
        raise ValueError("cannot run step-mode training: config.steps could not be resolved")
    print(f"train: step mode start global_step={self.train_state.global_step} steps={total_steps}")
    result = RoundDict()
    step_mode_iterators: dict[str, Iterator[tuple[int, Any, bool]] | None] = dict.fromkeys(train_splits)
    step_mode_sampler_epochs = {split: self.train_state.epoch for split in train_splits}
    while self.train_state.global_step < total_steps:
        self.supervisor.maybe_handle_termination_signal()
        round_start_step = self.train_state.global_step
        round_result = RoundDict()
        total_train_splits = len(train_splits)
        for split_index, split in enumerate(train_splits):
            self.supervisor.maybe_handle_termination_signal()
            self.mode = RunnerMode.train
            self.split = split
            remaining = total_steps - self.train_state.global_step
            if remaining <= 0:
                break
            loader = self.dataloaders[split]
            remaining_splits = total_train_splits - split_index
            split_steps = self._step_mode_split_budget(
                remaining_steps=remaining,
                remaining_splits=remaining_splits,
                loader=loader,
            )
            if split_steps <= 0:
                break
            start_global_step = self.train_state.global_step
            target_global_step = start_global_step + split_steps
            length = max(target_global_step - self.train_state.global_step - 1, 0)
            self.meters.reset()
            self.metrics = self.train_metrics
            if self.metrics is not None:
                self.metrics.reset()
            telemetry = LoopTelemetry(self, start_time=self.loop_time())
            self._reset_accumulation_normalization()
            if self.optimizer_container is not None:
                self.optimizer_container.zero_grad()
            elif self.optimizer is not None:
                self.optimizer.zero_grad()
            checkpoint_cadence = self.checkpoint_interval
            batch_iteration = -1

            while self.train_state.global_step < target_global_step:
                batch: tuple[int, Any, bool] | None = None
                iterator = step_mode_iterators[split]
                recreated = False
                while True:
                    if iterator is None:
                        if recreated:
                            break
                        self._set_loader_epoch(loader, step_mode_sampler_epochs[split])
                        iterator = self._iter_train_batches(loader)
                        step_mode_iterators[split] = iterator
                        recreated = True
                    try:
                        batch = next(iterator)
                        break
                    except StopIteration:
                        iterator = None
                        step_mode_iterators[split] = None
                        step_mode_sampler_epochs[split] += 1
                if batch is None:
                    break
                _, data, will_flush = batch
                batch_iteration += 1
                self.supervisor.maybe_handle_termination_signal()
                step_before = self.train_state.global_step
                # See `train_epoch` for normalizer semantics.
                loss_n = self._get_loss_normalizer(data)
                if loss_n is not None and loss_n <= 0:
                    loss_n = None
                self._pending_loss_normalizer = loss_n
                self._train_window_will_flush = will_flush
                try:
                    _, loss = self.train_step(data)
                finally:
                    self._train_window_will_flush = False
                    self._pending_loss_normalizer = None

                self.supervisor.mark_heartbeat_progress()
                self.supervisor.maybe_handle_termination_signal()
                current_time = self.loop_time()
                if self.scheduler is not None and hasattr(self.scheduler, "get_last_lr"):
                    self.meters.lr.update(self.scheduler.get_last_lr()[0])
                if loss is not None:
                    self.meters.loss.update(loss.detach(), n=loss_n or 1)
                telemetry.observe(iteration=batch_iteration, data=data, current_time=current_time)

                step_after = self.train_state.global_step
                if checkpoint_cadence > 0 and step_after != step_before and step_after % checkpoint_cadence == 0:
                    self.save_checkpoint()

                step_iteration = step_after - start_global_step - 1 if step_after != step_before else None
                if (
                    self.log_interval > 0
                    and step_iteration is not None
                    and (
                        (step_iteration > 0 and step_iteration % self.log_interval == 0) or step_iteration == length
                    )
                ):
                    telemetry.emit_log(
                        split=split,
                        iteration=batch_iteration,
                        length=length,
                        loss=loss,
                        loss_n=loss_n,
                        display_iteration=step_iteration,
                    )

            round_result[split] = self.get_epoch_result()
            telemetry.finalize_result(
                round_result[split], elapsed_seconds=self.loop_time(sync=True) - telemetry.start_time
            )
            self.supervisor.maybe_handle_termination_signal()

        if self.train_state.global_step == round_start_step:
            remaining_steps = total_steps - self.train_state.global_step
            warn(
                f"step-mode training made no progress after one full split pass "
                f"(target={total_steps}, reached={self.train_state.global_step}, remaining={remaining_steps})",
                RuntimeWarning,
                stacklevel=2,
            )
            break
        self._step_epoch_scheduler(round_result)
        result = round_result
        self.train_state.epoch += 1
    remaining_steps = total_steps - self.train_state.global_step
    if remaining_steps > 0:
        warn(
            f"step-mode training finished with {remaining_steps} step(s) remaining "
            f"(target={total_steps}, reached={self.train_state.global_step})",
            RuntimeWarning,
            stacklevel=2,
        )
    for split in evaluate_splits:
        result[split] = self.evaluate_steps(split=split)
    self.append_result(result, index=self.train_state.global_step)
    print(f"train: step mode result={result}")
    self.save_result()
    self.save_checkpoint(last_step=True)
    return self.results

evaluate_step

Python
evaluate_step(data: Any) -> tuple[Any, Tensor | None]

Run one evaluation micro-step.

The default implementation runs forward → optional loss → optional metric update under forward_context(). No backward pass and no optimizer step.

Called when: once per micro-batch by evaluate_epoch/evaluate_steps, which run under torch.inference_mode().

Precondition: at least one of self.model or self.ema is bound. self.mode == RunnerMode.evaluate. The default prefers self.ema over self.model when both are available.

Parameters:

Name Type Description Default
data
Any

One micro-batch. The default unpacks data["input"] / data.get("target") for mappings, (data[0], data[1]) for non-string sequences, and (data, None) otherwise. Override evaluate_step if your batch shape differs.

required

Returns:

Type Description
Any

(pred, loss). pred is the model output (used by metrics.update).

Tensor | None

loss is the scalar loss returned to the caller for reduced

tuple[Any, Tensor | None]

logging, or None when no criterion is set.

Raises:

Type Description
ValueError

neither self.model nor self.ema is initialized.

Side effects: moves data to self.device, runs forward through self.ema or self.model under forward_context(), computes loss when criterion is set, and updates self.metrics when bound.

Do not

  • Call self.backward(...) or self.step() (no optimizer here).
  • Mutate train_state.global_step or train_state.micro_step.
  • Switch the runner mode (the loop owns self.mode).
  • Call save_checkpoint() (cadence is owned by training loops only).

Backend notes:

  • ParallelRunner overrides this method when a pipeline schedule is set; the schedule owns micro-batching and pipeline-stage loss reduction.
Source code in danling/runners/torch_runner.py
Python
def evaluate_step(self, data: Any) -> tuple[Any, torch.Tensor | None]:
    """
    Run one evaluation micro-step.

    The default implementation runs forward → optional loss → optional
    metric update under `forward_context()`. No backward pass and no
    optimizer step.

    **Called when:** once per micro-batch by `evaluate_epoch`/`evaluate_steps`,
    which run under `torch.inference_mode()`.

    **Precondition:** at least one of `self.model` or `self.ema` is bound.
    `self.mode == RunnerMode.evaluate`. The default prefers `self.ema` over
    `self.model` when both are available.

    Args:
        data: One micro-batch. The default unpacks `data["input"]` /
            `data.get("target")` for mappings, `(data[0], data[1])` for
            non-string sequences, and `(data, None)` otherwise. Override
            `evaluate_step` if your batch shape differs.

    Returns:
        `(pred, loss)`. `pred` is the model output (used by `metrics.update`).
        `loss` is the scalar loss returned to the caller for reduced
        logging, or `None` when no `criterion` is set.

    Raises:
        ValueError: neither `self.model` nor `self.ema` is initialized.

    **Side effects:** moves `data` to `self.device`, runs forward through
    `self.ema or self.model` under `forward_context()`, computes loss when
    `criterion` is set, and updates `self.metrics` when bound.

    !!! danger "Do not"
        - Call `self.backward(...)` or `self.step()` (no optimizer here).
        - Mutate `train_state.global_step` or `train_state.micro_step`.
        - Switch the runner mode (the loop owns `self.mode`).
        - Call `save_checkpoint()` (cadence is owned by training loops only).

    **Backend notes:**

    - `ParallelRunner` overrides this method when a pipeline schedule is
      set; the schedule owns micro-batching and pipeline-stage loss
      reduction.
    """
    data = self.to_device(data)
    if isinstance(data, Mapping):
        inputs = data["input"]
        target = data.get("target")
    elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
        inputs = data[0]
        target = data[1] if len(data) > 1 else None
    else:
        inputs = data
        target = None

    if self.model is None and self.ema is None:
        raise ValueError("cannot run evaluate_step: model is not initialized")
    model = self.ema or self.model
    with self.forward_context():
        pred = model(**inputs) if isinstance(inputs, Mapping) else model(inputs)
        loss = self.criterion(pred, target) if self.criterion is not None else None

    if self.metrics is not None and pred is not None and target is not None:
        self.metrics.update(pred, target)

    return pred, loss

evaluate

Python
evaluate(
    evaluate_splits: list[str] | None = None,
) -> RoundDict

Run evaluation across splits with epoch-mode semantics.

Called when: user code explicitly evaluates a runner, or training code delegates to evaluation helpers.

Parameters:

Name Type Description Default
evaluate_splits
list[str] | None

Optional evaluation splits. When None, use self.evaluate_splits.

None

Returns:

Type Description
RoundDict

Mapping of split -> evaluation result for this call.

Raises:

Type Description
ValueError

no valid evaluation split can be resolved.

Side effects: sets evaluation mode per split, prints a formatted aggregate result, and writes scalar outputs through evaluate_epoch.

Source code in danling/runners/torch_runner.py
Python
def evaluate(self, evaluate_splits: list[str] | None = None) -> RoundDict:
    """
    Run evaluation across splits with epoch-mode semantics.

    **Called when:** user code explicitly evaluates a runner, or training
    code delegates to evaluation helpers.

    Args:
        evaluate_splits: Optional evaluation splits. When `None`, use `self.evaluate_splits`.

    Returns:
        Mapping of split -> evaluation result for this call.

    Raises:
        ValueError: no valid evaluation split can be resolved.

    **Side effects:** sets evaluation mode per split, prints a formatted
    aggregate result, and writes scalar outputs through `evaluate_epoch`.
    """

    evaluate_splits = self._resolve_requested_splits(evaluate_splits, self.evaluate_splits, kind="evaluation")
    if not evaluate_splits:
        raise ValueError("cannot start evaluation: no valid evaluation split was resolved")
    print("evaluate: start")
    print(f"evaluate: splits={evaluate_splits}")
    result = RoundDict()
    for split in evaluate_splits:
        result[split] = self.evaluate_epoch(split=split)
    display_epoch = self.train_state.epoch
    if self.epochs is not None and display_epoch > 0:
        display_epoch -= 1
    print(self.format_epoch_result(result, epochs=display_epoch))
    return result

evaluate_epoch

Python
evaluate_epoch(split: str = 'val') -> RoundDict

Run one full dataloader pass for an evaluation split.

Sets evaluation mode, resets meters/evaluation metrics, runs evaluate_step for every batch under inference mode, emits step logs, and writes the split result at the current epoch index.

Called when: evaluate or train_epochs evaluates a split.

Parameters:

Name Type Description Default
split
str

Evaluation split name.

'val'

Returns:

Type Description
RoundDict

Epoch-level metric mapping for this split.

Side effects: updates evaluation meters/metrics, emits logs, writes scalar results, and records telemetry. It does not update optimizer or training progress counters.

Source code in danling/runners/torch_runner.py
Python
@torch.inference_mode()
def evaluate_epoch(self, split: str = "val") -> RoundDict:
    """
    Run one full dataloader pass for an evaluation split.

    Sets evaluation mode, resets meters/evaluation metrics, runs
    `evaluate_step` for every batch under inference mode, emits step logs,
    and writes the split result at the current epoch index.

    **Called when:** `evaluate` or `train_epochs` evaluates a split.

    Args:
        split: Evaluation split name.

    Returns:
        Epoch-level metric mapping for this split.

    **Side effects:** updates evaluation meters/metrics, emits logs, writes
    scalar results, and records telemetry. It does not update optimizer or
    training progress counters.
    """
    loader = self.dataloaders[split]
    loader_length = self._loader_length(loader)
    length = loader_length - 1 if loader_length is not None else None

    last_loss: torch.Tensor | None = None
    last_loss_n: int | None = None
    self.mode = RunnerMode.evaluate
    self.split = split
    self.meters.reset()
    self.metrics = self.evaluate_metrics
    if self.metrics is not None:
        self.metrics.reset()
    telemetry = LoopTelemetry(self, start_time=self.loop_time())
    consumed = 0
    for iteration, data in enumerate(loader):
        consumed = iteration + 1
        self.supervisor.maybe_handle_termination_signal()
        loss_n = self._get_loss_normalizer(data)
        if loss_n is not None and loss_n <= 0:
            loss_n = None
        _, loss = self.evaluate_step(data)
        self.supervisor.mark_heartbeat_progress()
        self.supervisor.maybe_handle_termination_signal()
        current_time = self.loop_time()
        if loss is not None:
            self.meters.loss.update(loss.detach(), n=loss_n or 1)
        telemetry.observe(
            iteration=iteration,
            data=data,
            current_time=current_time,
        )
        self.supervisor.maybe_collect_garbage(iteration + 1, scope=f"evaluate:{split}")

        if self.log_interval > 0 and (
            (iteration > 0 and iteration % self.log_interval == 0) or iteration == length
        ):
            telemetry.emit_log(split=split, iteration=iteration, length=length, loss=loss, loss_n=loss_n)
        last_loss = loss
        last_loss_n = loss_n

    if (
        length is None
        and self.log_interval > 0
        and telemetry.last_iteration is not None
        and telemetry.last_iteration != telemetry.last_print_iteration
    ):
        assert telemetry.last_iteration is not None
        telemetry.emit_log(
            split=split,
            iteration=telemetry.last_iteration,
            length=length,
            loss=last_loss,
            loss_n=last_loss_n,
            reset_peak_stats=False,
        )
    result = self.get_epoch_result()
    telemetry.finalize_result(
        result, elapsed_seconds=self.loop_time(sync=True) - telemetry.start_time, steps=consumed
    )
    self.write_result(result, split, self.train_state.epoch)
    return result

evaluate_steps

Python
evaluate_steps(
    split: str = "val", steps: int | None = None
) -> RoundDict

Run bounded evaluation steps on one split.

Used by step-mode training to evaluate a small fixed number of batches without requiring a full evaluation pass.

Called when: train_steps evaluates configured splits after the step budget finishes, or user code requests bounded evaluation.

Parameters:

Name Type Description Default
split
str

Evaluation split name.

'val'
steps
int | None

Number of batches to evaluate. When None, defaults to max(self.steps // 20, 1).

None

Returns:

Type Description
RoundDict

Step-bounded evaluation metrics.

Raises:

Type Description
ValueError

step budget cannot be inferred, steps is negative, or the dataloader exhausts before the requested number of steps.

Side effects: writes scalar results at train_state.global_step.

Source code in danling/runners/torch_runner.py
Python
@torch.inference_mode()
def evaluate_steps(self, split: str = "val", steps: int | None = None) -> RoundDict:
    """
    Run bounded evaluation steps on one split.

    Used by step-mode training to evaluate a small fixed number of batches
    without requiring a full evaluation pass.

    **Called when:** `train_steps` evaluates configured splits after the
    step budget finishes, or user code requests bounded evaluation.

    Args:
        split: Evaluation split name.
        steps: Number of batches to evaluate. When `None`, defaults to `max(self.steps // 20, 1)`.

    Returns:
        Step-bounded evaluation metrics.

    Raises:
        ValueError: step budget cannot be inferred, `steps` is negative, or
            the dataloader exhausts before the requested number of steps.

    **Side effects:** writes scalar results at `train_state.global_step`.
    """
    if steps is None:
        total_steps = self.steps
        if total_steps is None:
            raise ValueError("cannot infer evaluation steps: step budget is unavailable; pass `steps`")
        steps = max(total_steps // 20, 1)
    if steps < 0:
        raise ValueError(f"invalid steps: expected a non-negative value, got {steps}")
    loader = self.dataloaders[split]
    length = steps - 1

    self.mode = RunnerMode.evaluate
    self.split = split
    if steps == 0:
        self.meters.reset()
        self.metrics = self.evaluate_metrics
        if self.metrics is not None:
            self.metrics.reset()
        result = self.get_epoch_result()
        self.write_result(result, split, self.train_state.global_step)
        return result

    self.meters.reset()
    self.metrics = self.evaluate_metrics
    if self.metrics is not None:
        self.metrics.reset()
    telemetry = LoopTelemetry(self, start_time=self.loop_time())
    consumed = 0
    for iteration, data in enumerate(loader):
        if steps is not None and iteration >= steps:
            break
        consumed = iteration + 1
        self.supervisor.maybe_handle_termination_signal()
        loss_n = self._get_loss_normalizer(data)
        if loss_n is not None and loss_n <= 0:
            loss_n = None
        _, loss = self.evaluate_step(data)
        self.supervisor.mark_heartbeat_progress()
        self.supervisor.maybe_handle_termination_signal()
        current_time = self.loop_time()
        if loss is not None:
            self.meters.loss.update(loss.detach(), n=loss_n or 1)
        telemetry.observe(iteration=iteration, data=data, current_time=current_time)
        self.supervisor.maybe_collect_garbage(iteration + 1, scope=f"evaluate:{split}")

        if self.log_interval > 0 and (
            (iteration > 0 and iteration % self.log_interval == 0) or iteration == length
        ):
            telemetry.emit_log(split=split, iteration=iteration, length=length, loss=loss, loss_n=loss_n)

    if steps is not None and consumed < steps:
        raise ValueError(
            f"evaluate steps exhausted early on split '{split}': requested {steps} step(s), got {consumed}"
        )
    result = self.get_epoch_result()
    telemetry.finalize_result(
        result, elapsed_seconds=self.loop_time(sync=True) - telemetry.start_time, steps=consumed
    )
    self.write_result(result, split, self.train_state.global_step)
    return result

infer_step

Python
infer_step(data: Any) -> list[float]

Run one inference micro-step.

The default implementation runs forward through self.ema or self.model, detaches scalar-per-example predictions, squeezes the trailing dimension, moves them to CPU, and returns them as a Python list.

Called when: once per micro-batch by infer/_iter_infer_batches. The method is decorated with torch.inference_mode().

Precondition: at least one of self.model or self.ema is bound. self.mode == RunnerMode.infer.

Parameters:

Name Type Description Default
data
Any

One micro-batch. The default unpacks data["input"] for mappings, data[0] for non-string sequences, and data itself otherwise. Override infer_step if your batch shape differs or you need to pass auxiliary tensors to the model.

required

Returns:

Type Description
list[float]

List of CPU floats for scalar-per-example predictions. The

list[float]

default converts with pred.squeeze(-1).detach().cpu().tolist().

list[float]

Override if your model emits multi-dim tensors, mappings, or

list[float]

non-numeric outputs.

Raises:

Type Description
ValueError

neither self.model nor self.ema is initialized.

Side effects: moves data to self.device, runs forward through self.ema or self.model under forward_context(), then converts the output to a CPU list.

Do not

  • Compute or accumulate metrics (inference is metric-free).
  • Mutate runner state counters.
  • Return a torch.Tensor (callers expect list[float] for batched aggregation and streaming).
  • Call self.backward(...) or self.step().

Backend notes:

  • ParallelRunner overrides this method when a pipeline schedule is set; non-first-stage ranks pass data=None and the schedule routes activations through pipeline communication.
Source code in danling/runners/torch_runner.py
Python
@torch.inference_mode()
def infer_step(self, data: Any) -> list[float]:
    """
    Run one inference micro-step.

    The default implementation runs forward through `self.ema or self.model`,
    detaches scalar-per-example predictions, squeezes the trailing
    dimension, moves them to CPU, and returns them as a Python list.

    **Called when:** once per micro-batch by `infer`/`_iter_infer_batches`.
    The method is decorated with `torch.inference_mode()`.

    **Precondition:** at least one of `self.model` or `self.ema` is bound.
    `self.mode == RunnerMode.infer`.

    Args:
        data: One micro-batch. The default unpacks `data["input"]` for
            mappings, `data[0]` for non-string sequences, and `data`
            itself otherwise. Override `infer_step` if your batch shape
            differs or you need to pass auxiliary tensors to the model.

    Returns:
        List of CPU floats for scalar-per-example predictions. The
        default converts with `pred.squeeze(-1).detach().cpu().tolist()`.
        Override if your model emits multi-dim tensors, mappings, or
        non-numeric outputs.

    Raises:
        ValueError: neither `self.model` nor `self.ema` is initialized.

    **Side effects:** moves `data` to `self.device`, runs forward through
    `self.ema or self.model` under `forward_context()`, then converts the
    output to a CPU list.

    !!! danger "Do not"
        - Compute or accumulate metrics (inference is metric-free).
        - Mutate runner state counters.
        - Return a `torch.Tensor` (callers expect `list[float]` for
          batched aggregation and streaming).
        - Call `self.backward(...)` or `self.step()`.

    **Backend notes:**

    - `ParallelRunner` overrides this method when a pipeline schedule is
      set; non-first-stage ranks pass `data=None` and the schedule routes
      activations through pipeline communication.
    """
    data = self.to_device(data)
    if isinstance(data, Mapping):
        inputs = data["input"]
    elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
        inputs = data[0]
    else:
        inputs = data

    if self.model is None and self.ema is None:
        raise ValueError("cannot run infer_step: model is not initialized")
    model = self.ema or self.model
    with self.forward_context():
        pred = model(**inputs) if isinstance(inputs, Mapping) else model(inputs)
    values = pred.squeeze(-1).detach().cpu().tolist()
    if isinstance(values, list):
        return values
    return [float(values)]

infer

Python
infer(
    split: str = "infer",
    *,
    steps: int | None = None,
    stream: bool | None = None
) -> list[float] | Iterator[list[float]]

Run inference on one split.

In non-stream mode this consumes all requested batches and returns a flattened Python list. In stream mode it returns an iterator of per-batch outputs and leaves consumption to the caller.

Called when: user code requests prediction-only execution.

Parameters:

Name Type Description Default
split
str

Inference split name.

'infer'
steps
int | None

Optional max number of batches to consume.

None
stream
bool | None

True returns a generator of per-batch outputs, False returns a flattened list. When None, stream only for unsized loaders without explicit steps.

None

Returns:

Type Description
list[float] | Iterator[list[float]]

Flattened predictions or a streaming iterator of batch predictions.

Raises:

Type Description
ValueError

steps is negative, or non-stream inference is requested for an unsized loader without an explicit step count.

Side effects: sets inference mode/split. It does not update metrics or optimizer state.

Source code in danling/runners/torch_runner.py
Python
def infer(
    self,
    split: str = "infer",
    *,
    steps: int | None = None,
    stream: bool | None = None,
) -> list[float] | Iterator[list[float]]:
    """
    Run inference on one split.

    In non-stream mode this consumes all requested batches and returns a
    flattened Python list. In stream mode it returns an iterator of
    per-batch outputs and leaves consumption to the caller.

    **Called when:** user code requests prediction-only execution.

    Args:
        split: Inference split name.
        steps: Optional max number of batches to consume.
        stream: `True` returns a generator of per-batch outputs, `False` returns a flattened list.
            When `None`, stream only for unsized loaders without explicit `steps`.

    Returns:
        Flattened predictions or a streaming iterator of batch predictions.

    Raises:
        ValueError: `steps` is negative, or non-stream inference is
            requested for an unsized loader without an explicit step count.

    **Side effects:** sets inference mode/split. It does not update metrics
    or optimizer state.
    """

    self.mode = RunnerMode.infer
    self.split = split
    loader = self.dataloaders[split]
    if steps is not None and steps < 0:
        raise ValueError(f"invalid steps: expected a non-negative value, got {steps}")

    loader_length = self._loader_length(loader)
    if stream is None:
        stream = steps is None and loader_length is None

    if not stream and loader_length is None and steps is None:
        raise ValueError("infer with stream=False requires `steps` for unsized loaders")

    iterator = self._iter_infer_batches(loader, steps=steps, split=split)
    if stream:
        return iterator

    total = steps if steps is not None else loader_length
    output: list[float] = []
    for values in tqdm(iterator, total=total, disable=self.distributed and not self.is_main_process):
        output.extend(values)
    return output

state_dict

Python
state_dict(cls: type = dict) -> Mapping

Return the TorchRunner checkpoint payload.

Extends BaseRunner.state_dict with backend metadata plus EMA, optimizer, scheduler, and unwrapped model state.

Called when: checkpoint managers persist a TorchRunner checkpoint.

Parameters:

Name Type Description Default
cls
type

Mapping factory used for nested payloads.

dict

Returns:

Type Description
Mapping

Mapping containing base runner state and torch component state.

Side effects: snapshots Python/NumPy/Torch RNG state before exporting.

Source code in danling/runners/torch_runner.py
Python
def state_dict(self, cls: type = dict) -> Mapping:
    """
    Return the TorchRunner checkpoint payload.

    Extends `BaseRunner.state_dict` with backend metadata plus EMA,
    optimizer, scheduler, and unwrapped model state.

    **Called when:** checkpoint managers persist a TorchRunner checkpoint.

    Args:
        cls: Mapping factory used for nested payloads.

    Returns:
        Mapping containing base runner state and torch component state.

    **Side effects:** snapshots Python/NumPy/Torch RNG state before
    exporting.
    """
    state = cls(super().state_dict(cls))
    state.update(self._export_checkpoint_metadata(cls))
    state.update(self._export_checkpoint_components(cls))
    return state

load_state_dict

Python
load_state_dict(checkpoint: Mapping[str, Any]) -> None

Restore base runner state plus Torch RNG state.

Model, optimizer, scheduler, and dataloader components are restored by load_checkpoint; this method owns only runner/RNG state.

Source code in danling/runners/torch_runner.py
Python
def load_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
    """
    Restore base runner state plus Torch RNG state.

    Model, optimizer, scheduler, and dataloader components are restored by
    `load_checkpoint`; this method owns only runner/RNG state.
    """
    super().load_state_dict(checkpoint)
    state_dict = checkpoint.get("state") or {}
    rng_state = state_dict.get("rng")
    if isinstance(rng_state, Mapping) and "torch_cpu" in rng_state and self.rng_state.torch_cpu is not None:
        torch.set_rng_state(self.rng_state.torch_cpu)
    if (
        torch.cuda.is_available()
        and isinstance(rng_state, Mapping)
        and "torch_cuda" in rng_state
        and self.rng_state.torch_cuda is not None
    ):
        torch.cuda.set_rng_state_all(self.rng_state.torch_cuda)

load_checkpoint

Python
load_checkpoint(
    checkpoint: Mapping | bytes | str | PathLike,
    *args,
    **kwargs
) -> None

Load a full checkpoint and rebind optimizer/scheduler helpers.

This delegates component restore to BaseRunner.load_checkpoint, then rebuilds the OptimizerContainer so scheduler and optimizer state stay bound after restore.

Source code in danling/runners/torch_runner.py
Python
def load_checkpoint(self, checkpoint: Mapping | bytes | str | os.PathLike, *args, **kwargs) -> None:
    """
    Load a full checkpoint and rebind optimizer/scheduler helpers.

    This delegates component restore to `BaseRunner.load_checkpoint`, then
    rebuilds the `OptimizerContainer` so scheduler and optimizer state stay
    bound after restore.
    """
    super().load_checkpoint(checkpoint, *args, **kwargs)
    self._bind_optimizer_container()

read_checkpoint

Python
read_checkpoint(
    checkpoint: Mapping | bytes | str | PathLike,
    *args,
    **kwargs
) -> Mapping[str, Any]

Read checkpoint payload from mapping/file/DCP directory input.

Source code in danling/runners/torch_runner.py
Python
def read_checkpoint(self, checkpoint: Mapping | bytes | str | os.PathLike, *args, **kwargs) -> Mapping[str, Any]:
    """Read checkpoint payload from mapping/file/DCP directory input."""
    if isinstance(checkpoint, Mapping):
        return checkpoint

    if self.config.checkpoint.backend.lower() == "dcp":
        return self.checkpoint_manager.load_checkpoint(checkpoint)
    return super().read_checkpoint(checkpoint, *args, **kwargs)

read_config classmethod

Python
read_config(
    checkpoint: Mapping | bytes | str | PathLike,
    *args,
    **kwargs
) -> RunnerConfig

Read runner config from checkpoint payload, including DCP directory inputs.

Source code in danling/runners/torch_runner.py
Python
@classmethod
def read_config(
    cls,
    checkpoint: Mapping | bytes | str | os.PathLike,
    *args,
    **kwargs,
) -> RunnerConfig:
    """Read runner config from checkpoint payload, including DCP directory inputs."""
    if isinstance(checkpoint, Mapping):
        return super().read_config(checkpoint, *args, **kwargs)

    if TorchDistributedCheckpointManager.is_checkpoint_path(checkpoint):
        return TorchDistributedCheckpointManager.read_config(checkpoint)

    return super().read_config(checkpoint, *args, **kwargs)

close

Python
close(timeout: float | None = None) -> bool

Close runner resources.

Source code in danling/runners/torch_runner.py
Python
def close(self, timeout: float | None = None) -> bool:
    """Close runner resources."""
    try:
        drained = super().close(timeout=timeout)
    except Exception:
        self._close_profiler()
        self.destroy_process_group()
        raise
    if not drained:
        return False
    self._close_profiler()
    self.destroy_process_group()
    return drained

NestedTensor

Bases: Tensor

A container for variable-length tensors that enables efficient batch operations.

NestedTensor solves a fundamental problem in deep learning: handling sequences of different lengths in batch operations. Instead of excessive padding or complex bucketing, NestedTensor provides an elegant solution that maintains both efficiency and usability.

The class provides three main views of the data: - .tensor: A padded tensor with zeros (or other value) in place of missing elements - .mask: A boolean mask indicating which elements are real vs padding - .concat: The packed tensor containing all elements concatenated without padding

When indexing a NestedTensor, the behavior depends on the index type: 1. Integer index (nt[0]): Returns a single tensor without padding 2. Slice index (nt[:]): Returns a new NestedTensor containing the selected batch elements 3. Tuple index (nt[:, 1:]): Returns a new NestedTensor with the specified sliced shape

Attributes:

Name Type Description
_values Tensor

Packed tensor data

_offsets Tensor

Top-level cumulative element counts, shape (B+1,)

_permutation tuple[int, ...]

Canonical logical-to-packed dimension permutation

_physical_shape Tensor

Per-element physical shapes, shape (B, max_ndim)

batch_first bool

Whether the first dimension is the batch dimension (B, N, *) If False, the first dimension is the sequence dimension (N, B, *)

padding_value float

Value used for padding in the padded tensor

mask_value bool

Boolean fill value for padding positions in generated masks. - mask_value=False (default): valid positions are True and padding is False. - mask_value=True: padding positions are True and valid positions are False.

Parameters:

Name Type Description Default

*tensors

Variable-length tensors or sequences to store

required

batch_first

Whether to use batch-first representation.

required

padding_value

Value to use for padding.

required

mask_value

Boolean fill value used for padding positions in masks.

required

Raises:

Type Description
ValueError

If tensors is not an iterable

Examples:

Basic usage:

Python Console Session
>>> nested_tensor = NestedTensor(torch.tensor([1, 2, 3]), torch.tensor([4, 5]))
>>> nested_tensor.shape
torch.Size([2, 3])
>>> nested_tensor.tensor  # Padded representation
tensor([[1, 2, 3],
        [4, 5, 0]])
>>> nested_tensor.mask  # Mask showing real vs padding values
tensor([[ True,  True,  True],
        [ True,  True, False]])
>>> nested_tensor.concat  # Concatenated version (no padding)
tensor([1, 2, 3, 4, 5])
Python Console Session
>>> nested_tensor[0]  # First tensor (no padding)
tensor([1, 2, 3])
>>> nested_tensor[:2]  # Returns a NestedTensor slice
NestedTensor([
    [1, 2, 3],
    [4, 5]
])
>>> nested_tensor[:, 1:]  # Slice operations return a new NestedTensor
NestedTensor([
    [2, 3],
    [5]
])

Type conversion:

Python Console Session
1
2
3
4
5
6
>>> nested_tensor.to(torch.float).tensor
tensor([[1., 2., 3.],
        [4., 5., 0.]])
>>> nested_tensor.half().tensor
tensor([[1., 2., 3.],
        [4., 5., 0.]], dtype=torch.float16)

Conversion to Python types:

Python Console Session
>>> nested_tensor.tolist()
[[1, 2, 3], [4, 5]]

Creating from Python lists:

Python Console Session
1
2
3
4
5
>>> NestedTensor(*[[1, 2, 3], [4, 5]])
NestedTensor([
    [1, 2, 3],
    [4, 5]
])
Source code in danling/tensors/nested_tensor.py
Python
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
class NestedTensor(torch.Tensor):
    r"""
    A container for variable-length tensors that enables efficient batch operations.

    `NestedTensor` solves a fundamental problem in deep learning: handling sequences of different lengths
    in batch operations. Instead of excessive padding or complex bucketing, `NestedTensor` provides an
    elegant solution that maintains both efficiency and usability.

    The class provides three main views of the data:
    - `.tensor`: A padded tensor with zeros (or other value) in place of missing elements
    - `.mask`: A boolean mask indicating which elements are real vs padding
    - `.concat`: The packed tensor containing all elements concatenated without padding

    When indexing a `NestedTensor`, the behavior depends on the index type:
    1. Integer index (`nt[0]`): Returns a single tensor without padding
    2. Slice index (`nt[:]`): Returns a new `NestedTensor` containing the selected batch elements
    3. Tuple index (`nt[:, 1:]`): Returns a new `NestedTensor` with the specified sliced shape

    Attributes:
        _values: Packed tensor data
        _offsets: Top-level cumulative element counts, shape (B+1,)
        _permutation: Canonical logical-to-packed dimension permutation
        _physical_shape: Per-element physical shapes, shape (B, max_ndim)
        batch_first: Whether the first dimension is the batch dimension (B, N, *)
            If `False`, the first dimension is the sequence dimension (N, B, *)
        padding_value: Value used for padding in the padded tensor
        mask_value: Boolean fill value for padding positions in generated masks.
            - ``mask_value=False`` (default): valid positions are ``True`` and padding is ``False``.
            - ``mask_value=True``: padding positions are ``True`` and valid positions are ``False``.

    Args:
        *tensors: Variable-length tensors or sequences to store
        batch_first: Whether to use batch-first representation.
        padding_value: Value to use for padding.
        mask_value: Boolean fill value used for padding positions in masks.

    Raises:
        ValueError: If `tensors` is not an iterable

    Examples:
        Basic usage:
        >>> nested_tensor = NestedTensor(torch.tensor([1, 2, 3]), torch.tensor([4, 5]))
        >>> nested_tensor.shape
        torch.Size([2, 3])
        >>> nested_tensor.tensor  # Padded representation
        tensor([[1, 2, 3],
                [4, 5, 0]])
        >>> nested_tensor.mask  # Mask showing real vs padding values
        tensor([[ True,  True,  True],
                [ True,  True, False]])
        >>> nested_tensor.concat  # Concatenated version (no padding)
        tensor([1, 2, 3, 4, 5])

        Indexing:
        >>> nested_tensor[0]  # First tensor (no padding)
        tensor([1, 2, 3])
        >>> nested_tensor[:2]  # Returns a NestedTensor slice
        NestedTensor([
            [1, 2, 3],
            [4, 5]
        ])
        >>> nested_tensor[:, 1:]  # Slice operations return a new NestedTensor
        NestedTensor([
            [2, 3],
            [5]
        ])

        Type conversion:
        >>> nested_tensor.to(torch.float).tensor
        tensor([[1., 2., 3.],
                [4., 5., 0.]])
        >>> nested_tensor.half().tensor
        tensor([[1., 2., 3.],
                [4., 5., 0.]], dtype=torch.float16)

        Conversion to Python types:
        >>> nested_tensor.tolist()
        [[1, 2, 3], [4, 5]]

        Creating from Python lists:
        >>> NestedTensor(*[[1, 2, 3], [4, 5]])
        NestedTensor([
            [1, 2, 3],
            [4, 5]
        ])
    """

    _values: Tensor
    _offsets: Tensor
    _permutation: tuple[int, ...]
    _physical_shape: Tensor
    _flatten_sentinel: Tensor = torch.empty(0)
    _logical_shape: torch.Size
    _batch_first: bool
    _padding_value: float
    _mask_value: bool
    _pin_memory: bool
    _packed_sizes: tuple[int, ...] | None
    _element_shapes: tuple[tuple[int, ...], ...] | None
    _cached_storage: tuple[Tensor, ...] | None
    _cached_hierarchical_offsets: tuple[Tensor, ...] | None
    _cached_tensor_view: tuple[bool, float, tuple[int, int, int], Tensor] | None
    _cached_mask_view: tuple[bool, bool, tuple[int, int], Tensor] | None
    _SERIALIZATION_VERSION = 1

    # Construction & Initialization

    @staticmethod
    def __new__(
        cls,
        *tensors: Iterable[Tensor],
        dtype: torch.dtype | None = None,
        device: torch.device | None = None,
        requires_grad: bool | None = None,
        pin_memory: bool = False,
        batch_first: bool = True,
        padding_value: SupportsFloat = 0.0,
        mask_value: bool = False,
    ):
        if len(tensors) == 1 and not isinstance(tensors[0], Tensor):
            if isinstance(tensors[0], Iterable):
                tensors = tuple(tensors[0])  # type: ignore
            else:
                raise ValueError(f"tensors must be an Iterable, but got {type(tensors[0])}.")

        # Validate and convert tensors
        validated = cls._coerce_tensors(
            tensors, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory
        )

        # Determine dtype/device from validated tensors or fallbacks
        out_dtype = validated[0].dtype if validated else (dtype or torch.get_default_dtype())
        out_device = validated[0].device if validated else (device or torch.device("cpu"))

        # Pack into values, offsets, tensor-shape metadata, and Python metadata.
        values, offsets, shape_tensor, packed_sizes, element_shapes = cls._pack(
            validated,
            dtype=out_dtype,
            device=out_device,
        )
        values = cls._maybe_pin_values(values, pin_memory)
        permutation = cls._permutation_from_element_shapes(element_shapes)

        # Compute logical shape
        logical_shape = cls._compute_logical_shape(validated, batch_first)
        if requires_grad is not None and values.requires_grad != requires_grad:
            values.requires_grad_(requires_grad)
        out_requires_grad = values.requires_grad

        result = torch.Tensor._make_wrapper_subclass(
            cls,
            logical_shape,
            dtype=out_dtype,
            device=out_device,
            requires_grad=out_requires_grad,
        )
        result._values = values
        result._offsets = offsets
        result._permutation = permutation
        result._physical_shape = shape_tensor
        result._logical_shape = logical_shape
        result._set_runtime_config(
            batch_first=batch_first,
            padding_value=padding_value,
            mask_value=mask_value,
        )
        result._pin_memory = bool(pin_memory and values.device.type == "cpu" and values.is_pinned())
        result._packed_sizes = packed_sizes
        result._element_shapes = element_shapes
        result._invalidate_transient_caches()
        cls._validate_packed_metadata(
            result._values,
            result._offsets,
            result._physical_shape,
            permutation=result._permutation,
            logical_shape=result._logical_shape,
            batch_first=result.batch_first,
            packed_sizes=result._packed_sizes,
            element_shapes=result._element_shapes,
        )
        return result

    def __init__(self, *args, **kwargs):
        pass  # All init in __new__

    # ------------------------------------------------------------------
    # Packed representation helpers
    # ------------------------------------------------------------------

    @staticmethod
    def _coerce_tensors(
        tensors: tuple,
        *,
        dtype: torch.dtype | None = None,
        device: torch.device | None = None,
        requires_grad: bool | None = None,
        pin_memory: bool = False,
    ) -> tuple[Tensor, ...]:
        if not isinstance(tensors, Iterable):
            raise ValueError(f"tensors must be an Iterable, but got {type(tensors)}.")
        if isinstance(tensors, Tensor) and hasattr(tensors, "unbind"):
            tensors = tensors.unbind()

        result: list[Tensor] = []
        common_device: torch.device | None = None
        common_ndim: int | None = None
        # Only track dtype promotion when the caller did not specify an explicit dtype.
        # When dtype is given, t.to(device, dtype=dtype) already handles casting in
        # the first pass, so the promotion loop and second pass are both unnecessary.
        needs_promotion = dtype is None
        common_dtype: torch.dtype | None = None

        for t in tensors:
            if not isinstance(t, Tensor):
                t = torch.tensor(t, dtype=dtype, device=device, pin_memory=pin_memory)
            else:
                t = t.to(device, dtype=dtype)
            if requires_grad is not None:
                t.requires_grad_(requires_grad)

            if common_device is None:
                common_device = t.device
            elif t.device != common_device:
                raise ValueError(
                    f"All tensors in NestedTensor must be on the same device, but got {common_device} and {t.device}"
                )

            if needs_promotion:
                if common_dtype is None:
                    common_dtype = t.dtype
                else:
                    common_dtype = torch.promote_types(common_dtype, t.dtype)

            if common_ndim is None:
                common_ndim = t.ndim
            elif t.ndim != common_ndim:
                raise ValueError(
                    f"All tensors must have the same number of dimensions, got ndim {common_ndim} and {t.ndim}. "
                    "If using a DataLoader with drop_last=False, squeeze the last batch before constructing "
                    "NestedTensor."
                )

            result.append(t)

        if not result:
            return ()

        # Second pass only when dtype=None AND promotion actually changed the dtype.
        if needs_promotion and common_dtype is not None and any(t.dtype != common_dtype for t in result):
            return tuple(t.to(dtype=common_dtype) for t in result)
        return tuple(result)

    @staticmethod
    def _pack(
        tensors: tuple[Tensor, ...],
        *,
        dtype: torch.dtype | None = None,
        device: torch.device | None = None,
        permutation: tuple[int, ...] | None = None,
    ) -> tuple[Tensor, Tensor, Tensor, tuple[int, ...], tuple[tuple[int, ...], ...]]:
        r"""Pack a sequence of tensors into values, offsets, tensor metadata, and Python metadata."""
        if not tensors:
            return (
                torch.empty(0, dtype=dtype or torch.get_default_dtype(), device=device),
                torch.zeros(1, dtype=torch.long),
                torch.empty(0, 0, dtype=torch.long),
                (),
                (),
            )

        max_ndim = max(t.ndim for t in tensors)
        element_shapes = tuple(tuple(int(dim) for dim in t.shape) for t in tensors)

        # Offsets and shape_tensor are metadata - always on CPU to avoid CUDA syncs.
        shape_tensor = torch.tensor([list(t.shape) + [0] * (max_ndim - t.ndim) for t in tensors], dtype=torch.long)
        if max_ndim == 0:
            values = torch.stack(tensors)
            sizes = torch.ones(len(tensors), dtype=torch.long)
            packed_sizes = tuple(1 for _ in tensors)
        else:
            if permutation is None:
                varying_dims, static_dims = NestedTensor._pack_layout_from_element_shapes(element_shapes)
                permutation = varying_dims + static_dims
            else:
                permutation = tuple(int(dim) for dim in permutation)
                if len(permutation) != max_ndim or tuple(sorted(permutation)) != tuple(range(max_ndim)):
                    raise ValueError(f"Invalid permutation dims {permutation} for tensors with rank {max_ndim}")
                ragged_rank = len(NestedTensor._hierarchical_level_sizes_from_element_shapes(element_shapes))
                varying_dims = permutation[:ragged_rank]
                static_dims = permutation[ragged_rank:]
            packed = []
            packed_sizes_list = []
            identity_permutation = tuple(range(max_ndim))
            for tensor, shape in zip(tensors, element_shapes):
                packed_size = NestedTensor._packed_size_from_shape(shape, varying_dims)
                packed_sizes_list.append(packed_size)
                packed_tensor = tensor if permutation == identity_permutation else tensor.permute(permutation)
                suffix_shape = tuple(shape[dim] for dim in static_dims)
                packed.append(packed_tensor.reshape((packed_size, *suffix_shape) if suffix_shape else (packed_size,)))
            values = torch.cat(packed, dim=0)
            sizes = torch.tensor(packed_sizes_list, dtype=torch.long)
            packed_sizes = tuple(packed_sizes_list)
        offsets = torch.zeros(len(tensors) + 1, dtype=torch.long)
        torch.cumsum(sizes, dim=0, out=offsets[1:])

        return values, offsets, shape_tensor, packed_sizes, element_shapes

    @staticmethod
    def _maybe_pin_values(values: Tensor, pin_memory: bool) -> Tensor:
        r"""Pin packed storage when requested and the values live on CPU."""
        if pin_memory and values.device.type == "cpu" and not values.is_pinned():
            return values.pin_memory()
        return values

    @staticmethod
    def _trim_shape(shape: Sequence[int]) -> tuple[int, ...]:
        end = len(shape)
        while end > 0 and shape[end - 1] == 0:
            end -= 1
        return tuple(int(shape[i]) for i in range(end))

    @staticmethod
    def _shape_numel(shape: tuple[int, ...]) -> int:
        size = 1
        for dim in shape:
            size *= int(dim)
        return size

    @classmethod
    def _permutation_from_element_shapes(cls, element_shapes: tuple[tuple[int, ...], ...]) -> tuple[int, ...]:
        varying_dims, static_dims = cls._pack_layout_from_element_shapes(element_shapes)
        return varying_dims + static_dims

    @classmethod
    def _permutation_from_physical_shape(
        cls,
        physical_shape: Tensor,
        element_shapes: tuple[tuple[int, ...], ...] | None,
    ) -> tuple[int, ...]:
        varying_dims, static_dims = cls._pack_layout_meta(physical_shape, element_shapes)
        return varying_dims + static_dims

    @staticmethod
    def _offsets_from_sizes(sizes: Sequence[int], *, dtype: torch.dtype = torch.long) -> Tensor:
        offsets = torch.empty((len(sizes) + 1,), dtype=dtype)
        offsets[0] = 0
        if sizes:
            offsets[1:] = torch.cumsum(torch.tensor(sizes, dtype=dtype), dim=0)
        return offsets

    @staticmethod
    def _meta_tensor_equal(lhs: Tensor, rhs: Tensor) -> bool:
        if _is_fake_tensor(lhs) or _is_fake_tensor(rhs):
            return lhs is rhs
        if lhs is rhs:
            return True
        if lhs.shape != rhs.shape:
            return False
        return bool(torch.equal(lhs, rhs))

    @classmethod
    def _hierarchical_level_sizes_from_element_shapes(
        cls,
        element_shapes: tuple[tuple[int, ...], ...],
    ) -> tuple[tuple[int, ...], ...]:
        if not element_shapes:
            return ()
        varying_dims, _ = cls._pack_layout_from_element_shapes(element_shapes)
        if not varying_dims:
            return ()

        level_sizes: list[tuple[int, ...]] = []
        prefix_products = [1] * len(element_shapes)
        for dim in varying_dims:
            sizes: list[int] = []
            next_prefix_products: list[int] = []
            for shape, prefix in zip(element_shapes, prefix_products):
                dim_size = int(shape[dim])
                sizes.extend([dim_size] * prefix)
                next_prefix_products.append(prefix * dim_size)
            level_sizes.append(tuple(sizes))
            prefix_products = next_prefix_products
        return tuple(level_sizes)

    @classmethod
    def _hierarchical_level_sizes_from_physical_shape(
        cls,
        physical_shape: Tensor,
        element_shapes: tuple[tuple[int, ...], ...] | None = None,
    ) -> tuple[tuple[int, ...], ...]:
        if physical_shape.numel() == 0:
            return ()
        if element_shapes is not None:
            return cls._hierarchical_level_sizes_from_element_shapes(element_shapes)
        if _is_fake_tensor(physical_shape):
            return ()

        varying_dims, _ = cls._pack_layout_meta(physical_shape, None)
        if not varying_dims:
            return ()

        shape_rows = tuple(cls._trim_shape(row) for row in physical_shape.tolist())
        level_sizes: list[tuple[int, ...]] = []
        prefix_products = [1] * len(shape_rows)
        for dim in varying_dims:
            sizes: list[int] = []
            next_prefix_products: list[int] = []
            for shape, prefix in zip(shape_rows, prefix_products):
                dim_size = int(shape[dim]) if dim < len(shape) else 0
                sizes.extend([dim_size] * prefix)
                next_prefix_products.append(prefix * dim_size)
            level_sizes.append(tuple(sizes))
            prefix_products = next_prefix_products
        return tuple(level_sizes)

    @staticmethod
    def _inverse_permutation(permutation: tuple[int, ...]) -> tuple[int, ...]:
        inverse = [0] * len(permutation)
        for index, dim in enumerate(permutation):
            inverse[dim] = index
        return tuple(inverse)

    @classmethod
    def _pack_layout_from_element_shapes(
        cls,
        element_shapes: tuple[tuple[int, ...], ...],
    ) -> tuple[tuple[int, ...], tuple[int, ...]]:
        if not element_shapes:
            return (), ()
        ndim = len(element_shapes[0])
        if ndim == 0:
            return (), ()
        reference = element_shapes[0]
        static_dims = [
            dim
            for dim in range(ndim)
            if all(len(shape) == ndim and shape[dim] == reference[dim] for shape in element_shapes[1:])
        ]
        if len(static_dims) == ndim:
            static_dims = list(range(1, ndim))
        static_dims_tuple = tuple(static_dims)
        varying_dims = tuple(dim for dim in range(ndim) if dim not in static_dims_tuple)
        return varying_dims, static_dims_tuple

    @classmethod
    def _pack_layout_meta(
        cls,
        physical_shape: Tensor,
        element_shapes: tuple[tuple[int, ...], ...] | None,
    ) -> tuple[tuple[int, ...], tuple[int, ...]]:
        if element_shapes is not None and (element_shapes or int(physical_shape.size(1)) == 0):
            return cls._pack_layout_from_element_shapes(element_shapes)
        ndim = int(physical_shape.size(1))
        if ndim == 0:
            return (), ()
        if physical_shape.size(0) == 0:
            return (0,), tuple(range(1, ndim))
        static_dims = tuple(
            dim
            for dim in range(ndim)
            if bool(torch.equal(physical_shape[:, dim], physical_shape[:1, dim].expand(physical_shape.size(0))))
        )
        if len(static_dims) == ndim:
            static_dims = tuple(range(1, ndim))
        varying_dims = tuple(dim for dim in range(ndim) if dim not in static_dims)
        return varying_dims, static_dims

    @staticmethod
    def _packed_size_from_shape(shape: tuple[int, ...], varying_dims: tuple[int, ...]) -> int:
        if not shape or not varying_dims:
            return 1
        size = 1
        for dim in varying_dims:
            size *= int(shape[dim])
        return size

    @classmethod
    def _python_meta_from_packed(
        cls,
        values: Tensor,
        offsets: Tensor,
        shape_tensor: Tensor,
        *,
        packed_sizes: tuple[int, ...] | None = None,
        element_shapes: tuple[tuple[int, ...], ...] | None = None,
    ) -> tuple[tuple[int, ...], tuple[tuple[int, ...], ...]]:
        if packed_sizes is None:
            packed_sizes = tuple(int(size) for size in (offsets[1:] - offsets[:-1]).tolist())
        if element_shapes is None:
            element_shapes = tuple(cls._trim_shape(shape) for shape in shape_tensor.tolist())
        return packed_sizes, element_shapes

    @classmethod
    @torch._dynamo.disable
    def _infer_python_meta_from_packed(
        cls,
        values: Tensor,
        offsets: Tensor,
        shape_tensor: Tensor,
        *,
        packed_sizes: tuple[int, ...] | None = None,
        element_shapes: tuple[tuple[int, ...], ...] | None = None,
    ) -> tuple[tuple[int, ...], tuple[tuple[int, ...], ...]]:
        return cls._python_meta_from_packed(
            values,
            offsets,
            shape_tensor,
            packed_sizes=packed_sizes,
            element_shapes=element_shapes,
        )

    @staticmethod
    def _compute_logical_shape(tensors: tuple[Tensor, ...], batch_first: bool) -> torch.Size:
        r"""Compute the logical shape [B, max_d0, max_d1, ...] from individual tensors."""
        if not tensors:
            return torch.Size((0,))
        if max(t.dim() for t in tensors) == 0:
            return torch.Size((len(tensors),))
        ndim = max(t.dim() for t in tensors)
        size = [max(t.shape[i] if i < len(t.shape) else 0 for t in tensors) for i in range(ndim)]
        size.insert(0 if batch_first else 1, len(tensors))
        return torch.Size(size)

    @staticmethod
    def _logical_shape_from_physical_shape(physical_shape: Tensor, offsets: Tensor, batch_first: bool) -> torch.Size:
        r"""Compute logical shape from packed metadata without unpacking elements."""
        batch_size = len(offsets) - 1
        if batch_size == 0:
            return torch.Size((0,))
        if physical_shape.numel() == 0:
            return torch.Size((batch_size,))
        size = [int(physical_shape[:, d].max().item()) for d in range(physical_shape.size(1))]
        while size and size[-1] == 0:
            size.pop()
        size.insert(0 if batch_first else 1, batch_size)
        return torch.Size(size)

    @staticmethod
    def _batch_dim_from_logical_shape(logical_shape: torch.Size, batch_first: bool) -> int:
        r"""Return the batch dimension index for a logical NestedTensor shape."""
        return 0 if len(logical_shape) <= 1 or batch_first else 1

    @classmethod
    def _validate_packed_metadata(
        cls,
        values: Tensor,
        offsets: Tensor,
        shape_tensor: Tensor,
        *,
        permutation: tuple[int, ...],
        logical_shape: torch.Size,
        batch_first: bool,
        packed_sizes: tuple[int, ...] | None,
        element_shapes: tuple[tuple[int, ...], ...] | None,
    ) -> None:
        r"""Validate that packed storage and metadata describe a coherent NestedTensor layout."""
        if offsets.device.type != "cpu":
            raise ValueError(f"offsets must be on CPU, got {offsets.device}")
        if shape_tensor.device.type != "cpu":
            raise ValueError(f"shape_tensor must be on CPU, got {shape_tensor.device}")
        if offsets.dim() != 1:
            raise ValueError(f"offsets must be 1-D, got shape {tuple(offsets.shape)}")
        if shape_tensor.dim() != 2:
            raise ValueError(f"shape_tensor must be 2-D, got shape {tuple(shape_tensor.shape)}")
        if offsets.dtype.is_floating_point or offsets.dtype.is_complex or offsets.dtype == torch.bool:
            raise ValueError(f"offsets must use an integer dtype, got {offsets.dtype}")
        if shape_tensor.dtype.is_floating_point or shape_tensor.dtype.is_complex or shape_tensor.dtype == torch.bool:
            raise ValueError(f"shape_tensor must use an integer dtype, got {shape_tensor.dtype}")

        batch_size = int(shape_tensor.size(0))
        if offsets.numel() != batch_size + 1:
            raise ValueError(
                "offsets length must equal batch size + 1, got "
                f"offsets.numel()={offsets.numel()}, batch_size={batch_size}"
            )

        physical_rank = int(shape_tensor.size(1))
        if len(logical_shape) != physical_rank + 1:
            raise ValueError(
                "logical shape rank must equal physical rank + 1, got "
                f"logical rank={len(logical_shape)}, physical rank={physical_rank}"
            )
        batch_dim = cls._batch_dim_from_logical_shape(logical_shape, batch_first)
        logical_batch = logical_shape[batch_dim]
        if logical_batch != batch_size:
            raise ValueError(f"logical batch size {logical_batch} does not match metadata batch size {batch_size}")

        if len(permutation) != physical_rank or tuple(sorted(int(dim) for dim in permutation)) != tuple(
            range(physical_rank)
        ):
            raise ValueError(f"Invalid permutation dims {permutation} for shape with {physical_rank} dims")

        if packed_sizes is not None:
            if len(packed_sizes) != batch_size:
                raise ValueError(
                    f"packed_sizes must have one entry per element, got {len(packed_sizes)} for batch size {batch_size}"
                )
            if any(int(size) < 0 for size in packed_sizes):
                raise ValueError("packed_sizes must be non-negative")
            if sum(int(size) for size in packed_sizes) != int(values.shape[0]):
                raise ValueError("packed_sizes must sum to the packed values length")

        if element_shapes is not None:
            if len(element_shapes) != batch_size:
                raise ValueError(
                    "element_shapes must have one entry per element, got "
                    f"{len(element_shapes)} for batch size {batch_size}"
                )
            normalized_shapes = tuple(tuple(int(dim) for dim in shape) for shape in element_shapes)
            if any(len(shape) != physical_rank for shape in normalized_shapes):
                raise ValueError(
                    f"element_shapes rank must match physical rank {physical_rank}, got {normalized_shapes}"
                )
            if any(any(dim < 0 for dim in shape) for shape in normalized_shapes):
                raise ValueError("element_shapes must be non-negative")
            if not _is_fake_tensor(shape_tensor):
                shape_rows = tuple(tuple(int(size) for size in row) for row in shape_tensor.tolist())
                if normalized_shapes != shape_rows:
                    raise ValueError("element_shapes must match shape_tensor exactly")

        if _is_fake_tensor(offsets) or _is_fake_tensor(shape_tensor):
            return

        if bool((shape_tensor < 0).any()):
            raise ValueError("shape_tensor must be non-negative")
        if int(offsets[0].item()) != 0:
            raise ValueError("offsets must start at 0")
        deltas = offsets[1:] - offsets[:-1]
        if bool((deltas < 0).any()):
            raise ValueError("offsets must be monotonically non-decreasing")
        if packed_sizes is None and int(offsets[-1].item()) != int(values.shape[0]):
            raise ValueError(
                f"offsets[-1] must equal packed values length, got offsets[-1]={int(offsets[-1].item())} "
                f"and values.shape[0]={int(values.shape[0])}"
            )

    def _validate_metadata(self) -> None:
        r"""Validate the current packed storage and metadata."""
        type(self)._validate_packed_metadata(
            self._values,
            self._offsets,
            self._physical_shape,
            permutation=self._permutation,
            logical_shape=self._logical_shape,
            batch_first=self.batch_first,
            packed_sizes=self._packed_sizes,
            element_shapes=self._element_shapes,
        )

    @staticmethod
    def _coerce_batch_first(value: bool) -> bool:
        if not isinstance(value, bool):
            raise TypeError(f"batch_first must be a bool, got {type(value).__name__}")
        return value

    @staticmethod
    def _coerce_mask_value(value: bool) -> bool:
        if not isinstance(value, bool):
            raise TypeError(f"mask_value must be a bool, got {type(value).__name__}")
        return value

    @staticmethod
    def _coerce_padding_value(value: SupportsFloat) -> float:
        try:
            return float(value)
        except (TypeError, ValueError) as exc:
            raise TypeError(f"padding_value must be float-convertible, got {type(value).__name__}") from exc

    def _set_runtime_config(
        self,
        *,
        batch_first: bool,
        padding_value: SupportsFloat,
        mask_value: bool,
    ) -> None:
        self._batch_first = type(self)._coerce_batch_first(batch_first)
        self._padding_value = type(self)._coerce_padding_value(padding_value)
        self._mask_value = type(self)._coerce_mask_value(mask_value)

    def _invalidate_transient_caches(self) -> None:
        r"""Drop all lazily materialized views derived from packed storage."""
        self._cached_storage = None
        self._cached_hierarchical_offsets = None
        self._cached_tensor_view = None
        self._cached_mask_view = None

    def _values_cache_token(self) -> tuple[int, int, int]:
        r"""Return a cache token for views that depend on packed values and layout metadata.

        Under ``torch.inference_mode`` tensors do not track version counters and
        in-place mutation is forbidden, so the cache is always valid.
        """
        if torch.is_inference_mode_enabled():
            return (0, 0, 0)
        return (int(self._values._version), int(self._offsets._version), int(self._physical_shape._version))

    def _shape_cache_token(self) -> tuple[int, int]:
        r"""Return a cache token for views that depend only on shape metadata."""
        if torch.is_inference_mode_enabled():
            return (0, 0)
        return (int(self._offsets._version), int(self._physical_shape._version))

    @classmethod
    def _validate_serialized_state(cls, state: Mapping) -> None:
        required = (
            "_state_version",
            "_values",
            "_offsets",
            "_permutation",
            "_physical_shape",
            "_logical_shape",
            "batch_first",
            "padding_value",
            "mask_value",
            "_pin_memory",
            "_packed_sizes",
            "_element_shapes",
        )
        missing = [key for key in required if key not in state]
        if missing:
            raise KeyError(f"Serialized NestedTensor state is missing required keys: {', '.join(missing)}")
        version = state["_state_version"]
        if version != cls._SERIALIZATION_VERSION:
            raise ValueError(f"Unsupported NestedTensor state version {version}; expected {cls._SERIALIZATION_VERSION}")

    @classmethod
    @torch._dynamo.disable
    def _from_packed(
        cls,
        values: Tensor,
        offsets: Tensor,
        shape_tensor: Tensor,
        *,
        permutation: tuple[int, ...] | None = None,
        batch_first: bool = True,
        padding_value: float = 0.0,
        mask_value: bool = False,
        pin_memory: bool = False,
        outer_size: torch.Size | tuple | None = None,
        packed_sizes: tuple[int, ...] | None = None,
        element_shapes: tuple[tuple[int, ...], ...] | None = None,
    ) -> Self:
        r"""Construct a NestedTensor directly from packed representation."""
        # offsets and shape_tensor MUST live on CPU to avoid implicit CUDA syncs
        # when handlers call .item() / .tolist() on them.
        if offsets.device.type != "cpu":
            raise ValueError(f"offsets must be on CPU, got {offsets.device}")
        if shape_tensor.device.type != "cpu":
            raise ValueError(f"shape_tensor must be on CPU, got {shape_tensor.device}")

        if outer_size is not None:
            logical_shape = torch.Size(outer_size)
        else:
            logical_shape = cls._logical_shape_from_physical_shape(shape_tensor, offsets, batch_first)
        if packed_sizes is None and not _is_fake_tensor(offsets):
            packed_sizes = tuple(int(size) for size in (offsets[1:] - offsets[:-1]).tolist())
        if element_shapes is None and not _is_fake_tensor(shape_tensor):
            element_shapes = tuple(cls._trim_shape(shape) for shape in shape_tensor.tolist())

        if _is_fake_tensor(values) and not (_is_fake_tensor(offsets) and _is_fake_tensor(shape_tensor)):
            from torch._subclasses.fake_tensor import maybe_get_fake_mode

            fake_mode = maybe_get_fake_mode(values)
            if fake_mode is not None:
                if not _is_fake_tensor(offsets):
                    offsets = fake_mode.from_tensor(offsets, static_shapes=True, trace=False)
                if not _is_fake_tensor(shape_tensor):
                    shape_tensor = fake_mode.from_tensor(shape_tensor, static_shapes=True, trace=False)

        values = cls._maybe_pin_values(values, pin_memory)
        result = torch.Tensor._make_wrapper_subclass(
            cls,
            logical_shape,
            dtype=values.dtype,
            device=values.device,
            requires_grad=values.requires_grad,
        )
        result._values = values
        result._offsets = offsets
        result._permutation = (
            tuple(int(dim) for dim in permutation)
            if permutation is not None
            else cls._permutation_from_physical_shape(shape_tensor, element_shapes)
        )
        result._physical_shape = shape_tensor
        result._logical_shape = logical_shape
        result._set_runtime_config(
            batch_first=batch_first,
            padding_value=padding_value,
            mask_value=mask_value,
        )
        result._pin_memory = bool(pin_memory and values.device.type == "cpu" and values.is_pinned())
        result._packed_sizes = packed_sizes
        result._element_shapes = element_shapes
        result._invalidate_transient_caches()
        cls._validate_packed_metadata(
            result._values,
            result._offsets,
            result._physical_shape,
            permutation=result._permutation,
            logical_shape=result._logical_shape,
            batch_first=result.batch_first,
            packed_sizes=result._packed_sizes,
            element_shapes=result._element_shapes,
        )
        return result

    # ------------------------------------------------------------------
    # torch.compile support
    # ------------------------------------------------------------------

    def __tensor_flatten__(self):
        # During tracing, wrapper instances can be inspected while being built.
        # Only expose tensor attrs that already exist so Dynamo/FakeTensor can
        # inspect partially constructed wrapper subclasses safely.
        instance_attrs = vars(self)
        inner_tensors = [name for name in ("_values", "_offsets", "_physical_shape") if name in instance_attrs]
        if not inner_tensors:
            inner_tensors = ["_flatten_sentinel"]
        return inner_tensors, {
            "batch_first": getattr(self, "batch_first", True),
            "padding_value": getattr(self, "padding_value", 0.0),
            "mask_value": getattr(self, "mask_value", False),
            "pin_memory": getattr(self, "_pin_memory", False),
            "packed_sizes": getattr(self, "_packed_sizes", ()),
            "element_shapes": getattr(self, "_element_shapes", ()),
            "permutation": getattr(self, "_permutation", ()),
        }

    @classmethod
    def __tensor_unflatten__(cls, inner_tensors, ctx, outer_size, outer_stride):
        values = inner_tensors.get("_values", inner_tensors.get("_flatten_sentinel"))
        if values is None:
            raise RuntimeError("NestedTensor requires _values during tensor unflatten.")

        offsets = inner_tensors.get("_offsets")
        shape_tensor = inner_tensors.get("_physical_shape")
        if offsets is not None and shape_tensor is not None:
            # During backward, outer_size may reflect a transposed view
            # (e.g., (seq, batch, hidden) from MHA's batch-dim transpose).
            # Detect and correct so _from_packed validation passes.
            batch_size = len(offsets) - 1
            outer = tuple(outer_size)
            batch_first = ctx.get("batch_first", True)
            if len(outer) >= 2 and (
                (batch_first and outer[0] != batch_size and outer[1] == batch_size)
                or (not batch_first and outer[1] != batch_size and outer[0] == batch_size)
            ):
                outer = (outer[1], outer[0], *outer[2:])
            return cls._from_packed(
                values,
                offsets,
                shape_tensor,
                outer_size=outer,
                **ctx,
            )

        result = torch.Tensor._make_wrapper_subclass(
            cls,
            torch.Size(outer_size),
            dtype=values.dtype,
            device=values.device,
            requires_grad=values.requires_grad,
        )
        result._values = values
        if offsets is not None:
            result._offsets = offsets
        if shape_tensor is not None:
            result._physical_shape = shape_tensor
        result._logical_shape = torch.Size(outer_size)
        result._set_runtime_config(
            batch_first=ctx["batch_first"],
            padding_value=ctx["padding_value"],
            mask_value=ctx["mask_value"],
        )
        result._pin_memory = ctx["pin_memory"]
        result._packed_sizes = ctx["packed_sizes"]
        result._element_shapes = ctx["element_shapes"]
        result._permutation = tuple(int(dim) for dim in ctx["permutation"])
        result._invalidate_transient_caches()
        return result

    # ------------------------------------------------------------------
    # Dispatch
    # ------------------------------------------------------------------

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any:
        if kwargs is None:
            kwargs = {}

        # Handle size() specially to avoid infinite recursion
        if func is torch.Tensor.size:
            self = args[0]
            dim = args[1] if len(args) > 1 else kwargs.get("dim")
            return self.size(dim)

        from .ops import NestedTensorFuncRegistry, _compile_unsupported, _is_compiling

        handler = NestedTensorFuncRegistry.get(func)
        if handler is not None:
            if _is_compiling() and not NestedTensorFuncRegistry.is_compile_safe(func, args, kwargs):
                name = getattr(func, "__qualname__", getattr(func, "__name__", repr(func)))
                _compile_unsupported(name, "handler is marked eager-only")
            return handler(*args, **kwargs)

        with torch._C.DisableTorchFunctionSubclass():
            return func(*args, **kwargs)

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None) -> Any:
        if kwargs is None:
            kwargs = {}

        from .ops import _compile_unsupported, _is_compiling

        if func in NestedTensorAtenRegistry:
            if _is_compiling() and not NestedTensorAtenRegistry.is_compile_safe(func, args, kwargs):
                name = getattr(func, "name", None)
                if callable(name):
                    name = name()
                _compile_unsupported(name or repr(func), "aten handler is marked eager-only")
            return NestedTensorAtenRegistry[func](func, args, kwargs)

        if _is_compiling():
            name = getattr(func, "name", None)
            if callable(name):
                name = name()
            _compile_unsupported(name or repr(func), "would fall back to per-element eager execution")
        return per_element_fallback(func, args, kwargs)

    # ------------------------------------------------------------------
    # Layout & Metadata Helpers
    # ------------------------------------------------------------------

    def _unpack(self) -> tuple[Tensor, ...]:
        r"""Reconstruct individual tensors from packed representation."""
        batch_size = len(self._offsets) - 1
        if batch_size == 0:
            return ()

        packed_sizes = self._packed_sizes
        if packed_sizes is None:
            if _is_fake_tensor(self._offsets):
                raise RuntimeError("NestedTensor packed sizes are unavailable for this instance.")
            packed_sizes = tuple(int(size) for size in (self._offsets[1:] - self._offsets[:-1]).tolist())

        element_shapes = self._element_shapes
        if element_shapes is None:
            element_shapes = tuple(tuple(int(dim) for dim in shape) for shape in self._original_shapes())

        splits = self._values.split(packed_sizes, dim=0)
        permutation = self._permutation
        if permutation:
            varying_dims = self._varying_dims
            static_dims = self._static_dims
        else:
            varying_dims, static_dims = type(self)._pack_layout_meta(self._physical_shape, element_shapes)
            permutation = varying_dims + static_dims
        inverse_permutation = type(self)._inverse_permutation(permutation)

        result = []
        for chunk, shape in zip(splits, element_shapes):
            if not shape:
                result.append(chunk[0])
            else:
                packed_shape = tuple(shape[dim] for dim in varying_dims) + tuple(shape[dim] for dim in static_dims)
                unpacked = chunk.reshape(packed_shape)
                if permutation != tuple(range(len(shape))):
                    unpacked = unpacked.permute(inverse_permutation)
                result.append(unpacked)
        return tuple(result)

    def _repack(self, tensors: Sequence) -> None:
        r"""
        Re-pack from already-validated tensors. Skips coercion — callers must ensure
        tensors share device, dtype, and ndim (which is always true for internal paths
        since tensors originate from _unpack or __setitem__ validation)."""
        self._invalidate_transient_caches()
        tensors = tuple(tensors) if not isinstance(tensors, tuple) else tensors
        if tensors and len(self._permutation) != tensors[0].ndim:
            raise RuntimeError(
                "NestedTensor._repack received tensors with rank "
                f"{tensors[0].ndim} but current permutation has rank {len(self._permutation)}"
            )
        values, offsets, shape_tensor, packed_sizes, element_shapes = self._pack(
            tensors,
            permutation=self._permutation if tensors else None,
        )
        values = type(self)._maybe_pin_values(values, self._pin_memory)
        self._values = values
        self._offsets = offsets
        self._physical_shape = shape_tensor
        self._logical_shape = self._compute_logical_shape(tensors, self.batch_first)
        self._packed_sizes = packed_sizes
        self._element_shapes = element_shapes
        self._validate_metadata()

    @property
    def _hierarchical_offsets(self) -> tuple[Tensor, ...]:
        if self._cached_hierarchical_offsets is None:
            level_sizes = type(self)._hierarchical_level_sizes_from_physical_shape(
                self._physical_shape,
                self._element_shapes,
            )
            if not level_sizes:
                if self._element_shapes is None and self._packed_sizes is not None:
                    self._cached_hierarchical_offsets = (
                        type(self)._offsets_from_sizes(self._packed_sizes, dtype=self._offsets.dtype),
                    )
                elif self._element_shapes is None and _is_fake_tensor(self._physical_shape):
                    self._cached_hierarchical_offsets = (self._offsets,)
                else:
                    self._cached_hierarchical_offsets = ()
            else:
                self._cached_hierarchical_offsets = tuple(
                    type(self)._offsets_from_sizes(level_sizes[level], dtype=self._offsets.dtype)
                    for level in range(len(level_sizes))
                )
        return self._cached_hierarchical_offsets

    @property
    def _ragged_rank(self) -> int:
        return len(self._hierarchical_offsets)

    def _ragged_level_offsets(self, level: int = -1) -> Tensor:
        offsets = self._hierarchical_offsets
        if not offsets:
            return self._offsets
        return offsets[level]

    def _ragged_level_sizes(self, level: int = -1) -> Tensor:
        offsets = self._ragged_level_offsets(level)
        return offsets[1:] - offsets[:-1]

    @property
    def _varying_dims(self) -> tuple[int, ...]:
        ragged_rank = self._ragged_rank
        if ragged_rank <= 0:
            return ()
        if self._permutation:
            return tuple(int(dim) for dim in self._permutation[:ragged_rank])
        varying_dims, _ = type(self)._pack_layout_meta(self._physical_shape, self._element_shapes)
        return varying_dims

    @property
    def _static_dims(self) -> tuple[int, ...]:
        ragged_rank = self._ragged_rank
        if self._permutation:
            return tuple(int(dim) for dim in self._permutation[ragged_rank:])
        _, static_dims = type(self)._pack_layout_meta(self._physical_shape, self._element_shapes)
        return static_dims

    def _has_same_structure(self, other: Self) -> bool:
        if self.batch_first != other.batch_first or self._permutation != other._permutation:
            return False
        if self._element_shapes is not None and other._element_shapes is not None:
            lhs_levels = type(self)._hierarchical_level_sizes_from_element_shapes(self._element_shapes)
            rhs_levels = type(self)._hierarchical_level_sizes_from_element_shapes(other._element_shapes)
            if lhs_levels or rhs_levels:
                return lhs_levels == rhs_levels
            return len(self) == len(other)
        lhs_offsets = self._hierarchical_offsets
        rhs_offsets = other._hierarchical_offsets
        if len(lhs_offsets) != len(rhs_offsets):
            return False
        if lhs_offsets:
            return all(type(self)._meta_tensor_equal(lhs, rhs) for lhs, rhs in zip(lhs_offsets, rhs_offsets))
        return type(self)._meta_tensor_equal(self._offsets, other._offsets)

    def _has_same_layout(self, other: Self) -> bool:
        if not self._has_same_structure(other):
            return False
        if self._element_shapes is not None and other._element_shapes is not None:
            if self._element_shapes != other._element_shapes:
                return False
            if self._packed_sizes is not None and other._packed_sizes is not None:
                return self._packed_sizes == other._packed_sizes
            return True
        if (
            self._packed_sizes is not None
            and other._packed_sizes is not None
            and self._packed_sizes != other._packed_sizes
        ):
            return False
        if not type(self)._meta_tensor_equal(self._physical_shape, other._physical_shape):
            return False
        return type(self)._meta_tensor_equal(self._offsets, other._offsets)

    def _packed_flat_index(
        self,
        *,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.long,
    ) -> Tensor:
        target_device = self.device if device is None else device
        leading = self._values.size(0) if self._values.dim() > 0 else self._values.numel()
        return torch.arange(leading, device=target_device, dtype=dtype)

    def _packed_batch_local_indices(
        self,
        flat_idx: Tensor | None = None,
        *,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.long,
    ) -> tuple[Tensor, Tensor]:
        target_device = self.device if device is None else device
        if flat_idx is None:
            flat_idx = self._packed_flat_index(device=target_device, dtype=dtype)
        offsets = self._offsets.to(device=target_device, dtype=dtype)
        batch_idx = torch.searchsorted(offsets[1:], flat_idx, right=True)
        local_idx = flat_idx - offsets[batch_idx]
        return batch_idx, local_idx

    def _packed_varying_coords(
        self,
        batch_idx: Tensor,
        local_idx: Tensor,
        *,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.long,
    ) -> tuple[Tensor, ...]:
        target_device = self.device if device is None else device
        varying_dims = self._varying_dims
        if not varying_dims:
            return ()

        varying_sizes = self._physical_shape[:, list(varying_dims)].to(device=target_device, dtype=dtype)[batch_idx]
        strides = torch.ones_like(varying_sizes)
        running = torch.ones(varying_sizes.size(0), dtype=dtype, device=target_device)
        for dim in range(varying_sizes.size(1) - 1, -1, -1):
            strides[:, dim] = running
            running = running * varying_sizes[:, dim]

        coords: list[Tensor] = []
        remainder = local_idx
        for dim in range(varying_sizes.size(1)):
            coord = remainder // strides[:, dim]
            coords.append(coord)
            remainder = remainder - coord * strides[:, dim]
        return tuple(coords)

    def _packed_dense_index(
        self,
        flat_idx: Tensor | None = None,
        *,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.long,
    ) -> tuple[Tensor | slice, ...]:
        target_device = self.device if device is None else device
        batch_idx, local_idx = self._packed_batch_local_indices(flat_idx, device=target_device, dtype=dtype)
        varying_dims = self._varying_dims
        coords = self._packed_varying_coords(batch_idx, local_idx, device=target_device, dtype=dtype)

        dense_index: list[Tensor | slice] = [batch_idx]
        coord_iter = iter(coords)
        for dim in range(self._physical_shape.size(1)):
            dense_index.append(next(coord_iter) if dim in varying_dims else slice(None))
        return tuple(dense_index)

    def _physical_shape_like_batch_dense(self, batch_dense_shape: Sequence[int]) -> Tensor:
        r"""Return per-element shapes for a batch-leading dense tensor with this NestedTensor's ragged structure."""
        expected_ndim = self._physical_shape.size(1) + 1
        if len(batch_dense_shape) != expected_ndim:
            raise ValueError(
                "Batch-leading dense tensor rank does not match NestedTensor layout, "
                f"expected {expected_ndim}, got {len(batch_dense_shape)}"
            )
        shape, _, _ = self._shape_meta_from_components(
            replace_dims={int(dim): int(batch_dense_shape[dim + 1]) for dim in self._static_dims}
        )
        return shape

    def _element_shapes_like_batch_dense(
        self,
        batch_dense_shape: Sequence[int],
    ) -> tuple[tuple[int, ...], ...] | None:
        r"""Return Python element-shape metadata for a batch-leading dense tensor with this NestedTensor's layout."""
        expected_ndim = self._physical_shape.size(1) + 1
        if len(batch_dense_shape) != expected_ndim:
            raise ValueError(
                "Batch-leading dense tensor rank does not match NestedTensor layout, "
                f"expected {expected_ndim}, got {len(batch_dense_shape)}"
            )
        _, _, element_shapes = self._shape_meta_from_components(
            replace_dims={int(dim): int(batch_dense_shape[dim + 1]) for dim in self._static_dims}
        )
        return element_shapes

    def _shape_meta_from_components(
        self,
        *,
        prefix: Sequence[int] = (),
        keep_dims: Sequence[int] | None = None,
        suffix: Sequence[int] = (),
        replace_dims: Mapping[int, int] | None = None,
    ) -> tuple[Tensor, tuple[int, ...] | None, tuple[tuple[int, ...], ...] | None]:
        r"""Build packed shape metadata by keeping selected dims and applying constant prefix/suffix updates."""
        if keep_dims is None:
            keep_dims = tuple(range(self._physical_shape.size(1)))
        keep_dims = tuple(int(dim) for dim in keep_dims)
        prefix = tuple(int(size) for size in prefix)
        suffix = tuple(int(size) for size in suffix)
        updates = {int(dim): int(size) for dim, size in (replace_dims or {}).items()}

        if self._element_shapes:
            element_shapes_list: list[tuple[int, ...]] = []
            for element_shape in self._element_shapes:
                projected = [*prefix, *(int(element_shape[dim]) for dim in keep_dims), *suffix]
                for dim, size in updates.items():
                    projected[dim] = size
                element_shapes_list.append(tuple(projected))
            element_shapes = tuple(element_shapes_list)
            max_ndim = max(len(shape) for shape in element_shapes)
            shape = torch.tensor(
                [list(shape) + [0] * (max_ndim - len(shape)) for shape in element_shapes],
                dtype=torch.long,
            )
            return shape, self._packed_sizes_like(element_shapes), element_shapes

        parts: list[Tensor] = []
        batch_size = len(self)
        if prefix:
            parts.append(self._physical_shape.new_tensor(prefix).reshape(1, -1).expand(batch_size, -1))
        if keep_dims:
            parts.append(self._physical_shape[:, list(keep_dims)].clone())
        if suffix:
            parts.append(self._physical_shape.new_tensor(suffix).reshape(1, -1).expand(batch_size, -1))
        if parts:
            shape = torch.cat(parts, dim=1)
        else:
            shape = self._physical_shape.new_empty((batch_size, 0))
        for dim, size in updates.items():
            shape[:, dim] = size
        return shape, None, None

    def _max_physical_dims(self) -> tuple[int, ...]:
        r"""Return the maximum per-element size for each physical dimension (excluding batch)."""
        batch_dim = 0 if self.batch_first else 1
        return tuple(int(size) for index, size in enumerate(self._logical_shape) if index != batch_dim)

    def _logical_shape_from_physical_dims(self, physical_dims: Sequence[int]) -> torch.Size:
        r"""Build a logical outer shape from non-batch physical-dimension sizes."""
        physical_dims = tuple(int(size) for size in physical_dims)
        batch_size = len(self)
        if self.batch_first:
            return torch.Size((batch_size, *physical_dims))
        if not physical_dims:
            return torch.Size((batch_size,))
        return torch.Size((physical_dims[0], batch_size, *physical_dims[1:]))

    def _logical_shape_from_components(
        self,
        *,
        prefix: Sequence[int] = (),
        keep_dims: Sequence[int] | None = None,
        suffix: Sequence[int] = (),
        replace_dims: Mapping[int, int] | None = None,
    ) -> torch.Size:
        r"""Build a logical outer shape by projecting the current physical-dimension extents."""
        physical_dims = list(self._max_physical_dims())
        if keep_dims is None:
            keep_dims = tuple(range(len(physical_dims)))
        projected = [*(int(prefix_dim) for prefix_dim in prefix), *(physical_dims[int(dim)] for dim in keep_dims)]
        projected.extend(int(suffix_dim) for suffix_dim in suffix)
        for dim, size in (replace_dims or {}).items():
            projected[int(dim)] = int(size)
        return self._logical_shape_from_physical_dims(projected)

    def _leading_dim_preserving_meta(
        self,
        suffix: Sequence[int],
    ) -> tuple[Tensor, torch.Size, tuple[int, ...] | None, tuple[tuple[int, ...], ...] | None]:
        r"""Build metadata for ops that preserve the first per-element dim and replace all trailing dims uniformly."""
        keep_dims = (0,) if self._physical_shape.size(1) > 0 else ()
        shape, packed_sizes, element_shapes = self._shape_meta_from_components(keep_dims=keep_dims, suffix=suffix)
        return shape, self._leading_dim_preserving_outer_size(suffix), packed_sizes, element_shapes

    def _leading_dim_preserving_outer_size(self, suffix: Sequence[int]) -> torch.Size:
        r"""Return logical outer size for ops that preserve per-element dim-0 and replace trailing dims uniformly."""
        suffix = tuple(int(size) for size in suffix)
        batch_size = len(self)
        batch_dim = 0 if self.batch_first else 1
        logical = list(self._logical_shape)
        non_batch = [int(logical[index]) for index in range(len(logical)) if index != batch_dim]

        new_non_batch: list[int] = []
        if self._physical_shape.size(1) > 0 and non_batch:
            new_non_batch.append(non_batch[0])
        new_non_batch.extend(suffix)

        if self.batch_first:
            return torch.Size((batch_size, *new_non_batch))
        if not new_non_batch:
            return torch.Size((batch_size,))
        return torch.Size((new_non_batch[0], batch_size, *new_non_batch[1:]))

    def _drop_trailing_physical_dims_meta(
        self,
        count: int,
        *,
        suffix: Sequence[int] = (),
    ) -> tuple[Tensor, tuple[int, ...] | None, tuple[tuple[int, ...], ...] | None]:
        r"""Build metadata after dropping trailing per-element dims and optionally appending a dense suffix."""
        keep_dims = tuple(range(max(self._physical_shape.size(1) - int(count), 0)))
        return self._shape_meta_from_components(keep_dims=keep_dims, suffix=suffix)

    def _replace_trailing_physical_dims_meta(
        self,
        trailing_sizes: Sequence[int],
    ) -> tuple[Tensor, tuple[int, ...] | None, tuple[tuple[int, ...], ...] | None]:
        r"""Build metadata after replacing the last physical dims with uniform sizes."""
        trailing_sizes = tuple(int(size) for size in trailing_sizes)
        if not trailing_sizes:
            return self._shape_meta_from_components()
        ndim = self._physical_shape.size(1)
        if len(trailing_sizes) > ndim:
            raise ValueError(f"Cannot replace {len(trailing_sizes)} trailing dims for per-element rank {ndim}")
        start = ndim - len(trailing_sizes)
        return self._shape_meta_from_components(
            replace_dims={start + index: size for index, size in enumerate(trailing_sizes)}
        )

    def _permutation_after_dropping_trailing_dims(self, count: int) -> tuple[int, ...]:
        r"""Return the canonical permutation after dropping trailing physical dims."""
        count = int(count)
        new_rank = max(self._physical_shape.size(1) - count, 0)
        if not self._permutation:
            return tuple(range(new_rank))
        return tuple(int(dim) for dim in self._permutation if dim < new_rank)

    def _permutation_after_replacing_trailing_dims(self, removed_count: int, added_count: int) -> tuple[int, ...]:
        r"""Return the canonical permutation after replacing trailing physical dims with a new suffix."""
        removed_count = int(removed_count)
        added_count = int(added_count)
        retained_rank = max(self._physical_shape.size(1) - removed_count, 0)
        retained = self._permutation_after_dropping_trailing_dims(removed_count)
        appended = tuple(range(retained_rank, retained_rank + added_count))
        return retained + appended

    def _scalar_result_meta(
        self,
    ) -> tuple[Tensor, Tensor, torch.Size, tuple[int, ...] | None, tuple[tuple[int, ...], ...] | None]:
        r"""Build metadata for one-scalar-per-element outputs."""
        shape, packed_sizes, element_shapes = self._shape_meta_from_components(keep_dims=())
        offsets = torch.arange(len(self) + 1, dtype=self._offsets.dtype, device=self._offsets.device)
        logical_shape = type(self)._logical_shape_from_physical_shape(shape, self._offsets, self.batch_first)
        return offsets, shape, logical_shape, packed_sizes, element_shapes

    def _from_scalar_result_values(self, values: Tensor) -> Self:
        r"""Wrap one scalar per element using the canonical scalar-result metadata."""
        cls = type(self)
        offsets, shape, outer_size, packed_sizes, element_shapes = self._scalar_result_meta()
        return cls._from_packed(
            values,
            offsets,
            shape,
            batch_first=self.batch_first,
            padding_value=self.padding_value,
            mask_value=self.mask_value,
            pin_memory=self._pin_memory,
            outer_size=outer_size,
            packed_sizes=packed_sizes,
            element_shapes=element_shapes,
        )

    @classmethod
    def _cat_batch_packed(cls, tensors: Sequence[Self]) -> Self | None:
        r"""Merge batch-dim concatenation directly from packed storage when layouts are compatible."""
        if not tensors:
            raise ValueError("Expected at least one NestedTensor to concatenate.")

        ref = tensors[0]
        packed_rank = ref._values.dim()
        packed_tail = ref._values.shape[1:]
        reference_permutation = ref._permutation
        for tensor in tensors[1:]:
            if tensor._values.dim() != packed_rank:
                return None
            if tensor._permutation != reference_permutation:
                return None
            if packed_rank > 1 and tensor._values.shape[1:] != packed_tail:
                return None

        new_values = torch.cat([tensor._values for tensor in tensors], dim=0)

        offset_parts = []
        cumulative = 0
        for index, tensor in enumerate(tensors):
            offsets = tensor._offsets if index == 0 else tensor._offsets[1:] + cumulative
            offset_parts.append(offsets)
            cumulative += int(tensor._offsets[-1].item())
        new_offsets = torch.cat(offset_parts, dim=0)

        max_cols = max(tensor._physical_shape.size(1) for tensor in tensors)
        if max_cols > 0:
            padded_shapes = []
            for tensor in tensors:
                physical_shape = tensor._physical_shape
                if physical_shape.size(1) < max_cols:
                    physical_shape = torch.nn.functional.pad(physical_shape, (0, max_cols - physical_shape.size(1)))
                padded_shapes.append(physical_shape)
            new_physical_shape = torch.cat(padded_shapes, dim=0)
        else:
            new_physical_shape = torch.empty(len(new_offsets) - 1, 0, dtype=torch.long)

        batch_dim = 0 if ref.batch_first else 1
        out_logical = list(ref._logical_shape)
        if len(out_logical) <= batch_dim:
            out_logical.extend(0 for _ in range(batch_dim + 1 - len(out_logical)))
        out_logical[batch_dim] = sum(len(tensor) for tensor in tensors)
        for logical_dim in range(len(out_logical)):
            if logical_dim == batch_dim:
                continue
            out_logical[logical_dim] = max(
                int(tensor._logical_shape[logical_dim]) if logical_dim < len(tensor._logical_shape) else 0
                for tensor in tensors
            )

        packed_sizes = None
        if all(tensor._packed_sizes is not None for tensor in tensors):
            packed_sizes = tuple(size for tensor in tensors for size in cast(tuple[int, ...], tensor._packed_sizes))
        element_shapes = None
        if all(tensor._element_shapes is not None for tensor in tensors):
            element_shapes = tuple(
                shape for tensor in tensors for shape in cast(tuple[tuple[int, ...], ...], tensor._element_shapes)
            )

        return cls._from_packed(
            new_values,
            new_offsets,
            new_physical_shape,
            permutation=reference_permutation,
            batch_first=ref.batch_first,
            padding_value=ref.padding_value,
            mask_value=ref.mask_value,
            pin_memory=ref._pin_memory,
            outer_size=tuple(out_logical),
            packed_sizes=packed_sizes,
            element_shapes=element_shapes,
        )

    @property
    def _storage(self) -> tuple[Tensor, ...]:
        if self._cached_storage is None:
            self._cached_storage = self._unpack()
        return self._cached_storage

    @_storage.setter
    def _storage(self, tensors: Sequence) -> None:
        self._repack(tensors)

    # ------------------------------------------------------------------
    # Cached materialized views
    # ------------------------------------------------------------------

    def _tensor_cached_view(self) -> Tensor:
        cached = self._cached_tensor_view
        token = self._values_cache_token()
        if (
            cached is not None
            and cached[0] is self.batch_first
            and cached[1] == self.padding_value
            and cached[2] == token
        ):
            return cached[3]
        batch_leading = self._materialize_batch_leading(self.padding_value)
        tensor = batch_leading if self.batch_first else batch_leading.movedim(0, 1)
        self._cached_tensor_view = (self.batch_first, self.padding_value, token, tensor)
        return tensor

    def _mask_cached_view(self) -> Tensor:
        cached = self._cached_mask_view
        token = self._shape_cache_token()
        if cached is not None and cached[0] is self.batch_first and cached[1] is self.mask_value and cached[2] == token:
            return cached[3]
        mask = self._materialize_mask()
        self._cached_mask_view = (self.batch_first, self.mask_value, token, mask)
        return mask

    @property
    def tensor_mask(self) -> tuple[Tensor, Tensor]:
        r"""
        Return a tuple of padded tensor and mask tensor.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.tensor_mask
            (tensor([[1, 2, 3],
                    [4, 5, 0]]), tensor([[ True,  True,  True],
                    [ True,  True, False]]))
        """
        return self._tensor_cached_view(), self._mask_cached_view()

    @property
    def tensor(self) -> Tensor:
        r"""
        Return a single tensor by padding all the tensors.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.tensor
            tensor([[1, 2, 3],
                    [4, 5, 0]])
        """
        return self._tensor_cached_view()

    @property
    def mask(self) -> Tensor:
        r"""
        Padding mask of `tensor`.

        `mask_value` controls which boolean value denotes padding in this mask.
        With the default `mask_value=False`, `True` means valid data.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.mask
            tensor([[ True,  True,  True],
                    [ True,  True, False]])
        """
        return self._mask_cached_view()

    def _mask_squeezes_channel(self) -> bool:
        return self._physical_shape.size(1) > 1 and (self._physical_shape.size(1) - 1) in self._static_dims

    def _materialize_mask(self) -> Tensor:
        batch_size = len(self)
        logical_shape = self._logical_shape
        squeeze_channel = self._mask_squeezes_channel()

        if batch_size == 0:
            mask_shape = logical_shape[:-1] if squeeze_channel else logical_shape
            return torch.empty(mask_shape, dtype=torch.bool, device=self.device)

        if self._physical_shape.size(1) == 0:
            return torch.full((batch_size,), not self.mask_value, dtype=torch.bool, device=self.device)

        effective_shape = logical_shape[:-1] if squeeze_channel else logical_shape
        batch_dim = 0 if self.batch_first else 1
        non_batch_sizes = [effective_shape[i] for i in range(len(effective_shape)) if i != batch_dim]

        sizes = self._physical_shape[:, :-1] if squeeze_channel else self._physical_shape
        sizes = sizes.to(device=self.device, dtype=torch.long)

        valid = _batch_leading_valid_mask_from_sizes(
            sizes,
            non_batch_sizes,
            device=self.device,
        )

        if not self.batch_first:
            valid = valid.movedim(0, 1)
        return valid if not self.mask_value else ~valid

    def _materialize_batch_leading(self, fill_value) -> Tensor:
        r"""Materialize a padded dense tensor with the batch dimension in front."""
        _check_execution_guard(_ExecutionGuardKind.PADDED_MATERIALIZATION, "NestedTensor._materialize_batch_leading")
        logical_shape = self._logical_shape
        batch_size = len(self)
        if batch_size == 0:
            if self.batch_first:
                return torch.empty(logical_shape, dtype=self._values.dtype, device=self.device)
            if len(logical_shape) <= 1:
                return torch.empty((0,), dtype=self._values.dtype, device=self.device)
            non_batch = list(logical_shape)
            non_batch.pop(1)
            return torch.empty((0, *non_batch), dtype=self._values.dtype, device=self.device)

        if self._physical_shape.size(1) == 0:
            return self._values.reshape((batch_size,))

        tensor_shape = list(logical_shape)
        tensor_shape.pop(0 if self.batch_first else 1)
        batch_leading = self._values.new_full((batch_size, *tensor_shape), fill_value)
        if self._values.size(0) > 0:
            batch_leading[self._packed_dense_index(device=batch_leading.device)] = self._values
        return batch_leading

    def _original_shapes(self) -> tuple[torch.Size, ...]:
        if self._element_shapes is not None:
            return tuple(torch.Size(shape) for shape in self._element_shapes)
        if not _is_fake_tensor(self._physical_shape):
            return tuple(torch.Size(type(self)._trim_shape(row)) for row in self._physical_shape.tolist())
        raise RuntimeError("NestedTensor shape metadata is unavailable for this instance.")

    @property
    def concat(self) -> Tensor:
        r"""
        Flatten elements and concatenate along the ragged dimension (no padding).

        This is particularly useful when calculating loss or passing `Linear` to avoid unnecessary computation.

        Examples:
            >>> nested_tensor = NestedTensor([torch.randn(9, 8), torch.randn(11, 8)])
            >>> nested_tensor.concat.shape
            torch.Size([20, 8])
            >>> nested_tensor = NestedTensor([torch.randn(9, 9, 8), torch.randn(11, 11, 8)])
            >>> nested_tensor.concat.shape
            torch.Size([202, 8])
            >>> nested_tensor = NestedTensor([torch.randn(9, 9, 8, 6), torch.randn(11, 11, 8, 6)])
            >>> nested_tensor.concat.shape
            torch.Size([202, 8, 6])
            >>> nested_tensor = NestedTensor([torch.randn(9, 9, 8, 7), torch.randn(11, 11, 8, 6)])
            >>> nested_tensor.concat.shape
            torch.Size([1293, 8])
            >>> nested_tensor = NestedTensor([torch.randn(1, 9, 9, 5), torch.randn(1, 11, 11, 5)])
            >>> nested_tensor.concat.shape
            torch.Size([202, 1, 5])
        """
        if len(self._offsets) <= 1:
            return torch.empty(0, dtype=self._values.dtype, device=self.device)
        return self._values

    def concatenate(self) -> tuple[Tensor, tuple[torch.Size, ...]]:
        r"""
        Concatenate tensors in padding dimension and return structural information for reconstruction.

        Returns:
            A tuple containing:
            - concat_tensor: The concatenated tensor (same as .concat property)
            - shapes: Tuple of original tensor shapes for reconstruction

        Examples:
            >>> nested_tensor = NestedTensor([torch.randn(9, 8), torch.randn(11, 8)])
            >>> concat_tensor, shapes = nested_tensor.concatenate()
            >>> concat_tensor.shape
            torch.Size([20, 8])
            >>> shapes
            (torch.Size([9, 8]), torch.Size([11, 8]))
            >>> reconstructed = NestedTensor.from_concatenated(concat_tensor, shapes)
            >>> torch.equal(nested_tensor.tensor, reconstructed.tensor)
            True
        """
        batch_size = len(self._offsets) - 1
        if batch_size == 0:
            return torch.empty(0, dtype=self._values.dtype, device=self.device), ()
        return self._values, self._original_shapes()

    # ------------------------------------------------------------------
    # Container protocol
    # ------------------------------------------------------------------

    def __len__(self) -> int:
        r"""Return the number of tensors in the batch."""
        if not hasattr(self, "_offsets"):
            with torch._C.DisableTorchFunctionSubclass():
                full_size = torch.Tensor.size(self)
            if len(full_size) == 0:
                return 0
            batch_dim = 0 if getattr(self, "batch_first", True) else (1 if len(full_size) > 1 else 0)
            return int(full_size[batch_dim])
        return len(self._offsets) - 1

    def __repr__(self):
        r"""Return a human-readable string representation of the NestedTensor."""
        if torch._dynamo.is_compiling():
            try:
                shape = tuple(self.size())
            except Exception:
                shape = "?"
            return (
                f"{self.__class__.__name__}(shape={shape}, dtype={self.dtype}, "
                f"device={self.device}, batch_first={getattr(self, 'batch_first', True)})"
            )

        try:
            from torch._subclasses.fake_tensor import is_fake

            for name in ("_values", "_offsets", "_physical_shape"):
                value = self.__dict__.get(name)
                if isinstance(value, Tensor) and is_fake(value):
                    shape = tuple(self.size())
                    return (
                        f"{self.__class__.__name__}(shape={shape}, dtype={self.dtype}, "
                        f"device={self.device}, batch_first={getattr(self, 'batch_first', True)})"
                    )
        except Exception:
            pass

        if not all(name in self.__dict__ for name in ("_values", "_offsets", "_physical_shape")):
            try:
                shape = tuple(self.size())
            except Exception:
                shape = "?"
            return (
                f"{self.__class__.__name__}(shape={shape}, dtype={self.dtype}, "
                f"device={self.device}, batch_first={getattr(self, 'batch_first', True)})"
            )

        if len(self) == 0:
            return self.__class__.__name__ + "()"

        storage = self._storage
        truncated = len(storage) > 10
        if truncated:
            storage = storage[:5]

        indent = "    "

        # Strip "tensor(" wrapper from each element's repr,
        # keeping PyTorch's internal number formatting (precision, alignment).
        data_parts = []
        for t in storage:
            s = repr(t)
            paren_idx = s.index("(")
            data = s[paren_idx + 1 : -1]  # noqa: E203
            # Re-indent continuation lines for multi-line element reprs (e.g. 2D tensors)
            if "\n" in data:
                lines = data.split("\n")
                data = lines[0] + "\n" + "\n".join(indent + " " + line.lstrip() for line in lines[1:])
            data_parts.append(data)

        result_lines = [self.__class__.__name__ + "(["]
        for i, part in enumerate(data_parts):
            suffix = "," if i < len(data_parts) - 1 or truncated else ""
            result_lines.append(indent + part + suffix)
        if truncated:
            result_lines.append(indent + f"... ({len(self)} tensors)")
        result_lines.append("])")
        return "\n".join(result_lines)

    def __bool__(self) -> bool:
        r"""NestedTensor follows tensor-style truthiness and never acts like a Python container."""
        raise RuntimeError(
            "Boolean value of NestedTensor is ambiguous. Use .numel(), .any(), .all(), or an explicit reduction."
        )

    def __iter__(self):
        r"""Iterate over the tensors in the batch."""
        _check_execution_guard(_ExecutionGuardKind.ITERATION, "NestedTensor.__iter__")
        return iter(self._storage)

    def __eq__(self, other):  # type: ignore[override]
        r"""Element-wise equality comparison."""
        try:
            return torch.eq(self, other)
        except TypeError:
            return NotImplemented

    def __ne__(self, other):  # type: ignore[override]
        r"""Element-wise inequality comparison."""
        try:
            return torch.ne(self, other)
        except TypeError:
            return NotImplemented

    # Python sets __hash__ = None when __eq__ is overridden in a subclass.
    # Preserve Tensor's identity hash so AOT/torch.compile memoization works.
    __hash__ = Tensor.__hash__

    # Arithmetic, comparison, and in-place operators are handled by the base
    # Tensor class, which routes through C++ → aten → __torch_dispatch__ →
    # aten_functions.py. No Python-level overrides needed.

    # ------------------------------------------------------------------
    # Conversion & Factory Methods
    # ------------------------------------------------------------------

    @classmethod
    def from_concatenated(cls, concat_tensor: Tensor, shapes: tuple[torch.Size, ...], **kwargs) -> Self:
        r"""
        Reconstruct a NestedTensor from a concatenated tensor and shape information.

        Args:
            concat_tensor: The concatenated tensor returned by concatenate()
            shapes: Tuple of original tensor shapes returned by concatenate()
            **kwargs: Additional arguments to pass to NestedTensor constructor

        Returns:
            Reconstructed NestedTensor

        Examples:
            >>> nested_tensor = NestedTensor([torch.randn(9, 9, 8), torch.randn(11, 11, 8)])
            >>> concat_tensor, shapes = nested_tensor.concatenate()
            >>> reconstructed = NestedTensor.from_concatenated(concat_tensor, shapes)
            >>> concat_tensor.shape
            torch.Size([202, 8])
            >>> reconstructed.shape
            torch.Size([2, 11, 11, 8])
            >>> torch.equal(nested_tensor.tensor, reconstructed.tensor)
            True
        """
        if not shapes:
            if "dtype" not in kwargs:
                kwargs["dtype"] = concat_tensor.dtype
            if "device" not in kwargs:
                kwargs["device"] = concat_tensor.device
            return cls([], **kwargs)

        num_elements = [shape.numel() for shape in shapes]
        element_shapes = tuple(tuple(int(dim) for dim in shape) for shape in shapes)
        varying_dims, static_dims = cls._pack_layout_from_element_shapes(element_shapes)
        permutation = varying_dims + static_dims
        identity_permutation = tuple(range(len(element_shapes[0]))) if element_shapes and element_shapes[0] else ()

        if len(set(shapes)) == 1 and permutation == identity_permutation:
            shape = shapes[0]
            total_elements = sum(num_elements)
            if concat_tensor.numel() == total_elements:
                try:
                    reshaped = concat_tensor.reshape(len(shapes), *shape)
                except (RuntimeError, ValueError):
                    # The reshape fast path is opportunistic; a normal unpack fallback
                    # is expected for non-view-compatible inputs.
                    pass
                else:
                    tensors = [t.reshape(shape) for t in reshaped.unbind(0)]
                    return cls(tensors, **kwargs)

        packed_sizes = tuple(cls._packed_size_from_shape(shape, varying_dims) for shape in element_shapes)
        total_expected = sum(num_elements)
        num_provided = concat_tensor.numel()
        if num_provided != total_expected:
            raise ValueError(
                f"Concatenated tensor has {num_provided} elements "
                f"but expected {total_expected} based on shapes {shapes}"
            )

        tensors = []
        start = 0
        inverse_permutation = cls._inverse_permutation(permutation)
        for shape, packed_size in zip(element_shapes, packed_sizes):
            end = start + packed_size
            chunk = concat_tensor.narrow(0, start, packed_size)
            packed_shape = tuple(shape[dim] for dim in varying_dims) + tuple(shape[dim] for dim in static_dims)
            tensor_data = chunk.reshape(packed_shape)
            if permutation != tuple(range(len(shape))):
                tensor_data = tensor_data.permute(inverse_permutation)
            tensors.append(tensor_data)
            start = end

        return cls(tensors, **kwargs)

    @classmethod
    def from_tensor_mask(cls, tensor: Tensor, mask: Tensor, *, batched: bool = False, **kwargs):
        r"""
        Build a `NestedTensor` object from a padded `Tensor` and corresponding mask `Tensor`.

        Args:
            tensor: Padded Tensor.
            mask: Tensor Mask.
                The mask uses the same convention as ``mask_value``:
                padding positions equal ``mask_value`` and valid positions equal ``not mask_value``.
            batched: When ``True`` and ``mask.ndim == 1``, treat ``mask`` as a per-batch-element
                selector (each ``True`` entry selects a row from ``tensor``) rather than a
                contiguous-prefix length indicator.

        Examples:
            >>> padded_tensor = torch.tensor([[1, 2, 3, 0, 0],
            ...                                [4, 5, 0, 0, 0],
            ...                                [6, 7, 8, 9, 0]])
            >>> mask_tensor = torch.tensor([[1, 1, 1, 0, 0],
            ...                             [1, 1, 0, 0, 0],
            ...                             [1, 1, 1, 1, 0]])
            >>> nested_tensor = NestedTensor.from_tensor_mask(padded_tensor, mask_tensor)
            >>> nested_tensor
            NestedTensor([
                [1, 2, 3],
                [4, 5],
                [6, 7, 8, 9]
            ])
        """
        mask = mask.to(dtype=torch.bool)
        mask_value = kwargs.get("mask_value", False)
        effective_mask = ~mask if mask_value else mask

        if mask.ndim == 1:
            if batched:
                indices = effective_mask.nonzero(as_tuple=False).flatten()
                return cls([tensor[int(i)] for i in indices], dtype=tensor.dtype, **kwargs)
            return cls(tensor[effective_mask], dtype=tensor.dtype, **kwargs)
        # ndim >= 2: batch setup is shared, per-element trim differs by rank
        batch_first = kwargs.get("batch_first", True)
        tensor_iter = tensor if batch_first else tensor.transpose(0, 1)
        mask_iter = effective_mask if batch_first else effective_mask.transpose(0, 1)
        if tensor_iter.size(0) != mask_iter.size(0):
            raise ValueError("Tensor/mask batch dimension mismatch: " f"{tensor_iter.size(0)} vs {mask_iter.size(0)}")
        trimmed = []

        def _is_prefix_mask(mask_1d: Tensor) -> bool:
            count = int(mask_1d.sum().item())
            prefix = torch.arange(mask_1d.size(0), device=mask_1d.device, dtype=torch.long) < count
            return bool(torch.equal(mask_1d, prefix))

        def _is_hierarchical_prefix_mask(mask_nd: Tensor) -> bool:
            if mask_nd.dim() == 1:
                return _is_prefix_mask(mask_nd)
            leading_valid = mask_nd.reshape(mask_nd.size(0), -1).any(dim=1)
            valid_count = int(leading_valid.sum().item())
            prefix = torch.arange(mask_nd.size(0), device=mask_nd.device, dtype=torch.long) < valid_count
            if not torch.equal(leading_valid, prefix):
                return False
            return all(_is_hierarchical_prefix_mask(mask_nd[index]) for index in range(valid_count))

        if mask.ndim == 2:
            # 1-D per-element mask: only contiguous-prefix masks can be reconstructed
            # via slicing without changing dense semantics.
            counts = mask_iter.sum(dim=1, dtype=torch.long)
            prefix = torch.arange(mask_iter.size(1), device=mask_iter.device, dtype=torch.long).unsqueeze(0)
            prefix = prefix < counts.unsqueeze(1)
            if not torch.equal(mask_iter, prefix):
                raise ValueError(
                    "from_tensor_mask() with 2-D masks requires each row to be a valid prefix mask; "
                    "interior False gaps are not supported."
                )
            for t, count in zip(tensor_iter, counts.tolist()):
                trimmed.append(t[:count])
        else:
            # N-D per-element mask: only hierarchical ragged-prefix masks are representable as NestedTensor.
            extents = torch.zeros((mask_iter.size(0), mask_iter.dim() - 1), dtype=torch.long, device=mask_iter.device)
            nonzero = mask_iter.nonzero(as_tuple=False)
            if nonzero.numel() > 0:
                batch_index = nonzero[:, :1].expand(-1, extents.size(1))
                extents.scatter_reduce_(0, batch_index, nonzero[:, 1:] + 1, reduce="amax", include_self=False)
            extent_rows = extents.cpu().tolist()
            for t, em, sizes in zip(tensor_iter, mask_iter, extent_rows):
                if not _is_hierarchical_prefix_mask(em):
                    raise ValueError(
                        "from_tensor_mask() with N-D masks requires each element mask to be a valid hierarchical "
                        "ragged prefix; "
                        "interior False gaps are not supported."
                    )
                slices = tuple(slice(0, size) for size in sizes)
                t_slice = t[slices]
                m_slice = em[slices]
                valid_mask = m_slice
                if t_slice.dim() > m_slice.dim():
                    valid_mask = m_slice.view(m_slice.shape + (1,) * (t_slice.dim() - m_slice.dim()))
                trimmed.append(t_slice.masked_fill(~valid_mask, kwargs.get("padding_value", 0.0)))
        return cls(trimmed, dtype=tensor.dtype, **kwargs)

    def _dense_to_packed_values(self, tensor: Tensor) -> Tensor | None:
        r"""
        Convert a batch-aligned dense tensor to ``self``'s packed ``_values`` layout.

        Returns ``None`` when the dense tensor does not cover the current logical
        padded extents and we must fall back to per-element slicing/repacking.
        """
        batch_leading = tensor.to(device=self.device)
        if self.dim() > 1 and not self.batch_first:
            batch_leading = batch_leading.movedim(1, 0)

        logical_shape = list(self.shape)
        if logical_shape:
            batch_dim = 0 if self.dim() <= 1 or self.batch_first else 1
            logical_shape.pop(batch_dim)
        if batch_leading.dim() != len(logical_shape) + 1:
            return None

        dense_sizes = tuple(int(batch_leading.size(dim + 1)) for dim in range(batch_leading.dim() - 1))
        if any(dense_sizes[dim] < int(size) for dim, size in enumerate(logical_shape)):
            return None

        if logical_shape:
            batch_leading = batch_leading[(slice(None), *[slice(0, int(size)) for size in logical_shape])]

        if batch_leading.dim() <= 1:
            return batch_leading.contiguous()

        return batch_leading[self._packed_dense_index(device=batch_leading.device)].contiguous()

    def _packed_sizes_like(self, element_shapes: tuple[tuple[int, ...], ...]) -> tuple[int, ...]:
        varying_dims, _ = type(self)._pack_layout_from_element_shapes(element_shapes)
        return tuple(type(self)._packed_size_from_shape(shape, varying_dims) for shape in element_shapes)

    def nested_like(self, tensor: Tensor, strict: bool = True) -> Self:
        r"""
        Create a new `NestedTensor` from a `Tensor`.
        The newly created `NestedTensor` will have the same shape as current `NestedTensor`.

        Args:
            tensor: The tensor to be converted to `NestedTensor`.
            strict: Check if the shape of `tensor` is the same as the current `NestedTensor`.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> (nested_tensor == nested_tensor.nested_like(nested_tensor)).all()
            tensor(True)
            >>> tensor = nested_tensor.tensor
            >>> (nested_tensor == nested_tensor.nested_like(tensor)).all()
            tensor(True)
            >>> f = nested_tensor.nested_like(torch.randn(2, 2))
            Traceback (most recent call last):
            ...
            ValueError: The shape of NestedTensor and input tensor does not match, ...
            >>> p = nested_tensor.nested_like(torch.randn(2, 2), False)
            >>> p = nested_tensor.nested_like(torch.randn(3, 3), False)
            Traceback (most recent call last):
            ...
            ValueError: The batch size of NestedTensor and input tensor does not match, 2 != 3
        """

        if isinstance(tensor, NestedTensor):
            return tensor.clone()

        if strict and self.shape != tensor.shape:
            raise ValueError(
                f"The shape of NestedTensor and input tensor does not match, {self.shape} != {tensor.shape}"
            )
        batch_dim = 0 if self.dim() <= 1 or self.batch_first else 1
        if len(self) != tensor.size(batch_dim):
            raise ValueError(
                "The batch size of NestedTensor and input tensor does not match, "
                f"{len(self)} != {tensor.size(batch_dim)}"
            )
        values = self._dense_to_packed_values(tensor)
        if values is not None:
            element_shapes = self._element_shapes
            return self.__class__._from_packed(
                values,
                self._offsets,
                self._physical_shape,
                batch_first=self.batch_first,
                padding_value=self.padding_value,
                mask_value=self.mask_value,
                pin_memory=self._pin_memory,
                outer_size=self._logical_shape,
                packed_sizes=self._packed_sizes,
                element_shapes=element_shapes,
            )
        dense_tensor = tensor.to(device=self.device)
        element_shapes = self._original_shapes()
        new_storage = []
        for idx, shape in enumerate(element_shapes):
            if self.batch_first:
                slices = (idx, *[slice(0, int(dim)) for dim in shape])
            else:
                if len(shape) == 0:
                    slices = (idx,)
                else:
                    slices = (slice(0, int(shape[0])), idx, *[slice(0, int(dim)) for dim in shape[1:]])
            # .contiguous() ensures storage elements don't inherit non-trivial
            # strides from the padded tensor (e.g. after transpose).
            new_storage.append(dense_tensor[slices].contiguous())
        return self.__class__(new_storage, dtype=tensor.dtype, **self._meta(include_dtype=False))

    @property
    def occupancy(self) -> float:
        r"""
        Occupancy of the NestedTensor.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3, 4]), torch.tensor([5, 6])])
            >>> nested_tensor.occupancy
            0.75
        """
        if len(self) == 0:
            return 0.0
        denom = self.shape.numel()  # type: ignore[union-attr]
        if denom == 0:
            return 0.0
        return self.numel() / denom  # type: ignore[union-attr]

    def to_torch_nested(self) -> Tensor:
        r"""
        Create a `torch.nested.nested_tensor` object from `self`.

        Examples:
            >>> nested_tensor = NestedTensor([[2, 3, 5], [7, 8]])
            >>> nt = nested_tensor.to_torch_nested()
            >>> nt.layout == torch.jagged
            True
            >>> nt.values()
            tensor([2, 3, 5, 7, 8])
        """
        storage = list(self._storage)
        if not storage or all(t.dim() > 0 for t in storage):
            return nested.nested_tensor(storage, layout=torch.jagged)
        return nested.nested_tensor(storage)

    def unbind(self, dim: int = 0) -> tuple[Tensor, ...]:
        r"""
        Unbind the NestedTensor.
        """
        return torch.unbind(self, dim=dim)

    def _maybe_exact_shape_nested_like(self, tensor: object) -> Self | None:
        r"""
        Convert an exact-shape dense tensor to this NestedTensor's layout.

        This is the shared policy boundary for dense-to-nested alignment used by
        operator helpers: only non-scalar dense tensors with logical shape exactly
        matching ``self.shape`` are converted, and the conversion always uses
        ``nested_like(..., strict=False)``.
        """
        if not isinstance(tensor, Tensor) or isinstance(tensor, type(self)):
            return None
        if tensor.dim() == 0 or tensor.shape != self.shape:
            return None
        return self.nested_like(tensor, strict=False)

    # ------------------------------------------------------------------
    # Indexing
    # ------------------------------------------------------------------

    def __getitem__(self, index: int | slice | list | tuple | Tensor | NestedTensor) -> Tensor | NestedTensor:
        r"""Retrieve element(s) by index, slice, list, tuple, or tensor mask."""
        if isinstance(index, int):
            return self._storage[index]
        if isinstance(index, (slice, list)):
            if isinstance(index, list) and index and all(isinstance(i, bool) for i in index):
                if len(index) != len(self):
                    raise IndexError(f"Boolean index has length {len(index)} but batch size is {len(self)}")
                index = [i for i, flag in enumerate(index) if flag]
            storage = tuple(self._storage[index] if isinstance(index, slice) else [self._storage[i] for i in index])
            return self.__class__(storage, **self._meta(include_dtype=True))
        if isinstance(index, tuple):
            if len(index) == 0:
                return self

            # Expand Ellipsis: ``nt[..., :2]`` on a 4-D NestedTensor becomes
            # ``nt[:, :, :, :2]``.  The batch dim is consumed first, so Ellipsis
            # fills the gap between the number of explicit indices and the total
            # number of logical dimensions.
            if Ellipsis in index:
                eidx = index.index(Ellipsis)
                n_explicit = len(index) - 1  # exclude Ellipsis itself
                n_expand = self.dim() - n_explicit
                index = index[:eidx] + (slice(None),) * n_expand + index[eidx + 1 :]

            batch_index, *rest = index

            if isinstance(batch_index, (Tensor, NestedTensor)):
                return self.tensor[index]

            if isinstance(batch_index, list) and batch_index and all(isinstance(i, bool) for i in batch_index):
                if len(batch_index) != len(self):
                    raise IndexError(f"Boolean index has length {len(batch_index)} but batch size is {len(self)}")
                batch_index = [i for i, flag in enumerate(batch_index) if flag]

            if isinstance(batch_index, int):
                tensor = self._storage[batch_index]
                if rest:
                    return tensor[tuple(rest)]
                return tensor
            elif isinstance(batch_index, (slice, list)):
                if isinstance(batch_index, slice):
                    selected = self._storage[batch_index]
                else:
                    selected = tuple(self._storage[i] for i in batch_index)
                if rest:
                    rest_tuple = tuple(rest)
                    selected = tuple(t[rest_tuple] for t in selected)
                return self.__class__(selected, **self._meta(include_dtype=True))
            raise ValueError(f"Unsupported batch index type {type(batch_index)}")
        if isinstance(index, NestedTensor):
            if len(self) != len(index):
                raise ValueError(
                    "NestedTensor batch length mismatch between self and index: "
                    f"self={len(self)}, index={len(index)}"
                )
            return self.__class__(
                [t[i] for t, i in zip(self._storage, index._storage)], **self._meta(include_dtype=True)
            )
        if isinstance(index, Tensor):
            if index.dim() == 0 and index.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8):
                return self._storage[int(index.item())]
            if index.dim() == 1:
                if index.dtype in (torch.bool, torch.uint8):
                    if index.numel() != len(self):
                        raise IndexError(f"Boolean index has length {index.numel()} but batch size is {len(self)}")
                    selected = tuple(self._storage[i] for i, flag in enumerate(index.tolist()) if bool(flag))
                    return self.__class__(selected, **self._meta(include_dtype=True))
                if index.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8):
                    return self.__class__(
                        [self._storage[int(i)] for i in index.tolist()],
                        **self._meta(include_dtype=True),
                    )
            index = self.nested_like(index, strict=False)
            return self.__class__(
                [t[i] for t, i in zip(self._storage, index._storage)], **self._meta(include_dtype=True)
            )
        raise ValueError(f"Unsupported index type {type(index)}")

    def __setitem__(self, index: int | slice | list | tuple, value: Tensor | NestedTensor) -> None:
        r"""
        Set values in the NestedTensor at the specified index.

        Args:
            index: The index to modify. Can be an integer, slice, list, or tuple.
            value: The new value to set. Can be a Tensor or NestedTensor.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor[0] = torch.tensor([6, 7, 8])
            >>> nested_tensor[0]
            tensor([6, 7, 8])
            >>> nested_tensor[1] = torch.tensor([9, 10, 11, 12])
            >>> nested_tensor.shape
            torch.Size([2, 4])
        """
        if isinstance(index, int):
            self._invalidate_transient_caches()
            if isinstance(value, NestedTensor):
                if len(value._storage) != 1:
                    raise ValueError(
                        f"When setting with an integer index, value must have a single tensor, but got {len(value)}"
                    )
                value = value._storage[0]
            if not isinstance(value, Tensor):
                value = torch.tensor(value, device=self.device, dtype=self.dtype)
            else:
                value = value.to(device=self.device, dtype=self.dtype)
            if self.requires_grad:
                value.requires_grad_(True)

            # Normalize negative index
            idx = index + len(self) if index < 0 else index
            if idx < 0 or idx >= len(self):
                raise IndexError(f"index {index} is out of range for NestedTensor with {len(self)} elements")
            expected_ndim = self._physical_shape.size(1)
            if value.dim() != expected_ndim:
                raise ValueError(
                    f"Assigned tensor ndim must match existing ndim {expected_ndim}, but got {value.dim()}"
                )

            old_start = int(self._offsets[idx].item())
            old_end = int(self._offsets[idx + 1].item())
            old_size = old_end - old_start
            new_shape_row = torch.tensor(list(value.shape), dtype=self._physical_shape.dtype)

            permutation = self._permutation
            identity_permutation = tuple(range(expected_ndim))
            varying_dims = self._varying_dims
            static_dims = self._static_dims
            packed_size = type(self)._packed_size_from_shape(tuple(int(dim) for dim in value.shape), varying_dims)
            packed_value = value if permutation == identity_permutation else value.permute(permutation)
            suffix_shape = tuple(int(value.shape[dim]) for dim in static_dims)
            new_payload = packed_value.reshape((packed_size, *suffix_shape) if suffix_shape else (packed_size,))
            new_size = packed_size

            if self._values.dim() > 1 and new_payload.shape[1:] != self._values.shape[1:]:
                storage_list = list(self._storage)
                storage_list[idx] = value
                self._repack(storage_list)
                return

            if new_size == old_size:
                # Same packed span size: direct overwrite keeps _values allocation.
                self._values[old_start:old_end] = new_payload
                self._physical_shape[idx] = new_shape_row
            else:
                # Different packed span size: splice _values and shift subsequent offsets.
                self._values = torch.cat([self._values[:old_start], new_payload, self._values[old_end:]], dim=0)
                delta = new_size - old_size
                self._offsets = self._offsets.clone()
                self._offsets[idx + 1 :] += delta  # noqa: E203
                self._physical_shape = self._physical_shape.clone()
                self._physical_shape[idx] = new_shape_row
            self._logical_shape = self._logical_shape_from_physical_shape(
                self._physical_shape, self._offsets, self.batch_first
            )
            if self._element_shapes is not None and self._packed_sizes is not None:
                element_shapes = list(self._element_shapes)
                element_shapes[idx] = tuple(int(dim) for dim in value.shape)
                self._element_shapes = tuple(element_shapes)
                packed_sizes = list(self._packed_sizes)
                packed_sizes[idx] = self._packed_sizes_like((self._element_shapes[idx],))[0]
                self._packed_sizes = tuple(packed_sizes)
            self._validate_metadata()
        elif isinstance(index, (slice, list)):
            if isinstance(index, list) and index and all(isinstance(i, bool) for i in index):
                if len(index) != len(self):
                    raise IndexError(f"Boolean index has length {len(index)} but batch size is {len(self)}")
                index = [i for i, flag in enumerate(index) if flag]

            if isinstance(value, Tensor) and not isinstance(value, NestedTensor):
                if value.dim() > 1 and value.size(0) > 1:
                    value = self.__class__(value.unbind(0), **self._meta())
                else:
                    value = self.__class__([value], **self._meta())

            if isinstance(index, slice):
                start, stop, step = index.indices(len(self))
                indices = range(start, stop, step)
            else:
                indices = index  # type: ignore[assignment]

            if len(indices) != len(value._storage):
                raise ValueError(
                    f"Size mismatch: tried to assign {len(value._storage)} values to {len(indices)} indices"
                )

            storage_list = list(self._storage)
            for i, idx in enumerate(indices):
                storage_list[idx] = value._storage[i]
            self._storage = tuple(storage_list)
        elif isinstance(index, tuple):
            if len(index) == 0:
                return
            if len(index) == 1:
                self[index[0]] = value
                return

            first_idx, rest_idx = index[0], index[1:]
            batch_indices: list[int]
            if isinstance(first_idx, int):
                batch_indices = [first_idx]
            elif isinstance(first_idx, (slice, list)):
                if isinstance(first_idx, list) and first_idx and all(isinstance(i, bool) for i in first_idx):
                    if len(first_idx) != len(self):
                        raise IndexError(f"Boolean index has length {len(first_idx)} but batch size is {len(self)}")
                    batch_indices = [i for i, flag in enumerate(first_idx) if flag]
                elif isinstance(first_idx, slice):
                    start, stop, step = first_idx.indices(len(self))
                    batch_indices = list(range(start, stop, step))
                else:
                    batch_indices = list(first_idx)  # type: ignore[arg-type]
            else:
                raise ValueError(f"Unsupported first index type {type(first_idx)}")

            if isinstance(value, NestedTensor):
                if len(batch_indices) != len(value._storage):
                    raise ValueError(
                        f"Size mismatch: tried to assign {len(value._storage)} values to {len(batch_indices)} indices"
                    )
                assigned_values = list(value._storage)
            else:
                assigned_values = [value] * len(batch_indices)

            elems = list(self._storage)
            for position, idx in enumerate(batch_indices):
                elem = elems[idx].clone()
                elem[rest_idx] = assigned_values[position]
                elems[idx] = elem
            self._storage = tuple(elems)
        else:
            raise ValueError(f"Unsupported index type {type(index)}")

    # ------------------------------------------------------------------
    # Properties: runtime config, dtype, device, requires_grad
    # ------------------------------------------------------------------

    @property
    def batch_first(self) -> bool:
        r"""Whether the logical outer shape uses ``(B, ...)`` instead of ``(..., B, ...)``."""
        return self._batch_first

    @batch_first.setter
    def batch_first(self, value: bool):
        new_value = type(self)._coerce_batch_first(value)
        old_value = getattr(self, "_batch_first", None)
        self._batch_first = new_value
        if old_value is None or old_value == new_value:
            return
        if hasattr(self, "_physical_shape") and hasattr(self, "_offsets") and hasattr(self, "_logical_shape"):
            self._logical_shape = type(self)._logical_shape_from_physical_shape(
                self._physical_shape,
                self._offsets,
                new_value,
            )
        if hasattr(self, "_cached_tensor_view"):
            self._invalidate_transient_caches()

    @property
    def padding_value(self) -> float:
        r"""Padding fill value used when materializing dense views."""
        return self._padding_value

    @padding_value.setter
    def padding_value(self, value: SupportsFloat):
        new_value = type(self)._coerce_padding_value(value)
        old_value = getattr(self, "_padding_value", None)
        self._padding_value = new_value
        if old_value is None or old_value == new_value:
            return
        if hasattr(self, "_cached_tensor_view"):
            self._cached_tensor_view = None

    @property
    def mask_value(self) -> bool:
        r"""Boolean value used to denote padding positions in generated masks."""
        return self._mask_value

    @mask_value.setter
    def mask_value(self, value: bool):
        new_value = type(self)._coerce_mask_value(value)
        old_value = getattr(self, "_mask_value", None)
        self._mask_value = new_value
        if old_value is None or old_value == new_value:
            return
        if hasattr(self, "_cached_mask_view"):
            self._cached_mask_view = None

    @property
    def dtype(self) -> torch.dtype:  # type: ignore[override]
        r"""Data type of the underlying tensor elements."""
        return self._values.dtype

    @dtype.setter
    def dtype(self, value: torch.dtype | None):
        r"""`dtype` is read-only; use `.to(dtype=...)` to convert."""
        raise AttributeError("NestedTensor.dtype is read-only; use .to(dtype=...) to create a converted tensor.")

    @property
    def device(self) -> torch.device:  # type: ignore[override]
        r"""Device on which the underlying tensor data resides."""
        return self._values.device

    @device.setter
    def device(self, value: torch.device | None):
        r"""`device` is read-only; use `.to(device=...)` to move tensors."""
        raise AttributeError("NestedTensor.device is read-only; use .to(device=...) to create a moved tensor.")

    @property
    def requires_grad(self) -> bool:  # type: ignore[override]
        r"""Whether gradient computation is enabled for this tensor."""
        return self._values.requires_grad

    @requires_grad.setter
    def requires_grad(self, value: bool):
        r"""Enable or disable gradient computation for this tensor."""
        self._values.requires_grad_(value)

    # ------------------------------------------------------------------
    # State management
    # ------------------------------------------------------------------

    def _meta(self, *, include_dtype: bool | None = None) -> Mapping:
        r"""Metadata used for structure-preserving reconstruction."""
        if include_dtype is None:
            # Empty reconstructions cannot infer dtype from storage; include it by default.
            include_dtype = self._values.numel() == 0
        if include_dtype:
            return {
                "batch_first": self.batch_first,
                "padding_value": self.padding_value,
                "mask_value": self.mask_value,
                "pin_memory": self._pin_memory,
                "device": self._values.device,
                "dtype": self.dtype,
            }
        return {
            "batch_first": self.batch_first,
            "padding_value": self.padding_value,
            "mask_value": self.mask_value,
            "pin_memory": self._pin_memory,
            "device": self._values.device,
        }

    def __getstate__(self) -> dict:
        return {
            "_state_version": self._SERIALIZATION_VERSION,
            "_values": self._values,
            "_offsets": self._offsets,
            "_permutation": self._permutation,
            "_physical_shape": self._physical_shape,
            "_logical_shape": self._logical_shape,
            "batch_first": self.batch_first,
            "padding_value": self.padding_value,
            "mask_value": self.mask_value,
            "_pin_memory": self._pin_memory,
            "_packed_sizes": self._packed_sizes,
            "_element_shapes": self._element_shapes,
        }

    def __setstate__(self, state: Mapping) -> None:
        type(self)._validate_serialized_state(state)
        self._values = state["_values"]
        self._offsets = state["_offsets"].cpu()
        self._permutation = tuple(int(dim) for dim in state["_permutation"])
        self._physical_shape = state["_physical_shape"].cpu()
        self._logical_shape = state["_logical_shape"]
        self._set_runtime_config(
            batch_first=state["batch_first"],
            padding_value=state["padding_value"],
            mask_value=state["mask_value"],
        )
        self._pin_memory = bool(state["_pin_memory"] and self._values.device.type == "cpu" and self._values.is_pinned())
        self._packed_sizes = state["_packed_sizes"]
        self._element_shapes = state["_element_shapes"]
        # Serialized state intentionally excludes transient caches.
        self._invalidate_transient_caches()
        self._validate_metadata()

    def __reduce__(self):
        return (self.__class__._from_state, (self.__getstate__(),))

    @classmethod
    def _from_state(cls, state: dict) -> Self:
        cls._validate_serialized_state(state)
        return cls._from_packed(
            state["_values"],
            state["_offsets"].cpu(),
            state["_physical_shape"].cpu(),
            permutation=tuple(int(dim) for dim in state["_permutation"]),
            batch_first=state["batch_first"],
            padding_value=state["padding_value"],
            mask_value=state["mask_value"],
            pin_memory=state["_pin_memory"],
            outer_size=state["_logical_shape"],
            packed_sizes=state["_packed_sizes"],
            element_shapes=state["_element_shapes"],
        )

    def __copy__(self):
        r"""Shallow copy: new NestedTensor sharing underlying tensor data."""
        return self.__class__._from_packed(
            self._values,
            self._offsets,
            self._physical_shape,
            permutation=self._permutation,
            batch_first=self.batch_first,
            padding_value=self.padding_value,
            mask_value=self.mask_value,
            pin_memory=self._pin_memory,
            outer_size=self._logical_shape,
            packed_sizes=self._packed_sizes,
            element_shapes=self._element_shapes,
        )

    def __deepcopy__(self, memo):
        r"""Deep copy: clones all tensor data."""
        result = self.__class__._from_packed(
            self._values.clone(),
            self._offsets.clone(),
            self._physical_shape.clone(),
            permutation=self._permutation,
            batch_first=self.batch_first,
            padding_value=self.padding_value,
            mask_value=self.mask_value,
            pin_memory=self._pin_memory,
            outer_size=self._logical_shape,
            packed_sizes=self._packed_sizes,
            element_shapes=self._element_shapes,
        )
        memo[id(self)] = result
        return result

    # ------------------------------------------------------------------
    # Tensor-like methods
    # ------------------------------------------------------------------

    def all(self, dim: int | None = None, keepdim: bool = False) -> bool | Tensor | NestedTensor:
        r"""
        Tests if all elements in NestedTensor evaluate to True.

        Examples:
            >>> nested_tensor = NestedTensor([torch.ones(2, 4, dtype=torch.bool), torch.ones(3, 5, dtype=torch.bool)])
            >>> nested_tensor.all()
            tensor(True)
            >>> nested_tensor.all(dim=0)
            tensor([True, True])
            >>> nested_tensor.all(dim=0, keepdim=True)
            tensor([[True, True]])
            >>> nested_tensor.all(dim=1)
            NestedTensor([
                [True, True, True, True],
                [True, True, True, True, True]
            ])
            >>> nested_tensor.all(dim=1, keepdim=True)
            NestedTensor([
                [[True, True, True, True]],
                [[True, True, True, True, True]]
            ])
            >>> nested_tensor.batch_first = False
            >>> nested_tensor.all(dim=1)
            tensor([True, True])
            >>> nested_tensor.all(dim=0)
            NestedTensor([
                [True, True, True, True],
                [True, True, True, True, True]
            ])
            >>> nested_tensor.all(dim=-2)
            tensor([True, True])
        """
        return torch.all(self, dim=dim, keepdim=keepdim)

    def any(self, dim: int | None = None, keepdim: bool = False) -> bool | Tensor | NestedTensor:
        r"""
        Tests if any elements in NestedTensor evaluate to True.

        Examples:
            >>> nested_tensor = NestedTensor([torch.zeros(2, dtype=torch.bool), torch.ones(3, dtype=torch.bool)])
            >>> nested_tensor.any()
            tensor(True)
            >>> nested_tensor.any(dim=0)
            tensor([False,  True])
        """
        return torch.any(self, dim=dim, keepdim=keepdim)

    def dim(self) -> int:
        r"""
        Number of dimension of the NestedTensor.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.dim()
            2
        """
        if not hasattr(self, "_logical_shape"):
            with torch._C.DisableTorchFunctionSubclass():
                return len(torch.Tensor.size(self))
        return len(self._logical_shape)

    def max(self, dim: int | None = None, keepdim: bool = False) -> Tensor | NestedTensor:
        r"""Return the maximum value, optionally along a given dimension."""
        if dim is None:
            return torch.max(self)
        return torch.max(self, dim=dim, keepdim=keepdim)

    def mean(
        self,
        dim: int | None = None,
        keepdim: bool = False,
        *,
        dtype: torch.dtype | None = None,  # type: ignore[name-defined]
    ) -> Tensor | NestedTensor:
        r"""Return the mean value, optionally along a given dimension."""
        return torch.mean(self, dim=dim, keepdim=keepdim, dtype=dtype)

    def min(self, dim: int | None = None, keepdim: bool = False) -> Tensor | NestedTensor:
        r"""Return the minimum value, optionally along a given dimension."""
        if dim is None:
            return torch.min(self)
        return torch.min(self, dim=dim, keepdim=keepdim)

    @property
    def mT(self) -> Self:  # type: ignore[override]
        r"""Matrix transpose over the last two per-element dimensions."""
        ndims = self.dim()
        batch_dim = 0 if self.batch_first else 1
        elem_dims = [d for d in range(ndims) if d != batch_dim]
        if len(elem_dims) < 2:
            raise RuntimeError(
                f"tensor.mT is only supported on matrices or batches of matrices. Got {len(elem_dims)}-D tensor."
            )
        return torch.transpose(self, elem_dims[-2], elem_dims[-1])

    @property
    def ndim(self) -> int:
        r"""
        Alias for `dim()`.
        """
        return self.dim()

    def numel(self) -> int:
        r"""
        Number of elements in the NestedTensor.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.numel()
            5
        """
        return self._values.numel()

    def permute(self, *dims) -> Self:
        r"""
        Apply permutation to each tensor in the NestedTensor.

        Args:
            *dims: The desired ordering of dimensions for the NestedTensor (including batch dimension).

        Returns:
            NestedTensor: A new NestedTensor with each tensor permuted.

        Examples:
            >>> nested_tensor = NestedTensor([torch.randn(3, 4, 5), torch.randn(2, 4, 5)])
            >>> permuted = nested_tensor.permute(0, 3, 1, 2)
            >>> permuted.shape
            torch.Size([2, 5, 3, 4])
        """
        return torch.permute(self, dims)

    def moveaxis(self, source, destination) -> Self:
        r"""Move per-element dimensions to new positions."""
        return torch.moveaxis(self, source, destination)

    def movedim(self, source, destination) -> Self:
        r"""Alias for `moveaxis()`."""
        return torch.movedim(self, source, destination)

    # to(), clone(), detach(), contiguous(), half(), float(), double(), etc.
    # are all handled by aten dispatch in aten_functions.py (aten._to_copy, aten.clone,
    # aten.detach). No custom Python methods needed.

    def pin_memory(self) -> Self:
        r"""Pin the underlying tensor memory for faster host-to-device transfer."""
        return type(self)._from_packed(
            self._values.pin_memory(),
            self._offsets,
            self._physical_shape,
            batch_first=self.batch_first,
            padding_value=self.padding_value,
            mask_value=self.mask_value,
            pin_memory=True,
            outer_size=self._logical_shape,
            packed_sizes=self._packed_sizes,
            element_shapes=self._element_shapes,
        )

    def prod(
        self,
        dim: int | None = None,
        keepdim: bool = False,
        *,
        dtype: torch.dtype | None = None,  # type: ignore[name-defined]
    ) -> Tensor | NestedTensor:
        r"""Return the product of elements, optionally along a given dimension."""
        return torch.prod(self, dim=dim, keepdim=keepdim, dtype=dtype)

    def requires_grad_(self, requires_grad: bool = True):
        r"""Enable or disable gradient computation in-place."""
        self.requires_grad = requires_grad
        return self

    def reshape(self, *shape) -> Self:
        r"""
        Reshape each tensor in the NestedTensor.

        Args:
            *shape: The desired size of each dimension for the underlying tensors.

        Returns:
            NestedTensor: A new NestedTensor with each tensor reshaped.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]])])
            >>> reshaped = nested_tensor.reshape(4)
            >>> reshaped.shape
            torch.Size([2, 4])
        """
        if not shape:
            raise TypeError("reshape() missing shape")
        target_shape = shape[0] if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)) else shape
        return torch.reshape(self, target_shape)

    def flatten(self, start_dim: int = 0, end_dim: int = -1):
        r"""Flatten each tensor in the NestedTensor."""
        return torch.flatten(self, start_dim=start_dim, end_dim=end_dim)

    def flip(self, dims) -> Self:
        r"""Flip each tensor in the NestedTensor along the given dimensions."""
        return torch.flip(self, dims)

    @property
    def shape(self) -> torch.Size:  # type: ignore[override, name-defined]
        r"""
        Alias for `size()`.
        """
        return self.size()

    def size(self, dim: int | None = None) -> torch.Size | int:  # type: ignore[override, name-defined]
        r"""
        Returns the size of the self `NestedTensor`.

        Args:
            dim: If not specified, the returned value is a `torch.Size`, a subclass of `tuple`.
                If specified, returns an `int` holding the size of that dimension.
                Defaults to `None`.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.size()
            torch.Size([2, 3])
            >>> nested_tensor.size(0)
            2
            >>> nested_tensor[1] = torch.tensor([4, 5, 6, 7])
            >>> nested_tensor.shape
            torch.Size([2, 4])
            >>> nested_tensor.size(1)
            4
        """
        if hasattr(self, "_logical_shape"):
            full_size = self._logical_shape
        else:
            with torch._C.DisableTorchFunctionSubclass():
                full_size = torch.Tensor.size(self)
        if dim is not None:
            dim = dim + len(full_size) if dim < 0 else dim
            return full_size[dim]
        return full_size

    def sum(
        self,
        dim: int | Sequence[int] | None = None,
        keepdim: bool = False,
        *,
        dtype: torch.dtype | None = None,  # type: ignore[name-defined]
    ) -> Tensor | NestedTensor:
        r"""
        Returns the sum of each tensor over the given dimension(s).

        Args:
            dim: The dimension or dimensions to reduce. If None, sum over all dimensions.
                Supports int, Sequence[int], or None. Negative dimensions are supported.
            keepdim: Whether to retain reduced dimensions with size 1.
            dtype: The desired data type of returned tensor.

        Returns:
            Tensor or NestedTensor depending on the dimensions being reduced.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.sum()
            tensor(15)
            >>> nested_tensor.sum(dim=0)  # when dim=0, sum across batch dimension
            tensor([6, 9])
            >>> nested_tensor.sum(dim=1)
            tensor([6, 9])
            >>> nested_tensor.sum(dim=[0, 1])
            tensor(15)
            >>> nested_tensor.sum(dim=0, keepdim=True)
            tensor([[6, 9]])
            >>> nested_tensor.sum(dtype=torch.float32)
            tensor(15.)
        """
        return torch.sum(self, dim=dim, keepdim=keepdim, dtype=dtype)

    @property
    def T(self) -> Self:  # type: ignore[override]
        r"""Transpose: reverse per-element dims while keeping batch dim fixed."""
        ndims = self.dim()
        if ndims <= 1:
            return self
        batch_dim = 0 if self.batch_first else 1
        elem_dims = [d for d in range(ndims) if d != batch_dim]
        order = list(reversed(elem_dims))
        order.insert(batch_dim, batch_dim)
        return torch.permute(self, tuple(order))

    def tolist(self) -> list:
        r"""
        Convert a NestedTensor to a list of lists of values.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.tolist()
            [[1, 2, 3], [4, 5]]
        """
        return [t.tolist() for t in self._storage]

    def transpose(self, dim0: int, dim1: int) -> Self:  # type: ignore[valid-type]
        r"""
        Transpose dimensions dim0 and dim1 for each tensor in the NestedTensor.

        Args:
            dim0: First dimension to transpose (in NestedTensor coordinate system).
            dim1: Second dimension to transpose (in NestedTensor coordinate system).

        Returns:
            NestedTensor: A new NestedTensor with each tensor transposed.

        Examples:
            >>> nested_tensor = NestedTensor([torch.randn(3, 4), torch.randn(2, 4)])
            >>> # NestedTensor shape is [2, 3, 4], underlying tensors are [3, 4] and [2, 4]
            >>> transposed = nested_tensor.transpose(1, 2)  # transpose dims 1 and 2
            >>> transposed.shape  # batch dimension is still first
            torch.Size([2, 4, 3])
        """
        return torch.transpose(self, dim0, dim1)

    def swapaxes(self, axis0: int, axis1: int) -> Self:
        r"""Alias for `transpose()`."""
        return torch.swapaxes(self, axis0, axis1)

    def swapdims(self, dim0: int, dim1: int) -> Self:
        r"""Alias for `swapaxes()`."""
        return torch.swapdims(self, dim0, dim1)

    def squeeze(self, dim: int | None = None) -> Self:  # type: ignore[valid-type]
        r"""Squeeze singleton dimensions from each tensor in the NestedTensor."""
        return torch.squeeze(self, dim=dim)

    def unsqueeze(self, dim: int) -> Self:  # type: ignore[valid-type]
        r"""
        Unsqueeze each tensor in the NestedTensor by adding a singleton dimension at the specified position.

        Args:
            dim: The dimension at which to add the singleton dimension. This is in the NestedTensor's
                coordinate system (where dim 0 is the batch dimension).

        Returns:
            NestedTensor: A new NestedTensor with each tensor unsqueezed at the specified dimension.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> # Original shape: [2, 3] (batch_size=2, max_seq_len=3)
            >>> unsqueezed = nested_tensor.unsqueeze(1)
            >>> unsqueezed.shape
            torch.Size([2, 1, 3])
            >>> # Now each underlying tensor has shape [1, seq_len] instead of [seq_len]

            >>> nested_tensor_2d = NestedTensor([torch.randn(3, 4), torch.randn(2, 4)])
            >>> # Original shape: [2, 3, 4] (batch_size=2, max_len1=3, max_len2=4)
            >>> unsqueezed_2d = nested_tensor_2d.unsqueeze(2)
            >>> unsqueezed_2d.shape
            torch.Size([2, 3, 1, 4])
            >>> # Now each underlying tensor has shape [len1, 1, len2] instead of [len1, len2]
        """
        return torch.unsqueeze(self, dim)

    def unflatten(self, dim: int, sizes) -> Self:  # type: ignore[valid-type]
        r"""Unflatten one dimension of each tensor in the NestedTensor."""
        return torch.unflatten(self, dim, sizes)

    def roll(self, shifts, dims=None) -> Self:
        r"""Roll each tensor in the NestedTensor along the given dimensions."""
        return torch.roll(self, shifts, dims=dims)

    def rot90(self, k: int = 1, dims: Sequence[int] = (0, 1)) -> Self:
        r"""Rotate each tensor in the NestedTensor by 90 degrees in the given plane."""
        return torch.rot90(self, k, dims)

    def view(self, *shape) -> Self:
        r"""
        View each tensor in the NestedTensor with a different shape.

        Args:
            *shape: The desired size of each dimension for the underlying tensors.

        Returns:
            NestedTensor: A new NestedTensor with each tensor viewed with the new shape.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]])])
            >>> viewed = nested_tensor.view(4)  # View each 2x2 tensor as 4
            >>> viewed.shape
            torch.Size([2, 4])
            >>> type(viewed).__name__
            'NestedTensor'
        """
        if not shape:
            raise TypeError("view() missing shape")
        target_shape = shape[0] if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)) else shape
        return NestedTensorAtenRegistry[torch.ops.aten.view.default](
            torch.ops.aten.view.default, (self, list(target_shape)), {}
        )

    def _view_shapes(self, shape) -> list[tuple[int, ...]]:  # type: ignore[valid-type]
        r"""
        Compute per-element view shapes, adjusting ragged dimensions.

        Batch-dim detection rules:
        1. If ``shape[batch_dim]`` does not match the batch size, batch dim is NOT included.
        2. If ``len(shape) != self.dim()``, batch dim IS included (unambiguous).
        3. If ``len(shape) == self.dim()`` (ambiguous), batch dim is included only when
           at least one other dimension matches max_sizes or is -1.

        For ragged dimensions, each target dimension that matches the corresponding
        max size is substituted with the element's actual size. When a target dimension
        matches a max size at a different position (e.g. after inserting a dim), a
        single-candidate search resolves the mapping.
        """
        if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
            shape = tuple(shape[0])

        batch_dim = 0 if self.batch_first else 1
        batch_size = len(self)

        # Step 1: Determine if batch dim is in the target shape
        include_batch = False
        if len(shape) > batch_dim:
            if shape[batch_dim] == batch_size and len(shape) != self.dim():
                include_batch = True
            elif shape[batch_dim] in (-1, batch_size) and len(shape) == self.dim():
                # Ambiguous: same dim count → confirm via dimension matching
                max_sizes = list(self.size())  # type: ignore[arg-type]
                if max_sizes:
                    max_sizes.pop(batch_dim)
                non_batch = [i for i in range(len(shape)) if i != batch_dim]
                include_batch = any(
                    j < len(max_sizes) and (shape[d] == -1 or shape[d] == max_sizes[j]) for j, d in enumerate(non_batch)
                )

        # Step 2: Strip batch dim from target shape
        target = list(shape)
        if include_batch:
            if target[batch_dim] == -1:
                target[batch_dim] = batch_size
            if target[batch_dim] != batch_size:
                raise ValueError(f"Batch dimension mismatch: expected {batch_size} but got {target[batch_dim]}")
            target.pop(batch_dim)

        # Step 3: Per-element shape adjustment (ragged dim substitution)
        max_sizes = list(self.size())  # type: ignore[arg-type]
        if max_sizes:
            max_sizes.pop(batch_dim)

        element_shapes = self._element_shapes
        if element_shapes is None:
            element_shapes = tuple(tuple(shape) for shape in self._original_shapes())

        view_shapes = []
        for element_shape in element_shapes:
            adjusted = list(target)
            available = list(range(len(max_sizes)))
            for i in range(min(len(adjusted), len(max_sizes))):
                if adjusted[i] == -1:
                    continue
                # Direct match: same position in max_sizes
                if adjusted[i] == max_sizes[i]:
                    adjusted[i] = element_shape[i]
                    if i in available:
                        available.remove(i)
                    continue
                # Indirect match: search remaining positions for unique candidate
                candidates = [j for j in available if max_sizes[j] == adjusted[i]]
                if len(candidates) == 1:
                    j = candidates[0]
                    adjusted[i] = element_shape[j]
                    available.remove(j)
            if adjusted.count(-1) == 1:
                missing = adjusted.index(-1)
                known = 1
                for dim in adjusted:
                    if dim != -1:
                        known *= dim
                element_numel = type(self)._shape_numel(element_shape)
                if known != 0 and element_numel % known == 0:
                    adjusted[missing] = element_numel // known
            view_shapes.append(tuple(adjusted))
        return view_shapes

    def where(self, condition: Tensor | NestedTensor, other: Tensor | NestedTensor | SupportsFloat) -> Self:
        r"""
        Return a NestedTensor of elements selected from either self or other, depending on condition.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.where(nested_tensor > 2, torch.tensor([[6, 5, 4], [3, 2, 1]]))
            NestedTensor([
                [6, 5, 3],
                [4, 5]
            ])
            >>> nested_tensor.where(nested_tensor > 2, NestedTensor([[6, 5, 4], [3, 2]]))
            NestedTensor([
                [6, 5, 3],
                [4, 5]
            ])
            >>> nested_tensor.where(torch.tensor(True), NestedTensor([[6, 5, 4], [3, 2]]))
            NestedTensor([
                [1, 2, 3],
                [4, 5]
            ])
        """
        return torch.where(condition, self, other)

tensor_mask property

Python
tensor_mask: tuple[Tensor, Tensor]

Return a tuple of padded tensor and mask tensor.

Examples:

Python Console Session
1
2
3
4
5
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.tensor_mask
(tensor([[1, 2, 3],
        [4, 5, 0]]), tensor([[ True,  True,  True],
        [ True,  True, False]]))

tensor property

Python
tensor: Tensor

Return a single tensor by padding all the tensors.

Examples:

Python Console Session
1
2
3
4
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.tensor
tensor([[1, 2, 3],
        [4, 5, 0]])

mask property

Python
mask: Tensor

Padding mask of tensor.

mask_value controls which boolean value denotes padding in this mask. With the default mask_value=False, True means valid data.

Examples:

Python Console Session
1
2
3
4
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.mask
tensor([[ True,  True,  True],
        [ True,  True, False]])

concat property

Python
concat: Tensor

Flatten elements and concatenate along the ragged dimension (no padding).

This is particularly useful when calculating loss or passing Linear to avoid unnecessary computation.

Examples:

Python Console Session
>>> nested_tensor = NestedTensor([torch.randn(9, 8), torch.randn(11, 8)])
>>> nested_tensor.concat.shape
torch.Size([20, 8])
>>> nested_tensor = NestedTensor([torch.randn(9, 9, 8), torch.randn(11, 11, 8)])
>>> nested_tensor.concat.shape
torch.Size([202, 8])
>>> nested_tensor = NestedTensor([torch.randn(9, 9, 8, 6), torch.randn(11, 11, 8, 6)])
>>> nested_tensor.concat.shape
torch.Size([202, 8, 6])
>>> nested_tensor = NestedTensor([torch.randn(9, 9, 8, 7), torch.randn(11, 11, 8, 6)])
>>> nested_tensor.concat.shape
torch.Size([1293, 8])
>>> nested_tensor = NestedTensor([torch.randn(1, 9, 9, 5), torch.randn(1, 11, 11, 5)])
>>> nested_tensor.concat.shape
torch.Size([202, 1, 5])

occupancy property

Python
occupancy: float

Occupancy of the NestedTensor.

Examples:

Python Console Session
1
2
3
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3, 4]), torch.tensor([5, 6])])
>>> nested_tensor.occupancy
0.75

batch_first property writable

Python
batch_first: bool

Whether the logical outer shape uses (B, ...) instead of (..., B, ...).

padding_value property writable

Python
padding_value: float

Padding fill value used when materializing dense views.

mask_value property writable

Python
mask_value: bool

Boolean value used to denote padding positions in generated masks.

dtype property writable

Python
dtype: dtype

Data type of the underlying tensor elements.

device property writable

Python
device: device

Device on which the underlying tensor data resides.

requires_grad property writable

Python
requires_grad: bool

Whether gradient computation is enabled for this tensor.

mT property

Python
mT: Self

Matrix transpose over the last two per-element dimensions.

ndim property

Python
ndim: int

Alias for dim().

shape property

Python
shape: Size

Alias for size().

T property

Python
T: Self

concatenate

Python
concatenate() -> tuple[Tensor, tuple[Size, ...]]

Concatenate tensors in padding dimension and return structural information for reconstruction.

Returns:

Type Description
Tensor

A tuple containing:

tuple[Size, ...]
  • concat_tensor: The concatenated tensor (same as .concat property)
tuple[Tensor, tuple[Size, ...]]
  • shapes: Tuple of original tensor shapes for reconstruction

Examples:

Python Console Session
1
2
3
4
5
6
7
8
9
>>> nested_tensor = NestedTensor([torch.randn(9, 8), torch.randn(11, 8)])
>>> concat_tensor, shapes = nested_tensor.concatenate()
>>> concat_tensor.shape
torch.Size([20, 8])
>>> shapes
(torch.Size([9, 8]), torch.Size([11, 8]))
>>> reconstructed = NestedTensor.from_concatenated(concat_tensor, shapes)
>>> torch.equal(nested_tensor.tensor, reconstructed.tensor)
True
Source code in danling/tensors/nested_tensor.py
Python
def concatenate(self) -> tuple[Tensor, tuple[torch.Size, ...]]:
    r"""
    Concatenate tensors in padding dimension and return structural information for reconstruction.

    Returns:
        A tuple containing:
        - concat_tensor: The concatenated tensor (same as .concat property)
        - shapes: Tuple of original tensor shapes for reconstruction

    Examples:
        >>> nested_tensor = NestedTensor([torch.randn(9, 8), torch.randn(11, 8)])
        >>> concat_tensor, shapes = nested_tensor.concatenate()
        >>> concat_tensor.shape
        torch.Size([20, 8])
        >>> shapes
        (torch.Size([9, 8]), torch.Size([11, 8]))
        >>> reconstructed = NestedTensor.from_concatenated(concat_tensor, shapes)
        >>> torch.equal(nested_tensor.tensor, reconstructed.tensor)
        True
    """
    batch_size = len(self._offsets) - 1
    if batch_size == 0:
        return torch.empty(0, dtype=self._values.dtype, device=self.device), ()
    return self._values, self._original_shapes()

__len__

Python
__len__() -> int

Return the number of tensors in the batch.

Source code in danling/tensors/nested_tensor.py
Python
def __len__(self) -> int:
    r"""Return the number of tensors in the batch."""
    if not hasattr(self, "_offsets"):
        with torch._C.DisableTorchFunctionSubclass():
            full_size = torch.Tensor.size(self)
        if len(full_size) == 0:
            return 0
        batch_dim = 0 if getattr(self, "batch_first", True) else (1 if len(full_size) > 1 else 0)
        return int(full_size[batch_dim])
    return len(self._offsets) - 1

__repr__

Python
__repr__()

Return a human-readable string representation of the NestedTensor.

Source code in danling/tensors/nested_tensor.py
Python
def __repr__(self):
    r"""Return a human-readable string representation of the NestedTensor."""
    if torch._dynamo.is_compiling():
        try:
            shape = tuple(self.size())
        except Exception:
            shape = "?"
        return (
            f"{self.__class__.__name__}(shape={shape}, dtype={self.dtype}, "
            f"device={self.device}, batch_first={getattr(self, 'batch_first', True)})"
        )

    try:
        from torch._subclasses.fake_tensor import is_fake

        for name in ("_values", "_offsets", "_physical_shape"):
            value = self.__dict__.get(name)
            if isinstance(value, Tensor) and is_fake(value):
                shape = tuple(self.size())
                return (
                    f"{self.__class__.__name__}(shape={shape}, dtype={self.dtype}, "
                    f"device={self.device}, batch_first={getattr(self, 'batch_first', True)})"
                )
    except Exception:
        pass

    if not all(name in self.__dict__ for name in ("_values", "_offsets", "_physical_shape")):
        try:
            shape = tuple(self.size())
        except Exception:
            shape = "?"
        return (
            f"{self.__class__.__name__}(shape={shape}, dtype={self.dtype}, "
            f"device={self.device}, batch_first={getattr(self, 'batch_first', True)})"
        )

    if len(self) == 0:
        return self.__class__.__name__ + "()"

    storage = self._storage
    truncated = len(storage) > 10
    if truncated:
        storage = storage[:5]

    indent = "    "

    # Strip "tensor(" wrapper from each element's repr,
    # keeping PyTorch's internal number formatting (precision, alignment).
    data_parts = []
    for t in storage:
        s = repr(t)
        paren_idx = s.index("(")
        data = s[paren_idx + 1 : -1]  # noqa: E203
        # Re-indent continuation lines for multi-line element reprs (e.g. 2D tensors)
        if "\n" in data:
            lines = data.split("\n")
            data = lines[0] + "\n" + "\n".join(indent + " " + line.lstrip() for line in lines[1:])
        data_parts.append(data)

    result_lines = [self.__class__.__name__ + "(["]
    for i, part in enumerate(data_parts):
        suffix = "," if i < len(data_parts) - 1 or truncated else ""
        result_lines.append(indent + part + suffix)
    if truncated:
        result_lines.append(indent + f"... ({len(self)} tensors)")
    result_lines.append("])")
    return "\n".join(result_lines)

__bool__

Python
__bool__() -> bool

NestedTensor follows tensor-style truthiness and never acts like a Python container.

Source code in danling/tensors/nested_tensor.py
Python
def __bool__(self) -> bool:
    r"""NestedTensor follows tensor-style truthiness and never acts like a Python container."""
    raise RuntimeError(
        "Boolean value of NestedTensor is ambiguous. Use .numel(), .any(), .all(), or an explicit reduction."
    )

__iter__

Python
__iter__()

Iterate over the tensors in the batch.

Source code in danling/tensors/nested_tensor.py
Python
def __iter__(self):
    r"""Iterate over the tensors in the batch."""
    _check_execution_guard(_ExecutionGuardKind.ITERATION, "NestedTensor.__iter__")
    return iter(self._storage)

__eq__

Python
__eq__(other)

Element-wise equality comparison.

Source code in danling/tensors/nested_tensor.py
Python
def __eq__(self, other):  # type: ignore[override]
    r"""Element-wise equality comparison."""
    try:
        return torch.eq(self, other)
    except TypeError:
        return NotImplemented

__ne__

Python
__ne__(other)

Element-wise inequality comparison.

Source code in danling/tensors/nested_tensor.py
Python
def __ne__(self, other):  # type: ignore[override]
    r"""Element-wise inequality comparison."""
    try:
        return torch.ne(self, other)
    except TypeError:
        return NotImplemented

from_concatenated classmethod

Python
from_concatenated(
    concat_tensor: Tensor,
    shapes: tuple[Size, ...],
    **kwargs
) -> Self

Reconstruct a NestedTensor from a concatenated tensor and shape information.

Parameters:

Name Type Description Default
concat_tensor
Tensor

The concatenated tensor returned by concatenate()

required
shapes
tuple[Size, ...]

Tuple of original tensor shapes returned by concatenate()

required
**kwargs

Additional arguments to pass to NestedTensor constructor

{}

Returns:

Type Description
Self

Reconstructed NestedTensor

Examples:

Python Console Session
1
2
3
4
5
6
7
8
9
>>> nested_tensor = NestedTensor([torch.randn(9, 9, 8), torch.randn(11, 11, 8)])
>>> concat_tensor, shapes = nested_tensor.concatenate()
>>> reconstructed = NestedTensor.from_concatenated(concat_tensor, shapes)
>>> concat_tensor.shape
torch.Size([202, 8])
>>> reconstructed.shape
torch.Size([2, 11, 11, 8])
>>> torch.equal(nested_tensor.tensor, reconstructed.tensor)
True
Source code in danling/tensors/nested_tensor.py
Python
@classmethod
def from_concatenated(cls, concat_tensor: Tensor, shapes: tuple[torch.Size, ...], **kwargs) -> Self:
    r"""
    Reconstruct a NestedTensor from a concatenated tensor and shape information.

    Args:
        concat_tensor: The concatenated tensor returned by concatenate()
        shapes: Tuple of original tensor shapes returned by concatenate()
        **kwargs: Additional arguments to pass to NestedTensor constructor

    Returns:
        Reconstructed NestedTensor

    Examples:
        >>> nested_tensor = NestedTensor([torch.randn(9, 9, 8), torch.randn(11, 11, 8)])
        >>> concat_tensor, shapes = nested_tensor.concatenate()
        >>> reconstructed = NestedTensor.from_concatenated(concat_tensor, shapes)
        >>> concat_tensor.shape
        torch.Size([202, 8])
        >>> reconstructed.shape
        torch.Size([2, 11, 11, 8])
        >>> torch.equal(nested_tensor.tensor, reconstructed.tensor)
        True
    """
    if not shapes:
        if "dtype" not in kwargs:
            kwargs["dtype"] = concat_tensor.dtype
        if "device" not in kwargs:
            kwargs["device"] = concat_tensor.device
        return cls([], **kwargs)

    num_elements = [shape.numel() for shape in shapes]
    element_shapes = tuple(tuple(int(dim) for dim in shape) for shape in shapes)
    varying_dims, static_dims = cls._pack_layout_from_element_shapes(element_shapes)
    permutation = varying_dims + static_dims
    identity_permutation = tuple(range(len(element_shapes[0]))) if element_shapes and element_shapes[0] else ()

    if len(set(shapes)) == 1 and permutation == identity_permutation:
        shape = shapes[0]
        total_elements = sum(num_elements)
        if concat_tensor.numel() == total_elements:
            try:
                reshaped = concat_tensor.reshape(len(shapes), *shape)
            except (RuntimeError, ValueError):
                # The reshape fast path is opportunistic; a normal unpack fallback
                # is expected for non-view-compatible inputs.
                pass
            else:
                tensors = [t.reshape(shape) for t in reshaped.unbind(0)]
                return cls(tensors, **kwargs)

    packed_sizes = tuple(cls._packed_size_from_shape(shape, varying_dims) for shape in element_shapes)
    total_expected = sum(num_elements)
    num_provided = concat_tensor.numel()
    if num_provided != total_expected:
        raise ValueError(
            f"Concatenated tensor has {num_provided} elements "
            f"but expected {total_expected} based on shapes {shapes}"
        )

    tensors = []
    start = 0
    inverse_permutation = cls._inverse_permutation(permutation)
    for shape, packed_size in zip(element_shapes, packed_sizes):
        end = start + packed_size
        chunk = concat_tensor.narrow(0, start, packed_size)
        packed_shape = tuple(shape[dim] for dim in varying_dims) + tuple(shape[dim] for dim in static_dims)
        tensor_data = chunk.reshape(packed_shape)
        if permutation != tuple(range(len(shape))):
            tensor_data = tensor_data.permute(inverse_permutation)
        tensors.append(tensor_data)
        start = end

    return cls(tensors, **kwargs)

from_tensor_mask classmethod

Python
from_tensor_mask(
    tensor: Tensor,
    mask: Tensor,
    *,
    batched: bool = False,
    **kwargs
)

Build a NestedTensor object from a padded Tensor and corresponding mask Tensor.

Parameters:

Name Type Description Default
tensor
Tensor

Padded Tensor.

required
mask
Tensor

Tensor Mask. The mask uses the same convention as mask_value: padding positions equal mask_value and valid positions equal not mask_value.

required
batched
bool

When True and mask.ndim == 1, treat mask as a per-batch-element selector (each True entry selects a row from tensor) rather than a contiguous-prefix length indicator.

False

Examples:

Python Console Session
>>> padded_tensor = torch.tensor([[1, 2, 3, 0, 0],
...                                [4, 5, 0, 0, 0],
...                                [6, 7, 8, 9, 0]])
>>> mask_tensor = torch.tensor([[1, 1, 1, 0, 0],
...                             [1, 1, 0, 0, 0],
...                             [1, 1, 1, 1, 0]])
>>> nested_tensor = NestedTensor.from_tensor_mask(padded_tensor, mask_tensor)
>>> nested_tensor
NestedTensor([
    [1, 2, 3],
    [4, 5],
    [6, 7, 8, 9]
])
Source code in danling/tensors/nested_tensor.py
Python
@classmethod
def from_tensor_mask(cls, tensor: Tensor, mask: Tensor, *, batched: bool = False, **kwargs):
    r"""
    Build a `NestedTensor` object from a padded `Tensor` and corresponding mask `Tensor`.

    Args:
        tensor: Padded Tensor.
        mask: Tensor Mask.
            The mask uses the same convention as ``mask_value``:
            padding positions equal ``mask_value`` and valid positions equal ``not mask_value``.
        batched: When ``True`` and ``mask.ndim == 1``, treat ``mask`` as a per-batch-element
            selector (each ``True`` entry selects a row from ``tensor``) rather than a
            contiguous-prefix length indicator.

    Examples:
        >>> padded_tensor = torch.tensor([[1, 2, 3, 0, 0],
        ...                                [4, 5, 0, 0, 0],
        ...                                [6, 7, 8, 9, 0]])
        >>> mask_tensor = torch.tensor([[1, 1, 1, 0, 0],
        ...                             [1, 1, 0, 0, 0],
        ...                             [1, 1, 1, 1, 0]])
        >>> nested_tensor = NestedTensor.from_tensor_mask(padded_tensor, mask_tensor)
        >>> nested_tensor
        NestedTensor([
            [1, 2, 3],
            [4, 5],
            [6, 7, 8, 9]
        ])
    """
    mask = mask.to(dtype=torch.bool)
    mask_value = kwargs.get("mask_value", False)
    effective_mask = ~mask if mask_value else mask

    if mask.ndim == 1:
        if batched:
            indices = effective_mask.nonzero(as_tuple=False).flatten()
            return cls([tensor[int(i)] for i in indices], dtype=tensor.dtype, **kwargs)
        return cls(tensor[effective_mask], dtype=tensor.dtype, **kwargs)
    # ndim >= 2: batch setup is shared, per-element trim differs by rank
    batch_first = kwargs.get("batch_first", True)
    tensor_iter = tensor if batch_first else tensor.transpose(0, 1)
    mask_iter = effective_mask if batch_first else effective_mask.transpose(0, 1)
    if tensor_iter.size(0) != mask_iter.size(0):
        raise ValueError("Tensor/mask batch dimension mismatch: " f"{tensor_iter.size(0)} vs {mask_iter.size(0)}")
    trimmed = []

    def _is_prefix_mask(mask_1d: Tensor) -> bool:
        count = int(mask_1d.sum().item())
        prefix = torch.arange(mask_1d.size(0), device=mask_1d.device, dtype=torch.long) < count
        return bool(torch.equal(mask_1d, prefix))

    def _is_hierarchical_prefix_mask(mask_nd: Tensor) -> bool:
        if mask_nd.dim() == 1:
            return _is_prefix_mask(mask_nd)
        leading_valid = mask_nd.reshape(mask_nd.size(0), -1).any(dim=1)
        valid_count = int(leading_valid.sum().item())
        prefix = torch.arange(mask_nd.size(0), device=mask_nd.device, dtype=torch.long) < valid_count
        if not torch.equal(leading_valid, prefix):
            return False
        return all(_is_hierarchical_prefix_mask(mask_nd[index]) for index in range(valid_count))

    if mask.ndim == 2:
        # 1-D per-element mask: only contiguous-prefix masks can be reconstructed
        # via slicing without changing dense semantics.
        counts = mask_iter.sum(dim=1, dtype=torch.long)
        prefix = torch.arange(mask_iter.size(1), device=mask_iter.device, dtype=torch.long).unsqueeze(0)
        prefix = prefix < counts.unsqueeze(1)
        if not torch.equal(mask_iter, prefix):
            raise ValueError(
                "from_tensor_mask() with 2-D masks requires each row to be a valid prefix mask; "
                "interior False gaps are not supported."
            )
        for t, count in zip(tensor_iter, counts.tolist()):
            trimmed.append(t[:count])
    else:
        # N-D per-element mask: only hierarchical ragged-prefix masks are representable as NestedTensor.
        extents = torch.zeros((mask_iter.size(0), mask_iter.dim() - 1), dtype=torch.long, device=mask_iter.device)
        nonzero = mask_iter.nonzero(as_tuple=False)
        if nonzero.numel() > 0:
            batch_index = nonzero[:, :1].expand(-1, extents.size(1))
            extents.scatter_reduce_(0, batch_index, nonzero[:, 1:] + 1, reduce="amax", include_self=False)
        extent_rows = extents.cpu().tolist()
        for t, em, sizes in zip(tensor_iter, mask_iter, extent_rows):
            if not _is_hierarchical_prefix_mask(em):
                raise ValueError(
                    "from_tensor_mask() with N-D masks requires each element mask to be a valid hierarchical "
                    "ragged prefix; "
                    "interior False gaps are not supported."
                )
            slices = tuple(slice(0, size) for size in sizes)
            t_slice = t[slices]
            m_slice = em[slices]
            valid_mask = m_slice
            if t_slice.dim() > m_slice.dim():
                valid_mask = m_slice.view(m_slice.shape + (1,) * (t_slice.dim() - m_slice.dim()))
            trimmed.append(t_slice.masked_fill(~valid_mask, kwargs.get("padding_value", 0.0)))
    return cls(trimmed, dtype=tensor.dtype, **kwargs)

nested_like

Python
nested_like(tensor: Tensor, strict: bool = True) -> Self

Create a new NestedTensor from a Tensor. The newly created NestedTensor will have the same shape as current NestedTensor.

Parameters:

Name Type Description Default
tensor
Tensor

The tensor to be converted to NestedTensor.

required
strict
bool

Check if the shape of tensor is the same as the current NestedTensor.

True

Examples:

Python Console Session
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> (nested_tensor == nested_tensor.nested_like(nested_tensor)).all()
tensor(True)
>>> tensor = nested_tensor.tensor
>>> (nested_tensor == nested_tensor.nested_like(tensor)).all()
tensor(True)
>>> f = nested_tensor.nested_like(torch.randn(2, 2))
Traceback (most recent call last):
...
ValueError: The shape of NestedTensor and input tensor does not match, ...
>>> p = nested_tensor.nested_like(torch.randn(2, 2), False)
>>> p = nested_tensor.nested_like(torch.randn(3, 3), False)
Traceback (most recent call last):
...
ValueError: The batch size of NestedTensor and input tensor does not match, 2 != 3
Source code in danling/tensors/nested_tensor.py
Python
def nested_like(self, tensor: Tensor, strict: bool = True) -> Self:
    r"""
    Create a new `NestedTensor` from a `Tensor`.
    The newly created `NestedTensor` will have the same shape as current `NestedTensor`.

    Args:
        tensor: The tensor to be converted to `NestedTensor`.
        strict: Check if the shape of `tensor` is the same as the current `NestedTensor`.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> (nested_tensor == nested_tensor.nested_like(nested_tensor)).all()
        tensor(True)
        >>> tensor = nested_tensor.tensor
        >>> (nested_tensor == nested_tensor.nested_like(tensor)).all()
        tensor(True)
        >>> f = nested_tensor.nested_like(torch.randn(2, 2))
        Traceback (most recent call last):
        ...
        ValueError: The shape of NestedTensor and input tensor does not match, ...
        >>> p = nested_tensor.nested_like(torch.randn(2, 2), False)
        >>> p = nested_tensor.nested_like(torch.randn(3, 3), False)
        Traceback (most recent call last):
        ...
        ValueError: The batch size of NestedTensor and input tensor does not match, 2 != 3
    """

    if isinstance(tensor, NestedTensor):
        return tensor.clone()

    if strict and self.shape != tensor.shape:
        raise ValueError(
            f"The shape of NestedTensor and input tensor does not match, {self.shape} != {tensor.shape}"
        )
    batch_dim = 0 if self.dim() <= 1 or self.batch_first else 1
    if len(self) != tensor.size(batch_dim):
        raise ValueError(
            "The batch size of NestedTensor and input tensor does not match, "
            f"{len(self)} != {tensor.size(batch_dim)}"
        )
    values = self._dense_to_packed_values(tensor)
    if values is not None:
        element_shapes = self._element_shapes
        return self.__class__._from_packed(
            values,
            self._offsets,
            self._physical_shape,
            batch_first=self.batch_first,
            padding_value=self.padding_value,
            mask_value=self.mask_value,
            pin_memory=self._pin_memory,
            outer_size=self._logical_shape,
            packed_sizes=self._packed_sizes,
            element_shapes=element_shapes,
        )
    dense_tensor = tensor.to(device=self.device)
    element_shapes = self._original_shapes()
    new_storage = []
    for idx, shape in enumerate(element_shapes):
        if self.batch_first:
            slices = (idx, *[slice(0, int(dim)) for dim in shape])
        else:
            if len(shape) == 0:
                slices = (idx,)
            else:
                slices = (slice(0, int(shape[0])), idx, *[slice(0, int(dim)) for dim in shape[1:]])
        # .contiguous() ensures storage elements don't inherit non-trivial
        # strides from the padded tensor (e.g. after transpose).
        new_storage.append(dense_tensor[slices].contiguous())
    return self.__class__(new_storage, dtype=tensor.dtype, **self._meta(include_dtype=False))

to_torch_nested

Python
to_torch_nested() -> Tensor

Create a torch.nested.nested_tensor object from self.

Examples:

Python Console Session
1
2
3
4
5
6
>>> nested_tensor = NestedTensor([[2, 3, 5], [7, 8]])
>>> nt = nested_tensor.to_torch_nested()
>>> nt.layout == torch.jagged
True
>>> nt.values()
tensor([2, 3, 5, 7, 8])
Source code in danling/tensors/nested_tensor.py
Python
def to_torch_nested(self) -> Tensor:
    r"""
    Create a `torch.nested.nested_tensor` object from `self`.

    Examples:
        >>> nested_tensor = NestedTensor([[2, 3, 5], [7, 8]])
        >>> nt = nested_tensor.to_torch_nested()
        >>> nt.layout == torch.jagged
        True
        >>> nt.values()
        tensor([2, 3, 5, 7, 8])
    """
    storage = list(self._storage)
    if not storage or all(t.dim() > 0 for t in storage):
        return nested.nested_tensor(storage, layout=torch.jagged)
    return nested.nested_tensor(storage)

unbind

Python
unbind(dim: int = 0) -> tuple[Tensor, ...]

Unbind the NestedTensor.

Source code in danling/tensors/nested_tensor.py
Python
def unbind(self, dim: int = 0) -> tuple[Tensor, ...]:
    r"""
    Unbind the NestedTensor.
    """
    return torch.unbind(self, dim=dim)

__getitem__

Python
__getitem__(
    index: (
        int | slice | list | tuple | Tensor | NestedTensor
    ),
) -> Tensor | NestedTensor

Retrieve element(s) by index, slice, list, tuple, or tensor mask.

Source code in danling/tensors/nested_tensor.py
Python
def __getitem__(self, index: int | slice | list | tuple | Tensor | NestedTensor) -> Tensor | NestedTensor:
    r"""Retrieve element(s) by index, slice, list, tuple, or tensor mask."""
    if isinstance(index, int):
        return self._storage[index]
    if isinstance(index, (slice, list)):
        if isinstance(index, list) and index and all(isinstance(i, bool) for i in index):
            if len(index) != len(self):
                raise IndexError(f"Boolean index has length {len(index)} but batch size is {len(self)}")
            index = [i for i, flag in enumerate(index) if flag]
        storage = tuple(self._storage[index] if isinstance(index, slice) else [self._storage[i] for i in index])
        return self.__class__(storage, **self._meta(include_dtype=True))
    if isinstance(index, tuple):
        if len(index) == 0:
            return self

        # Expand Ellipsis: ``nt[..., :2]`` on a 4-D NestedTensor becomes
        # ``nt[:, :, :, :2]``.  The batch dim is consumed first, so Ellipsis
        # fills the gap between the number of explicit indices and the total
        # number of logical dimensions.
        if Ellipsis in index:
            eidx = index.index(Ellipsis)
            n_explicit = len(index) - 1  # exclude Ellipsis itself
            n_expand = self.dim() - n_explicit
            index = index[:eidx] + (slice(None),) * n_expand + index[eidx + 1 :]

        batch_index, *rest = index

        if isinstance(batch_index, (Tensor, NestedTensor)):
            return self.tensor[index]

        if isinstance(batch_index, list) and batch_index and all(isinstance(i, bool) for i in batch_index):
            if len(batch_index) != len(self):
                raise IndexError(f"Boolean index has length {len(batch_index)} but batch size is {len(self)}")
            batch_index = [i for i, flag in enumerate(batch_index) if flag]

        if isinstance(batch_index, int):
            tensor = self._storage[batch_index]
            if rest:
                return tensor[tuple(rest)]
            return tensor
        elif isinstance(batch_index, (slice, list)):
            if isinstance(batch_index, slice):
                selected = self._storage[batch_index]
            else:
                selected = tuple(self._storage[i] for i in batch_index)
            if rest:
                rest_tuple = tuple(rest)
                selected = tuple(t[rest_tuple] for t in selected)
            return self.__class__(selected, **self._meta(include_dtype=True))
        raise ValueError(f"Unsupported batch index type {type(batch_index)}")
    if isinstance(index, NestedTensor):
        if len(self) != len(index):
            raise ValueError(
                "NestedTensor batch length mismatch between self and index: "
                f"self={len(self)}, index={len(index)}"
            )
        return self.__class__(
            [t[i] for t, i in zip(self._storage, index._storage)], **self._meta(include_dtype=True)
        )
    if isinstance(index, Tensor):
        if index.dim() == 0 and index.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8):
            return self._storage[int(index.item())]
        if index.dim() == 1:
            if index.dtype in (torch.bool, torch.uint8):
                if index.numel() != len(self):
                    raise IndexError(f"Boolean index has length {index.numel()} but batch size is {len(self)}")
                selected = tuple(self._storage[i] for i, flag in enumerate(index.tolist()) if bool(flag))
                return self.__class__(selected, **self._meta(include_dtype=True))
            if index.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8):
                return self.__class__(
                    [self._storage[int(i)] for i in index.tolist()],
                    **self._meta(include_dtype=True),
                )
        index = self.nested_like(index, strict=False)
        return self.__class__(
            [t[i] for t, i in zip(self._storage, index._storage)], **self._meta(include_dtype=True)
        )
    raise ValueError(f"Unsupported index type {type(index)}")

__setitem__

Python
__setitem__(
    index: int | slice | list | tuple,
    value: Tensor | NestedTensor,
) -> None

Set values in the NestedTensor at the specified index.

Parameters:

Name Type Description Default
index
int | slice | list | tuple

The index to modify. Can be an integer, slice, list, or tuple.

required
value
Tensor | NestedTensor

The new value to set. Can be a Tensor or NestedTensor.

required

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor[0] = torch.tensor([6, 7, 8])
>>> nested_tensor[0]
tensor([6, 7, 8])
>>> nested_tensor[1] = torch.tensor([9, 10, 11, 12])
>>> nested_tensor.shape
torch.Size([2, 4])
Source code in danling/tensors/nested_tensor.py
Python
def __setitem__(self, index: int | slice | list | tuple, value: Tensor | NestedTensor) -> None:
    r"""
    Set values in the NestedTensor at the specified index.

    Args:
        index: The index to modify. Can be an integer, slice, list, or tuple.
        value: The new value to set. Can be a Tensor or NestedTensor.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> nested_tensor[0] = torch.tensor([6, 7, 8])
        >>> nested_tensor[0]
        tensor([6, 7, 8])
        >>> nested_tensor[1] = torch.tensor([9, 10, 11, 12])
        >>> nested_tensor.shape
        torch.Size([2, 4])
    """
    if isinstance(index, int):
        self._invalidate_transient_caches()
        if isinstance(value, NestedTensor):
            if len(value._storage) != 1:
                raise ValueError(
                    f"When setting with an integer index, value must have a single tensor, but got {len(value)}"
                )
            value = value._storage[0]
        if not isinstance(value, Tensor):
            value = torch.tensor(value, device=self.device, dtype=self.dtype)
        else:
            value = value.to(device=self.device, dtype=self.dtype)
        if self.requires_grad:
            value.requires_grad_(True)

        # Normalize negative index
        idx = index + len(self) if index < 0 else index
        if idx < 0 or idx >= len(self):
            raise IndexError(f"index {index} is out of range for NestedTensor with {len(self)} elements")
        expected_ndim = self._physical_shape.size(1)
        if value.dim() != expected_ndim:
            raise ValueError(
                f"Assigned tensor ndim must match existing ndim {expected_ndim}, but got {value.dim()}"
            )

        old_start = int(self._offsets[idx].item())
        old_end = int(self._offsets[idx + 1].item())
        old_size = old_end - old_start
        new_shape_row = torch.tensor(list(value.shape), dtype=self._physical_shape.dtype)

        permutation = self._permutation
        identity_permutation = tuple(range(expected_ndim))
        varying_dims = self._varying_dims
        static_dims = self._static_dims
        packed_size = type(self)._packed_size_from_shape(tuple(int(dim) for dim in value.shape), varying_dims)
        packed_value = value if permutation == identity_permutation else value.permute(permutation)
        suffix_shape = tuple(int(value.shape[dim]) for dim in static_dims)
        new_payload = packed_value.reshape((packed_size, *suffix_shape) if suffix_shape else (packed_size,))
        new_size = packed_size

        if self._values.dim() > 1 and new_payload.shape[1:] != self._values.shape[1:]:
            storage_list = list(self._storage)
            storage_list[idx] = value
            self._repack(storage_list)
            return

        if new_size == old_size:
            # Same packed span size: direct overwrite keeps _values allocation.
            self._values[old_start:old_end] = new_payload
            self._physical_shape[idx] = new_shape_row
        else:
            # Different packed span size: splice _values and shift subsequent offsets.
            self._values = torch.cat([self._values[:old_start], new_payload, self._values[old_end:]], dim=0)
            delta = new_size - old_size
            self._offsets = self._offsets.clone()
            self._offsets[idx + 1 :] += delta  # noqa: E203
            self._physical_shape = self._physical_shape.clone()
            self._physical_shape[idx] = new_shape_row
        self._logical_shape = self._logical_shape_from_physical_shape(
            self._physical_shape, self._offsets, self.batch_first
        )
        if self._element_shapes is not None and self._packed_sizes is not None:
            element_shapes = list(self._element_shapes)
            element_shapes[idx] = tuple(int(dim) for dim in value.shape)
            self._element_shapes = tuple(element_shapes)
            packed_sizes = list(self._packed_sizes)
            packed_sizes[idx] = self._packed_sizes_like((self._element_shapes[idx],))[0]
            self._packed_sizes = tuple(packed_sizes)
        self._validate_metadata()
    elif isinstance(index, (slice, list)):
        if isinstance(index, list) and index and all(isinstance(i, bool) for i in index):
            if len(index) != len(self):
                raise IndexError(f"Boolean index has length {len(index)} but batch size is {len(self)}")
            index = [i for i, flag in enumerate(index) if flag]

        if isinstance(value, Tensor) and not isinstance(value, NestedTensor):
            if value.dim() > 1 and value.size(0) > 1:
                value = self.__class__(value.unbind(0), **self._meta())
            else:
                value = self.__class__([value], **self._meta())

        if isinstance(index, slice):
            start, stop, step = index.indices(len(self))
            indices = range(start, stop, step)
        else:
            indices = index  # type: ignore[assignment]

        if len(indices) != len(value._storage):
            raise ValueError(
                f"Size mismatch: tried to assign {len(value._storage)} values to {len(indices)} indices"
            )

        storage_list = list(self._storage)
        for i, idx in enumerate(indices):
            storage_list[idx] = value._storage[i]
        self._storage = tuple(storage_list)
    elif isinstance(index, tuple):
        if len(index) == 0:
            return
        if len(index) == 1:
            self[index[0]] = value
            return

        first_idx, rest_idx = index[0], index[1:]
        batch_indices: list[int]
        if isinstance(first_idx, int):
            batch_indices = [first_idx]
        elif isinstance(first_idx, (slice, list)):
            if isinstance(first_idx, list) and first_idx and all(isinstance(i, bool) for i in first_idx):
                if len(first_idx) != len(self):
                    raise IndexError(f"Boolean index has length {len(first_idx)} but batch size is {len(self)}")
                batch_indices = [i for i, flag in enumerate(first_idx) if flag]
            elif isinstance(first_idx, slice):
                start, stop, step = first_idx.indices(len(self))
                batch_indices = list(range(start, stop, step))
            else:
                batch_indices = list(first_idx)  # type: ignore[arg-type]
        else:
            raise ValueError(f"Unsupported first index type {type(first_idx)}")

        if isinstance(value, NestedTensor):
            if len(batch_indices) != len(value._storage):
                raise ValueError(
                    f"Size mismatch: tried to assign {len(value._storage)} values to {len(batch_indices)} indices"
                )
            assigned_values = list(value._storage)
        else:
            assigned_values = [value] * len(batch_indices)

        elems = list(self._storage)
        for position, idx in enumerate(batch_indices):
            elem = elems[idx].clone()
            elem[rest_idx] = assigned_values[position]
            elems[idx] = elem
        self._storage = tuple(elems)
    else:
        raise ValueError(f"Unsupported index type {type(index)}")

__copy__

Python
__copy__()

Shallow copy: new NestedTensor sharing underlying tensor data.

Source code in danling/tensors/nested_tensor.py
Python
def __copy__(self):
    r"""Shallow copy: new NestedTensor sharing underlying tensor data."""
    return self.__class__._from_packed(
        self._values,
        self._offsets,
        self._physical_shape,
        permutation=self._permutation,
        batch_first=self.batch_first,
        padding_value=self.padding_value,
        mask_value=self.mask_value,
        pin_memory=self._pin_memory,
        outer_size=self._logical_shape,
        packed_sizes=self._packed_sizes,
        element_shapes=self._element_shapes,
    )

__deepcopy__

Python
__deepcopy__(memo)

Deep copy: clones all tensor data.

Source code in danling/tensors/nested_tensor.py
Python
def __deepcopy__(self, memo):
    r"""Deep copy: clones all tensor data."""
    result = self.__class__._from_packed(
        self._values.clone(),
        self._offsets.clone(),
        self._physical_shape.clone(),
        permutation=self._permutation,
        batch_first=self.batch_first,
        padding_value=self.padding_value,
        mask_value=self.mask_value,
        pin_memory=self._pin_memory,
        outer_size=self._logical_shape,
        packed_sizes=self._packed_sizes,
        element_shapes=self._element_shapes,
    )
    memo[id(self)] = result
    return result

all

Python
all(
    dim: int | None = None, keepdim: bool = False
) -> bool | Tensor | NestedTensor

Tests if all elements in NestedTensor evaluate to True.

Examples:

Python Console Session
>>> nested_tensor = NestedTensor([torch.ones(2, 4, dtype=torch.bool), torch.ones(3, 5, dtype=torch.bool)])
>>> nested_tensor.all()
tensor(True)
>>> nested_tensor.all(dim=0)
tensor([True, True])
>>> nested_tensor.all(dim=0, keepdim=True)
tensor([[True, True]])
>>> nested_tensor.all(dim=1)
NestedTensor([
    [True, True, True, True],
    [True, True, True, True, True]
])
>>> nested_tensor.all(dim=1, keepdim=True)
NestedTensor([
    [[True, True, True, True]],
    [[True, True, True, True, True]]
])
>>> nested_tensor.batch_first = False
>>> nested_tensor.all(dim=1)
tensor([True, True])
>>> nested_tensor.all(dim=0)
NestedTensor([
    [True, True, True, True],
    [True, True, True, True, True]
])
>>> nested_tensor.all(dim=-2)
tensor([True, True])
Source code in danling/tensors/nested_tensor.py
Python
def all(self, dim: int | None = None, keepdim: bool = False) -> bool | Tensor | NestedTensor:
    r"""
    Tests if all elements in NestedTensor evaluate to True.

    Examples:
        >>> nested_tensor = NestedTensor([torch.ones(2, 4, dtype=torch.bool), torch.ones(3, 5, dtype=torch.bool)])
        >>> nested_tensor.all()
        tensor(True)
        >>> nested_tensor.all(dim=0)
        tensor([True, True])
        >>> nested_tensor.all(dim=0, keepdim=True)
        tensor([[True, True]])
        >>> nested_tensor.all(dim=1)
        NestedTensor([
            [True, True, True, True],
            [True, True, True, True, True]
        ])
        >>> nested_tensor.all(dim=1, keepdim=True)
        NestedTensor([
            [[True, True, True, True]],
            [[True, True, True, True, True]]
        ])
        >>> nested_tensor.batch_first = False
        >>> nested_tensor.all(dim=1)
        tensor([True, True])
        >>> nested_tensor.all(dim=0)
        NestedTensor([
            [True, True, True, True],
            [True, True, True, True, True]
        ])
        >>> nested_tensor.all(dim=-2)
        tensor([True, True])
    """
    return torch.all(self, dim=dim, keepdim=keepdim)

any

Python
any(
    dim: int | None = None, keepdim: bool = False
) -> bool | Tensor | NestedTensor

Tests if any elements in NestedTensor evaluate to True.

Examples:

Python Console Session
1
2
3
4
5
>>> nested_tensor = NestedTensor([torch.zeros(2, dtype=torch.bool), torch.ones(3, dtype=torch.bool)])
>>> nested_tensor.any()
tensor(True)
>>> nested_tensor.any(dim=0)
tensor([False,  True])
Source code in danling/tensors/nested_tensor.py
Python
def any(self, dim: int | None = None, keepdim: bool = False) -> bool | Tensor | NestedTensor:
    r"""
    Tests if any elements in NestedTensor evaluate to True.

    Examples:
        >>> nested_tensor = NestedTensor([torch.zeros(2, dtype=torch.bool), torch.ones(3, dtype=torch.bool)])
        >>> nested_tensor.any()
        tensor(True)
        >>> nested_tensor.any(dim=0)
        tensor([False,  True])
    """
    return torch.any(self, dim=dim, keepdim=keepdim)

dim

Python
dim() -> int

Number of dimension of the NestedTensor.

Examples:

Python Console Session
1
2
3
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.dim()
2
Source code in danling/tensors/nested_tensor.py
Python
def dim(self) -> int:
    r"""
    Number of dimension of the NestedTensor.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> nested_tensor.dim()
        2
    """
    if not hasattr(self, "_logical_shape"):
        with torch._C.DisableTorchFunctionSubclass():
            return len(torch.Tensor.size(self))
    return len(self._logical_shape)

max

Python
max(
    dim: int | None = None, keepdim: bool = False
) -> Tensor | NestedTensor

Return the maximum value, optionally along a given dimension.

Source code in danling/tensors/nested_tensor.py
Python
def max(self, dim: int | None = None, keepdim: bool = False) -> Tensor | NestedTensor:
    r"""Return the maximum value, optionally along a given dimension."""
    if dim is None:
        return torch.max(self)
    return torch.max(self, dim=dim, keepdim=keepdim)

mean

Python
mean(
    dim: int | None = None,
    keepdim: bool = False,
    *,
    dtype: dtype | None = None
) -> Tensor | NestedTensor

Return the mean value, optionally along a given dimension.

Source code in danling/tensors/nested_tensor.py
Python
def mean(
    self,
    dim: int | None = None,
    keepdim: bool = False,
    *,
    dtype: torch.dtype | None = None,  # type: ignore[name-defined]
) -> Tensor | NestedTensor:
    r"""Return the mean value, optionally along a given dimension."""
    return torch.mean(self, dim=dim, keepdim=keepdim, dtype=dtype)

min

Python
min(
    dim: int | None = None, keepdim: bool = False
) -> Tensor | NestedTensor

Return the minimum value, optionally along a given dimension.

Source code in danling/tensors/nested_tensor.py
Python
def min(self, dim: int | None = None, keepdim: bool = False) -> Tensor | NestedTensor:
    r"""Return the minimum value, optionally along a given dimension."""
    if dim is None:
        return torch.min(self)
    return torch.min(self, dim=dim, keepdim=keepdim)

numel

Python
numel() -> int

Number of elements in the NestedTensor.

Examples:

Python Console Session
1
2
3
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.numel()
5
Source code in danling/tensors/nested_tensor.py
Python
def numel(self) -> int:
    r"""
    Number of elements in the NestedTensor.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> nested_tensor.numel()
        5
    """
    return self._values.numel()

permute

Python
permute(*dims) -> Self

Apply permutation to each tensor in the NestedTensor.

Parameters:

Name Type Description Default
*dims

The desired ordering of dimensions for the NestedTensor (including batch dimension).

()

Returns:

Name Type Description
NestedTensor Self

A new NestedTensor with each tensor permuted.

Examples:

Python Console Session
1
2
3
4
>>> nested_tensor = NestedTensor([torch.randn(3, 4, 5), torch.randn(2, 4, 5)])
>>> permuted = nested_tensor.permute(0, 3, 1, 2)
>>> permuted.shape
torch.Size([2, 5, 3, 4])
Source code in danling/tensors/nested_tensor.py
Python
def permute(self, *dims) -> Self:
    r"""
    Apply permutation to each tensor in the NestedTensor.

    Args:
        *dims: The desired ordering of dimensions for the NestedTensor (including batch dimension).

    Returns:
        NestedTensor: A new NestedTensor with each tensor permuted.

    Examples:
        >>> nested_tensor = NestedTensor([torch.randn(3, 4, 5), torch.randn(2, 4, 5)])
        >>> permuted = nested_tensor.permute(0, 3, 1, 2)
        >>> permuted.shape
        torch.Size([2, 5, 3, 4])
    """
    return torch.permute(self, dims)

moveaxis

Python
moveaxis(source, destination) -> Self

Move per-element dimensions to new positions.

Source code in danling/tensors/nested_tensor.py
Python
def moveaxis(self, source, destination) -> Self:
    r"""Move per-element dimensions to new positions."""
    return torch.moveaxis(self, source, destination)

movedim

Python
movedim(source, destination) -> Self

Alias for moveaxis().

Source code in danling/tensors/nested_tensor.py
Python
def movedim(self, source, destination) -> Self:
    r"""Alias for `moveaxis()`."""
    return torch.movedim(self, source, destination)

pin_memory

Python
pin_memory() -> Self

Pin the underlying tensor memory for faster host-to-device transfer.

Source code in danling/tensors/nested_tensor.py
Python
def pin_memory(self) -> Self:
    r"""Pin the underlying tensor memory for faster host-to-device transfer."""
    return type(self)._from_packed(
        self._values.pin_memory(),
        self._offsets,
        self._physical_shape,
        batch_first=self.batch_first,
        padding_value=self.padding_value,
        mask_value=self.mask_value,
        pin_memory=True,
        outer_size=self._logical_shape,
        packed_sizes=self._packed_sizes,
        element_shapes=self._element_shapes,
    )

prod

Python
prod(
    dim: int | None = None,
    keepdim: bool = False,
    *,
    dtype: dtype | None = None
) -> Tensor | NestedTensor

Return the product of elements, optionally along a given dimension.

Source code in danling/tensors/nested_tensor.py
Python
def prod(
    self,
    dim: int | None = None,
    keepdim: bool = False,
    *,
    dtype: torch.dtype | None = None,  # type: ignore[name-defined]
) -> Tensor | NestedTensor:
    r"""Return the product of elements, optionally along a given dimension."""
    return torch.prod(self, dim=dim, keepdim=keepdim, dtype=dtype)

requires_grad_

Python
requires_grad_(requires_grad: bool = True)

Enable or disable gradient computation in-place.

Source code in danling/tensors/nested_tensor.py
Python
def requires_grad_(self, requires_grad: bool = True):
    r"""Enable or disable gradient computation in-place."""
    self.requires_grad = requires_grad
    return self

reshape

Python
reshape(*shape) -> Self

Reshape each tensor in the NestedTensor.

Parameters:

Name Type Description Default
*shape

The desired size of each dimension for the underlying tensors.

()

Returns:

Name Type Description
NestedTensor Self

A new NestedTensor with each tensor reshaped.

Examples:

Python Console Session
1
2
3
4
>>> nested_tensor = NestedTensor([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]])])
>>> reshaped = nested_tensor.reshape(4)
>>> reshaped.shape
torch.Size([2, 4])
Source code in danling/tensors/nested_tensor.py
Python
def reshape(self, *shape) -> Self:
    r"""
    Reshape each tensor in the NestedTensor.

    Args:
        *shape: The desired size of each dimension for the underlying tensors.

    Returns:
        NestedTensor: A new NestedTensor with each tensor reshaped.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]])])
        >>> reshaped = nested_tensor.reshape(4)
        >>> reshaped.shape
        torch.Size([2, 4])
    """
    if not shape:
        raise TypeError("reshape() missing shape")
    target_shape = shape[0] if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)) else shape
    return torch.reshape(self, target_shape)

flatten

Python
flatten(start_dim: int = 0, end_dim: int = -1)

Flatten each tensor in the NestedTensor.

Source code in danling/tensors/nested_tensor.py
Python
def flatten(self, start_dim: int = 0, end_dim: int = -1):
    r"""Flatten each tensor in the NestedTensor."""
    return torch.flatten(self, start_dim=start_dim, end_dim=end_dim)

flip

Python
flip(dims) -> Self

Flip each tensor in the NestedTensor along the given dimensions.

Source code in danling/tensors/nested_tensor.py
Python
def flip(self, dims) -> Self:
    r"""Flip each tensor in the NestedTensor along the given dimensions."""
    return torch.flip(self, dims)

size

Python
size(dim: int | None = None) -> Size | int

Returns the size of the self NestedTensor.

Parameters:

Name Type Description Default
dim
int | None

If not specified, the returned value is a torch.Size, a subclass of tuple. If specified, returns an int holding the size of that dimension. Defaults to None.

None

Examples:

Python Console Session
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.size()
torch.Size([2, 3])
>>> nested_tensor.size(0)
2
>>> nested_tensor[1] = torch.tensor([4, 5, 6, 7])
>>> nested_tensor.shape
torch.Size([2, 4])
>>> nested_tensor.size(1)
4
Source code in danling/tensors/nested_tensor.py
Python
def size(self, dim: int | None = None) -> torch.Size | int:  # type: ignore[override, name-defined]
    r"""
    Returns the size of the self `NestedTensor`.

    Args:
        dim: If not specified, the returned value is a `torch.Size`, a subclass of `tuple`.
            If specified, returns an `int` holding the size of that dimension.
            Defaults to `None`.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> nested_tensor.size()
        torch.Size([2, 3])
        >>> nested_tensor.size(0)
        2
        >>> nested_tensor[1] = torch.tensor([4, 5, 6, 7])
        >>> nested_tensor.shape
        torch.Size([2, 4])
        >>> nested_tensor.size(1)
        4
    """
    if hasattr(self, "_logical_shape"):
        full_size = self._logical_shape
    else:
        with torch._C.DisableTorchFunctionSubclass():
            full_size = torch.Tensor.size(self)
    if dim is not None:
        dim = dim + len(full_size) if dim < 0 else dim
        return full_size[dim]
    return full_size

sum

Python
sum(
    dim: int | Sequence[int] | None = None,
    keepdim: bool = False,
    *,
    dtype: dtype | None = None
) -> Tensor | NestedTensor

Returns the sum of each tensor over the given dimension(s).

Parameters:

Name Type Description Default
dim
int | Sequence[int] | None

The dimension or dimensions to reduce. If None, sum over all dimensions. Supports int, Sequence[int], or None. Negative dimensions are supported.

None
keepdim
bool

Whether to retain reduced dimensions with size 1.

False
dtype
dtype | None

The desired data type of returned tensor.

None

Returns:

Type Description
Tensor | NestedTensor

Tensor or NestedTensor depending on the dimensions being reduced.

Examples:

Python Console Session
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.sum()
tensor(15)
>>> nested_tensor.sum(dim=0)  # when dim=0, sum across batch dimension
tensor([6, 9])
>>> nested_tensor.sum(dim=1)
tensor([6, 9])
>>> nested_tensor.sum(dim=[0, 1])
tensor(15)
>>> nested_tensor.sum(dim=0, keepdim=True)
tensor([[6, 9]])
>>> nested_tensor.sum(dtype=torch.float32)
tensor(15.)
Source code in danling/tensors/nested_tensor.py
Python
def sum(
    self,
    dim: int | Sequence[int] | None = None,
    keepdim: bool = False,
    *,
    dtype: torch.dtype | None = None,  # type: ignore[name-defined]
) -> Tensor | NestedTensor:
    r"""
    Returns the sum of each tensor over the given dimension(s).

    Args:
        dim: The dimension or dimensions to reduce. If None, sum over all dimensions.
            Supports int, Sequence[int], or None. Negative dimensions are supported.
        keepdim: Whether to retain reduced dimensions with size 1.
        dtype: The desired data type of returned tensor.

    Returns:
        Tensor or NestedTensor depending on the dimensions being reduced.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> nested_tensor.sum()
        tensor(15)
        >>> nested_tensor.sum(dim=0)  # when dim=0, sum across batch dimension
        tensor([6, 9])
        >>> nested_tensor.sum(dim=1)
        tensor([6, 9])
        >>> nested_tensor.sum(dim=[0, 1])
        tensor(15)
        >>> nested_tensor.sum(dim=0, keepdim=True)
        tensor([[6, 9]])
        >>> nested_tensor.sum(dtype=torch.float32)
        tensor(15.)
    """
    return torch.sum(self, dim=dim, keepdim=keepdim, dtype=dtype)

tolist

Python
tolist() -> list

Convert a NestedTensor to a list of lists of values.

Examples:

Python Console Session
1
2
3
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.tolist()
[[1, 2, 3], [4, 5]]
Source code in danling/tensors/nested_tensor.py
Python
def tolist(self) -> list:
    r"""
    Convert a NestedTensor to a list of lists of values.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> nested_tensor.tolist()
        [[1, 2, 3], [4, 5]]
    """
    return [t.tolist() for t in self._storage]

transpose

Python
transpose(dim0: int, dim1: int) -> Self

Transpose dimensions dim0 and dim1 for each tensor in the NestedTensor.

Parameters:

Name Type Description Default
dim0
int

First dimension to transpose (in NestedTensor coordinate system).

required
dim1
int

Second dimension to transpose (in NestedTensor coordinate system).

required

Returns:

Name Type Description
NestedTensor Self

A new NestedTensor with each tensor transposed.

Examples:

Python Console Session
1
2
3
4
5
>>> nested_tensor = NestedTensor([torch.randn(3, 4), torch.randn(2, 4)])
>>> # NestedTensor shape is [2, 3, 4], underlying tensors are [3, 4] and [2, 4]
>>> transposed = nested_tensor.transpose(1, 2)  # transpose dims 1 and 2
>>> transposed.shape  # batch dimension is still first
torch.Size([2, 4, 3])
Source code in danling/tensors/nested_tensor.py
Python
def transpose(self, dim0: int, dim1: int) -> Self:  # type: ignore[valid-type]
    r"""
    Transpose dimensions dim0 and dim1 for each tensor in the NestedTensor.

    Args:
        dim0: First dimension to transpose (in NestedTensor coordinate system).
        dim1: Second dimension to transpose (in NestedTensor coordinate system).

    Returns:
        NestedTensor: A new NestedTensor with each tensor transposed.

    Examples:
        >>> nested_tensor = NestedTensor([torch.randn(3, 4), torch.randn(2, 4)])
        >>> # NestedTensor shape is [2, 3, 4], underlying tensors are [3, 4] and [2, 4]
        >>> transposed = nested_tensor.transpose(1, 2)  # transpose dims 1 and 2
        >>> transposed.shape  # batch dimension is still first
        torch.Size([2, 4, 3])
    """
    return torch.transpose(self, dim0, dim1)

swapaxes

Python
swapaxes(axis0: int, axis1: int) -> Self

Alias for transpose().

Source code in danling/tensors/nested_tensor.py
Python
def swapaxes(self, axis0: int, axis1: int) -> Self:
    r"""Alias for `transpose()`."""
    return torch.swapaxes(self, axis0, axis1)

swapdims

Python
swapdims(dim0: int, dim1: int) -> Self

Alias for swapaxes().

Source code in danling/tensors/nested_tensor.py
Python
def swapdims(self, dim0: int, dim1: int) -> Self:
    r"""Alias for `swapaxes()`."""
    return torch.swapdims(self, dim0, dim1)

squeeze

Python
squeeze(dim: int | None = None) -> Self

Squeeze singleton dimensions from each tensor in the NestedTensor.

Source code in danling/tensors/nested_tensor.py
Python
def squeeze(self, dim: int | None = None) -> Self:  # type: ignore[valid-type]
    r"""Squeeze singleton dimensions from each tensor in the NestedTensor."""
    return torch.squeeze(self, dim=dim)

unsqueeze

Python
unsqueeze(dim: int) -> Self

Unsqueeze each tensor in the NestedTensor by adding a singleton dimension at the specified position.

Parameters:

Name Type Description Default
dim
int

The dimension at which to add the singleton dimension. This is in the NestedTensor’s coordinate system (where dim 0 is the batch dimension).

required

Returns:

Name Type Description
NestedTensor Self

A new NestedTensor with each tensor unsqueezed at the specified dimension.

Examples:

Python Console Session
1
2
3
4
5
6
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> # Original shape: [2, 3] (batch_size=2, max_seq_len=3)
>>> unsqueezed = nested_tensor.unsqueeze(1)
>>> unsqueezed.shape
torch.Size([2, 1, 3])
>>> # Now each underlying tensor has shape [1, seq_len] instead of [seq_len]
Python Console Session
1
2
3
4
5
6
>>> nested_tensor_2d = NestedTensor([torch.randn(3, 4), torch.randn(2, 4)])
>>> # Original shape: [2, 3, 4] (batch_size=2, max_len1=3, max_len2=4)
>>> unsqueezed_2d = nested_tensor_2d.unsqueeze(2)
>>> unsqueezed_2d.shape
torch.Size([2, 3, 1, 4])
>>> # Now each underlying tensor has shape [len1, 1, len2] instead of [len1, len2]
Source code in danling/tensors/nested_tensor.py
Python
def unsqueeze(self, dim: int) -> Self:  # type: ignore[valid-type]
    r"""
    Unsqueeze each tensor in the NestedTensor by adding a singleton dimension at the specified position.

    Args:
        dim: The dimension at which to add the singleton dimension. This is in the NestedTensor's
            coordinate system (where dim 0 is the batch dimension).

    Returns:
        NestedTensor: A new NestedTensor with each tensor unsqueezed at the specified dimension.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> # Original shape: [2, 3] (batch_size=2, max_seq_len=3)
        >>> unsqueezed = nested_tensor.unsqueeze(1)
        >>> unsqueezed.shape
        torch.Size([2, 1, 3])
        >>> # Now each underlying tensor has shape [1, seq_len] instead of [seq_len]

        >>> nested_tensor_2d = NestedTensor([torch.randn(3, 4), torch.randn(2, 4)])
        >>> # Original shape: [2, 3, 4] (batch_size=2, max_len1=3, max_len2=4)
        >>> unsqueezed_2d = nested_tensor_2d.unsqueeze(2)
        >>> unsqueezed_2d.shape
        torch.Size([2, 3, 1, 4])
        >>> # Now each underlying tensor has shape [len1, 1, len2] instead of [len1, len2]
    """
    return torch.unsqueeze(self, dim)

unflatten

Python
unflatten(dim: int, sizes) -> Self

Unflatten one dimension of each tensor in the NestedTensor.

Source code in danling/tensors/nested_tensor.py
Python
def unflatten(self, dim: int, sizes) -> Self:  # type: ignore[valid-type]
    r"""Unflatten one dimension of each tensor in the NestedTensor."""
    return torch.unflatten(self, dim, sizes)

roll

Python
roll(shifts, dims=None) -> Self

Roll each tensor in the NestedTensor along the given dimensions.

Source code in danling/tensors/nested_tensor.py
Python
def roll(self, shifts, dims=None) -> Self:
    r"""Roll each tensor in the NestedTensor along the given dimensions."""
    return torch.roll(self, shifts, dims=dims)

rot90

Python
rot90(k: int = 1, dims: Sequence[int] = (0, 1)) -> Self

Rotate each tensor in the NestedTensor by 90 degrees in the given plane.

Source code in danling/tensors/nested_tensor.py
Python
def rot90(self, k: int = 1, dims: Sequence[int] = (0, 1)) -> Self:
    r"""Rotate each tensor in the NestedTensor by 90 degrees in the given plane."""
    return torch.rot90(self, k, dims)

view

Python
view(*shape) -> Self

View each tensor in the NestedTensor with a different shape.

Parameters:

Name Type Description Default
*shape

The desired size of each dimension for the underlying tensors.

()

Returns:

Name Type Description
NestedTensor Self

A new NestedTensor with each tensor viewed with the new shape.

Examples:

Python Console Session
1
2
3
4
5
6
>>> nested_tensor = NestedTensor([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]])])
>>> viewed = nested_tensor.view(4)  # View each 2x2 tensor as 4
>>> viewed.shape
torch.Size([2, 4])
>>> type(viewed).__name__
'NestedTensor'
Source code in danling/tensors/nested_tensor.py
Python
def view(self, *shape) -> Self:
    r"""
    View each tensor in the NestedTensor with a different shape.

    Args:
        *shape: The desired size of each dimension for the underlying tensors.

    Returns:
        NestedTensor: A new NestedTensor with each tensor viewed with the new shape.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]])])
        >>> viewed = nested_tensor.view(4)  # View each 2x2 tensor as 4
        >>> viewed.shape
        torch.Size([2, 4])
        >>> type(viewed).__name__
        'NestedTensor'
    """
    if not shape:
        raise TypeError("view() missing shape")
    target_shape = shape[0] if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)) else shape
    return NestedTensorAtenRegistry[torch.ops.aten.view.default](
        torch.ops.aten.view.default, (self, list(target_shape)), {}
    )

where

Python
where(
    condition: Tensor | NestedTensor,
    other: Tensor | NestedTensor | SupportsFloat,
) -> Self

Return a NestedTensor of elements selected from either self or other, depending on condition.

Examples:

Python Console Session
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.where(nested_tensor > 2, torch.tensor([[6, 5, 4], [3, 2, 1]]))
NestedTensor([
    [6, 5, 3],
    [4, 5]
])
>>> nested_tensor.where(nested_tensor > 2, NestedTensor([[6, 5, 4], [3, 2]]))
NestedTensor([
    [6, 5, 3],
    [4, 5]
])
>>> nested_tensor.where(torch.tensor(True), NestedTensor([[6, 5, 4], [3, 2]]))
NestedTensor([
    [1, 2, 3],
    [4, 5]
])
Source code in danling/tensors/nested_tensor.py
Python
def where(self, condition: Tensor | NestedTensor, other: Tensor | NestedTensor | SupportsFloat) -> Self:
    r"""
    Return a NestedTensor of elements selected from either self or other, depending on condition.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> nested_tensor.where(nested_tensor > 2, torch.tensor([[6, 5, 4], [3, 2, 1]]))
        NestedTensor([
            [6, 5, 3],
            [4, 5]
        ])
        >>> nested_tensor.where(nested_tensor > 2, NestedTensor([[6, 5, 4], [3, 2]]))
        NestedTensor([
            [6, 5, 3],
            [4, 5]
        ])
        >>> nested_tensor.where(torch.tensor(True), NestedTensor([[6, 5, 4], [3, 2]]))
        NestedTensor([
            [1, 2, 3],
            [4, 5]
        ])
    """
    return torch.where(condition, self, other)

PNTensor

Bases: Tensor

A tensor wrapper that can be collated into NestedTensor with PyTorch DataLoader.

PNTensor (Potential Nested Tensor) seamlessly bridges the gap between individual tensors and batched NestedTensor objects in PyTorch workflows. It’s designed specifically to work with PyTorch’s DataLoader collation mechanism, allowing datasets to return variable-length tensors that can be combined into a NestedTensor when batched.

The class provides three properties that mirror those of NestedTensor: - .tensor: The tensor itself (self) - .mask: A tensor of ones with the same shape as self - .concat: The tensor itself (self)

Examples:

Basic usage with PyTorch DataLoader:

Python Console Session
>>> from torch.utils.data import Dataset, DataLoader
>>> from danling.tensors import PNTensor
>>> class VariableLengthDataset(Dataset):
...     def __init__(self, data):
...         self.data = data
...     def __len__(self):
...         return len(self.data)
...     def __getitem__(self, idx):
...         return PNTensor(self.data[idx])
>>> # Create a dataset with variable-length sequences
>>> dataset = VariableLengthDataset([[1, 2, 3], [4, 5], [6, 7, 8, 9]])
>>> dataloader = DataLoader(dataset, batch_size=3)
>>> # The DataLoader produces NestedTensor batches by default for PNTensor
>>> batch = next(iter(dataloader))
>>> batch
NestedTensor([
    [1., 2., 3.],
    [4., 5.],
    [6., 7., 8., 9.]
])

Using PNTensor directly:

Python Console Session
1
2
3
4
5
6
7
>>> tensor = PNTensor([1, 2, 3])
>>> tensor
PNTensor([1., 2., 3.])
>>> tensor.tensor
PNTensor([1., 2., 3.])
>>> tensor.mask
PNTensor([True, True, True])
Source code in danling/tensors/pn_tensor.py
Python
class PNTensor(Tensor):
    r"""
    A tensor wrapper that can be collated into NestedTensor with PyTorch DataLoader.

    `PNTensor` (Potential Nested Tensor) seamlessly bridges the gap between individual tensors
    and batched `NestedTensor` objects in PyTorch workflows. It's designed specifically to work
    with PyTorch's DataLoader collation mechanism, allowing datasets to return variable-length
    tensors that can be combined into a `NestedTensor` when batched.

    The class provides three properties that mirror those of NestedTensor:
    - `.tensor`: The tensor itself (self)
    - `.mask`: A tensor of ones with the same shape as self
    - `.concat`: The tensor itself (self)

    Attributes:
        Inherits all attributes from torch.Tensor

    Methods:
        Inherits all methods from torch.Tensor

    Examples:
        Basic usage with PyTorch DataLoader:

        >>> from torch.utils.data import Dataset, DataLoader
        >>> from danling.tensors import PNTensor
        >>> class VariableLengthDataset(Dataset):
        ...     def __init__(self, data):
        ...         self.data = data
        ...     def __len__(self):
        ...         return len(self.data)
        ...     def __getitem__(self, idx):
        ...         return PNTensor(self.data[idx])
        >>> # Create a dataset with variable-length sequences
        >>> dataset = VariableLengthDataset([[1, 2, 3], [4, 5], [6, 7, 8, 9]])
        >>> dataloader = DataLoader(dataset, batch_size=3)
        >>> # The DataLoader produces NestedTensor batches by default for PNTensor
        >>> batch = next(iter(dataloader))
        >>> batch
        NestedTensor([
            [1., 2., 3.],
            [4., 5.],
            [6., 7., 8., 9.]
        ])

        Using PNTensor directly:

        >>> tensor = PNTensor([1, 2, 3])
        >>> tensor
        PNTensor([1., 2., 3.])
        >>> tensor.tensor
        PNTensor([1., 2., 3.])
        >>> tensor.mask
        PNTensor([True, True, True])
    """

    @property
    def tensor(self) -> Tensor:
        r"""
        Identical to `self`.

        Returns:
            (torch.Tensor):

        Examples:
            >>> tensor = torch.tensor([1, 2, 3])
            >>> pn_tensor = PNTensor([1, 2, 3])
            >>> bool((tensor == pn_tensor).all())
            True
            >>> bool((tensor == pn_tensor.tensor).all())
            True
        """

        return self

    @property
    def mask(self) -> Tensor:
        r"""
        All-True boolean mask (PNTensor has no padding).

        Returns a stride-0 expanded view — no memory allocation beyond a scalar.

        Returns:
            (torch.Tensor):

        Examples:
            >>> tensor = torch.tensor([1, 2, 3])
            >>> pn_tensor = PNTensor([1, 2, 3])
            >>> bool((pn_tensor.mask == torch.ones_like(pn_tensor)).all().item())
            True
        """

        return torch.ones((), dtype=torch.bool, device=self.device).expand_as(self)

    @property
    def concat(self) -> Tensor:
        r"""
        Identical to `self`.

        Returns:
            (torch.Tensor):

        Examples:
            >>> pn_tensor = PNTensor([1, 2, 3])
            >>> bool((pn_tensor == pn_tensor.concat).all())
            True
        """

        return self

    def new_empty(self, *args, **kwargs):
        r"""Return a new empty PNTensor with the same type."""
        return PNTensor(super().new_empty(*args, **kwargs))

tensor property

Python
tensor: Tensor

Identical to self.

Returns:

Type Description
Tensor

Examples:

Python Console Session
1
2
3
4
5
6
>>> tensor = torch.tensor([1, 2, 3])
>>> pn_tensor = PNTensor([1, 2, 3])
>>> bool((tensor == pn_tensor).all())
True
>>> bool((tensor == pn_tensor.tensor).all())
True

mask property

Python
mask: Tensor

All-True boolean mask (PNTensor has no padding).

Returns a stride-0 expanded view — no memory allocation beyond a scalar.

Returns:

Type Description
Tensor

Examples:

Python Console Session
1
2
3
4
>>> tensor = torch.tensor([1, 2, 3])
>>> pn_tensor = PNTensor([1, 2, 3])
>>> bool((pn_tensor.mask == torch.ones_like(pn_tensor)).all().item())
True

concat property

Python
concat: Tensor

Identical to self.

Returns:

Type Description
Tensor

Examples:

Python Console Session
1
2
3
>>> pn_tensor = PNTensor([1, 2, 3])
>>> bool((pn_tensor == pn_tensor.concat).all())
True

new_empty

Python
new_empty(*args, **kwargs)

Return a new empty PNTensor with the same type.

Source code in danling/tensors/pn_tensor.py
Python
def new_empty(self, *args, **kwargs):
    r"""Return a new empty PNTensor with the same type."""
    return PNTensor(super().new_empty(*args, **kwargs))

ensure_dir

Bases: property

Ensure a directory property exists.

Examples:

Python Console Session
1
2
3
>>> @ensure_dir
... def dir(self) -> str:
...     return os.path.join("path", "to", "dir")
Source code in danling/utils/descriptors.py
Python
class ensure_dir(property):
    r"""
    Ensure a directory property exists.

    Examples:
        >>> @ensure_dir
        ... def dir(self) -> str:
        ...     return os.path.join("path", "to", "dir")
    """

    def __get__(self, instance, owner=None):
        val = super().__get__(instance, owner)
        makedirs(val, exist_ok=True)
        return val

GlobalMetrics

Data container for metrics descriptors.

The container aggregates required artifacts (preds/targets, confusion matrix, running stats) only once, synchronises them across processes, and lets descriptors compute metric values without duplicating work.

Source code in danling/metrics/global_metrics.py
Python
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
class GlobalMetrics:
    """
    Data container for metrics descriptors.

    The container aggregates required artifacts (preds/targets, confusion
    matrix, running stats) only once, synchronises them across processes,
    and lets descriptors compute metric values without duplicating work.
    """

    _local_n: int = 0
    _local_count: int = 0

    def __init__(
        self,
        *metric_funcs,
        preprocess: Callable = base_preprocess,
        distributed: bool = True,
        device: torch.device | str | None = None,
        **meters,
    ) -> None:
        positional: list[tuple[str, MetricFunc]] = []
        for metric in iter_metric_funcs(metric_funcs):
            metric = self._coerce_metric(metric)
            positional.append((metric.name, metric))

        named: dict[str, MetricFunc] = {}
        for name, metric in meters.items():
            named[name] = self._coerce_metric(metric)

        metric_map = merge_metric_entries(positional, named)
        self.metrics = metric_map
        self.requirements = MetricState.collect_requirements(tuple(self.metrics.values()), require_nonempty=True)
        self.preprocess = preprocess
        self.distributed = distributed
        self.device = torch.device(device) if device is not None else None

        self._artifacts = _ArtifactState()
        self._sync = _SyncState()
        self._local_n = 0
        self._local_count = 0

        self._artifact_version = 0
        self._cache: dict[str, tuple[int, Tensor | float]] = {}

    # Construction
    @staticmethod
    def _coerce_metric(value: MetricFunc) -> MetricFunc:
        if not isinstance(value, MetricFunc):
            raise ValueError(f"Expected metric functions to be MetricFunc instances, got {type(value)}")
        return value

    # Lifecycle
    def update(self, input: Tensor | NestedTensor | Sequence, target: Tensor | NestedTensor | Sequence) -> None:
        artifacts = self._artifacts
        input, target = self.preprocess(input, target)
        if isinstance(input, NestedTensor):
            input = input.concat
        if isinstance(target, NestedTensor):
            target = target.concat

        artifacts.last_preds = input.detach()
        artifacts.last_targets = target.detach()
        self._local_n = self._infer_batch_count(target)
        self._local_count += self._local_n

        if self.requirements["preds_targets"]:
            stored_input = self._detach_to_device(input)
            stored_target = self._detach_to_device(target)
            artifacts.preds.append(stored_input)
            artifacts.targets.append(stored_target)
            artifacts.pending_preds.append(stored_input)
            artifacts.pending_targets.append(stored_target)
            artifacts.preds_template = self._empty_artifact(stored_input)
            artifacts.targets_template = self._empty_artifact(stored_target)
        if self.requirements["confmat"]:
            batch_confmat = MetricState.compute_confmat(input, target, self.requirements)
            if batch_confmat is None:
                raise MetricRequirementError("Confusion matrix requested but required tensors are not available.")
            artifacts.last_confmat = batch_confmat
            artifacts.confmat = batch_confmat if artifacts.confmat is None else artifacts.confmat + batch_confmat

        self._artifact_version += 1
        self._cache.clear()
        self._mark_sync_stale()

    def sync(self) -> None:
        artifacts = self._artifacts
        sync = self._sync
        world_size = get_world_size() if self.distributed else 1
        if sync.synced and sync.world_size == world_size:
            return

        synced_pred_chunks = sync.pred_chunks
        synced_target_chunks = sync.target_chunks
        synced_preds = sync.preds if sync.preds is not None else self._local_preds()
        synced_targets = sync.targets if sync.targets is not None else self._local_targets()
        synced_confmat = artifacts.confmat

        if self.distributed:
            if world_size > 1:
                if self.requirements["preds_targets"]:
                    if self._requires_full_artifact_resync(world_size):
                        local_pred_tensor = self._local_artifact(artifacts.preds, artifacts.preds_template)
                        local_target_tensor = self._local_artifact(artifacts.targets, artifacts.targets_template)
                        synced_pred_chunks = self._gather_tensor_chunks(local_pred_tensor, world_size)
                        synced_target_chunks = self._gather_tensor_chunks(local_target_tensor, world_size)
                    else:
                        delta_pred_tensor = self._local_artifact(artifacts.pending_preds, artifacts.preds_template)
                        delta_target_tensor = self._local_artifact(
                            artifacts.pending_targets, artifacts.targets_template
                        )
                        delta_pred_chunks = self._gather_tensor_chunks(delta_pred_tensor, world_size)
                        delta_target_chunks = self._gather_tensor_chunks(delta_target_tensor, world_size)
                        synced_pred_chunks = self._append_tensor_chunks(sync.pred_chunks, delta_pred_chunks)
                        synced_target_chunks = self._append_tensor_chunks(sync.target_chunks, delta_target_chunks)
                    synced_preds = self._concat_tensor_chunks(synced_pred_chunks)
                    synced_targets = self._concat_tensor_chunks(synced_target_chunks)

                if self.requirements["confmat"]:
                    local_confmat = artifacts.confmat if artifacts.confmat is not None else self._empty_confmat()
                    synced_confmat, synced_count = self._all_reduce_confmat_count(local_confmat, self._local_count)
                    if synced_count == 0:
                        synced_confmat = None
                else:
                    synced_count = None
            else:
                if self.requirements["preds_targets"]:
                    synced_pred_chunks = [self._local_preds()]
                    synced_target_chunks = [self._local_targets()]
                    synced_preds = synced_pred_chunks[0]
                    synced_targets = synced_target_chunks[0]
                synced_count = self._local_count
        else:
            if self.requirements["preds_targets"]:
                synced_pred_chunks = [self._local_preds()]
                synced_target_chunks = [self._local_targets()]
                synced_preds = synced_pred_chunks[0]
                synced_targets = synced_target_chunks[0]
            synced_count = self._local_count
        sync.pred_chunks = synced_pred_chunks if self.requirements["preds_targets"] else None
        sync.target_chunks = synced_target_chunks if self.requirements["preds_targets"] else None
        sync.preds = synced_preds if self.requirements["preds_targets"] else None
        sync.targets = synced_targets if self.requirements["preds_targets"] else None
        sync.confmat = synced_confmat if self.requirements["confmat"] else None
        if world_size == 1:
            sync.count = self._local_count
        elif sync.targets is not None:
            sync.count = self._infer_batch_count(sync.targets)
        elif synced_count is not None:
            sync.count = synced_count
        else:
            sync.count = None
        artifacts.pending_preds.clear()
        artifacts.pending_targets.clear()
        self._cache.clear()
        sync.synced = True
        sync.world_size = world_size

    def reset(self) -> Self:
        self._artifacts = _ArtifactState()
        self._clear_sync_state()
        self._local_n = 0
        self._local_count = 0
        self._artifact_version = 0
        self._cache.clear()
        return self

    # Public reductions
    def value(self) -> RoundDict:
        state = self._last_state()
        return RoundDict(
            {name: self._run_metric(name, func, state, cache=False) for name, func in self.metrics.items()}
        )

    def batch(self) -> RoundDict:
        world_size = self._current_world_size()
        if world_size == 1:
            return self.value()

        if not self.requirements["preds_targets"] and not self.requirements["confmat"]:
            return self._approximate_batch_values()

        state, _ = self._batch_state(world_size)
        return RoundDict(
            {name: self._run_metric(name, func, state, cache=False) for name, func in self.metrics.items()}
        )

    def average(self) -> RoundDict:
        self.sync()
        state = self._average_state()
        return RoundDict({name: self._run_metric(name, func, state, cache=True) for name, func in self.metrics.items()})

    # Public aliases
    @property
    def val(self) -> RoundDict:
        return self.value()

    @property
    def bat(self) -> RoundDict:
        return self.batch()

    @property
    def avg(self) -> RoundDict:
        return self.average()

    # Public artifact accessors
    @property
    def preds(self) -> Tensor:
        if self._should_expose_synced_state() and self._sync.preds is not None:
            return self._sync.preds
        return self._local_preds()

    @property
    def targets(self) -> Tensor:
        if self._should_expose_synced_state() and self._sync.targets is not None:
            return self._sync.targets
        return self._local_targets()

    @property
    def confmat(self) -> Tensor | None:
        if self._should_expose_synced_state() and self._sync.confmat is not None:
            return self._sync.confmat
        return self._artifacts.confmat

    # Public state accessors
    @property
    def n(self) -> int:
        return self._local_n

    @property
    def count(self) -> int:
        return self._local_count

    # Formatting helpers
    def __repr__(self) -> str:  # pragma: no cover - repr convenience
        keys = tuple(self.metrics.keys())
        return f"{self.__class__.__name__}{keys}"

    def __format__(self, format_spec: str) -> str:
        val = self.value()
        state = self._local_average_state()
        avg = RoundDict({name: self._run_metric(name, func, state, cache=False) for name, func in self.metrics.items()})
        return "\t".join(
            f"{key}: {val[key].__format__(format_spec)} ({avg[key].__format__(format_spec)})" for key in val
        )

    # State builders
    def _last_state(self) -> MetricState:
        artifacts = self._artifacts
        return MetricState(preds=artifacts.last_preds, targets=artifacts.last_targets, confmat=artifacts.last_confmat)

    def _local_average_state(self) -> MetricState:
        return MetricState(preds=self._local_preds(), targets=self._local_targets(), confmat=self._artifacts.confmat)

    def _average_state(self) -> MetricState:
        preds = self._sync.preds if self._sync.preds is not None else self._local_preds()
        targets = self._sync.targets if self._sync.targets is not None else self._local_targets()
        confmat = self._sync.confmat if self._sync.confmat is not None else self._artifacts.confmat
        return MetricState(preds=preds, targets=targets, confmat=confmat)

    def _batch_state(self, world_size: int) -> tuple[MetricState, int]:
        artifacts = self._artifacts
        synced_preds = artifacts.last_preds
        synced_targets = artifacts.last_targets
        synced_confmat = artifacts.last_confmat
        synced_count: int | None = None

        if self.requirements["preds_targets"]:
            local_pred_tensor = self._current_batch_artifact(artifacts.last_preds, artifacts.preds_template)
            local_target_tensor = self._current_batch_artifact(artifacts.last_targets, artifacts.targets_template)
            synced_preds = self._gather_tensor(local_pred_tensor, world_size)
            synced_targets = self._gather_tensor(local_target_tensor, world_size)
            synced_count = self._infer_batch_count(synced_targets)

        if self.requirements["confmat"]:
            local_confmat = artifacts.last_confmat if artifacts.last_confmat is not None else self._empty_confmat()
            synced_confmat, confmat_count = self._all_reduce_confmat_count(local_confmat, self._local_n)
            if confmat_count == 0:
                synced_confmat = None
            if synced_count is None:
                synced_count = confmat_count

        if synced_count is None:
            synced_count = self._all_reduce_count(self._local_n)

        return MetricState(preds=synced_preds, targets=synced_targets, confmat=synced_confmat), synced_count

    def _run_metric(self, name: str, func: MetricFunc, state: MetricState, cache: bool) -> Tensor | float:
        if cache:
            cached = self._cache.get(name)
            if cached and cached[0] == self._artifact_version:
                return cached[1]

        value = func(state)

        if cache:
            self._cache[name] = (self._artifact_version, value)
        return value

    # Local artifact helpers
    def _local_preds(self) -> Tensor:
        preds = self._artifacts.preds
        if preds:
            if len(preds) == 1:
                return preds[0]
            return torch.cat(preds, dim=0)
        return torch.empty(0, device=self.device or "cpu")

    def _local_targets(self) -> Tensor:
        targets = self._artifacts.targets
        if targets:
            if len(targets) == 1:
                return targets[0]
            return torch.cat(targets, dim=0)
        return torch.empty(0, device=self.device or "cpu")

    def _detach_to_device(self, tensor: Tensor) -> Tensor:
        output = tensor.detach()
        if self.device is not None:
            output = output.to(self.device)
        return output

    def _local_artifact(self, tensors: list[Tensor], template: Tensor | None) -> Tensor | None:
        if not tensors:
            return template
        if len(tensors) == 1:
            return tensors[0]
        return torch.cat(tensors, dim=0)

    def _approximate_batch_values(self) -> RoundDict:
        local_values = self.value()
        device = self._sync_device()
        local_count = float(self._local_n)
        names: list[str] = []
        shapes: list[torch.Size] = []
        tensor_flags: list[bool] = []
        total_numel = 1

        for name, value in local_values.items():
            tensor_value = torch.as_tensor(value, dtype=torch.float64, device=device)
            names.append(name)
            shapes.append(tensor_value.shape)
            tensor_flags.append(isinstance(value, Tensor))
            total_numel += tensor_value.numel()

        reduced = torch.zeros(total_numel, dtype=torch.float64, device=device)
        reduced[0] = local_count

        offset = 1
        for name in names:
            tensor_value = torch.as_tensor(local_values[name], dtype=torch.float64, device=device).reshape(-1)
            numel = tensor_value.numel()
            if local_count > 0:
                reduced[offset : offset + numel] = tensor_value * local_count
            offset += numel

        dist.all_reduce(reduced)

        total_count = int(round(reduced[0].item()))

        batch_values = RoundDict()
        offset = 1
        for name, shape, is_tensor in zip(names, shapes, tensor_flags):
            numel = int(torch.Size(shape).numel())
            values = reduced[offset : offset + numel]
            offset += numel
            if total_count == 0:
                reduced_value = torch.full(shape, float("nan"), dtype=torch.float64, device=device)
            else:
                reduced_value = (values / total_count).reshape(shape)
            batch_values[name] = reduced_value if is_tensor else reduced_value.item()

        return batch_values

    # Distributed synchronization helpers
    def _gather_tensor(self, tensor: Tensor | None, world_size: int) -> Tensor:
        gathered_chunks = self._gather_tensor_chunks(tensor, world_size)
        return self._concat_tensor_chunks(gathered_chunks)

    def _gather_tensor_chunks(self, tensor: Tensor | None, world_size: int) -> list[Tensor]:
        device = self._sync_device()
        if tensor is not None:
            tensor = self._tensor_on_sync_device(tensor, device=device)
        local_size = torch.tensor([tensor.shape[0] if tensor is not None else -1], dtype=torch.int64, device=device)
        size_list = [torch.zeros_like(local_size) for _ in range(world_size)]
        dist.all_gather(size_list, local_size)
        sizes = torch.cat(size_list)

        if (sizes < 0).any():
            tensor = self._gather_tensor_with_metadata(tensor, world_size)
            sizes = sizes.clamp_min(0)

        if tensor is None:
            return [torch.empty(0, device=device) for _ in range(world_size)]

        return self._gather_tensor_chunks_data(tensor, sizes, world_size)

    def _gather_tensor_chunks_data(self, tensor: Tensor, sizes: Tensor, world_size: int) -> list[Tensor]:
        max_size = int(sizes.max().item())

        padded_tensor = torch.empty((max_size, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
        padded_tensor.zero_()
        if tensor.shape[0] > 0:
            padded_tensor[: tensor.shape[0]] = tensor
        gathered_tensors = [torch.empty_like(padded_tensor) for _ in range(world_size)]
        dist.all_gather(gathered_tensors, padded_tensor)
        empty = torch.empty((0, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
        return [gathered_tensors[i][: sizes[i]] if sizes[i] > 0 else empty.clone() for i in range(world_size)]

    def _gather_tensor_with_metadata(self, tensor: Tensor | None, world_size: int) -> Tensor | None:
        metadata = None if tensor is None else (tuple(tensor.shape[1:]), str(tensor.dtype))
        metadata_list: list[tuple[tuple[int, ...], str] | None] = [None for _ in range(world_size)]
        dist.all_gather_object(metadata_list, metadata)
        reference = next((item for item in metadata_list if item is not None), None)
        if reference is None:
            return tensor
        if tensor is not None:
            return tensor

        reference_shape, reference_dtype = reference
        return torch.empty(
            (0, *reference_shape),
            dtype=getattr(torch, reference_dtype.removeprefix("torch.")),
            device=self._sync_device(),
        )

    def _all_reduce_confmat_count(self, tensor: Tensor, count: int) -> tuple[Tensor, int]:
        tensor = self._tensor_on_sync_device(tensor)
        reduced = torch.empty(tensor.numel() + 1, dtype=tensor.dtype, device=tensor.device)
        reduced[:-1] = tensor.reshape(-1)
        reduced[-1] = count
        dist.all_reduce(reduced)
        return reduced[:-1].reshape_as(tensor), int(round(float(reduced[-1].item())))

    def _all_reduce_count(self, count: int) -> int:
        device = self._sync_device()
        reduced = torch.tensor([float(count)], dtype=torch.float64, device=device)
        dist.all_reduce(reduced)
        return int(round(reduced.item()))

    def _empty_artifact(self, tensor: Tensor) -> Tensor:
        return torch.empty((0, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)

    def _current_batch_artifact(self, tensor: Tensor, template: Tensor | None) -> Tensor | None:
        if tensor.numel() > 0 or tensor.ndim > 1:
            return tensor
        if template is not None:
            return template
        return None

    def _concat_tensor_chunks(self, chunks: list[Tensor] | None) -> Tensor:
        if not chunks:
            return torch.empty(0, device=self._sync_device())
        if len(chunks) == 1:
            return chunks[0]
        return torch.cat(chunks, dim=0)

    def _append_tensor_chunks(self, base: list[Tensor] | None, delta: list[Tensor]) -> list[Tensor]:
        if base is None:
            return delta
        return [torch.cat((current, update), dim=0) for current, update in zip(base, delta)]

    def _tensor_on_sync_device(self, tensor: Tensor, *, device: torch.device | None = None) -> Tensor:
        sync_device = self._sync_device() if device is None else device
        if tensor.device == sync_device:
            return tensor
        return tensor.to(sync_device)

    def _mark_sync_stale(self) -> None:
        self._sync.count = None
        self._sync.synced = False

    def _clear_sync_state(self) -> None:
        self._sync = _SyncState()

    # Generic helpers
    def _empty_confmat(self) -> Tensor:
        task = self.requirements["task"]
        device = self._sync_device()

        if task == "binary":
            return torch.zeros((2, 2), dtype=torch.long, device=device)

        if task == "multiclass":
            num_classes = self.requirements["num_classes"]
            return torch.zeros((num_classes, num_classes), dtype=torch.long, device=device)

        if task == "multilabel":
            num_labels = self.requirements["num_labels"]
            return torch.zeros((num_labels, 2, 2), dtype=torch.long, device=device)

        raise MetricRequirementError(f"Unsupported confusion matrix task: {task!r}")

    def _sync_device(self) -> torch.device:
        if not (dist.is_available() and dist.is_initialized()) and self.device is not None:
            return self.device
        return infer_device()

    def _current_world_size(self) -> int:
        if not self.distributed:
            return 1
        return get_world_size()

    def _should_expose_synced_state(self) -> bool:
        return self._sync.synced and self._sync.world_size == self._current_world_size()

    def _requires_full_artifact_resync(self, world_size: int) -> bool:
        return (
            self._sync.world_size != world_size
            or self._sync.pred_chunks is None
            or self._sync.target_chunks is None
            or not self._chunks_match_template(self._sync.pred_chunks, self._artifacts.preds_template)
            or not self._chunks_match_template(self._sync.target_chunks, self._artifacts.targets_template)
        )

    @staticmethod
    def _infer_batch_count(target: Tensor) -> int:
        if target.ndim == 0:
            return 1
        return int(target.shape[0])

    @staticmethod
    def _chunks_match_template(chunks: list[Tensor], template: Tensor | None) -> bool:
        if template is None:
            return True
        return all(chunk.dtype == template.dtype and chunk.shape[1:] == template.shape[1:] for chunk in chunks)

MultiTaskMetrics

Bases: MultiTaskBase

Examples:

Python Console Session
>>> from danling.metrics.functional import accuracy
>>> metrics = MultiTaskMetrics(aggregate="macro")
>>> metrics.dataset1 = StreamMetrics(acc=accuracy)
>>> metrics.dataset2 = StreamMetrics(acc=accuracy)
>>> metrics.update({"dataset1": {"input": [0.2, 0.4, 0.6, 0.7], "target": [0, 1, 0, 1]}, "dataset2": ([0.1, 0.4, 0.6, 0.8], [1, 0, 0, 0])})
>>> f"{metrics:.4f}"
'dataset1: acc: 0.5000 (0.5000)\ndataset2: acc: 0.2500 (0.2500)'
>>> metrics.update({"dataset1": ([0.1, 0.4, 0.6, 0.8], [0, 0, 1, 0]), "dataset2": {"input": [0.2, 0.3, 0.6, 0.7], "target": [0, 0, 0, 1]}})
>>> round(metrics.avg["aggregate"]["acc"], 4)
0.5625
>>> metrics.update(dict(loss=""))
Traceback (most recent call last):
ValueError: Task loss not found in ...
Notes
  • MultiTaskMetrics manages a flat collection of task-level metric containers
  • All task containers are updated simultaneously with a single update() call
  • Aggregation mode is configured at construction time via aggregate=...
  • aggregate="macro" gives equal task weight, aggregate="micro" weights by sample count, and aggregate="weighted" uses explicit aggregate_weights
  • Aggregate outputs are matched by exact relative metric path across tasks
  • Provides a structured way to track metrics across different tasks or model components
See Also
  • GlobalMetrics: Exact metrics container that stores prediction and target history.
  • StreamMetrics: Streaming metrics container for hot-path metric tracking.
Source code in danling/metrics/multitask.py
Python
class MultiTaskMetrics(MultiTaskBase):
    r"""
    Examples:
        >>> from danling.metrics.functional import accuracy
        >>> metrics = MultiTaskMetrics(aggregate="macro")
        >>> metrics.dataset1 = StreamMetrics(acc=accuracy)
        >>> metrics.dataset2 = StreamMetrics(acc=accuracy)
        >>> metrics.update({"dataset1": {"input": [0.2, 0.4, 0.6, 0.7], "target": [0, 1, 0, 1]}, "dataset2": ([0.1, 0.4, 0.6, 0.8], [1, 0, 0, 0])})
        >>> f"{metrics:.4f}"
        'dataset1: acc: 0.5000 (0.5000)\ndataset2: acc: 0.2500 (0.2500)'
        >>> metrics.update({"dataset1": ([0.1, 0.4, 0.6, 0.8], [0, 0, 1, 0]), "dataset2": {"input": [0.2, 0.3, 0.6, 0.7], "target": [0, 0, 0, 1]}})
        >>> round(metrics.avg["aggregate"]["acc"], 4)
        0.5625
        >>> metrics.update(dict(loss=""))  # doctest: +ELLIPSIS
        Traceback (most recent call last):
        ValueError: Task loss not found in ...

    Notes:
        - `MultiTaskMetrics` manages a flat collection of task-level metric containers
        - All task containers are updated simultaneously with a single `update()` call
        - Aggregation mode is configured at construction time via `aggregate=...`
        - `aggregate="macro"` gives equal task weight, `aggregate="micro"` weights by sample count,
          and `aggregate="weighted"` uses explicit `aggregate_weights`
        - Aggregate outputs are matched by exact relative metric path across tasks
        - Provides a structured way to track metrics across different tasks or model components

    See Also:
        - [`GlobalMetrics`][danling.metrics.global_metrics.GlobalMetrics]:
            Exact metrics container that stores prediction and target history.
        - [`StreamMetrics`][danling.metrics.stream_metrics.StreamMetrics]:
            Streaming metrics container for hot-path metric tracking.
    """  # noqa: E501

    def __init__(
        self,
        *args,
        aggregate: Literal["macro", "micro", "weighted"] | None = None,
        aggregate_weights: Mapping[str, float | int | Tensor] | None = None,
        **kwargs,
    ):
        super().__init__(*args, aggregate=aggregate, aggregate_weights=aggregate_weights, **kwargs)

    def update(  # type: ignore[override] # pylint: disable=W0221
        self,
        values: Mapping[str, Mapping[str, Tensor | NestedTensor | Sequence] | Sequence],
    ) -> None:
        r"""
        Updates all task metric containers.

        Args:
            values: Mapping from task names to update payloads.
                Mapping payloads are forwarded as keyword arguments to the
                child container's `update()`; sequence payloads are forwarded
                positionally.
        """

        for task, payload in values.items():
            if task not in self:
                raise ValueError(f"Task {task} not found in {self}")
            task_metrics = self[task]
            if isinstance(payload, Mapping):
                task_metrics.update(**payload)
            elif isinstance(payload, Sequence):
                task_metrics.update(*payload)
            else:
                raise ValueError(
                    f"Expected payload for task {task} to be a Mapping or Sequence, but got {type(payload)}"
                )

    def set(  # pylint: disable=W0237
        self,
        name: str,
        task_metrics: GlobalMetrics | MetricMeter | StreamMetrics | Callable,  # type: ignore[override]
    ) -> None:
        if callable(task_metrics) and not isinstance(task_metrics, (GlobalMetrics, StreamMetrics, MetricMeter)):
            task_metrics = MetricMeter(task_metrics)
        if isinstance(task_metrics, MetricMeter):
            task_metrics.output_name = self._metric_output_name(name, task_metrics)
        if not isinstance(task_metrics, (GlobalMetrics, StreamMetrics, MetricMeter)):
            raise ValueError(
                f"Expected task_metrics for {name} to be an instance of GlobalMetrics, "
                f"StreamMetrics, MetricMeter, or a callable, but got {type(task_metrics)}"
            )
        super().set(name, task_metrics)

    @staticmethod
    def _metric_output_name(task_name: str, metric: MetricMeter) -> str:
        output_name = getattr(metric, "output_name", None)
        if isinstance(output_name, str) and output_name not in {"", "<lambda>", "__call__"}:
            return output_name

        try:
            inferred = infer_metric_name(metric.metric)
        except ValueError:
            return task_name
        if inferred in {"", "<lambda>", "__call__"}:
            return task_name
        return inferred

update

Python
update(
    values: Mapping[
        str,
        Mapping[str, Tensor | NestedTensor | Sequence]
        | Sequence,
    ],
) -> None

Updates all task metric containers.

Parameters:

Name Type Description Default
values
Mapping[str, Mapping[str, Tensor | NestedTensor | Sequence] | Sequence]

Mapping from task names to update payloads. Mapping payloads are forwarded as keyword arguments to the child container’s update(); sequence payloads are forwarded positionally.

required
Source code in danling/metrics/multitask.py
Python
def update(  # type: ignore[override] # pylint: disable=W0221
    self,
    values: Mapping[str, Mapping[str, Tensor | NestedTensor | Sequence] | Sequence],
) -> None:
    r"""
    Updates all task metric containers.

    Args:
        values: Mapping from task names to update payloads.
            Mapping payloads are forwarded as keyword arguments to the
            child container's `update()`; sequence payloads are forwarded
            positionally.
    """

    for task, payload in values.items():
        if task not in self:
            raise ValueError(f"Task {task} not found in {self}")
        task_metrics = self[task]
        if isinstance(payload, Mapping):
            task_metrics.update(**payload)
        elif isinstance(payload, Sequence):
            task_metrics.update(*payload)
        else:
            raise ValueError(
                f"Expected payload for task {task} to be a Mapping or Sequence, but got {type(payload)}"
            )

to_device

Python
to_device(data: Any, device: device)

Move data to device.

Source code in danling/data/utils.py
Python
def to_device(data: Any, device: torch.device):
    r"""Move data to device."""
    if isinstance(data, torch.Tensor):
        return data.to(device)
    if isinstance(data, FlatDict):
        return data.to(device)
    if isinstance(data, list):
        return [to_device(i, device) for i in data]
    if isinstance(data, tuple):
        return tuple(to_device(i, device) for i in data)
    if isinstance(data, dict):
        return FlatDict({k: to_device(v, device) for k, v in data.items()})
    if hasattr(data, "to"):
        return data.to(device)
    return data

tensor

Python
tensor(
    data: Any,
    dtype=None,
    device=None,
    requires_grad: bool = False,
    pin_memory: bool = False,
) -> PNTensor

Create a PNTensor from data, similar to torch.tensor() but returning a PNTensor.

This function is a convenient way to create PNTensor objects that can be collated into NestedTensor when used with PyTorch DataLoader after importing danling.tensors. The interface mirrors torch.tensor() to make it easy to switch between regular tensors and PNTensors.

Parameters:

Name Type Description Default

data

Any

Initial data for the tensor. Can be a list, tuple, NumPy ndarray, scalar, etc.

required

dtype

Desired data type of the returned tensor.

None

device

Device on which to place the tensor.

None

requires_grad

bool

If autograd should record operations on the returned tensor.

False

pin_memory

bool

If True, the tensor will be allocated in pinned memory.

False

Returns:

Name Type Description
PNTensor PNTensor

A tensor wrapper for NestedTensor-oriented collation

Examples:

Python Console Session
1
2
3
4
>>> from danling.tensors import tensor
>>> t = tensor([1, 2, 3])
>>> t
PNTensor([1, 2, 3])
Source code in danling/tensors/pn_tensor.py
Python
def tensor(data: Any, dtype=None, device=None, requires_grad: bool = False, pin_memory: bool = False) -> PNTensor:
    r"""
    Create a PNTensor from data, similar to torch.tensor() but returning a PNTensor.

    This function is a convenient way to create PNTensor objects that can be
    collated into NestedTensor when used with PyTorch DataLoader after importing
    ``danling.tensors``. The interface mirrors torch.tensor() to make it easy to
    switch between regular tensors and PNTensors.

    Args:
        data: Initial data for the tensor. Can be a list, tuple, NumPy ndarray, scalar, etc.
        dtype: Desired data type of the returned tensor.
        device: Device on which to place the tensor.
        requires_grad: If autograd should record operations on the returned tensor.
        pin_memory: If True, the tensor will be allocated in pinned memory.

    Returns:
        PNTensor: A tensor wrapper for NestedTensor-oriented collation

    Examples:
        >>> from danling.tensors import tensor
        >>> t = tensor([1, 2, 3])
        >>> t
        PNTensor([1, 2, 3])
    """
    return PNTensor(torch.tensor(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory))

catch

Python
catch(
    error: Exceptions = Exception,
    exclude: Exceptions | None = None,
    callback: Callable = print_exc,
    *callback_args,
    **callback_kwargs
)

Decorator to catch error except for exclude. Detailed traceback will be printed to stderr.

catch is extremely useful for unfatal errors. For example, Runner saves checkpoint regularly, however, this might break running if the space is full. Decorating save method with catch will allow you to catch these errors and continue your running.

Parameters:

Name Type Description Default

error

Exceptions

Exceptions to be caught.

Exception

exclude

Exceptions | None

Exceptions to be excluded.

None

callback

Callable

Callback to be called when an error occurs. The first four arguments to callback are exc, func, args, kwargs. Additional arguments should be passed with *callback_args and **callback_kwargs.

print_exc

callback_args

Arguments to be passed to callback.

()

callback_kwargs

Keyword arguments to be passed to callback.

{}

Examples:

Python Console Session
>>> def file_not_found(*args, **kwargs):
...     raise FileNotFoundError
>>> func = file_not_found
>>> func()
Traceback (most recent call last):
FileNotFoundError
>>> func = catch(OSError)(file_not_found)
>>> func()
>>> func = catch(IOError)(file_not_found)
>>> func()
>>> func = catch(ZeroDivisionError)(file_not_found)
>>> func()
Traceback (most recent call last):
FileNotFoundError
Source code in danling/utils/decorators.py
Python
@flexible_decorator
def catch(  # pylint: disable=keyword-arg-before-vararg
    error: Exceptions = Exception,
    exclude: Exceptions | None = None,
    callback: Callable = print_exc,
    *callback_args,
    **callback_kwargs,
):
    r"""
    Decorator to catch `error` except for `exclude`.
    Detailed traceback will be printed to `stderr`.

    `catch` is extremely useful for unfatal errors.
    For example, `Runner` saves checkpoint regularly, however, this might break running if the space is full.
    Decorating `save` method with `catch` will allow you to catch these errors and continue your running.

    Args:
        error: Exceptions to be caught.
        exclude: Exceptions to be excluded.
        callback: Callback to be called when an error occurs.
            The first four arguments to `callback` are `exc`, `func`, `args`, `kwargs`.
            Additional arguments should be passed with `*callback_args` and `**callback_kwargs`.
        callback_args: Arguments to be passed to `callback`.
        callback_kwargs: Keyword arguments to be passed to `callback`.

    Examples:
        >>> def file_not_found(*args, **kwargs):
        ...     raise FileNotFoundError
        >>> func = file_not_found
        >>> func()
        Traceback (most recent call last):
        FileNotFoundError
        >>> func = catch(OSError)(file_not_found)
        >>> func()
        >>> func = catch(IOError)(file_not_found)
        >>> func()
        >>> func = catch(ZeroDivisionError)(file_not_found)
        >>> func()
        Traceback (most recent call last):
        FileNotFoundError
    """

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):  # pylint: disable=inconsistent-return-statements
            try:
                return func(*args, **kwargs)
            except error as exc:  # pylint: disable=broad-exception-caught
                if exclude is not None and isinstance(exc, exclude):
                    raise exc
                callback(exc, func, args, kwargs, *callback_args, **callback_kwargs)

        return wrapper

    decorator.__doc__ = catch.__doc__

    return decorator

debug

Python
debug(
    enable: bool = True,
    error: Exceptions = Exception,
    exclude: Optional[Exceptions] = None,
)

Contextmanager to enter debug mode on error except for exclude.

debug is intended to be used to catch the error and enter debug mode. Since it is mainly for development purposed, we include an enable args so that it can be deactivated.

Parameters:

Name Type Description Default

enable

bool

Whether to enable the contextmanager. Defaults to True.

True

error

Exceptions

The error to catch. Defaults to Exception.

Exception

exclude

Optional[Exceptions]

The error to exclude. Defaults to None.

None
Source code in danling/utils/context_managers.py
Python
@contextmanager
def debug(
    enable: bool = True,
    error: Exceptions = Exception,
    exclude: Optional[Exceptions] = None,
):
    """
    Contextmanager to enter debug mode on `error` except for `exclude`.

    `debug` is intended to be used to catch the error and enter debug mode.
    Since it is mainly for development purposed, we include an `enable` args so that it can be deactivated.

    Args:
        enable: Whether to enable the contextmanager.
            Defaults to `True`.
        error: The error to catch.
            Defaults to `Exception`.
        exclude: The error to exclude.
            Defaults to `None`.
    """

    if not enable:
        yield
        return
    try:
        yield
    except error as exc:  # pylint: disable=broad-exception-caught
        if exclude is not None and isinstance(exc, exclude):
            raise exc
        _, m, tb = sys.exc_info()
        print(repr(m), file=sys.stderr)
        pdb.post_mortem(tb)
    finally:
        pass

flexible_decorator

Python
flexible_decorator(
    maybe_decorator: Optional[Callable] = None,
)

Meta decorator to allow bracket-less decorator when no arguments are passed.

Examples:

For decorator defined as follows:

Python Console Session
1
2
3
4
5
>>> @flexible_decorator
... def decorator(*args, **kwargs):
...     def wrapper(func, *args, **kwargs):
...         pass
...     return wrapper

The following two are equivalent:

Python Console Session
1
2
3
>>> @decorator
... def func(*args, **kwargs):
...     pass
Python Console Session
1
2
3
>>> @decorator()
... def func(*args, **kwargs):
...     pass
Source code in danling/utils/decorators.py
Python
def flexible_decorator(maybe_decorator: Optional[Callable] = None):
    r"""
    Meta decorator to allow bracket-less decorator when no arguments are passed.

    Examples:
        For decorator defined as follows:

        >>> @flexible_decorator
        ... def decorator(*args, **kwargs):
        ...     def wrapper(func, *args, **kwargs):
        ...         pass
        ...     return wrapper

        The following two are equivalent:

        >>> @decorator
        ... def func(*args, **kwargs):
        ...     pass

        >>> @decorator()
        ... def func(*args, **kwargs):
        ...     pass
    """

    def decorator(func: Callable):
        @wraps(func)
        def wrapper(*args, **kwargs):
            if len(args) == 1 and isfunction(args[0]):
                return func(**kwargs)(args[0])
            return func(*args, **kwargs)

        return wrapper

    if maybe_decorator is None:
        return decorator
    return decorator(maybe_decorator)

is_json_serializable

Python
is_json_serializable(obj: Any) -> bool

Check if obj is JSON serializable.

Source code in danling/utils/io.py
Python
def is_json_serializable(obj: Any) -> bool:
    r"""
    Check if `obj` is JSON serializable.
    """
    try:
        json.dumps(obj)
        return True
    except (TypeError, OverflowError):
        return False

load

Python
load(file: PathStr, *args: Any, **kwargs: Any) -> Any

Load any file with supported extensions.

Source code in danling/utils/io.py
Python
def load(file: PathStr, *args: Any, **kwargs: Any) -> Any:
    r"""
    Load any file with supported extensions.
    """
    if not os.path.isfile(file):
        raise ValueError(f"Trying to load {file!r} but it is not a file.")
    extension = os.path.splitext(file)[-1].lower()[1:]
    if extension in PYTORCH:
        if not TORCH_AVAILABLE:
            raise ImportError(f"Trying to load {file!r} but torch is not installed.")
        return torch.load(file, *args, **kwargs)
    if extension in NUMPY:
        if not NUMPY_AVAILABLE:
            raise ImportError(f"Trying to load {file!r} but numpy is not installed.")
        return numpy.load(file, *args, **kwargs)
    if extension in JSON:
        with open(file) as fp:
            return json.load(fp, *args, **kwargs)  # type: ignore[arg-type]
    if extension in YAML:
        with open(file) as fp:
            kwargs.setdefault("Loader", yaml.FullLoader)  # type: ignore[arg-type]
            return yaml.load(fp, *args, **kwargs)  # type: ignore[arg-type]
    if extension in PICKLE:
        with open(file, "rb") as fp:
            return pickle.load(fp, *args, **kwargs)  # type: ignore[arg-type]
    if extension in PANDAS_SUPPORTED:
        return load_pandas(file, *args, **kwargs)
    raise ValueError(f"Tying to load {file!r} with unsupported extension={extension!r}")

load_pandas

Python
load_pandas(
    file: PathStr, *args: Any, **kwargs: Any
) -> Any

Load any pandas data file with supported extensions.

Source code in danling/utils/io.py
Python
def load_pandas(file: PathStr, *args: Any, **kwargs: Any) -> Any:
    r"""
    Load any pandas data file with supported extensions.
    """
    if not PANDAS_AVAILABLE:
        raise ImportError(f"Trying to load {file!r} but pandas is not installed.")
    if not os.path.isfile(file):
        raise ValueError(f"Trying to load {file!r} but it is not a file.")
    extension = os.path.splitext(file)[-1].lower()[1:]
    if extension in PANDAS or extension in PICKLE:
        return pandas.read_pickle(file, *args, **kwargs)
    if extension in PARQUET:
        return pandas.read_parquet(file, *args, **kwargs)
    if extension in H5:
        return pandas.read_hdf(file, *args, **kwargs)
    if extension in CSV:
        return pandas.read_csv(file, *args, **kwargs)
    if extension in JSON:
        return pandas.read_json(file, *args, **kwargs)
    if extension in EXCEL:
        return pandas.read_excel(file, *args, **kwargs)
    if extension in XML:
        return pandas.read_xml(file, *args, **kwargs)
    if extension in SQL:
        return pandas.read_sql(file, *args, **kwargs)
    raise ValueError(f"Tying to load {file!r} with unsupported extension={extension!r}")

method_cache

Python
method_cache(
    maxsize: int | None = 128, typed: bool = False
)

Decorator to cache the result of an instance method.

functools.lru_cache uses a strong reference to the instance, which will make the instance immortal and break the garbage collection.

method_cache uses a weak reference to the instance to resolve this issue.

See Also
Source code in danling/utils/decorators.py
Python
@flexible_decorator
def method_cache(maxsize: int | None = 128, typed: bool = False):
    r"""
    Decorator to cache the result of an instance method.

    `functools.lru_cache` uses a strong reference to the instance,
    which will make the instance immortal and break the garbage collection.

    `method_cache` uses a weak reference to the instance to resolve this issue.

    See Also:
        https://rednafi.github.io/reflections/dont-wrap-instance-methods-with-functoolslru_cache-decorator-in-python.html
    """

    def decorator(func):
        @wraps(func)
        def wrapper(self, *args, **kwargs):
            self_ref = ref(self)

            @wraps(func)
            @lru_cache(maxsize=maxsize, typed=typed)
            def cached_method(*args, **kwargs):
                return func(self_ref(), *args, **kwargs)

            setattr(self, func.__name__, cached_method)
            return cached_method(*args, **kwargs)

        return wrapper

    return decorator

save

Python
save(
    obj: Any, file: PathStr, *args: Any, **kwargs: Any
) -> File

Save any file with supported extensions.

Source code in danling/utils/io.py
Python
def save(obj: Any, file: PathStr, *args: Any, **kwargs: Any) -> File:
    r"""
    Save any file with supported extensions.
    """
    extension = os.path.splitext(file)[-1].lower()[1:]
    if extension in PYTORCH:
        if not TORCH_AVAILABLE:
            raise ImportError(f"Trying to save {obj} to {file!r} but torch is not installed.")
        torch.save(obj, file, *args, **kwargs)
    elif extension in NUMPY:
        if not NUMPY_AVAILABLE:
            raise ImportError(f"Trying to save {obj} to {file!r} but numpy is not installed.")
        numpy.save(file, obj, *args, **kwargs)
    elif extension in PANDAS:
        if not PANDAS_AVAILABLE:
            raise ImportError(f"Trying to save {obj} to {file!r} but pandas is not installed.")
        pandas.to_pickle(obj, file, *args, **kwargs)
    elif extension in PARQUET:
        if isinstance(obj, pandas.DataFrame):
            obj.to_parquet(file, *args, **kwargs)
        elif not PYARROW_AVAILABLE:
            raise ImportError(f"Trying to save {obj} to {file!r} but pyarrow is not installed.")
        else:
            pyarrow.parquet.write_table(obj, file, *args, **kwargs)
    elif extension in CSV:
        if isinstance(obj, pandas.DataFrame):
            obj.to_csv(file, *args, **kwargs)
        else:
            raise NotImplementedError(f"Trying to save {obj} to {file!r} but is not supported")
    elif extension in JSON:
        if isinstance(obj, FlatDict):
            obj.json(file)
        else:
            with open(file, "w") as fp:
                json.dump(obj, fp, *args, **kwargs)  # type: ignore[arg-type]
    elif extension in YAML:
        if isinstance(obj, FlatDict):
            obj.yaml(file)
        else:
            with open(file, "w") as fp:
                yaml.dump(obj, fp, *args, **kwargs)  # type: ignore[arg-type, call-overload]
    elif extension in PICKLE:
        with open(file, "wb") as fp:
            pickle.dump(obj, fp, *args, **kwargs)  # type: ignore[arg-type]
    else:
        raise ValueError(f"Tying to save {obj} to {file!r} with unsupported extension={extension!r}")
    return file