跳转至

Functions

NestedTensor operation support is split by dispatch layer, not by a single utils.py module. The public documentation follows that structure:

Torch Functions

danling.tensors.torch_functions registers torch.* handlers through __torch_function__, such as torch.cat, torch.stack, reductions, indexing, and matrix operations.

danling.tensors.torch_functions

torch.* function overrides for NestedTensor via __torch_function__.

This module is the Level 2 dispatch layer. When a torch.* call (e.g. torch.cat, torch.mean, torch.einsum) involves a NestedTensor, __torch_function__ checks NestedTensorFuncRegistry for a registered handler.

Handlers here use several strategies depending on the op’s needs:

  • Packed fast-path — ops that work directly on the concatenated _values tensor via NestedTensor._from_packed without knowing element boundaries.
  • Per-element dispatch — ops that must be applied to each element individually via _map_storage_serial, e.g. when dimension indices need translation or output shapes differ per element.

If no handler is registered here, the call falls through to aten decomposition and then to __torch_dispatch__ (see aten_functions).

NN Functions

danling.tensors.nn_functions registers torch.nn.functional.* handlers such as attention, embedding, normalization, pooling, convolution, and loss functions.

danling.tensors.nn_functions

torch.nn.functional.* overrides for NestedTensor via __torch_function__.

This module is the Level 3 dispatch layer, registering handlers for F.linear, F.conv*, F.max_pool*, F.embedding, F.layer_norm, F.scaled_dot_product_attention, and other torch.nn.functional ops.

The design rule is:

  • use one canonical NestedTensor implementation path per op whenever possible
  • reserve compile_safe for packed-first training hot paths only
  • keep convenience APIs honest by marking densifying or repacking handlers eager-only under torch.compile
  • leave non-hot spatial ops eager-only rather than carrying speculative packed fast paths

That means most spatial operators here use per-element dispatch, while a small Tier A set of transformer-hot packed handlers stays compile-safe.

danling.tensors.nn_functions.create_flex_block_mask

Python
create_flex_block_mask(
    mask_mod: Callable,
    query: NestedTensor,
    key: NestedTensor | None = None,
    *,
    num_heads: int | None = None,
    block_size: int | tuple[int, int] = 128,
    compile_mask: bool = False
)

Create a FlexAttention block mask directly from DanLing ragged attention storage.

Source code in danling/tensors/nn_functions.py
Python
def create_flex_block_mask(
    mask_mod: Callable,
    query: NestedTensor,
    key: NestedTensor | None = None,
    *,
    num_heads: int | None = None,
    block_size: int | tuple[int, int] = 128,
    compile_mask: bool = False,
):
    r"""Create a FlexAttention block mask directly from DanLing ragged attention storage."""
    if _torch_create_block_mask is None:
        raise RuntimeError("FlexAttention is unavailable in this PyTorch build.")
    if key is None:
        key = query
    if len(query) != len(key):
        raise ValueError(
            "NestedTensor batch length mismatch between query and key: " f"query={len(query)}, key={len(key)}"
        )
    q_sizes = _packed_sizes_tuple(query)
    k_sizes = _packed_sizes_tuple(key)
    compile_requested = compile_mask or torch.compiler.is_compiling()
    if mask_mod is _flex_allow_all and not compile_requested:
        return _cached_same_sequence_block_mask(
            q_sizes,
            k_sizes,
            num_heads,
            block_size,
            query.device.type,
            query.device.index,
        )
    wrapped_mask_mod = _flex_wrap_mask_mod(mask_mod, query, key)
    return _torch_create_block_mask(
        wrapped_mask_mod,
        1,
        num_heads,
        int(sum(q_sizes)),
        int(sum(k_sizes)),
        device=query.device,
        BLOCK_SIZE=block_size,
        _compile=compile_requested,
    )

Aten Functions

danling.tensors.aten_functions registers packed-storage __torch_dispatch__ handlers and fallback behavior for aten ops.

danling.tensors.aten_functions

__torch_dispatch__ handlers for NestedTensor aten ops (Level 1 dispatch).

This module implements the dispatch table that maps aten ops to optimized handlers operating on the packed representation (_values, _offsets, _physical_shape).

Architecture
  • Elementwise ops operate directly on _values (no unpack/repack overhead)
  • Structural ops (clone, detach, to_copy) operate on all inner tensors
  • Unregistered ops fall back to per-element application via _storage

Dispatch Registries

danling.tensors.ops provides registry types, dispatch tables, and diagnostic helpers used to extend or test NestedTensor operation support.

danling.tensors.ops

Internal helpers shared across NestedTensor function registrations.

NestedTensorFuncRegistry module-attribute

Python
NestedTensorFuncRegistry = TorchFuncRegistry()

NestedTensorAtenRegistry module-attribute

Python
NestedTensorAtenRegistry = TorchFuncRegistry()

TorchFuncRegistry

Bases: dict

Plain dict mapping functions/ops to their NestedTensor handlers.

Uses dict directly for O(1) lookup with minimal overhead (~30 ns) instead of chanfig.Registry (~700-2300 ns).

Used for both __torch_function__ (torch/nn ops) and __torch_dispatch__ (aten ops) dispatch tables.

Source code in danling/tensors/ops.py
Python
class TorchFuncRegistry(dict):
    r"""
    Plain dict mapping functions/ops to their NestedTensor handlers.

    Uses ``dict`` directly for O(1) lookup with minimal overhead (~30 ns)
    instead of chanfig.Registry (~700-2300 ns).

    Used for both ``__torch_function__`` (torch/nn ops) and
    ``__torch_dispatch__`` (aten ops) dispatch tables.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._compile_safe: dict[Callable, bool] = {}
        self._compile_guard: dict[Callable, Callable[[tuple, dict[str, object]], bool]] = {}

    def register(
        self,
        func: Callable,
        handler: Callable,
        *,
        compile_safe: bool = False,
        compile_guard: Callable[[tuple, dict[str, object]], bool] | None = None,
    ) -> Callable:
        r"""Register *handler* for *func* and record whether the path is compile-safe by default."""
        self[func] = handler
        self._compile_safe[func] = bool(compile_safe)
        if compile_guard is not None:
            self._compile_guard[func] = compile_guard
        else:
            self._compile_guard.pop(func, None)
        return handler

    def implement(
        self,
        func: Callable,
        *,
        compile_safe: bool = False,
        compile_guard: Callable[[tuple, dict[str, object]], bool] | None = None,
    ) -> Callable:
        r"""Decorator to register a handler for *func*."""

        def wrapper(handler: Callable) -> Callable:
            return self.register(func, handler, compile_safe=compile_safe, compile_guard=compile_guard)

        return wrapper

    def is_compile_safe(
        self, func: Callable, args: tuple | None = None, kwargs: dict[str, object] | None = None
    ) -> bool:
        r"""Return whether *func* is allowed to run while ``torch.compile`` is tracing."""
        if not bool(self._compile_safe.get(func, False)):
            return False
        guard = self._compile_guard.get(func)
        if guard is None or args is None:
            return True
        return bool(guard(args, kwargs or {}))

    def set_compile_safe(self, func: Callable, compile_safe: bool = True) -> None:
        r"""Update compile policy for an already-registered handler."""
        if func not in self:
            raise KeyError(f"{func} is not registered")
        self._compile_safe[func] = bool(compile_safe)

    def set_compile_guard(self, func: Callable, guard: Callable[[tuple, dict[str, object]], bool] | None) -> None:
        r"""Set or clear the runtime compile guard for an already-registered handler."""
        if func not in self:
            raise KeyError(f"{func} is not registered")
        if guard is None:
            self._compile_guard.pop(func, None)
        else:
            self._compile_guard[func] = guard

    def get_compile_guard(self, func: Callable) -> Callable[[tuple, dict[str, object]], bool] | None:
        r"""Return the runtime compile guard for *func*, if any."""
        return self._compile_guard.get(func)

register

Python
register(
    func: Callable,
    handler: Callable,
    *,
    compile_safe: bool = False,
    compile_guard: (
        Callable[[tuple, dict[str, object]], bool] | None
    ) = None
) -> Callable

Register handler for func and record whether the path is compile-safe by default.

Source code in danling/tensors/ops.py
Python
def register(
    self,
    func: Callable,
    handler: Callable,
    *,
    compile_safe: bool = False,
    compile_guard: Callable[[tuple, dict[str, object]], bool] | None = None,
) -> Callable:
    r"""Register *handler* for *func* and record whether the path is compile-safe by default."""
    self[func] = handler
    self._compile_safe[func] = bool(compile_safe)
    if compile_guard is not None:
        self._compile_guard[func] = compile_guard
    else:
        self._compile_guard.pop(func, None)
    return handler

implement

Python
implement(
    func: Callable,
    *,
    compile_safe: bool = False,
    compile_guard: (
        Callable[[tuple, dict[str, object]], bool] | None
    ) = None
) -> Callable

Decorator to register a handler for func.

Source code in danling/tensors/ops.py
Python
def implement(
    self,
    func: Callable,
    *,
    compile_safe: bool = False,
    compile_guard: Callable[[tuple, dict[str, object]], bool] | None = None,
) -> Callable:
    r"""Decorator to register a handler for *func*."""

    def wrapper(handler: Callable) -> Callable:
        return self.register(func, handler, compile_safe=compile_safe, compile_guard=compile_guard)

    return wrapper

is_compile_safe

Python
is_compile_safe(
    func: Callable,
    args: tuple | None = None,
    kwargs: dict[str, object] | None = None,
) -> bool

Return whether func is allowed to run while torch.compile is tracing.

Source code in danling/tensors/ops.py
Python
def is_compile_safe(
    self, func: Callable, args: tuple | None = None, kwargs: dict[str, object] | None = None
) -> bool:
    r"""Return whether *func* is allowed to run while ``torch.compile`` is tracing."""
    if not bool(self._compile_safe.get(func, False)):
        return False
    guard = self._compile_guard.get(func)
    if guard is None or args is None:
        return True
    return bool(guard(args, kwargs or {}))

set_compile_safe

Python
set_compile_safe(
    func: Callable, compile_safe: bool = True
) -> None

Update compile policy for an already-registered handler.

Source code in danling/tensors/ops.py
Python
def set_compile_safe(self, func: Callable, compile_safe: bool = True) -> None:
    r"""Update compile policy for an already-registered handler."""
    if func not in self:
        raise KeyError(f"{func} is not registered")
    self._compile_safe[func] = bool(compile_safe)

set_compile_guard

Python
set_compile_guard(
    func: Callable,
    guard: (
        Callable[[tuple, dict[str, object]], bool] | None
    ),
) -> None

Set or clear the runtime compile guard for an already-registered handler.

Source code in danling/tensors/ops.py
Python
def set_compile_guard(self, func: Callable, guard: Callable[[tuple, dict[str, object]], bool] | None) -> None:
    r"""Set or clear the runtime compile guard for an already-registered handler."""
    if func not in self:
        raise KeyError(f"{func} is not registered")
    if guard is None:
        self._compile_guard.pop(func, None)
    else:
        self._compile_guard[func] = guard

get_compile_guard

Python
get_compile_guard(
    func: Callable,
) -> Callable[[tuple, dict[str, object]], bool] | None

Return the runtime compile guard for func, if any.

Source code in danling/tensors/ops.py
Python
def get_compile_guard(self, func: Callable) -> Callable[[tuple, dict[str, object]], bool] | None:
    r"""Return the runtime compile guard for *func*, if any."""
    return self._compile_guard.get(func)

nested_execution_guard

Python
nested_execution_guard(
    *,
    forbid_iteration: bool = False,
    forbid_storage_map: bool = False,
    forbid_eager_fallback: bool = False,
    forbid_padded_materialization: bool = False,
    forbid_dense_repack: bool = False
)

Temporarily forbid selected slow paths while exercising NestedTensor hot paths.

This is intended for transformer-critical regression checks, where falling back to Python loops or padded materialization is considered a bug.

Source code in danling/tensors/ops.py
Python
@contextmanager
def nested_execution_guard(
    *,
    forbid_iteration: bool = False,
    forbid_storage_map: bool = False,
    forbid_eager_fallback: bool = False,
    forbid_padded_materialization: bool = False,
    forbid_dense_repack: bool = False,
):
    r"""
    Temporarily forbid selected slow paths while exercising NestedTensor hot paths.

    This is intended for transformer-critical regression checks, where falling
    back to Python loops or padded materialization is considered a bug.
    """
    current = _EXECUTION_GUARD.get()
    merged = _ExecutionGuard(
        forbid_iteration=forbid_iteration or (current.forbid_iteration if current is not None else False),
        forbid_storage_map=forbid_storage_map or (current.forbid_storage_map if current is not None else False),
        forbid_eager_fallback=forbid_eager_fallback
        or (current.forbid_eager_fallback if current is not None else False),
        forbid_padded_materialization=forbid_padded_materialization
        or (current.forbid_padded_materialization if current is not None else False),
        forbid_dense_repack=forbid_dense_repack or (current.forbid_dense_repack if current is not None else False),
    )
    token: Token[_ExecutionGuard | None] = _EXECUTION_GUARD.set(merged)
    try:
        yield
    finally:
        _EXECUTION_GUARD.reset(token)

The files under danling.tensors.functions are specialized implementations used by nn_functions for convolution, pooling, and channel operators. They are kept out of the docs navigation because users normally call the corresponding PyTorch or torch.nn.functional API directly.