跳转至

Utils

danling.metrics.utils

MetersBase

Bases: DefaultDict

Base container for collections of meter objects.

Subclasses can provide a meter_cls attribute to enforce the type of values stored in the dictionary and customise how callable objects are converted into meters.

Source code in danling/metrics/utils.py
Python
class MetersBase(DefaultDict):
    r"""Base container for collections of meter objects.

    Subclasses can provide a `meter_cls` attribute to enforce the type of
    values stored in the dictionary and customise how callable objects are
    converted into meters.
    """

    meter_cls: Optional[type] = None

    def __init__(self, *args: Mapping[str, Any] | None, default_factory=None, **meters: Any) -> None:
        self.__dict__["_metricsdict_initialising"] = True
        meter_cls = getattr(self, "meter_cls", None)
        factory = default_factory if default_factory is not None else meter_cls
        super().__init__(default_factory=factory)
        self.__dict__["_metricsdict_initialising"] = False
        dict.pop(self, "meter_cls", None)

        initial: dict[str, Any] = {}
        if args:
            if len(args) > 1:
                raise TypeError("MetersBase accepts at most one positional mapping argument.")
            mapping = args[0]
            if mapping is not None:
                initial.update(dict(mapping))
        if meters:
            initial.update(meters)
        for name, meter in initial.items():
            self.set(name, meter)

    def set(self, name: Any, value: Any) -> None:
        if self.__dict__.get("_metricsdict_initialising", False):
            super().set(name, value)
            return
        super().set(name, self._coerce_meter(value))

    def _coerce_meter(self, value: Any):
        meter_cls = getattr(self, "meter_cls", None)
        if meter_cls is None or isinstance(value, meter_cls):
            return value
        raise ValueError(f"Expected value to be an instance of {meter_cls.__name__}, but got {type(value)}")

    def value(self) -> RoundDict[str, float]:
        return RoundDict({key: metric.value() for key, metric in self.all_items()})

    def batch(self) -> RoundDict[str, float]:
        return RoundDict({key: metric.batch() for key, metric in self.all_items()})

    def average(self) -> RoundDict[str, float]:
        return RoundDict({key: metric.average() for key, metric in self.all_items()})

    @property
    def val(self) -> RoundDict[str, float]:
        return self.value()

    @property
    def bat(self) -> RoundDict[str, float]:
        return self.batch()

    @property
    def avg(self) -> RoundDict[str, float]:
        return self.average()

    def reset(self) -> Self:
        for metric in self.all_values():
            metric.reset()
        return self

    def __format__(self, format_spec: str) -> str:
        return "\t".join(f"{key}: {metric.__format__(format_spec)}" for key, metric in self.all_items())

MultiTaskBase

Bases: NestedDict

Container that groups meters for multiple tasks and aggregates them.

Source code in danling/metrics/utils.py
Python
class MultiTaskBase(NestedDict):
    r"""
    Container that groups meters for multiple tasks and aggregates them.
    """

    return_average = False

    def __init__(self, *args, return_average: bool = False, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.setattr("return_average", return_average)

    def value(self) -> RoundDict[str, float]:
        output = RoundDict()
        for key, metrics in self.all_items():
            value = metrics.value()
            if all(isnan(v) for v in value.values()):
                continue
            output[key] = value
        if self.getattr("return_average", False):
            output["average"] = self.compute_average(output)
        return output

    def batch(self) -> RoundDict[str, float]:
        output = RoundDict()
        for key, metrics in self.all_items():
            value = metrics.batch()
            if all(isnan(v) for v in value.values()):
                continue
            output[key] = value
        if self.getattr("return_average", False):
            output["average"] = self.compute_average(output)
        return output

    def average(self) -> RoundDict[str, float]:
        output = RoundDict()
        for key, metrics in self.all_items():
            value = metrics.average()
            if all(isnan(v) for v in value.values()):
                continue
            output[key] = value
        if self.getattr("return_average", False):
            output["average"] = self.compute_average(output)
        return output

    def compute_average(self, output: RoundDict[str, float]) -> RoundDict[str, float]:
        average = DefaultDict(default_factory=list)
        for key, metric in output.all_items():
            average[key.rsplit(".", 1)[-1]].append(metric)
        return RoundDict({key: sum(values) / len(values) for key, values in average.items()})

    @property
    def val(self) -> RoundDict[str, float]:
        return self.value()

    @property
    def bat(self) -> RoundDict[str, float]:
        return self.batch()

    @property
    def avg(self) -> RoundDict[str, float]:
        return self.average()

    def reset(self) -> Self:
        for metric in self.all_values():
            metric.reset()
        return self

    def __format__(self, format_spec: str) -> str:
        return "\n".join(f"{key}: {metric.__format__(format_spec)}" for key, metric in self.all_items())