Skip to content

NestedTensor

danling.tensors.nested_tensor

Core [NestedTensor][] class: packed variable-length tensor with __torch_function__ / __torch_dispatch__ integration.

This module defines the storage contract (_values, _offsets, _physical_shape, _permutation), construction, metadata management, materialization (.tensor, .mask, .concat), serialization, and dispatch entry points.

NestedTensor

Bases: Tensor

A container for variable-length tensors that enables efficient batch operations.

NestedTensor solves a fundamental problem in deep learning: handling sequences of different lengths in batch operations. Instead of excessive padding or complex bucketing, NestedTensor provides an elegant solution that maintains both efficiency and usability.

The class provides three main views of the data: - .tensor: A padded tensor with zeros (or other value) in place of missing elements - .mask: A boolean mask indicating which elements are real vs padding - .concat: The packed tensor containing all elements concatenated without padding

When indexing a NestedTensor, the behavior depends on the index type: 1. Integer index (nt[0]): Returns a single tensor without padding 2. Slice index (nt[:]): Returns a new NestedTensor containing the selected batch elements 3. Tuple index (nt[:, 1:]): Returns a new NestedTensor with the specified sliced shape

Attributes:

Name Type Description
_values Tensor

Packed tensor data

_offsets Tensor

Top-level cumulative element counts, shape (B+1,)

_permutation tuple[int, ...]

Canonical logical-to-packed dimension permutation

_physical_shape Tensor

Per-element physical shapes, shape (B, max_ndim)

batch_first bool

Whether the first dimension is the batch dimension (B, N, *) If False, the first dimension is the sequence dimension (N, B, *)

padding_value float

Value used for padding in the padded tensor

mask_value bool

Boolean fill value for padding positions in generated masks. - mask_value=False (default): valid positions are True and padding is False. - mask_value=True: padding positions are True and valid positions are False.

Parameters:

Name Type Description Default

*tensors

Variable-length tensors or sequences to store

required

batch_first

Whether to use batch-first representation.

required

padding_value

Value to use for padding.

required

mask_value

Boolean fill value used for padding positions in masks.

required

Raises:

Type Description
ValueError

If tensors is not an iterable

Examples:

Basic usage:

Python Console Session
>>> nested_tensor = NestedTensor(torch.tensor([1, 2, 3]), torch.tensor([4, 5]))
>>> nested_tensor.shape
torch.Size([2, 3])
>>> nested_tensor.tensor  # Padded representation
tensor([[1, 2, 3],
        [4, 5, 0]])
>>> nested_tensor.mask  # Mask showing real vs padding values
tensor([[ True,  True,  True],
        [ True,  True, False]])
>>> nested_tensor.concat  # Concatenated version (no padding)
tensor([1, 2, 3, 4, 5])
Python Console Session
>>> nested_tensor[0]  # First tensor (no padding)
tensor([1, 2, 3])
>>> nested_tensor[:2]  # Returns a NestedTensor slice
NestedTensor([
    [1, 2, 3],
    [4, 5]
])
>>> nested_tensor[:, 1:]  # Slice operations return a new NestedTensor
NestedTensor([
    [2, 3],
    [5]
])

Type conversion:

Python Console Session
1
2
3
4
5
6
>>> nested_tensor.to(torch.float).tensor
tensor([[1., 2., 3.],
        [4., 5., 0.]])
>>> nested_tensor.half().tensor
tensor([[1., 2., 3.],
        [4., 5., 0.]], dtype=torch.float16)

Conversion to Python types:

Python Console Session
>>> nested_tensor.tolist()
[[1, 2, 3], [4, 5]]

Creating from Python lists:

Python Console Session
1
2
3
4
5
>>> NestedTensor(*[[1, 2, 3], [4, 5]])
NestedTensor([
    [1, 2, 3],
    [4, 5]
])
Source code in danling/tensors/nested_tensor.py
Python
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
class NestedTensor(torch.Tensor):
    r"""
    A container for variable-length tensors that enables efficient batch operations.

    `NestedTensor` solves a fundamental problem in deep learning: handling sequences of different lengths
    in batch operations. Instead of excessive padding or complex bucketing, `NestedTensor` provides an
    elegant solution that maintains both efficiency and usability.

    The class provides three main views of the data:
    - `.tensor`: A padded tensor with zeros (or other value) in place of missing elements
    - `.mask`: A boolean mask indicating which elements are real vs padding
    - `.concat`: The packed tensor containing all elements concatenated without padding

    When indexing a `NestedTensor`, the behavior depends on the index type:
    1. Integer index (`nt[0]`): Returns a single tensor without padding
    2. Slice index (`nt[:]`): Returns a new `NestedTensor` containing the selected batch elements
    3. Tuple index (`nt[:, 1:]`): Returns a new `NestedTensor` with the specified sliced shape

    Attributes:
        _values: Packed tensor data
        _offsets: Top-level cumulative element counts, shape (B+1,)
        _permutation: Canonical logical-to-packed dimension permutation
        _physical_shape: Per-element physical shapes, shape (B, max_ndim)
        batch_first: Whether the first dimension is the batch dimension (B, N, *)
            If `False`, the first dimension is the sequence dimension (N, B, *)
        padding_value: Value used for padding in the padded tensor
        mask_value: Boolean fill value for padding positions in generated masks.
            - ``mask_value=False`` (default): valid positions are ``True`` and padding is ``False``.
            - ``mask_value=True``: padding positions are ``True`` and valid positions are ``False``.

    Args:
        *tensors: Variable-length tensors or sequences to store
        batch_first: Whether to use batch-first representation.
        padding_value: Value to use for padding.
        mask_value: Boolean fill value used for padding positions in masks.

    Raises:
        ValueError: If `tensors` is not an iterable

    Examples:
        Basic usage:
        >>> nested_tensor = NestedTensor(torch.tensor([1, 2, 3]), torch.tensor([4, 5]))
        >>> nested_tensor.shape
        torch.Size([2, 3])
        >>> nested_tensor.tensor  # Padded representation
        tensor([[1, 2, 3],
                [4, 5, 0]])
        >>> nested_tensor.mask  # Mask showing real vs padding values
        tensor([[ True,  True,  True],
                [ True,  True, False]])
        >>> nested_tensor.concat  # Concatenated version (no padding)
        tensor([1, 2, 3, 4, 5])

        Indexing:
        >>> nested_tensor[0]  # First tensor (no padding)
        tensor([1, 2, 3])
        >>> nested_tensor[:2]  # Returns a NestedTensor slice
        NestedTensor([
            [1, 2, 3],
            [4, 5]
        ])
        >>> nested_tensor[:, 1:]  # Slice operations return a new NestedTensor
        NestedTensor([
            [2, 3],
            [5]
        ])

        Type conversion:
        >>> nested_tensor.to(torch.float).tensor
        tensor([[1., 2., 3.],
                [4., 5., 0.]])
        >>> nested_tensor.half().tensor
        tensor([[1., 2., 3.],
                [4., 5., 0.]], dtype=torch.float16)

        Conversion to Python types:
        >>> nested_tensor.tolist()
        [[1, 2, 3], [4, 5]]

        Creating from Python lists:
        >>> NestedTensor(*[[1, 2, 3], [4, 5]])
        NestedTensor([
            [1, 2, 3],
            [4, 5]
        ])
    """

    _values: Tensor
    _offsets: Tensor
    _permutation: tuple[int, ...]
    _physical_shape: Tensor
    _flatten_sentinel: Tensor = torch.empty(0)
    _logical_shape: torch.Size
    _batch_first: bool
    _padding_value: float
    _mask_value: bool
    _pin_memory: bool
    _packed_sizes: tuple[int, ...] | None
    _element_shapes: tuple[tuple[int, ...], ...] | None
    _cached_storage: tuple[Tensor, ...] | None
    _cached_hierarchical_offsets: tuple[Tensor, ...] | None
    _cached_tensor_view: tuple[bool, float, tuple[int, int, int], Tensor] | None
    _cached_mask_view: tuple[bool, bool, tuple[int, int], Tensor] | None
    _SERIALIZATION_VERSION = 1

    # Construction & Initialization

    @staticmethod
    def __new__(
        cls,
        *tensors: Iterable[Tensor],
        dtype: torch.dtype | None = None,
        device: torch.device | None = None,
        requires_grad: bool | None = None,
        pin_memory: bool = False,
        batch_first: bool = True,
        padding_value: SupportsFloat = 0.0,
        mask_value: bool = False,
    ):
        if len(tensors) == 1 and not isinstance(tensors[0], Tensor):
            if isinstance(tensors[0], Iterable):
                tensors = tuple(tensors[0])  # type: ignore
            else:
                raise ValueError(f"tensors must be an Iterable, but got {type(tensors[0])}.")

        # Validate and convert tensors
        validated = cls._coerce_tensors(
            tensors, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory
        )

        # Determine dtype/device from validated tensors or fallbacks
        out_dtype = validated[0].dtype if validated else (dtype or torch.get_default_dtype())
        out_device = validated[0].device if validated else (device or torch.device("cpu"))

        # Pack into values, offsets, tensor-shape metadata, and Python metadata.
        values, offsets, shape_tensor, packed_sizes, element_shapes = cls._pack(
            validated,
            dtype=out_dtype,
            device=out_device,
        )
        values = cls._maybe_pin_values(values, pin_memory)
        permutation = cls._permutation_from_element_shapes(element_shapes)

        # Compute logical shape
        logical_shape = cls._compute_logical_shape(validated, batch_first)
        if requires_grad is not None and values.requires_grad != requires_grad:
            values.requires_grad_(requires_grad)
        out_requires_grad = values.requires_grad

        result = torch.Tensor._make_wrapper_subclass(
            cls,
            logical_shape,
            dtype=out_dtype,
            device=out_device,
            requires_grad=out_requires_grad,
        )
        result._values = values
        result._offsets = offsets
        result._permutation = permutation
        result._physical_shape = shape_tensor
        result._logical_shape = logical_shape
        result._set_runtime_config(
            batch_first=batch_first,
            padding_value=padding_value,
            mask_value=mask_value,
        )
        result._pin_memory = bool(pin_memory and values.device.type == "cpu" and values.is_pinned())
        result._packed_sizes = packed_sizes
        result._element_shapes = element_shapes
        result._invalidate_transient_caches()
        cls._validate_packed_metadata(
            result._values,
            result._offsets,
            result._physical_shape,
            permutation=result._permutation,
            logical_shape=result._logical_shape,
            batch_first=result.batch_first,
            packed_sizes=result._packed_sizes,
            element_shapes=result._element_shapes,
        )
        return result

    def __init__(self, *args, **kwargs):
        pass  # All init in __new__

    # ------------------------------------------------------------------
    # Packed representation helpers
    # ------------------------------------------------------------------

    @staticmethod
    def _coerce_tensors(
        tensors: tuple,
        *,
        dtype: torch.dtype | None = None,
        device: torch.device | None = None,
        requires_grad: bool | None = None,
        pin_memory: bool = False,
    ) -> tuple[Tensor, ...]:
        if not isinstance(tensors, Iterable):
            raise ValueError(f"tensors must be an Iterable, but got {type(tensors)}.")
        if isinstance(tensors, Tensor) and hasattr(tensors, "unbind"):
            tensors = tensors.unbind()

        result: list[Tensor] = []
        common_device: torch.device | None = None
        common_ndim: int | None = None
        # Only track dtype promotion when the caller did not specify an explicit dtype.
        # When dtype is given, t.to(device, dtype=dtype) already handles casting in
        # the first pass, so the promotion loop and second pass are both unnecessary.
        needs_promotion = dtype is None
        common_dtype: torch.dtype | None = None

        for t in tensors:
            if not isinstance(t, Tensor):
                t = torch.tensor(t, dtype=dtype, device=device, pin_memory=pin_memory)
            else:
                t = t.to(device, dtype=dtype)
            if requires_grad is not None:
                t.requires_grad_(requires_grad)

            if common_device is None:
                common_device = t.device
            elif t.device != common_device:
                raise ValueError(
                    f"All tensors in NestedTensor must be on the same device, but got {common_device} and {t.device}"
                )

            if needs_promotion:
                if common_dtype is None:
                    common_dtype = t.dtype
                else:
                    common_dtype = torch.promote_types(common_dtype, t.dtype)

            if common_ndim is None:
                common_ndim = t.ndim
            elif t.ndim != common_ndim:
                raise ValueError(
                    f"All tensors must have the same number of dimensions, got ndim {common_ndim} and {t.ndim}. "
                    "If using a DataLoader with drop_last=False, squeeze the last batch before constructing "
                    "NestedTensor."
                )

            result.append(t)

        if not result:
            return ()

        # Second pass only when dtype=None AND promotion actually changed the dtype.
        if needs_promotion and common_dtype is not None and any(t.dtype != common_dtype for t in result):
            return tuple(t.to(dtype=common_dtype) for t in result)
        return tuple(result)

    @staticmethod
    def _pack(
        tensors: tuple[Tensor, ...],
        *,
        dtype: torch.dtype | None = None,
        device: torch.device | None = None,
        permutation: tuple[int, ...] | None = None,
    ) -> tuple[Tensor, Tensor, Tensor, tuple[int, ...], tuple[tuple[int, ...], ...]]:
        r"""Pack a sequence of tensors into values, offsets, tensor metadata, and Python metadata."""
        if not tensors:
            return (
                torch.empty(0, dtype=dtype or torch.get_default_dtype(), device=device),
                torch.zeros(1, dtype=torch.long),
                torch.empty(0, 0, dtype=torch.long),
                (),
                (),
            )

        max_ndim = max(t.ndim for t in tensors)
        element_shapes = tuple(tuple(int(dim) for dim in t.shape) for t in tensors)

        # Offsets and shape_tensor are metadata - always on CPU to avoid CUDA syncs.
        shape_tensor = torch.tensor([list(t.shape) + [0] * (max_ndim - t.ndim) for t in tensors], dtype=torch.long)
        if max_ndim == 0:
            values = torch.stack(tensors)
            sizes = torch.ones(len(tensors), dtype=torch.long)
            packed_sizes = tuple(1 for _ in tensors)
        else:
            if permutation is None:
                varying_dims, static_dims = NestedTensor._pack_layout_from_element_shapes(element_shapes)
                permutation = varying_dims + static_dims
            else:
                permutation = tuple(int(dim) for dim in permutation)
                if len(permutation) != max_ndim or tuple(sorted(permutation)) != tuple(range(max_ndim)):
                    raise ValueError(f"Invalid permutation dims {permutation} for tensors with rank {max_ndim}")
                ragged_rank = len(NestedTensor._hierarchical_level_sizes_from_element_shapes(element_shapes))
                varying_dims = permutation[:ragged_rank]
                static_dims = permutation[ragged_rank:]
            packed = []
            packed_sizes_list = []
            identity_permutation = tuple(range(max_ndim))
            for tensor, shape in zip(tensors, element_shapes):
                packed_size = NestedTensor._packed_size_from_shape(shape, varying_dims)
                packed_sizes_list.append(packed_size)
                packed_tensor = tensor if permutation == identity_permutation else tensor.permute(permutation)
                suffix_shape = tuple(shape[dim] for dim in static_dims)
                packed.append(packed_tensor.reshape((packed_size, *suffix_shape) if suffix_shape else (packed_size,)))
            values = torch.cat(packed, dim=0)
            sizes = torch.tensor(packed_sizes_list, dtype=torch.long)
            packed_sizes = tuple(packed_sizes_list)
        offsets = torch.zeros(len(tensors) + 1, dtype=torch.long)
        torch.cumsum(sizes, dim=0, out=offsets[1:])

        return values, offsets, shape_tensor, packed_sizes, element_shapes

    @staticmethod
    def _maybe_pin_values(values: Tensor, pin_memory: bool) -> Tensor:
        r"""Pin packed storage when requested and the values live on CPU."""
        if pin_memory and values.device.type == "cpu" and not values.is_pinned():
            return values.pin_memory()
        return values

    @staticmethod
    def _trim_shape(shape: Sequence[int]) -> tuple[int, ...]:
        end = len(shape)
        while end > 0 and shape[end - 1] == 0:
            end -= 1
        return tuple(int(shape[i]) for i in range(end))

    @staticmethod
    def _shape_numel(shape: tuple[int, ...]) -> int:
        size = 1
        for dim in shape:
            size *= int(dim)
        return size

    @classmethod
    def _permutation_from_element_shapes(cls, element_shapes: tuple[tuple[int, ...], ...]) -> tuple[int, ...]:
        varying_dims, static_dims = cls._pack_layout_from_element_shapes(element_shapes)
        return varying_dims + static_dims

    @classmethod
    def _permutation_from_physical_shape(
        cls,
        physical_shape: Tensor,
        element_shapes: tuple[tuple[int, ...], ...] | None,
    ) -> tuple[int, ...]:
        varying_dims, static_dims = cls._pack_layout_meta(physical_shape, element_shapes)
        return varying_dims + static_dims

    @staticmethod
    def _offsets_from_sizes(sizes: Sequence[int], *, dtype: torch.dtype = torch.long) -> Tensor:
        offsets = torch.empty((len(sizes) + 1,), dtype=dtype)
        offsets[0] = 0
        if sizes:
            offsets[1:] = torch.cumsum(torch.tensor(sizes, dtype=dtype), dim=0)
        return offsets

    @staticmethod
    def _meta_tensor_equal(lhs: Tensor, rhs: Tensor) -> bool:
        if _is_fake_tensor(lhs) or _is_fake_tensor(rhs):
            return lhs is rhs
        if lhs is rhs:
            return True
        if lhs.shape != rhs.shape:
            return False
        return bool(torch.equal(lhs, rhs))

    @classmethod
    def _hierarchical_level_sizes_from_element_shapes(
        cls,
        element_shapes: tuple[tuple[int, ...], ...],
    ) -> tuple[tuple[int, ...], ...]:
        if not element_shapes:
            return ()
        varying_dims, _ = cls._pack_layout_from_element_shapes(element_shapes)
        if not varying_dims:
            return ()

        level_sizes: list[tuple[int, ...]] = []
        prefix_products = [1] * len(element_shapes)
        for dim in varying_dims:
            sizes: list[int] = []
            next_prefix_products: list[int] = []
            for shape, prefix in zip(element_shapes, prefix_products):
                dim_size = int(shape[dim])
                sizes.extend([dim_size] * prefix)
                next_prefix_products.append(prefix * dim_size)
            level_sizes.append(tuple(sizes))
            prefix_products = next_prefix_products
        return tuple(level_sizes)

    @classmethod
    def _hierarchical_level_sizes_from_physical_shape(
        cls,
        physical_shape: Tensor,
        element_shapes: tuple[tuple[int, ...], ...] | None = None,
    ) -> tuple[tuple[int, ...], ...]:
        if physical_shape.numel() == 0:
            return ()
        if element_shapes is not None:
            return cls._hierarchical_level_sizes_from_element_shapes(element_shapes)
        if _is_fake_tensor(physical_shape):
            return ()

        varying_dims, _ = cls._pack_layout_meta(physical_shape, None)
        if not varying_dims:
            return ()

        shape_rows = tuple(cls._trim_shape(row) for row in physical_shape.tolist())
        level_sizes: list[tuple[int, ...]] = []
        prefix_products = [1] * len(shape_rows)
        for dim in varying_dims:
            sizes: list[int] = []
            next_prefix_products: list[int] = []
            for shape, prefix in zip(shape_rows, prefix_products):
                dim_size = int(shape[dim]) if dim < len(shape) else 0
                sizes.extend([dim_size] * prefix)
                next_prefix_products.append(prefix * dim_size)
            level_sizes.append(tuple(sizes))
            prefix_products = next_prefix_products
        return tuple(level_sizes)

    @staticmethod
    def _inverse_permutation(permutation: tuple[int, ...]) -> tuple[int, ...]:
        inverse = [0] * len(permutation)
        for index, dim in enumerate(permutation):
            inverse[dim] = index
        return tuple(inverse)

    @classmethod
    def _pack_layout_from_element_shapes(
        cls,
        element_shapes: tuple[tuple[int, ...], ...],
    ) -> tuple[tuple[int, ...], tuple[int, ...]]:
        if not element_shapes:
            return (), ()
        ndim = len(element_shapes[0])
        if ndim == 0:
            return (), ()
        reference = element_shapes[0]
        static_dims = [
            dim
            for dim in range(ndim)
            if all(len(shape) == ndim and shape[dim] == reference[dim] for shape in element_shapes[1:])
        ]
        if len(static_dims) == ndim:
            static_dims = list(range(1, ndim))
        static_dims_tuple = tuple(static_dims)
        varying_dims = tuple(dim for dim in range(ndim) if dim not in static_dims_tuple)
        return varying_dims, static_dims_tuple

    @classmethod
    def _pack_layout_meta(
        cls,
        physical_shape: Tensor,
        element_shapes: tuple[tuple[int, ...], ...] | None,
    ) -> tuple[tuple[int, ...], tuple[int, ...]]:
        if element_shapes is not None and (element_shapes or int(physical_shape.size(1)) == 0):
            return cls._pack_layout_from_element_shapes(element_shapes)
        ndim = int(physical_shape.size(1))
        if ndim == 0:
            return (), ()
        if physical_shape.size(0) == 0:
            return (0,), tuple(range(1, ndim))
        static_dims = tuple(
            dim
            for dim in range(ndim)
            if bool(torch.equal(physical_shape[:, dim], physical_shape[:1, dim].expand(physical_shape.size(0))))
        )
        if len(static_dims) == ndim:
            static_dims = tuple(range(1, ndim))
        varying_dims = tuple(dim for dim in range(ndim) if dim not in static_dims)
        return varying_dims, static_dims

    @staticmethod
    def _packed_size_from_shape(shape: tuple[int, ...], varying_dims: tuple[int, ...]) -> int:
        if not shape or not varying_dims:
            return 1
        size = 1
        for dim in varying_dims:
            size *= int(shape[dim])
        return size

    @classmethod
    def _python_meta_from_packed(
        cls,
        values: Tensor,
        offsets: Tensor,
        shape_tensor: Tensor,
        *,
        packed_sizes: tuple[int, ...] | None = None,
        element_shapes: tuple[tuple[int, ...], ...] | None = None,
    ) -> tuple[tuple[int, ...], tuple[tuple[int, ...], ...]]:
        if packed_sizes is None:
            packed_sizes = tuple(int(size) for size in (offsets[1:] - offsets[:-1]).tolist())
        if element_shapes is None:
            element_shapes = tuple(cls._trim_shape(shape) for shape in shape_tensor.tolist())
        return packed_sizes, element_shapes

    @classmethod
    @torch._dynamo.disable
    def _infer_python_meta_from_packed(
        cls,
        values: Tensor,
        offsets: Tensor,
        shape_tensor: Tensor,
        *,
        packed_sizes: tuple[int, ...] | None = None,
        element_shapes: tuple[tuple[int, ...], ...] | None = None,
    ) -> tuple[tuple[int, ...], tuple[tuple[int, ...], ...]]:
        return cls._python_meta_from_packed(
            values,
            offsets,
            shape_tensor,
            packed_sizes=packed_sizes,
            element_shapes=element_shapes,
        )

    @staticmethod
    def _compute_logical_shape(tensors: tuple[Tensor, ...], batch_first: bool) -> torch.Size:
        r"""Compute the logical shape [B, max_d0, max_d1, ...] from individual tensors."""
        if not tensors:
            return torch.Size((0,))
        if max(t.dim() for t in tensors) == 0:
            return torch.Size((len(tensors),))
        ndim = max(t.dim() for t in tensors)
        size = [max(t.shape[i] if i < len(t.shape) else 0 for t in tensors) for i in range(ndim)]
        size.insert(0 if batch_first else 1, len(tensors))
        return torch.Size(size)

    @staticmethod
    def _logical_shape_from_physical_shape(physical_shape: Tensor, offsets: Tensor, batch_first: bool) -> torch.Size:
        r"""Compute logical shape from packed metadata without unpacking elements."""
        batch_size = len(offsets) - 1
        if batch_size == 0:
            return torch.Size((0,))
        if physical_shape.numel() == 0:
            return torch.Size((batch_size,))
        size = [int(physical_shape[:, d].max().item()) for d in range(physical_shape.size(1))]
        while size and size[-1] == 0:
            size.pop()
        size.insert(0 if batch_first else 1, batch_size)
        return torch.Size(size)

    @staticmethod
    def _batch_dim_from_logical_shape(logical_shape: torch.Size, batch_first: bool) -> int:
        r"""Return the batch dimension index for a logical NestedTensor shape."""
        return 0 if len(logical_shape) <= 1 or batch_first else 1

    @classmethod
    def _validate_packed_metadata(
        cls,
        values: Tensor,
        offsets: Tensor,
        shape_tensor: Tensor,
        *,
        permutation: tuple[int, ...],
        logical_shape: torch.Size,
        batch_first: bool,
        packed_sizes: tuple[int, ...] | None,
        element_shapes: tuple[tuple[int, ...], ...] | None,
    ) -> None:
        r"""Validate that packed storage and metadata describe a coherent NestedTensor layout."""
        if offsets.device.type != "cpu":
            raise ValueError(f"offsets must be on CPU, got {offsets.device}")
        if shape_tensor.device.type != "cpu":
            raise ValueError(f"shape_tensor must be on CPU, got {shape_tensor.device}")
        if offsets.dim() != 1:
            raise ValueError(f"offsets must be 1-D, got shape {tuple(offsets.shape)}")
        if shape_tensor.dim() != 2:
            raise ValueError(f"shape_tensor must be 2-D, got shape {tuple(shape_tensor.shape)}")
        if offsets.dtype.is_floating_point or offsets.dtype.is_complex or offsets.dtype == torch.bool:
            raise ValueError(f"offsets must use an integer dtype, got {offsets.dtype}")
        if shape_tensor.dtype.is_floating_point or shape_tensor.dtype.is_complex or shape_tensor.dtype == torch.bool:
            raise ValueError(f"shape_tensor must use an integer dtype, got {shape_tensor.dtype}")

        batch_size = int(shape_tensor.size(0))
        if offsets.numel() != batch_size + 1:
            raise ValueError(
                "offsets length must equal batch size + 1, got "
                f"offsets.numel()={offsets.numel()}, batch_size={batch_size}"
            )

        physical_rank = int(shape_tensor.size(1))
        if len(logical_shape) != physical_rank + 1:
            raise ValueError(
                "logical shape rank must equal physical rank + 1, got "
                f"logical rank={len(logical_shape)}, physical rank={physical_rank}"
            )
        batch_dim = cls._batch_dim_from_logical_shape(logical_shape, batch_first)
        logical_batch = logical_shape[batch_dim]
        if logical_batch != batch_size:
            raise ValueError(f"logical batch size {logical_batch} does not match metadata batch size {batch_size}")

        if len(permutation) != physical_rank or tuple(sorted(int(dim) for dim in permutation)) != tuple(
            range(physical_rank)
        ):
            raise ValueError(f"Invalid permutation dims {permutation} for shape with {physical_rank} dims")

        if packed_sizes is not None:
            if len(packed_sizes) != batch_size:
                raise ValueError(
                    f"packed_sizes must have one entry per element, got {len(packed_sizes)} for batch size {batch_size}"
                )
            if any(int(size) < 0 for size in packed_sizes):
                raise ValueError("packed_sizes must be non-negative")
            if sum(int(size) for size in packed_sizes) != int(values.shape[0]):
                raise ValueError("packed_sizes must sum to the packed values length")

        if element_shapes is not None:
            if len(element_shapes) != batch_size:
                raise ValueError(
                    "element_shapes must have one entry per element, got "
                    f"{len(element_shapes)} for batch size {batch_size}"
                )
            normalized_shapes = tuple(tuple(int(dim) for dim in shape) for shape in element_shapes)
            if any(len(shape) != physical_rank for shape in normalized_shapes):
                raise ValueError(
                    f"element_shapes rank must match physical rank {physical_rank}, got {normalized_shapes}"
                )
            if any(any(dim < 0 for dim in shape) for shape in normalized_shapes):
                raise ValueError("element_shapes must be non-negative")
            if not _is_fake_tensor(shape_tensor):
                shape_rows = tuple(tuple(int(size) for size in row) for row in shape_tensor.tolist())
                if normalized_shapes != shape_rows:
                    raise ValueError("element_shapes must match shape_tensor exactly")

        if _is_fake_tensor(offsets) or _is_fake_tensor(shape_tensor):
            return

        if bool((shape_tensor < 0).any()):
            raise ValueError("shape_tensor must be non-negative")
        if int(offsets[0].item()) != 0:
            raise ValueError("offsets must start at 0")
        deltas = offsets[1:] - offsets[:-1]
        if bool((deltas < 0).any()):
            raise ValueError("offsets must be monotonically non-decreasing")
        if packed_sizes is None and int(offsets[-1].item()) != int(values.shape[0]):
            raise ValueError(
                f"offsets[-1] must equal packed values length, got offsets[-1]={int(offsets[-1].item())} "
                f"and values.shape[0]={int(values.shape[0])}"
            )

    def _validate_metadata(self) -> None:
        r"""Validate the current packed storage and metadata."""
        type(self)._validate_packed_metadata(
            self._values,
            self._offsets,
            self._physical_shape,
            permutation=self._permutation,
            logical_shape=self._logical_shape,
            batch_first=self.batch_first,
            packed_sizes=self._packed_sizes,
            element_shapes=self._element_shapes,
        )

    @staticmethod
    def _coerce_batch_first(value: bool) -> bool:
        if not isinstance(value, bool):
            raise TypeError(f"batch_first must be a bool, got {type(value).__name__}")
        return value

    @staticmethod
    def _coerce_mask_value(value: bool) -> bool:
        if not isinstance(value, bool):
            raise TypeError(f"mask_value must be a bool, got {type(value).__name__}")
        return value

    @staticmethod
    def _coerce_padding_value(value: SupportsFloat) -> float:
        try:
            return float(value)
        except (TypeError, ValueError) as exc:
            raise TypeError(f"padding_value must be float-convertible, got {type(value).__name__}") from exc

    def _set_runtime_config(
        self,
        *,
        batch_first: bool,
        padding_value: SupportsFloat,
        mask_value: bool,
    ) -> None:
        self._batch_first = type(self)._coerce_batch_first(batch_first)
        self._padding_value = type(self)._coerce_padding_value(padding_value)
        self._mask_value = type(self)._coerce_mask_value(mask_value)

    def _invalidate_transient_caches(self) -> None:
        r"""Drop all lazily materialized views derived from packed storage."""
        self._cached_storage = None
        self._cached_hierarchical_offsets = None
        self._cached_tensor_view = None
        self._cached_mask_view = None

    def _values_cache_token(self) -> tuple[int, int, int]:
        r"""Return a cache token for views that depend on packed values and layout metadata.

        Under ``torch.inference_mode`` tensors do not track version counters and
        in-place mutation is forbidden, so the cache is always valid.
        """
        if torch.is_inference_mode_enabled():
            return (0, 0, 0)
        return (int(self._values._version), int(self._offsets._version), int(self._physical_shape._version))

    def _shape_cache_token(self) -> tuple[int, int]:
        r"""Return a cache token for views that depend only on shape metadata."""
        if torch.is_inference_mode_enabled():
            return (0, 0)
        return (int(self._offsets._version), int(self._physical_shape._version))

    @classmethod
    def _validate_serialized_state(cls, state: Mapping) -> None:
        required = (
            "_state_version",
            "_values",
            "_offsets",
            "_permutation",
            "_physical_shape",
            "_logical_shape",
            "batch_first",
            "padding_value",
            "mask_value",
            "_pin_memory",
            "_packed_sizes",
            "_element_shapes",
        )
        missing = [key for key in required if key not in state]
        if missing:
            raise KeyError(f"Serialized NestedTensor state is missing required keys: {', '.join(missing)}")
        version = state["_state_version"]
        if version != cls._SERIALIZATION_VERSION:
            raise ValueError(f"Unsupported NestedTensor state version {version}; expected {cls._SERIALIZATION_VERSION}")

    @classmethod
    @torch._dynamo.disable
    def _from_packed(
        cls,
        values: Tensor,
        offsets: Tensor,
        shape_tensor: Tensor,
        *,
        permutation: tuple[int, ...] | None = None,
        batch_first: bool = True,
        padding_value: float = 0.0,
        mask_value: bool = False,
        pin_memory: bool = False,
        outer_size: torch.Size | tuple | None = None,
        packed_sizes: tuple[int, ...] | None = None,
        element_shapes: tuple[tuple[int, ...], ...] | None = None,
    ) -> Self:
        r"""Construct a NestedTensor directly from packed representation."""
        # offsets and shape_tensor MUST live on CPU to avoid implicit CUDA syncs
        # when handlers call .item() / .tolist() on them.
        if offsets.device.type != "cpu":
            raise ValueError(f"offsets must be on CPU, got {offsets.device}")
        if shape_tensor.device.type != "cpu":
            raise ValueError(f"shape_tensor must be on CPU, got {shape_tensor.device}")

        if outer_size is not None:
            logical_shape = torch.Size(outer_size)
        else:
            logical_shape = cls._logical_shape_from_physical_shape(shape_tensor, offsets, batch_first)
        if packed_sizes is None and not _is_fake_tensor(offsets):
            packed_sizes = tuple(int(size) for size in (offsets[1:] - offsets[:-1]).tolist())
        if element_shapes is None and not _is_fake_tensor(shape_tensor):
            element_shapes = tuple(cls._trim_shape(shape) for shape in shape_tensor.tolist())

        if _is_fake_tensor(values) and not (_is_fake_tensor(offsets) and _is_fake_tensor(shape_tensor)):
            from torch._subclasses.fake_tensor import maybe_get_fake_mode

            fake_mode = maybe_get_fake_mode(values)
            if fake_mode is not None:
                if not _is_fake_tensor(offsets):
                    offsets = fake_mode.from_tensor(offsets, static_shapes=True, trace=False)
                if not _is_fake_tensor(shape_tensor):
                    shape_tensor = fake_mode.from_tensor(shape_tensor, static_shapes=True, trace=False)

        values = cls._maybe_pin_values(values, pin_memory)
        result = torch.Tensor._make_wrapper_subclass(
            cls,
            logical_shape,
            dtype=values.dtype,
            device=values.device,
            requires_grad=values.requires_grad,
        )
        result._values = values
        result._offsets = offsets
        result._permutation = (
            tuple(int(dim) for dim in permutation)
            if permutation is not None
            else cls._permutation_from_physical_shape(shape_tensor, element_shapes)
        )
        result._physical_shape = shape_tensor
        result._logical_shape = logical_shape
        result._set_runtime_config(
            batch_first=batch_first,
            padding_value=padding_value,
            mask_value=mask_value,
        )
        result._pin_memory = bool(pin_memory and values.device.type == "cpu" and values.is_pinned())
        result._packed_sizes = packed_sizes
        result._element_shapes = element_shapes
        result._invalidate_transient_caches()
        cls._validate_packed_metadata(
            result._values,
            result._offsets,
            result._physical_shape,
            permutation=result._permutation,
            logical_shape=result._logical_shape,
            batch_first=result.batch_first,
            packed_sizes=result._packed_sizes,
            element_shapes=result._element_shapes,
        )
        return result

    # ------------------------------------------------------------------
    # torch.compile support
    # ------------------------------------------------------------------

    def __tensor_flatten__(self):
        # During tracing, wrapper instances can be inspected while being built.
        # Only expose tensor attrs that already exist so Dynamo/FakeTensor can
        # inspect partially constructed wrapper subclasses safely.
        instance_attrs = vars(self)
        inner_tensors = [name for name in ("_values", "_offsets", "_physical_shape") if name in instance_attrs]
        if not inner_tensors:
            inner_tensors = ["_flatten_sentinel"]
        return inner_tensors, {
            "batch_first": getattr(self, "batch_first", True),
            "padding_value": getattr(self, "padding_value", 0.0),
            "mask_value": getattr(self, "mask_value", False),
            "pin_memory": getattr(self, "_pin_memory", False),
            "packed_sizes": getattr(self, "_packed_sizes", ()),
            "element_shapes": getattr(self, "_element_shapes", ()),
            "permutation": getattr(self, "_permutation", ()),
        }

    @classmethod
    def __tensor_unflatten__(cls, inner_tensors, ctx, outer_size, outer_stride):
        values = inner_tensors.get("_values", inner_tensors.get("_flatten_sentinel"))
        if values is None:
            raise RuntimeError("NestedTensor requires _values during tensor unflatten.")

        offsets = inner_tensors.get("_offsets")
        shape_tensor = inner_tensors.get("_physical_shape")
        if offsets is not None and shape_tensor is not None:
            # During backward, outer_size may reflect a transposed view
            # (e.g., (seq, batch, hidden) from MHA's batch-dim transpose).
            # Detect and correct so _from_packed validation passes.
            batch_size = len(offsets) - 1
            outer = tuple(outer_size)
            batch_first = ctx.get("batch_first", True)
            if len(outer) >= 2 and (
                (batch_first and outer[0] != batch_size and outer[1] == batch_size)
                or (not batch_first and outer[1] != batch_size and outer[0] == batch_size)
            ):
                outer = (outer[1], outer[0], *outer[2:])
            return cls._from_packed(
                values,
                offsets,
                shape_tensor,
                outer_size=outer,
                **ctx,
            )

        result = torch.Tensor._make_wrapper_subclass(
            cls,
            torch.Size(outer_size),
            dtype=values.dtype,
            device=values.device,
            requires_grad=values.requires_grad,
        )
        result._values = values
        if offsets is not None:
            result._offsets = offsets
        if shape_tensor is not None:
            result._physical_shape = shape_tensor
        result._logical_shape = torch.Size(outer_size)
        result._set_runtime_config(
            batch_first=ctx["batch_first"],
            padding_value=ctx["padding_value"],
            mask_value=ctx["mask_value"],
        )
        result._pin_memory = ctx["pin_memory"]
        result._packed_sizes = ctx["packed_sizes"]
        result._element_shapes = ctx["element_shapes"]
        result._permutation = tuple(int(dim) for dim in ctx["permutation"])
        result._invalidate_transient_caches()
        return result

    # ------------------------------------------------------------------
    # Dispatch
    # ------------------------------------------------------------------

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any:
        if kwargs is None:
            kwargs = {}

        # Handle size() specially to avoid infinite recursion
        if func is torch.Tensor.size:
            self = args[0]
            dim = args[1] if len(args) > 1 else kwargs.get("dim")
            return self.size(dim)

        from .ops import NestedTensorFuncRegistry, _compile_unsupported, _is_compiling

        handler = NestedTensorFuncRegistry.get(func)
        if handler is not None:
            if _is_compiling() and not NestedTensorFuncRegistry.is_compile_safe(func, args, kwargs):
                name = getattr(func, "__qualname__", getattr(func, "__name__", repr(func)))
                _compile_unsupported(name, "handler is marked eager-only")
            return handler(*args, **kwargs)

        with torch._C.DisableTorchFunctionSubclass():
            return func(*args, **kwargs)

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None) -> Any:
        if kwargs is None:
            kwargs = {}

        from .ops import _compile_unsupported, _is_compiling

        if func in NestedTensorAtenRegistry:
            if _is_compiling() and not NestedTensorAtenRegistry.is_compile_safe(func, args, kwargs):
                name = getattr(func, "name", None)
                if callable(name):
                    name = name()
                _compile_unsupported(name or repr(func), "aten handler is marked eager-only")
            return NestedTensorAtenRegistry[func](func, args, kwargs)

        if _is_compiling():
            name = getattr(func, "name", None)
            if callable(name):
                name = name()
            _compile_unsupported(name or repr(func), "would fall back to per-element eager execution")
        return per_element_fallback(func, args, kwargs)

    # ------------------------------------------------------------------
    # Layout & Metadata Helpers
    # ------------------------------------------------------------------

    def _unpack(self) -> tuple[Tensor, ...]:
        r"""Reconstruct individual tensors from packed representation."""
        batch_size = len(self._offsets) - 1
        if batch_size == 0:
            return ()

        packed_sizes = self._packed_sizes
        if packed_sizes is None:
            if _is_fake_tensor(self._offsets):
                raise RuntimeError("NestedTensor packed sizes are unavailable for this instance.")
            packed_sizes = tuple(int(size) for size in (self._offsets[1:] - self._offsets[:-1]).tolist())

        element_shapes = self._element_shapes
        if element_shapes is None:
            element_shapes = tuple(tuple(int(dim) for dim in shape) for shape in self._original_shapes())

        splits = self._values.split(packed_sizes, dim=0)
        permutation = self._permutation
        if permutation:
            varying_dims = self._varying_dims
            static_dims = self._static_dims
        else:
            varying_dims, static_dims = type(self)._pack_layout_meta(self._physical_shape, element_shapes)
            permutation = varying_dims + static_dims
        inverse_permutation = type(self)._inverse_permutation(permutation)

        result = []
        for chunk, shape in zip(splits, element_shapes):
            if not shape:
                result.append(chunk[0])
            else:
                packed_shape = tuple(shape[dim] for dim in varying_dims) + tuple(shape[dim] for dim in static_dims)
                unpacked = chunk.reshape(packed_shape)
                if permutation != tuple(range(len(shape))):
                    unpacked = unpacked.permute(inverse_permutation)
                result.append(unpacked)
        return tuple(result)

    def _repack(self, tensors: Sequence) -> None:
        r"""
        Re-pack from already-validated tensors. Skips coercion — callers must ensure
        tensors share device, dtype, and ndim (which is always true for internal paths
        since tensors originate from _unpack or __setitem__ validation)."""
        self._invalidate_transient_caches()
        tensors = tuple(tensors) if not isinstance(tensors, tuple) else tensors
        if tensors and len(self._permutation) != tensors[0].ndim:
            raise RuntimeError(
                "NestedTensor._repack received tensors with rank "
                f"{tensors[0].ndim} but current permutation has rank {len(self._permutation)}"
            )
        values, offsets, shape_tensor, packed_sizes, element_shapes = self._pack(
            tensors,
            permutation=self._permutation if tensors else None,
        )
        values = type(self)._maybe_pin_values(values, self._pin_memory)
        self._values = values
        self._offsets = offsets
        self._physical_shape = shape_tensor
        self._logical_shape = self._compute_logical_shape(tensors, self.batch_first)
        self._packed_sizes = packed_sizes
        self._element_shapes = element_shapes
        self._validate_metadata()

    @property
    def _hierarchical_offsets(self) -> tuple[Tensor, ...]:
        if self._cached_hierarchical_offsets is None:
            level_sizes = type(self)._hierarchical_level_sizes_from_physical_shape(
                self._physical_shape,
                self._element_shapes,
            )
            if not level_sizes:
                if self._element_shapes is None and self._packed_sizes is not None:
                    self._cached_hierarchical_offsets = (
                        type(self)._offsets_from_sizes(self._packed_sizes, dtype=self._offsets.dtype),
                    )
                elif self._element_shapes is None and _is_fake_tensor(self._physical_shape):
                    self._cached_hierarchical_offsets = (self._offsets,)
                else:
                    self._cached_hierarchical_offsets = ()
            else:
                self._cached_hierarchical_offsets = tuple(
                    type(self)._offsets_from_sizes(level_sizes[level], dtype=self._offsets.dtype)
                    for level in range(len(level_sizes))
                )
        return self._cached_hierarchical_offsets

    @property
    def _ragged_rank(self) -> int:
        return len(self._hierarchical_offsets)

    def _ragged_level_offsets(self, level: int = -1) -> Tensor:
        offsets = self._hierarchical_offsets
        if not offsets:
            return self._offsets
        return offsets[level]

    def _ragged_level_sizes(self, level: int = -1) -> Tensor:
        offsets = self._ragged_level_offsets(level)
        return offsets[1:] - offsets[:-1]

    @property
    def _varying_dims(self) -> tuple[int, ...]:
        ragged_rank = self._ragged_rank
        if ragged_rank <= 0:
            return ()
        if self._permutation:
            return tuple(int(dim) for dim in self._permutation[:ragged_rank])
        varying_dims, _ = type(self)._pack_layout_meta(self._physical_shape, self._element_shapes)
        return varying_dims

    @property
    def _static_dims(self) -> tuple[int, ...]:
        ragged_rank = self._ragged_rank
        if self._permutation:
            return tuple(int(dim) for dim in self._permutation[ragged_rank:])
        _, static_dims = type(self)._pack_layout_meta(self._physical_shape, self._element_shapes)
        return static_dims

    def _has_same_structure(self, other: Self) -> bool:
        if self.batch_first != other.batch_first or self._permutation != other._permutation:
            return False
        if self._element_shapes is not None and other._element_shapes is not None:
            lhs_levels = type(self)._hierarchical_level_sizes_from_element_shapes(self._element_shapes)
            rhs_levels = type(self)._hierarchical_level_sizes_from_element_shapes(other._element_shapes)
            if lhs_levels or rhs_levels:
                return lhs_levels == rhs_levels
            return len(self) == len(other)
        lhs_offsets = self._hierarchical_offsets
        rhs_offsets = other._hierarchical_offsets
        if len(lhs_offsets) != len(rhs_offsets):
            return False
        if lhs_offsets:
            return all(type(self)._meta_tensor_equal(lhs, rhs) for lhs, rhs in zip(lhs_offsets, rhs_offsets))
        return type(self)._meta_tensor_equal(self._offsets, other._offsets)

    def _has_same_layout(self, other: Self) -> bool:
        if not self._has_same_structure(other):
            return False
        if self._element_shapes is not None and other._element_shapes is not None:
            if self._element_shapes != other._element_shapes:
                return False
            if self._packed_sizes is not None and other._packed_sizes is not None:
                return self._packed_sizes == other._packed_sizes
            return True
        if (
            self._packed_sizes is not None
            and other._packed_sizes is not None
            and self._packed_sizes != other._packed_sizes
        ):
            return False
        if not type(self)._meta_tensor_equal(self._physical_shape, other._physical_shape):
            return False
        return type(self)._meta_tensor_equal(self._offsets, other._offsets)

    def _packed_flat_index(
        self,
        *,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.long,
    ) -> Tensor:
        target_device = self.device if device is None else device
        leading = self._values.size(0) if self._values.dim() > 0 else self._values.numel()
        return torch.arange(leading, device=target_device, dtype=dtype)

    def _packed_batch_local_indices(
        self,
        flat_idx: Tensor | None = None,
        *,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.long,
    ) -> tuple[Tensor, Tensor]:
        target_device = self.device if device is None else device
        if flat_idx is None:
            flat_idx = self._packed_flat_index(device=target_device, dtype=dtype)
        offsets = self._offsets.to(device=target_device, dtype=dtype)
        batch_idx = torch.searchsorted(offsets[1:], flat_idx, right=True)
        local_idx = flat_idx - offsets[batch_idx]
        return batch_idx, local_idx

    def _packed_varying_coords(
        self,
        batch_idx: Tensor,
        local_idx: Tensor,
        *,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.long,
    ) -> tuple[Tensor, ...]:
        target_device = self.device if device is None else device
        varying_dims = self._varying_dims
        if not varying_dims:
            return ()

        varying_sizes = self._physical_shape[:, list(varying_dims)].to(device=target_device, dtype=dtype)[batch_idx]
        strides = torch.ones_like(varying_sizes)
        running = torch.ones(varying_sizes.size(0), dtype=dtype, device=target_device)
        for dim in range(varying_sizes.size(1) - 1, -1, -1):
            strides[:, dim] = running
            running = running * varying_sizes[:, dim]

        coords: list[Tensor] = []
        remainder = local_idx
        for dim in range(varying_sizes.size(1)):
            coord = remainder // strides[:, dim]
            coords.append(coord)
            remainder = remainder - coord * strides[:, dim]
        return tuple(coords)

    def _packed_dense_index(
        self,
        flat_idx: Tensor | None = None,
        *,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.long,
    ) -> tuple[Tensor | slice, ...]:
        target_device = self.device if device is None else device
        batch_idx, local_idx = self._packed_batch_local_indices(flat_idx, device=target_device, dtype=dtype)
        varying_dims = self._varying_dims
        coords = self._packed_varying_coords(batch_idx, local_idx, device=target_device, dtype=dtype)

        dense_index: list[Tensor | slice] = [batch_idx]
        coord_iter = iter(coords)
        for dim in range(self._physical_shape.size(1)):
            dense_index.append(next(coord_iter) if dim in varying_dims else slice(None))
        return tuple(dense_index)

    def _physical_shape_like_batch_dense(self, batch_dense_shape: Sequence[int]) -> Tensor:
        r"""Return per-element shapes for a batch-leading dense tensor with this NestedTensor's ragged structure."""
        expected_ndim = self._physical_shape.size(1) + 1
        if len(batch_dense_shape) != expected_ndim:
            raise ValueError(
                "Batch-leading dense tensor rank does not match NestedTensor layout, "
                f"expected {expected_ndim}, got {len(batch_dense_shape)}"
            )
        shape, _, _ = self._shape_meta_from_components(
            replace_dims={int(dim): int(batch_dense_shape[dim + 1]) for dim in self._static_dims}
        )
        return shape

    def _element_shapes_like_batch_dense(
        self,
        batch_dense_shape: Sequence[int],
    ) -> tuple[tuple[int, ...], ...] | None:
        r"""Return Python element-shape metadata for a batch-leading dense tensor with this NestedTensor's layout."""
        expected_ndim = self._physical_shape.size(1) + 1
        if len(batch_dense_shape) != expected_ndim:
            raise ValueError(
                "Batch-leading dense tensor rank does not match NestedTensor layout, "
                f"expected {expected_ndim}, got {len(batch_dense_shape)}"
            )
        _, _, element_shapes = self._shape_meta_from_components(
            replace_dims={int(dim): int(batch_dense_shape[dim + 1]) for dim in self._static_dims}
        )
        return element_shapes

    def _shape_meta_from_components(
        self,
        *,
        prefix: Sequence[int] = (),
        keep_dims: Sequence[int] | None = None,
        suffix: Sequence[int] = (),
        replace_dims: Mapping[int, int] | None = None,
    ) -> tuple[Tensor, tuple[int, ...] | None, tuple[tuple[int, ...], ...] | None]:
        r"""Build packed shape metadata by keeping selected dims and applying constant prefix/suffix updates."""
        if keep_dims is None:
            keep_dims = tuple(range(self._physical_shape.size(1)))
        keep_dims = tuple(int(dim) for dim in keep_dims)
        prefix = tuple(int(size) for size in prefix)
        suffix = tuple(int(size) for size in suffix)
        updates = {int(dim): int(size) for dim, size in (replace_dims or {}).items()}

        if self._element_shapes:
            element_shapes_list: list[tuple[int, ...]] = []
            for element_shape in self._element_shapes:
                projected = [*prefix, *(int(element_shape[dim]) for dim in keep_dims), *suffix]
                for dim, size in updates.items():
                    projected[dim] = size
                element_shapes_list.append(tuple(projected))
            element_shapes = tuple(element_shapes_list)
            max_ndim = max(len(shape) for shape in element_shapes)
            shape = torch.tensor(
                [list(shape) + [0] * (max_ndim - len(shape)) for shape in element_shapes],
                dtype=torch.long,
            )
            return shape, self._packed_sizes_like(element_shapes), element_shapes

        parts: list[Tensor] = []
        batch_size = len(self)
        if prefix:
            parts.append(self._physical_shape.new_tensor(prefix).reshape(1, -1).expand(batch_size, -1))
        if keep_dims:
            parts.append(self._physical_shape[:, list(keep_dims)].clone())
        if suffix:
            parts.append(self._physical_shape.new_tensor(suffix).reshape(1, -1).expand(batch_size, -1))
        if parts:
            shape = torch.cat(parts, dim=1)
        else:
            shape = self._physical_shape.new_empty((batch_size, 0))
        for dim, size in updates.items():
            shape[:, dim] = size
        return shape, None, None

    def _max_physical_dims(self) -> tuple[int, ...]:
        r"""Return the maximum per-element size for each physical dimension (excluding batch)."""
        batch_dim = 0 if self.batch_first else 1
        return tuple(int(size) for index, size in enumerate(self._logical_shape) if index != batch_dim)

    def _logical_shape_from_physical_dims(self, physical_dims: Sequence[int]) -> torch.Size:
        r"""Build a logical outer shape from non-batch physical-dimension sizes."""
        physical_dims = tuple(int(size) for size in physical_dims)
        batch_size = len(self)
        if self.batch_first:
            return torch.Size((batch_size, *physical_dims))
        if not physical_dims:
            return torch.Size((batch_size,))
        return torch.Size((physical_dims[0], batch_size, *physical_dims[1:]))

    def _logical_shape_from_components(
        self,
        *,
        prefix: Sequence[int] = (),
        keep_dims: Sequence[int] | None = None,
        suffix: Sequence[int] = (),
        replace_dims: Mapping[int, int] | None = None,
    ) -> torch.Size:
        r"""Build a logical outer shape by projecting the current physical-dimension extents."""
        physical_dims = list(self._max_physical_dims())
        if keep_dims is None:
            keep_dims = tuple(range(len(physical_dims)))
        projected = [*(int(prefix_dim) for prefix_dim in prefix), *(physical_dims[int(dim)] for dim in keep_dims)]
        projected.extend(int(suffix_dim) for suffix_dim in suffix)
        for dim, size in (replace_dims or {}).items():
            projected[int(dim)] = int(size)
        return self._logical_shape_from_physical_dims(projected)

    def _leading_dim_preserving_meta(
        self,
        suffix: Sequence[int],
    ) -> tuple[Tensor, torch.Size, tuple[int, ...] | None, tuple[tuple[int, ...], ...] | None]:
        r"""Build metadata for ops that preserve the first per-element dim and replace all trailing dims uniformly."""
        keep_dims = (0,) if self._physical_shape.size(1) > 0 else ()
        shape, packed_sizes, element_shapes = self._shape_meta_from_components(keep_dims=keep_dims, suffix=suffix)
        return shape, self._leading_dim_preserving_outer_size(suffix), packed_sizes, element_shapes

    def _leading_dim_preserving_outer_size(self, suffix: Sequence[int]) -> torch.Size:
        r"""Return logical outer size for ops that preserve per-element dim-0 and replace trailing dims uniformly."""
        suffix = tuple(int(size) for size in suffix)
        batch_size = len(self)
        batch_dim = 0 if self.batch_first else 1
        logical = list(self._logical_shape)
        non_batch = [int(logical[index]) for index in range(len(logical)) if index != batch_dim]

        new_non_batch: list[int] = []
        if self._physical_shape.size(1) > 0 and non_batch:
            new_non_batch.append(non_batch[0])
        new_non_batch.extend(suffix)

        if self.batch_first:
            return torch.Size((batch_size, *new_non_batch))
        if not new_non_batch:
            return torch.Size((batch_size,))
        return torch.Size((new_non_batch[0], batch_size, *new_non_batch[1:]))

    def _drop_trailing_physical_dims_meta(
        self,
        count: int,
        *,
        suffix: Sequence[int] = (),
    ) -> tuple[Tensor, tuple[int, ...] | None, tuple[tuple[int, ...], ...] | None]:
        r"""Build metadata after dropping trailing per-element dims and optionally appending a dense suffix."""
        keep_dims = tuple(range(max(self._physical_shape.size(1) - int(count), 0)))
        return self._shape_meta_from_components(keep_dims=keep_dims, suffix=suffix)

    def _replace_trailing_physical_dims_meta(
        self,
        trailing_sizes: Sequence[int],
    ) -> tuple[Tensor, tuple[int, ...] | None, tuple[tuple[int, ...], ...] | None]:
        r"""Build metadata after replacing the last physical dims with uniform sizes."""
        trailing_sizes = tuple(int(size) for size in trailing_sizes)
        if not trailing_sizes:
            return self._shape_meta_from_components()
        ndim = self._physical_shape.size(1)
        if len(trailing_sizes) > ndim:
            raise ValueError(f"Cannot replace {len(trailing_sizes)} trailing dims for per-element rank {ndim}")
        start = ndim - len(trailing_sizes)
        return self._shape_meta_from_components(
            replace_dims={start + index: size for index, size in enumerate(trailing_sizes)}
        )

    def _permutation_after_dropping_trailing_dims(self, count: int) -> tuple[int, ...]:
        r"""Return the canonical permutation after dropping trailing physical dims."""
        count = int(count)
        new_rank = max(self._physical_shape.size(1) - count, 0)
        if not self._permutation:
            return tuple(range(new_rank))
        return tuple(int(dim) for dim in self._permutation if dim < new_rank)

    def _permutation_after_replacing_trailing_dims(self, removed_count: int, added_count: int) -> tuple[int, ...]:
        r"""Return the canonical permutation after replacing trailing physical dims with a new suffix."""
        removed_count = int(removed_count)
        added_count = int(added_count)
        retained_rank = max(self._physical_shape.size(1) - removed_count, 0)
        retained = self._permutation_after_dropping_trailing_dims(removed_count)
        appended = tuple(range(retained_rank, retained_rank + added_count))
        return retained + appended

    def _scalar_result_meta(
        self,
    ) -> tuple[Tensor, Tensor, torch.Size, tuple[int, ...] | None, tuple[tuple[int, ...], ...] | None]:
        r"""Build metadata for one-scalar-per-element outputs."""
        shape, packed_sizes, element_shapes = self._shape_meta_from_components(keep_dims=())
        offsets = torch.arange(len(self) + 1, dtype=self._offsets.dtype, device=self._offsets.device)
        logical_shape = type(self)._logical_shape_from_physical_shape(shape, self._offsets, self.batch_first)
        return offsets, shape, logical_shape, packed_sizes, element_shapes

    def _from_scalar_result_values(self, values: Tensor) -> Self:
        r"""Wrap one scalar per element using the canonical scalar-result metadata."""
        cls = type(self)
        offsets, shape, outer_size, packed_sizes, element_shapes = self._scalar_result_meta()
        return cls._from_packed(
            values,
            offsets,
            shape,
            batch_first=self.batch_first,
            padding_value=self.padding_value,
            mask_value=self.mask_value,
            pin_memory=self._pin_memory,
            outer_size=outer_size,
            packed_sizes=packed_sizes,
            element_shapes=element_shapes,
        )

    @classmethod
    def _cat_batch_packed(cls, tensors: Sequence[Self]) -> Self | None:
        r"""Merge batch-dim concatenation directly from packed storage when layouts are compatible."""
        if not tensors:
            raise ValueError("Expected at least one NestedTensor to concatenate.")

        ref = tensors[0]
        packed_rank = ref._values.dim()
        packed_tail = ref._values.shape[1:]
        reference_permutation = ref._permutation
        for tensor in tensors[1:]:
            if tensor._values.dim() != packed_rank:
                return None
            if tensor._permutation != reference_permutation:
                return None
            if packed_rank > 1 and tensor._values.shape[1:] != packed_tail:
                return None

        new_values = torch.cat([tensor._values for tensor in tensors], dim=0)

        offset_parts = []
        cumulative = 0
        for index, tensor in enumerate(tensors):
            offsets = tensor._offsets if index == 0 else tensor._offsets[1:] + cumulative
            offset_parts.append(offsets)
            cumulative += int(tensor._offsets[-1].item())
        new_offsets = torch.cat(offset_parts, dim=0)

        max_cols = max(tensor._physical_shape.size(1) for tensor in tensors)
        if max_cols > 0:
            padded_shapes = []
            for tensor in tensors:
                physical_shape = tensor._physical_shape
                if physical_shape.size(1) < max_cols:
                    physical_shape = torch.nn.functional.pad(physical_shape, (0, max_cols - physical_shape.size(1)))
                padded_shapes.append(physical_shape)
            new_physical_shape = torch.cat(padded_shapes, dim=0)
        else:
            new_physical_shape = torch.empty(len(new_offsets) - 1, 0, dtype=torch.long)

        batch_dim = 0 if ref.batch_first else 1
        out_logical = list(ref._logical_shape)
        if len(out_logical) <= batch_dim:
            out_logical.extend(0 for _ in range(batch_dim + 1 - len(out_logical)))
        out_logical[batch_dim] = sum(len(tensor) for tensor in tensors)
        for logical_dim in range(len(out_logical)):
            if logical_dim == batch_dim:
                continue
            out_logical[logical_dim] = max(
                int(tensor._logical_shape[logical_dim]) if logical_dim < len(tensor._logical_shape) else 0
                for tensor in tensors
            )

        packed_sizes = None
        if all(tensor._packed_sizes is not None for tensor in tensors):
            packed_sizes = tuple(size for tensor in tensors for size in cast(tuple[int, ...], tensor._packed_sizes))
        element_shapes = None
        if all(tensor._element_shapes is not None for tensor in tensors):
            element_shapes = tuple(
                shape for tensor in tensors for shape in cast(tuple[tuple[int, ...], ...], tensor._element_shapes)
            )

        return cls._from_packed(
            new_values,
            new_offsets,
            new_physical_shape,
            permutation=reference_permutation,
            batch_first=ref.batch_first,
            padding_value=ref.padding_value,
            mask_value=ref.mask_value,
            pin_memory=ref._pin_memory,
            outer_size=tuple(out_logical),
            packed_sizes=packed_sizes,
            element_shapes=element_shapes,
        )

    @property
    def _storage(self) -> tuple[Tensor, ...]:
        if self._cached_storage is None:
            self._cached_storage = self._unpack()
        return self._cached_storage

    @_storage.setter
    def _storage(self, tensors: Sequence) -> None:
        self._repack(tensors)

    # ------------------------------------------------------------------
    # Cached materialized views
    # ------------------------------------------------------------------

    def _tensor_cached_view(self) -> Tensor:
        cached = self._cached_tensor_view
        token = self._values_cache_token()
        if (
            cached is not None
            and cached[0] is self.batch_first
            and cached[1] == self.padding_value
            and cached[2] == token
        ):
            return cached[3]
        batch_leading = self._materialize_batch_leading(self.padding_value)
        tensor = batch_leading if self.batch_first else batch_leading.movedim(0, 1)
        self._cached_tensor_view = (self.batch_first, self.padding_value, token, tensor)
        return tensor

    def _mask_cached_view(self) -> Tensor:
        cached = self._cached_mask_view
        token = self._shape_cache_token()
        if cached is not None and cached[0] is self.batch_first and cached[1] is self.mask_value and cached[2] == token:
            return cached[3]
        mask = self._materialize_mask()
        self._cached_mask_view = (self.batch_first, self.mask_value, token, mask)
        return mask

    @property
    def tensor_mask(self) -> tuple[Tensor, Tensor]:
        r"""
        Return a tuple of padded tensor and mask tensor.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.tensor_mask
            (tensor([[1, 2, 3],
                    [4, 5, 0]]), tensor([[ True,  True,  True],
                    [ True,  True, False]]))
        """
        return self._tensor_cached_view(), self._mask_cached_view()

    @property
    def tensor(self) -> Tensor:
        r"""
        Return a single tensor by padding all the tensors.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.tensor
            tensor([[1, 2, 3],
                    [4, 5, 0]])
        """
        return self._tensor_cached_view()

    @property
    def mask(self) -> Tensor:
        r"""
        Padding mask of `tensor`.

        `mask_value` controls which boolean value denotes padding in this mask.
        With the default `mask_value=False`, `True` means valid data.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.mask
            tensor([[ True,  True,  True],
                    [ True,  True, False]])
        """
        return self._mask_cached_view()

    def _mask_squeezes_channel(self) -> bool:
        return self._physical_shape.size(1) > 1 and (self._physical_shape.size(1) - 1) in self._static_dims

    def _materialize_mask(self) -> Tensor:
        batch_size = len(self)
        logical_shape = self._logical_shape
        squeeze_channel = self._mask_squeezes_channel()

        if batch_size == 0:
            mask_shape = logical_shape[:-1] if squeeze_channel else logical_shape
            return torch.empty(mask_shape, dtype=torch.bool, device=self.device)

        if self._physical_shape.size(1) == 0:
            return torch.full((batch_size,), not self.mask_value, dtype=torch.bool, device=self.device)

        effective_shape = logical_shape[:-1] if squeeze_channel else logical_shape
        batch_dim = 0 if self.batch_first else 1
        non_batch_sizes = [effective_shape[i] for i in range(len(effective_shape)) if i != batch_dim]

        sizes = self._physical_shape[:, :-1] if squeeze_channel else self._physical_shape
        sizes = sizes.to(device=self.device, dtype=torch.long)

        valid = _batch_leading_valid_mask_from_sizes(
            sizes,
            non_batch_sizes,
            device=self.device,
        )

        if not self.batch_first:
            valid = valid.movedim(0, 1)
        return valid if not self.mask_value else ~valid

    def _materialize_batch_leading(self, fill_value) -> Tensor:
        r"""Materialize a padded dense tensor with the batch dimension in front."""
        _check_execution_guard(_ExecutionGuardKind.PADDED_MATERIALIZATION, "NestedTensor._materialize_batch_leading")
        logical_shape = self._logical_shape
        batch_size = len(self)
        if batch_size == 0:
            if self.batch_first:
                return torch.empty(logical_shape, dtype=self._values.dtype, device=self.device)
            if len(logical_shape) <= 1:
                return torch.empty((0,), dtype=self._values.dtype, device=self.device)
            non_batch = list(logical_shape)
            non_batch.pop(1)
            return torch.empty((0, *non_batch), dtype=self._values.dtype, device=self.device)

        if self._physical_shape.size(1) == 0:
            return self._values.reshape((batch_size,))

        tensor_shape = list(logical_shape)
        tensor_shape.pop(0 if self.batch_first else 1)
        batch_leading = self._values.new_full((batch_size, *tensor_shape), fill_value)
        if self._values.size(0) > 0:
            batch_leading[self._packed_dense_index(device=batch_leading.device)] = self._values
        return batch_leading

    def _original_shapes(self) -> tuple[torch.Size, ...]:
        if self._element_shapes is not None:
            return tuple(torch.Size(shape) for shape in self._element_shapes)
        if not _is_fake_tensor(self._physical_shape):
            return tuple(torch.Size(type(self)._trim_shape(row)) for row in self._physical_shape.tolist())
        raise RuntimeError("NestedTensor shape metadata is unavailable for this instance.")

    @property
    def concat(self) -> Tensor:
        r"""
        Flatten elements and concatenate along the ragged dimension (no padding).

        This is particularly useful when calculating loss or passing `Linear` to avoid unnecessary computation.

        Examples:
            >>> nested_tensor = NestedTensor([torch.randn(9, 8), torch.randn(11, 8)])
            >>> nested_tensor.concat.shape
            torch.Size([20, 8])
            >>> nested_tensor = NestedTensor([torch.randn(9, 9, 8), torch.randn(11, 11, 8)])
            >>> nested_tensor.concat.shape
            torch.Size([202, 8])
            >>> nested_tensor = NestedTensor([torch.randn(9, 9, 8, 6), torch.randn(11, 11, 8, 6)])
            >>> nested_tensor.concat.shape
            torch.Size([202, 8, 6])
            >>> nested_tensor = NestedTensor([torch.randn(9, 9, 8, 7), torch.randn(11, 11, 8, 6)])
            >>> nested_tensor.concat.shape
            torch.Size([1293, 8])
            >>> nested_tensor = NestedTensor([torch.randn(1, 9, 9, 5), torch.randn(1, 11, 11, 5)])
            >>> nested_tensor.concat.shape
            torch.Size([202, 1, 5])
        """
        if len(self._offsets) <= 1:
            return torch.empty(0, dtype=self._values.dtype, device=self.device)
        return self._values

    def concatenate(self) -> tuple[Tensor, tuple[torch.Size, ...]]:
        r"""
        Concatenate tensors in padding dimension and return structural information for reconstruction.

        Returns:
            A tuple containing:
            - concat_tensor: The concatenated tensor (same as .concat property)
            - shapes: Tuple of original tensor shapes for reconstruction

        Examples:
            >>> nested_tensor = NestedTensor([torch.randn(9, 8), torch.randn(11, 8)])
            >>> concat_tensor, shapes = nested_tensor.concatenate()
            >>> concat_tensor.shape
            torch.Size([20, 8])
            >>> shapes
            (torch.Size([9, 8]), torch.Size([11, 8]))
            >>> reconstructed = NestedTensor.from_concatenated(concat_tensor, shapes)
            >>> torch.equal(nested_tensor.tensor, reconstructed.tensor)
            True
        """
        batch_size = len(self._offsets) - 1
        if batch_size == 0:
            return torch.empty(0, dtype=self._values.dtype, device=self.device), ()
        return self._values, self._original_shapes()

    # ------------------------------------------------------------------
    # Container protocol
    # ------------------------------------------------------------------

    def __len__(self) -> int:
        r"""Return the number of tensors in the batch."""
        if not hasattr(self, "_offsets"):
            with torch._C.DisableTorchFunctionSubclass():
                full_size = torch.Tensor.size(self)
            if len(full_size) == 0:
                return 0
            batch_dim = 0 if getattr(self, "batch_first", True) else (1 if len(full_size) > 1 else 0)
            return int(full_size[batch_dim])
        return len(self._offsets) - 1

    def __repr__(self):
        r"""Return a human-readable string representation of the NestedTensor."""
        if torch._dynamo.is_compiling():
            try:
                shape = tuple(self.size())
            except Exception:
                shape = "?"
            return (
                f"{self.__class__.__name__}(shape={shape}, dtype={self.dtype}, "
                f"device={self.device}, batch_first={getattr(self, 'batch_first', True)})"
            )

        try:
            from torch._subclasses.fake_tensor import is_fake

            for name in ("_values", "_offsets", "_physical_shape"):
                value = self.__dict__.get(name)
                if isinstance(value, Tensor) and is_fake(value):
                    shape = tuple(self.size())
                    return (
                        f"{self.__class__.__name__}(shape={shape}, dtype={self.dtype}, "
                        f"device={self.device}, batch_first={getattr(self, 'batch_first', True)})"
                    )
        except Exception:
            pass

        if not all(name in self.__dict__ for name in ("_values", "_offsets", "_physical_shape")):
            try:
                shape = tuple(self.size())
            except Exception:
                shape = "?"
            return (
                f"{self.__class__.__name__}(shape={shape}, dtype={self.dtype}, "
                f"device={self.device}, batch_first={getattr(self, 'batch_first', True)})"
            )

        if len(self) == 0:
            return self.__class__.__name__ + "()"

        storage = self._storage
        truncated = len(storage) > 10
        if truncated:
            storage = storage[:5]

        indent = "    "

        # Strip "tensor(" wrapper from each element's repr,
        # keeping PyTorch's internal number formatting (precision, alignment).
        data_parts = []
        for t in storage:
            s = repr(t)
            paren_idx = s.index("(")
            data = s[paren_idx + 1 : -1]  # noqa: E203
            # Re-indent continuation lines for multi-line element reprs (e.g. 2D tensors)
            if "\n" in data:
                lines = data.split("\n")
                data = lines[0] + "\n" + "\n".join(indent + " " + line.lstrip() for line in lines[1:])
            data_parts.append(data)

        result_lines = [self.__class__.__name__ + "(["]
        for i, part in enumerate(data_parts):
            suffix = "," if i < len(data_parts) - 1 or truncated else ""
            result_lines.append(indent + part + suffix)
        if truncated:
            result_lines.append(indent + f"... ({len(self)} tensors)")
        result_lines.append("])")
        return "\n".join(result_lines)

    def __bool__(self) -> bool:
        r"""NestedTensor follows tensor-style truthiness and never acts like a Python container."""
        raise RuntimeError(
            "Boolean value of NestedTensor is ambiguous. Use .numel(), .any(), .all(), or an explicit reduction."
        )

    def __iter__(self):
        r"""Iterate over the tensors in the batch."""
        _check_execution_guard(_ExecutionGuardKind.ITERATION, "NestedTensor.__iter__")
        return iter(self._storage)

    def __eq__(self, other):  # type: ignore[override]
        r"""Element-wise equality comparison."""
        try:
            return torch.eq(self, other)
        except TypeError:
            return NotImplemented

    def __ne__(self, other):  # type: ignore[override]
        r"""Element-wise inequality comparison."""
        try:
            return torch.ne(self, other)
        except TypeError:
            return NotImplemented

    # Python sets __hash__ = None when __eq__ is overridden in a subclass.
    # Preserve Tensor's identity hash so AOT/torch.compile memoization works.
    __hash__ = Tensor.__hash__

    # Arithmetic, comparison, and in-place operators are handled by the base
    # Tensor class, which routes through C++ → aten → __torch_dispatch__ →
    # aten_functions.py. No Python-level overrides needed.

    # ------------------------------------------------------------------
    # Conversion & Factory Methods
    # ------------------------------------------------------------------

    @classmethod
    def from_concatenated(cls, concat_tensor: Tensor, shapes: tuple[torch.Size, ...], **kwargs) -> Self:
        r"""
        Reconstruct a NestedTensor from a concatenated tensor and shape information.

        Args:
            concat_tensor: The concatenated tensor returned by concatenate()
            shapes: Tuple of original tensor shapes returned by concatenate()
            **kwargs: Additional arguments to pass to NestedTensor constructor

        Returns:
            Reconstructed NestedTensor

        Examples:
            >>> nested_tensor = NestedTensor([torch.randn(9, 9, 8), torch.randn(11, 11, 8)])
            >>> concat_tensor, shapes = nested_tensor.concatenate()
            >>> reconstructed = NestedTensor.from_concatenated(concat_tensor, shapes)
            >>> concat_tensor.shape
            torch.Size([202, 8])
            >>> reconstructed.shape
            torch.Size([2, 11, 11, 8])
            >>> torch.equal(nested_tensor.tensor, reconstructed.tensor)
            True
        """
        if not shapes:
            if "dtype" not in kwargs:
                kwargs["dtype"] = concat_tensor.dtype
            if "device" not in kwargs:
                kwargs["device"] = concat_tensor.device
            return cls([], **kwargs)

        num_elements = [shape.numel() for shape in shapes]
        element_shapes = tuple(tuple(int(dim) for dim in shape) for shape in shapes)
        varying_dims, static_dims = cls._pack_layout_from_element_shapes(element_shapes)
        permutation = varying_dims + static_dims
        identity_permutation = tuple(range(len(element_shapes[0]))) if element_shapes and element_shapes[0] else ()

        if len(set(shapes)) == 1 and permutation == identity_permutation:
            shape = shapes[0]
            total_elements = sum(num_elements)
            if concat_tensor.numel() == total_elements:
                try:
                    reshaped = concat_tensor.reshape(len(shapes), *shape)
                except (RuntimeError, ValueError):
                    # The reshape fast path is opportunistic; a normal unpack fallback
                    # is expected for non-view-compatible inputs.
                    pass
                else:
                    tensors = [t.reshape(shape) for t in reshaped.unbind(0)]
                    return cls(tensors, **kwargs)

        packed_sizes = tuple(cls._packed_size_from_shape(shape, varying_dims) for shape in element_shapes)
        total_expected = sum(num_elements)
        num_provided = concat_tensor.numel()
        if num_provided != total_expected:
            raise ValueError(
                f"Concatenated tensor has {num_provided} elements "
                f"but expected {total_expected} based on shapes {shapes}"
            )

        tensors = []
        start = 0
        inverse_permutation = cls._inverse_permutation(permutation)
        for shape, packed_size in zip(element_shapes, packed_sizes):
            end = start + packed_size
            chunk = concat_tensor.narrow(0, start, packed_size)
            packed_shape = tuple(shape[dim] for dim in varying_dims) + tuple(shape[dim] for dim in static_dims)
            tensor_data = chunk.reshape(packed_shape)
            if permutation != tuple(range(len(shape))):
                tensor_data = tensor_data.permute(inverse_permutation)
            tensors.append(tensor_data)
            start = end

        return cls(tensors, **kwargs)

    @classmethod
    def from_tensor_mask(cls, tensor: Tensor, mask: Tensor, *, batched: bool = False, **kwargs):
        r"""
        Build a `NestedTensor` object from a padded `Tensor` and corresponding mask `Tensor`.

        Args:
            tensor: Padded Tensor.
            mask: Tensor Mask.
                The mask uses the same convention as ``mask_value``:
                padding positions equal ``mask_value`` and valid positions equal ``not mask_value``.
            batched: When ``True`` and ``mask.ndim == 1``, treat ``mask`` as a per-batch-element
                selector (each ``True`` entry selects a row from ``tensor``) rather than a
                contiguous-prefix length indicator.

        Examples:
            >>> padded_tensor = torch.tensor([[1, 2, 3, 0, 0],
            ...                                [4, 5, 0, 0, 0],
            ...                                [6, 7, 8, 9, 0]])
            >>> mask_tensor = torch.tensor([[1, 1, 1, 0, 0],
            ...                             [1, 1, 0, 0, 0],
            ...                             [1, 1, 1, 1, 0]])
            >>> nested_tensor = NestedTensor.from_tensor_mask(padded_tensor, mask_tensor)
            >>> nested_tensor
            NestedTensor([
                [1, 2, 3],
                [4, 5],
                [6, 7, 8, 9]
            ])
        """
        mask = mask.to(dtype=torch.bool)
        mask_value = kwargs.get("mask_value", False)
        effective_mask = ~mask if mask_value else mask

        if mask.ndim == 1:
            if batched:
                indices = effective_mask.nonzero(as_tuple=False).flatten()
                return cls([tensor[int(i)] for i in indices], dtype=tensor.dtype, **kwargs)
            return cls(tensor[effective_mask], dtype=tensor.dtype, **kwargs)
        # ndim >= 2: batch setup is shared, per-element trim differs by rank
        batch_first = kwargs.get("batch_first", True)
        tensor_iter = tensor if batch_first else tensor.transpose(0, 1)
        mask_iter = effective_mask if batch_first else effective_mask.transpose(0, 1)
        if tensor_iter.size(0) != mask_iter.size(0):
            raise ValueError("Tensor/mask batch dimension mismatch: " f"{tensor_iter.size(0)} vs {mask_iter.size(0)}")
        trimmed = []

        def _is_prefix_mask(mask_1d: Tensor) -> bool:
            count = int(mask_1d.sum().item())
            prefix = torch.arange(mask_1d.size(0), device=mask_1d.device, dtype=torch.long) < count
            return bool(torch.equal(mask_1d, prefix))

        def _is_hierarchical_prefix_mask(mask_nd: Tensor) -> bool:
            if mask_nd.dim() == 1:
                return _is_prefix_mask(mask_nd)
            leading_valid = mask_nd.reshape(mask_nd.size(0), -1).any(dim=1)
            valid_count = int(leading_valid.sum().item())
            prefix = torch.arange(mask_nd.size(0), device=mask_nd.device, dtype=torch.long) < valid_count
            if not torch.equal(leading_valid, prefix):
                return False
            return all(_is_hierarchical_prefix_mask(mask_nd[index]) for index in range(valid_count))

        if mask.ndim == 2:
            # 1-D per-element mask: only contiguous-prefix masks can be reconstructed
            # via slicing without changing dense semantics.
            counts = mask_iter.sum(dim=1, dtype=torch.long)
            prefix = torch.arange(mask_iter.size(1), device=mask_iter.device, dtype=torch.long).unsqueeze(0)
            prefix = prefix < counts.unsqueeze(1)
            if not torch.equal(mask_iter, prefix):
                raise ValueError(
                    "from_tensor_mask() with 2-D masks requires each row to be a valid prefix mask; "
                    "interior False gaps are not supported."
                )
            for t, count in zip(tensor_iter, counts.tolist()):
                trimmed.append(t[:count])
        else:
            # N-D per-element mask: only hierarchical ragged-prefix masks are representable as NestedTensor.
            extents = torch.zeros((mask_iter.size(0), mask_iter.dim() - 1), dtype=torch.long, device=mask_iter.device)
            nonzero = mask_iter.nonzero(as_tuple=False)
            if nonzero.numel() > 0:
                batch_index = nonzero[:, :1].expand(-1, extents.size(1))
                extents.scatter_reduce_(0, batch_index, nonzero[:, 1:] + 1, reduce="amax", include_self=False)
            extent_rows = extents.cpu().tolist()
            for t, em, sizes in zip(tensor_iter, mask_iter, extent_rows):
                if not _is_hierarchical_prefix_mask(em):
                    raise ValueError(
                        "from_tensor_mask() with N-D masks requires each element mask to be a valid hierarchical "
                        "ragged prefix; "
                        "interior False gaps are not supported."
                    )
                slices = tuple(slice(0, size) for size in sizes)
                t_slice = t[slices]
                m_slice = em[slices]
                valid_mask = m_slice
                if t_slice.dim() > m_slice.dim():
                    valid_mask = m_slice.view(m_slice.shape + (1,) * (t_slice.dim() - m_slice.dim()))
                trimmed.append(t_slice.masked_fill(~valid_mask, kwargs.get("padding_value", 0.0)))
        return cls(trimmed, dtype=tensor.dtype, **kwargs)

    def _dense_to_packed_values(self, tensor: Tensor) -> Tensor | None:
        r"""
        Convert a batch-aligned dense tensor to ``self``'s packed ``_values`` layout.

        Returns ``None`` when the dense tensor does not cover the current logical
        padded extents and we must fall back to per-element slicing/repacking.
        """
        batch_leading = tensor.to(device=self.device)
        if self.dim() > 1 and not self.batch_first:
            batch_leading = batch_leading.movedim(1, 0)

        logical_shape = list(self.shape)
        if logical_shape:
            batch_dim = 0 if self.dim() <= 1 or self.batch_first else 1
            logical_shape.pop(batch_dim)
        if batch_leading.dim() != len(logical_shape) + 1:
            return None

        dense_sizes = tuple(int(batch_leading.size(dim + 1)) for dim in range(batch_leading.dim() - 1))
        if any(dense_sizes[dim] < int(size) for dim, size in enumerate(logical_shape)):
            return None

        if logical_shape:
            batch_leading = batch_leading[(slice(None), *[slice(0, int(size)) for size in logical_shape])]

        if batch_leading.dim() <= 1:
            return batch_leading.contiguous()

        return batch_leading[self._packed_dense_index(device=batch_leading.device)].contiguous()

    def _packed_sizes_like(self, element_shapes: tuple[tuple[int, ...], ...]) -> tuple[int, ...]:
        varying_dims, _ = type(self)._pack_layout_from_element_shapes(element_shapes)
        return tuple(type(self)._packed_size_from_shape(shape, varying_dims) for shape in element_shapes)

    def nested_like(self, tensor: Tensor, strict: bool = True) -> Self:
        r"""
        Create a new `NestedTensor` from a `Tensor`.
        The newly created `NestedTensor` will have the same shape as current `NestedTensor`.

        Args:
            tensor: The tensor to be converted to `NestedTensor`.
            strict: Check if the shape of `tensor` is the same as the current `NestedTensor`.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> (nested_tensor == nested_tensor.nested_like(nested_tensor)).all()
            tensor(True)
            >>> tensor = nested_tensor.tensor
            >>> (nested_tensor == nested_tensor.nested_like(tensor)).all()
            tensor(True)
            >>> f = nested_tensor.nested_like(torch.randn(2, 2))
            Traceback (most recent call last):
            ...
            ValueError: The shape of NestedTensor and input tensor does not match, ...
            >>> p = nested_tensor.nested_like(torch.randn(2, 2), False)
            >>> p = nested_tensor.nested_like(torch.randn(3, 3), False)
            Traceback (most recent call last):
            ...
            ValueError: The batch size of NestedTensor and input tensor does not match, 2 != 3
        """

        if isinstance(tensor, NestedTensor):
            return tensor.clone()

        if strict and self.shape != tensor.shape:
            raise ValueError(
                f"The shape of NestedTensor and input tensor does not match, {self.shape} != {tensor.shape}"
            )
        batch_dim = 0 if self.dim() <= 1 or self.batch_first else 1
        if len(self) != tensor.size(batch_dim):
            raise ValueError(
                "The batch size of NestedTensor and input tensor does not match, "
                f"{len(self)} != {tensor.size(batch_dim)}"
            )
        values = self._dense_to_packed_values(tensor)
        if values is not None:
            element_shapes = self._element_shapes
            return self.__class__._from_packed(
                values,
                self._offsets,
                self._physical_shape,
                batch_first=self.batch_first,
                padding_value=self.padding_value,
                mask_value=self.mask_value,
                pin_memory=self._pin_memory,
                outer_size=self._logical_shape,
                packed_sizes=self._packed_sizes,
                element_shapes=element_shapes,
            )
        dense_tensor = tensor.to(device=self.device)
        element_shapes = self._original_shapes()
        new_storage = []
        for idx, shape in enumerate(element_shapes):
            if self.batch_first:
                slices = (idx, *[slice(0, int(dim)) for dim in shape])
            else:
                if len(shape) == 0:
                    slices = (idx,)
                else:
                    slices = (slice(0, int(shape[0])), idx, *[slice(0, int(dim)) for dim in shape[1:]])
            # .contiguous() ensures storage elements don't inherit non-trivial
            # strides from the padded tensor (e.g. after transpose).
            new_storage.append(dense_tensor[slices].contiguous())
        return self.__class__(new_storage, dtype=tensor.dtype, **self._meta(include_dtype=False))

    @property
    def occupancy(self) -> float:
        r"""
        Occupancy of the NestedTensor.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3, 4]), torch.tensor([5, 6])])
            >>> nested_tensor.occupancy
            0.75
        """
        if len(self) == 0:
            return 0.0
        denom = self.shape.numel()  # type: ignore[union-attr]
        if denom == 0:
            return 0.0
        return self.numel() / denom  # type: ignore[union-attr]

    def to_torch_nested(self) -> Tensor:
        r"""
        Create a `torch.nested.nested_tensor` object from `self`.

        Examples:
            >>> nested_tensor = NestedTensor([[2, 3, 5], [7, 8]])
            >>> nt = nested_tensor.to_torch_nested()
            >>> nt.layout == torch.jagged
            True
            >>> nt.values()
            tensor([2, 3, 5, 7, 8])
        """
        storage = list(self._storage)
        if not storage or all(t.dim() > 0 for t in storage):
            return nested.nested_tensor(storage, layout=torch.jagged)
        return nested.nested_tensor(storage)

    def unbind(self, dim: int = 0) -> tuple[Tensor, ...]:
        r"""
        Unbind the NestedTensor.
        """
        return torch.unbind(self, dim=dim)

    def _maybe_exact_shape_nested_like(self, tensor: object) -> Self | None:
        r"""
        Convert an exact-shape dense tensor to this NestedTensor's layout.

        This is the shared policy boundary for dense-to-nested alignment used by
        operator helpers: only non-scalar dense tensors with logical shape exactly
        matching ``self.shape`` are converted, and the conversion always uses
        ``nested_like(..., strict=False)``.
        """
        if not isinstance(tensor, Tensor) or isinstance(tensor, type(self)):
            return None
        if tensor.dim() == 0 or tensor.shape != self.shape:
            return None
        return self.nested_like(tensor, strict=False)

    # ------------------------------------------------------------------
    # Indexing
    # ------------------------------------------------------------------

    def __getitem__(self, index: int | slice | list | tuple | Tensor | NestedTensor) -> Tensor | NestedTensor:
        r"""Retrieve element(s) by index, slice, list, tuple, or tensor mask."""
        if isinstance(index, int):
            return self._storage[index]
        if isinstance(index, (slice, list)):
            if isinstance(index, list) and index and all(isinstance(i, bool) for i in index):
                if len(index) != len(self):
                    raise IndexError(f"Boolean index has length {len(index)} but batch size is {len(self)}")
                index = [i for i, flag in enumerate(index) if flag]
            storage = tuple(self._storage[index] if isinstance(index, slice) else [self._storage[i] for i in index])
            return self.__class__(storage, **self._meta(include_dtype=True))
        if isinstance(index, tuple):
            if len(index) == 0:
                return self

            # Expand Ellipsis: ``nt[..., :2]`` on a 4-D NestedTensor becomes
            # ``nt[:, :, :, :2]``.  The batch dim is consumed first, so Ellipsis
            # fills the gap between the number of explicit indices and the total
            # number of logical dimensions.
            if Ellipsis in index:
                eidx = index.index(Ellipsis)
                n_explicit = len(index) - 1  # exclude Ellipsis itself
                n_expand = self.dim() - n_explicit
                index = index[:eidx] + (slice(None),) * n_expand + index[eidx + 1 :]

            batch_index, *rest = index

            if isinstance(batch_index, (Tensor, NestedTensor)):
                return self.tensor[index]

            if isinstance(batch_index, list) and batch_index and all(isinstance(i, bool) for i in batch_index):
                if len(batch_index) != len(self):
                    raise IndexError(f"Boolean index has length {len(batch_index)} but batch size is {len(self)}")
                batch_index = [i for i, flag in enumerate(batch_index) if flag]

            if isinstance(batch_index, int):
                tensor = self._storage[batch_index]
                if rest:
                    return tensor[tuple(rest)]
                return tensor
            elif isinstance(batch_index, (slice, list)):
                if isinstance(batch_index, slice):
                    selected = self._storage[batch_index]
                else:
                    selected = tuple(self._storage[i] for i in batch_index)
                if rest:
                    rest_tuple = tuple(rest)
                    selected = tuple(t[rest_tuple] for t in selected)
                return self.__class__(selected, **self._meta(include_dtype=True))
            raise ValueError(f"Unsupported batch index type {type(batch_index)}")
        if isinstance(index, NestedTensor):
            if len(self) != len(index):
                raise ValueError(
                    "NestedTensor batch length mismatch between self and index: "
                    f"self={len(self)}, index={len(index)}"
                )
            return self.__class__(
                [t[i] for t, i in zip(self._storage, index._storage)], **self._meta(include_dtype=True)
            )
        if isinstance(index, Tensor):
            if index.dim() == 0 and index.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8):
                return self._storage[int(index.item())]
            if index.dim() == 1:
                if index.dtype in (torch.bool, torch.uint8):
                    if index.numel() != len(self):
                        raise IndexError(f"Boolean index has length {index.numel()} but batch size is {len(self)}")
                    selected = tuple(self._storage[i] for i, flag in enumerate(index.tolist()) if bool(flag))
                    return self.__class__(selected, **self._meta(include_dtype=True))
                if index.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8):
                    return self.__class__(
                        [self._storage[int(i)] for i in index.tolist()],
                        **self._meta(include_dtype=True),
                    )
            index = self.nested_like(index, strict=False)
            return self.__class__(
                [t[i] for t, i in zip(self._storage, index._storage)], **self._meta(include_dtype=True)
            )
        raise ValueError(f"Unsupported index type {type(index)}")

    def __setitem__(self, index: int | slice | list | tuple, value: Tensor | NestedTensor) -> None:
        r"""
        Set values in the NestedTensor at the specified index.

        Args:
            index: The index to modify. Can be an integer, slice, list, or tuple.
            value: The new value to set. Can be a Tensor or NestedTensor.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor[0] = torch.tensor([6, 7, 8])
            >>> nested_tensor[0]
            tensor([6, 7, 8])
            >>> nested_tensor[1] = torch.tensor([9, 10, 11, 12])
            >>> nested_tensor.shape
            torch.Size([2, 4])
        """
        if isinstance(index, int):
            self._invalidate_transient_caches()
            if isinstance(value, NestedTensor):
                if len(value._storage) != 1:
                    raise ValueError(
                        f"When setting with an integer index, value must have a single tensor, but got {len(value)}"
                    )
                value = value._storage[0]
            if not isinstance(value, Tensor):
                value = torch.tensor(value, device=self.device, dtype=self.dtype)
            else:
                value = value.to(device=self.device, dtype=self.dtype)
            if self.requires_grad:
                value.requires_grad_(True)

            # Normalize negative index
            idx = index + len(self) if index < 0 else index
            if idx < 0 or idx >= len(self):
                raise IndexError(f"index {index} is out of range for NestedTensor with {len(self)} elements")
            expected_ndim = self._physical_shape.size(1)
            if value.dim() != expected_ndim:
                raise ValueError(
                    f"Assigned tensor ndim must match existing ndim {expected_ndim}, but got {value.dim()}"
                )

            old_start = int(self._offsets[idx].item())
            old_end = int(self._offsets[idx + 1].item())
            old_size = old_end - old_start
            new_shape_row = torch.tensor(list(value.shape), dtype=self._physical_shape.dtype)

            permutation = self._permutation
            identity_permutation = tuple(range(expected_ndim))
            varying_dims = self._varying_dims
            static_dims = self._static_dims
            packed_size = type(self)._packed_size_from_shape(tuple(int(dim) for dim in value.shape), varying_dims)
            packed_value = value if permutation == identity_permutation else value.permute(permutation)
            suffix_shape = tuple(int(value.shape[dim]) for dim in static_dims)
            new_payload = packed_value.reshape((packed_size, *suffix_shape) if suffix_shape else (packed_size,))
            new_size = packed_size

            if self._values.dim() > 1 and new_payload.shape[1:] != self._values.shape[1:]:
                storage_list = list(self._storage)
                storage_list[idx] = value
                self._repack(storage_list)
                return

            if new_size == old_size:
                # Same packed span size: direct overwrite keeps _values allocation.
                self._values[old_start:old_end] = new_payload
                self._physical_shape[idx] = new_shape_row
            else:
                # Different packed span size: splice _values and shift subsequent offsets.
                self._values = torch.cat([self._values[:old_start], new_payload, self._values[old_end:]], dim=0)
                delta = new_size - old_size
                self._offsets = self._offsets.clone()
                self._offsets[idx + 1 :] += delta  # noqa: E203
                self._physical_shape = self._physical_shape.clone()
                self._physical_shape[idx] = new_shape_row
            self._logical_shape = self._logical_shape_from_physical_shape(
                self._physical_shape, self._offsets, self.batch_first
            )
            if self._element_shapes is not None and self._packed_sizes is not None:
                element_shapes = list(self._element_shapes)
                element_shapes[idx] = tuple(int(dim) for dim in value.shape)
                self._element_shapes = tuple(element_shapes)
                packed_sizes = list(self._packed_sizes)
                packed_sizes[idx] = self._packed_sizes_like((self._element_shapes[idx],))[0]
                self._packed_sizes = tuple(packed_sizes)
            self._validate_metadata()
        elif isinstance(index, (slice, list)):
            if isinstance(index, list) and index and all(isinstance(i, bool) for i in index):
                if len(index) != len(self):
                    raise IndexError(f"Boolean index has length {len(index)} but batch size is {len(self)}")
                index = [i for i, flag in enumerate(index) if flag]

            if isinstance(value, Tensor) and not isinstance(value, NestedTensor):
                if value.dim() > 1 and value.size(0) > 1:
                    value = self.__class__(value.unbind(0), **self._meta())
                else:
                    value = self.__class__([value], **self._meta())

            if isinstance(index, slice):
                start, stop, step = index.indices(len(self))
                indices = range(start, stop, step)
            else:
                indices = index  # type: ignore[assignment]

            if len(indices) != len(value._storage):
                raise ValueError(
                    f"Size mismatch: tried to assign {len(value._storage)} values to {len(indices)} indices"
                )

            storage_list = list(self._storage)
            for i, idx in enumerate(indices):
                storage_list[idx] = value._storage[i]
            self._storage = tuple(storage_list)
        elif isinstance(index, tuple):
            if len(index) == 0:
                return
            if len(index) == 1:
                self[index[0]] = value
                return

            first_idx, rest_idx = index[0], index[1:]
            batch_indices: list[int]
            if isinstance(first_idx, int):
                batch_indices = [first_idx]
            elif isinstance(first_idx, (slice, list)):
                if isinstance(first_idx, list) and first_idx and all(isinstance(i, bool) for i in first_idx):
                    if len(first_idx) != len(self):
                        raise IndexError(f"Boolean index has length {len(first_idx)} but batch size is {len(self)}")
                    batch_indices = [i for i, flag in enumerate(first_idx) if flag]
                elif isinstance(first_idx, slice):
                    start, stop, step = first_idx.indices(len(self))
                    batch_indices = list(range(start, stop, step))
                else:
                    batch_indices = list(first_idx)  # type: ignore[arg-type]
            else:
                raise ValueError(f"Unsupported first index type {type(first_idx)}")

            if isinstance(value, NestedTensor):
                if len(batch_indices) != len(value._storage):
                    raise ValueError(
                        f"Size mismatch: tried to assign {len(value._storage)} values to {len(batch_indices)} indices"
                    )
                assigned_values = list(value._storage)
            else:
                assigned_values = [value] * len(batch_indices)

            elems = list(self._storage)
            for position, idx in enumerate(batch_indices):
                elem = elems[idx].clone()
                elem[rest_idx] = assigned_values[position]
                elems[idx] = elem
            self._storage = tuple(elems)
        else:
            raise ValueError(f"Unsupported index type {type(index)}")

    # ------------------------------------------------------------------
    # Properties: runtime config, dtype, device, requires_grad
    # ------------------------------------------------------------------

    @property
    def batch_first(self) -> bool:
        r"""Whether the logical outer shape uses ``(B, ...)`` instead of ``(..., B, ...)``."""
        return self._batch_first

    @batch_first.setter
    def batch_first(self, value: bool):
        new_value = type(self)._coerce_batch_first(value)
        old_value = getattr(self, "_batch_first", None)
        self._batch_first = new_value
        if old_value is None or old_value == new_value:
            return
        if hasattr(self, "_physical_shape") and hasattr(self, "_offsets") and hasattr(self, "_logical_shape"):
            self._logical_shape = type(self)._logical_shape_from_physical_shape(
                self._physical_shape,
                self._offsets,
                new_value,
            )
        if hasattr(self, "_cached_tensor_view"):
            self._invalidate_transient_caches()

    @property
    def padding_value(self) -> float:
        r"""Padding fill value used when materializing dense views."""
        return self._padding_value

    @padding_value.setter
    def padding_value(self, value: SupportsFloat):
        new_value = type(self)._coerce_padding_value(value)
        old_value = getattr(self, "_padding_value", None)
        self._padding_value = new_value
        if old_value is None or old_value == new_value:
            return
        if hasattr(self, "_cached_tensor_view"):
            self._cached_tensor_view = None

    @property
    def mask_value(self) -> bool:
        r"""Boolean value used to denote padding positions in generated masks."""
        return self._mask_value

    @mask_value.setter
    def mask_value(self, value: bool):
        new_value = type(self)._coerce_mask_value(value)
        old_value = getattr(self, "_mask_value", None)
        self._mask_value = new_value
        if old_value is None or old_value == new_value:
            return
        if hasattr(self, "_cached_mask_view"):
            self._cached_mask_view = None

    @property
    def dtype(self) -> torch.dtype:  # type: ignore[override]
        r"""Data type of the underlying tensor elements."""
        return self._values.dtype

    @dtype.setter
    def dtype(self, value: torch.dtype | None):
        r"""`dtype` is read-only; use `.to(dtype=...)` to convert."""
        raise AttributeError("NestedTensor.dtype is read-only; use .to(dtype=...) to create a converted tensor.")

    @property
    def device(self) -> torch.device:  # type: ignore[override]
        r"""Device on which the underlying tensor data resides."""
        return self._values.device

    @device.setter
    def device(self, value: torch.device | None):
        r"""`device` is read-only; use `.to(device=...)` to move tensors."""
        raise AttributeError("NestedTensor.device is read-only; use .to(device=...) to create a moved tensor.")

    @property
    def requires_grad(self) -> bool:  # type: ignore[override]
        r"""Whether gradient computation is enabled for this tensor."""
        return self._values.requires_grad

    @requires_grad.setter
    def requires_grad(self, value: bool):
        r"""Enable or disable gradient computation for this tensor."""
        self._values.requires_grad_(value)

    # ------------------------------------------------------------------
    # State management
    # ------------------------------------------------------------------

    def _meta(self, *, include_dtype: bool | None = None) -> Mapping:
        r"""Metadata used for structure-preserving reconstruction."""
        if include_dtype is None:
            # Empty reconstructions cannot infer dtype from storage; include it by default.
            include_dtype = self._values.numel() == 0
        if include_dtype:
            return {
                "batch_first": self.batch_first,
                "padding_value": self.padding_value,
                "mask_value": self.mask_value,
                "pin_memory": self._pin_memory,
                "device": self._values.device,
                "dtype": self.dtype,
            }
        return {
            "batch_first": self.batch_first,
            "padding_value": self.padding_value,
            "mask_value": self.mask_value,
            "pin_memory": self._pin_memory,
            "device": self._values.device,
        }

    def __getstate__(self) -> dict:
        return {
            "_state_version": self._SERIALIZATION_VERSION,
            "_values": self._values,
            "_offsets": self._offsets,
            "_permutation": self._permutation,
            "_physical_shape": self._physical_shape,
            "_logical_shape": self._logical_shape,
            "batch_first": self.batch_first,
            "padding_value": self.padding_value,
            "mask_value": self.mask_value,
            "_pin_memory": self._pin_memory,
            "_packed_sizes": self._packed_sizes,
            "_element_shapes": self._element_shapes,
        }

    def __setstate__(self, state: Mapping) -> None:
        type(self)._validate_serialized_state(state)
        self._values = state["_values"]
        self._offsets = state["_offsets"].cpu()
        self._permutation = tuple(int(dim) for dim in state["_permutation"])
        self._physical_shape = state["_physical_shape"].cpu()
        self._logical_shape = state["_logical_shape"]
        self._set_runtime_config(
            batch_first=state["batch_first"],
            padding_value=state["padding_value"],
            mask_value=state["mask_value"],
        )
        self._pin_memory = bool(state["_pin_memory"] and self._values.device.type == "cpu" and self._values.is_pinned())
        self._packed_sizes = state["_packed_sizes"]
        self._element_shapes = state["_element_shapes"]
        # Serialized state intentionally excludes transient caches.
        self._invalidate_transient_caches()
        self._validate_metadata()

    def __reduce__(self):
        return (self.__class__._from_state, (self.__getstate__(),))

    @classmethod
    def _from_state(cls, state: dict) -> Self:
        cls._validate_serialized_state(state)
        return cls._from_packed(
            state["_values"],
            state["_offsets"].cpu(),
            state["_physical_shape"].cpu(),
            permutation=tuple(int(dim) for dim in state["_permutation"]),
            batch_first=state["batch_first"],
            padding_value=state["padding_value"],
            mask_value=state["mask_value"],
            pin_memory=state["_pin_memory"],
            outer_size=state["_logical_shape"],
            packed_sizes=state["_packed_sizes"],
            element_shapes=state["_element_shapes"],
        )

    def __copy__(self):
        r"""Shallow copy: new NestedTensor sharing underlying tensor data."""
        return self.__class__._from_packed(
            self._values,
            self._offsets,
            self._physical_shape,
            permutation=self._permutation,
            batch_first=self.batch_first,
            padding_value=self.padding_value,
            mask_value=self.mask_value,
            pin_memory=self._pin_memory,
            outer_size=self._logical_shape,
            packed_sizes=self._packed_sizes,
            element_shapes=self._element_shapes,
        )

    def __deepcopy__(self, memo):
        r"""Deep copy: clones all tensor data."""
        result = self.__class__._from_packed(
            self._values.clone(),
            self._offsets.clone(),
            self._physical_shape.clone(),
            permutation=self._permutation,
            batch_first=self.batch_first,
            padding_value=self.padding_value,
            mask_value=self.mask_value,
            pin_memory=self._pin_memory,
            outer_size=self._logical_shape,
            packed_sizes=self._packed_sizes,
            element_shapes=self._element_shapes,
        )
        memo[id(self)] = result
        return result

    # ------------------------------------------------------------------
    # Tensor-like methods
    # ------------------------------------------------------------------

    def all(self, dim: int | None = None, keepdim: bool = False) -> bool | Tensor | NestedTensor:
        r"""
        Tests if all elements in NestedTensor evaluate to True.

        Examples:
            >>> nested_tensor = NestedTensor([torch.ones(2, 4, dtype=torch.bool), torch.ones(3, 5, dtype=torch.bool)])
            >>> nested_tensor.all()
            tensor(True)
            >>> nested_tensor.all(dim=0)
            tensor([True, True])
            >>> nested_tensor.all(dim=0, keepdim=True)
            tensor([[True, True]])
            >>> nested_tensor.all(dim=1)
            NestedTensor([
                [True, True, True, True],
                [True, True, True, True, True]
            ])
            >>> nested_tensor.all(dim=1, keepdim=True)
            NestedTensor([
                [[True, True, True, True]],
                [[True, True, True, True, True]]
            ])
            >>> nested_tensor.batch_first = False
            >>> nested_tensor.all(dim=1)
            tensor([True, True])
            >>> nested_tensor.all(dim=0)
            NestedTensor([
                [True, True, True, True],
                [True, True, True, True, True]
            ])
            >>> nested_tensor.all(dim=-2)
            tensor([True, True])
        """
        return torch.all(self, dim=dim, keepdim=keepdim)

    def any(self, dim: int | None = None, keepdim: bool = False) -> bool | Tensor | NestedTensor:
        r"""
        Tests if any elements in NestedTensor evaluate to True.

        Examples:
            >>> nested_tensor = NestedTensor([torch.zeros(2, dtype=torch.bool), torch.ones(3, dtype=torch.bool)])
            >>> nested_tensor.any()
            tensor(True)
            >>> nested_tensor.any(dim=0)
            tensor([False,  True])
        """
        return torch.any(self, dim=dim, keepdim=keepdim)

    def dim(self) -> int:
        r"""
        Number of dimension of the NestedTensor.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.dim()
            2
        """
        if not hasattr(self, "_logical_shape"):
            with torch._C.DisableTorchFunctionSubclass():
                return len(torch.Tensor.size(self))
        return len(self._logical_shape)

    def max(self, dim: int | None = None, keepdim: bool = False) -> Tensor | NestedTensor:
        r"""Return the maximum value, optionally along a given dimension."""
        if dim is None:
            return torch.max(self)
        return torch.max(self, dim=dim, keepdim=keepdim)

    def mean(
        self,
        dim: int | None = None,
        keepdim: bool = False,
        *,
        dtype: torch.dtype | None = None,  # type: ignore[name-defined]
    ) -> Tensor | NestedTensor:
        r"""Return the mean value, optionally along a given dimension."""
        return torch.mean(self, dim=dim, keepdim=keepdim, dtype=dtype)

    def min(self, dim: int | None = None, keepdim: bool = False) -> Tensor | NestedTensor:
        r"""Return the minimum value, optionally along a given dimension."""
        if dim is None:
            return torch.min(self)
        return torch.min(self, dim=dim, keepdim=keepdim)

    @property
    def mT(self) -> Self:  # type: ignore[override]
        r"""Matrix transpose over the last two per-element dimensions."""
        ndims = self.dim()
        batch_dim = 0 if self.batch_first else 1
        elem_dims = [d for d in range(ndims) if d != batch_dim]
        if len(elem_dims) < 2:
            raise RuntimeError(
                f"tensor.mT is only supported on matrices or batches of matrices. Got {len(elem_dims)}-D tensor."
            )
        return torch.transpose(self, elem_dims[-2], elem_dims[-1])

    @property
    def ndim(self) -> int:
        r"""
        Alias for `dim()`.
        """
        return self.dim()

    def numel(self) -> int:
        r"""
        Number of elements in the NestedTensor.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.numel()
            5
        """
        return self._values.numel()

    def permute(self, *dims) -> Self:
        r"""
        Apply permutation to each tensor in the NestedTensor.

        Args:
            *dims: The desired ordering of dimensions for the NestedTensor (including batch dimension).

        Returns:
            NestedTensor: A new NestedTensor with each tensor permuted.

        Examples:
            >>> nested_tensor = NestedTensor([torch.randn(3, 4, 5), torch.randn(2, 4, 5)])
            >>> permuted = nested_tensor.permute(0, 3, 1, 2)
            >>> permuted.shape
            torch.Size([2, 5, 3, 4])
        """
        return torch.permute(self, dims)

    def moveaxis(self, source, destination) -> Self:
        r"""Move per-element dimensions to new positions."""
        return torch.moveaxis(self, source, destination)

    def movedim(self, source, destination) -> Self:
        r"""Alias for `moveaxis()`."""
        return torch.movedim(self, source, destination)

    # to(), clone(), detach(), contiguous(), half(), float(), double(), etc.
    # are all handled by aten dispatch in aten_functions.py (aten._to_copy, aten.clone,
    # aten.detach). No custom Python methods needed.

    def pin_memory(self) -> Self:
        r"""Pin the underlying tensor memory for faster host-to-device transfer."""
        return type(self)._from_packed(
            self._values.pin_memory(),
            self._offsets,
            self._physical_shape,
            batch_first=self.batch_first,
            padding_value=self.padding_value,
            mask_value=self.mask_value,
            pin_memory=True,
            outer_size=self._logical_shape,
            packed_sizes=self._packed_sizes,
            element_shapes=self._element_shapes,
        )

    def prod(
        self,
        dim: int | None = None,
        keepdim: bool = False,
        *,
        dtype: torch.dtype | None = None,  # type: ignore[name-defined]
    ) -> Tensor | NestedTensor:
        r"""Return the product of elements, optionally along a given dimension."""
        return torch.prod(self, dim=dim, keepdim=keepdim, dtype=dtype)

    def requires_grad_(self, requires_grad: bool = True):
        r"""Enable or disable gradient computation in-place."""
        self.requires_grad = requires_grad
        return self

    def reshape(self, *shape) -> Self:
        r"""
        Reshape each tensor in the NestedTensor.

        Args:
            *shape: The desired size of each dimension for the underlying tensors.

        Returns:
            NestedTensor: A new NestedTensor with each tensor reshaped.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]])])
            >>> reshaped = nested_tensor.reshape(4)
            >>> reshaped.shape
            torch.Size([2, 4])
        """
        if not shape:
            raise TypeError("reshape() missing shape")
        target_shape = shape[0] if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)) else shape
        return torch.reshape(self, target_shape)

    def flatten(self, start_dim: int = 0, end_dim: int = -1):
        r"""Flatten each tensor in the NestedTensor."""
        return torch.flatten(self, start_dim=start_dim, end_dim=end_dim)

    def flip(self, dims) -> Self:
        r"""Flip each tensor in the NestedTensor along the given dimensions."""
        return torch.flip(self, dims)

    @property
    def shape(self) -> torch.Size:  # type: ignore[override, name-defined]
        r"""
        Alias for `size()`.
        """
        return self.size()

    def size(self, dim: int | None = None) -> torch.Size | int:  # type: ignore[override, name-defined]
        r"""
        Returns the size of the self `NestedTensor`.

        Args:
            dim: If not specified, the returned value is a `torch.Size`, a subclass of `tuple`.
                If specified, returns an `int` holding the size of that dimension.
                Defaults to `None`.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.size()
            torch.Size([2, 3])
            >>> nested_tensor.size(0)
            2
            >>> nested_tensor[1] = torch.tensor([4, 5, 6, 7])
            >>> nested_tensor.shape
            torch.Size([2, 4])
            >>> nested_tensor.size(1)
            4
        """
        if hasattr(self, "_logical_shape"):
            full_size = self._logical_shape
        else:
            with torch._C.DisableTorchFunctionSubclass():
                full_size = torch.Tensor.size(self)
        if dim is not None:
            dim = dim + len(full_size) if dim < 0 else dim
            return full_size[dim]
        return full_size

    def sum(
        self,
        dim: int | Sequence[int] | None = None,
        keepdim: bool = False,
        *,
        dtype: torch.dtype | None = None,  # type: ignore[name-defined]
    ) -> Tensor | NestedTensor:
        r"""
        Returns the sum of each tensor over the given dimension(s).

        Args:
            dim: The dimension or dimensions to reduce. If None, sum over all dimensions.
                Supports int, Sequence[int], or None. Negative dimensions are supported.
            keepdim: Whether to retain reduced dimensions with size 1.
            dtype: The desired data type of returned tensor.

        Returns:
            Tensor or NestedTensor depending on the dimensions being reduced.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.sum()
            tensor(15)
            >>> nested_tensor.sum(dim=0)  # when dim=0, sum across batch dimension
            tensor([6, 9])
            >>> nested_tensor.sum(dim=1)
            tensor([6, 9])
            >>> nested_tensor.sum(dim=[0, 1])
            tensor(15)
            >>> nested_tensor.sum(dim=0, keepdim=True)
            tensor([[6, 9]])
            >>> nested_tensor.sum(dtype=torch.float32)
            tensor(15.)
        """
        return torch.sum(self, dim=dim, keepdim=keepdim, dtype=dtype)

    @property
    def T(self) -> Self:  # type: ignore[override]
        r"""Transpose: reverse per-element dims while keeping batch dim fixed."""
        ndims = self.dim()
        if ndims <= 1:
            return self
        batch_dim = 0 if self.batch_first else 1
        elem_dims = [d for d in range(ndims) if d != batch_dim]
        order = list(reversed(elem_dims))
        order.insert(batch_dim, batch_dim)
        return torch.permute(self, tuple(order))

    def tolist(self) -> list:
        r"""
        Convert a NestedTensor to a list of lists of values.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.tolist()
            [[1, 2, 3], [4, 5]]
        """
        return [t.tolist() for t in self._storage]

    def transpose(self, dim0: int, dim1: int) -> Self:  # type: ignore[valid-type]
        r"""
        Transpose dimensions dim0 and dim1 for each tensor in the NestedTensor.

        Args:
            dim0: First dimension to transpose (in NestedTensor coordinate system).
            dim1: Second dimension to transpose (in NestedTensor coordinate system).

        Returns:
            NestedTensor: A new NestedTensor with each tensor transposed.

        Examples:
            >>> nested_tensor = NestedTensor([torch.randn(3, 4), torch.randn(2, 4)])
            >>> # NestedTensor shape is [2, 3, 4], underlying tensors are [3, 4] and [2, 4]
            >>> transposed = nested_tensor.transpose(1, 2)  # transpose dims 1 and 2
            >>> transposed.shape  # batch dimension is still first
            torch.Size([2, 4, 3])
        """
        return torch.transpose(self, dim0, dim1)

    def swapaxes(self, axis0: int, axis1: int) -> Self:
        r"""Alias for `transpose()`."""
        return torch.swapaxes(self, axis0, axis1)

    def swapdims(self, dim0: int, dim1: int) -> Self:
        r"""Alias for `swapaxes()`."""
        return torch.swapdims(self, dim0, dim1)

    def squeeze(self, dim: int | None = None) -> Self:  # type: ignore[valid-type]
        r"""Squeeze singleton dimensions from each tensor in the NestedTensor."""
        return torch.squeeze(self, dim=dim)

    def unsqueeze(self, dim: int) -> Self:  # type: ignore[valid-type]
        r"""
        Unsqueeze each tensor in the NestedTensor by adding a singleton dimension at the specified position.

        Args:
            dim: The dimension at which to add the singleton dimension. This is in the NestedTensor's
                coordinate system (where dim 0 is the batch dimension).

        Returns:
            NestedTensor: A new NestedTensor with each tensor unsqueezed at the specified dimension.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> # Original shape: [2, 3] (batch_size=2, max_seq_len=3)
            >>> unsqueezed = nested_tensor.unsqueeze(1)
            >>> unsqueezed.shape
            torch.Size([2, 1, 3])
            >>> # Now each underlying tensor has shape [1, seq_len] instead of [seq_len]

            >>> nested_tensor_2d = NestedTensor([torch.randn(3, 4), torch.randn(2, 4)])
            >>> # Original shape: [2, 3, 4] (batch_size=2, max_len1=3, max_len2=4)
            >>> unsqueezed_2d = nested_tensor_2d.unsqueeze(2)
            >>> unsqueezed_2d.shape
            torch.Size([2, 3, 1, 4])
            >>> # Now each underlying tensor has shape [len1, 1, len2] instead of [len1, len2]
        """
        return torch.unsqueeze(self, dim)

    def unflatten(self, dim: int, sizes) -> Self:  # type: ignore[valid-type]
        r"""Unflatten one dimension of each tensor in the NestedTensor."""
        return torch.unflatten(self, dim, sizes)

    def roll(self, shifts, dims=None) -> Self:
        r"""Roll each tensor in the NestedTensor along the given dimensions."""
        return torch.roll(self, shifts, dims=dims)

    def rot90(self, k: int = 1, dims: Sequence[int] = (0, 1)) -> Self:
        r"""Rotate each tensor in the NestedTensor by 90 degrees in the given plane."""
        return torch.rot90(self, k, dims)

    def view(self, *shape) -> Self:
        r"""
        View each tensor in the NestedTensor with a different shape.

        Args:
            *shape: The desired size of each dimension for the underlying tensors.

        Returns:
            NestedTensor: A new NestedTensor with each tensor viewed with the new shape.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]])])
            >>> viewed = nested_tensor.view(4)  # View each 2x2 tensor as 4
            >>> viewed.shape
            torch.Size([2, 4])
            >>> type(viewed).__name__
            'NestedTensor'
        """
        if not shape:
            raise TypeError("view() missing shape")
        target_shape = shape[0] if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)) else shape
        return NestedTensorAtenRegistry[torch.ops.aten.view.default](
            torch.ops.aten.view.default, (self, list(target_shape)), {}
        )

    def _view_shapes(self, shape) -> list[tuple[int, ...]]:  # type: ignore[valid-type]
        r"""
        Compute per-element view shapes, adjusting ragged dimensions.

        Batch-dim detection rules:
        1. If ``shape[batch_dim]`` does not match the batch size, batch dim is NOT included.
        2. If ``len(shape) != self.dim()``, batch dim IS included (unambiguous).
        3. If ``len(shape) == self.dim()`` (ambiguous), batch dim is included only when
           at least one other dimension matches max_sizes or is -1.

        For ragged dimensions, each target dimension that matches the corresponding
        max size is substituted with the element's actual size. When a target dimension
        matches a max size at a different position (e.g. after inserting a dim), a
        single-candidate search resolves the mapping.
        """
        if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
            shape = tuple(shape[0])

        batch_dim = 0 if self.batch_first else 1
        batch_size = len(self)

        # Step 1: Determine if batch dim is in the target shape
        include_batch = False
        if len(shape) > batch_dim:
            if shape[batch_dim] == batch_size and len(shape) != self.dim():
                include_batch = True
            elif shape[batch_dim] in (-1, batch_size) and len(shape) == self.dim():
                # Ambiguous: same dim count → confirm via dimension matching
                max_sizes = list(self.size())  # type: ignore[arg-type]
                if max_sizes:
                    max_sizes.pop(batch_dim)
                non_batch = [i for i in range(len(shape)) if i != batch_dim]
                include_batch = any(
                    j < len(max_sizes) and (shape[d] == -1 or shape[d] == max_sizes[j]) for j, d in enumerate(non_batch)
                )

        # Step 2: Strip batch dim from target shape
        target = list(shape)
        if include_batch:
            if target[batch_dim] == -1:
                target[batch_dim] = batch_size
            if target[batch_dim] != batch_size:
                raise ValueError(f"Batch dimension mismatch: expected {batch_size} but got {target[batch_dim]}")
            target.pop(batch_dim)

        # Step 3: Per-element shape adjustment (ragged dim substitution)
        max_sizes = list(self.size())  # type: ignore[arg-type]
        if max_sizes:
            max_sizes.pop(batch_dim)

        element_shapes = self._element_shapes
        if element_shapes is None:
            element_shapes = tuple(tuple(shape) for shape in self._original_shapes())

        view_shapes = []
        for element_shape in element_shapes:
            adjusted = list(target)
            available = list(range(len(max_sizes)))
            for i in range(min(len(adjusted), len(max_sizes))):
                if adjusted[i] == -1:
                    continue
                # Direct match: same position in max_sizes
                if adjusted[i] == max_sizes[i]:
                    adjusted[i] = element_shape[i]
                    if i in available:
                        available.remove(i)
                    continue
                # Indirect match: search remaining positions for unique candidate
                candidates = [j for j in available if max_sizes[j] == adjusted[i]]
                if len(candidates) == 1:
                    j = candidates[0]
                    adjusted[i] = element_shape[j]
                    available.remove(j)
            if adjusted.count(-1) == 1:
                missing = adjusted.index(-1)
                known = 1
                for dim in adjusted:
                    if dim != -1:
                        known *= dim
                element_numel = type(self)._shape_numel(element_shape)
                if known != 0 and element_numel % known == 0:
                    adjusted[missing] = element_numel // known
            view_shapes.append(tuple(adjusted))
        return view_shapes

    def where(self, condition: Tensor | NestedTensor, other: Tensor | NestedTensor | SupportsFloat) -> Self:
        r"""
        Return a NestedTensor of elements selected from either self or other, depending on condition.

        Examples:
            >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            >>> nested_tensor.where(nested_tensor > 2, torch.tensor([[6, 5, 4], [3, 2, 1]]))
            NestedTensor([
                [6, 5, 3],
                [4, 5]
            ])
            >>> nested_tensor.where(nested_tensor > 2, NestedTensor([[6, 5, 4], [3, 2]]))
            NestedTensor([
                [6, 5, 3],
                [4, 5]
            ])
            >>> nested_tensor.where(torch.tensor(True), NestedTensor([[6, 5, 4], [3, 2]]))
            NestedTensor([
                [1, 2, 3],
                [4, 5]
            ])
        """
        return torch.where(condition, self, other)

tensor_mask property

Python
tensor_mask: tuple[Tensor, Tensor]

Return a tuple of padded tensor and mask tensor.

Examples:

Python Console Session
1
2
3
4
5
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.tensor_mask
(tensor([[1, 2, 3],
        [4, 5, 0]]), tensor([[ True,  True,  True],
        [ True,  True, False]]))

tensor property

Python
tensor: Tensor

Return a single tensor by padding all the tensors.

Examples:

Python Console Session
1
2
3
4
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.tensor
tensor([[1, 2, 3],
        [4, 5, 0]])

mask property

Python
mask: Tensor

Padding mask of tensor.

mask_value controls which boolean value denotes padding in this mask. With the default mask_value=False, True means valid data.

Examples:

Python Console Session
1
2
3
4
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.mask
tensor([[ True,  True,  True],
        [ True,  True, False]])

concat property

Python
concat: Tensor

Flatten elements and concatenate along the ragged dimension (no padding).

This is particularly useful when calculating loss or passing Linear to avoid unnecessary computation.

Examples:

Python Console Session
>>> nested_tensor = NestedTensor([torch.randn(9, 8), torch.randn(11, 8)])
>>> nested_tensor.concat.shape
torch.Size([20, 8])
>>> nested_tensor = NestedTensor([torch.randn(9, 9, 8), torch.randn(11, 11, 8)])
>>> nested_tensor.concat.shape
torch.Size([202, 8])
>>> nested_tensor = NestedTensor([torch.randn(9, 9, 8, 6), torch.randn(11, 11, 8, 6)])
>>> nested_tensor.concat.shape
torch.Size([202, 8, 6])
>>> nested_tensor = NestedTensor([torch.randn(9, 9, 8, 7), torch.randn(11, 11, 8, 6)])
>>> nested_tensor.concat.shape
torch.Size([1293, 8])
>>> nested_tensor = NestedTensor([torch.randn(1, 9, 9, 5), torch.randn(1, 11, 11, 5)])
>>> nested_tensor.concat.shape
torch.Size([202, 1, 5])

occupancy property

Python
occupancy: float

Occupancy of the NestedTensor.

Examples:

Python Console Session
1
2
3
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3, 4]), torch.tensor([5, 6])])
>>> nested_tensor.occupancy
0.75

batch_first property writable

Python
batch_first: bool

Whether the logical outer shape uses (B, ...) instead of (..., B, ...).

padding_value property writable

Python
padding_value: float

Padding fill value used when materializing dense views.

mask_value property writable

Python
mask_value: bool

Boolean value used to denote padding positions in generated masks.

dtype property writable

Python
dtype: dtype

Data type of the underlying tensor elements.

device property writable

Python
device: device

Device on which the underlying tensor data resides.

requires_grad property writable

Python
requires_grad: bool

Whether gradient computation is enabled for this tensor.

mT property

Python
mT: Self

Matrix transpose over the last two per-element dimensions.

ndim property

Python
ndim: int

Alias for dim().

shape property

Python
shape: Size

Alias for size().

T property

Python
T: Self

concatenate

Python
concatenate() -> tuple[Tensor, tuple[Size, ...]]

Concatenate tensors in padding dimension and return structural information for reconstruction.

Returns:

Type Description
Tensor

A tuple containing:

tuple[Size, ...]
  • concat_tensor: The concatenated tensor (same as .concat property)
tuple[Tensor, tuple[Size, ...]]
  • shapes: Tuple of original tensor shapes for reconstruction

Examples:

Python Console Session
1
2
3
4
5
6
7
8
9
>>> nested_tensor = NestedTensor([torch.randn(9, 8), torch.randn(11, 8)])
>>> concat_tensor, shapes = nested_tensor.concatenate()
>>> concat_tensor.shape
torch.Size([20, 8])
>>> shapes
(torch.Size([9, 8]), torch.Size([11, 8]))
>>> reconstructed = NestedTensor.from_concatenated(concat_tensor, shapes)
>>> torch.equal(nested_tensor.tensor, reconstructed.tensor)
True
Source code in danling/tensors/nested_tensor.py
Python
def concatenate(self) -> tuple[Tensor, tuple[torch.Size, ...]]:
    r"""
    Concatenate tensors in padding dimension and return structural information for reconstruction.

    Returns:
        A tuple containing:
        - concat_tensor: The concatenated tensor (same as .concat property)
        - shapes: Tuple of original tensor shapes for reconstruction

    Examples:
        >>> nested_tensor = NestedTensor([torch.randn(9, 8), torch.randn(11, 8)])
        >>> concat_tensor, shapes = nested_tensor.concatenate()
        >>> concat_tensor.shape
        torch.Size([20, 8])
        >>> shapes
        (torch.Size([9, 8]), torch.Size([11, 8]))
        >>> reconstructed = NestedTensor.from_concatenated(concat_tensor, shapes)
        >>> torch.equal(nested_tensor.tensor, reconstructed.tensor)
        True
    """
    batch_size = len(self._offsets) - 1
    if batch_size == 0:
        return torch.empty(0, dtype=self._values.dtype, device=self.device), ()
    return self._values, self._original_shapes()

__len__

Python
__len__() -> int

Return the number of tensors in the batch.

Source code in danling/tensors/nested_tensor.py
Python
def __len__(self) -> int:
    r"""Return the number of tensors in the batch."""
    if not hasattr(self, "_offsets"):
        with torch._C.DisableTorchFunctionSubclass():
            full_size = torch.Tensor.size(self)
        if len(full_size) == 0:
            return 0
        batch_dim = 0 if getattr(self, "batch_first", True) else (1 if len(full_size) > 1 else 0)
        return int(full_size[batch_dim])
    return len(self._offsets) - 1

__repr__

Python
__repr__()

Return a human-readable string representation of the NestedTensor.

Source code in danling/tensors/nested_tensor.py
Python
def __repr__(self):
    r"""Return a human-readable string representation of the NestedTensor."""
    if torch._dynamo.is_compiling():
        try:
            shape = tuple(self.size())
        except Exception:
            shape = "?"
        return (
            f"{self.__class__.__name__}(shape={shape}, dtype={self.dtype}, "
            f"device={self.device}, batch_first={getattr(self, 'batch_first', True)})"
        )

    try:
        from torch._subclasses.fake_tensor import is_fake

        for name in ("_values", "_offsets", "_physical_shape"):
            value = self.__dict__.get(name)
            if isinstance(value, Tensor) and is_fake(value):
                shape = tuple(self.size())
                return (
                    f"{self.__class__.__name__}(shape={shape}, dtype={self.dtype}, "
                    f"device={self.device}, batch_first={getattr(self, 'batch_first', True)})"
                )
    except Exception:
        pass

    if not all(name in self.__dict__ for name in ("_values", "_offsets", "_physical_shape")):
        try:
            shape = tuple(self.size())
        except Exception:
            shape = "?"
        return (
            f"{self.__class__.__name__}(shape={shape}, dtype={self.dtype}, "
            f"device={self.device}, batch_first={getattr(self, 'batch_first', True)})"
        )

    if len(self) == 0:
        return self.__class__.__name__ + "()"

    storage = self._storage
    truncated = len(storage) > 10
    if truncated:
        storage = storage[:5]

    indent = "    "

    # Strip "tensor(" wrapper from each element's repr,
    # keeping PyTorch's internal number formatting (precision, alignment).
    data_parts = []
    for t in storage:
        s = repr(t)
        paren_idx = s.index("(")
        data = s[paren_idx + 1 : -1]  # noqa: E203
        # Re-indent continuation lines for multi-line element reprs (e.g. 2D tensors)
        if "\n" in data:
            lines = data.split("\n")
            data = lines[0] + "\n" + "\n".join(indent + " " + line.lstrip() for line in lines[1:])
        data_parts.append(data)

    result_lines = [self.__class__.__name__ + "(["]
    for i, part in enumerate(data_parts):
        suffix = "," if i < len(data_parts) - 1 or truncated else ""
        result_lines.append(indent + part + suffix)
    if truncated:
        result_lines.append(indent + f"... ({len(self)} tensors)")
    result_lines.append("])")
    return "\n".join(result_lines)

__bool__

Python
__bool__() -> bool

NestedTensor follows tensor-style truthiness and never acts like a Python container.

Source code in danling/tensors/nested_tensor.py
Python
def __bool__(self) -> bool:
    r"""NestedTensor follows tensor-style truthiness and never acts like a Python container."""
    raise RuntimeError(
        "Boolean value of NestedTensor is ambiguous. Use .numel(), .any(), .all(), or an explicit reduction."
    )

__iter__

Python
__iter__()

Iterate over the tensors in the batch.

Source code in danling/tensors/nested_tensor.py
Python
def __iter__(self):
    r"""Iterate over the tensors in the batch."""
    _check_execution_guard(_ExecutionGuardKind.ITERATION, "NestedTensor.__iter__")
    return iter(self._storage)

__eq__

Python
__eq__(other)

Element-wise equality comparison.

Source code in danling/tensors/nested_tensor.py
Python
def __eq__(self, other):  # type: ignore[override]
    r"""Element-wise equality comparison."""
    try:
        return torch.eq(self, other)
    except TypeError:
        return NotImplemented

__ne__

Python
__ne__(other)

Element-wise inequality comparison.

Source code in danling/tensors/nested_tensor.py
Python
def __ne__(self, other):  # type: ignore[override]
    r"""Element-wise inequality comparison."""
    try:
        return torch.ne(self, other)
    except TypeError:
        return NotImplemented

from_concatenated classmethod

Python
from_concatenated(
    concat_tensor: Tensor,
    shapes: tuple[Size, ...],
    **kwargs
) -> Self

Reconstruct a NestedTensor from a concatenated tensor and shape information.

Parameters:

Name Type Description Default
concat_tensor
Tensor

The concatenated tensor returned by concatenate()

required
shapes
tuple[Size, ...]

Tuple of original tensor shapes returned by concatenate()

required
**kwargs

Additional arguments to pass to NestedTensor constructor

{}

Returns:

Type Description
Self

Reconstructed NestedTensor

Examples:

Python Console Session
1
2
3
4
5
6
7
8
9
>>> nested_tensor = NestedTensor([torch.randn(9, 9, 8), torch.randn(11, 11, 8)])
>>> concat_tensor, shapes = nested_tensor.concatenate()
>>> reconstructed = NestedTensor.from_concatenated(concat_tensor, shapes)
>>> concat_tensor.shape
torch.Size([202, 8])
>>> reconstructed.shape
torch.Size([2, 11, 11, 8])
>>> torch.equal(nested_tensor.tensor, reconstructed.tensor)
True
Source code in danling/tensors/nested_tensor.py
Python
@classmethod
def from_concatenated(cls, concat_tensor: Tensor, shapes: tuple[torch.Size, ...], **kwargs) -> Self:
    r"""
    Reconstruct a NestedTensor from a concatenated tensor and shape information.

    Args:
        concat_tensor: The concatenated tensor returned by concatenate()
        shapes: Tuple of original tensor shapes returned by concatenate()
        **kwargs: Additional arguments to pass to NestedTensor constructor

    Returns:
        Reconstructed NestedTensor

    Examples:
        >>> nested_tensor = NestedTensor([torch.randn(9, 9, 8), torch.randn(11, 11, 8)])
        >>> concat_tensor, shapes = nested_tensor.concatenate()
        >>> reconstructed = NestedTensor.from_concatenated(concat_tensor, shapes)
        >>> concat_tensor.shape
        torch.Size([202, 8])
        >>> reconstructed.shape
        torch.Size([2, 11, 11, 8])
        >>> torch.equal(nested_tensor.tensor, reconstructed.tensor)
        True
    """
    if not shapes:
        if "dtype" not in kwargs:
            kwargs["dtype"] = concat_tensor.dtype
        if "device" not in kwargs:
            kwargs["device"] = concat_tensor.device
        return cls([], **kwargs)

    num_elements = [shape.numel() for shape in shapes]
    element_shapes = tuple(tuple(int(dim) for dim in shape) for shape in shapes)
    varying_dims, static_dims = cls._pack_layout_from_element_shapes(element_shapes)
    permutation = varying_dims + static_dims
    identity_permutation = tuple(range(len(element_shapes[0]))) if element_shapes and element_shapes[0] else ()

    if len(set(shapes)) == 1 and permutation == identity_permutation:
        shape = shapes[0]
        total_elements = sum(num_elements)
        if concat_tensor.numel() == total_elements:
            try:
                reshaped = concat_tensor.reshape(len(shapes), *shape)
            except (RuntimeError, ValueError):
                # The reshape fast path is opportunistic; a normal unpack fallback
                # is expected for non-view-compatible inputs.
                pass
            else:
                tensors = [t.reshape(shape) for t in reshaped.unbind(0)]
                return cls(tensors, **kwargs)

    packed_sizes = tuple(cls._packed_size_from_shape(shape, varying_dims) for shape in element_shapes)
    total_expected = sum(num_elements)
    num_provided = concat_tensor.numel()
    if num_provided != total_expected:
        raise ValueError(
            f"Concatenated tensor has {num_provided} elements "
            f"but expected {total_expected} based on shapes {shapes}"
        )

    tensors = []
    start = 0
    inverse_permutation = cls._inverse_permutation(permutation)
    for shape, packed_size in zip(element_shapes, packed_sizes):
        end = start + packed_size
        chunk = concat_tensor.narrow(0, start, packed_size)
        packed_shape = tuple(shape[dim] for dim in varying_dims) + tuple(shape[dim] for dim in static_dims)
        tensor_data = chunk.reshape(packed_shape)
        if permutation != tuple(range(len(shape))):
            tensor_data = tensor_data.permute(inverse_permutation)
        tensors.append(tensor_data)
        start = end

    return cls(tensors, **kwargs)

from_tensor_mask classmethod

Python
from_tensor_mask(
    tensor: Tensor,
    mask: Tensor,
    *,
    batched: bool = False,
    **kwargs
)

Build a NestedTensor object from a padded Tensor and corresponding mask Tensor.

Parameters:

Name Type Description Default
tensor
Tensor

Padded Tensor.

required
mask
Tensor

Tensor Mask. The mask uses the same convention as mask_value: padding positions equal mask_value and valid positions equal not mask_value.

required
batched
bool

When True and mask.ndim == 1, treat mask as a per-batch-element selector (each True entry selects a row from tensor) rather than a contiguous-prefix length indicator.

False

Examples:

Python Console Session
>>> padded_tensor = torch.tensor([[1, 2, 3, 0, 0],
...                                [4, 5, 0, 0, 0],
...                                [6, 7, 8, 9, 0]])
>>> mask_tensor = torch.tensor([[1, 1, 1, 0, 0],
...                             [1, 1, 0, 0, 0],
...                             [1, 1, 1, 1, 0]])
>>> nested_tensor = NestedTensor.from_tensor_mask(padded_tensor, mask_tensor)
>>> nested_tensor
NestedTensor([
    [1, 2, 3],
    [4, 5],
    [6, 7, 8, 9]
])
Source code in danling/tensors/nested_tensor.py
Python
@classmethod
def from_tensor_mask(cls, tensor: Tensor, mask: Tensor, *, batched: bool = False, **kwargs):
    r"""
    Build a `NestedTensor` object from a padded `Tensor` and corresponding mask `Tensor`.

    Args:
        tensor: Padded Tensor.
        mask: Tensor Mask.
            The mask uses the same convention as ``mask_value``:
            padding positions equal ``mask_value`` and valid positions equal ``not mask_value``.
        batched: When ``True`` and ``mask.ndim == 1``, treat ``mask`` as a per-batch-element
            selector (each ``True`` entry selects a row from ``tensor``) rather than a
            contiguous-prefix length indicator.

    Examples:
        >>> padded_tensor = torch.tensor([[1, 2, 3, 0, 0],
        ...                                [4, 5, 0, 0, 0],
        ...                                [6, 7, 8, 9, 0]])
        >>> mask_tensor = torch.tensor([[1, 1, 1, 0, 0],
        ...                             [1, 1, 0, 0, 0],
        ...                             [1, 1, 1, 1, 0]])
        >>> nested_tensor = NestedTensor.from_tensor_mask(padded_tensor, mask_tensor)
        >>> nested_tensor
        NestedTensor([
            [1, 2, 3],
            [4, 5],
            [6, 7, 8, 9]
        ])
    """
    mask = mask.to(dtype=torch.bool)
    mask_value = kwargs.get("mask_value", False)
    effective_mask = ~mask if mask_value else mask

    if mask.ndim == 1:
        if batched:
            indices = effective_mask.nonzero(as_tuple=False).flatten()
            return cls([tensor[int(i)] for i in indices], dtype=tensor.dtype, **kwargs)
        return cls(tensor[effective_mask], dtype=tensor.dtype, **kwargs)
    # ndim >= 2: batch setup is shared, per-element trim differs by rank
    batch_first = kwargs.get("batch_first", True)
    tensor_iter = tensor if batch_first else tensor.transpose(0, 1)
    mask_iter = effective_mask if batch_first else effective_mask.transpose(0, 1)
    if tensor_iter.size(0) != mask_iter.size(0):
        raise ValueError("Tensor/mask batch dimension mismatch: " f"{tensor_iter.size(0)} vs {mask_iter.size(0)}")
    trimmed = []

    def _is_prefix_mask(mask_1d: Tensor) -> bool:
        count = int(mask_1d.sum().item())
        prefix = torch.arange(mask_1d.size(0), device=mask_1d.device, dtype=torch.long) < count
        return bool(torch.equal(mask_1d, prefix))

    def _is_hierarchical_prefix_mask(mask_nd: Tensor) -> bool:
        if mask_nd.dim() == 1:
            return _is_prefix_mask(mask_nd)
        leading_valid = mask_nd.reshape(mask_nd.size(0), -1).any(dim=1)
        valid_count = int(leading_valid.sum().item())
        prefix = torch.arange(mask_nd.size(0), device=mask_nd.device, dtype=torch.long) < valid_count
        if not torch.equal(leading_valid, prefix):
            return False
        return all(_is_hierarchical_prefix_mask(mask_nd[index]) for index in range(valid_count))

    if mask.ndim == 2:
        # 1-D per-element mask: only contiguous-prefix masks can be reconstructed
        # via slicing without changing dense semantics.
        counts = mask_iter.sum(dim=1, dtype=torch.long)
        prefix = torch.arange(mask_iter.size(1), device=mask_iter.device, dtype=torch.long).unsqueeze(0)
        prefix = prefix < counts.unsqueeze(1)
        if not torch.equal(mask_iter, prefix):
            raise ValueError(
                "from_tensor_mask() with 2-D masks requires each row to be a valid prefix mask; "
                "interior False gaps are not supported."
            )
        for t, count in zip(tensor_iter, counts.tolist()):
            trimmed.append(t[:count])
    else:
        # N-D per-element mask: only hierarchical ragged-prefix masks are representable as NestedTensor.
        extents = torch.zeros((mask_iter.size(0), mask_iter.dim() - 1), dtype=torch.long, device=mask_iter.device)
        nonzero = mask_iter.nonzero(as_tuple=False)
        if nonzero.numel() > 0:
            batch_index = nonzero[:, :1].expand(-1, extents.size(1))
            extents.scatter_reduce_(0, batch_index, nonzero[:, 1:] + 1, reduce="amax", include_self=False)
        extent_rows = extents.cpu().tolist()
        for t, em, sizes in zip(tensor_iter, mask_iter, extent_rows):
            if not _is_hierarchical_prefix_mask(em):
                raise ValueError(
                    "from_tensor_mask() with N-D masks requires each element mask to be a valid hierarchical "
                    "ragged prefix; "
                    "interior False gaps are not supported."
                )
            slices = tuple(slice(0, size) for size in sizes)
            t_slice = t[slices]
            m_slice = em[slices]
            valid_mask = m_slice
            if t_slice.dim() > m_slice.dim():
                valid_mask = m_slice.view(m_slice.shape + (1,) * (t_slice.dim() - m_slice.dim()))
            trimmed.append(t_slice.masked_fill(~valid_mask, kwargs.get("padding_value", 0.0)))
    return cls(trimmed, dtype=tensor.dtype, **kwargs)

nested_like

Python
nested_like(tensor: Tensor, strict: bool = True) -> Self

Create a new NestedTensor from a Tensor. The newly created NestedTensor will have the same shape as current NestedTensor.

Parameters:

Name Type Description Default
tensor
Tensor

The tensor to be converted to NestedTensor.

required
strict
bool

Check if the shape of tensor is the same as the current NestedTensor.

True

Examples:

Python Console Session
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> (nested_tensor == nested_tensor.nested_like(nested_tensor)).all()
tensor(True)
>>> tensor = nested_tensor.tensor
>>> (nested_tensor == nested_tensor.nested_like(tensor)).all()
tensor(True)
>>> f = nested_tensor.nested_like(torch.randn(2, 2))
Traceback (most recent call last):
...
ValueError: The shape of NestedTensor and input tensor does not match, ...
>>> p = nested_tensor.nested_like(torch.randn(2, 2), False)
>>> p = nested_tensor.nested_like(torch.randn(3, 3), False)
Traceback (most recent call last):
...
ValueError: The batch size of NestedTensor and input tensor does not match, 2 != 3
Source code in danling/tensors/nested_tensor.py
Python
def nested_like(self, tensor: Tensor, strict: bool = True) -> Self:
    r"""
    Create a new `NestedTensor` from a `Tensor`.
    The newly created `NestedTensor` will have the same shape as current `NestedTensor`.

    Args:
        tensor: The tensor to be converted to `NestedTensor`.
        strict: Check if the shape of `tensor` is the same as the current `NestedTensor`.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> (nested_tensor == nested_tensor.nested_like(nested_tensor)).all()
        tensor(True)
        >>> tensor = nested_tensor.tensor
        >>> (nested_tensor == nested_tensor.nested_like(tensor)).all()
        tensor(True)
        >>> f = nested_tensor.nested_like(torch.randn(2, 2))
        Traceback (most recent call last):
        ...
        ValueError: The shape of NestedTensor and input tensor does not match, ...
        >>> p = nested_tensor.nested_like(torch.randn(2, 2), False)
        >>> p = nested_tensor.nested_like(torch.randn(3, 3), False)
        Traceback (most recent call last):
        ...
        ValueError: The batch size of NestedTensor and input tensor does not match, 2 != 3
    """

    if isinstance(tensor, NestedTensor):
        return tensor.clone()

    if strict and self.shape != tensor.shape:
        raise ValueError(
            f"The shape of NestedTensor and input tensor does not match, {self.shape} != {tensor.shape}"
        )
    batch_dim = 0 if self.dim() <= 1 or self.batch_first else 1
    if len(self) != tensor.size(batch_dim):
        raise ValueError(
            "The batch size of NestedTensor and input tensor does not match, "
            f"{len(self)} != {tensor.size(batch_dim)}"
        )
    values = self._dense_to_packed_values(tensor)
    if values is not None:
        element_shapes = self._element_shapes
        return self.__class__._from_packed(
            values,
            self._offsets,
            self._physical_shape,
            batch_first=self.batch_first,
            padding_value=self.padding_value,
            mask_value=self.mask_value,
            pin_memory=self._pin_memory,
            outer_size=self._logical_shape,
            packed_sizes=self._packed_sizes,
            element_shapes=element_shapes,
        )
    dense_tensor = tensor.to(device=self.device)
    element_shapes = self._original_shapes()
    new_storage = []
    for idx, shape in enumerate(element_shapes):
        if self.batch_first:
            slices = (idx, *[slice(0, int(dim)) for dim in shape])
        else:
            if len(shape) == 0:
                slices = (idx,)
            else:
                slices = (slice(0, int(shape[0])), idx, *[slice(0, int(dim)) for dim in shape[1:]])
        # .contiguous() ensures storage elements don't inherit non-trivial
        # strides from the padded tensor (e.g. after transpose).
        new_storage.append(dense_tensor[slices].contiguous())
    return self.__class__(new_storage, dtype=tensor.dtype, **self._meta(include_dtype=False))

to_torch_nested

Python
to_torch_nested() -> Tensor

Create a torch.nested.nested_tensor object from self.

Examples:

Python Console Session
1
2
3
4
5
6
>>> nested_tensor = NestedTensor([[2, 3, 5], [7, 8]])
>>> nt = nested_tensor.to_torch_nested()
>>> nt.layout == torch.jagged
True
>>> nt.values()
tensor([2, 3, 5, 7, 8])
Source code in danling/tensors/nested_tensor.py
Python
def to_torch_nested(self) -> Tensor:
    r"""
    Create a `torch.nested.nested_tensor` object from `self`.

    Examples:
        >>> nested_tensor = NestedTensor([[2, 3, 5], [7, 8]])
        >>> nt = nested_tensor.to_torch_nested()
        >>> nt.layout == torch.jagged
        True
        >>> nt.values()
        tensor([2, 3, 5, 7, 8])
    """
    storage = list(self._storage)
    if not storage or all(t.dim() > 0 for t in storage):
        return nested.nested_tensor(storage, layout=torch.jagged)
    return nested.nested_tensor(storage)

unbind

Python
unbind(dim: int = 0) -> tuple[Tensor, ...]

Unbind the NestedTensor.

Source code in danling/tensors/nested_tensor.py
Python
def unbind(self, dim: int = 0) -> tuple[Tensor, ...]:
    r"""
    Unbind the NestedTensor.
    """
    return torch.unbind(self, dim=dim)

__getitem__

Python
__getitem__(
    index: (
        int | slice | list | tuple | Tensor | NestedTensor
    ),
) -> Tensor | NestedTensor

Retrieve element(s) by index, slice, list, tuple, or tensor mask.

Source code in danling/tensors/nested_tensor.py
Python
def __getitem__(self, index: int | slice | list | tuple | Tensor | NestedTensor) -> Tensor | NestedTensor:
    r"""Retrieve element(s) by index, slice, list, tuple, or tensor mask."""
    if isinstance(index, int):
        return self._storage[index]
    if isinstance(index, (slice, list)):
        if isinstance(index, list) and index and all(isinstance(i, bool) for i in index):
            if len(index) != len(self):
                raise IndexError(f"Boolean index has length {len(index)} but batch size is {len(self)}")
            index = [i for i, flag in enumerate(index) if flag]
        storage = tuple(self._storage[index] if isinstance(index, slice) else [self._storage[i] for i in index])
        return self.__class__(storage, **self._meta(include_dtype=True))
    if isinstance(index, tuple):
        if len(index) == 0:
            return self

        # Expand Ellipsis: ``nt[..., :2]`` on a 4-D NestedTensor becomes
        # ``nt[:, :, :, :2]``.  The batch dim is consumed first, so Ellipsis
        # fills the gap between the number of explicit indices and the total
        # number of logical dimensions.
        if Ellipsis in index:
            eidx = index.index(Ellipsis)
            n_explicit = len(index) - 1  # exclude Ellipsis itself
            n_expand = self.dim() - n_explicit
            index = index[:eidx] + (slice(None),) * n_expand + index[eidx + 1 :]

        batch_index, *rest = index

        if isinstance(batch_index, (Tensor, NestedTensor)):
            return self.tensor[index]

        if isinstance(batch_index, list) and batch_index and all(isinstance(i, bool) for i in batch_index):
            if len(batch_index) != len(self):
                raise IndexError(f"Boolean index has length {len(batch_index)} but batch size is {len(self)}")
            batch_index = [i for i, flag in enumerate(batch_index) if flag]

        if isinstance(batch_index, int):
            tensor = self._storage[batch_index]
            if rest:
                return tensor[tuple(rest)]
            return tensor
        elif isinstance(batch_index, (slice, list)):
            if isinstance(batch_index, slice):
                selected = self._storage[batch_index]
            else:
                selected = tuple(self._storage[i] for i in batch_index)
            if rest:
                rest_tuple = tuple(rest)
                selected = tuple(t[rest_tuple] for t in selected)
            return self.__class__(selected, **self._meta(include_dtype=True))
        raise ValueError(f"Unsupported batch index type {type(batch_index)}")
    if isinstance(index, NestedTensor):
        if len(self) != len(index):
            raise ValueError(
                "NestedTensor batch length mismatch between self and index: "
                f"self={len(self)}, index={len(index)}"
            )
        return self.__class__(
            [t[i] for t, i in zip(self._storage, index._storage)], **self._meta(include_dtype=True)
        )
    if isinstance(index, Tensor):
        if index.dim() == 0 and index.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8):
            return self._storage[int(index.item())]
        if index.dim() == 1:
            if index.dtype in (torch.bool, torch.uint8):
                if index.numel() != len(self):
                    raise IndexError(f"Boolean index has length {index.numel()} but batch size is {len(self)}")
                selected = tuple(self._storage[i] for i, flag in enumerate(index.tolist()) if bool(flag))
                return self.__class__(selected, **self._meta(include_dtype=True))
            if index.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8):
                return self.__class__(
                    [self._storage[int(i)] for i in index.tolist()],
                    **self._meta(include_dtype=True),
                )
        index = self.nested_like(index, strict=False)
        return self.__class__(
            [t[i] for t, i in zip(self._storage, index._storage)], **self._meta(include_dtype=True)
        )
    raise ValueError(f"Unsupported index type {type(index)}")

__setitem__

Python
__setitem__(
    index: int | slice | list | tuple,
    value: Tensor | NestedTensor,
) -> None

Set values in the NestedTensor at the specified index.

Parameters:

Name Type Description Default
index
int | slice | list | tuple

The index to modify. Can be an integer, slice, list, or tuple.

required
value
Tensor | NestedTensor

The new value to set. Can be a Tensor or NestedTensor.

required

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor[0] = torch.tensor([6, 7, 8])
>>> nested_tensor[0]
tensor([6, 7, 8])
>>> nested_tensor[1] = torch.tensor([9, 10, 11, 12])
>>> nested_tensor.shape
torch.Size([2, 4])
Source code in danling/tensors/nested_tensor.py
Python
def __setitem__(self, index: int | slice | list | tuple, value: Tensor | NestedTensor) -> None:
    r"""
    Set values in the NestedTensor at the specified index.

    Args:
        index: The index to modify. Can be an integer, slice, list, or tuple.
        value: The new value to set. Can be a Tensor or NestedTensor.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> nested_tensor[0] = torch.tensor([6, 7, 8])
        >>> nested_tensor[0]
        tensor([6, 7, 8])
        >>> nested_tensor[1] = torch.tensor([9, 10, 11, 12])
        >>> nested_tensor.shape
        torch.Size([2, 4])
    """
    if isinstance(index, int):
        self._invalidate_transient_caches()
        if isinstance(value, NestedTensor):
            if len(value._storage) != 1:
                raise ValueError(
                    f"When setting with an integer index, value must have a single tensor, but got {len(value)}"
                )
            value = value._storage[0]
        if not isinstance(value, Tensor):
            value = torch.tensor(value, device=self.device, dtype=self.dtype)
        else:
            value = value.to(device=self.device, dtype=self.dtype)
        if self.requires_grad:
            value.requires_grad_(True)

        # Normalize negative index
        idx = index + len(self) if index < 0 else index
        if idx < 0 or idx >= len(self):
            raise IndexError(f"index {index} is out of range for NestedTensor with {len(self)} elements")
        expected_ndim = self._physical_shape.size(1)
        if value.dim() != expected_ndim:
            raise ValueError(
                f"Assigned tensor ndim must match existing ndim {expected_ndim}, but got {value.dim()}"
            )

        old_start = int(self._offsets[idx].item())
        old_end = int(self._offsets[idx + 1].item())
        old_size = old_end - old_start
        new_shape_row = torch.tensor(list(value.shape), dtype=self._physical_shape.dtype)

        permutation = self._permutation
        identity_permutation = tuple(range(expected_ndim))
        varying_dims = self._varying_dims
        static_dims = self._static_dims
        packed_size = type(self)._packed_size_from_shape(tuple(int(dim) for dim in value.shape), varying_dims)
        packed_value = value if permutation == identity_permutation else value.permute(permutation)
        suffix_shape = tuple(int(value.shape[dim]) for dim in static_dims)
        new_payload = packed_value.reshape((packed_size, *suffix_shape) if suffix_shape else (packed_size,))
        new_size = packed_size

        if self._values.dim() > 1 and new_payload.shape[1:] != self._values.shape[1:]:
            storage_list = list(self._storage)
            storage_list[idx] = value
            self._repack(storage_list)
            return

        if new_size == old_size:
            # Same packed span size: direct overwrite keeps _values allocation.
            self._values[old_start:old_end] = new_payload
            self._physical_shape[idx] = new_shape_row
        else:
            # Different packed span size: splice _values and shift subsequent offsets.
            self._values = torch.cat([self._values[:old_start], new_payload, self._values[old_end:]], dim=0)
            delta = new_size - old_size
            self._offsets = self._offsets.clone()
            self._offsets[idx + 1 :] += delta  # noqa: E203
            self._physical_shape = self._physical_shape.clone()
            self._physical_shape[idx] = new_shape_row
        self._logical_shape = self._logical_shape_from_physical_shape(
            self._physical_shape, self._offsets, self.batch_first
        )
        if self._element_shapes is not None and self._packed_sizes is not None:
            element_shapes = list(self._element_shapes)
            element_shapes[idx] = tuple(int(dim) for dim in value.shape)
            self._element_shapes = tuple(element_shapes)
            packed_sizes = list(self._packed_sizes)
            packed_sizes[idx] = self._packed_sizes_like((self._element_shapes[idx],))[0]
            self._packed_sizes = tuple(packed_sizes)
        self._validate_metadata()
    elif isinstance(index, (slice, list)):
        if isinstance(index, list) and index and all(isinstance(i, bool) for i in index):
            if len(index) != len(self):
                raise IndexError(f"Boolean index has length {len(index)} but batch size is {len(self)}")
            index = [i for i, flag in enumerate(index) if flag]

        if isinstance(value, Tensor) and not isinstance(value, NestedTensor):
            if value.dim() > 1 and value.size(0) > 1:
                value = self.__class__(value.unbind(0), **self._meta())
            else:
                value = self.__class__([value], **self._meta())

        if isinstance(index, slice):
            start, stop, step = index.indices(len(self))
            indices = range(start, stop, step)
        else:
            indices = index  # type: ignore[assignment]

        if len(indices) != len(value._storage):
            raise ValueError(
                f"Size mismatch: tried to assign {len(value._storage)} values to {len(indices)} indices"
            )

        storage_list = list(self._storage)
        for i, idx in enumerate(indices):
            storage_list[idx] = value._storage[i]
        self._storage = tuple(storage_list)
    elif isinstance(index, tuple):
        if len(index) == 0:
            return
        if len(index) == 1:
            self[index[0]] = value
            return

        first_idx, rest_idx = index[0], index[1:]
        batch_indices: list[int]
        if isinstance(first_idx, int):
            batch_indices = [first_idx]
        elif isinstance(first_idx, (slice, list)):
            if isinstance(first_idx, list) and first_idx and all(isinstance(i, bool) for i in first_idx):
                if len(first_idx) != len(self):
                    raise IndexError(f"Boolean index has length {len(first_idx)} but batch size is {len(self)}")
                batch_indices = [i for i, flag in enumerate(first_idx) if flag]
            elif isinstance(first_idx, slice):
                start, stop, step = first_idx.indices(len(self))
                batch_indices = list(range(start, stop, step))
            else:
                batch_indices = list(first_idx)  # type: ignore[arg-type]
        else:
            raise ValueError(f"Unsupported first index type {type(first_idx)}")

        if isinstance(value, NestedTensor):
            if len(batch_indices) != len(value._storage):
                raise ValueError(
                    f"Size mismatch: tried to assign {len(value._storage)} values to {len(batch_indices)} indices"
                )
            assigned_values = list(value._storage)
        else:
            assigned_values = [value] * len(batch_indices)

        elems = list(self._storage)
        for position, idx in enumerate(batch_indices):
            elem = elems[idx].clone()
            elem[rest_idx] = assigned_values[position]
            elems[idx] = elem
        self._storage = tuple(elems)
    else:
        raise ValueError(f"Unsupported index type {type(index)}")

__copy__

Python
__copy__()

Shallow copy: new NestedTensor sharing underlying tensor data.

Source code in danling/tensors/nested_tensor.py
Python
def __copy__(self):
    r"""Shallow copy: new NestedTensor sharing underlying tensor data."""
    return self.__class__._from_packed(
        self._values,
        self._offsets,
        self._physical_shape,
        permutation=self._permutation,
        batch_first=self.batch_first,
        padding_value=self.padding_value,
        mask_value=self.mask_value,
        pin_memory=self._pin_memory,
        outer_size=self._logical_shape,
        packed_sizes=self._packed_sizes,
        element_shapes=self._element_shapes,
    )

__deepcopy__

Python
__deepcopy__(memo)

Deep copy: clones all tensor data.

Source code in danling/tensors/nested_tensor.py
Python
def __deepcopy__(self, memo):
    r"""Deep copy: clones all tensor data."""
    result = self.__class__._from_packed(
        self._values.clone(),
        self._offsets.clone(),
        self._physical_shape.clone(),
        permutation=self._permutation,
        batch_first=self.batch_first,
        padding_value=self.padding_value,
        mask_value=self.mask_value,
        pin_memory=self._pin_memory,
        outer_size=self._logical_shape,
        packed_sizes=self._packed_sizes,
        element_shapes=self._element_shapes,
    )
    memo[id(self)] = result
    return result

all

Python
all(
    dim: int | None = None, keepdim: bool = False
) -> bool | Tensor | NestedTensor

Tests if all elements in NestedTensor evaluate to True.

Examples:

Python Console Session
>>> nested_tensor = NestedTensor([torch.ones(2, 4, dtype=torch.bool), torch.ones(3, 5, dtype=torch.bool)])
>>> nested_tensor.all()
tensor(True)
>>> nested_tensor.all(dim=0)
tensor([True, True])
>>> nested_tensor.all(dim=0, keepdim=True)
tensor([[True, True]])
>>> nested_tensor.all(dim=1)
NestedTensor([
    [True, True, True, True],
    [True, True, True, True, True]
])
>>> nested_tensor.all(dim=1, keepdim=True)
NestedTensor([
    [[True, True, True, True]],
    [[True, True, True, True, True]]
])
>>> nested_tensor.batch_first = False
>>> nested_tensor.all(dim=1)
tensor([True, True])
>>> nested_tensor.all(dim=0)
NestedTensor([
    [True, True, True, True],
    [True, True, True, True, True]
])
>>> nested_tensor.all(dim=-2)
tensor([True, True])
Source code in danling/tensors/nested_tensor.py
Python
def all(self, dim: int | None = None, keepdim: bool = False) -> bool | Tensor | NestedTensor:
    r"""
    Tests if all elements in NestedTensor evaluate to True.

    Examples:
        >>> nested_tensor = NestedTensor([torch.ones(2, 4, dtype=torch.bool), torch.ones(3, 5, dtype=torch.bool)])
        >>> nested_tensor.all()
        tensor(True)
        >>> nested_tensor.all(dim=0)
        tensor([True, True])
        >>> nested_tensor.all(dim=0, keepdim=True)
        tensor([[True, True]])
        >>> nested_tensor.all(dim=1)
        NestedTensor([
            [True, True, True, True],
            [True, True, True, True, True]
        ])
        >>> nested_tensor.all(dim=1, keepdim=True)
        NestedTensor([
            [[True, True, True, True]],
            [[True, True, True, True, True]]
        ])
        >>> nested_tensor.batch_first = False
        >>> nested_tensor.all(dim=1)
        tensor([True, True])
        >>> nested_tensor.all(dim=0)
        NestedTensor([
            [True, True, True, True],
            [True, True, True, True, True]
        ])
        >>> nested_tensor.all(dim=-2)
        tensor([True, True])
    """
    return torch.all(self, dim=dim, keepdim=keepdim)

any

Python
any(
    dim: int | None = None, keepdim: bool = False
) -> bool | Tensor | NestedTensor

Tests if any elements in NestedTensor evaluate to True.

Examples:

Python Console Session
1
2
3
4
5
>>> nested_tensor = NestedTensor([torch.zeros(2, dtype=torch.bool), torch.ones(3, dtype=torch.bool)])
>>> nested_tensor.any()
tensor(True)
>>> nested_tensor.any(dim=0)
tensor([False,  True])
Source code in danling/tensors/nested_tensor.py
Python
def any(self, dim: int | None = None, keepdim: bool = False) -> bool | Tensor | NestedTensor:
    r"""
    Tests if any elements in NestedTensor evaluate to True.

    Examples:
        >>> nested_tensor = NestedTensor([torch.zeros(2, dtype=torch.bool), torch.ones(3, dtype=torch.bool)])
        >>> nested_tensor.any()
        tensor(True)
        >>> nested_tensor.any(dim=0)
        tensor([False,  True])
    """
    return torch.any(self, dim=dim, keepdim=keepdim)

dim

Python
dim() -> int

Number of dimension of the NestedTensor.

Examples:

Python Console Session
1
2
3
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.dim()
2
Source code in danling/tensors/nested_tensor.py
Python
def dim(self) -> int:
    r"""
    Number of dimension of the NestedTensor.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> nested_tensor.dim()
        2
    """
    if not hasattr(self, "_logical_shape"):
        with torch._C.DisableTorchFunctionSubclass():
            return len(torch.Tensor.size(self))
    return len(self._logical_shape)

max

Python
max(
    dim: int | None = None, keepdim: bool = False
) -> Tensor | NestedTensor

Return the maximum value, optionally along a given dimension.

Source code in danling/tensors/nested_tensor.py
Python
def max(self, dim: int | None = None, keepdim: bool = False) -> Tensor | NestedTensor:
    r"""Return the maximum value, optionally along a given dimension."""
    if dim is None:
        return torch.max(self)
    return torch.max(self, dim=dim, keepdim=keepdim)

mean

Python
mean(
    dim: int | None = None,
    keepdim: bool = False,
    *,
    dtype: dtype | None = None
) -> Tensor | NestedTensor

Return the mean value, optionally along a given dimension.

Source code in danling/tensors/nested_tensor.py
Python
def mean(
    self,
    dim: int | None = None,
    keepdim: bool = False,
    *,
    dtype: torch.dtype | None = None,  # type: ignore[name-defined]
) -> Tensor | NestedTensor:
    r"""Return the mean value, optionally along a given dimension."""
    return torch.mean(self, dim=dim, keepdim=keepdim, dtype=dtype)

min

Python
min(
    dim: int | None = None, keepdim: bool = False
) -> Tensor | NestedTensor

Return the minimum value, optionally along a given dimension.

Source code in danling/tensors/nested_tensor.py
Python
def min(self, dim: int | None = None, keepdim: bool = False) -> Tensor | NestedTensor:
    r"""Return the minimum value, optionally along a given dimension."""
    if dim is None:
        return torch.min(self)
    return torch.min(self, dim=dim, keepdim=keepdim)

numel

Python
numel() -> int

Number of elements in the NestedTensor.

Examples:

Python Console Session
1
2
3
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.numel()
5
Source code in danling/tensors/nested_tensor.py
Python
def numel(self) -> int:
    r"""
    Number of elements in the NestedTensor.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> nested_tensor.numel()
        5
    """
    return self._values.numel()

permute

Python
permute(*dims) -> Self

Apply permutation to each tensor in the NestedTensor.

Parameters:

Name Type Description Default
*dims

The desired ordering of dimensions for the NestedTensor (including batch dimension).

()

Returns:

Name Type Description
NestedTensor Self

A new NestedTensor with each tensor permuted.

Examples:

Python Console Session
1
2
3
4
>>> nested_tensor = NestedTensor([torch.randn(3, 4, 5), torch.randn(2, 4, 5)])
>>> permuted = nested_tensor.permute(0, 3, 1, 2)
>>> permuted.shape
torch.Size([2, 5, 3, 4])
Source code in danling/tensors/nested_tensor.py
Python
def permute(self, *dims) -> Self:
    r"""
    Apply permutation to each tensor in the NestedTensor.

    Args:
        *dims: The desired ordering of dimensions for the NestedTensor (including batch dimension).

    Returns:
        NestedTensor: A new NestedTensor with each tensor permuted.

    Examples:
        >>> nested_tensor = NestedTensor([torch.randn(3, 4, 5), torch.randn(2, 4, 5)])
        >>> permuted = nested_tensor.permute(0, 3, 1, 2)
        >>> permuted.shape
        torch.Size([2, 5, 3, 4])
    """
    return torch.permute(self, dims)

moveaxis

Python
moveaxis(source, destination) -> Self

Move per-element dimensions to new positions.

Source code in danling/tensors/nested_tensor.py
Python
def moveaxis(self, source, destination) -> Self:
    r"""Move per-element dimensions to new positions."""
    return torch.moveaxis(self, source, destination)

movedim

Python
movedim(source, destination) -> Self

Alias for moveaxis().

Source code in danling/tensors/nested_tensor.py
Python
def movedim(self, source, destination) -> Self:
    r"""Alias for `moveaxis()`."""
    return torch.movedim(self, source, destination)

pin_memory

Python
pin_memory() -> Self

Pin the underlying tensor memory for faster host-to-device transfer.

Source code in danling/tensors/nested_tensor.py
Python
def pin_memory(self) -> Self:
    r"""Pin the underlying tensor memory for faster host-to-device transfer."""
    return type(self)._from_packed(
        self._values.pin_memory(),
        self._offsets,
        self._physical_shape,
        batch_first=self.batch_first,
        padding_value=self.padding_value,
        mask_value=self.mask_value,
        pin_memory=True,
        outer_size=self._logical_shape,
        packed_sizes=self._packed_sizes,
        element_shapes=self._element_shapes,
    )

prod

Python
prod(
    dim: int | None = None,
    keepdim: bool = False,
    *,
    dtype: dtype | None = None
) -> Tensor | NestedTensor

Return the product of elements, optionally along a given dimension.

Source code in danling/tensors/nested_tensor.py
Python
def prod(
    self,
    dim: int | None = None,
    keepdim: bool = False,
    *,
    dtype: torch.dtype | None = None,  # type: ignore[name-defined]
) -> Tensor | NestedTensor:
    r"""Return the product of elements, optionally along a given dimension."""
    return torch.prod(self, dim=dim, keepdim=keepdim, dtype=dtype)

requires_grad_

Python
requires_grad_(requires_grad: bool = True)

Enable or disable gradient computation in-place.

Source code in danling/tensors/nested_tensor.py
Python
def requires_grad_(self, requires_grad: bool = True):
    r"""Enable or disable gradient computation in-place."""
    self.requires_grad = requires_grad
    return self

reshape

Python
reshape(*shape) -> Self

Reshape each tensor in the NestedTensor.

Parameters:

Name Type Description Default
*shape

The desired size of each dimension for the underlying tensors.

()

Returns:

Name Type Description
NestedTensor Self

A new NestedTensor with each tensor reshaped.

Examples:

Python Console Session
1
2
3
4
>>> nested_tensor = NestedTensor([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]])])
>>> reshaped = nested_tensor.reshape(4)
>>> reshaped.shape
torch.Size([2, 4])
Source code in danling/tensors/nested_tensor.py
Python
def reshape(self, *shape) -> Self:
    r"""
    Reshape each tensor in the NestedTensor.

    Args:
        *shape: The desired size of each dimension for the underlying tensors.

    Returns:
        NestedTensor: A new NestedTensor with each tensor reshaped.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]])])
        >>> reshaped = nested_tensor.reshape(4)
        >>> reshaped.shape
        torch.Size([2, 4])
    """
    if not shape:
        raise TypeError("reshape() missing shape")
    target_shape = shape[0] if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)) else shape
    return torch.reshape(self, target_shape)

flatten

Python
flatten(start_dim: int = 0, end_dim: int = -1)

Flatten each tensor in the NestedTensor.

Source code in danling/tensors/nested_tensor.py
Python
def flatten(self, start_dim: int = 0, end_dim: int = -1):
    r"""Flatten each tensor in the NestedTensor."""
    return torch.flatten(self, start_dim=start_dim, end_dim=end_dim)

flip

Python
flip(dims) -> Self

Flip each tensor in the NestedTensor along the given dimensions.

Source code in danling/tensors/nested_tensor.py
Python
def flip(self, dims) -> Self:
    r"""Flip each tensor in the NestedTensor along the given dimensions."""
    return torch.flip(self, dims)

size

Python
size(dim: int | None = None) -> Size | int

Returns the size of the self NestedTensor.

Parameters:

Name Type Description Default
dim
int | None

If not specified, the returned value is a torch.Size, a subclass of tuple. If specified, returns an int holding the size of that dimension. Defaults to None.

None

Examples:

Python Console Session
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.size()
torch.Size([2, 3])
>>> nested_tensor.size(0)
2
>>> nested_tensor[1] = torch.tensor([4, 5, 6, 7])
>>> nested_tensor.shape
torch.Size([2, 4])
>>> nested_tensor.size(1)
4
Source code in danling/tensors/nested_tensor.py
Python
def size(self, dim: int | None = None) -> torch.Size | int:  # type: ignore[override, name-defined]
    r"""
    Returns the size of the self `NestedTensor`.

    Args:
        dim: If not specified, the returned value is a `torch.Size`, a subclass of `tuple`.
            If specified, returns an `int` holding the size of that dimension.
            Defaults to `None`.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> nested_tensor.size()
        torch.Size([2, 3])
        >>> nested_tensor.size(0)
        2
        >>> nested_tensor[1] = torch.tensor([4, 5, 6, 7])
        >>> nested_tensor.shape
        torch.Size([2, 4])
        >>> nested_tensor.size(1)
        4
    """
    if hasattr(self, "_logical_shape"):
        full_size = self._logical_shape
    else:
        with torch._C.DisableTorchFunctionSubclass():
            full_size = torch.Tensor.size(self)
    if dim is not None:
        dim = dim + len(full_size) if dim < 0 else dim
        return full_size[dim]
    return full_size

sum

Python
sum(
    dim: int | Sequence[int] | None = None,
    keepdim: bool = False,
    *,
    dtype: dtype | None = None
) -> Tensor | NestedTensor

Returns the sum of each tensor over the given dimension(s).

Parameters:

Name Type Description Default
dim
int | Sequence[int] | None

The dimension or dimensions to reduce. If None, sum over all dimensions. Supports int, Sequence[int], or None. Negative dimensions are supported.

None
keepdim
bool

Whether to retain reduced dimensions with size 1.

False
dtype
dtype | None

The desired data type of returned tensor.

None

Returns:

Type Description
Tensor | NestedTensor

Tensor or NestedTensor depending on the dimensions being reduced.

Examples:

Python Console Session
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.sum()
tensor(15)
>>> nested_tensor.sum(dim=0)  # when dim=0, sum across batch dimension
tensor([6, 9])
>>> nested_tensor.sum(dim=1)
tensor([6, 9])
>>> nested_tensor.sum(dim=[0, 1])
tensor(15)
>>> nested_tensor.sum(dim=0, keepdim=True)
tensor([[6, 9]])
>>> nested_tensor.sum(dtype=torch.float32)
tensor(15.)
Source code in danling/tensors/nested_tensor.py
Python
def sum(
    self,
    dim: int | Sequence[int] | None = None,
    keepdim: bool = False,
    *,
    dtype: torch.dtype | None = None,  # type: ignore[name-defined]
) -> Tensor | NestedTensor:
    r"""
    Returns the sum of each tensor over the given dimension(s).

    Args:
        dim: The dimension or dimensions to reduce. If None, sum over all dimensions.
            Supports int, Sequence[int], or None. Negative dimensions are supported.
        keepdim: Whether to retain reduced dimensions with size 1.
        dtype: The desired data type of returned tensor.

    Returns:
        Tensor or NestedTensor depending on the dimensions being reduced.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> nested_tensor.sum()
        tensor(15)
        >>> nested_tensor.sum(dim=0)  # when dim=0, sum across batch dimension
        tensor([6, 9])
        >>> nested_tensor.sum(dim=1)
        tensor([6, 9])
        >>> nested_tensor.sum(dim=[0, 1])
        tensor(15)
        >>> nested_tensor.sum(dim=0, keepdim=True)
        tensor([[6, 9]])
        >>> nested_tensor.sum(dtype=torch.float32)
        tensor(15.)
    """
    return torch.sum(self, dim=dim, keepdim=keepdim, dtype=dtype)

tolist

Python
tolist() -> list

Convert a NestedTensor to a list of lists of values.

Examples:

Python Console Session
1
2
3
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.tolist()
[[1, 2, 3], [4, 5]]
Source code in danling/tensors/nested_tensor.py
Python
def tolist(self) -> list:
    r"""
    Convert a NestedTensor to a list of lists of values.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> nested_tensor.tolist()
        [[1, 2, 3], [4, 5]]
    """
    return [t.tolist() for t in self._storage]

transpose

Python
transpose(dim0: int, dim1: int) -> Self

Transpose dimensions dim0 and dim1 for each tensor in the NestedTensor.

Parameters:

Name Type Description Default
dim0
int

First dimension to transpose (in NestedTensor coordinate system).

required
dim1
int

Second dimension to transpose (in NestedTensor coordinate system).

required

Returns:

Name Type Description
NestedTensor Self

A new NestedTensor with each tensor transposed.

Examples:

Python Console Session
1
2
3
4
5
>>> nested_tensor = NestedTensor([torch.randn(3, 4), torch.randn(2, 4)])
>>> # NestedTensor shape is [2, 3, 4], underlying tensors are [3, 4] and [2, 4]
>>> transposed = nested_tensor.transpose(1, 2)  # transpose dims 1 and 2
>>> transposed.shape  # batch dimension is still first
torch.Size([2, 4, 3])
Source code in danling/tensors/nested_tensor.py
Python
def transpose(self, dim0: int, dim1: int) -> Self:  # type: ignore[valid-type]
    r"""
    Transpose dimensions dim0 and dim1 for each tensor in the NestedTensor.

    Args:
        dim0: First dimension to transpose (in NestedTensor coordinate system).
        dim1: Second dimension to transpose (in NestedTensor coordinate system).

    Returns:
        NestedTensor: A new NestedTensor with each tensor transposed.

    Examples:
        >>> nested_tensor = NestedTensor([torch.randn(3, 4), torch.randn(2, 4)])
        >>> # NestedTensor shape is [2, 3, 4], underlying tensors are [3, 4] and [2, 4]
        >>> transposed = nested_tensor.transpose(1, 2)  # transpose dims 1 and 2
        >>> transposed.shape  # batch dimension is still first
        torch.Size([2, 4, 3])
    """
    return torch.transpose(self, dim0, dim1)

swapaxes

Python
swapaxes(axis0: int, axis1: int) -> Self

Alias for transpose().

Source code in danling/tensors/nested_tensor.py
Python
def swapaxes(self, axis0: int, axis1: int) -> Self:
    r"""Alias for `transpose()`."""
    return torch.swapaxes(self, axis0, axis1)

swapdims

Python
swapdims(dim0: int, dim1: int) -> Self

Alias for swapaxes().

Source code in danling/tensors/nested_tensor.py
Python
def swapdims(self, dim0: int, dim1: int) -> Self:
    r"""Alias for `swapaxes()`."""
    return torch.swapdims(self, dim0, dim1)

squeeze

Python
squeeze(dim: int | None = None) -> Self

Squeeze singleton dimensions from each tensor in the NestedTensor.

Source code in danling/tensors/nested_tensor.py
Python
def squeeze(self, dim: int | None = None) -> Self:  # type: ignore[valid-type]
    r"""Squeeze singleton dimensions from each tensor in the NestedTensor."""
    return torch.squeeze(self, dim=dim)

unsqueeze

Python
unsqueeze(dim: int) -> Self

Unsqueeze each tensor in the NestedTensor by adding a singleton dimension at the specified position.

Parameters:

Name Type Description Default
dim
int

The dimension at which to add the singleton dimension. This is in the NestedTensor’s coordinate system (where dim 0 is the batch dimension).

required

Returns:

Name Type Description
NestedTensor Self

A new NestedTensor with each tensor unsqueezed at the specified dimension.

Examples:

Python Console Session
1
2
3
4
5
6
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> # Original shape: [2, 3] (batch_size=2, max_seq_len=3)
>>> unsqueezed = nested_tensor.unsqueeze(1)
>>> unsqueezed.shape
torch.Size([2, 1, 3])
>>> # Now each underlying tensor has shape [1, seq_len] instead of [seq_len]
Python Console Session
1
2
3
4
5
6
>>> nested_tensor_2d = NestedTensor([torch.randn(3, 4), torch.randn(2, 4)])
>>> # Original shape: [2, 3, 4] (batch_size=2, max_len1=3, max_len2=4)
>>> unsqueezed_2d = nested_tensor_2d.unsqueeze(2)
>>> unsqueezed_2d.shape
torch.Size([2, 3, 1, 4])
>>> # Now each underlying tensor has shape [len1, 1, len2] instead of [len1, len2]
Source code in danling/tensors/nested_tensor.py
Python
def unsqueeze(self, dim: int) -> Self:  # type: ignore[valid-type]
    r"""
    Unsqueeze each tensor in the NestedTensor by adding a singleton dimension at the specified position.

    Args:
        dim: The dimension at which to add the singleton dimension. This is in the NestedTensor's
            coordinate system (where dim 0 is the batch dimension).

    Returns:
        NestedTensor: A new NestedTensor with each tensor unsqueezed at the specified dimension.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> # Original shape: [2, 3] (batch_size=2, max_seq_len=3)
        >>> unsqueezed = nested_tensor.unsqueeze(1)
        >>> unsqueezed.shape
        torch.Size([2, 1, 3])
        >>> # Now each underlying tensor has shape [1, seq_len] instead of [seq_len]

        >>> nested_tensor_2d = NestedTensor([torch.randn(3, 4), torch.randn(2, 4)])
        >>> # Original shape: [2, 3, 4] (batch_size=2, max_len1=3, max_len2=4)
        >>> unsqueezed_2d = nested_tensor_2d.unsqueeze(2)
        >>> unsqueezed_2d.shape
        torch.Size([2, 3, 1, 4])
        >>> # Now each underlying tensor has shape [len1, 1, len2] instead of [len1, len2]
    """
    return torch.unsqueeze(self, dim)

unflatten

Python
unflatten(dim: int, sizes) -> Self

Unflatten one dimension of each tensor in the NestedTensor.

Source code in danling/tensors/nested_tensor.py
Python
def unflatten(self, dim: int, sizes) -> Self:  # type: ignore[valid-type]
    r"""Unflatten one dimension of each tensor in the NestedTensor."""
    return torch.unflatten(self, dim, sizes)

roll

Python
roll(shifts, dims=None) -> Self

Roll each tensor in the NestedTensor along the given dimensions.

Source code in danling/tensors/nested_tensor.py
Python
def roll(self, shifts, dims=None) -> Self:
    r"""Roll each tensor in the NestedTensor along the given dimensions."""
    return torch.roll(self, shifts, dims=dims)

rot90

Python
rot90(k: int = 1, dims: Sequence[int] = (0, 1)) -> Self

Rotate each tensor in the NestedTensor by 90 degrees in the given plane.

Source code in danling/tensors/nested_tensor.py
Python
def rot90(self, k: int = 1, dims: Sequence[int] = (0, 1)) -> Self:
    r"""Rotate each tensor in the NestedTensor by 90 degrees in the given plane."""
    return torch.rot90(self, k, dims)

view

Python
view(*shape) -> Self

View each tensor in the NestedTensor with a different shape.

Parameters:

Name Type Description Default
*shape

The desired size of each dimension for the underlying tensors.

()

Returns:

Name Type Description
NestedTensor Self

A new NestedTensor with each tensor viewed with the new shape.

Examples:

Python Console Session
1
2
3
4
5
6
>>> nested_tensor = NestedTensor([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]])])
>>> viewed = nested_tensor.view(4)  # View each 2x2 tensor as 4
>>> viewed.shape
torch.Size([2, 4])
>>> type(viewed).__name__
'NestedTensor'
Source code in danling/tensors/nested_tensor.py
Python
def view(self, *shape) -> Self:
    r"""
    View each tensor in the NestedTensor with a different shape.

    Args:
        *shape: The desired size of each dimension for the underlying tensors.

    Returns:
        NestedTensor: A new NestedTensor with each tensor viewed with the new shape.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]])])
        >>> viewed = nested_tensor.view(4)  # View each 2x2 tensor as 4
        >>> viewed.shape
        torch.Size([2, 4])
        >>> type(viewed).__name__
        'NestedTensor'
    """
    if not shape:
        raise TypeError("view() missing shape")
    target_shape = shape[0] if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)) else shape
    return NestedTensorAtenRegistry[torch.ops.aten.view.default](
        torch.ops.aten.view.default, (self, list(target_shape)), {}
    )

where

Python
where(
    condition: Tensor | NestedTensor,
    other: Tensor | NestedTensor | SupportsFloat,
) -> Self

Return a NestedTensor of elements selected from either self or other, depending on condition.

Examples:

Python Console Session
>>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
>>> nested_tensor.where(nested_tensor > 2, torch.tensor([[6, 5, 4], [3, 2, 1]]))
NestedTensor([
    [6, 5, 3],
    [4, 5]
])
>>> nested_tensor.where(nested_tensor > 2, NestedTensor([[6, 5, 4], [3, 2]]))
NestedTensor([
    [6, 5, 3],
    [4, 5]
])
>>> nested_tensor.where(torch.tensor(True), NestedTensor([[6, 5, 4], [3, 2]]))
NestedTensor([
    [1, 2, 3],
    [4, 5]
])
Source code in danling/tensors/nested_tensor.py
Python
def where(self, condition: Tensor | NestedTensor, other: Tensor | NestedTensor | SupportsFloat) -> Self:
    r"""
    Return a NestedTensor of elements selected from either self or other, depending on condition.

    Examples:
        >>> nested_tensor = NestedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        >>> nested_tensor.where(nested_tensor > 2, torch.tensor([[6, 5, 4], [3, 2, 1]]))
        NestedTensor([
            [6, 5, 3],
            [4, 5]
        ])
        >>> nested_tensor.where(nested_tensor > 2, NestedTensor([[6, 5, 4], [3, 2]]))
        NestedTensor([
            [6, 5, 3],
            [4, 5]
        ])
        >>> nested_tensor.where(torch.tensor(True), NestedTensor([[6, 5, 4], [3, 2]]))
        NestedTensor([
            [1, 2, 3],
            [4, 5]
        ])
    """
    return torch.where(condition, self, other)