class MultiTaskBase(FlatDict):
r"""
Container that groups meters for multiple tasks and aggregates them.
"""
def __init__(
self,
*args,
aggregate: Literal["macro", "micro", "weighted"] | None = None,
aggregate_weights: Mapping[str, float | int | Tensor] | None = None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
if aggregate not in {None, "macro", "micro", "weighted"}:
raise ValueError(f"aggregate must be one of None, 'macro', 'micro', or 'weighted', but got {aggregate!r}")
if aggregate == "weighted":
if aggregate_weights is None:
raise ValueError("aggregate_weights is required when aggregate='weighted'")
elif aggregate_weights is not None:
raise ValueError("aggregate_weights is only supported when aggregate='weighted'")
self.setattr("_aggregate", aggregate)
self.setattr(
"_aggregate_weights",
(
None
if aggregate_weights is None
else {
str(name): self._normalize_weight_value(weight, label=f"aggregate weight for task {name!r}")
for name, weight in aggregate_weights.items()
}
),
)
# Normalization helpers
@classmethod
def _normalize_metric_output(
cls, name: str, metrics: Any, value: Mapping[str, float | Tensor] | float | Tensor
) -> RoundDict[str, float | Tensor]:
if isinstance(value, Mapping):
if isinstance(value, RoundDict):
return value
return RoundDict(value)
output_name = getattr(metrics, "output_name", None)
if output_name is not None:
output_name = str(output_name)
if output_name in {"", "<lambda>", "__call__"}:
output_name = None
if output_name is None:
output_name = name
return RoundDict({output_name: value})
@classmethod
def _flatten_mapping(cls, mapping: Mapping[str, Any]) -> dict[str, Any]:
if hasattr(mapping, "all_items"):
return dict(mapping.all_items())
flat: dict[str, Any] = {}
for key, value in mapping.items():
path = str(key)
if isinstance(value, Mapping):
for nested_key, nested_value in cls._flatten_mapping(value).items():
flat[f"{path}.{nested_key}"] = nested_value
else:
flat[path] = value
return flat
@staticmethod
def _is_nan_value(value: Any) -> bool:
if isinstance(value, Tensor):
return bool(value.isnan().all().item())
try:
return isnan(value)
except TypeError:
return False
@staticmethod
def _to_averageable(value: Any, *, label: str) -> float | Tensor:
if isinstance(value, Tensor):
return value.detach().to(dtype=torch.float64)
return float(value)
@staticmethod
def _normalize_weight_value(value: Any, *, label: str) -> float:
if isinstance(value, Tensor):
if value.numel() != 1:
raise ValueError(f"{label} must be scalar, but got shape {tuple(value.shape)}")
value = float(value.item())
else:
value = float(value)
if not isfinite(value) or value < 0:
raise ValueError(f"{label} must be a non-negative finite scalar, but got {value!r}")
return value
@staticmethod
def _reduction_weight_attr(reduction: Reduction) -> str:
if reduction in {"value", "batch"}:
return "n"
if reduction == "average":
return "count"
raise ValueError(f"Unsupported reduction: {reduction!r}")
@classmethod
def _weight_for_path(
cls,
task_name: str,
path: str,
weight_source: Mapping[str, Any] | Any,
*,
label_prefix: str,
) -> float:
if isinstance(weight_source, Mapping):
flat_weights = cls._flatten_mapping(weight_source)
if path not in flat_weights:
raise ValueError(f"{label_prefix} is missing a weight for metric '{task_name}.{path}'")
return cls._normalize_weight_value(
flat_weights[path], label=f"{label_prefix} for metric '{task_name}.{path}'"
)
return cls._normalize_weight_value(weight_source, label=f"{label_prefix} for task {task_name!r}")
@staticmethod
def _sync_weights(weights: list[float]) -> list[float]:
if not weights or not (dist.is_available() and dist.is_initialized()):
return weights
device = infer_device()
reduced = torch.tensor(weights, dtype=torch.float64, device=device)
dist.all_reduce(reduced)
return reduced.tolist()
# Public reductions
def _collect_output(self, reduction: Reduction) -> RoundDict[str, float | Tensor]:
output = RoundDict()
for key, metrics in self.all_items():
value = self._normalize_metric_output(key, metrics, getattr(metrics, reduction)())
if all(self._is_nan_value(v) for v in value.all_values()):
continue
output[key] = value
aggregate = self.getattr("_aggregate", None)
if aggregate is not None:
output["aggregate"] = self.compute_aggregate(output, reduction)
return output
def value(self) -> RoundDict[str, float | Tensor]:
return self._collect_output("value")
def batch(self) -> RoundDict[str, float | Tensor]:
return self._collect_output("batch")
def average(self) -> RoundDict[str, float | Tensor]:
return self._collect_output("average")
def compute_aggregate(
self,
output: RoundDict[str, float | Tensor],
reduction: Reduction,
) -> RoundDict[str, float | Tensor]:
aggregate = self.getattr("_aggregate", None)
if aggregate is None:
return RoundDict()
if aggregate == "macro":
return self.compute_average(output)
if aggregate == "micro":
return self.compute_weighted_average(output, reduction=reduction, mode="micro")
return self.compute_weighted_average(output, reduction=reduction, mode="weighted")
def compute_average(self, output: RoundDict[str, float | Tensor]) -> RoundDict[str, float | Tensor]:
totals: dict[str, float | Tensor] = {}
counts: dict[str, int] = {}
for task_name, task_output in output.items():
if task_name == "aggregate":
continue
flat_output = self._flatten_mapping(task_output)
for path, value in flat_output.items():
if self._is_nan_value(value):
continue
averageable_value = self._to_averageable(value, label=f"metric '{task_name}.{path}'")
if path not in totals:
totals[path] = (
averageable_value.clone() if isinstance(averageable_value, Tensor) else averageable_value
)
else:
total = totals[path]
if isinstance(averageable_value, Tensor):
if not isinstance(total, Tensor):
raise ValueError(f"metric '{path}' mixes scalar and tensor outputs across tasks")
if total.shape != averageable_value.shape:
raise ValueError(
f"metric '{path}' has inconsistent tensor shapes across tasks: "
f"{tuple(total.shape)} vs {tuple(averageable_value.shape)}"
)
totals[path] = total + averageable_value
else:
if isinstance(total, Tensor):
raise ValueError(f"metric '{path}' mixes tensor and scalar outputs across tasks")
totals[path] = total + averageable_value
counts[path] = counts.get(path, 0) + 1
average = RoundDict()
for path, total in totals.items():
average[path] = total / counts[path]
return average
def compute_weighted_average(
self,
output: RoundDict[str, float | Tensor],
*,
reduction: Reduction,
mode: Literal["micro", "weighted"],
) -> RoundDict[str, float | Tensor]:
if mode == "weighted":
task_weights = self.getattr("_aggregate_weights", None)
if task_weights is None:
raise ValueError("aggregate_weights is required when aggregate='weighted'")
unknown_tasks = set(task_weights) - {str(name) for name in self.keys()}
if unknown_tasks:
raise ValueError(f"aggregate_weights contains unknown tasks: {sorted(unknown_tasks)!r}")
else:
task_weights = None
weighted_entries: list[tuple[str, float | Tensor, float]] = []
for task_name, task_output in output.items():
if task_name == "aggregate":
continue
flat_output = self._flatten_mapping(task_output)
if mode == "micro":
weight_attr = self._reduction_weight_attr(reduction)
if not hasattr(self[task_name], weight_attr):
raise ValueError(
f"micro aggregate requires task {task_name!r} to expose {weight_attr!r} sample counts"
)
weight_source = getattr(self[task_name], weight_attr)
label_prefix = "sample weight"
else:
if task_name not in task_weights:
raise ValueError(f"aggregate_weights is missing a weight for task {task_name!r}")
weight_source = task_weights[task_name]
label_prefix = "aggregate weight"
for path, value in flat_output.items():
if self._is_nan_value(value):
continue
averageable_value = self._to_averageable(value, label=f"metric '{task_name}.{path}'")
weight = self._weight_for_path(task_name, path, weight_source, label_prefix=label_prefix)
weighted_entries.append((path, averageable_value, weight))
if mode == "micro" and reduction in {"batch", "average"}:
synced_weights = self._sync_weights([weight for _, _, weight in weighted_entries])
else:
synced_weights = [weight for _, _, weight in weighted_entries]
totals: dict[str, float | Tensor] = {}
weight_totals: dict[str, float] = {}
for (path, averageable_value, _), weight in zip(weighted_entries, synced_weights):
if weight <= 0:
continue
if path not in totals:
totals[path] = averageable_value * weight
else:
total = totals[path]
if isinstance(averageable_value, Tensor):
if not isinstance(total, Tensor):
raise ValueError(f"metric '{path}' mixes scalar and tensor outputs across tasks")
if total.shape != averageable_value.shape:
raise ValueError(
f"metric '{path}' has inconsistent tensor shapes across tasks: "
f"{tuple(total.shape)} vs {tuple(averageable_value.shape)}"
)
totals[path] = total + averageable_value * weight
else:
if isinstance(total, Tensor):
raise ValueError(f"metric '{path}' mixes tensor and scalar outputs across tasks")
totals[path] = total + averageable_value * weight
weight_totals[path] = weight_totals.get(path, 0.0) + weight
average = RoundDict()
for path, total in totals.items():
total_weight = weight_totals[path]
if total_weight > 0:
average[path] = total / total_weight
return average
# Public aliases
@property
def val(self) -> RoundDict[str, float | Tensor]:
return self.value()
@property
def bat(self) -> RoundDict[str, float | Tensor]:
return self.batch()
@property
def avg(self) -> RoundDict[str, float | Tensor]:
return self.average()
# Lifecycle
def reset(self) -> Self:
for metric in self.all_values():
metric.reset()
return self
# Formatting helpers
def __format__(self, format_spec: str) -> str:
return "\n".join(f"{key}: {metric.__format__(format_spec)}" for key, metric in self.all_items())