WIP: demonstrate_single_stage

This commit is contained in:
Andriy Roshchenko
2026-01-30 06:53:10 +00:00
parent b73e82e2e6
commit f1fcb64b37
2 changed files with 106 additions and 17 deletions

View File

@@ -502,10 +502,10 @@ int main()
std::cout << "Using GPU: " << props.name << "\n";
// Small tensor for demonstration
constexpr index_t H = 2;
constexpr index_t W = 3;
constexpr index_t C = 2;
constexpr index_t size = H * W * C;
constexpr index_t H = 2; // height
constexpr index_t W = 3; // width
constexpr index_t C = 2; // # of chanels
constexpr index_t size = H * W * C; // total tensor size
std::cout << "\nTensor configuration:\n";
std::cout << " Shape: [" << H << ", " << W << ", " << C << "]\n";

View File

@@ -30,7 +30,7 @@ struct TensorAdaptorsKernel
// Part 1: make_single_stage_tensor_adaptor examples
CK_TILE_DEVICE static void demonstrate_single_stage()
{
printf("PART 1: make_single_stage_tensor_adaptor\n");
printf("PART 1: <demonstrate_single_stage> make_single_stage_tensor_adaptor\n");
printf("=========================================\n\n");
printf(
@@ -44,22 +44,48 @@ struct TensorAdaptorsKernel
constexpr index_t M = 128;
constexpr index_t K = 64;
constexpr index_t M0 = 4;
constexpr index_t M1 = 32;
constexpr index_t M1 = M / M0; // M1 = 32
printf("Input layout: [M=%ld, K=%ld]\n", static_cast<long>(M), static_cast<long>(K));
printf("Goal: Split M into [M0=%ld, M1=%ld] for tiling\n",
static_cast<long>(M0),
static_cast<long>(M1));
/*
for(index_t m = 0; m < M; m++)
{
index_t m0 = m / M1;
index_t m1 = m % M1;
// printf(" M=%ld -> [M0=%ld, M1=%ld]\n", static_cast<long>(m),
// static_cast<long>(m0), static_cast<long>(m1)); for(index_t k = 0; k < K; k++)
{
index_t input_offset = m * K + k;
index_t output_offset = (m0 * M1 + m1) * K + k;
// printf(" K=%ld: input_offset=%ld -> output_offset=%ld\n",
// static_cast<long>(k), static_cast<long>(input_offset),
// static_cast<long>(output_offset));
}
}
*/
auto transforms =
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_pass_through_transform(number<K>{}));
auto lower_dims = make_tuple(sequence<0>{}, sequence<1>{});
auto upper_dims = make_tuple(sequence<0, 1>{}, sequence<2>{});
auto lower_dims = make_tuple(sequence<0>{}, sequence<1>{}); // 2D bottom index
auto upper_dims = make_tuple(sequence<0, 1>{}, sequence<2>{}); // 3D top index
auto adaptor = make_single_stage_tensor_adaptor(transforms, lower_dims, upper_dims);
/*
for(index_t m0 = 0; m0 < M0; m0++)
for(index_t m1 = 0; m1 < M1; m1++)
for(index_t k = 0; k < K; k++)
{
index_t m = m0 * M1 + m1;
index_t input_offset = m * K + k;
auto top_idx = make_tuple(m0, m1, k);
auto bottom_idx = adaptor.calculate_bottom_index(top_idx); // [m, k]
}
*/
printf("\nAdaptor created:\n");
printf(" Input: [M, K] = [%ld, %ld]\n", static_cast<long>(M), static_cast<long>(K));
printf(" Output: [M0, M1, K] = [%ld, %ld, %ld]\n",
@@ -68,7 +94,7 @@ struct TensorAdaptorsKernel
static_cast<long>(K));
auto top_idx = make_tuple(1, 16, 32);
auto bottom_idx = adaptor.calculate_bottom_index(top_idx);
auto bottom_idx = adaptor.calculate_bottom_index(top_idx); //[1 * M1 + 16, 32]
printf("\nTest: [M0=1, M1=16, K=32] -> [M=%ld, K=%ld]\n",
static_cast<long>(bottom_idx[number<0>{}]),
@@ -84,9 +110,9 @@ struct TensorAdaptorsKernel
constexpr index_t M = 256;
constexpr index_t N = 256;
constexpr index_t M0 = 4;
constexpr index_t M1 = 64;
constexpr index_t M1 = M / M0; // M1 = 64
constexpr index_t N0 = 4;
constexpr index_t N1 = 64;
constexpr index_t N1 = N / N0; // N1 = 64
printf("Input: [M=%ld, N=%ld]\n", static_cast<long>(M), static_cast<long>(N));
printf("Output: [M0=%ld, N0=%ld, M1=%ld, N1=%ld] (interleaved)\n",
@@ -95,19 +121,82 @@ struct TensorAdaptorsKernel
static_cast<long>(M1),
static_cast<long>(N1));
for(index_t m = 0; m < M; m++)
{
index_t m0 = m / M1;
index_t m1 = m % M1;
for(index_t n = 0; n < N; n++)
{
index_t n0 = n / N1;
index_t n1 = n % N1;
if(m0 == 2 && n0 == 3 && m1 == 16 && n1 == 32)
{
index_t input_offset = m * N + n;
index_t natural_output_offset =
m0 * M1 * N0 * N1 + m1 * N0 * N1 + n0 * N1 + n1;
index_t transf_output_offset =
m0 * M1 * N0 * N1 + n0 * M1 * N1 + m1 * N1 + n1;
printf(
"\n[m=%ld, n=%ld] -> [m0=%ld, n0=%ld, m1=%ld, n1=%ld] -> [%ld*M1+%ld, "
"%ld*N1+%ld] | input_offset=%ld=%ld*N+%ld -> "
"natural_output_offset=%ld=%ld * M1 * N0 * N1 + %ld * N0 * N1 + %ld * "
"N1 + %ld -> transf_output_offset=%ld=%ld * M1 * N0 * N1 + %ld * M1 * "
"N1 + %ld * N1 + %ld",
static_cast<long>(m),
static_cast<long>(n),
static_cast<long>(m0),
static_cast<long>(n0),
static_cast<long>(m1),
static_cast<long>(n1),
static_cast<long>(m0),
static_cast<long>(m1),
static_cast<long>(n0),
static_cast<long>(n1),
static_cast<long>(input_offset),
static_cast<long>(m),
static_cast<long>(n),
static_cast<long>(natural_output_offset),
static_cast<long>(m0),
static_cast<long>(m1),
static_cast<long>(n0),
static_cast<long>(n1),
static_cast<long>(transf_output_offset),
static_cast<long>(m0),
static_cast<long>(n0),
static_cast<long>(m1),
static_cast<long>(n1));
}
}
}
auto transforms =
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_unmerge_transform(make_tuple(number<N0>{}, number<N1>{})));
auto lower_dims = make_tuple(sequence<0>{}, sequence<1>{});
auto upper_dims = make_tuple(sequence<0, 2>{}, // M splits to dims 0,2
sequence<1, 3>{} // N splits to dims 1,3
);
auto lower_dims = make_tuple(sequence<0>{}, sequence<1>{}); // 2D bottom index
auto upper_dims = make_tuple(sequence<0, 2>{}, // M splits to dims 0,2
sequence<1, 3>{} // N splits to dims 1,3
); // 4D top index [M0, N0, M1, N1]
auto adaptor = make_single_stage_tensor_adaptor(transforms, lower_dims, upper_dims);
/*
for(index_t m0 = 0; m0 < M0; m0++)
for(index_t n0 = 0; n0 < N0; n0++)
for(index_t m1 = 0; m1 < M1; m1++)
for(index_t n1 = 0; n1 < N1; n1++)
{
index_t m = m0 * M1 + m1;
index_t n = n0 * N1 + n1;
index_t input_offset = m * N + n;
// auto top_idx = make_tuple(m0, n0, m1, n1);
// auto bottom_idx = adaptor.calculate_bottom_index(top_idx); // [m, n]
}
*/
auto top_idx = make_tuple(2, 3, 16, 32);
auto bottom_idx = adaptor.calculate_bottom_index(top_idx);
auto bottom_idx = adaptor.calculate_bottom_index(top_idx); //[2 * M1 + 16, 3 * N1 + 32]
// NOTE: Bottom index computation fuses 0 and 2 dimensions for M, and 1 and 3 for N
printf("\nTest: [M0=2, N0=3, M1=16, N1=32] -> [M=%ld, N=%ld]\n",
static_cast<long>(bottom_idx[number<0>{}]),
static_cast<long>(bottom_idx[number<1>{}]));