跳转至

Metrics

danling.metrics provides metric containers and metric descriptors for large-scale training. The design is exact-by-default while keeping a lighter streaming path available for hot training loops.

Design Summary

  • Exact by default: factory functions return GlobalMetrics unless mode="stream" is set.
  • Shared state: metric descriptors declare required artifacts (preds/targets, confmat) so containers build them once.
  • Symmetric API: GlobalMetrics and StreamMetrics share the same constructor signature.
  • Extensible: users can provide custom MetricFunc implementations (or plain callables for StreamMetrics).

Core Components

  • GlobalMetrics
  • Stores exact artifacts for global/global computation.
  • Computes values from shared [MetricState][danling.metrics.MetricState].
  • bat synchronizes the current-step exact state: reduced current-step confusion matrices when sufficient, gathered current-step preds/targets otherwise.
  • Performs distributed synchronization lazily in average().
  • In distributed exact mode, descriptors that require preds/targets gather full artifacts on average(), so this path is best reserved for eval/reporting rather than hot training-loop logging.
  • StreamMetrics
  • Computes streaming scores online and tracks running averages.
  • Uses the same metric descriptors and preprocess contract as GlobalMetrics.
  • Metrics are evaluated once per update; batch-vs-sample semantics are determined by the metric itself.
  • Suitable for high-throughput training loops.
  • MetricMeter
  • Single-metric streaming meter used internally by StreamMetrics.
  • [METRICS registry][danling.metrics.METRICS]
  • Task factory registry with explicit mode.
  • MultiTaskMetrics
  • Flat task container for multi-head / multi-dataset evaluation.
  • Aggregates matching metric paths with a plain mean across tasks.

Quick Start

Exact Metrics (Default)

Python
1
2
3
4
5
6
7
8
import torch
import danling as dl

metrics = dl.metrics.binary_metrics()  # mode="global" by default -> GlobalMetrics
metrics.update(torch.randn(32), torch.randint(2, (32,)))

print(metrics.val)  # last update
print(metrics.avg)  # exact average over all accumulated state

Streaming Metrics

Python
1
2
3
4
5
6
7
8
import torch
import danling as dl

metrics = dl.metrics.multiclass_metrics(num_classes=10, mode="stream")  # -> StreamMetrics
metrics.update(torch.randn(64, 10), torch.randint(10, (64,)))

print(metrics.val)  # current batch metric
print(metrics.avg)  # running average

StreamMetrics semantics:

  • val is the local value for the most recent update.
  • bat is the synchronized current-step metric.
  • avg is a sample-count-weighted running average.
  • Metrics are evaluated once per update.
  • Stream metrics preserve tensor outputs and average them elementwise across batches.
  • Plain callables receive preprocessed input / target tensors; MetricFunc descriptors receive MetricState.
  • Stream metrics with the same names as exact global metrics may still be running approximations rather than exact dataset-level values.

Global vs Stream

Aspect GlobalMetrics StreamMetrics
Default factory mode mode="global" mode="stream"
State Stores full required artifacts Stores running meter stats
Sync pattern bat() syncs current-step exact state; average() syncs accumulated exact state bat() syncs current-step metric; average() syncs running stats
Typical use Exact eval, AUROC/AUPRC/correlation Fast training logs
Memory Higher Lower

Shared Constructor Contract

GlobalMetrics and StreamMetrics intentionally share this signature:

Python
(*metric_funcs, preprocess=..., distributed=True, device=None, **metrics)

Rules:

  • Positional *metric_funcs can be metric descriptors (or iterables of descriptors). StreamMetrics also accepts plain callables.
  • Keyword **metrics are named metrics and override positional metrics with the same name.
  • preprocess is applied once per update.
  • device controls where internal artifacts/stat reductions live.

Factory Functions

All factories accept:

  • mode="global" | "stream" ("global" default)
  • *metric_funcs: if provided, defaults are replaced
  • **metrics: named extra metrics (or overrides)
  • task-specific arguments (num_classes, num_labels, num_outputs, ignore_index, etc.)

danling.metrics.functional.classification is kept as a thin convenience layer for one-shot DanLing metric calls. These wrappers apply DanLing preprocessing first (for example: nested-tensor alignment, shape normalization, ignore_index filtering, and probability normalization where applicable), then forward to the corresponding TorchMetrics functional implementation. Extra keyword arguments are still forwarded, but they operate on the preprocessed tensors and therefore must be compatible with the resulting shapes. Container-facing code should prefer MetricFunc descriptors, which let GlobalMetrics and StreamMetrics build shared state once.

Example:

Python
import danling as dl
from danling.metrics.functional import binary_precision, binary_recall

# Keep defaults and add metrics
metrics = dl.metrics.binary_metrics(
    mode="global",
    precision=binary_precision(),
    recall=binary_recall(),
)

# Replace defaults completely
metrics_only_pr = dl.metrics.binary_metrics(
    binary_precision(),
    binary_recall(),
    mode="global",
)

Default Metric Sets

Factories keep defaults minimal:

  • Binary / Multiclass / Multilabel:
  • auroc, auprc, acc, f1, mcc
  • Regression:
  • pearson, spearman, r2, mse, rmse

Additional built-ins (opt-in):

  • Classification: precision, recall, fbeta, specificity, balanced_accuracy, jaccard, iou, hamming_loss
  • Regression: mae

multiclass_accuracy also supports top-k via k. For multiclass classification, balanced_accuracy is the class-balanced recall and only supports the standard definition: average="macro" with k=1.

Custom Metric Descriptor (MetricFunc)

For consistent behavior across both containers, implement MetricFunc and read from MetricState.

Python
import torch
from danling.metrics import GlobalMetrics, StreamMetrics
from danling.metrics.functional import MetricFunc

class MeanBias(MetricFunc):
    def __init__(self, name: str = "mean_bias") -> None:
        super().__init__(name=name, preds_targets=True)

    def __call__(self, state):
        if state.preds.numel() == 0 or state.targets.numel() == 0:
            return torch.tensor(float("nan"))
        return (state.preds - state.targets).mean()

metric = MeanBias()

global_metrics = GlobalMetrics(metric, distributed=False)
stream_metrics = StreamMetrics(metric)

pred = torch.randn(16)
target = torch.randn(16)

global_metrics.update(pred, target)
stream_metrics.update(pred, target)

Multi-Task Usage

Python
import torch
import danling as dl

metrics = dl.metrics.MultiTaskMetrics()
metrics.cls = dl.metrics.binary_metrics(mode="stream")
metrics.reg = dl.metrics.regression_metrics(num_outputs=4, mode="global", distributed=False)

metrics.update(
    {
        "cls": (torch.randn(32), torch.randint(2, (32,))),
        "reg": (torch.randn(32, 4), torch.randn(32, 4)),
    }
)

print(metrics.avg)

Pass aggregate="macro" if you want equal task weighting, aggregate="micro" if you want sample-count weighting, or aggregate="weighted" together with aggregate_weights={"task": weight, ...} if you want explicit task weights. Aggregate outputs match metrics by exact relative metric path, so tasks with different metric namespaces stay separate rather than being merged by leaf name alone.

Registry Usage

Python
1
2
3
from danling.metrics import METRICS

metrics = METRICS.build(type="multiclass", mode="stream", num_classes=10)