Bases: Tensor
Wrapper for tensors to be converted to NestedTensor
.
PNTensor
is a subclass of torch.Tensor
.
It implements three additional property as NestedTensor
: tensor
, mask
, and concat
.
Although it is possible to directly construct NestedTensor
in dataset,
the best practice is to do so is in collate_fn
.
PNTensor
is introduced to smoothen the process.
Convert tensors that will be converted to NestedTensor
to a PNTensor
,
and PyTorch Dataloader will automatically collate PNTensor
to NestedTensor
.
Source code in danling/tensors/nested_tensor.py
Python |
---|
| class PNTensor(Tensor):
r"""
Wrapper for tensors to be converted to `NestedTensor`.
`PNTensor` is a subclass of `torch.Tensor`.
It implements three additional property as `NestedTensor`: `tensor`, `mask`, and `concat`.
Although it is possible to directly construct `NestedTensor` in dataset,
the best practice is to do so is in `collate_fn`.
`PNTensor` is introduced to smoothen the process.
Convert tensors that will be converted to `NestedTensor` to a `PNTensor`,
and PyTorch Dataloader will automatically collate `PNTensor` to `NestedTensor`.
"""
@property
def tensor(self) -> Tensor:
r"""
Identical to `self`.
Returns:
(torch.Tensor):
Examples:
>>> tensor = torch.tensor([1, 2, 3])
>>> pn_tensor = PNTensor([1, 2, 3])
>>> bool((tensor == pn_tensor).all())
True
>>> bool((tensor == pn_tensor.tensor).all())
True
"""
return self
@property
def mask(self) -> Tensor:
r"""
Identical to `torch.ones_like(self)`.
Returns:
(torch.Tensor):
Examples:
>>> tensor = torch.tensor([1, 2, 3])
>>> pn_tensor = PNTensor([1, 2, 3])
>>> bool((pn_tensor.mask == torch.ones_like(pn_tensor)).all().item())
True
"""
return torch.ones_like(self)
@property
def contact(self) -> Tensor:
r"""
Identical to `self`.
Returns:
(torch.Tensor):
Examples:
>>> tensor = torch.tensor([1, 2, 3])
>>> pn_tensor = PNTensor([1, 2, 3])
>>> bool((tensor == pn_tensor).all())
True
>>> bool((tensor == pn_tensor.contact).all())
True
"""
return self
def new_empty(self, *args, **kwargs):
return PNTensor(super().new_empty(*args, **kwargs))
|
tensor
property
Identical to self
.
Returns:
Examples:
Python Console Session>>> tensor = torch.tensor([1, 2, 3])
>>> pn_tensor = PNTensor([1, 2, 3])
>>> bool((tensor == pn_tensor).all())
True
>>> bool((tensor == pn_tensor.tensor).all())
True
mask
property
Identical to torch.ones_like(self)
.
Returns:
Examples:
Python Console Session>>> tensor = torch.tensor([1, 2, 3])
>>> pn_tensor = PNTensor([1, 2, 3])
>>> bool((pn_tensor.mask == torch.ones_like(pn_tensor)).all().item())
True
Identical to self
.
Returns:
Examples:
Python Console Session>>> tensor = torch.tensor([1, 2, 3])
>>> pn_tensor = PNTensor([1, 2, 3])
>>> bool((tensor == pn_tensor).all())
True
>>> bool((tensor == pn_tensor.contact).all())
True