diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index 8e8026d88d..f97cc03d2a 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -1,3 +1,10 @@ add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp) add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp) + + +set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) \ No newline at end of file diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index f5ddcd278c..4cecf5fc8d 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -887,6 +887,58 @@ struct tile_window_with_static_lengths this->window_lengths_ = window_lengths; this->bottom_tensor_view_ = bottom_tensor_view; } + + /** + * @brief Print tile window elements for debugging. + * + * @tparam DataType Element data type (e.g., fp16_t, float, bf8_t) + * @param start_i Starting row (inclusive) + * @param end_i Ending row (exclusive) + * @param start_j Starting column (inclusive) + * @param end_j Ending column (exclusive) + * @param label Optional output label + * + * @note Tested on fp16. Custom types may need adjustments. + * @example tile_window.template print_tile_window_range(0, 4, 0, 8, "A"); + */ + template + CK_TILE_DEVICE void print_tile_window_range(index_t start_i, + index_t end_i, + index_t start_j, + index_t end_j, + const char* label = "") const + { + const auto& tensor_view = this->get_bottom_tensor_view(); + const auto window_origin = this->get_window_origin(); + + printf("%s Window Range [%d:%d, %d:%d] (origin: %d, %d):\n", + label, + start_i, + end_i - 1, + start_j, + end_j - 1, + window_origin[0], + window_origin[1]); + + for(index_t i = start_i; i < end_i; i++) + { + for(index_t j = start_j; j < end_j; j++) + { + // Create coordinate for this element relative to window origin + auto coord = + make_tensor_coordinate(tensor_view.get_tensor_descriptor(), + make_tuple(window_origin[0] + i, window_origin[1] + j)); + + // Get the element using thread buffer type directly + using ThreadBuf = thread_buffer; + auto buf = tensor_view.template get_vectorized_elements(coord, 0); + auto value = buf.at(number<0>{}); // Extract first element from thread buffer + printf(" %s[%d,%d] = %f", label, i, j, static_cast(value)); + } + printf("\n"); + } + printf("\n"); + } }; template diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 92ae6411a5..a891d4df55 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -646,16 +646,13 @@ struct StreamKTilePartitioner * @brief Get length of loop iterations for stream-k loop */ CK_TILE_DEVICE uint32_t GetCurrentIterLength(uint32_t iter_start, - uint32_t iter_end, - uint32_t total_iter_length) const noexcept + uint32_t iter_end) const noexcept { - uint32_t iter_length_mod, iter_length_quo /*unused*/; - k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod); - uint32_t total_iter_length_val = static_cast(total_iter_length); - uint32_t current_iter_length = - min(iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, - total_iter_length_val); - return current_iter_length; + // A WG's iter_end is either in the current C macro tile or not. + // If it is not, then the macro tile boundary is where the WG must stop. + uint32_t distance_to_tile_boundary = + k_iters_per_tile.get() - (iter_start % k_iters_per_tile.get()); + return min(iter_start + distance_to_tile_boundary, iter_end) - iter_start; } /** @@ -672,9 +669,7 @@ struct StreamKTilePartitioner CK_TILE_DEVICE void GetTileIdxWithOffset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const noexcept { - uint32_t tile_idx_val = static_cast(tile_idx); - uint32_t iter_offset_val = static_cast(iter_offset); - k_iters_per_tile.divmod(iter, tile_idx_val, iter_offset_val); + k_iters_per_tile.divmod(iter, tile_idx, iter_offset); } /** diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 704d0d01ee..dda38bbc47 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -374,7 +374,7 @@ struct GroupedGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset); + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = @@ -436,7 +436,7 @@ struct GroupedGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset); + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp index 77c431e49c..5df1f092d7 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp @@ -141,11 +141,17 @@ struct StreamKKernel return UniversalGemmKernel::BlockSize(); } - CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args) + /// @brief Constructs kernel arguments for the Stream-K kernel. + /// @param host_args Stream-K host arguments. + /// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device. + /// The caller may select their own to assist with test reproducibility, etc. + /// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may + /// select their own to assist with test reproducibility, etc. + /// @return The kernel arguments for Stream-K. + CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args, + int num_cu = NumCU(), + int occupancy = Occupancy()) { - uint32_t occupancy = static_cast(Occupancy()); - uint32_t num_cu = static_cast(NumCU()); - return StreamKKernelArgs{{host_args.as_ptr, host_args.bs_ptr, host_args.ds_ptr, @@ -166,14 +172,71 @@ struct StreamKKernel TilePartitioner{static_cast(host_args.M), static_cast(host_args.N), static_cast(host_args.K), - num_cu, - occupancy, + static_cast(num_cu), + static_cast(occupancy), host_args.num_sk_blocks}}; } - CK_TILE_HOST static bool - IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) + template + CK_TILE_DEVICE static void + RunGemm(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + CDataType* c_ptr, + void* smem_ptr_0, + const typename UniversalGemmKernel::KernelArgs& kargs, + const index_t num_loop, + const index_t block_idx_m, + const index_t block_idx_n, + const index_t k_size) { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + UniversalGemmKernel::template MakeGemmTensorViews( + as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size); + + const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = + UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0); + const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1); + const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2); + + // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute + // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this + // case, we call the GemmPipeline's operator() function that takes both has_hot_loop and + // tail_num. + const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + + const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0], + bs_block_window[UniversalGemmKernel::I0], + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + + if(UseDefaultScheduler || (get_warp_id() == 0)) + { + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + } + + CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs) + { + if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy."); + } + return false; + } return UniversalGemmKernel::IsSupportedArgument(kargs); } @@ -199,9 +262,81 @@ struct StreamKKernel kargs.workspace_ptr = workspace_ptr; } - // Temporary placeholder to support the Occupancy() static function. - // Since the Occupancy function uses kentry, this class must have an operator() function - CK_TILE_DEVICE void operator()(StreamKKernelArgs /*kargs*/) const {} + /// @brief Entry point for the Stream-K Kernel, performing the main Stream-K loop. + CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const + { + // Allocate LDS + __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()]; + + uint32_t block_idx = ck_tile::get_block_1d_id(); + + bool is_padding_block = + __builtin_amdgcn_readfirstlane(block_idx >= kargs.tile_partitioner.sk_num_blocks && + block_idx < kargs.tile_partitioner.dp_start_block_idx); + + // Padding blocks make it such that the DP blocks are aligned with the number of CUs; they + // should not partake in the GEMM + if(is_padding_block) + return; + + // Determine the K offset of the first and final macro tile in the A and B tensors along the + // K dimension. + uint32_t iter_start, iter_end; + kargs.tile_partitioner.GetBlockItr(block_idx, iter_start, iter_end); + + // Main Stream-K loop + while(true) + { + // Determine the number of macro tiles in A and B this WG is resposible for in the + // current C macro tile. + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + kargs.tile_partitioner.GetCurrentIterLength(iter_start, iter_end)); + + // Determine the 1D tile_idx and the iter_offset for this WG. + // The tile_idx is the 1D macro tile index in the C tensor. + // The iter_offset is the starting macro tile index in the K dimension for the WG in the + // current iteration of the while loop. + uint32_t tile_idx, iter_offset; + kargs.tile_partitioner.GetTileIdxWithOffset(iter_start, tile_idx, iter_offset); + + // Get the 2D tile index in the C tensor for this WG using the 1D index (i.e. tile_idx) + auto spatial_idx = kargs.tile_partitioner.GetOutputTileIndex(tile_idx); + + // Get the offsets in A, B, C tensors. + index_t i_m = static_cast(spatial_idx[UniversalGemmKernel::I0] * + TilePartitioner::MPerBlock); + index_t i_n = static_cast(spatial_idx[UniversalGemmKernel::I1] * + TilePartitioner::NPerBlock); + index_t i_k = static_cast(iter_offset) * TilePartitioner::KPerBlock; + + // Determine the total size along the K dimension the WG is using in this iteration + // (used to construct tensor views). + index_t k_size = static_cast(current_iter_length * TilePartitioner::KPerBlock); + + // Update pointer offsets for A, B, and C. + const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + i_k; + const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + i_k; + CDataType* c_ptr = static_cast(kargs.e_ptr); + + // Run the GEMM pipeline and Epilogue. + RunGemm({a_ptr}, + {b_ptr}, + {/*ds_ptr*/}, + c_ptr, + smem_ptr_0, + kargs, + current_iter_length, + i_m, + i_n, + k_size); + + // Prepare for next Stream-K loop iteration. + iter_start += current_iter_length; + if(iter_end <= iter_start) + break; + block_sync_lds(); + } + } private: CK_TILE_HOST static int NumCU() diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 8117d65758..cfba8b6c9d 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -579,7 +579,7 @@ struct UniversalGemmKernel const std::array& ds_ptr, EDataType* e_ptr, const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + const index_t k_size) { static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); @@ -591,7 +591,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( static_cast(as_ptr[i]), - make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.M, k_size), make_tuple(kargs.stride_As[i], 1), number{}, number<1>{}); @@ -600,7 +600,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( static_cast(as_ptr[i]), - make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(k_size, kargs.M), make_tuple(kargs.stride_As[i], 1), number{}, number<1>{}); @@ -617,7 +617,7 @@ struct UniversalGemmKernel if constexpr(TilePartitioner::BlockGemmShape::PermuteB) { constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; + const index_t K0 = k_size / K1; constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); const auto b_k0_n_k1_desc = @@ -638,7 +638,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( bs_ptr[i], - make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(k_size, kargs.N), make_tuple(kargs.stride_Bs[i], 1), number{}, number<1>{}); @@ -649,7 +649,7 @@ struct UniversalGemmKernel if constexpr(TilePartitioner::BlockGemmShape::PermuteB) { constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; + const index_t K0 = k_size / K1; constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); const auto b_k0_n_k1_desc = @@ -672,7 +672,7 @@ struct UniversalGemmKernel { index_t kFlatK = GemmPipeline::BlockGemmShape::flatKPerWarp * - (splitk_batch_offset.splitted_k / + (k_size / TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})); index_t kFlatN = kargs.N * kargs.K / kFlatK; @@ -687,7 +687,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( bs_ptr[i], - make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.N, k_size), make_tuple(kargs.stride_Bs[i], 1), number{}, number<1>{}); @@ -962,7 +962,7 @@ struct UniversalGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -1018,7 +1018,7 @@ struct UniversalGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 993df2ec40..9314d4b795 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -3,7 +3,9 @@ add_subdirectory(gemm) add_subdirectory(gemm_weight_preshuffle) add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) +add_subdirectory(grouped_gemm_preshuffle) add_subdirectory(gemm_multi_d) +add_subdirectory(gemm_streamk) add_subdirectory(data_type) add_subdirectory(container) add_subdirectory(elementwise) diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt new file mode 100644 index 0000000000..e00874ba07 --- /dev/null +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -0,0 +1,7 @@ +# Currently test_ck_tile_streamk is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9") + #TODO: support all arches + add_gtest_executable(test_ck_tile_streamk test_gemm_streamk.cpp) +else() + message(DEBUG "Skipping test_ck_tile_streamk tests for current target") +endif() diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk.cpp b/test/ck_tile/gemm_streamk/test_gemm_streamk.cpp new file mode 100644 index 0000000000..99c3fb397f --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk.cpp @@ -0,0 +1,14 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_types.hpp" +#include "test_gemm_streamk_util.hpp" +#include "gtest/gtest.h" + +#define TEST_SUITE_NAME TestCkTileStreamK + +TYPED_TEST_SUITE(TestCkTileStreamK, KernelTypesStreamK); + +#include "test_gemm_streamk_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc b/test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc new file mode 100644 index 0000000000..1db7ef0fb0 --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc @@ -0,0 +1,118 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_DP) +{ + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 0; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks4) +{ + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 4; + + this->Run(M, N, K, num_sk_blocks); +} + +// TODO: Renable this test once reduction is implemented +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks12) +{ + GTEST_SKIP() << "Skipping this test: There are precision issues with atomics due to >=3 WGs " + "contributing to each macro tile in C"; + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 12; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks8) +{ + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 8; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_DP) +{ + + ck_tile::index_t M = 512; + ck_tile::index_t N = 512; + ck_tile::index_t K = 512; + uint32_t num_sk_blocks = 0; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_SKBlocks16) +{ + + ck_tile::index_t M = 512; + ck_tile::index_t N = 512; + ck_tile::index_t K = 512; + uint32_t num_sk_blocks = 16; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_SKBlocks8) +{ + + ck_tile::index_t M = 512; + ck_tile::index_t N = 512; + ck_tile::index_t K = 512; + uint32_t num_sk_blocks = 8; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M3840_N4096_K4096_DP) +{ + + ck_tile::index_t M = 3840; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + uint32_t num_sk_blocks = 0; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M3840_N4096_K4096_SKBlocks64) +{ + + ck_tile::index_t M = 3840; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + uint32_t num_sk_blocks = 64; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_Unsupported_Reduction) +{ + + ck_tile::index_t M = 3840; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + uint32_t num_sk_blocks = 64; + + EXPECT_THROW(this->Run(M, N, K, num_sk_blocks, ck_tile::StreamKReductionStrategy::Reduction), + std::runtime_error); +} diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp new file mode 100644 index 0000000000..399f3f11e8 --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp @@ -0,0 +1,25 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using BF16 = ck_tile::bf16_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +using KernelTypesStreamK = ::testing::Types< +// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType + std::tuple< Row, Col, Row, F16, F16, F32, F16>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16> +>; + +// clang-format on diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp new file mode 100644 index 0000000000..b8a55b024d --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -0,0 +1,282 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // The logic below may need to become more advanced once bugs in Stream-K Tile Partitioner are + // resolved. Because the number of WGs contributing to a macro tile in C may not be the same for + // all macro tiles in C. + + // Calculate error due to more than 1 WG contributing to the same macro tile in C + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +class TestCkTileStreamK : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using AccDataType = std::tuple_element_t<5, Tuple>; + using CDataType = std::tuple_element_t<6, Tuple>; + using DsLayout = ck_tile::tuple<>; + using DsDataType = ck_tile::tuple<>; + + template + void invoke_streamk(const ck_tile::StreamKHostArgs& args, + const ck_tile::stream_config& s, + int num_cu, + int occupancy) + { + + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool kPadM = PadM; + constexpr bool kPadN = PadN; + constexpr bool kPadK = PadK; + constexpr bool preshuffle = Preshuffle; + + constexpr bool DoubleSmemBuffer = false; + constexpr int kBlockPerCu = 1; + constexpr bool StructuredSparsity = false; + constexpr bool NumWaveGroup = 1; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::StreamKTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + + // We create the GEMM pipeline without specifying has_hot_loop or tail_num. + // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K + // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K + // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + // For initial testing, we will just test with one pipeline. + // More extensive testing is coming later and will test other pipelines. + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + UniversalGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::StreamKKernel; + + auto kargs = Kernel::MakeKernelArgs(args, num_cu, occupancy); + + if(!Kernel::IsSupportedArgument(kargs)) + { + EXPECT_TRUE(false); + } + + dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner); + dim3 block_dims = Kernel::BlockSize(); + + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grid_dims, block_dims, 0, kargs)); + }; + + Run(ck_tile::integral_constant{}); + } + + public: + // Since Stream-K is build on gfx9, the lower bound for CUs is 104. Thus, we default num_cu to + // 104 and occupancy to 1 to ensure tests are reproducible on different architectures. + void Run(ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + uint32_t num_sk_blocks = 0xffffffff, + ck_tile::StreamKReductionStrategy reduction_strategy = + ck_tile::StreamKReductionStrategy::Atomic, + int occupancy = 1, + int num_cu = 104, + ck_tile::index_t stride_A = 0, + ck_tile::index_t stride_B = 0, + ck_tile::index_t stride_C = 0) + { + + using namespace ck_tile::literals; + + if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction) + { + throw std::runtime_error("Reduction Strategy is current unsupported!\n"); + } + + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + stride_A = f_get_default_stride(M, K, stride_A, ALayout{}); + stride_B = f_get_default_stride(K, N, stride_B, BLayout{}); + stride_C = f_get_default_stride(M, N, stride_C, CLayout{}); + + ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); + ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); + ck_tile::HostTensor c_m_n_dev_result( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + + ck_tile::FillUniformDistributionIntegerValue{-5, 5, /*seed*/ 11939}(a_m_k); + ck_tile::FillUniformDistributionIntegerValue{-5, 5, /*seed*/ 11940}(b_k_n); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + M, + N, + K, + stride_A, + stride_B, + stride_C, + reduction_strategy, + num_sk_blocks}; + + invoke_streamk( + args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + ck_tile::HostTensor c_m_n_host_ref( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, /*kbatch*/ 1, max_accumulated_value); + + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + EXPECT_TRUE(pass); + }; +}; diff --git a/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt b/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt new file mode 100644 index 0000000000..68120efc7e --- /dev/null +++ b/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt @@ -0,0 +1,9 @@ +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx94|gfx95") + add_gtest_executable(test_ck_tile_grouped_gemm_preshuffle test_grouped_gemm_preshuffle.cpp) + target_compile_options(test_ck_tile_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp new file mode 100644 index 0000000000..cf10853b3f --- /dev/null +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_preshuffle_util.hpp" + +using F16 = ck_tile::half_t; +using F8 = ck_tile::fp8_t; +using F32 = float; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// Custom tuple-like structure for kernel configuration +template +struct KernelConfig +{ + using ALayoutType = ALayout_; + using BLayoutType = BLayout_; + using CLayoutType = CLayout_; + using ADataType = ADataType_; + using BDataType = BDataType_; + using AccDataType = AccDataType_; + using CDataType = CDataType_; + + static constexpr int M_Tile_ = M_Tile_val_; + static constexpr int N_Tile_ = N_Tile_val_; + static constexpr int K_Tile_ = K_Tile_val_; + static constexpr int BlockPerCu_ = BlockPerCu_val_; +}; + +// clang-format off +using KernelTypes = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_Tile, N_Tile, K_Tile, BlockPerCu + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F8, F8, F32, F16, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 128, 128, 2>, + KernelConfig< Row, Col, Row, F8, F8, F32, F16, 128, 128, 128, 2> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmPreshuffle, KernelTypes); + +#include "test_grouped_gemm_preshuffle_ut_cases.inc" +#include "test_grouped_gemm_preshuffle_prefill_cases.inc" diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_prefill_cases.inc b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_prefill_cases.inc new file mode 100644 index 0000000000..340d807ba2 --- /dev/null +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_prefill_cases.inc @@ -0,0 +1,61 @@ +#pragma once + +// Test with prefill config struct +TYPED_TEST(TestCkTileGroupedGemmPreshuffle, PrefillVariant) +{ + const int group_count = 4; + const int kbatch = 1; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + for(int i = 0; i < group_count; i++) + { + + Ms.push_back(256 + 128 * i); + Ns.push_back(256 + 128 * i); + Ks.push_back(128 * (i + 1)); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Cs.push_back(Ns[i]); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count); +} + +TYPED_TEST(TestCkTileGroupedGemmPreshuffle, VariedDimensions) +{ + const int group_count = 6; + const int kbatch = 1; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + std::vector> test_cases = {{64, 128, 256}, + {128, 256, 512}, + {256, 512, 1024}, + {512, 256, 128}, + {128, 128, 128}, + {64, 512, 256}}; + + for(int i = 0; i < group_count; i++) + { + auto [M, N, K] = test_cases[i]; + Ms.push_back(M); + Ns.push_back(N); + Ks.push_back(K); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Cs.push_back(Ns[i]); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count); +} diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_ut_cases.inc b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_ut_cases.inc new file mode 100644 index 0000000000..beca5e62b5 --- /dev/null +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_ut_cases.inc @@ -0,0 +1,53 @@ +#pragma once + +// kPadK is not needed for these k values +TYPED_TEST(TestCkTileGroupedGemmPreshuffle, kPadKFalse) +{ + const int group_count = 4; + const int kbatch = 1; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 256 * i); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Cs.push_back(Ns[i]); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count); +} + +// kPadK is needed to be true for these k values +TYPED_TEST(TestCkTileGroupedGemmPreshuffle, kPadKTrue) +{ + const int group_count = 4; + const int kbatch = 1; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 128 * i); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Cs.push_back(Ns[i]); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count); +} diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp new file mode 100644 index 0000000000..799a5f2907 --- /dev/null +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -0,0 +1,374 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +template +constexpr ck_tile::index_t get_k_warp_tile_flatmm() +{ +#if defined(CK_GFX950_SUPPORT) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} + +template +class TestCkTileGroupedGemmPreshuffle : public ::testing::Test +{ + protected: + using ALayout = typename Tuple::ALayoutType; + using BLayout = typename Tuple::BLayoutType; + using CLayout = typename Tuple::CLayoutType; + using ADataType = typename Tuple::ADataType; + using BDataType = typename Tuple::BDataType; + using AccDataType = typename Tuple::AccDataType; + using CDataType = typename Tuple::CDataType; + using PrecType = BDataType; + using DsLayout = ck_tile::tuple<>; // not used + using DsDataType = ck_tile::tuple<>; // not used + + static const bool kPadM = false; + static const bool kPadN = false; + static const bool kPadK = true; // preshuffle pipeline requires k padding + + static const int kBlockPerCu = Tuple::BlockPerCu_; + + // Tile dimensions from tuple + static const ck_tile::index_t M_Tile = Tuple::M_Tile_; + static const ck_tile::index_t N_Tile = Tuple::N_Tile_; + static const ck_tile::index_t K_Tile = Tuple::K_Tile_; + + static const ck_tile::index_t M_Warp = 1; + static const ck_tile::index_t N_Warp = 4; + static const ck_tile::index_t K_Warp = 1; + + static const ck_tile::index_t M_Warp_Tile = 16; + static const ck_tile::index_t N_Warp_Tile = 16; + static const ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem + static constexpr bool TransposeC = false; // transpose c is not supported + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + + template + auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) + { + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + } + + using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; + inline std::size_t get_workspace_size(const std::vector& gemm_descs) + { + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); + } + + template + auto shuffle_b(const ck_tile::HostTensor& t) + { + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4; + ck_tile::HostTensor t_view( + {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + + template + void invoke_grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) + { + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + // for testing purposes, we can hardcode the values here as we what is compatible with + // pipeline + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = + ck_tile::GemmSpatiallyLocalTilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = + ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); + + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + // EXPECT TO FAIL because splitk is not supported + EXPECT_FALSE(true); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + } + + public: + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + std::vector& stride_As, + std::vector& stride_Bs, + std::vector& stride_Cs, + const int kbatch = 1, + const int group_count = 16) + { + + using namespace ck_tile::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + c_m_n_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + c_m_n_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + stride_As[i] = f_get_default_stride(M, K, stride_As[i], ALayout{}); + stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], BLayout{}); + stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{}); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + f_host_tensor_descriptor(M, K, stride_As[i], ALayout{}))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + f_host_tensor_descriptor(K, N, stride_Bs[i], BLayout{}))); + c_m_n_tensors.push_back(ck_tile::HostTensor( + f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + + // Host-side preshuffle of B + auto b_shuffle_host = shuffle_b(b_k_n_tensors[i]); + + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_shuffle_host.get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back(std::make_unique( + c_m_n_tensors[i].get_element_space_size_in_bytes())); + + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data()); + c_m_n_dev_buf[i]->SetZero(); + c_m_n_tensors[i].SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + + gemm_descs.push_back( + {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + } + + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + + invoke_grouped_gemm(gemm_descs, + ck_tile::stream_config{nullptr, false, 1}, + gemm_workspace.GetDeviceBuffer()); + + // Copy results back to host for validation + for(int i = 0; i < group_count; i++) + { + c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data()); + } + + bool pass{true}; + for(int i = 0; i < group_count; ++i) + { + ck_tile::HostTensor c_m_n_host_ref( + f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{})); + c_m_n_host_ref.SetZero(); + ck_tile::reference_gemm( + a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + Ks[i], kbatch, max_accumulated_value); + pass &= ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + } + EXPECT_TRUE(pass); + } +};