mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
WIP: Tensor adaptors
This commit is contained in:
@@ -220,9 +220,9 @@ struct TensorAdaptorsKernel
|
||||
constexpr index_t M = 256;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t M0 = 4;
|
||||
constexpr index_t M1 = 64;
|
||||
constexpr index_t M1 = M / M0; // M1 = 64
|
||||
constexpr index_t K0 = 4;
|
||||
constexpr index_t K1 = 32;
|
||||
constexpr index_t K1 = K / K0; // K1 = 32
|
||||
|
||||
printf("Stage 1: [M=%ld, K=%ld] -> [M0=%ld, M1=%ld, K=%ld]\n",
|
||||
static_cast<long>(M),
|
||||
@@ -259,6 +259,30 @@ struct TensorAdaptorsKernel
|
||||
printf("\nTest: [M0=2, M1=32, K0=3, K1=16] -> [M=%ld, K=%ld]\n",
|
||||
static_cast<long>(bottom_idx[number<0>{}]),
|
||||
static_cast<long>(bottom_idx[number<1>{}]));
|
||||
|
||||
auto test_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
|
||||
make_unmerge_transform(make_tuple(number<K0>{}, number<K1>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}));
|
||||
auto test_bottom_idx = test_adaptor.calculate_bottom_index(top_idx);
|
||||
printf("\nTest1: [M0=2, M1=32, K0=3, K1=16] -> [M=%ld, K=%ld]\n",
|
||||
static_cast<long>(test_bottom_idx[number<0>{}]),
|
||||
static_cast<long>(test_bottom_idx[number<1>{}]));
|
||||
/*
|
||||
for(index_t m0 = 0; m0 < M0; m0++)
|
||||
for(index_t m1 = 0; m1 < M1; m1++)
|
||||
for(index_t k0 = 0; k0 < K0; k0++)
|
||||
for(index_t k1 = 0; k1 < K1; k1++)
|
||||
{
|
||||
index_t m = m0 * M1 + m1;
|
||||
index_t k = k0 * K1 + k1;
|
||||
index_t input_offset = m * K + k;
|
||||
// auto top_idx = make_tuple(m0, m1, k0, k1);
|
||||
// auto bottom_idx = final_adaptor.calculate_bottom_index(top_idx); //
|
||||
// [m, k]
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
printf("\n\n");
|
||||
@@ -278,9 +302,9 @@ struct TensorAdaptorsKernel
|
||||
constexpr index_t M = 128;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t M0 = 4;
|
||||
constexpr index_t M1 = 32;
|
||||
constexpr index_t M1 = M / M0; // M1 = 32;
|
||||
constexpr index_t K0 = 4;
|
||||
constexpr index_t K1 = 16;
|
||||
constexpr index_t K1 = K / K0; // K1 = 16;
|
||||
|
||||
printf("Adaptor A: [M=%ld, K=%ld] -> [M0=%ld, M1=%ld, K=%ld]\n",
|
||||
static_cast<long>(M),
|
||||
@@ -311,7 +335,7 @@ struct TensorAdaptorsKernel
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{}));
|
||||
|
||||
auto chained = chain_tensor_adaptors(adaptor_a, adaptor_b);
|
||||
auto chained = chain_tensor_adaptors(adaptor_a, adaptor_b); // union of both
|
||||
|
||||
printf("\nChained: [M=%ld, K=%ld] -> [M0=%ld, M1=%ld, K0=%ld, K1=%ld]\n",
|
||||
static_cast<long>(M),
|
||||
@@ -343,8 +367,8 @@ struct TensorAdaptorsKernel
|
||||
constexpr index_t NWaves = 4;
|
||||
constexpr index_t MPerXDL = 16;
|
||||
constexpr index_t NPerXDL = 16;
|
||||
constexpr index_t M0 = M / (MWaves * MPerXDL);
|
||||
constexpr index_t N0 = N / (NWaves * NPerXDL);
|
||||
constexpr index_t M0 = M / (MWaves * MPerXDL); // M0 = 4
|
||||
constexpr index_t N0 = N / (NWaves * NPerXDL); // N0 = 4
|
||||
|
||||
printf("GEMM C Matrix: [M=%ld, N=%ld]\n", static_cast<long>(M), static_cast<long>(N));
|
||||
printf("Tiling: [M0=%ld, N0=%ld, M1=%ld, N1=%ld, M2=%ld, N2=%ld]\n",
|
||||
@@ -360,8 +384,8 @@ struct TensorAdaptorsKernel
|
||||
make_tuple(number<M0>{}, number<MWaves>{}, number<MPerXDL>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<N0>{}, number<NWaves>{}, number<NPerXDL>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{}));
|
||||
make_tuple(sequence<0>{}, sequence<1>{}), // [M, N]
|
||||
make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{})); // [M0, N0, M1, N1, M2, N2]
|
||||
|
||||
auto top_idx = make_tuple(2, 3, 1, 2, 8, 12);
|
||||
auto bottom_idx = adaptor.calculate_bottom_index(top_idx);
|
||||
@@ -369,6 +393,61 @@ struct TensorAdaptorsKernel
|
||||
static_cast<long>(bottom_idx[number<0>{}]),
|
||||
static_cast<long>(bottom_idx[number<1>{}]));
|
||||
|
||||
// Check tile distribution
|
||||
for(index_t m0 = 0; m0 < M0; m0++)
|
||||
for(index_t n0 = 0; n0 < N0; n0++)
|
||||
for(index_t m1 = 0; m1 < MWaves; m1++)
|
||||
for(index_t n1 = 0; n1 < NWaves; n1++)
|
||||
{
|
||||
index_t tile_id = (m0 * N0 + n0) * (MWaves * NWaves) + (m1 * NWaves + n1);
|
||||
|
||||
auto tile_top_idx = make_tuple(m0, n0, m1, n1, 0, 0);
|
||||
auto tile_bottom_idx = adaptor.calculate_bottom_index(tile_top_idx);
|
||||
|
||||
using Coord2D = std::pair<index_t, index_t>;
|
||||
|
||||
Coord2D top_corner = {tile_bottom_idx[number<0>{}],
|
||||
tile_bottom_idx[number<1>{}]};
|
||||
|
||||
tile_top_idx = make_tuple(m0, n0, m1, n1, MPerXDL - 1, NPerXDL - 1);
|
||||
tile_bottom_idx = adaptor.calculate_bottom_index(tile_top_idx);
|
||||
|
||||
Coord2D bottom_corner = {tile_bottom_idx[number<0>{}],
|
||||
tile_bottom_idx[number<1>{}]};
|
||||
|
||||
printf("Tile ID %2ld: [M0=%ld, N0=%ld, M1=%ld, N1=%ld] = [%ld, %ld] -> "
|
||||
"[%ld, %ld]\n",
|
||||
static_cast<long>(tile_id),
|
||||
static_cast<long>(m0),
|
||||
static_cast<long>(n0),
|
||||
static_cast<long>(m1),
|
||||
static_cast<long>(n1),
|
||||
static_cast<long>(top_corner.first),
|
||||
static_cast<long>(top_corner.second),
|
||||
static_cast<long>(bottom_corner.first),
|
||||
static_cast<long>(bottom_corner.second));
|
||||
|
||||
if(tile_id == 171)
|
||||
{
|
||||
// print this tile
|
||||
printf("\nTile %2ld:\n", static_cast<long>(tile_id));
|
||||
|
||||
for(index_t m2 = 0; m2 < MPerXDL; m2++)
|
||||
{
|
||||
for(index_t n2 = 0; n2 < NPerXDL; n2++)
|
||||
{
|
||||
auto coord_top = make_tuple(m0, n0, m1, n1, m2, n2);
|
||||
auto coord_bottom = adaptor.calculate_bottom_index(coord_top);
|
||||
printf(" [%3ld, %3ld] ",
|
||||
static_cast<long>(coord_bottom[number<0>{}]),
|
||||
static_cast<long>(coord_bottom[number<1>{}]));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
printf("\n\n");
|
||||
}
|
||||
|
||||
@@ -416,10 +495,10 @@ struct TensorAdaptorsKernel
|
||||
for(index_t i = OrigSize; i < TotalSize; i++)
|
||||
{
|
||||
auto coord = make_tensor_coordinate(desc_padded, make_tuple(i));
|
||||
index_t offset = coord.get_offset();
|
||||
index_t offset = coord.get_offset(); // XXX: offset does not wrap!!!
|
||||
DataType val = p_data[offset];
|
||||
|
||||
printf(" Index %ld -> offset %ld -> value %.1f (wraps around)\n",
|
||||
printf(" Index %ld -> offset %ld -> value %.1f (wraps around\?\?\?)\n",
|
||||
static_cast<long>(i),
|
||||
static_cast<long>(offset),
|
||||
static_cast<float>(val));
|
||||
@@ -513,7 +592,7 @@ struct TensorAdaptorsKernel
|
||||
|
||||
// Add newline every 4 coordinates for readability
|
||||
count++;
|
||||
if(count % 4 == 0)
|
||||
if(count % 2 == 0)
|
||||
{
|
||||
printf("\n");
|
||||
}
|
||||
@@ -523,7 +602,7 @@ struct TensorAdaptorsKernel
|
||||
}
|
||||
}
|
||||
}
|
||||
if(count % 4 != 0)
|
||||
if(count % 2 != 0)
|
||||
printf("\n");
|
||||
|
||||
printf("\nKey Observations:\n");
|
||||
|
||||
Reference in New Issue
Block a user