Bases: Registry
Registry for extending PyTorch functions to work with custom tensor types like NestedTensor.
TorchFuncRegistry
provides a clean interface for implementing PyTorch function
overrides for custom tensor types such as NestedTensor. It’s used internally by
NestedTensor to register implementations for various torch functions like
torch.cat, torch.mean, torch.stack, etc.
This mechanism enables NestedTensor to behave like a regular torch.Tensor
when used with standard PyTorch functions by providing custom implementations
that understand the NestedTensor structure.
Usage:
Python |
---|
| # Create a registry
registry = TorchFuncRegistry("my_tensor_registry")
# Register an implementation for torch.mean
@registry.implement(torch.mean)
def mean_implementation(input, dim=None, keepdim=False, **kwargs):
# Custom implementation for your tensor type
pass
# The registry can be used to look up the implementation
registry[torch.mean] # Returns mean_implementation
|
Source code in danling/tensor/utils.py
Python |
---|
| class TorchFuncRegistry(Registry): # pylint: disable=too-few-public-methods
"""
Registry for extending PyTorch functions to work with custom tensor types like NestedTensor.
`TorchFuncRegistry` provides a clean interface for implementing PyTorch function
overrides for custom tensor types such as NestedTensor. It's used internally by
NestedTensor to register implementations for various torch functions like
torch.cat, torch.mean, torch.stack, etc.
This mechanism enables NestedTensor to behave like a regular torch.Tensor
when used with standard PyTorch functions by providing custom implementations
that understand the NestedTensor structure.
Usage:
```python
# Create a registry
registry = TorchFuncRegistry("my_tensor_registry")
# Register an implementation for torch.mean
@registry.implement(torch.mean)
def mean_implementation(input, dim=None, keepdim=False, **kwargs):
# Custom implementation for your tensor type
pass
# The registry can be used to look up the implementation
registry[torch.mean] # Returns mean_implementation
```
"""
def implement(self, torch_function: Callable) -> Callable:
r"""
Register a custom implementation for a PyTorch function.
Use this decorator to provide implementations for PyTorch functions
that will work with custom tensor types like NestedTensor. This is
the key mechanism that allows NestedTensor to integrate seamlessly
with the PyTorch ecosystem.
Args:
torch_function: The original PyTorch function to override (e.g., torch.mean, torch.cat)
Returns:
Callable: A decorator function that registers the implementation
Raises:
ValueError: If the function is already registered and override=False
Examples:
>>> import torch
>>> registry = TorchFuncRegistry("test")
>>> @registry.implement(torch.mean)
... def mean(input):
... return input.mean()
>>> registry[torch.mean] # doctest: +ELLIPSIS
<function mean at ...>
Note:
This is primarily used internally by NestedTensor.__torch_function__
to provide implementations for various PyTorch functions. You can
use the same mechanism to extend NestedTensor with additional
function implementations.
"""
if torch_function in self and not self.override:
raise ValueError(f"Torch function {torch_function.__name__} already registered.")
@wraps(self.register)
def register(function):
self.set(torch_function, function)
return function
return register
|
implement
Register a custom implementation for a PyTorch function.
Use this decorator to provide implementations for PyTorch functions
that will work with custom tensor types like NestedTensor. This is
the key mechanism that allows NestedTensor to integrate seamlessly
with the PyTorch ecosystem.
Parameters:
Name |
Type |
Description |
Default |
torch_function
|
Callable
|
The original PyTorch function to override (e.g., torch.mean, torch.cat)
|
required
|
Returns:
Name | Type |
Description |
Callable |
Callable
|
A decorator function that registers the implementation
|
Raises:
Type |
Description |
ValueError
|
If the function is already registered and override=False
|
Examples:
Python Console Session |
---|
| >>> import torch
>>> registry = TorchFuncRegistry("test")
>>> @registry.implement(torch.mean)
... def mean(input):
... return input.mean()
>>> registry[torch.mean]
<function mean at ...>
|
Note
This is primarily used internally by NestedTensor.torch_function
to provide implementations for various PyTorch functions. You can
use the same mechanism to extend NestedTensor with additional
function implementations.
Source code in danling/tensor/utils.py
Python |
---|
| def implement(self, torch_function: Callable) -> Callable:
r"""
Register a custom implementation for a PyTorch function.
Use this decorator to provide implementations for PyTorch functions
that will work with custom tensor types like NestedTensor. This is
the key mechanism that allows NestedTensor to integrate seamlessly
with the PyTorch ecosystem.
Args:
torch_function: The original PyTorch function to override (e.g., torch.mean, torch.cat)
Returns:
Callable: A decorator function that registers the implementation
Raises:
ValueError: If the function is already registered and override=False
Examples:
>>> import torch
>>> registry = TorchFuncRegistry("test")
>>> @registry.implement(torch.mean)
... def mean(input):
... return input.mean()
>>> registry[torch.mean] # doctest: +ELLIPSIS
<function mean at ...>
Note:
This is primarily used internally by NestedTensor.__torch_function__
to provide implementations for various PyTorch functions. You can
use the same mechanism to extend NestedTensor with additional
function implementations.
"""
if torch_function in self and not self.override:
raise ValueError(f"Torch function {torch_function.__name__} already registered.")
@wraps(self.register)
def register(function):
self.set(torch_function, function)
return function
return register
|