Improved CShuffle test.

This commit is contained in:
Ville Pietilä
2025-09-19 13:45:39 +00:00
parent af6838e5dc
commit 437599c517

View File

@@ -637,9 +637,6 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu
// similar to the output of the MFMA when 2 conv groups are merged.
// We want to copy only the diagonal 2x2 blocks to the output, similar to the epilogue
// part of batched iGEMM for 2 conv groups.
constexpr index_t MPerIterationShuffle = 2;
constexpr index_t NPerIterationShuffle = 2;
constexpr index_t NumGroupsToMerge = 2; // Number of merged groups
// Create tensor descriptor for the output 4x2 matrix (2 diagonal blocks stacked vertically)
@@ -652,25 +649,31 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu
number<NPerBlock/NumGroupsToMerge>{}),
{0, 0}); // We have only threadblock
//------------------------------------------------------------
// CShuffle epilogue similation
//------------------------------------------------------------
// Allocate and prepare LDS
__shared__ char p_smem[MPerBlock * NPerBlock * sizeof(int)];
constexpr index_t MPerThread = 2;
constexpr index_t NPerThread = 2;
// Initialize the LDS to zero
if(threadIdx.x == 0 && blockIdx.x == 0)
{
int* lds_data = reinterpret_cast<int*>(p_smem);
for (index_t i = 0; i < MPerBlock; i++)
{
for (index_t j = 0; j < NPerBlock; j++)
{
lds_data[i * NPerBlock + j] = 0;
}
}
}
__syncthreads();
constexpr auto lds_tile_encoding = tile_distribution_encoding<
sequence<>,
tuple<
sequence<1, 1, 2, MPerThread>,
sequence<1, 1, 2, NPerThread>>,
tuple<sequence<1, 2>, sequence<1,2>>,
tuple<sequence<1, 1>, sequence<2,2>>,
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{};
constexpr index_t MPerIterationShuffle = 2;
constexpr index_t NPerIterationShuffle = 2;
auto lds_tile_distribution = make_static_tile_distribution(lds_tile_encoding);
auto lds_tile = make_static_distributed_tensor<int>(lds_tile_distribution);
auto lds_tile = make_static_distributed_tensor<int>(distribution);
constexpr auto lds_block_desc = make_naive_tensor_descriptor(
make_tuple(number<MPerBlock>{}, number<NPerBlock>{}),
@@ -681,13 +684,14 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu
auto in_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerThread>{}, number<NPerThread>{}),
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0},
lds_tile_distribution);
distribution);
// Row major thread to data mapping in LDS
auto out_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerThread>{}, number<NPerThread>{}),
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
// Set-up traversing the 2x2 blocks
@@ -703,8 +707,8 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<
4, // Block size
MPerThread,
NPerThread,
MPerIterationShuffle,
NPerIterationShuffle,
2, // Vector size
tile_distribution_pattern::sparse_row,
1>; // Number of wave groups
@@ -712,54 +716,36 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu
constexpr auto output_tile_distribution =
TileEncodingPattern::make_2d_static_tile_distribution();
// Copy the diagonal 2x2 block from register to global memrory via LDS.
// Copy the tile at one go from register to LDS.
block_sync_lds();
lds_tile.get_thread_buffer() = input_tensor.get_y_sliced_thread_data(
sequence<0, 0, 0, 0>{},
sequence<1, 1, MPerIterationShuffle, MPerIterationShuffle>{});
store_tile(in_lds_window, lds_tile);
block_sync_lds();
// Print the contents of LDS
if (threadIdx.x == 0 && blockIdx.x == 0)
{
printf("LDS contents:\n");
int* lds_data = reinterpret_cast<int*>(p_smem);
for (index_t i = 0; i < 4; i++)
{
for (index_t j = 0; j < 4; j++)
{
printf("%3d ", lds_data[i * 4 + j]);
}
printf("\n");
}
}
// For the output tensor, we need to copy only the diagonal 2x2 blocks to global memory.
static_for<0, NumGroupsToMerge, 1>{}
(
[&](auto group)
{
constexpr auto iAccess = number<group * NumGroupsToMerge + group>{};
if constexpr(group == 0)
{
block_sync_lds();
constexpr auto idx_y_start = SFC::get_index(iAccess);
static_assert(idx_y_start.size() == 2, "wrong!");
printf("Thread id: %u, Group %d, idx_y_start: (%d, %d)\n",
threadIdx.x, group.value, idx_y_start.at(number<0>{}).value, idx_y_start.at(number<1>{}).value);
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
printf("Thread id: %u, Group %d, mIter %d, nIter %d\n", threadIdx.x, group.value, mIter.value, nIter.value);
__syncthreads();
lds_tile.get_thread_buffer() = input_tensor.get_y_sliced_thread_data(
sequence<0, 0,
mIter * MPerIterationShuffle,
nIter * NPerIterationShuffle>{},
sequence<1, 1, MPerThread, NPerThread>{});
store_tile(in_lds_window, lds_tile);
block_sync_lds();
}
// Print the contents of LDS
if (threadIdx.x == 0 && blockIdx.x == 0)
{
printf("LDS contents after loading group %d:\n", group.value);
int* lds_data = reinterpret_cast<int*>(p_smem);
for (index_t i = 0; i < 4; i++)
{
for (index_t j = 0; j < 4; j++)
{
printf("%3d ", lds_data[i * 4 + j]);
}
printf("\n");
}
}
auto out_tensor = load_tile(make_tile_window(out_lds_window, output_tile_distribution));
store_tile(output_window, out_tensor);
@@ -769,12 +755,14 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu
if (threadIdx.x == 0 && blockIdx.x == 0)
{
printf("Output tensor contents after loading group %d:\n", group.value);
for (index_t i = 0; i < 4; i++)
{
for (index_t j = 0; j < 2; j++)
{
printf("Output(%d, %d) = %d\n", i, j, output[i * 2 + j]);
printf("%3d", output[i * 2 + j]);
}
printf("\n");
}
}
__syncthreads();
@@ -785,7 +773,7 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu
constexpr auto step = SFC_dram::get_forward_step(group);
move_tile_window(output_window, {step.at(number<0>{}), step.at(number<1>{})});
// TODO: This should not be needed.
constexpr auto iAccess = number<group * NumGroupsToMerge + group>{};
constexpr auto next_iAccess = number<(group+1) * NumGroupsToMerge + (group+1)>{};
constexpr auto step_lds = SFC::get_step_between(iAccess, next_iAccess);
move_tile_window(out_lds_window, {step_lds.at(number<0>{}), step_lds.at(number<1>{})});