WIP: Tensor adaptors

This commit is contained in:
Andriy Roshchenko
2026-01-31 06:49:26 +00:00
parent f1fcb64b37
commit 7d182d8628

View File

@@ -220,9 +220,9 @@ struct TensorAdaptorsKernel
constexpr index_t M = 256;
constexpr index_t K = 128;
constexpr index_t M0 = 4;
constexpr index_t M1 = 64;
constexpr index_t M1 = M / M0; // M1 = 64
constexpr index_t K0 = 4;
constexpr index_t K1 = 32;
constexpr index_t K1 = K / K0; // K1 = 32
printf("Stage 1: [M=%ld, K=%ld] -> [M0=%ld, M1=%ld, K=%ld]\n",
static_cast<long>(M),
@@ -259,6 +259,30 @@ struct TensorAdaptorsKernel
printf("\nTest: [M0=2, M1=32, K0=3, K1=16] -> [M=%ld, K=%ld]\n",
static_cast<long>(bottom_idx[number<0>{}]),
static_cast<long>(bottom_idx[number<1>{}]));
auto test_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_unmerge_transform(make_tuple(number<K0>{}, number<K1>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}));
auto test_bottom_idx = test_adaptor.calculate_bottom_index(top_idx);
printf("\nTest1: [M0=2, M1=32, K0=3, K1=16] -> [M=%ld, K=%ld]\n",
static_cast<long>(test_bottom_idx[number<0>{}]),
static_cast<long>(test_bottom_idx[number<1>{}]));
/*
for(index_t m0 = 0; m0 < M0; m0++)
for(index_t m1 = 0; m1 < M1; m1++)
for(index_t k0 = 0; k0 < K0; k0++)
for(index_t k1 = 0; k1 < K1; k1++)
{
index_t m = m0 * M1 + m1;
index_t k = k0 * K1 + k1;
index_t input_offset = m * K + k;
// auto top_idx = make_tuple(m0, m1, k0, k1);
// auto bottom_idx = final_adaptor.calculate_bottom_index(top_idx); //
// [m, k]
}
*/
}
printf("\n\n");
@@ -278,9 +302,9 @@ struct TensorAdaptorsKernel
constexpr index_t M = 128;
constexpr index_t K = 64;
constexpr index_t M0 = 4;
constexpr index_t M1 = 32;
constexpr index_t M1 = M / M0; // M1 = 32;
constexpr index_t K0 = 4;
constexpr index_t K1 = 16;
constexpr index_t K1 = K / K0; // K1 = 16;
printf("Adaptor A: [M=%ld, K=%ld] -> [M0=%ld, M1=%ld, K=%ld]\n",
static_cast<long>(M),
@@ -311,7 +335,7 @@ struct TensorAdaptorsKernel
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{}));
auto chained = chain_tensor_adaptors(adaptor_a, adaptor_b);
auto chained = chain_tensor_adaptors(adaptor_a, adaptor_b); // union of both
printf("\nChained: [M=%ld, K=%ld] -> [M0=%ld, M1=%ld, K0=%ld, K1=%ld]\n",
static_cast<long>(M),
@@ -343,8 +367,8 @@ struct TensorAdaptorsKernel
constexpr index_t NWaves = 4;
constexpr index_t MPerXDL = 16;
constexpr index_t NPerXDL = 16;
constexpr index_t M0 = M / (MWaves * MPerXDL);
constexpr index_t N0 = N / (NWaves * NPerXDL);
constexpr index_t M0 = M / (MWaves * MPerXDL); // M0 = 4
constexpr index_t N0 = N / (NWaves * NPerXDL); // N0 = 4
printf("GEMM C Matrix: [M=%ld, N=%ld]\n", static_cast<long>(M), static_cast<long>(N));
printf("Tiling: [M0=%ld, N0=%ld, M1=%ld, N1=%ld, M2=%ld, N2=%ld]\n",
@@ -360,8 +384,8 @@ struct TensorAdaptorsKernel
make_tuple(number<M0>{}, number<MWaves>{}, number<MPerXDL>{})),
make_unmerge_transform(
make_tuple(number<N0>{}, number<NWaves>{}, number<NPerXDL>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{}));
make_tuple(sequence<0>{}, sequence<1>{}), // [M, N]
make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{})); // [M0, N0, M1, N1, M2, N2]
auto top_idx = make_tuple(2, 3, 1, 2, 8, 12);
auto bottom_idx = adaptor.calculate_bottom_index(top_idx);
@@ -369,6 +393,61 @@ struct TensorAdaptorsKernel
static_cast<long>(bottom_idx[number<0>{}]),
static_cast<long>(bottom_idx[number<1>{}]));
// Check tile distribution
for(index_t m0 = 0; m0 < M0; m0++)
for(index_t n0 = 0; n0 < N0; n0++)
for(index_t m1 = 0; m1 < MWaves; m1++)
for(index_t n1 = 0; n1 < NWaves; n1++)
{
index_t tile_id = (m0 * N0 + n0) * (MWaves * NWaves) + (m1 * NWaves + n1);
auto tile_top_idx = make_tuple(m0, n0, m1, n1, 0, 0);
auto tile_bottom_idx = adaptor.calculate_bottom_index(tile_top_idx);
using Coord2D = std::pair<index_t, index_t>;
Coord2D top_corner = {tile_bottom_idx[number<0>{}],
tile_bottom_idx[number<1>{}]};
tile_top_idx = make_tuple(m0, n0, m1, n1, MPerXDL - 1, NPerXDL - 1);
tile_bottom_idx = adaptor.calculate_bottom_index(tile_top_idx);
Coord2D bottom_corner = {tile_bottom_idx[number<0>{}],
tile_bottom_idx[number<1>{}]};
printf("Tile ID %2ld: [M0=%ld, N0=%ld, M1=%ld, N1=%ld] = [%ld, %ld] -> "
"[%ld, %ld]\n",
static_cast<long>(tile_id),
static_cast<long>(m0),
static_cast<long>(n0),
static_cast<long>(m1),
static_cast<long>(n1),
static_cast<long>(top_corner.first),
static_cast<long>(top_corner.second),
static_cast<long>(bottom_corner.first),
static_cast<long>(bottom_corner.second));
if(tile_id == 171)
{
// print this tile
printf("\nTile %2ld:\n", static_cast<long>(tile_id));
for(index_t m2 = 0; m2 < MPerXDL; m2++)
{
for(index_t n2 = 0; n2 < NPerXDL; n2++)
{
auto coord_top = make_tuple(m0, n0, m1, n1, m2, n2);
auto coord_bottom = adaptor.calculate_bottom_index(coord_top);
printf(" [%3ld, %3ld] ",
static_cast<long>(coord_bottom[number<0>{}]),
static_cast<long>(coord_bottom[number<1>{}]));
}
printf("\n");
}
printf("\n");
}
}
printf("\n\n");
}
@@ -416,10 +495,10 @@ struct TensorAdaptorsKernel
for(index_t i = OrigSize; i < TotalSize; i++)
{
auto coord = make_tensor_coordinate(desc_padded, make_tuple(i));
index_t offset = coord.get_offset();
index_t offset = coord.get_offset(); // XXX: offset does not wrap!!!
DataType val = p_data[offset];
printf(" Index %ld -> offset %ld -> value %.1f (wraps around)\n",
printf(" Index %ld -> offset %ld -> value %.1f (wraps around\?\?\?)\n",
static_cast<long>(i),
static_cast<long>(offset),
static_cast<float>(val));
@@ -513,7 +592,7 @@ struct TensorAdaptorsKernel
// Add newline every 4 coordinates for readability
count++;
if(count % 4 == 0)
if(count % 2 == 0)
{
printf("\n");
}
@@ -523,7 +602,7 @@ struct TensorAdaptorsKernel
}
}
}
if(count % 4 != 0)
if(count % 2 != 0)
printf("\n");
printf("\nKey Observations:\n");