From 98696413248802ab8007b709e5fc76988b5600b6 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 12 Dec 2025 09:27:12 -0800 Subject: [PATCH 01/10] disable test_tile_gemm_quant_bquant_preshuffle (#3420) --- test/ck_tile/gemm_block_scale/CMakeLists.txt | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 8309b14f0a..2b0ffaafa2 100755 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -24,10 +24,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") target_compile_options(test_tile_gemm_quant_bquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) # BQuant tests (with PreshuffleB) - add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle - test_gemm_quant_bquant_preshuffle.cpp - ) - target_compile_options(test_tile_gemm_quant_bquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # disabling this test until it can be built within reasonable time! + # currently taking ~50 minutes on gfx12! + #add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle + # test_gemm_quant_bquant_preshuffle.cpp + #) + #target_compile_options(test_tile_gemm_quant_bquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) # RowColQuant tests add_gtest_executable(test_tile_gemm_quant_rowcol From fc7bf0ab1c5ed28e5962681007f84a2e8d3ee051 Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Sat, 13 Dec 2025 01:28:37 +0800 Subject: [PATCH 02/10] [CK_TILE] Port hw independent changes from internal repo to develop branch (#3301) * [CK_TILE] Port hw independent changes from internal repo to develop branch It includes PR#96, #114, #120, #121. * correct rebase error --- example/ck_tile/03_gemm/gemm_utils.hpp | 2 +- example/ck_tile/03_gemm/run_gemm_example.inc | 4 +- .../fused_moe/kernel/moe_sorting_kernel.hpp | 2 + .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 1 + .../ops/gemm/kernel/streamk_gemm_kernel.hpp | 7 +- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 2 +- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 32 +- .../gemm_pipeline_agmem_bgmem_creg_v2.hpp | 25 +- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 5 +- .../ops/reduce/block/block_reduce2d.hpp | 2 +- include/ck_tile/utility/json_dump.hpp | 475 +++++++++--------- .../epilogue/test_cshuffle_epilogue_util.hpp | 2 +- test/ck_tile/gemm_multi_abd/CMakeLists.txt | 2 +- .../test_gemm_multi_abd_cshuffle.cpp | 15 +- .../test_gemm_multi_abd_default2d.cpp | 8 +- .../test_gemm_multi_abd_util.hpp | 36 +- .../test_gemm_pipeline_util.hpp | 24 +- .../grouped_gemm_multi_d/CMakeLists.txt | 2 +- .../test_grouped_gemm_multi_d.cpp | 53 +- .../grouped_gemm_preshuffle/CMakeLists.txt | 2 +- .../test_grouped_gemm_preshuffle.cpp | 12 +- .../test_grouped_gemm_preshuffle_util.hpp | 62 ++- 22 files changed, 465 insertions(+), 310 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index b25aec101b..47c47334e7 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -459,7 +459,7 @@ struct PipelineTypeTraits ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; }; -auto create_args() +inline auto create_args() { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "3840", "m dimension") diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index c4f100b36b..78f3a9b0b3 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -197,8 +197,8 @@ bool do_verify(const ck_tile::HostTensor& c_m_n_dev_result, return pass; } -std::tuple -parse_gemm_size(ck_tile::ArgParser& arg_parser) +std::tuple inline parse_gemm_size( + ck_tile::ArgParser& arg_parser) { ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 3445f063f5..52b2b86574 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -986,6 +986,8 @@ struct MoeSortingKernel p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id; } } + __syncthreads(); + smem_cumdup(num_experts) = smem_cumsum(num_experts); // fill the p_sorted_token_ids/p_sorted_weights diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 63993c5eb6..838fc236d2 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -561,6 +561,7 @@ struct GroupedGemmKernel const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( 0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d); Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d); + block_sync_lds(); block_id = block_id + grid_size; // advance to next block // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR if(block_id >= cum_grid_size) diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp index 91f1358321..6130107cfe 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp @@ -631,6 +631,7 @@ struct StreamKKernel tile_idx += kargs.tile_partitioner.get_grid()) { BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0); + block_sync_lds(); } // Stream-K section @@ -679,8 +680,8 @@ struct StreamKKernel { hipDeviceProp_t dev_prop; hipDevice_t dev; - hip_check_error(hipGetDevice(&dev)); - hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + ck_tile::hip_check_error(hipGetDevice(&dev)); + ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); int num_cu = dev_prop.multiProcessorCount; return num_cu; @@ -700,7 +701,7 @@ struct StreamKKernel constexpr int min_block_per_cu = 1; const auto kernel = kentry; - hip_check_error( + ck_tile::hip_check_error( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0)); return max(occupancy, 1); diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 4b28ac3f12..866a4cc693 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -280,7 +280,7 @@ struct UniversalGemmKernel using Kernel = UniversalGemmKernel; const auto kernel = kentry<1, Kernel, KernelArgs>; int occupancy; - hip_check_error( + ck_tile::hip_check_error( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0)); const int grid_size = get_available_compute_units(s) * occupancy; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 16ed8de22f..936c38ddf3 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -9,11 +9,35 @@ namespace ck_tile { +template +struct BaseGemmPipelineAGmemBGmemCRegV1 +{ + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr bool UsePersistentKernel = false; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; } + + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t) + { + return TailNumber::Empty; + } + + template + CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber) + { + return run_func(bool_constant{}, integral_constant{}); + } +}; + // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register template -struct GemmPipelineAGmemBGmemCRegV1 +struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1 { using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; @@ -48,14 +72,14 @@ struct GemmPipelineAGmemBGmemCRegV1 template static constexpr index_t GetVectorSizeA() { - return Problem::VectorSizeA; + return Policy::template GetVectorSizeA(); } template static constexpr index_t GetVectorSizeB() { - return Problem::VectorSizeB; + return Policy::template GetVectorSizeB(); } - static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index 5dbcde80a6..c711c768ec 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -9,11 +9,34 @@ namespace ck_tile { +template +struct BaseGemmPipelineAGmemBGmemCRegV2 +{ + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; } + + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t) + { + return TailNumber::Empty; + } + + template + CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber) + { + return run_func(bool_constant{}, integral_constant{}); + } +}; // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register template -struct GemmPipelineAGmemBGmemCRegV2 +struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2 { using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index d76fd6dc0f..47607a40f5 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -43,13 +43,14 @@ template + bool Preshuffle_ = false, + int VectorSize_ = 16> struct TileGemmUniversalTraits { static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; static constexpr bool kPadK = kPadK_; - static constexpr int _VectorSize = 16; + static constexpr int _VectorSize = VectorSize_; static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; using AsLayout = AsLayout_; diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index cbf4afefb2..ba6ed27651 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -425,7 +425,7 @@ struct BlockReduce2dCrossWarpSync if constexpr(num_reduce_warps == 1) return; - + block_sync_lds(); // Each warp's lane 0 writes its partial results to shared memory const index_t smem_offset = warp_id; if(lane_id == 0) diff --git a/include/ck_tile/utility/json_dump.hpp b/include/ck_tile/utility/json_dump.hpp index b5bab28cac..03e97c0b76 100644 --- a/include/ck_tile/utility/json_dump.hpp +++ b/include/ck_tile/utility/json_dump.hpp @@ -160,23 +160,23 @@ void dump_gemm_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_batched_gemm_json_results(const std::string& json_filename, - const std::string& op_name, - int M, - int N, - int K, - int stride_A, - int stride_B, - int stride_C, - int batch_stride_A, - int batch_stride_B, - int batch_stride_C, - int batch_count, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "batched_gemm_basic") +inline void dump_batched_gemm_json_results(const std::string& json_filename, + const std::string& op_name, + int M, + int N, + int K, + int stride_A, + int stride_B, + int stride_C, + int batch_stride_A, + int batch_stride_B, + int batch_stride_C, + int batch_count, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "batched_gemm_basic") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -218,20 +218,20 @@ void dump_grouped_gemm_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_flatmm_json_results(const std::string& json_filename, - const std::string& datatype, - int M, - int N, - int K, - int stride_A, - int stride_B, - int stride_C, - int kbatch, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "flatmm_basic") +inline void dump_flatmm_json_results(const std::string& json_filename, + const std::string& datatype, + int M, + int N, + int K, + int stride_A, + int stride_B, + int stride_C, + int kbatch, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "flatmm_basic") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -248,21 +248,22 @@ void dump_flatmm_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename, - const std::string& op_name, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideD0, - int StrideD1, - int StrideE, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "gemm_multi_d_fp16") +inline void +dump_gemm_multi_d_fp16_json_results(const std::string& json_filename, + const std::string& op_name, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideD1, + int StrideE, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "gemm_multi_d_fp16") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -280,14 +281,14 @@ void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_elementwise_json_results(const std::string& json_filename, - const std::string& prec, - int grid_size, - int block_size, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "elementwise") +inline void dump_elementwise_json_results(const std::string& json_filename, + const std::string& prec, + int grid_size, + int block_size, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "elementwise") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -298,22 +299,22 @@ void dump_elementwise_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_layernorm2d_fwd_json_results(const std::string& json_filename, - const std::string& prec_i, - const std::string& prec_o, - const std::string& prec_sm, - const std::string& prec_sy, - int m, - int n, - int x_stride, - int xr_stride, - int y_stride, - int yr_stride, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "layernorm2d_fwd") +inline void dump_layernorm2d_fwd_json_results(const std::string& json_filename, + const std::string& prec_i, + const std::string& prec_o, + const std::string& prec_sm, + const std::string& prec_sy, + int m, + int n, + int x_stride, + int xr_stride, + int y_stride, + int yr_stride, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "layernorm2d_fwd") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -357,13 +358,13 @@ void dump_reduce_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_permute_json_results(const std::string& json_filename, - const std::string& data_type, - bool pass, - float ave_time, - float tflop, - float gb_per_sec, - const std::string& kernel_name = "permute") +inline void dump_permute_json_results(const std::string& json_filename, + const std::string& data_type, + bool pass, + float ave_time, + float tflop, + float gb_per_sec, + const std::string& kernel_name = "permute") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -373,19 +374,19 @@ void dump_permute_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_topk_softmax_json(const std::string& json_filename, - const std::string& input_prec, - const std::string& weight_prec, - int tokens, - int experts, - int topk, - int stride_input, - int stride_output, - float ave_time, - float tflop, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "topk_softmax") +inline void dump_topk_softmax_json(const std::string& json_filename, + const std::string& input_prec, + const std::string& weight_prec, + int tokens, + int experts, + int topk, + int stride_input, + int stride_output, + float ave_time, + float tflop, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "topk_softmax") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -401,20 +402,20 @@ void dump_topk_softmax_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_rmsnorm2d_fwd_json(const std::string& json_filename, - const std::string& prec_str, - int m, - int n, - int x_stride, - int xr_stride, - int y_stride, - int yr_stride, - int use_model_sensitive_rmsnorm, - float ave_time, - float tflops, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "rmsnorm2d_fwd") +inline void dump_rmsnorm2d_fwd_json(const std::string& json_filename, + const std::string& prec_str, + int m, + int n, + int x_stride, + int xr_stride, + int y_stride, + int yr_stride, + int use_model_sensitive_rmsnorm, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "rmsnorm2d_fwd") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -431,19 +432,19 @@ void dump_rmsnorm2d_fwd_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_add_rmsnorm2d_rdquant_fwd_json( - const std::string& json_filename, - const std::string& input_data_type, - const std::string& quantized_data_type, - int m, - int n, - int stride, - float epsilon, - float ave_time, - float tflops, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd") +inline void +dump_add_rmsnorm2d_rdquant_fwd_json(const std::string& json_filename, + const std::string& input_data_type, + const std::string& quantized_data_type, + int m, + int n, + int stride, + float epsilon, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -458,17 +459,17 @@ void dump_add_rmsnorm2d_rdquant_fwd_json( END_JSON_DUMP_FILE(); } -void dump_smoothquant_json(const std::string& json_filename, - const std::string& prec_str, - int m, - int n, - int x_stride, - int y_stride, - float ave_time, - float tflops, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "smoothquant") +inline void dump_smoothquant_json(const std::string& json_filename, + const std::string& prec_str, + int m, + int n, + int x_stride, + int y_stride, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "smoothquant") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -482,19 +483,19 @@ void dump_smoothquant_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_moe_sorting_json(const std::string& json_filename, - const std::string& index_prec, - const std::string& weight_prec, - const std::string& workspace_size, - int dispatch_policy, - int tokens, - int num_experts, - int topk, - float ave_time, - float tflops, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "moe_sorting") +inline void dump_moe_sorting_json(const std::string& json_filename, + const std::string& index_prec, + const std::string& weight_prec, + const std::string& workspace_size, + int dispatch_policy, + int tokens, + int num_experts, + int topk, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "moe_sorting") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -510,19 +511,19 @@ void dump_moe_sorting_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_batched_transpose_json(const std::string& json_filename, - int N, - int C, - int H, - int W, - const std::string& layout_in, - const std::string& layout_out, - const std::string& prec, - float ave_time, - float tflops, - float gb_per_sec, - bool pass, - const std::string& kernel_name = "batched_transpose") +inline void dump_batched_transpose_json(const std::string& json_filename, + int N, + int C, + int H, + int W, + const std::string& layout_in, + const std::string& layout_out, + const std::string& prec, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "batched_transpose") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -538,19 +539,19 @@ void dump_batched_transpose_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_moe_smoothquant_json(const std::string& json_filename, - const std::string& prec_i, - const std::string& prec_o, - int tokens, - int hidden_size, - int stride, - int experts, - int topk, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "moe_smoothquant") +inline void dump_moe_smoothquant_json(const std::string& json_filename, + const std::string& prec_i, + const std::string& prec_o, + int tokens, + int hidden_size, + int stride, + int experts, + int topk, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "moe_smoothquant") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -566,26 +567,26 @@ void dump_moe_smoothquant_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_fused_moe_json(const std::string& json_filename, - const std::string& api_str, - const std::string& prec_str, - int tokens, - bool is_local_token, - int local_tokens, - int experts, - int topk, - int hidden_size, - int intermediate_size, - int stride, - int block_m, - int activation, - bool gate_only, - bool fused_quant, - bool pass, - float ave_time, - float tflops, - float tb_per_sec, - const std::string& kernel_name = "fused_moe") +inline void dump_fused_moe_json(const std::string& json_filename, + const std::string& api_str, + const std::string& prec_str, + int tokens, + bool is_local_token, + int local_tokens, + int experts, + int topk, + int hidden_size, + int intermediate_size, + int stride, + int block_m, + int activation, + bool gate_only, + bool fused_quant, + bool pass, + float ave_time, + float tflops, + float tb_per_sec, + const std::string& kernel_name = "fused_moe") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -610,29 +611,29 @@ void dump_fused_moe_json(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_fmha_fwd_json_results(const std::string& json_filename, - const std::string& prec, - const std::string& mode, - const std::string& io_layout, - int batch, - int nhead, - int nhead_k, - int seqlen_qs, - int seqlen_ks, - int seqlen_kpads, - int hdim_q, - int hdim_v, - float scale_s, - float p_drop, - bool lse, - const std::string& qscale, - const std::string& bias, - const std::string& vlayout, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "fmha_fwd") +inline void dump_fmha_fwd_json_results(const std::string& json_filename, + const std::string& prec, + const std::string& mode, + const std::string& io_layout, + int batch, + int nhead, + int nhead_k, + int seqlen_qs, + int seqlen_ks, + int seqlen_kpads, + int hdim_q, + int hdim_v, + float scale_s, + float p_drop, + bool lse, + const std::string& qscale, + const std::string& bias, + const std::string& vlayout, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "fmha_fwd") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); @@ -658,33 +659,33 @@ void dump_fmha_fwd_json_results(const std::string& json_filename, END_JSON_DUMP_FILE(); } -void dump_fmha_bwd_json_results(const std::string& json_filename, - const std::string& data_type, - const std::string& mode, - const std::string& i_perm, - const std::string& o_perm, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - float scale, - const std::string& bias, - bool use_dbias, - float p_drop, - bool s_randval, - bool deterministic, - const std::string& mask, - int mask_left, - int mask_right, - int workspace_size, - bool pass, - float ave_time, - float tflops, - float gb_per_sec, - const std::string& kernel_name = "fmha_bwd") +inline void dump_fmha_bwd_json_results(const std::string& json_filename, + const std::string& data_type, + const std::string& mode, + const std::string& i_perm, + const std::string& o_perm, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale, + const std::string& bias, + bool use_dbias, + float p_drop, + bool s_randval, + bool deterministic, + const std::string& mask, + int mask_left, + int mask_right, + int workspace_size, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "fmha_bwd") { START_JSON_DUMP_FILE(json_filename); ADD_KEY_VALUE("name", kernel_name); diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp index 4fdbf23864..9b90110c07 100644 --- a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp @@ -130,7 +130,7 @@ auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None) constexpr index_t kMPerBlock = Problem::kMPerBlock; constexpr index_t kNPerBlock = Problem::kNPerBlock; - constexpr index_t kBlockSize = Problem::kBlockSize; + index_t kBlockSize = ck_tile::is_wave32() ? Problem::kBlockSize / 2 : Problem::kBlockSize; std::cout << "Running CShuffleEpilogue test with M=" << M << ", N=" << N << ", MPerBlock=" << kMPerBlock << ", NPerBlock=" << kNPerBlock diff --git a/test/ck_tile/gemm_multi_abd/CMakeLists.txt b/test/ck_tile/gemm_multi_abd/CMakeLists.txt index 2dccf9cd60..03759652cd 100644 --- a/test/ck_tile/gemm_multi_abd/CMakeLists.txt +++ b/test/ck_tile/gemm_multi_abd/CMakeLists.txt @@ -7,7 +7,7 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") add_gtest_executable(test_ck_tile_gemm_multi_abd_cshuffle test_gemm_multi_abd_cshuffle.cpp) add_gtest_executable(test_ck_tile_gemm_multi_abd_default2d test_gemm_multi_abd_default2d.cpp) target_compile_definitions(test_ck_tile_gemm_multi_abd_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp index 08997529b2..ab00f16632 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp @@ -20,20 +20,21 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using KernelTypes = ::testing::Types< // Has cshuffle epilogue enabled // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog +#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, +#endif std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> - >; + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type> + >; // clang-format on TYPED_TEST_SUITE(TestCkTileGemmMultiABD, KernelTypes); diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp index dac33b4656..c4bfc3e7cb 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp @@ -20,17 +20,19 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using KernelTypes = ::testing::Types< // Has cshuffle epilogue disabled // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog +#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, +#endif std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> >; // clang-format on diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp index ee045c7f48..8cee050db2 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp @@ -23,6 +23,28 @@ static constexpr inline auto is_row_major(Layout layout_) ck_tile::tensor_layout::gemm::RowMajor>>{}; } +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if CK_TILE_USE_WMMA + return 16; +#else +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +#endif +} + template & args, const ck_tile::stream_config& s) { - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 32; + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; +#if CK_TILE_USE_WMMA + using ADataType = + ck_tile::remove_cvref_t{}, AsDataType>>; + constexpr ck_tile::index_t M_Warp_Tile = 16; + constexpr ck_tile::index_t N_Warp_Tile = 16; + constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); +#else constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 16; +#endif constexpr bool DoubleSmemBuffer = false; diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 43a73738d9..7c085b5098 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -13,6 +13,28 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if CK_TILE_USE_WMMA + return 16; +#else +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +#endif +} + template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, @@ -80,7 +102,7 @@ struct config_wmma static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; template diff --git a/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt b/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt index f86da3c4d5..5363e365fc 100644 --- a/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt @@ -9,7 +9,7 @@ endif() # Use standard asm for rtn bf16 conversion instead of turncate list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) -if(GPU_TARGETS MATCHES "gfx94|gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") add_gtest_executable(test_ck_tile_grouped_gemm_multi_d test_grouped_gemm_multi_d.cpp) target_compile_options(test_ck_tile_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp index 65c662199b..8d56c274aa 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp @@ -29,9 +29,6 @@ template ; - static constexpr int M_Tile_ = M_Tile_val_; - static constexpr int N_Tile_ = N_Tile_val_; - static constexpr int K_Tile_ = K_Tile_val_; - static constexpr int M_Warp_ = M_Warp_val_; - static constexpr int N_Warp_ = N_Warp_val_; - static constexpr int K_Warp_ = K_Warp_val_; - static constexpr int M_Warp_Tile_ = M_Warp_Tile_val_; - static constexpr int N_Warp_Tile_ = N_Warp_Tile_val_; - static constexpr int K_Warp_Tile_ = K_Warp_Tile_val_; + static constexpr int M_Tile_ = M_Tile_val_; + static constexpr int N_Tile_ = N_Tile_val_; + static constexpr int K_Tile_ = K_Tile_val_; + static constexpr int M_Warp_ = M_Warp_val_; + static constexpr int N_Warp_ = N_Warp_val_; + static constexpr int K_Warp_ = K_Warp_val_; +#if CK_TILE_USE_WMMA + static constexpr int M_Warp_Tile_ = 16; + static constexpr int N_Warp_Tile_ = 16; + static constexpr int K_Warp_Tile_ = 16; +#else + static constexpr int M_Warp_Tile_ = 32; + static constexpr int N_Warp_Tile_ = 32; + static constexpr int K_Warp_Tile_ = (M_Warp_val_ == 2) ? 16 : 8; +#endif static constexpr bool DoubleSmemBuffer_ = DoubleSmemBuffer_val_; static constexpr auto Scheduler_ = Scheduler_val_; static constexpr PipelineType Pipeline_ = Pipeline_val_; @@ -68,21 +71,21 @@ struct KernelConfig // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, ELayout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline, Persistent + // ALayout, BLayout, ELayout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, DoubleSmemBuffer, Scheduler, Pipeline, Persistent // FP16 A/B/D/E - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 - KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true>, // v4 + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 + KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true>, // v4 // BF16 A/B/D/E - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 - KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4 + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 + KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4 >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt b/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt index 08b413aea9..3a230aed0c 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt @@ -6,7 +6,7 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -if(GPU_TARGETS MATCHES "gfx94|gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_gtest_executable(test_ck_tile_grouped_gemm_preshuffle test_grouped_gemm_preshuffle.cpp) target_compile_options(test_ck_tile_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp index 623d0152d6..450b7b8f24 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp @@ -50,16 +50,16 @@ struct KernelConfig // clang-format off using KernelTypes = ::testing::Types< // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent ,M_Tile, N_Tile, K_Tile, BlockPerCu - KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>, +#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 16, 64, 256, 1>, - KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>, KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 128, 128, 128, 2>, - - KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>, KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 16, 64, 256, 1>, - KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>, KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 128, 128, 128, 2>, - +#endif + KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>, + KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>, KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>, KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>, KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 128, 128, 128, 2>, diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index 0eb388082b..5628b6feae 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -14,6 +14,9 @@ template constexpr ck_tile::index_t get_k_warp_tile_flatmm() { +#if CK_TILE_USE_WMMA + return 16; +#else #if defined(CK_GFX950_SUPPORT) if constexpr(M_Warp_Tile == 32) return sizeof(PrecType) == 2 ? 16 : 64; @@ -25,6 +28,7 @@ constexpr ck_tile::index_t get_k_warp_tile_flatmm() else return sizeof(PrecType) == 2 ? 32 : 64; #endif +#endif } template @@ -101,13 +105,40 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view( - {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if(ck_tile::is_gfx12_supported()) + { + constexpr int divisor = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / N_Warp_Tile, + N_Warp_Tile, + k_ / K_Warp_Tile, + kABK0PerLane, + divisor, + kABK1PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view( + {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } } template @@ -115,6 +146,11 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test const ck_tile::stream_config& s, void* kargs_ptr) { + constexpr ck_tile::index_t WaveSize = 32; + constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile); + constexpr bool SupportVectorSize16 = + (M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0; + constexpr int VectorSize = SupportVectorSize16 ? 16 : 8; using GemmShape = ck_tile::TileGemmShape, @@ -137,7 +173,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test /*UseStructuredSparsity*/ false, /*Persistent*/ false, /*NumWaveGroups*/ 1, - /*Preshuffle*/ true>; + /*Preshuffle*/ true, + VectorSize>; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem, ck_tile::sequence, @@ -230,7 +273,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test /*UseStructuredSparsity*/ false, /*Persistent*/ true, // Enable persistent mode /*NumWaveGroups*/ 1, - /*Preshuffle*/ true>; + /*Preshuffle*/ true, + VectorSize>; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem Date: Fri, 12 Dec 2025 19:26:47 +0100 Subject: [PATCH 03/10] Fix compilation ab scale multi target (#3413) --- .../gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp index ac5b7dd0c4..0974f45a2b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp @@ -527,11 +527,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale } else { -#if defined(__gfx11__) - // TODO: remove this restriction - static_assert(ScaleBlockM >= MPerWmma, - "ScaleBlockM must be greater equal than MPerWmma"); -#endif static_assert( ScaleBlockK >= WmmaSelector:: From 9707ddb444f42b490c73b7884babccde2988ed7e Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Fri, 12 Dec 2025 17:08:26 -0700 Subject: [PATCH 04/10] [CK TILE GEMM STREAMK] update identifier names according to the new code style (#3348) * [CK TILE GEMM STREAMK] update identifier names according to the new code style --- .../ck_tile/40_streamk_gemm/gemm_utils.hpp | 56 +-- .../40_streamk_gemm/run_gemm_example.inc | 380 +++++++++--------- .../40_streamk_gemm/streamk_gemm_basic.cpp | 204 +++++----- 3 files changed, 328 insertions(+), 312 deletions(-) diff --git a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp index dad31ec637..34c6c6b0ae 100644 --- a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp +++ b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp @@ -7,46 +7,46 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -struct GemmConfigBase +struct GemmConfigurationBase { - static constexpr bool kPadM = true; - static constexpr bool kPadN = true; - static constexpr bool kPadK = true; + static constexpr bool PAD_M = true; + static constexpr bool PAD_N = true; + static constexpr bool PAD_K = true; - static constexpr bool PermuteA = false; - static constexpr bool PermuteB = false; + static constexpr bool PERMUTE_A = false; + static constexpr bool PERMUTE_B = false; - static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = false; + static constexpr bool TRANSPOSE_C = false; + static constexpr bool USE_STRUCTURED_SPARSITY = false; - static constexpr int kBlockPerCu = 1; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool Preshuffle = false; - static constexpr bool DoubleSmemBuffer = false; + static constexpr int BLOCK_PER_CU = 1; + static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t NUM_WAVE_GROUPS = 1; + static constexpr bool PRESHUFFLE = false; + static constexpr bool DOUBLE_SMEM_BUFFER = false; }; -template -struct GemmConfigMemoryInterwave : public GemmConfigBase +template +struct GemmConfigurationMemoryInterwave : public GemmConfigurationBase { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 16; + static constexpr ck_tile::index_t M_TILE = 256; + static constexpr ck_tile::index_t N_TILE = 256; + static constexpr ck_tile::index_t K_TILE = 16; - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; + static constexpr ck_tile::index_t M_WARP = 2; + static constexpr ck_tile::index_t N_WARP = 2; + static constexpr ck_tile::index_t K_WARP = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + static constexpr ck_tile::index_t M_WARP_TILE = 32; + static constexpr ck_tile::index_t N_WARP_TILE = 32; + static constexpr ck_tile::index_t K_WARP_TILE = sizeof(PrecisionType) == 2 ? 8 : 16; - static constexpr bool Persistent = Persistent_; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool PERSISTENT = IsPersistent; + static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave; }; template -struct StreamKGemmTypeConfig +struct StreamKGemmTypeConfiguration { using ADataType = ADataType_; using BDataType = BDataType_; @@ -54,7 +54,7 @@ struct StreamKGemmTypeConfig using CDataType = CDataType_; }; -auto create_args(int argc, char* argv[]) +auto createArgs(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "512", "m dimension") diff --git a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc index d18ac2e68a..7442bd33f2 100644 --- a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc +++ b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc @@ -12,31 +12,35 @@ static constexpr inline auto is_row_major(Layout) } template -auto calculate_rtol_atol(const ck_tile::index_t K, - const ck_tile::index_t kbatch, - const float max_accumulated_value) +auto calculateRtolAtol(const ck_tile::index_t k_dim, + const ck_tile::index_t k_batch, + const float max_accumulated_value) { using ComputeType = std::conditional_t; // Calculate thresholds - const auto rtol = ck_tile::get_relative_threshold( - ck_tile::integer_divide_ceil(K, kbatch)); - const auto atol = ck_tile::get_absolute_threshold( - max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + const auto relative_tolerance = + ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(k_dim, k_batch)); + const auto absolute_tolerance = + ck_tile::get_absolute_threshold( + max_accumulated_value / k_batch, ck_tile::integer_divide_ceil(k_dim, k_batch)); // Calculate error due to multiple WGs working in the same C macro tile - const auto rtol_split_k = - ck_tile::get_relative_threshold(kbatch); - const auto atol_split_k = ck_tile::get_absolute_threshold( - max_accumulated_value, kbatch); + const auto relative_tolerance_split_k = + ck_tile::get_relative_threshold(k_batch); + const auto absolute_tolerance_split_k = + ck_tile::get_absolute_threshold(max_accumulated_value, + k_batch); // Use higher threshold - return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + return ck_tile::make_tuple(std::max(relative_tolerance, relative_tolerance_split_k), + std::max(absolute_tolerance, absolute_tolerance_split_k)); } -template std::tuple gemm(const ck_tile::StreamKHostArgs& args, - const ck_tile::stream_config& s); + const ck_tile::stream_config& stream_config); -template -std::tuple -invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, - ck_tile::DeviceMem& b_k_n_dev_buf, - ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::index_t M, - ck_tile::index_t N, - ck_tile::index_t K, - ck_tile::index_t stride_A, - ck_tile::index_t stride_B, - ck_tile::index_t stride_C, - int n_warmup, - int n_repeat, - bool flush_cache, - ck_tile::StreamKReductionStrategy reduction_strategy) +std::tuple invokeGemm(ck_tile::DeviceMem& a_m_k_device_memory, + ck_tile::DeviceMem& b_k_n_device_memory, + ck_tile::DeviceMem& c_m_n_device_memory, + ck_tile::index_t m_dim, + ck_tile::index_t n_dim, + ck_tile::index_t k_dim, + ck_tile::index_t stride_a, + ck_tile::index_t stride_b, + ck_tile::index_t stride_c, + int warmup_iterations, + int repeat_iterations, + bool flush_cache, + ck_tile::StreamKReductionStrategy reduction_strategy) { - ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - c_m_n_dev_buf.GetDeviceBuffer(), - M, - N, - K, - stride_A, - stride_B, - stride_C}; + ck_tile::StreamKHostArgs args{a_m_k_device_memory.GetDeviceBuffer(), + b_k_n_device_memory.GetDeviceBuffer(), + c_m_n_device_memory.GetDeviceBuffer(), + m_dim, + n_dim, + k_dim, + stride_a, + stride_b, + stride_c}; - std::tuple ave_time_and_batch; + std::tuple average_time_and_batch; if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic) { - ave_time_and_batch = gemm( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache}); + average_time_and_batch = gemm( + args, + ck_tile::stream_config{ + nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache}); } else /*Reduction*/ { - ave_time_and_batch = gemm( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache}); + average_time_and_batch = gemm( + args, + ck_tile::stream_config{ + nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache}); } - return ave_time_and_batch; + return average_time_and_batch; } template -bool do_verify(const ck_tile::HostTensor& c_m_n_dev_result, - const ck_tile::HostTensor& c_m_n_ref, - const ck_tile::tuple& rtol_atol, - const char* variant) +bool doVerify(const ck_tile::HostTensor& c_m_n_device_result, + const ck_tile::HostTensor& c_m_n_reference, + const ck_tile::tuple& relative_absolute_tolerances, + const char* variant) { - bool pass = ck_tile::check_err(c_m_n_dev_result, - c_m_n_ref, + bool pass = ck_tile::check_err(c_m_n_device_result, + c_m_n_reference, "Error: Incorrect results!", - rtol_atol.at(ck_tile::number<0>{}), - rtol_atol.at(ck_tile::number<1>{})); + relative_absolute_tolerances.at(ck_tile::number<0>{}), + relative_absolute_tolerances.at(ck_tile::number<1>{})); - std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) - << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "Relative error threshold: " + << relative_absolute_tolerances.at(ck_tile::number<0>{}) + << " Absolute error threshold: " + << relative_absolute_tolerances.at(ck_tile::number<1>{}) << std::endl; std::cout << "The " << variant << " verification result is:" << (pass ? "correct" : "fail") << std::endl; return pass; } -ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string& strategy) +ck_tile::StreamKReductionStrategy getReductionStrategyValue(const std::string& strategy) { if(strategy == "atomic") { @@ -156,172 +165,169 @@ ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string } } -template -int run_gemm_example_with_layouts(int argc, - char* argv[], - const ALayout a_layout = ALayout{}, - const BLayout b_layout = BLayout{}, - [[maybe_unused]] const CLayout c_layout = CLayout{}) +int runGemmExampleWithLayouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) { - auto [result, arg_parser] = create_args(argc, argv); + auto [result, arg_parser] = createArgs(argc, argv); if(!result) return -1; - static_assert(!GemmConfig::Preshuffle, "Not implemented"); - static_assert(!GemmConfig::UseStructuredSparsity, "Not implemented"); - static_assert(!GemmConfig::PermuteA, "Not implemented"); - static_assert(!GemmConfig::PermuteB, "Not implemented"); + static_assert(!GemmConfiguration::PRESHUFFLE, "Not implemented"); + static_assert(!GemmConfiguration::USE_STRUCTURED_SPARSITY, "Not implemented"); + static_assert(!GemmConfiguration::PERMUTE_A, "Not implemented"); + static_assert(!GemmConfiguration::PERMUTE_B, "Not implemented"); - using ADataType = typename TypeConfig::ADataType; - using BDataType = typename TypeConfig::BDataType; - using AccDataType = typename TypeConfig::AccDataType; - using CDataType = typename TypeConfig::CDataType; + using ADataType = typename TypeConfiguration::ADataType; + using BDataType = typename TypeConfiguration::BDataType; + using AccumulatorDataType = typename TypeConfiguration::AccDataType; + using CDataType = typename TypeConfiguration::CDataType; - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - ck_tile::index_t K = arg_parser.get_int("k"); + ck_tile::index_t m_dim = arg_parser.get_int("m"); + ck_tile::index_t n_dim = arg_parser.get_int("n"); + ck_tile::index_t k_dim = arg_parser.get_int("k"); - ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); - ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); - ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); - - int n_warmup = arg_parser.get_int("warmup"); - int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t stride_a = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_b = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_c = arg_parser.get_int("stride_c"); + int warmup_iterations = arg_parser.get_int("warmup"); + int repeat_iterations = arg_parser.get_int("repeat"); ck_tile::index_t init_method = arg_parser.get_int("init"); bool flush_cache = arg_parser.get_bool("flush_cache"); - ck_tile::StreamKReductionStrategy reduction_strategy = - get_reduction_strategy_value(arg_parser.get_str("reduction_strategy")); + getReductionStrategyValue(arg_parser.get_str("reduction_strategy")); - stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); - stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + stride_a = ck_tile::get_default_stride(m_dim, k_dim, stride_a, is_row_major(a_layout)); + stride_b = ck_tile::get_default_stride(k_dim, n_dim, stride_b, is_row_major(b_layout)); + stride_c = ck_tile::get_default_stride(m_dim, n_dim, stride_c, is_row_major(CLayout{})); - ck_tile::HostTensor a_m_k( - ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); - ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); - ck_tile::HostTensor c_m_n_dev_result( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + ck_tile::HostTensor a_m_k_host( + ck_tile::host_tensor_descriptor(m_dim, k_dim, stride_a, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n_host( + ck_tile::host_tensor_descriptor(k_dim, n_dim, stride_b, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_device_result( + ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{}))); if(init_method == 0) { - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n_host); } else if(init_method == 1) { - ck_tile::FillMonotonicSeq{}(a_m_k); - ck_tile::FillMonotonicSeq{}(b_k_n); + ck_tile::FillMonotonicSeq{}(a_m_k_host); + ck_tile::FillMonotonicSeq{}(b_k_n_host); } else if(init_method == 2) { - ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k); - ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n); + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n_host); } else { - a_m_k.SetZero(); - b_k_n.SetZero(); + a_m_k_host.SetZero(); + b_k_n_host.SetZero(); } - ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + ck_tile::DeviceMem a_m_k_device_memory(a_m_k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_device_memory(b_k_n_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_device_memory(c_m_n_device_result.get_element_space_size_in_bytes()); - a_m_k_dev_buf.ToDevice(a_m_k.data()); - b_k_n_dev_buf.ToDevice(b_k_n.data()); - c_m_n_dev_buf.SetZero(); - c_m_n_dev_result.SetZero(); + a_m_k_device_memory.ToDevice(a_m_k_host.data()); + b_k_n_device_memory.ToDevice(b_k_n_host.data()); + c_m_n_device_memory.SetZero(); + c_m_n_device_result.SetZero(); + auto [average_time, num_wgs_per_tile] = invokeGemm, + AccumulatorDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_device_memory, + b_k_n_device_memory, + c_m_n_device_memory, + m_dim, + n_dim, + k_dim, + stride_a, + stride_b, + stride_c, + warmup_iterations, + repeat_iterations, + flush_cache, + reduction_strategy); - auto [ave_time, num_wgs_per_tile] = invoke_gemm, - AccDataType, - CDataType, - ALayout, - BLayout, - ck_tile::tuple<>, - CLayout>(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - n_warmup, - n_repeat, - flush_cache, - reduction_strategy); + c_m_n_device_memory.FromDevice(c_m_n_device_result.data()); - c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_byte = - sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K - << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C + std::size_t flop = std::size_t(2) * m_dim * n_dim * k_dim; + std::size_t num_byte = sizeof(ADataType) * m_dim * k_dim + sizeof(BDataType) * n_dim * k_dim + + sizeof(CDataType) * m_dim * n_dim; + float tflops = static_cast(flop) / 1.E9 / average_time; + float gb_per_sec = num_byte / 1.E6 / average_time; + std::cout << "Run Gemm kernel with M=" << m_dim << " N=" << n_dim << " K=" << k_dim + << " StrideA=" << stride_a << " StrideB=" << stride_b << " StrideC=" << stride_c << " A_Layout=" << ALayout::name << " B_Layout=" << BLayout::name << " C_Layout=" << CLayout::name << " A_Type=" << ck_tile::DataTypeTraits::name << " B_Type=" << ck_tile::DataTypeTraits::name << " C_Type=" << ck_tile::DataTypeTraits::name << " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " " - << " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time + << " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << average_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; - bool pass = false; // Memory on host to store gpu reference result - ck_tile::HostTensor c_m_n_ref( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - c_m_n_ref.SetZero(); + ck_tile::HostTensor c_m_n_reference( + ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{}))); + c_m_n_reference.SetZero(); if(arg_parser.get_int("v") == 1) // Validate on the CPU { - ck_tile::reference_gemm( - a_m_k, b_k_n, c_m_n_ref); + ck_tile::reference_gemm( + a_m_k_host, b_k_n_host, c_m_n_reference); const float max_accumulated_value = - *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, num_wgs_per_tile, max_accumulated_value); - pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU"); + *std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end()); + const auto relative_absolute_tolerances = + calculateRtolAtol( + k_dim, num_wgs_per_tile, max_accumulated_value); + pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "CPU"); } else if(arg_parser.get_int("v") == 2) // Validate on the GPU { // Memory on device to store gpu reference result - ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes()); - c_m_n_gpu_buf_ref.SetZero(); - - ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); - BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); - CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + ck_tile::DeviceMem c_m_n_gpu_buffer_reference( + c_m_n_reference.get_element_space_size_in_bytes()); + c_m_n_gpu_buffer_reference.SetZero(); + ADataType* d_A = static_cast(a_m_k_device_memory.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_device_memory.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buffer_reference.GetDeviceBuffer()); ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); - - c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data()); + CLayout>( + d_A, d_B, d_C, m_dim, n_dim, k_dim, stride_a, stride_b, stride_c); + c_m_n_gpu_buffer_reference.FromDevice(c_m_n_reference.data()); const float max_accumulated_value = - *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, num_wgs_per_tile, max_accumulated_value); - pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU"); + *std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end()); + const auto relative_absolute_tolerances = + calculateRtolAtol( + k_dim, num_wgs_per_tile, max_accumulated_value); + pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "GPU"); } return pass; diff --git a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp index 83795fbf6a..d3ee9fe9c6 100644 --- a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp +++ b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp @@ -4,11 +4,11 @@ #include "gemm_utils.hpp" #include "ck_tile/ops/common.hpp" -template std::tuple gemm(const ck_tile::StreamKHostArgs& args, - const ck_tile::stream_config& s) + const ck_tile::stream_config& stream_config) { - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile:: - sequence, - GemmConfig::PermuteA, - GemmConfig::PermuteB>; + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence, + GemmConfiguration::PERMUTE_A, + GemmConfiguration::PERMUTE_B>; - using TilePartitioner = - ck_tile::StreamKTilePartitioner; + using TilePartitioner = ck_tile:: + StreamKTilePartitioner; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; - const auto Run = [&](const auto memory_operation) -> std::tuple { + const auto runKernel = [&](const auto memory_operation) -> std::tuple { // We create the GEMM pipeline without specifying has_hot_loop or tail_num. // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -61,39 +67,39 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, ck_tile::CShuffleEpilogueProblem>; + GemmConfiguration::NUM_WAVE_GROUPS>>; using Kernel = ck_tile::StreamKKernel; - auto kargs = Kernel::MakeKernelArgs(args); - const auto workspace_size = Kernel::GetWorkSpaceSize(kargs); + auto kernel_args = Kernel::MakeKernelArgs(args); + const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args); ck_tile::DeviceMem workspace_data(workspace_size); workspace_data.SetZero(); - kargs.workspace_ptr = workspace_data.GetDeviceBuffer(); + kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer(); - dim3 grids = Kernel::GridSize(kargs.tile_partitioner); + dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner); dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) + if(!Kernel::IsSupportedArgument(kernel_args)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - if(s.log_level_ > 0) + if(stream_config.log_level_ > 0) { std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' << "shape: " << GemmShape::GetName() << '\n' @@ -109,7 +115,7 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, { // Clear the output C tensor results after each repetition of the kernel hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); } else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) { @@ -120,45 +126,47 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, std::function preprocess = reset_data_buffers; - float ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + float average_time = + ck_tile::launch_kernel_time_mask(stream_config, + preprocess, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kernel_args)); - ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile(); - return std::tuple{ave_time, num_wgs_per_tile}; + ck_tile::index_t num_wgs_per_tile = + kernel_args.tile_partitioner.estimate_num_wgs_per_tile(); + return std::tuple{average_time, num_wgs_per_tile}; }; if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy) { - return Run(ck_tile::integral_constant{}); + return runKernel(ck_tile::integral_constant{}); } else // We are using ck_tile::StreamKReductionStrategy::Reduction { - return Run(ck_tile::integral_constant{}); + return runKernel(ck_tile::integral_constant{}); } } #include "run_gemm_example.inc" -template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +template +int runGemmExamplePrecisionType(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( + return runGemmExampleWithLayouts( argc, argv, Row{}, Col{}, Row{}); } else @@ -169,72 +177,74 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a return 0; } -template