跳转至

TorchFuncRegistry

danling.tensor.TorchFuncRegistry

Bases: Registry

Registry for extending PyTorch functions to work with custom tensor types like NestedTensor.

TorchFuncRegistry provides a clean interface for implementing PyTorch function overrides for custom tensor types such as NestedTensor. It’s used internally by NestedTensor to register implementations for various torch functions like torch.cat, torch.mean, torch.stack, etc.

This mechanism enables NestedTensor to behave like a regular torch.Tensor when used with standard PyTorch functions by providing custom implementations that understand the NestedTensor structure.

Usage:

Python
# Create a registry
registry = TorchFuncRegistry("my_tensor_registry")

# Register an implementation for torch.mean
@registry.implement(torch.mean)
def mean_implementation(input, dim=None, keepdim=False, **kwargs):
    # Custom implementation for your tensor type
    pass

# The registry can be used to look up the implementation
registry[torch.mean]  # Returns mean_implementation

Source code in danling/tensor/utils.py
Python
class TorchFuncRegistry(Registry):  # pylint: disable=too-few-public-methods
    """
    Registry for extending PyTorch functions to work with custom tensor types like NestedTensor.

    `TorchFuncRegistry` provides a clean interface for implementing PyTorch function
    overrides for custom tensor types such as NestedTensor. It's used internally by
    NestedTensor to register implementations for various torch functions like
    torch.cat, torch.mean, torch.stack, etc.

    This mechanism enables NestedTensor to behave like a regular torch.Tensor
    when used with standard PyTorch functions by providing custom implementations
    that understand the NestedTensor structure.

    Usage:
    ```python
    # Create a registry
    registry = TorchFuncRegistry("my_tensor_registry")

    # Register an implementation for torch.mean
    @registry.implement(torch.mean)
    def mean_implementation(input, dim=None, keepdim=False, **kwargs):
        # Custom implementation for your tensor type
        pass

    # The registry can be used to look up the implementation
    registry[torch.mean]  # Returns mean_implementation
    ```
    """

    def implement(self, torch_function: Callable) -> Callable:
        r"""
        Register a custom implementation for a PyTorch function.

        Use this decorator to provide implementations for PyTorch functions
        that will work with custom tensor types like NestedTensor. This is
        the key mechanism that allows NestedTensor to integrate seamlessly
        with the PyTorch ecosystem.

        Args:
            torch_function: The original PyTorch function to override (e.g., torch.mean, torch.cat)

        Returns:
            Callable: A decorator function that registers the implementation

        Raises:
            ValueError: If the function is already registered and override=False

        Examples:
            >>> import torch
            >>> registry = TorchFuncRegistry("test")
            >>> @registry.implement(torch.mean)
            ... def mean(input):
            ...     return input.mean()
            >>> registry[torch.mean]  # doctest: +ELLIPSIS
            <function mean at ...>

        Note:
            This is primarily used internally by NestedTensor.__torch_function__
            to provide implementations for various PyTorch functions. You can
            use the same mechanism to extend NestedTensor with additional
            function implementations.
        """

        if torch_function in self and not self.override:
            raise ValueError(f"Torch function {torch_function.__name__} already registered.")

        @wraps(self.register)
        def register(function):
            self.set(torch_function, function)
            return function

        return register

implement

Python
implement(torch_function: Callable) -> Callable

Register a custom implementation for a PyTorch function.

Use this decorator to provide implementations for PyTorch functions that will work with custom tensor types like NestedTensor. This is the key mechanism that allows NestedTensor to integrate seamlessly with the PyTorch ecosystem.

Parameters:

Name Type Description Default

torch_function

Callable

The original PyTorch function to override (e.g., torch.mean, torch.cat)

required

Returns:

Name Type Description
Callable Callable

A decorator function that registers the implementation

Raises:

Type Description
ValueError

If the function is already registered and override=False

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> import torch
>>> registry = TorchFuncRegistry("test")
>>> @registry.implement(torch.mean)
... def mean(input):
...     return input.mean()
>>> registry[torch.mean]
<function mean at ...>
Note

This is primarily used internally by NestedTensor.torch_function to provide implementations for various PyTorch functions. You can use the same mechanism to extend NestedTensor with additional function implementations.

Source code in danling/tensor/utils.py
Python
def implement(self, torch_function: Callable) -> Callable:
    r"""
    Register a custom implementation for a PyTorch function.

    Use this decorator to provide implementations for PyTorch functions
    that will work with custom tensor types like NestedTensor. This is
    the key mechanism that allows NestedTensor to integrate seamlessly
    with the PyTorch ecosystem.

    Args:
        torch_function: The original PyTorch function to override (e.g., torch.mean, torch.cat)

    Returns:
        Callable: A decorator function that registers the implementation

    Raises:
        ValueError: If the function is already registered and override=False

    Examples:
        >>> import torch
        >>> registry = TorchFuncRegistry("test")
        >>> @registry.implement(torch.mean)
        ... def mean(input):
        ...     return input.mean()
        >>> registry[torch.mean]  # doctest: +ELLIPSIS
        <function mean at ...>

    Note:
        This is primarily used internally by NestedTensor.__torch_function__
        to provide implementations for various PyTorch functions. You can
        use the same mechanism to extend NestedTensor with additional
        function implementations.
    """

    if torch_function in self and not self.override:
        raise ValueError(f"Torch function {torch_function.__name__} already registered.")

    @wraps(self.register)
    def register(function):
        self.set(torch_function, function)
        return function

    return register