Skip to content

Utils

danling.metric.utils

MetricsDict

Bases: DefaultDict

A MetricsDict for better support for AverageMeters.

Source code in danling/metric/utils.py
Python
class MetricsDict(DefaultDict):
    r"""
    A `MetricsDict` for better support for `AverageMeters`.
    """

    def value(self) -> RoundDict[str, float]:
        return RoundDict({key: metric.value() for key, metric in self.all_items()})

    def batch(self) -> RoundDict[str, float]:
        return RoundDict({key: metric.batch() for key, metric in self.all_items()})

    def average(self) -> RoundDict[str, float]:
        return RoundDict({key: metric.average() for key, metric in self.all_items()})

    @property
    def val(self) -> RoundDict[str, float]:
        return self.value()

    @property
    def bat(self) -> RoundDict[str, float]:
        return self.batch()

    @property
    def avg(self) -> RoundDict[str, float]:
        return self.average()

    def reset(self) -> Self:
        for metric in self.all_values():
            metric.reset()
        return self

    def __format__(self, format_spec: str) -> str:
        return "\t".join(f"{key}: {metric.__format__(format_spec)}" for key, metric in self.all_items())

MultiTaskDict

Bases: NestedDict

A MultiTaskDict for better multi-task support.

Source code in danling/metric/utils.py
Python
class MultiTaskDict(NestedDict):
    r"""
    A `MultiTaskDict` for better multi-task support.
    """

    return_average = False

    def __init__(self, *args, return_average: bool = False, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.setattr("return_average", return_average)

    def value(self) -> RoundDict[str, float]:
        output = RoundDict()
        for key, metrics in self.all_items():
            value = metrics.value()
            if all(isnan(v) for v in value.values()):
                continue
            output[key] = value
        if self.getattr("return_average", False):
            output["average"] = self.compute_average(output)
        return output

    def batch(self) -> RoundDict[str, float]:
        output = RoundDict()
        for key, metrics in self.all_items():
            value = metrics.batch()
            if all(isnan(v) for v in value.values()):
                continue
            output[key] = value
        if self.getattr("return_average", False):
            output["average"] = self.compute_average(output)
        return output

    def average(self) -> RoundDict[str, float]:
        output = RoundDict()
        for key, metrics in self.all_items():
            value = metrics.average()
            if all(isnan(v) for v in value.values()):
                continue
            output[key] = value
        if self.getattr("return_average", False):
            output["average"] = self.compute_average(output)
        return output

    def compute_average(self, output: RoundDict[str, float]) -> RoundDict[str, float]:
        average = DefaultDict(default_factory=list)
        for key, metric in output.all_items():
            average[key.rsplit(".", 1)[-1]].append(metric)
        return RoundDict({key: sum(values) / len(values) for key, values in average.items()})

    @property
    def val(self) -> RoundDict[str, float]:
        return self.value()

    @property
    def bat(self) -> RoundDict[str, float]:
        return self.batch()

    @property
    def avg(self) -> RoundDict[str, float]:
        return self.average()

    def reset(self) -> Self:
        for metric in self.all_values():
            metric.reset()
        return self

    def __format__(self, format_spec: str) -> str:
        return "\n".join(f"{key}: {metric.__format__(format_spec)}" for key, metric in self.all_items())