Registry for extending PyTorch functions to work with custom tensor types like NestedTensor.
TorchFuncRegistry provides a clean interface for implementing PyTorch function
overrides for custom tensor types such as NestedTensor. It’s used internally by
NestedTensor to register implementations for various torch functions like
torch.cat, torch.mean, torch.stack, etc.
This mechanism enables NestedTensor to behave like a regular torch.Tensor
when used with standard PyTorch functions by providing custom implementations
that understand the NestedTensor structure.
# Create a registryregistry=TorchFuncRegistry("my_tensor_registry")# Register an implementation for torch.mean@registry.implement(torch.mean)defmean_implementation(input,dim=None,keepdim=False,**kwargs):# Custom implementation for your tensor typepass# The registry can be used to look up the implementationregistry[torch.mean]# Returns mean_implementation
classTorchFuncRegistry(Registry):# pylint: disable=too-few-public-methods""" Registry for extending PyTorch functions to work with custom tensor types like NestedTensor. `TorchFuncRegistry` provides a clean interface for implementing PyTorch function overrides for custom tensor types such as NestedTensor. It's used internally by NestedTensor to register implementations for various torch functions like torch.cat, torch.mean, torch.stack, etc. This mechanism enables NestedTensor to behave like a regular torch.Tensor when used with standard PyTorch functions by providing custom implementations that understand the NestedTensor structure. Usage: ```python # Create a registry registry = TorchFuncRegistry("my_tensor_registry") # Register an implementation for torch.mean @registry.implement(torch.mean) def mean_implementation(input, dim=None, keepdim=False, **kwargs): # Custom implementation for your tensor type pass # The registry can be used to look up the implementation registry[torch.mean] # Returns mean_implementation ``` """defimplement(self,torch_function:Callable)->Callable:r""" Register a custom implementation for a PyTorch function. Use this decorator to provide implementations for PyTorch functions that will work with custom tensor types like NestedTensor. This is the key mechanism that allows NestedTensor to integrate seamlessly with the PyTorch ecosystem. Args: torch_function: The original PyTorch function to override (e.g., torch.mean, torch.cat) Returns: Callable: A decorator function that registers the implementation Raises: ValueError: If the function is already registered and override=False Examples: >>> import torch >>> registry = TorchFuncRegistry("test") >>> @registry.implement(torch.mean) ... def mean(input): ... return input.mean() >>> registry[torch.mean] # doctest: +ELLIPSIS <function mean at ...> Note: This is primarily used internally by NestedTensor.__torch_function__ to provide implementations for various PyTorch functions. You can use the same mechanism to extend NestedTensor with additional function implementations. """iftorch_functioninselfandnotself.override:raiseValueError(f"Torch function {torch_function.__name__} already registered.")@wraps(self.register)defregister(function):self.set(torch_function,function)returnfunctionreturnregister
Register a custom implementation for a PyTorch function.
Use this decorator to provide implementations for PyTorch functions
that will work with custom tensor types like NestedTensor. This is
the key mechanism that allows NestedTensor to integrate seamlessly
with the PyTorch ecosystem.
>>> importtorch>>> registry=TorchFuncRegistry("test")>>> @registry.implement(torch.mean)... defmean(input):... returninput.mean()>>> registry[torch.mean]<function mean at ...>
Note
This is primarily used internally by NestedTensor.torch_function
to provide implementations for various PyTorch functions. You can
use the same mechanism to extend NestedTensor with additional
function implementations.
defimplement(self,torch_function:Callable)->Callable:r""" Register a custom implementation for a PyTorch function. Use this decorator to provide implementations for PyTorch functions that will work with custom tensor types like NestedTensor. This is the key mechanism that allows NestedTensor to integrate seamlessly with the PyTorch ecosystem. Args: torch_function: The original PyTorch function to override (e.g., torch.mean, torch.cat) Returns: Callable: A decorator function that registers the implementation Raises: ValueError: If the function is already registered and override=False Examples: >>> import torch >>> registry = TorchFuncRegistry("test") >>> @registry.implement(torch.mean) ... def mean(input): ... return input.mean() >>> registry[torch.mean] # doctest: +ELLIPSIS <function mean at ...> Note: This is primarily used internally by NestedTensor.__torch_function__ to provide implementations for various PyTorch functions. You can use the same mechanism to extend NestedTensor with additional function implementations. """iftorch_functioninselfandnotself.override:raiseValueError(f"Torch function {torch_function.__name__} already registered.")@wraps(self.register)defregister(function):self.set(torch_function,function)returnfunctionreturnregister
classNestedTensorFuncWrapper:# pylint: disable=R0903r""" Function Wrapper to handle NestedTensor as input. """__storage:Sequence[Callable]=[]state:Mapping={}def__init__(self,*callables:Iterable[Callable],state:Mapping|None=None)->None:iflen(callables)==1andisinstance(callables,Sequence):callables=callables[0]# type: ignoreself._storage=callables# type: ignoreifstateisNone:state={}self.state=stateself.device=self.state.get("device")@propertydef_storage(self):returnself.__storage@_storage.setterdef_storage(self,callables:Sequence):ifnotisinstance(callables,Sequence):raiseValueError(f"callables must be a Sequence, bug got {type(callables)}")iflen(callables)==0:raiseValueError("callables must be a non-empty Sequence.")ifnotcallable(callables[0]):raiseValueError(f"callables must be a Sequence of Callable, bug got {type(callables[0])}")self.__storage=callablesdef__call__(self,*args,**kwargs)->NestedTensor|Tensor|Sequence[Tensor]:from.nested_tensorimportNestedTensorret=[call(*args,**kwargs)forcallinself._storage]elem=ret[0]ifisinstance(elem,Tensor):try:returntorch.stack(ret,dim=0)except(ValueError,RuntimeError):returnNestedTensor(ret,**self.state)ifelem.__hash__isnotNoneandlen(set(ret))==1:returnelemreturnret