跳转至

PNTensor

danling.tensors.PNTensor

Bases: Tensor

Wrapper for tensors to be converted to NestedTensor.

PNTensor is a subclass of torch.Tensor. It implements three additional property as NestedTensor: tensor, mask, and concat.

Although it is possible to directly construct NestedTensor in dataset, the best practice is to do so is in collate_fn. PNTensor is introduced to smoothen the process.

Convert tensors that will be converted to NestedTensor to a PNTensor, and PyTorch Dataloader will automatically collate PNTensor to NestedTensor.

Source code in danling/tensors/nested_tensor.py
Python
class PNTensor(Tensor):
    r"""
    Wrapper for tensors to be converted to `NestedTensor`.

    `PNTensor` is a subclass of `torch.Tensor`.
    It implements three additional property as `NestedTensor`: `tensor`, `mask`, and `concat`.

    Although it is possible to directly construct `NestedTensor` in dataset,
    the best practice is to do so is in `collate_fn`.
    `PNTensor` is introduced to smoothen the process.

    Convert tensors that will be converted to `NestedTensor` to a `PNTensor`,
    and PyTorch Dataloader will automatically collate `PNTensor` to `NestedTensor`.
    """

    @property
    def tensor(self) -> Tensor:
        r"""
        Identical to `self`.

        Returns:
            (torch.Tensor):

        Examples:
            >>> tensor = torch.tensor([1, 2, 3])
            >>> pn_tensor = PNTensor([1, 2, 3])
            >>> bool((tensor == pn_tensor).all())
            True
            >>> bool((tensor == pn_tensor.tensor).all())
            True
        """

        return self

    @property
    def mask(self) -> Tensor:
        r"""
        Identical to `torch.ones_like(self)`.

        Returns:
            (torch.Tensor):

        Examples:
            >>> tensor = torch.tensor([1, 2, 3])
            >>> pn_tensor = PNTensor([1, 2, 3])
            >>> bool((pn_tensor.mask == torch.ones_like(pn_tensor)).all().item())
            True
        """

        return torch.ones_like(self)

    @property
    def contact(self) -> Tensor:
        r"""
        Identical to `self`.

        Returns:
            (torch.Tensor):

        Examples:
            >>> tensor = torch.tensor([1, 2, 3])
            >>> pn_tensor = PNTensor([1, 2, 3])
            >>> bool((tensor == pn_tensor).all())
            True
            >>> bool((tensor == pn_tensor.contact).all())
            True
        """

        return self

    def new_empty(self, *args, **kwargs):
        return PNTensor(super().new_empty(*args, **kwargs))

tensor property

Python
tensor: Tensor

Identical to self.

Returns:

Type Description
Tensor

Examples:

Python Console Session
>>> tensor = torch.tensor([1, 2, 3])
>>> pn_tensor = PNTensor([1, 2, 3])
>>> bool((tensor == pn_tensor).all())
True
>>> bool((tensor == pn_tensor.tensor).all())
True

mask property

Python
mask: Tensor

Identical to torch.ones_like(self).

Returns:

Type Description
Tensor

Examples:

Python Console Session
>>> tensor = torch.tensor([1, 2, 3])
>>> pn_tensor = PNTensor([1, 2, 3])
>>> bool((pn_tensor.mask == torch.ones_like(pn_tensor)).all().item())
True

contact property

Python
contact: Tensor

Identical to self.

Returns:

Type Description
Tensor

Examples:

Python Console Session
>>> tensor = torch.tensor([1, 2, 3])
>>> pn_tensor = PNTensor([1, 2, 3])
>>> bool((tensor == pn_tensor).all())
True
>>> bool((tensor == pn_tensor.contact).all())
True