class MultiTaskDict(NestedDict):
r"""
A `MultiTaskDict` for better multi-task support.
"""
return_average = False
def __init__(self, *args, return_average: bool = False, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.setattr("return_average", return_average)
def value(self) -> RoundDict[str, float]:
output = RoundDict()
for key, metrics in self.all_items():
value = metrics.value()
if all(isnan(v) for v in value.values()):
continue
output[key] = value
if self.getattr("return_average", False):
output["average"] = self.compute_average(output)
return output
def batch(self) -> RoundDict[str, float]:
output = RoundDict()
for key, metrics in self.all_items():
value = metrics.batch()
if all(isnan(v) for v in value.values()):
continue
output[key] = value
if self.getattr("return_average", False):
output["average"] = self.compute_average(output)
return output
def average(self) -> RoundDict[str, float]:
output = RoundDict()
for key, metrics in self.all_items():
value = metrics.average()
if all(isnan(v) for v in value.values()):
continue
output[key] = value
if self.getattr("return_average", False):
output["average"] = self.compute_average(output)
return output
def compute_average(self, output: RoundDict[str, float]) -> RoundDict[str, float]:
average = DefaultDict(default_factory=list)
for key, metric in output.all_items():
average[key.rsplit(".", 1)[-1]].append(metric)
return RoundDict({key: sum(values) / len(values) for key, values in average.items()})
@property
def val(self) -> RoundDict[str, float]:
return self.value()
@property
def bat(self) -> RoundDict[str, float]:
return self.batch()
@property
def avg(self) -> RoundDict[str, float]:
return self.average()
def reset(self) -> Self:
for metric in self.all_values():
metric.reset()
return self
def __format__(self, format_spec: str) -> str:
return "\n".join(f"{key}: {metric.__format__(format_spec)}" for key, metric in self.all_items())