Skip to content

TorchFuncRegistry

danling.tensors.TorchFuncRegistry

Bases: Registry

TorchFuncRegistry for extending PyTorch Tensor.

Source code in danling/tensors/utils.py
Python
class TorchFuncRegistry(Registry):  # pylint: disable=too-few-public-methods
    """
    `TorchFuncRegistry` for extending PyTorch Tensor.
    """

    def implement(self, torch_function: Callable) -> Callable:
        r"""
        Implement an implementation for a torch function.

        Args:
            function: The torch function to implement.

        Returns:
            function: The registered function.

        Raises:
            ValueError: If the function with the same name already registered and `TorchFuncRegistry.override=False`.

        Examples:
            >>> import torch
            >>> registry = TorchFuncRegistry("test")
            >>> @registry.implement(torch.mean)
            ... def mean(input):
            ...     raise input.mean()
            >>> registry  # doctest: +ELLIPSIS
            TorchFuncRegistry(
              (<built-in method mean of type object at ...>): <function mean at ...>
            )
        """

        if torch_function in self and not self.override:
            raise ValueError(f"Torch function {torch_function.__name__} already registered.")

        @wraps(self.register)
        def register(function):
            self.set(torch_function, function)
            return function

        return register

implement

Python
implement(torch_function: Callable) -> Callable

Implement an implementation for a torch function.

Parameters:

Name Type Description Default

function

The torch function to implement.

required

Returns:

Name Type Description
function Callable

The registered function.

Raises:

Type Description
ValueError

If the function with the same name already registered and TorchFuncRegistry.override=False.

Examples:

Python Console Session
>>> import torch
>>> registry = TorchFuncRegistry("test")
>>> @registry.implement(torch.mean)
... def mean(input):
...     raise input.mean()
>>> registry
TorchFuncRegistry(
  (<built-in method mean of type object at ...>): <function mean at ...>
)
Source code in danling/tensors/utils.py
Python
def implement(self, torch_function: Callable) -> Callable:
    r"""
    Implement an implementation for a torch function.

    Args:
        function: The torch function to implement.

    Returns:
        function: The registered function.

    Raises:
        ValueError: If the function with the same name already registered and `TorchFuncRegistry.override=False`.

    Examples:
        >>> import torch
        >>> registry = TorchFuncRegistry("test")
        >>> @registry.implement(torch.mean)
        ... def mean(input):
        ...     raise input.mean()
        >>> registry  # doctest: +ELLIPSIS
        TorchFuncRegistry(
          (<built-in method mean of type object at ...>): <function mean at ...>
        )
    """

    if torch_function in self and not self.override:
        raise ValueError(f"Torch function {torch_function.__name__} already registered.")

    @wraps(self.register)
    def register(function):
        self.set(torch_function, function)
        return function

    return register