Skip to content

MetricMeter

danling.metrics.MetricMeter

Bases: AverageMeter

A memory-efficient metric tracker that computes and averages metrics across batches.

MetricMeter applies a metric function to each batch and maintains running averages without storing the complete history of predictions and labels. This makes it ideal for metrics that can be meaningfully averaged across batches (like accuracy or loss).

Attributes:

Name Type Description
metric Callable | MetricFunc

The metric function to compute on each batch

preprocess

Optional preprocessing function applied before the metric

val float | Tensor

Result from the most recent batch on the current rank

bat float | Tensor

Synchronized metric result for the current step

avg float | Tensor

Weighted average of all results so far

sum float | Tensor

Running sum of (metric × batch_size) values

count int

Running sum of batch sizes

Parameters:

Name Type Description Default

metric

Callable | MetricFunc

Function that computes a metric given input and target tensors

required

preprocess

Callable | None

Optional preprocessing function to apply before computing the metric

None

Examples:

Python Console Session
>>> import torch
>>> from danling.metrics.functional import accuracy
>>> meter = MetricMeter(accuracy)
>>> meter.update(torch.tensor([0.1, 0.8, 0.6, 0.2]), torch.tensor([0, 1, 0, 0]))
>>> meter.val
0.75
>>> meter.avg
0.75
>>> meter.update(torch.tensor([0.1, 0.7, 0.3, 0.2, 0.8, 0.4]), torch.tensor([0, 1, 1, 0, 0, 1]))
>>> meter.val
0.5
>>> meter.avg
0.6
>>> meter.sum
6.0
>>> meter.count
10
>>> meter.reset()
MetricMeter(accuracy)
>>> meter.val
nan
>>> meter.avg
nan
Notes
  • MetricMeter is more memory-efficient than GlobalMetrics because it only stores running statistics
  • Only suitable for metrics that can be meaningfully averaged batch-by-batch
  • Not suitable for metrics like AUROC that need the entire dataset
  • Metrics are evaluated once per update; batch-vs-sample semantics are determined by the metric itself
  • Stream metrics may return tensors; tensor outputs are averaged elementwise across batches
  • MetricFunc descriptors receive [MetricState][danling.metrics.MetricState]
  • Plain callables receive preprocessed input / target tensors
  • For multiple metrics, use StreamMetrics
See Also
  • AverageMeter: A lightweight utility to compute and store running averages of values.
Source code in danling/metrics/stream_metrics.py
Python
class MetricMeter(AverageMeter):
    r"""
    A memory-efficient metric tracker that computes and averages metrics across batches.

    MetricMeter applies a metric function to each batch and maintains running averages
    without storing the complete history of predictions and labels. This makes it ideal for
    metrics that can be meaningfully averaged across batches (like accuracy or loss).

    Attributes:
        metric: The metric function to compute on each batch
        preprocess: Optional preprocessing function applied before the metric
        val: Result from the most recent batch on the current rank
        bat: Synchronized metric result for the current step
        avg: Weighted average of all results so far
        sum: Running sum of (metric × batch_size) values
        count: Running sum of batch sizes

    Args:
        metric: Function that computes a metric given input and target tensors
        preprocess: Optional preprocessing function to apply before computing the metric

    Examples:
        >>> import torch
        >>> from danling.metrics.functional import accuracy
        >>> meter = MetricMeter(accuracy)
        >>> meter.update(torch.tensor([0.1, 0.8, 0.6, 0.2]), torch.tensor([0, 1, 0, 0]))
        >>> meter.val
        0.75
        >>> meter.avg
        0.75
        >>> meter.update(torch.tensor([0.1, 0.7, 0.3, 0.2, 0.8, 0.4]), torch.tensor([0, 1, 1, 0, 0, 1]))
        >>> meter.val
        0.5
        >>> meter.avg
        0.6
        >>> meter.sum
        6.0
        >>> meter.count
        10
        >>> meter.reset()
        MetricMeter(accuracy)
        >>> meter.val
        nan
        >>> meter.avg
        nan

    Notes:
        - MetricMeter is more memory-efficient than [`GlobalMetrics`][danling.metrics.global_metrics.GlobalMetrics]
          because it only stores running statistics
        - Only suitable for metrics that can be meaningfully averaged batch-by-batch
        - Not suitable for metrics like AUROC that need the entire dataset
        - Metrics are evaluated once per update; batch-vs-sample semantics are determined by the metric itself
        - Stream metrics may return tensors; tensor outputs are averaged elementwise across batches
        - `MetricFunc` descriptors receive [`MetricState`][danling.metrics.MetricState]
        - Plain callables receive preprocessed `input` / `target` tensors
        - For multiple metrics, use [`StreamMetrics`][danling.metrics.stream_metrics.StreamMetrics]

    See Also:
        - [`AverageMeter`][danling.metrics.average_meter.AverageMeter]:
            A lightweight utility to compute and store running averages of values.
    """

    metric: Callable | MetricFunc
    output_name: str | None = None

    # Construction
    def __init__(
        self,
        metric: Callable | MetricFunc,
        *,
        preprocess: Callable | None = None,
        device: torch.device | str | None = None,
        distributed: bool = True,
    ) -> None:
        super().__init__(device=device, distributed=distributed)
        if not callable(metric):
            raise ValueError(f"Expected metric to be callable, but got {type(metric)}")
        self.metric = metric
        self.preprocess = preprocess
        self._requirements = MetricState.collect_requirements((metric,)) if isinstance(metric, MetricFunc) else None

    # Mutation
    def update(  # type: ignore[override] # pylint: disable=W0237
        self,
        input: Tensor | NestedTensor,  # pylint: disable=W0622
        target: Tensor | NestedTensor,
        *,
        n: int | None = None,
    ) -> None:
        r"""
        Updates the average and current value in the meter.

        Args:
            input: Prediction tensor or nested tensor.
            target: Ground-truth tensor or nested tensor.
            n: Optional number of samples represented by this update. When omitted,
                the batch size is inferred from the inputs.
        """

        if self.preprocess is not None:
            input, target = self.preprocess(input, target)

        self._update_state(self._build_state(input, target), n=n)

    # Internal helpers
    def _update_state(self, state: MetricState, *, n: int | None = None) -> None:
        if n is None:
            try:
                n = len(state.preds)
            except TypeError:
                n = 1

        value = self._compute_metric(state)
        super().update(value=value, n=n)

    def _compute_metric(self, state: MetricState) -> Tensor | float:
        if isinstance(self.metric, MetricFunc):
            return self._normalize_value(self.metric(state))
        return self._normalize_value(self.metric(state.preds, state.targets))

    def _build_state(
        self,
        input: Tensor | NestedTensor | Sequence,
        target: Tensor | NestedTensor | Sequence,
    ) -> MetricState:
        return MetricState.from_requirements(input, target, self._requirements)

    @staticmethod
    def _normalize_value(value: Tensor | float | int) -> Tensor | float:
        if isinstance(value, Tensor):
            if value.numel() == 0:
                return torch.tensor(float("nan"))
            if value.numel() == 1:
                return value.item()
            return value.detach()
        return float(value)

    def __repr__(self):
        metric = self.metric
        if isinstance(metric, MetricFunc):
            return f"{self.__class__.__name__}({metric.name})"
        if isinstance(metric, partial):
            metric = metric.func
        return f"{self.__class__.__name__}({metric.__name__})"

update

Python
update(
    input: Tensor | NestedTensor,
    target: Tensor | NestedTensor,
    *,
    n: int | None = None
) -> None

Updates the average and current value in the meter.

Parameters:

Name Type Description Default

input

Tensor | NestedTensor

Prediction tensor or nested tensor.

required

target

Tensor | NestedTensor

Ground-truth tensor or nested tensor.

required

n

int | None

Optional number of samples represented by this update. When omitted, the batch size is inferred from the inputs.

None
Source code in danling/metrics/stream_metrics.py
Python
def update(  # type: ignore[override] # pylint: disable=W0237
    self,
    input: Tensor | NestedTensor,  # pylint: disable=W0622
    target: Tensor | NestedTensor,
    *,
    n: int | None = None,
) -> None:
    r"""
    Updates the average and current value in the meter.

    Args:
        input: Prediction tensor or nested tensor.
        target: Ground-truth tensor or nested tensor.
        n: Optional number of samples represented by this update. When omitted,
            the batch size is inferred from the inputs.
    """

    if self.preprocess is not None:
        input, target = self.preprocess(input, target)

    self._update_state(self._build_state(input, target), n=n)