From 437599c5178af4fae60b3c8de3e6ff2f4c81b7e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Fri, 19 Sep 2025 13:45:39 +0000 Subject: [PATCH] Improved CShuffle test. --- test/ck_tile/tensor_view/test_tensor_view.cpp | 124 ++++++++---------- 1 file changed, 56 insertions(+), 68 deletions(-) diff --git a/test/ck_tile/tensor_view/test_tensor_view.cpp b/test/ck_tile/tensor_view/test_tensor_view.cpp index b8328b58d9..c6eafac817 100644 --- a/test/ck_tile/tensor_view/test_tensor_view.cpp +++ b/test/ck_tile/tensor_view/test_tensor_view.cpp @@ -637,9 +637,6 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu // similar to the output of the MFMA when 2 conv groups are merged. // We want to copy only the diagonal 2x2 blocks to the output, similar to the epilogue // part of batched iGEMM for 2 conv groups. - - constexpr index_t MPerIterationShuffle = 2; - constexpr index_t NPerIterationShuffle = 2; constexpr index_t NumGroupsToMerge = 2; // Number of merged groups // Create tensor descriptor for the output 4x2 matrix (2 diagonal blocks stacked vertically) @@ -652,25 +649,31 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu number{}), {0, 0}); // We have only threadblock + //------------------------------------------------------------ + // CShuffle epilogue similation + //------------------------------------------------------------ + // Allocate and prepare LDS __shared__ char p_smem[MPerBlock * NPerBlock * sizeof(int)]; - constexpr index_t MPerThread = 2; - constexpr index_t NPerThread = 2; + // Initialize the LDS to zero + if(threadIdx.x == 0 && blockIdx.x == 0) + { + int* lds_data = reinterpret_cast(p_smem); + for (index_t i = 0; i < MPerBlock; i++) + { + for (index_t j = 0; j < NPerBlock; j++) + { + lds_data[i * NPerBlock + j] = 0; + } + } + } + __syncthreads(); - constexpr auto lds_tile_encoding = tile_distribution_encoding< - sequence<>, - tuple< - sequence<1, 1, 2, MPerThread>, - sequence<1, 1, 2, NPerThread>>, - tuple, sequence<1,2>>, - tuple, sequence<2,2>>, - sequence<1, 1, 2, 2>, - sequence<0, 3, 0, 3>>{}; + constexpr index_t MPerIterationShuffle = 2; + constexpr index_t NPerIterationShuffle = 2; - auto lds_tile_distribution = make_static_tile_distribution(lds_tile_encoding); - - auto lds_tile = make_static_distributed_tensor(lds_tile_distribution); + auto lds_tile = make_static_distributed_tensor(distribution); constexpr auto lds_block_desc = make_naive_tensor_descriptor( make_tuple(number{}, number{}), @@ -681,13 +684,14 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu auto in_lds_window = make_tile_window( o_lds_block, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, - lds_tile_distribution); + distribution); + // Row major thread to data mapping in LDS auto out_lds_window = make_tile_window( o_lds_block, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}); // Set-up traversing the 2x2 blocks @@ -703,8 +707,8 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu using TileEncodingPattern = tile_distribution_encoding_pattern_2d< 4, // Block size - MPerThread, - NPerThread, + MPerIterationShuffle, + NPerIterationShuffle, 2, // Vector size tile_distribution_pattern::sparse_row, 1>; // Number of wave groups @@ -712,54 +716,36 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu constexpr auto output_tile_distribution = TileEncodingPattern::make_2d_static_tile_distribution(); - // Copy the diagonal 2x2 block from register to global memrory via LDS. + // Copy the tile at one go from register to LDS. + block_sync_lds(); + + lds_tile.get_thread_buffer() = input_tensor.get_y_sliced_thread_data( + sequence<0, 0, 0, 0>{}, + sequence<1, 1, MPerIterationShuffle, MPerIterationShuffle>{}); + + store_tile(in_lds_window, lds_tile); + block_sync_lds(); + + // Print the contents of LDS + if (threadIdx.x == 0 && blockIdx.x == 0) + { + printf("LDS contents:\n"); + int* lds_data = reinterpret_cast(p_smem); + for (index_t i = 0; i < 4; i++) + { + for (index_t j = 0; j < 4; j++) + { + printf("%3d ", lds_data[i * 4 + j]); + } + printf("\n"); + } + } + + // For the output tensor, we need to copy only the diagonal 2x2 blocks to global memory. static_for<0, NumGroupsToMerge, 1>{} ( [&](auto group) { - constexpr auto iAccess = number{}; - if constexpr(group == 0) - { - block_sync_lds(); - constexpr auto idx_y_start = SFC::get_index(iAccess); - - static_assert(idx_y_start.size() == 2, "wrong!"); - - printf("Thread id: %u, Group %d, idx_y_start: (%d, %d)\n", - threadIdx.x, group.value, idx_y_start.at(number<0>{}).value, idx_y_start.at(number<1>{}).value); - - constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; - constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; - - printf("Thread id: %u, Group %d, mIter %d, nIter %d\n", threadIdx.x, group.value, mIter.value, nIter.value); - - __syncthreads(); - - lds_tile.get_thread_buffer() = input_tensor.get_y_sliced_thread_data( - sequence<0, 0, - mIter * MPerIterationShuffle, - nIter * NPerIterationShuffle>{}, - sequence<1, 1, MPerThread, NPerThread>{}); - - store_tile(in_lds_window, lds_tile); - block_sync_lds(); - } - - // Print the contents of LDS - if (threadIdx.x == 0 && blockIdx.x == 0) - { - printf("LDS contents after loading group %d:\n", group.value); - int* lds_data = reinterpret_cast(p_smem); - for (index_t i = 0; i < 4; i++) - { - for (index_t j = 0; j < 4; j++) - { - printf("%3d ", lds_data[i * 4 + j]); - } - printf("\n"); - } - } - auto out_tensor = load_tile(make_tile_window(out_lds_window, output_tile_distribution)); store_tile(output_window, out_tensor); @@ -769,12 +755,14 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu if (threadIdx.x == 0 && blockIdx.x == 0) { + printf("Output tensor contents after loading group %d:\n", group.value); for (index_t i = 0; i < 4; i++) { for (index_t j = 0; j < 2; j++) { - printf("Output(%d, %d) = %d\n", i, j, output[i * 2 + j]); + printf("%3d", output[i * 2 + j]); } + printf("\n"); } } __syncthreads(); @@ -785,7 +773,7 @@ __global__ void test_4x4_matrix_get_2x2_blocks_with_sfc_and_lds_kernel(int* inpu constexpr auto step = SFC_dram::get_forward_step(group); move_tile_window(output_window, {step.at(number<0>{}), step.at(number<1>{})}); - // TODO: This should not be needed. + constexpr auto iAccess = number{}; constexpr auto next_iAccess = number<(group+1) * NumGroupsToMerge + (group+1)>{}; constexpr auto step_lds = SFC::get_step_between(iAccess, next_iAccess); move_tile_window(out_lds_window, {step_lds.at(number<0>{}), step_lds.at(number<1>{})});