mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 20:27:42 +00:00
Improved CShuffle test.
This commit is contained in:
@@ -637,9 +637,6 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu
|
||||
// similar to the output of the MFMA when 2 conv groups are merged.
|
||||
// We want to copy only the diagonal 2x2 blocks to the output, similar to the epilogue
|
||||
// part of batched iGEMM for 2 conv groups.
|
||||
|
||||
constexpr index_t MPerIterationShuffle = 2;
|
||||
constexpr index_t NPerIterationShuffle = 2;
|
||||
constexpr index_t NumGroupsToMerge = 2; // Number of merged groups
|
||||
|
||||
// Create tensor descriptor for the output 4x2 matrix (2 diagonal blocks stacked vertically)
|
||||
@@ -652,25 +649,31 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu
|
||||
number<NPerBlock/NumGroupsToMerge>{}),
|
||||
{0, 0}); // We have only threadblock
|
||||
|
||||
//------------------------------------------------------------
|
||||
// CShuffle epilogue similation
|
||||
//------------------------------------------------------------
|
||||
|
||||
// Allocate and prepare LDS
|
||||
__shared__ char p_smem[MPerBlock * NPerBlock * sizeof(int)];
|
||||
|
||||
constexpr index_t MPerThread = 2;
|
||||
constexpr index_t NPerThread = 2;
|
||||
// Initialize the LDS to zero
|
||||
if(threadIdx.x == 0 && blockIdx.x == 0)
|
||||
{
|
||||
int* lds_data = reinterpret_cast<int*>(p_smem);
|
||||
for (index_t i = 0; i < MPerBlock; i++)
|
||||
{
|
||||
for (index_t j = 0; j < NPerBlock; j++)
|
||||
{
|
||||
lds_data[i * NPerBlock + j] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
constexpr auto lds_tile_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<
|
||||
sequence<1, 1, 2, MPerThread>,
|
||||
sequence<1, 1, 2, NPerThread>>,
|
||||
tuple<sequence<1, 2>, sequence<1,2>>,
|
||||
tuple<sequence<1, 1>, sequence<2,2>>,
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{};
|
||||
constexpr index_t MPerIterationShuffle = 2;
|
||||
constexpr index_t NPerIterationShuffle = 2;
|
||||
|
||||
auto lds_tile_distribution = make_static_tile_distribution(lds_tile_encoding);
|
||||
|
||||
auto lds_tile = make_static_distributed_tensor<int>(lds_tile_distribution);
|
||||
auto lds_tile = make_static_distributed_tensor<int>(distribution);
|
||||
|
||||
constexpr auto lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(number<MPerBlock>{}, number<NPerBlock>{}),
|
||||
@@ -681,13 +684,14 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu
|
||||
|
||||
auto in_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerThread>{}, number<NPerThread>{}),
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0},
|
||||
lds_tile_distribution);
|
||||
distribution);
|
||||
|
||||
// Row major thread to data mapping in LDS
|
||||
auto out_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerThread>{}, number<NPerThread>{}),
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
// Set-up traversing the 2x2 blocks
|
||||
@@ -703,8 +707,8 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu
|
||||
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<
|
||||
4, // Block size
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
2, // Vector size
|
||||
tile_distribution_pattern::sparse_row,
|
||||
1>; // Number of wave groups
|
||||
@@ -712,54 +716,36 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu
|
||||
constexpr auto output_tile_distribution =
|
||||
TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
|
||||
// Copy the diagonal 2x2 block from register to global memrory via LDS.
|
||||
// Copy the tile at one go from register to LDS.
|
||||
block_sync_lds();
|
||||
|
||||
lds_tile.get_thread_buffer() = input_tensor.get_y_sliced_thread_data(
|
||||
sequence<0, 0, 0, 0>{},
|
||||
sequence<1, 1, MPerIterationShuffle, MPerIterationShuffle>{});
|
||||
|
||||
store_tile(in_lds_window, lds_tile);
|
||||
block_sync_lds();
|
||||
|
||||
// Print the contents of LDS
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0)
|
||||
{
|
||||
printf("LDS contents:\n");
|
||||
int* lds_data = reinterpret_cast<int*>(p_smem);
|
||||
for (index_t i = 0; i < 4; i++)
|
||||
{
|
||||
for (index_t j = 0; j < 4; j++)
|
||||
{
|
||||
printf("%3d ", lds_data[i * 4 + j]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// For the output tensor, we need to copy only the diagonal 2x2 blocks to global memory.
|
||||
static_for<0, NumGroupsToMerge, 1>{}
|
||||
(
|
||||
[&](auto group)
|
||||
{
|
||||
constexpr auto iAccess = number<group * NumGroupsToMerge + group>{};
|
||||
if constexpr(group == 0)
|
||||
{
|
||||
block_sync_lds();
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
|
||||
static_assert(idx_y_start.size() == 2, "wrong!");
|
||||
|
||||
printf("Thread id: %u, Group %d, idx_y_start: (%d, %d)\n",
|
||||
threadIdx.x, group.value, idx_y_start.at(number<0>{}).value, idx_y_start.at(number<1>{}).value);
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
|
||||
printf("Thread id: %u, Group %d, mIter %d, nIter %d\n", threadIdx.x, group.value, mIter.value, nIter.value);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
lds_tile.get_thread_buffer() = input_tensor.get_y_sliced_thread_data(
|
||||
sequence<0, 0,
|
||||
mIter * MPerIterationShuffle,
|
||||
nIter * NPerIterationShuffle>{},
|
||||
sequence<1, 1, MPerThread, NPerThread>{});
|
||||
|
||||
store_tile(in_lds_window, lds_tile);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// Print the contents of LDS
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0)
|
||||
{
|
||||
printf("LDS contents after loading group %d:\n", group.value);
|
||||
int* lds_data = reinterpret_cast<int*>(p_smem);
|
||||
for (index_t i = 0; i < 4; i++)
|
||||
{
|
||||
for (index_t j = 0; j < 4; j++)
|
||||
{
|
||||
printf("%3d ", lds_data[i * 4 + j]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
auto out_tensor = load_tile(make_tile_window(out_lds_window, output_tile_distribution));
|
||||
|
||||
store_tile(output_window, out_tensor);
|
||||
@@ -769,12 +755,14 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0)
|
||||
{
|
||||
|
||||
printf("Output tensor contents after loading group %d:\n", group.value);
|
||||
for (index_t i = 0; i < 4; i++)
|
||||
{
|
||||
for (index_t j = 0; j < 2; j++)
|
||||
{
|
||||
printf("Output(%d, %d) = %d\n", i, j, output[i * 2 + j]);
|
||||
printf("%3d", output[i * 2 + j]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -785,7 +773,7 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu
|
||||
constexpr auto step = SFC_dram::get_forward_step(group);
|
||||
move_tile_window(output_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
|
||||
// TODO: This should not be needed.
|
||||
constexpr auto iAccess = number<group * NumGroupsToMerge + group>{};
|
||||
constexpr auto next_iAccess = number<(group+1) * NumGroupsToMerge + (group+1)>{};
|
||||
constexpr auto step_lds = SFC::get_step_between(iAccess, next_iAccess);
|
||||
move_tile_window(out_lds_window, {step_lds.at(number<0>{}), step_lds.at(number<1>{})});
|
||||
|
||||
Reference in New Issue
Block a user