Skip to content

Functional

danling.metric.functional

base_preprocess

Python
base_preprocess(
    input: Tensor | NestedTensor | Sequence,
    target: Tensor | NestedTensor | Sequence,
    ignore_index: int | None = None,
    ignore_nan: bool = False,
)

Basic preprocessing function for metric inputs and targets.

This function handles common preprocessing tasks for metric computation: 1. Converting inputs/targets to tensors or nested tensors 2. Handling nested tensors by concatenating them 3. Removing ignored indices (e.g., padding) 4. Removing NaN values 5. Properly reshaping tensors for metric functions

Parameters:

Name Type Description Default

input

Tensor | NestedTensor | Sequence

Model predictions or outputs - Can be Tensor, NestedTensor, or a sequence that can be converted to tensor

required

target

Tensor | NestedTensor | Sequence

Ground truth labels/values - Can be Tensor, NestedTensor, or a sequence that can be converted to tensor

required

ignore_index

int | None

Value in target to ignore (e.g., -100 for padding in classification tasks)

None

ignore_nan

bool

Whether to remove NaN values (useful for regression tasks)

False

Returns:

Name Type Description
tuple

(processed_input, processed_target) as tensors ready for metric computation

Examples:

Python Console Session
1
2
3
4
5
6
>>> # Basic usage with tensors
>>> input = torch.tensor([0.1, 0.8, 0.6, 0.2])
>>> target = torch.tensor([0, 1, 1, 0])
>>> proc_input, proc_target = base_preprocess(input, target)
>>> proc_input.shape, proc_target.shape
(torch.Size([4]), torch.Size([4]))
Python Console Session
1
2
3
4
5
6
>>> # Ignoring a specific value
>>> input = torch.tensor([0.1, 0.8, 0.6, 0.2])
>>> target = torch.tensor([0, -100, 1, 0])
>>> proc_input, proc_target = base_preprocess(input, target, ignore_index=-100)
>>> proc_input.shape, proc_target.shape
(torch.Size([3]), torch.Size([3]))
Python Console Session
1
2
3
4
5
6
>>> # Working with lists
>>> input = [0.1, 0.8, 0.6, 0.2]
>>> target = [0, 1, 1, 0]
>>> proc_input, proc_target = base_preprocess(input, target)
>>> proc_input.shape, proc_target.shape
(torch.Size([4]), torch.Size([4]))
Source code in danling/metric/functional/preprocess.py
Python
def base_preprocess(
    input: Tensor | NestedTensor | Sequence,
    target: Tensor | NestedTensor | Sequence,
    ignore_index: int | None = None,
    ignore_nan: bool = False,
):
    """
    Basic preprocessing function for metric inputs and targets.

    This function handles common preprocessing tasks for metric computation:
    1. Converting inputs/targets to tensors or nested tensors
    2. Handling nested tensors by concatenating them
    3. Removing ignored indices (e.g., padding)
    4. Removing NaN values
    5. Properly reshaping tensors for metric functions

    Args:
        input: Model predictions or outputs
            - Can be Tensor, NestedTensor, or a sequence that can be converted to tensor
        target: Ground truth labels/values
            - Can be Tensor, NestedTensor, or a sequence that can be converted to tensor
        ignore_index: Value in target to ignore (e.g., -100 for padding in classification tasks)
        ignore_nan: Whether to remove NaN values (useful for regression tasks)

    Returns:
        tuple: (processed_input, processed_target) as tensors ready for metric computation

    Examples:
        >>> # Basic usage with tensors
        >>> input = torch.tensor([0.1, 0.8, 0.6, 0.2])
        >>> target = torch.tensor([0, 1, 1, 0])
        >>> proc_input, proc_target = base_preprocess(input, target)
        >>> proc_input.shape, proc_target.shape
        (torch.Size([4]), torch.Size([4]))

        >>> # Ignoring a specific value
        >>> input = torch.tensor([0.1, 0.8, 0.6, 0.2])
        >>> target = torch.tensor([0, -100, 1, 0])
        >>> proc_input, proc_target = base_preprocess(input, target, ignore_index=-100)
        >>> proc_input.shape, proc_target.shape
        (torch.Size([3]), torch.Size([3]))

        >>> # Working with lists
        >>> input = [0.1, 0.8, 0.6, 0.2]
        >>> target = [0, 1, 1, 0]
        >>> proc_input, proc_target = base_preprocess(input, target)
        >>> proc_input.shape, proc_target.shape
        (torch.Size([4]), torch.Size([4]))
    """
    if not isinstance(input, (Tensor, NestedTensor)):
        try:
            input = torch.tensor(input)
        except ValueError:
            input = NestedTensor(input)
    if not isinstance(target, (Tensor, NestedTensor)):
        try:
            target = torch.tensor(target)
        except ValueError:
            target = NestedTensor(target)
    if isinstance(input, NestedTensor) or isinstance(target, NestedTensor):
        if isinstance(input, NestedTensor) and isinstance(target, Tensor):
            target = input.nested_like(target, strict=False)
        if isinstance(target, NestedTensor) and isinstance(input, Tensor):
            input = target.nested_like(input, strict=False)
        input, target = input.concat, target.concat
    if ignore_index is not None:
        mask = target != ignore_index
        input, target = input[mask], target[mask]
    if ignore_nan:
        mask = ~(torch.isnan(target))
        input, target = input[mask], target[mask]
    if input.numel() == target.numel():
        return input.squeeze(), target.squeeze()
    return input, target

with_preprocess

Python
with_preprocess(preprocess_fn: Callable, **default_kwargs)

Decorator to apply preprocessing to metric functions.

This decorator wraps metric functions to handle preprocessing of inputs before the metric is calculated. It handles common preprocessing tasks like ignoring certain values or converting input formats, making the metric functions more robust and easier to use.

Parameters:

Name Type Description Default

preprocess_fn

Callable

The preprocessing function to apply

required

**default_kwargs

Default values for the preprocessing function’s parameters

{}

Returns:

Type Description

Decorated function with preprocessing capability

Examples:

Python Console Session
1
2
3
4
5
>>> import torch
>>> @with_preprocess(preprocess_binary, ignore_index=-100)
... def my_regression_metric(input, target):
...     # Assumes input and target are already preprocessed
...     return input.mean() / target.mean()
Python Console Session
>>> # When called, preprocessing is automatically applied
>>> result = my_regression_metric(torch.rand(4), torch.rand(4))
Python Console Session
>>> # Preprocessing can be disabled
>>> result = my_regression_metric(torch.rand(4), torch.rand(4), preprocess=False)
Python Console Session
>>> # Additional preprocessing parameters can be passed
>>> result = my_regression_metric(torch.rand(4), torch.rand(4), ignore_index=-1)
Notes
  • The wrapper extracts parameters needed for preprocessing from **kwargs
  • If preprocess=False is passed, no preprocessing is applied
  • All other parameters are passed through to the metric function
Source code in danling/metric/functional/preprocess.py
Python
def with_preprocess(preprocess_fn: Callable, **default_kwargs):
    """
    Decorator to apply preprocessing to metric functions.

    This decorator wraps metric functions to handle preprocessing of inputs before the metric is calculated.
    It handles common preprocessing tasks like ignoring certain values or converting input formats,
    making the metric functions more robust and easier to use.

    Args:
        preprocess_fn: The preprocessing function to apply
        **default_kwargs: Default values for the preprocessing function's parameters

    Returns:
        Decorated function with preprocessing capability

    Examples:
        >>> import torch
        >>> @with_preprocess(preprocess_binary, ignore_index=-100)
        ... def my_regression_metric(input, target):
        ...     # Assumes input and target are already preprocessed
        ...     return input.mean() / target.mean()

        >>> # When called, preprocessing is automatically applied
        >>> result = my_regression_metric(torch.rand(4), torch.rand(4))

        >>> # Preprocessing can be disabled
        >>> result = my_regression_metric(torch.rand(4), torch.rand(4), preprocess=False)

        >>> # Additional preprocessing parameters can be passed
        >>> result = my_regression_metric(torch.rand(4), torch.rand(4), ignore_index=-1)

    Notes:
        - The wrapper extracts parameters needed for preprocessing from **kwargs
        - If preprocess=False is passed, no preprocessing is applied
        - All other parameters are passed through to the metric function
    """
    preprocess_params = set(signature(preprocess_fn).parameters.keys()) - {"input", "target"}

    def decorator(metric_fn: Callable) -> Callable:

        @wraps(metric_fn)
        def wrapper(
            input: Tensor | NestedTensor | Sequence,
            target: Tensor | NestedTensor | Sequence,
            *,
            preprocess: bool = True,
            **kwargs,
        ):
            metric_kwargs = {k: v for k, v in kwargs.items() if k not in default_kwargs}

            if not preprocess:
                return metric_fn(input, target, **metric_kwargs)

            preprocess_kwargs = copy(default_kwargs)
            for key in preprocess_params:
                if key in kwargs:
                    preprocess_kwargs[key] = kwargs[key]
            input, target = preprocess_fn(input, target, **preprocess_kwargs)
            return metric_fn(input, target, **metric_kwargs)

        return wrapper

    return decorator