From 9175bef6795962ec0249ad6770c71eced32448d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Tue, 16 Sep 2025 15:29:01 +0000 Subject: [PATCH] Improved integration test. --- test/ck_tile/tensor_view/test_tensor_view.cpp | 99 ++++++++++++++----- 1 file changed, 76 insertions(+), 23 deletions(-) diff --git a/test/ck_tile/tensor_view/test_tensor_view.cpp b/test/ck_tile/tensor_view/test_tensor_view.cpp index 8c91f83c7e..63d03be48a 100644 --- a/test/ck_tile/tensor_view/test_tensor_view.cpp +++ b/test/ck_tile/tensor_view/test_tensor_view.cpp @@ -555,7 +555,18 @@ __global__ void test_4x4_matrix_get_2x2_blocks_kernel(int* input, int* output) distribution); auto distributed_tensor = tile_window.load(); + // Up to this point, we have set-up the distributed tensor for the 4x4 matrix + // similar to the output of the MFMA when 2 conv groups are merged. + // We want to copy only the diagonal 2x2 blocks to the output, similar to the epilogue + // part of batched iGEMM for 2 conv groups. + // Create output encoding for 4x2 matrix (2 diagonal blocks stacked vertically) + // The threads access the output distributed tensor in a sliced manner + // T0 -> row 0 + // T1 -> row 1 (this is masked out when copying from input) + // T2 -> row 2 (this is masked out when copying from input) + // T3 -> row 3 + // Each thread copies 2 elements (a row of a 2x2 block) constexpr auto output_encoding = tile_distribution_encoding< sequence<>, tuple< @@ -573,23 +584,74 @@ __global__ void test_4x4_matrix_get_2x2_blocks_kernel(int* input, int* output) auto output_global_view = make_naive_tensor_view_packed( output, make_tuple(4, 2)); - static_for<0, 2, 1>{}([&](auto row_offset) + auto get_block_number = [&]() -> index_t { - if (threadIdx.x == 0) - { - output_distributed_tensor.get_thread_buffer() = distributed_tensor.get_y_sliced_thread_data( - sequence<0, 0, row_offset, 0>{}, - sequence<1, 1, 1, 2>{}); // copy one row of a 2x2 block - } + const auto x_space_coord = distribution.calculate_index(); + if (x_space_coord[0] == 0 && x_space_coord[1] == 0) + { + return 0; + } + else if (x_space_coord[0] == 0 && x_space_coord[1] == 2) + { + return 1; + } + else if (x_space_coord[0] == 2 && x_space_coord[1] == 0) + { + return 2; + } + else if (x_space_coord[0] == 2 && x_space_coord[1] == 2) + { + return 3; + } + return -1; + }; - if (threadIdx.x == 3) + auto mask = [&]() -> bool + { + // Only blocks 0 and 3 are diagonal + const auto blockId = get_block_number(); + return (blockId == 0 || blockId == 3); + }; + + auto get_output_row_offset = [&](auto row) -> index_t + { + const auto blockId = get_block_number(); + if (blockId == 0) { - output_distributed_tensor.get_thread_buffer() = distributed_tensor.get_y_sliced_thread_data( - sequence<0, 0, 1-row_offset, 0>{}, - sequence<1, 1, 1, 2>{}); // copy one row of a 2x2 block + return row; + } + else if (blockId == 3) + { + return -row; + } + else + { + return -1000; // Invalid for other threads + } + }; + + // Because we copy one row at the time, we need to loop over the 2 rows of the 2x2 blocks. + // We mask out the threads that do contribute to the diagonal blocks. + static_for<0, 2, 1>{}([&](auto row) + { + if (mask()) + { + //const auto row_offset = input_row_offset(get_block_number()); + const auto block_id = get_block_number(); + if (block_id == 0) + { + output_distributed_tensor.get_thread_buffer() = distributed_tensor.get_y_sliced_thread_data( + sequence<0, 0, row, 0>{}, + sequence<1, 1, 1, 2>{}); // copy one row of a 2x2 block + } + else if (block_id == 3) + { + output_distributed_tensor.get_thread_buffer() = distributed_tensor.get_y_sliced_thread_data( + sequence<0, 0, 1 - row, 0>{}, //row 0 -> row 1, row 1 -> row 0 + sequence<1, 1, 1, 2>{}); // copy one row of a 2x2 block + } } - // Print the output distributed tensor for verification if constexpr (DebugOutput) { block_sync_lds(); @@ -608,8 +670,9 @@ __global__ void test_4x4_matrix_get_2x2_blocks_kernel(int* input, int* output) }); } - if (threadIdx.x == 0) + if (mask()) { + const auto row_offset = get_output_row_offset(row); auto output_tile_window = make_tile_window(output_global_view, make_tuple(4, 2), {row_offset, 0}, @@ -617,16 +680,6 @@ __global__ void test_4x4_matrix_get_2x2_blocks_kernel(int* input, int* output) output_tile_window.store(output_distributed_tensor); } - - if (threadIdx.x == 3) - { - auto output_tile_window = make_tile_window(output_global_view, - make_tuple(4, 2), - {-row_offset, 0}, - output_distribution); - - output_tile_window.store(output_distributed_tensor); - } }); }