From 68b20e1d4f3956eae5cbbeb6b5bd1d3c671b3a70 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Mon, 18 Aug 2025 17:12:50 +0000 Subject: [PATCH] Merge commit '9fcc1ee9fd9730efd865f530afde505f2556954d' into develop --- CHANGELOG.md | 1 + CMakeLists.txt | 2 - .../01_fmha/codegen/ops/fmha_batch_prefill.py | 4 +- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 24 +-- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 4 +- .../01_fmha/codegen/ops/fmha_fwd_appendkv.py | 4 +- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 8 +- .../codegen/ops/fmha_pagedkv_prefill.py | 4 +- example/ck_tile/02_layernorm2d/generate.py | 4 +- example/ck_tile/03_gemm/gemm_basic.cpp | 11 +- .../03_gemm/gemm_splitk_two_stage_reduce.cpp | 17 +- .../03_gemm/gemm_weight_preshuffle.cpp | 13 +- example/ck_tile/03_gemm/universal_gemm.cpp | 13 +- .../ck_tile/04_img2col/image_to_column.cpp | 5 +- example/ck_tile/05_reduce/reduce.cpp | 24 +-- .../matrix_core_swizzle_kernel.hpp | 5 +- example/ck_tile/06_permute/permute.cpp | 24 +-- .../09_topk_softmax/topk_softmax_api.cpp | 8 +- .../10_rmsnorm2d/example_rmsnorm2d_fwd.cpp | 5 +- example/ck_tile/10_rmsnorm2d/generate.py | 4 +- .../example_add_rmsnorm2d_rdquant_fwd.cpp | 5 +- ..._rmsnorm2d_rdquant_fwd_instance_common.hpp | 4 +- .../12_smoothquant/example_smoothquant.cpp | 5 +- .../instances/smoothquant_instance_common.hpp | 4 +- .../13_moe_sorting/moe_sorting_api.cpp | 24 +-- .../moe_smoothquant_instance_common.hpp | 4 +- .../instances/fused_moegemm_api_internal.hpp | 4 +- .../instances/fused_moesorting_api.cpp | 6 +- .../ck_tile/16_batched_gemm/batched_gemm.cpp | 7 +- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 9 +- example/ck_tile/18_flatmm/flatmm_basic.cpp | 15 +- .../ck_tile/18_flatmm/run_flatmm_example.inc | 14 +- .../19_gemm_multi_d/gemm_multi_d_fp16.cpp | 7 +- .../grouped_convolution_backward_weight.cpp | 7 +- .../grouped_convolution_forward.cpp | 7 +- .../21_elementwise/elementwise_example.cpp | 22 +-- .../elementwise_example_add_4d.cpp | 2 +- .../elementwise_example_transpose.cpp | 22 +-- .../elementwise_example_unary.cpp | 22 +-- .../batched_transpose_api.cpp | 8 +- .../38_block_scale_gemm/gemm_aquant_basic.cpp | 7 +- .../gemm_aquant_preshuffle.cpp | 7 +- example/ck_tile/39_copy/copy_basic.cpp | 20 +-- include/ck_tile/core/arch/arch.hpp | 19 +- include/ck_tile/core/config.hpp | 6 - include/ck_tile/host/kernel_launch.hpp | 12 +- .../add_rmsnorm2d_rdquant_fwd_kernel.hpp | 3 +- .../kernel/batched_transpose_kernel.hpp | 2 + .../batched_transpose_lds_problem.hpp | 5 +- .../elementwise/kernel/elementwise_kernel.hpp | 2 + .../ops/epilogue/cshuffle_epilogue.hpp | 3 +- .../ops/flatmm/kernel/flatmm_kernel.hpp | 16 +- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 7 +- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 9 +- .../fmha/kernel/fmha_fwd_appendkv_kernel.hpp | 3 +- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 7 +- .../fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 1 + .../fmha_fwd_splitkv_combine_kernel.hpp | 3 +- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 3 +- .../fused_moe/kernel/moe_sorting_kernel.hpp | 162 +++++++++--------- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 12 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 1 + .../ops/gemm/kernel/gemm_multi_d_kernel.hpp | 1 + .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 8 +- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 35 ++-- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 12 +- .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 12 +- .../gemm_pipeline_ag_bg_cr_comp_v5.hpp | 12 +- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 12 +- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 12 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 38 +++- .../wp_pipeline_agmem_bgmem_creg_v1.hpp | 6 +- .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 6 +- .../kernel/gemm_aquant_kernel.hpp | 20 +-- ...ped_convolution_backward_weight_kernel.hpp | 4 +- .../grouped_convolution_forward_kernel.hpp | 4 +- .../kernel/image_to_column_kernel.hpp | 3 +- .../pipeline/tile_image_to_column_shape.hpp | 7 +- .../kernel/layernorm2d_fwd_kernel.hpp | 6 +- .../permute/kernel/generic_permute_kernel.hpp | 2 +- .../ops/reduce/kernel/reduce2d_kernel.hpp | 2 + .../rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp | 1 + .../kernel/moe_smoothquant_kernel.hpp | 1 + .../smoothquant/kernel/smoothquant_kernel.hpp | 1 + .../kernel/topk_softmax_kernel.hpp | 4 +- include/ck_tile/ref/naive_attention.hpp | 6 +- ..._rmsnorm2d_rdquant_fwd_instance_common.hpp | 2 +- .../batched_gemm/test_batched_gemm_util.hpp | 7 +- .../test_batched_transpose.cpp | 10 +- .../elementwise/test_elementwise_1d.cpp | 24 ++- .../test_gemm_pipeline_basic_run_test.inc | 7 +- .../test_gemm_pipeline_universal_run_test.inc | 13 +- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 6 +- .../test_run_gemm_aquant_example.inc | 7 +- .../gemm_multi_d/test_gemm_multi_d_util.hpp | 7 +- .../test_gemm_pipeline_util.hpp | 5 +- .../grouped_gemm/test_grouped_gemm_util.hpp | 22 +-- .../test_tile_image_to_column.cpp | 4 +- test/ck_tile/layernorm2d/generate.py | 4 +- test/ck_tile/memory_copy/test_copy.cpp | 22 +-- test/ck_tile/memory_copy/test_copy.hpp | 3 +- .../moe_smoothquant_instance_common.hpp | 2 +- test/ck_tile/moe_sorting/moe_sorting_api.cpp | 24 +-- .../matrix_core_swizzle_kernel.hpp | 7 +- test/ck_tile/permute/test_permute_util.hpp | 8 +- test/ck_tile/reduce/test_reduce2d.cpp | 24 +-- test/ck_tile/rmsnorm2d/generate.py | 2 +- .../instances/smoothquant_instance_common.hpp | 2 +- .../topk_softmax/test_topk_softmax_api.cpp | 8 +- tile_engine/ops/gemm/codegen_utils.py | 1 - tile_engine/ops/gemm/gemm_instance_builder.py | 6 +- .../gemm_multi_d_codegen_utils.py | 1 - .../gemm_multi_d_instance_builder.py | 6 +- 113 files changed, 610 insertions(+), 531 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c09271edc..1246248eac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ None * Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced. * Number of instances in instance factory for grouped convolution backward weight NGCHW/GKYXC/NGKHW has been reduced. * Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced. +* Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK_TILE (#2594) ### Known issues diff --git a/CMakeLists.txt b/CMakeLists.txt index 07d2e166bb..35ebba8085 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -327,7 +327,6 @@ endif() if(USE_OPT_GFX11) add_compile_options(-mcumode) add_compile_options(-mno-wavefrontsize64) - add_compile_definitions(CK_TILE_WAVE32_ENABLED) message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}") endif() @@ -339,7 +338,6 @@ endif() if(USE_OPT_GFX12 AND (SUPPORTED_GPU_TARGETS MATCHES "gfx12")) add_compile_options(-mno-wavefrontsize64) - add_compile_definitions(CK_TILE_WAVE32_ENABLED) message(STATUS "CK compiled with USE_OPT_GFX12 set to ${USE_OPT_GFX12}") endif() diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 5d55e8bc36..0d8f366d8a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -110,9 +110,9 @@ float fmha_batch_prefill_(const ck_tile::stream_config& s, fmha_b if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index bb3a0587e7..0391191fb2 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -136,10 +136,10 @@ float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} template <> @@ -148,9 +148,9 @@ void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_co {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( ck_tile::stream_config{{s.stream_id_}}); }} @@ -425,10 +425,10 @@ float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} template <> @@ -436,9 +436,9 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_co {{ using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( ck_tile::stream_config{{s.stream_id_}}); }} @@ -530,10 +530,10 @@ float fmha_bwd_convert_dq_(const ck_tile::stream_confi if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} template <> @@ -542,9 +542,9 @@ void fmha_bwd_convert_dq_oneshot_(const ck_tile::strea {{ using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( ck_tile::stream_config{{s.stream_id_}}); }} diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index f614f42e6b..e59147a4f3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -110,9 +110,9 @@ float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 2e5bc2bd3d..0ebeaddf9c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -60,9 +60,9 @@ float fmha_fwd_appendkv_(const ck_tile::stream_config& s, fmha_fw if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index b2d962cd74..1dd8f0e3c6 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -108,9 +108,9 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ using k_ = fmha_kernel; auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); }} }}; }} @@ -208,9 +208,9 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ using k_ = fmha_kernel; auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); }} }}; }} diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 650ebaf80e..e468e82ed5 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -109,9 +109,9 @@ float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index d77582630a..c4366f6662 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -235,7 +235,7 @@ float layernorm2d_fwd_(const S& s, A a) using Kernel = ck_tile::Layernorm2dFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); @@ -243,7 +243,7 @@ float layernorm2d_fwd_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 25781a4ae8..8cdbe39e86 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -65,7 +65,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -81,8 +80,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -100,10 +99,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) << std::endl; } - float ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index a4a8039288..f42135a0b5 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. #include @@ -208,7 +208,6 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& DsLayout, ELayout, CDEElementWise, - UniversalGemmProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GemmConfig::M_Warp, @@ -232,7 +231,7 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& { grids = Kernel::GridSize(args.M, args.N, args.k_batch); } - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -279,15 +278,13 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& ave_time = ck_tile::launch_kernel_time_mask( s, run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } return ave_time; }; @@ -373,7 +370,7 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config float ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{}, kGridSize, kBlockSize, diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index 2057f1e4f5..0018db2c99 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -103,7 +103,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) DsLayout, ELayout, CDEElementWise, - UniversalGemmProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GemmConfig::M_Warp, @@ -126,7 +125,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { grids = Kernel::GridSize(args.M, args.N, args.k_batch); } - constexpr dim3 blocks = Kernel::BlockSize(); + dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -172,15 +171,13 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ave_time = ck_tile::launch_kernel_time_mask( s, run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } return ave_time; }; diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 149a8c2f0c..4e01710b4d 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -103,7 +103,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) DsLayout, ELayout, CDEElementWise, - UniversalGemmProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GemmConfig::M_Warp, @@ -127,7 +126,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { grids = Kernel::GridSize(args.M, args.N, args.k_batch); } - constexpr dim3 blocks = Kernel::BlockSize(); + dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -173,15 +172,13 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ave_time = ck_tile::launch_kernel_time_mask( s, run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } return ave_time; }; diff --git a/example/ck_tile/04_img2col/image_to_column.cpp b/example/ck_tile/04_img2col/image_to_column.cpp index 299a2f3444..22b5d640d8 100644 --- a/example/ck_tile/04_img2col/image_to_column.cpp +++ b/example/ck_tile/04_img2col/image_to_column.cpp @@ -55,13 +55,12 @@ float image_to_column(const image_to_column_traits& traits, args.N * args.output_spatial_lengths[0] * args.output_spatial_lengths[1], args.filter_spatial_lengths[0] * args.filter_spatial_lengths[1] * args.C, args.G); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 2; float ave_time = ck_tile::launch_kernel( - stream_conf, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + stream_conf, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; } diff --git a/example/ck_tile/05_reduce/reduce.cpp b/example/ck_tile/05_reduce/reduce.cpp index cf816caa88..a110c2f98d 100644 --- a/example/ck_tile/05_reduce/reduce.cpp +++ b/example/ck_tile/05_reduce/reduce.cpp @@ -94,18 +94,18 @@ bool run(const ck_tile::ArgParser& arg_parser) throw std::runtime_error("Wrong! Arguments not supported!\n"); } - float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, - ck_tile::make_kernel( - Kernel{}, - kGridSize, - kBlockSize, - 0, - static_cast(x_buf.GetDeviceBuffer()), - static_cast(y_buf.GetDeviceBuffer()), - input_shape, - input_strides, - kept_dim, - reduce_dims)); + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + input_shape, + input_strides, + kept_dim, + reduce_dims)); std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C; diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp index 688f4f3d50..d486196fc3 100644 --- a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -115,11 +115,12 @@ struct matrix_core_swizzle_kernel __host__ void operator()(const ck_tile::stream_config& s) const { - ck_tile::kentry<<>>(a); + ck_tile::kentry<1, kernel><<>>(a); } struct kernel { + static constexpr int kBlockSize = BLOCK_SIZE; __device__ static constexpr auto get_src_dist() { using namespace ck_tile; diff --git a/example/ck_tile/06_permute/permute.cpp b/example/ck_tile/06_permute/permute.cpp index 477ae370b9..aafece0f25 100644 --- a/example/ck_tile/06_permute/permute.cpp +++ b/example/ck_tile/06_permute/permute.cpp @@ -53,11 +53,11 @@ float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s) auto kargs = Kernel::MakeKargs(a); - const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(a); + const dim3 blocks = Kernel::BlockSize(); - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + float ave_time = + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs)); return ave_time; } @@ -69,11 +69,11 @@ float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s) auto kargs = Kernel::MakeKargs(a); - const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(a); + const dim3 blocks = Kernel::BlockSize(); - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + float ave_time = + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs)); return ave_time; } @@ -85,11 +85,11 @@ float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s) auto kargs = Kernel::MakeKargs(a); - const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(a); + const dim3 blocks = Kernel::BlockSize(); - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + float ave_time = + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs)); return ave_time; } diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp index 249a307b81..c2bad24cfe 100644 --- a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp +++ b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp @@ -13,11 +13,11 @@ \ auto kargs = kernel::MakeKargs(a); \ \ - const dim3 grids = kernel::GridSize(a); \ - constexpr dim3 blocks = kernel::BlockSize(); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(); \ \ - float ave_time = ck_tile::launch_kernel( \ - s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); \ + float ave_time = \ + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \ \ return ave_time; diff --git a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp index e0a71452ea..511efeeaec 100644 --- a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -138,12 +138,11 @@ bool run(const ck_tile::ArgParser& arg_parser) auto kargs = Kernel::MakeKargs(args); const dim3 grids = Kernel::GridSize(args); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat}; - ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::launch_kernel(s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); bool pass = true; diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index b0ba400af1..ea8dfdf9ce 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -249,7 +249,7 @@ float rmsnorm2d_fwd_(const S& s, A a) using Kernel = ck_tile::Rmsnorm2dFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); @@ -257,7 +257,7 @@ float rmsnorm2d_fwd_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp index 449bc17e04..ace5fe0c4f 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -136,12 +136,11 @@ bool run(const ck_tile::ArgParser& arg_parser) auto kargs = Kernel::MakeKargs(args); const dim3 grids = Kernel::GridSize(args); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat}; - ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::launch_kernel(s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); bool pass = true; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp index 25b10e1dc4..d997596414 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp @@ -58,7 +58,7 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) using Kernel = ck_tile::AddRmsnorm2dRdquantFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); @@ -66,5 +66,5 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } diff --git a/example/ck_tile/12_smoothquant/example_smoothquant.cpp b/example/ck_tile/12_smoothquant/example_smoothquant.cpp index 5fcacacee8..e688947d71 100644 --- a/example/ck_tile/12_smoothquant/example_smoothquant.cpp +++ b/example/ck_tile/12_smoothquant/example_smoothquant.cpp @@ -126,12 +126,11 @@ bool run(const ck_tile::ArgParser& arg_parser) auto kargs = Kernel::MakeKargs(args); const dim3 grids = Kernel::GridSize(args); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto s = ck_tile::stream_config{nullptr, true, 1, warmup, repeat}; - ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::launch_kernel(s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); bool pass = true; diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp b/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp index 555159566e..873a474afb 100644 --- a/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp @@ -50,7 +50,7 @@ float smoothquant_(const S& s, A a) using Kernel = ck_tile::Smoothquant; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); @@ -58,5 +58,5 @@ float smoothquant_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index a71c5e51a6..d614b8462a 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -209,7 +209,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -227,7 +227,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #if MOE_SORTING_SUPPORT_LARGE_EXPERT #define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -283,7 +283,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ const auto lds_size = kernel::GetSmemSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() #define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ @@ -334,15 +334,15 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi } \ } -#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \ - [&]() { \ - using problem_ = \ - ck_tile::MoeSortingClearWorkspaceProblem; \ - using kernel = ck_tile::MoeSortingClearWorkspaceKernel; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \ + [&]() { \ + using problem_ = \ + ck_tile::MoeSortingClearWorkspaceProblem; \ + using kernel = ck_tile::MoeSortingClearWorkspaceKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp index 885d9ff7bf..607217ea52 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp @@ -53,7 +53,7 @@ float moe_smoothquant_(const S& s, A a) using Kernel = ck_tile::MoeSmoothquant; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); @@ -61,5 +61,5 @@ float moe_smoothquant_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp index 6e54df9fde..9d1675386f 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp @@ -53,7 +53,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) using f_kernel = ck_tile::FusedMoeGemmKernel; const dim3 grids = f_kernel::GridSize(a); - constexpr dim3 blocks = f_kernel::BlockSize(); + const dim3 blocks = f_kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; static int printed = 0; @@ -66,5 +66,5 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) } return ck_tile::launch_kernel( - s, ck_tile::make_kernel(f_kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(f_kernel{}, grids, blocks, 0, kargs)); } diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 5f87393a0a..441aa84edf 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -213,7 +213,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -231,7 +231,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #if MOE_SORTING_SUPPORT_LARGE_EXPERT #define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -287,7 +287,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ const auto lds_size = kernel::GetSmemSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() #define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index 9616abb800..09ba010e00 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -142,7 +142,6 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre DsLayout, CLayout, CDEElementWise, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -156,8 +155,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre using Kernel = ck_tile::BatchedGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -176,7 +175,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre } ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index a821af0649..1e6844261f 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -82,7 +82,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GemmConfig::M_Warp, @@ -92,9 +91,9 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, GemmConfig::K_Warp_Tile, UniversalGemmProblem::TransposeC, memory_operation>>; - using Kernel = ck_tile::GroupedGemmKernel; - constexpr dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); if(s.log_level_ > 0) { @@ -105,7 +104,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{}, grids, blocks, diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 50bf791207..93117e5b75 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -101,7 +101,6 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c DsLayout, ELayout, CDEElementWise, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, FlatmmConfig::M_Warp, @@ -119,8 +118,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -171,15 +170,13 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c ave_time = ck_tile::launch_kernel_time_mask( s, run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } return ave_time; }; diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index 8f39b07be5..013db6715d 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -42,7 +42,9 @@ auto shuffle_b(const ck_tile::HostTensor& t) assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; - constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4; + + int divisor = ck_tile::is_wave32() ? (FlatmmConfig::N_Warp_Tile == 32 ? 1 : 2) + : (FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4); ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, FlatmmConfig::N_Warp_Tile, k_ / FlatmmConfig::K_Warp_Tile, @@ -213,6 +215,16 @@ int run_flatmm_example_with_layouts(int argc, ck_tile::FillUniformDistribution{1.f, 1.f}(a_host); ck_tile::FillUniformDistribution{1.f, 1.f}(b_origin_host); } + else if(init_method == 3) + { + ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_origin_host); + } + else if(init_method == 4) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_origin_host); + } else { a_host.SetZero(); diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp index d7bf2b5c42..fc52cb66cc 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp @@ -146,7 +146,6 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& DsLayout, CLayout, CDEElementWise, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -160,8 +159,8 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& using Kernel = ck_tile::GemmKernelMultiD; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -176,7 +175,7 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& } ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp index 67db775e09..debbb6bc0c 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp @@ -78,7 +78,6 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, typename GroupedConvTraitsType::ImplicitGemmDsLayout, ck_tile::tensor_layout::gemm::RowMajor, CDEElementWise, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -98,8 +97,8 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, ConvEpilogue>; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(kargs); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -123,7 +122,7 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, float ave_time = ck_tile::launch_kernel_time_mask( s, Kernel::Preprocess(kargs, s), - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp index ce19c77bc1..6700970583 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp @@ -77,7 +77,6 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til typename GroupedConvTraitsType::ImplicitGemmDsLayout, ck_tile::tensor_layout::gemm::RowMajor, CDEElementWise, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -97,8 +96,8 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til ConvEpilogue>; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(kargs); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -120,7 +119,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til } float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/21_elementwise/elementwise_example.cpp b/example/ck_tile/21_elementwise/elementwise_example.cpp index 469345b46c..2cc539e117 100644 --- a/example/ck_tile/21_elementwise/elementwise_example.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example.cpp @@ -167,17 +167,17 @@ bool run(const ck_tile::ArgParser& arg_parser) } // 4. Run the kernel - float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, - ck_tile::make_kernel( - Kernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(N, 1), // Input Stride - ck_tile::make_tuple(N, 1), // Output Stride - input_tensors, - static_cast(y_buf.GetDeviceBuffer()))); + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(N, 1), // Input Stride + ck_tile::make_tuple(N, 1), // Output Stride + input_tensors, + static_cast(y_buf.GetDeviceBuffer()))); std::cout << "Average time: " << ave_time << " ms" << std::endl; diff --git a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp index 4a031265c9..7087d092a2 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp @@ -113,7 +113,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // Run the kernel float ave_time = launch_kernel( ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{}, kGridSize, kBlockSize, diff --git a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp index aff74ae250..28cdaf27b9 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp @@ -112,17 +112,17 @@ bool run(const ck_tile::ArgParser& arg_parser) } // 4. Run the kernel - float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, - ck_tile::make_kernel( - Kernel{}, - kGridSize, - kBlockSize, - 0, // Shared memory - op_lengths, // Logical dimensions for the operation (M, N) - input_strides, // Strides for input tensor(s) - output_strides, // Strides for output tensor (N, M) - input_tensors, - static_cast(y_buf.GetDeviceBuffer()))); + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, // Shared memory + op_lengths, // Logical dimensions for the operation (M, N) + input_strides, // Strides for input tensor(s) + output_strides, // Strides for output tensor (N, M) + input_tensors, + static_cast(y_buf.GetDeviceBuffer()))); std::cout << "Average time: " << ave_time << " ms" << std::endl; diff --git a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp index d83592a033..782d3da24d 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp @@ -99,17 +99,17 @@ bool run(const ck_tile::ArgParser& arg_parser) } // 4. Run the kernel - float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, - ck_tile::make_kernel( - Kernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(N, 1), // Input Stride - ck_tile::make_tuple(N, 1), // Output Stride - input_tensors, - static_cast(y_buf.GetDeviceBuffer()))); + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(N, 1), // Input Stride + ck_tile::make_tuple(N, 1), // Output Stride + input_tensors, + static_cast(y_buf.GetDeviceBuffer()))); std::cout << "Average time: " << ave_time << " ms" << std::endl; diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp index 1f0f0b9bc1..931a9dfa3c 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp @@ -74,8 +74,8 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con auto kargs = kernel::MakeKargs(a); - const dim3 grids = kernel::GridSize(a); - constexpr dim3 blocks = kernel::BlockSize(); + const dim3 grids = kernel::GridSize(a); + const dim3 blocks = kernel::BlockSize(); printf("Pipeline: %d\n", Config::kPipelineId); printf("Grid: x=%u y=%u z=%u\n", grids.x, grids.y, grids.z); @@ -96,8 +96,8 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con printf("Launching Kernel...\n"); - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); + float ave_time = + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); printf("Kernel finished...\n"); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp index 2ac08c7343..2ea8530cb2 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp @@ -96,7 +96,6 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -111,8 +110,8 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(args.k_batch != 1) { @@ -136,7 +135,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s } float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp index f4f1aa98d3..4adc3df94b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp @@ -96,7 +96,6 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -111,8 +110,8 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(args.k_batch != 1) { @@ -136,7 +135,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s } float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/39_copy/copy_basic.cpp b/example/ck_tile/39_copy/copy_basic.cpp index 460036a641..3f36d7f4f0 100644 --- a/example/ck_tile/39_copy/copy_basic.cpp +++ b/example/ck_tile/39_copy/copy_basic.cpp @@ -99,16 +99,16 @@ bool run(const ck_tile::ArgParser& arg_parser) << ")" << std::endl; // Launch kernel - float ave_time = launch_kernel( - ck_tile::stream_config{nullptr, true, warmup, repeat, 1}, - ck_tile::make_kernel(Kernel{}, - kGridSize, - kBlockSize, - 0, - static_cast(x_buf.GetDeviceBuffer()), - static_cast(y_buf.GetDeviceBuffer()), - m, - n)); + float ave_time = + launch_kernel(ck_tile::stream_config{nullptr, true, warmup, repeat, 1}, + ck_tile::make_kernel<1>(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + m, + n)); // Calculate and print performance metrics std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m * n; diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index ec5f49108e..234929d6e6 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -60,13 +60,30 @@ enum struct memory_operation_enum : std::uint16_t CK_TILE_HOST_DEVICE constexpr index_t get_warp_size() { -#if defined(__GFX9__) || (!defined(__HIP_DEVICE_COMPILE__) && !defined(CK_TILE_WAVE32_ENABLED)) +#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) return 64; #else return 32; #endif } +CK_TILE_HOST bool is_wave32() +{ + hipDeviceProp_t props{}; + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return false; + } + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return false; + } + return props.major > 9; +} + CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; } CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; } diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index f94065da2b..7b5b862cb1 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -274,12 +274,6 @@ #define CK_TILE_WA_ISSUE_2028 0 #endif -#ifndef CK_TILE_WAVE32_ENABLED -#if defined(__gfx11__) || defined(__gfx12__) -#define CK_TILE_WAVE32_ENABLED -#endif -#endif - // Y pointed to R, we don't see a valuable use case. // Will enforce encoding to check Y not pointed to R if set to zero #ifndef CK_TILE_ENC_SUPPORT_Y_TO_R diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index 91ac3d5a0b..368a0594c5 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -15,9 +15,9 @@ namespace ck_tile { -template +template #if CK_TILE_USE_LAUNCH_BOUNDS -__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu) +__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) #endif __global__ void kentry(Args... args) { @@ -35,15 +35,11 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu) // // the "static __device__ operator()(some_arg)" is the entry point of KernelImpl // -template +template CK_TILE_HOST auto make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { - const auto kernel = kentry; - + const auto kernel = kentry; return [=](const stream_config& s) { kernel<<>>(args...); }; diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp index f06910db3d..c7717f08cd 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -53,6 +53,7 @@ struct AddRmsnorm2dRdquantFwd static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; diff --git a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp index a4150e8d84..b0f48f6c5b 100644 --- a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp +++ b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp @@ -34,6 +34,8 @@ struct BatchedTransposeKernel using Type = typename Problem::DataType; + static constexpr index_t kBlockSize = Problem::kBlockSize; + struct BatchedTransposeKargs { const void* p_input; diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp index 45803ae2da..b791bf9727 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -20,11 +20,10 @@ struct BatchedTransposeLdsProblem static constexpr index_t kRowWarps_ = NumWarps::at(number<0>{}); static constexpr index_t kColWarps_ = NumWarps::at(number<1>{}); - static constexpr index_t kBlockSize_ = get_warp_size() * kRowWarps_ * kColWarps_; static constexpr index_t kRowPerBlock_ = BlockTile::at(number<0>{}); static constexpr index_t kColPerBlock_ = BlockTile::at(number<1>{}); - static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kBlockSize = get_warp_size() * kRowWarps_ * kColWarps_; // warps per block static constexpr index_t kLeadNumWarps = kColWarps_; static constexpr index_t kSecondNumWarps = kRowWarps_; diff --git a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp index 103468c5fa..2ec9414f42 100644 --- a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp +++ b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp @@ -20,6 +20,8 @@ struct ElementWiseKernel using YDataType = ck_tile::remove_cvref_t; using ElementWiseOperation = ck_tile::remove_cvref_t; + static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize; + template CK_TILE_DEVICE void operator()(Dims lens, Dims input_strides, diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index f773de9e7e..1d0a4c42f4 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -17,7 +17,6 @@ template ; using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size(); static constexpr index_t kMPerBlock = kM_; static constexpr index_t kNPerBlock = kN_; static constexpr index_t MWave = MWave_; diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index 76df056ea6..20ca976590 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -91,13 +91,13 @@ struct FlatmmKernel using FlatmmPipeline = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; - using DsLayout = remove_cvref_t; - using DsDataType = remove_cvref_t; - static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + using DsDataType = remove_cvref_t; + static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -127,7 +127,7 @@ struct FlatmmKernel return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr KernelArgs MakeKernelArgs(const FlatmmHostArgs& hostArgs) diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 20783ea8bf..3ca79fc46e 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -237,15 +237,16 @@ struct UniversalFlatmmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad() { - using TileShape = typename Problem::BlockGemmShape; + using TileShape = typename Problem::BlockGemmShape; + constexpr index_t scale = get_warp_size() == 32 ? 2 : 1; if constexpr(TileShape::WarpTile::at(I1) == 32) { - return TileShape::WarpTile::at(I2) / 2; + return TileShape::WarpTile::at(I2) * scale / 2; } else { static_assert(TileShape::WarpTile::at(I1) == 16); - return TileShape::WarpTile::at(I2) / 4; + return TileShape::WarpTile::at(I2) * scale / 4; } } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 0d0959ba27..2850ce3379 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -24,9 +24,10 @@ namespace ck_tile { template struct FmhaBatchPrefillWithPagedKVCacheKernel { - using FmhaPipeline = ck_tile::remove_cvref_t; - using EpiloguePipeline = ck_tile::remove_cvref_t; - static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; static_assert(kBlockPerCu > 0); static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 5129f83532..81075d0ec6 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,6 +16,7 @@ struct FmhaFwdAppendKVKernel using FmhaPipeline = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 5b3d38d3e7..6d35afaa26 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -25,9 +25,10 @@ namespace ck_tile { template struct FmhaFwdKernel { - using FmhaPipeline = ck_tile::remove_cvref_t; - using EpiloguePipeline = ck_tile::remove_cvref_t; - static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; static_assert(kBlockPerCu > 0); static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index d8cd006c60..9a3e8ac304 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -30,6 +30,7 @@ struct FmhaFwdPagedKVKernel using EpiloguePipeline = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index 99ee912db9..ee1236d465 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,6 +14,7 @@ struct FmhaFwdSplitKVCombineKernel static constexpr index_t kNumWarps = FmhaPipeline::kNumWarps; static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); static constexpr index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 501aa26667..c50537f3fe 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -26,6 +26,7 @@ struct FmhaFwdSplitKVKernel using EpiloguePipeline = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index a5f9f31d6a..faeb5cf6b3 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -213,7 +213,7 @@ struct MoeSortingKernel using Hargs = MoeSortingHostArgs; - static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t kBlockSize = 256; static constexpr index_t OCCUPANCY = 2; // hard coded struct Kargs @@ -487,8 +487,8 @@ struct MoeSortingKernel vector_type* p_buf = reinterpret_cast(buf); auto zero_ = vector_type{0}; - for(long_index_t i = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x; i < total_elems; - i += (gridDim.x - 1) * BLOCK_SIZE) + for(long_index_t i = (blockIdx.x - 1) * kBlockSize + threadIdx.x; i < total_elems; + i += (gridDim.x - 1) * kBlockSize) { p_buf[i] = zero_; } @@ -1419,7 +1419,7 @@ template struct MoeSortingClearWorkspaceKernel { using Problem = remove_cvref_t; - static constexpr index_t BLOCK_SIZE = Problem::BlockSize; + static constexpr index_t kBlockSize = Problem::BlockSize; static constexpr index_t OCCUPANCY = Problem::Occu; using Hargs = MoeSortingHostArgs; @@ -1461,7 +1461,7 @@ struct MoeSortingClearWorkspaceKernel CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; } - CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); } // in byte CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; } @@ -1499,8 +1499,8 @@ struct MoeSortingClearWorkspaceKernel vector_type* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); auto zero_ = vector_type{0}; - for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elems; - i += gridDim.x * BLOCK_SIZE) + for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elems; + i += gridDim.x * kBlockSize) { p_expert_mesh[i] = zero_; } @@ -1560,7 +1560,7 @@ struct MoeSortingMultiPhaseKernel_P0 using WeightType = typename Problem::WeightType; using MeshType = typename Problem::MeshType; - static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t kBlockSize = 256; static constexpr index_t OCCUPANCY = 2; // hard coded typedef MoeSortingHostArgs MoeSortingKargs; @@ -1604,7 +1604,7 @@ struct MoeSortingMultiPhaseKernel_P0 CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; } - CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); } // in byte CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; } @@ -1647,8 +1647,8 @@ struct MoeSortingMultiPhaseKernel_P0 index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; #pragma unroll Problem::SubTokenTile - for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; - i += gridDim.x * BLOCK_SIZE) + for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elem; + i += gridDim.x * kBlockSize) { auto x = p_topk_ids[i]; static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) { @@ -1678,7 +1678,7 @@ struct MoeSortingMultiPhaseKernel_P1 using WeightType = typename Problem::WeightType; using MeshType = typename Problem::MeshType; - static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t kBlockSize = 256; static constexpr index_t OCCUPANCY = 2; // hard coded typedef MoeSortingHostArgs MoeSortingKargs; @@ -1709,12 +1709,12 @@ struct MoeSortingMultiPhaseKernel_P1 CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); } - CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); } // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - return BLOCK_SIZE / get_warp_size() * sizeof(IndexType); + return kBlockSize / get_warp_size() * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -1756,7 +1756,7 @@ struct MoeSortingMultiPhaseKernel_P1 r_t* p_expert_mesh = reinterpret_cast( reinterpret_cast(kargs.p_expert_mesh) + eid * mesh_stride); - int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize; if constexpr(Problem::LocalExpertMasking) { @@ -1768,7 +1768,7 @@ struct MoeSortingMultiPhaseKernel_P1 index_t cnt = 0; // per-wave cnt for(int i = 0; i < loops; i++) { - int position = i * BLOCK_SIZE + threadIdx.x; + int position = i * kBlockSize + threadIdx.x; r_t v{0}; if(position < (mesh_stride / index_pack)) v = p_expert_mesh[position]; @@ -1792,7 +1792,7 @@ struct MoeSortingMultiPhaseKernel_P1 if(threadIdx.x == 0) { index_t c = 0; - for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++) + for(auto i = 0; i < (kBlockSize / get_warp_size()); i++) { c += s[i]; } @@ -1811,7 +1811,7 @@ struct MoeSortingMultiPhaseKernel_P01 using WeightType = typename Problem::WeightType; using MeshType = typename Problem::MeshType; - static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t kBlockSize = 256; static constexpr index_t OCCUPANCY = 2; // hard coded typedef MoeSortingHostArgs MoeSortingKargs; @@ -1878,12 +1878,12 @@ struct MoeSortingMultiPhaseKernel_P01 CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; } - CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto WGCounts(const Hargs& h) { index_t total_elem = h.tokens * h.topk / Problem::SubTokenTile; - index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE; + index_t elem_cnt = (total_elem + kBlockSize - 1) / kBlockSize; // no more than grid_size return min(elem_cnt, GridSize(h)); @@ -1892,7 +1892,7 @@ struct MoeSortingMultiPhaseKernel_P01 // in byte CK_TILE_HOST static constexpr auto GetSmemSize() { - return BLOCK_SIZE / get_warp_size() * sizeof(IndexType); + return kBlockSize / get_warp_size() * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -1921,7 +1921,7 @@ struct MoeSortingMultiPhaseKernel_P01 if constexpr(Problem::LocalToken) { index_t total_elem = rounded_tokens * kargs.topk / Problem::SubTokenTile; - index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE; + index_t elem_cnt = (total_elem + kBlockSize - 1) / kBlockSize; // no more than grid_size return min(elem_cnt, kargs.wg_count); @@ -1940,8 +1940,8 @@ struct MoeSortingMultiPhaseKernel_P01 index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; #pragma unroll Problem::SubTokenTile - for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; - i += BLOCK_SIZE * gridDim.x) + for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elem; + i += kBlockSize * gridDim.x) { auto x = p_topk_ids[i]; static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) { @@ -1996,7 +1996,7 @@ struct MoeSortingMultiPhaseKernel_P01 auto f_sum = [](auto x_, auto y_) { return x_ + y_; }; - int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + int loops = (kargs.mesh_stride / index_pack + kBlockSize - 1) / kBlockSize; if constexpr(Problem::LocalExpertMasking) { @@ -2008,7 +2008,7 @@ struct MoeSortingMultiPhaseKernel_P01 index_t cnt = 0; // per-wave cnt for(int i = 0; i < loops; i++) { - int position = i * BLOCK_SIZE + threadIdx.x; + int position = i * kBlockSize + threadIdx.x; r_t v{0}; if(position < (kargs.mesh_stride / index_pack)) v = p_expert_mesh[position]; @@ -2033,7 +2033,7 @@ struct MoeSortingMultiPhaseKernel_P01 if(threadIdx.x == 0) { index_t c = 0; - for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++) + for(auto i = 0; i < (kBlockSize / get_warp_size()); i++) { c += s[i]; } @@ -2055,7 +2055,7 @@ struct MoeSortingMultiPhaseKernel_P2 using WeightType = typename Problem::WeightType; using MeshType = typename Problem::MeshType; - static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t kBlockSize = 256; static constexpr index_t OCCUPANCY = 2; // hard coded typedef MoeSortingHostArgs MoeSortingKargs; @@ -2123,17 +2123,17 @@ struct MoeSortingMultiPhaseKernel_P2 return dim3(h.num_experts + get_num_cu() * OCCUPANCY); #else // use 1 block to cumsum - return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); + return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16)); #endif } - CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); } // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - // return 2 * BLOCK_SIZE * sizeof(IndexType); - return (4 + 2 * BLOCK_SIZE / get_warp_size()) * sizeof(IndexType); + // return 2 * kBlockSize * sizeof(IndexType); + return (4 + 2 * kBlockSize / get_warp_size()) * sizeof(IndexType); } // reduce single pixel within a wave @@ -2142,7 +2142,7 @@ struct MoeSortingMultiPhaseKernel_P2 if(blockIdx.x > 0) { #if MOE_SORTING_FMOE_2D_BUF - impl::moe_buf_set_zero_kernel_2d(kargs.p_moe_buf, + impl::moe_buf_set_zero_kernel_2d(kargs.p_moe_buf, kargs.tokens, kargs.moe_buf_interm_dim, kargs.moe_buf_elem_bytes, @@ -2150,7 +2150,7 @@ struct MoeSortingMultiPhaseKernel_P2 gridDim.x - 1); return; #else - impl::moe_buf_set_zero_kernel( + impl::moe_buf_set_zero_kernel( reinterpret_cast(kargs.p_moe_buf), kargs.moe_buf_bytes, blockIdx.x - 1); @@ -2167,7 +2167,7 @@ struct MoeSortingMultiPhaseKernel_P2 reinterpret_cast(kargs.p_total_tokens_post_pad); IndexType* p_sorted_expert_ids = reinterpret_cast(kargs.p_sorted_expert_ids); - const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; + const index_t loops = (kargs.num_experts + kBlockSize - 1) / kBlockSize; index_t wave_id = threadIdx.x / get_warp_size(); index_t lane_id = threadIdx.x % get_warp_size(); @@ -2176,7 +2176,7 @@ struct MoeSortingMultiPhaseKernel_P2 for(index_t i = 0; i < loops; i++) { - index_t position = i * BLOCK_SIZE + threadIdx.x; + index_t position = i * kBlockSize + threadIdx.x; IndexType a_ = 0; // token count for a expert IndexType b_ = 0; // mask for a expert if(position < kargs.num_experts) @@ -2221,15 +2221,15 @@ struct MoeSortingMultiPhaseKernel_P2 if(lane_id == get_warp_size() - 1) { s[4 + wave_id] = cumsum_a; - s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b; + s[4 + wave_id + kBlockSize / get_warp_size()] = cumsum_b; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { + static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev_a = s[4 + i_w]; - IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()]; + IndexType prev_b = s[4 + i_w + kBlockSize / get_warp_size()]; prev_a = wave_id > i_w ? prev_a : 0; // mask out prev_b = wave_id > i_w ? prev_b : 0; // mask out cumsum_a += prev_a; @@ -2240,7 +2240,7 @@ struct MoeSortingMultiPhaseKernel_P2 cumsum_a += prev_cumsum_a; cumsum_b += prev_cumsum_b; - if(threadIdx.x == BLOCK_SIZE - 1) + if(threadIdx.x == kBlockSize - 1) { s[2] = cumsum_a; // store the last cumsum s[3] = cumsum_b; @@ -2297,7 +2297,7 @@ struct MoeSortingMultiPhaseKernel_P3 using WeightType = typename Problem::WeightType; using MeshType = typename Problem::MeshType; - static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t kBlockSize = 256; static constexpr index_t OCCUPANCY = 2; // hard coded typedef MoeSortingHostArgs MoeSortingKargs; @@ -2341,12 +2341,12 @@ struct MoeSortingMultiPhaseKernel_P3 CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); } - CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); } // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - return (4 + BLOCK_SIZE / get_warp_size()) * sizeof(IndexType); + return (4 + kBlockSize / get_warp_size()) * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -2391,11 +2391,11 @@ struct MoeSortingMultiPhaseKernel_P3 } // cumsum one by one - int loops = (kargs.mesh_stride + BLOCK_SIZE - 1) / BLOCK_SIZE; + int loops = (kargs.mesh_stride + kBlockSize - 1) / kBlockSize; int prev_cumsum = 0; for(int i = 0; i < loops; i++) { - int i_token = i * BLOCK_SIZE + threadIdx.x; + int i_token = i * kBlockSize + threadIdx.x; IndexType x = 0; if(i_token < tokens) { @@ -2414,13 +2414,13 @@ struct MoeSortingMultiPhaseKernel_P3 __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { + static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; }); cumsum += prev_cumsum; // add previous round cumsum - if(threadIdx.x == BLOCK_SIZE - 1) + if(threadIdx.x == kBlockSize - 1) { s[0] = cumsum; } @@ -2441,7 +2441,7 @@ struct MoeSortingMultiPhaseKernel_P3 } } - for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE) + for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += kBlockSize) { #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor); @@ -2457,9 +2457,9 @@ namespace impl { // we use dynamic LDS size here CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_) { - constexpr index_t BLOCK_SIZE = 256; // hardcoded 256 + constexpr index_t kBlockSize = 256; // hardcoded 256 const index_t expert_cumsum_elem = num_experts_ + 1; - return (4 + 2 * BLOCK_SIZE / get_warp_size() + expert_cumsum_elem) * sizeof(int); + return (4 + 2 * kBlockSize / get_warp_size() + expert_cumsum_elem) * sizeof(int); } } // namespace impl @@ -2473,7 +2473,7 @@ struct MoeSortingMultiPhaseKernel_P23 using WeightType = typename Problem::WeightType; using MeshType = typename Problem::MeshType; - static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t kBlockSize = 256; static constexpr index_t OCCUPANCY = 2; // hard coded typedef MoeSortingHostArgs MoeSortingKargs; @@ -2563,18 +2563,18 @@ struct MoeSortingMultiPhaseKernel_P23 return dim3(h.num_experts + get_num_cu() * OCCUPANCY); #else // use 1 block to cumsum - // return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); - return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); + // return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16)); + return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16)); #endif } - CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); } // only use this at host ! CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h) { const auto smem_23 = impl::moe_sorting_get_smem_size_p23(h.num_experts); - const auto smem_sf = BLOCK_SIZE * 4 * sizeof(IndexType); + const auto smem_sf = kBlockSize * 4 * sizeof(IndexType); return max(smem_23, smem_sf); } @@ -2595,7 +2595,7 @@ struct MoeSortingMultiPhaseKernel_P23 if(static_cast(blockIdx.x) >= kargs.num_experts) { #if MOE_SORTING_FMOE_2D_BUF - impl::moe_buf_set_zero_kernel_2d(kargs.p_moe_buf, + impl::moe_buf_set_zero_kernel_2d(kargs.p_moe_buf, tokens, kargs.moe_buf_interm_dim, kargs.moe_buf_elem_bytes, @@ -2603,7 +2603,7 @@ struct MoeSortingMultiPhaseKernel_P23 gridDim.x - kargs.num_experts); return; #else - impl::moe_buf_set_zero_kernel( + impl::moe_buf_set_zero_kernel( reinterpret_cast(kargs.p_moe_buf), kargs.moe_buf_bytes, blockIdx.x - kargs.num_experts); @@ -2618,13 +2618,13 @@ struct MoeSortingMultiPhaseKernel_P23 const IndexType* p_local_expert_mask = static_cast(kargs.p_local_expert_mask); IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); - IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size(); + IndexType* p_expert_cumsum_smem = s + 4 + 2 * kBlockSize / get_warp_size(); IndexType* p_total_tokens_post_pad = reinterpret_cast(kargs.p_total_tokens_post_pad); IndexType* p_sorted_expert_ids = reinterpret_cast(kargs.p_sorted_expert_ids); - const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; + const index_t loops = (kargs.num_experts + kBlockSize - 1) / kBlockSize; index_t wave_id = threadIdx.x / get_warp_size(); index_t lane_id = threadIdx.x % get_warp_size(); @@ -2633,7 +2633,7 @@ struct MoeSortingMultiPhaseKernel_P23 for(index_t i = 0; i < loops; i++) { - index_t position = i * BLOCK_SIZE + threadIdx.x; + index_t position = i * kBlockSize + threadIdx.x; IndexType a_ = 0; // token count for a expert IndexType b_ = 0; // mask for a expert if(position < kargs.num_experts) @@ -2678,15 +2678,15 @@ struct MoeSortingMultiPhaseKernel_P23 if(lane_id == get_warp_size() - 1) { s[4 + wave_id] = cumsum_a; - s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b; + s[4 + wave_id + kBlockSize / get_warp_size()] = cumsum_b; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { + static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev_a = s[4 + i_w]; - IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()]; + IndexType prev_b = s[4 + i_w + kBlockSize / get_warp_size()]; prev_a = wave_id > i_w ? prev_a : 0; // mask out prev_b = wave_id > i_w ? prev_b : 0; // mask out cumsum_a += prev_a; @@ -2697,7 +2697,7 @@ struct MoeSortingMultiPhaseKernel_P23 cumsum_a += prev_cumsum_a; cumsum_b += prev_cumsum_b; - if(threadIdx.x == BLOCK_SIZE - 1) + if(threadIdx.x == kBlockSize - 1) { s[2] = cumsum_a; // store the last cumsum s[3] = cumsum_b; @@ -2758,7 +2758,7 @@ struct MoeSortingMultiPhaseKernel_P23 IndexType* s = reinterpret_cast(smem); MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); IndexType* p_sorted_token_ids = reinterpret_cast(kargs.p_sorted_token_ids); - IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size(); + IndexType* p_expert_cumsum_smem = s + 4 + 2 * kBlockSize / get_warp_size(); const WeightType* p_weights = static_cast(kargs.p_weights); WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); @@ -2795,13 +2795,13 @@ struct MoeSortingMultiPhaseKernel_P23 constexpr index_t index_pack = Problem::SubTokenTile; // always packed using r_t = ext_vector_t; // always use int32x4 using d_t = ext_vector_t; - int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize; int prev_cumsum = 0; for(int i = 0; i < loops; i++) { - int i_token_pack = i * BLOCK_SIZE + threadIdx.x; + int i_token_pack = i * kBlockSize + threadIdx.x; r_t x_v = 0; if(i_token_pack < (tokens + index_pack - 1) / index_pack) { @@ -2819,7 +2819,7 @@ struct MoeSortingMultiPhaseKernel_P23 static_for<0, index_pack, 1>{}([&](auto j_) { constexpr auto j = j_.value; - x_r[j] = reinterpret_cast(s)[threadIdx.x + j * BLOCK_SIZE]; + x_r[j] = reinterpret_cast(s)[threadIdx.x + j * kBlockSize]; }); } #else @@ -2830,7 +2830,7 @@ struct MoeSortingMultiPhaseKernel_P23 #pragma unroll for(int j = 0; j < index_pack / 2; j++) { - int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j * BLOCK_SIZE; + int i_token = i * kBlockSize * index_pack + threadIdx.x + j * kBlockSize; index_t x = x_d[j]; int i_topk = x - 1; // topk of this token int i_show = x != 0 ? 1 : 0; // has this token or not @@ -2845,13 +2845,13 @@ struct MoeSortingMultiPhaseKernel_P23 __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { + static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; }); cumsum += prev_cumsum; // add previous round cumsum - if(threadIdx.x == BLOCK_SIZE - 1) + if(threadIdx.x == kBlockSize - 1) { s[0] = cumsum; } @@ -2896,13 +2896,13 @@ struct MoeSortingMultiPhaseKernel_P23 __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { + static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; }); cumsum += prev_cumsum; // add previous round cumsum - if(threadIdx.x == BLOCK_SIZE - 1) + if(threadIdx.x == kBlockSize - 1) { s[0] = cumsum; } @@ -2912,10 +2912,10 @@ struct MoeSortingMultiPhaseKernel_P23 int position = cumsum - cumsum_store; static_for<0, index_pack, 1>{}([&](auto j_) { constexpr auto j = j_.value; - // int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j * - // BLOCK_SIZE; + // int i_token = i * kBlockSize * index_pack + threadIdx.x + j * + // kBlockSize; int i_token = - i * BLOCK_SIZE * index_pack + threadIdx.x * index_pack + j; + i * kBlockSize * index_pack + threadIdx.x * index_pack + j; if(i_show[j]) { @@ -2932,7 +2932,7 @@ struct MoeSortingMultiPhaseKernel_P23 }); #if 0 - int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x * 2 + j * BLOCK_SIZE * 2; + int i_token = i * kBlockSize * index_pack + threadIdx.x * 2 + j * kBlockSize * 2; index_t x = x_d[j]; index_t x0 = static_cast(x & 0xffff); index_t x1 = static_cast(x >> 16); @@ -2951,13 +2951,13 @@ struct MoeSortingMultiPhaseKernel_P23 __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { + static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; }); cumsum += prev_cumsum; // add previous round cumsum - if(threadIdx.x == BLOCK_SIZE - 1) + if(threadIdx.x == kBlockSize - 1) { s[0] = cumsum; } @@ -2996,7 +2996,7 @@ struct MoeSortingMultiPhaseKernel_P23 } } - for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE) + for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += kBlockSize) { #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor); diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 9c1ce73eac..fcfbf9635f 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -64,6 +64,7 @@ struct BatchedGemmKernel /// functions. using UniversalGemmKernel = UniversalGemmKernel; + static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize; using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; @@ -121,9 +122,16 @@ struct BatchedGemmKernel return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch); } - CK_TILE_HOST static constexpr auto BlockSize() -> dim3 + CK_TILE_HOST static auto BlockSize() -> dim3 { - return dim3(UniversalGemmKernel::KernelBlockSize); + if(ck_tile::is_wave32()) + { + return dim3(UniversalGemmKernel::kBlockSize / 2); + } + else + { + return dim3(UniversalGemmKernel::kBlockSize); + } } CK_TILE_HOST static constexpr BatchedGemmKernelArgs diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 079d3972d1..e37b4f36d4 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -113,6 +113,7 @@ struct GemmKernel static constexpr index_t NumATensor = 1; static constexpr index_t NumBTensor = 1; + static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize; CK_TILE_HOST static auto GetName() -> const std::string { diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp index 34340008d4..34c4e72b22 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp @@ -86,6 +86,7 @@ struct GemmKernelMultiD /// functions. using UniversalGemmKernel = UniversalGemmKernel; + static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize; using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 477a87d42f..c35435ee5e 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -128,7 +128,7 @@ struct GroupedGemmKernel using OffsetTile1DPartitioner = OffsettedTile1DPartitioner; using Kernel = GroupedGemmKernel; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -155,7 +155,7 @@ struct GroupedGemmKernel return group_count * sizeof(GemmTransKernelArg); } - CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(kBlockSize); } /** * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. @@ -166,10 +166,10 @@ struct GroupedGemmKernel CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 { using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*; - const auto kernel = kentry; + const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>; int occupancy; HIP_CHECK_ERROR( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0)); + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0)); const int grid_size = get_available_compute_units(s) * occupancy; return dim3(grid_size, 1, 1); } diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index ec1cc2ddb4..8117d65758 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -196,7 +196,7 @@ struct UniversalGemmKernel using ELayout = remove_cvref_t; using EDataType = remove_cvref_t; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; // Get the persistent kernel if the pipeline has it available struct has_persistent_kernel @@ -275,15 +275,26 @@ struct UniversalGemmKernel CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 { using Kernel = UniversalGemmKernel; - const auto kernel = kentry; + const auto kernel = kentry<1, Kernel, KernelArgs>; int occupancy; hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0)); + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0)); + const int grid_size = get_available_compute_units(s) * occupancy; return dim3(grid_size, 1, 1); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + CK_TILE_HOST static auto BlockSize() + { + if(ck_tile::is_wave32()) + { + return dim3(kBlockSize / 2); + } + else + { + return dim3(kBlockSize); + } + } CK_TILE_HOST static constexpr KernelArgs MakeKernelArgs(const UniversalGemmHostArgs& hostArgs) @@ -371,7 +382,9 @@ struct UniversalGemmKernel } } - bool AsTesnorIsValid = {true}; + const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA() + : GemmPipeline::template GetVectorSizeA(); + bool AsTesnorIsValid = {true}; static_for<0, NumATensor, 1>{}([&](auto index) { using AiLayout = remove_cvref_t>; if constexpr(std::is_same_v) @@ -387,7 +400,7 @@ struct UniversalGemmKernel } AsTesnorIsValid = false; } - if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) + if(kargs.K % vectorSizeA != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { @@ -407,7 +420,7 @@ struct UniversalGemmKernel } AsTesnorIsValid = false; } - if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) + if(kargs.M % vectorSizeA != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { @@ -418,7 +431,9 @@ struct UniversalGemmKernel } }); - bool BsTesnorIsValid = {true}; + bool BsTesnorIsValid = {true}; + const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB() + : GemmPipeline::template GetVectorSizeB(); static_for<0, NumBTensor, 1>{}([&](auto index) { using BiLayout = remove_cvref_t>; if constexpr(std::is_same_v) @@ -432,7 +447,7 @@ struct UniversalGemmKernel } BsTesnorIsValid = false; } - if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) + if(kargs.N % vectorSizeB != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { @@ -454,7 +469,7 @@ struct UniversalGemmKernel } BsTesnorIsValid = false; } - if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) + if(kargs.K % vectorSizeB != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 2d439c6970..5f4ee8987e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -127,8 +127,16 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } - static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + template + static constexpr index_t GetVectorSizeA() + { + return Policy::template GetVectorSizeA(); + } + template + static constexpr index_t GetVectorSizeB() + { + return Policy::template GetVectorSizeB(); + } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } static constexpr index_t APackedSize = diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index b0cd93a661..c835809b5d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -124,8 +124,16 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } - static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + template + static constexpr index_t GetVectorSizeA() + { + return Policy::template GetVectorSizeA(); + } + template + static constexpr index_t GetVectorSizeB() + { + return Policy::template GetVectorSizeB(); + } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index 0fdcc04d89..b05145890f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -61,8 +61,16 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } - static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + template + static constexpr index_t GetVectorSizeA() + { + return Policy::template GetVectorSizeA(); + } + template + static constexpr index_t GetVectorSizeB() + { + return Policy::template GetVectorSizeB(); + } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } static constexpr bool kPadM = Problem::kPadM; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index d62add7ef3..e1acfebc47 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -176,8 +176,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } - static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + template + static constexpr index_t GetVectorSizeA() + { + return Policy::template GetVectorSizeA(); + } + template + static constexpr index_t GetVectorSizeB() + { + return Policy::template GetVectorSizeB(); + } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index d8118a7f8f..e3b4863392 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -36,8 +36,16 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kKPerBlock = BlockGemmShape::kK; - static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; } - static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; } + template + static constexpr index_t GetVectorSizeA() + { + return Problem::VectorSizeA; + } + template + static constexpr index_t GetVectorSizeB() + { + return Problem::VectorSizeB; + } static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index e4b3649595..40ee952b1b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -305,11 +305,15 @@ struct UniversalGemmBasePolicy * @tparam XPerTile The contiguous Tile dimension size. * @return Maximum DRAM vector load size. */ - template + template CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize() { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t BlockSize = IsWave32Host ? Problem::kBlockSize / 2 : Problem::kBlockSize; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; @@ -349,7 +353,7 @@ struct UniversalGemmBasePolicy } } - template + template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() { using ALayout = remove_cvref_t; @@ -359,15 +363,23 @@ struct UniversalGemmBasePolicy if constexpr(std::is_same_v) { - return GetGlobalVectorLoadSize(); + return GetGlobalVectorLoadSize(); } else { - return GetGlobalVectorLoadSize(); + return GetGlobalVectorLoadSize(); } } - template + template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() { using BLayout = remove_cvref_t; @@ -377,11 +389,19 @@ struct UniversalGemmBasePolicy if constexpr(std::is_same_v) { - return GetGlobalVectorLoadSize(); + return GetGlobalVectorLoadSize(); } else { - return GetGlobalVectorLoadSize(); + return GetGlobalVectorLoadSize(); } } diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp index cadd77a61f..b91c211d91 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp @@ -59,13 +59,15 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; + template static constexpr index_t GetVectorSizeA() { - return PipelinePolicy::template GetVectorSizeA(); + return PipelinePolicy::template GetVectorSizeA(); } + template static constexpr index_t GetVectorSizeB() { - return PipelinePolicy::template GetVectorSizeB(); + return PipelinePolicy::template GetVectorSizeB(); } static constexpr bool kPadM = Problem::kPadM; diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 9c0f257e8e..c507d8d8d8 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -76,13 +76,15 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; + template static constexpr index_t GetVectorSizeA() { - return PipelinePolicy::template GetVectorSizeA(); + return PipelinePolicy::template GetVectorSizeA(); } + template static constexpr index_t GetVectorSizeB() { - return PipelinePolicy::template GetVectorSizeB(); + return PipelinePolicy::template GetVectorSizeB(); } static constexpr bool kPadM = Problem::kPadM; diff --git a/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp index 78a514d6cd..6973c80d57 100644 --- a/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp +++ b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp @@ -99,15 +99,15 @@ struct AQuantGemmKernelArgs template struct AQuantGemmKernel { - using TilePartitioner = remove_cvref_t; - using GemmPipeline = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using AQLayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; - static constexpr bool Preshuffle = GemmPipeline::Preshuffle; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using AQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; + static constexpr bool Preshuffle = GemmPipeline::Preshuffle; using ADataType = remove_cvref_t; using AQDataType = remove_cvref_t; @@ -131,7 +131,7 @@ struct AQuantGemmKernel return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr AQuantGemmKernelArgs MakeKernelArgs(const AQuantGemmHostArgs& hostArgs) diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 115f6dea19..7ea2e31706 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -354,7 +354,7 @@ struct GroupedConvolutionBackwardWeightKernel using GemmDsLayout = remove_cvref_t; static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; using InDataType = remove_cvref_t; using WeiDataType = remove_cvref_t; @@ -393,7 +393,7 @@ struct GroupedConvolutionBackwardWeightKernel TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs) diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 8cd1710043..d3a90ea144 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -361,7 +361,7 @@ struct GroupedConvolutionForwardKernel using GemmDsLayout = remove_cvref_t; static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; using InDataType = remove_cvref_t; using WeiDataType = remove_cvref_t; @@ -398,7 +398,7 @@ struct GroupedConvolutionForwardKernel TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs) diff --git a/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp b/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp index ee74f1588f..eb54807d88 100644 --- a/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp +++ b/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -31,6 +31,7 @@ struct ImageToColumn static constexpr index_t kMPerBlock = Problem::BlockShape::kMPerBlock; static constexpr index_t kKPerBlock = Problem::BlockShape::kKPerBlock; + static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize; struct Kargs { diff --git a/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp b/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp index ad513dbd11..05490ac3ed 100644 --- a/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp +++ b/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp @@ -14,11 +14,10 @@ struct TileImageToColumnShape static constexpr index_t kMPerThread = ThreadTile::at(number<0>{}); static constexpr index_t kKPerThread = ThreadTile::at(number<1>{}); - static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); - static constexpr index_t kKPerWarp = WarpTile::at(number<1>{}); - + static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread; - static constexpr index_t kKThreadPerWarp = kKPerWarp / kKPerThread; + static constexpr index_t kKThreadPerWarp = get_warp_size() / kMThreadPerWarp; + static constexpr index_t kKPerWarp = kKPerThread * kKThreadPerWarp; static constexpr index_t kMPerBlock = BlockTile::at(number<0>{}); static constexpr index_t kKPerBlock = BlockTile::at(number<1>{}); diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index 146ac40fb7..6998b358d8 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -76,9 +76,9 @@ struct Layernorm2dFwd static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; - - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; struct Kargs { diff --git a/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp b/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp index 1c5cc4a11a..3578e3b375 100644 --- a/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp +++ b/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp index 0cae4023b7..5755f38475 100644 --- a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp +++ b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp @@ -25,6 +25,8 @@ struct Reduce using ComputeDataType = ck_tile::remove_cvref_t; using YDataType = ck_tile::remove_cvref_t; + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; + private: // Helper function to calculate optimal vector size for input tensor template diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp index 6cb81b8856..e7f4ce0ba8 100644 --- a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp @@ -70,6 +70,7 @@ struct Rmsnorm2dFwd static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; diff --git a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp index cb934c6c52..b70e996617 100644 --- a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp +++ b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp @@ -48,6 +48,7 @@ struct MoeSmoothquant static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; diff --git a/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp index 540fddd2e8..7dc913901e 100644 --- a/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp +++ b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp @@ -45,6 +45,7 @@ struct Smoothquant static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; diff --git a/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp index b8520ae61a..277049f6b0 100644 --- a/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp +++ b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -34,6 +34,8 @@ struct TopkSoftmaxKernel using WeightType = typename Problem::WeightType; using IndexType = typename Problem::IndexType; + static constexpr index_t kBlockSize = Problem::BlockSize; + struct TopkSoftmaxKargs { const void* p_input; diff --git a/include/ck_tile/ref/naive_attention.hpp b/include/ck_tile/ref/naive_attention.hpp index 172fcee2e3..50e963bd72 100644 --- a/include/ck_tile/ref/naive_attention.hpp +++ b/include/ck_tile/ref/naive_attention.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -117,7 +117,7 @@ struct naive_attention_fwd_kernel std::is_same_v && std::is_same_v; static constexpr int v_per_token_quant_group_size = 64; - + static constexpr int kBlockSize = 256; // TODO: hardcode using SoftmaxType = float; // always using float to do softmax compute using QuantComputeType = float; // used for quant/dequant scale compute @@ -254,7 +254,7 @@ struct naive_attention_fwd_kernel __device__ T load(int i_s, int i_h, int i_d) { return base_ptr[get_offset(i_s, i_h, i_d)]; } }; - __device__ __host__ static constexpr int get_block_size() { return 256; } + __device__ __host__ static constexpr int get_block_size() { return kBlockSize; } // for simpliciy, 1 WG always compute 1 token along q, compute all token along kv // compute all hdim from q, compute WG_SIZE hdim from v diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp index 25b10e1dc4..dd90034064 100644 --- a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp @@ -66,5 +66,5 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index f654d1a917..f634e508e3 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -111,7 +111,6 @@ class TestCkTileBatchedGemm : public ::testing::Test DsLayout, CLayout, ck_tile::element_wise::PassThrough, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -124,8 +123,8 @@ class TestCkTileBatchedGemm : public ::testing::Test using Kernel = ck_tile::BatchedGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -144,7 +143,7 @@ class TestCkTileBatchedGemm : public ::testing::Test } ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/test/ck_tile/batched_transpose/test_batched_transpose.cpp b/test/ck_tile/batched_transpose/test_batched_transpose.cpp index 77d5825eed..8812397946 100644 --- a/test/ck_tile/batched_transpose/test_batched_transpose.cpp +++ b/test/ck_tile/batched_transpose/test_batched_transpose.cpp @@ -137,11 +137,11 @@ class TestCkTileBatchedTranspose // N C H W layout_in== Config::BlockTile::at(1)}; auto kargs = Kernel::MakeKargs(host_args); - auto sc = ck_tile::stream_config{}; - const dim3 grid_size = Kernel::GridSize(host_args); - constexpr dim3 block_size = Kernel::BlockSize(); - ck_tile::launch_kernel( - sc, ck_tile::make_kernel(Kernel{}, grid_size, block_size, 0, kargs)); + auto sc = ck_tile::stream_config{}; + const dim3 grid_size = Kernel::GridSize(host_args); + const dim3 block_size = Kernel::BlockSize(); + ck_tile::launch_kernel(sc, + ck_tile::make_kernel<1>(Kernel{}, grid_size, block_size, 0, kargs)); y_dev.FromDevice(y_host.data()); ck_tile::reference_batched_transpose(x_host, y_ref, layout_in, layout_out); diff --git a/test/ck_tile/elementwise/test_elementwise_1d.cpp b/test/ck_tile/elementwise/test_elementwise_1d.cpp index 9966c369be..3ce6e78d1d 100644 --- a/test/ck_tile/elementwise/test_elementwise_1d.cpp +++ b/test/ck_tile/elementwise/test_elementwise_1d.cpp @@ -118,19 +118,17 @@ class TestCkTileElementwise : public ::testing::Test "The kernel configuration is not supported for the given input size."); } - ck_tile::launch_kernel( - s, - ck_tile::make_kernel // MinBlockPerCu - (ew_kernel, - grid, - block, - 0, // actual shared memory - lens, - strides, // input strides - strides, // output strides - d_x_ptrs_tuple, - p_y_device)); + ck_tile::launch_kernel(s, + ck_tile::make_kernel // MinBlockPerCu + (ew_kernel, + grid, + block, + 0, // actual shared memory + lens, + strides, // input strides + strides, // output strides + d_x_ptrs_tuple, + p_y_device)); d_y_mem.FromDevice(h_y.data()); diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc index 4321709ea5..53eff9ecc4 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc @@ -77,7 +77,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -93,8 +92,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -114,7 +113,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) } float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc index a22ecf2486..adae8dcf92 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc @@ -91,7 +91,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) DsLayout, ELayout, CDEElementWise, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GemmConfig::M_Warp, @@ -114,7 +113,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { grids = Kernel::GridSize(args.M, args.N, args.k_batch); } - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -165,15 +164,13 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ave_time = ck_tile::launch_kernel_time_mask( s, run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } return ave_time; }; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 26ff847841..af4f8d3d38 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -10,6 +10,7 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +#include "ck_tile/core/numeric/math.hpp" template auto calculate_rtol_atol(const ck_tile::index_t K, @@ -184,7 +185,6 @@ class TestCkTileGemmPipeline : public ::testing::Test DsLayout, CLayout, ck_tile::element_wise::PassThrough, - GemmPipeline::BlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -207,7 +207,7 @@ class TestCkTileGemmPipeline : public ::testing::Test { grids = Kernel::GridSize(args.M, args.N, args.k_batch); } - constexpr dim3 blocks = Kernel::BlockSize(); + dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -222,7 +222,7 @@ class TestCkTileGemmPipeline : public ::testing::Test } ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { diff --git a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc index 0b886938b8..e8ff45fc5e 100644 --- a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc +++ b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc @@ -99,7 +99,6 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -114,8 +113,8 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(args.k_batch != 1) { @@ -139,7 +138,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s } float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp index c08951435e..d21777c92b 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -178,7 +178,6 @@ class TestCkTileGemmMultiD : public ::testing::Test DsLayout, ELayout, CDEElementWise, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -192,8 +191,8 @@ class TestCkTileGemmMultiD : public ::testing::Test using Kernel = ck_tile::GemmKernelMultiD; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -212,7 +211,7 @@ class TestCkTileGemmMultiD : public ::testing::Test } ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index af229aad29..5d52f15696 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -183,7 +183,6 @@ class TestCkTileGemmPipeline : public ::testing::Test DsLayout, CLayout, ck_tile::element_wise::PassThrough, - GemmPipeline::BlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GemmConfig::M_Warp, @@ -206,7 +205,7 @@ class TestCkTileGemmPipeline : public ::testing::Test { grids = Kernel::GridSize(args.M, args.N, args.k_batch); } - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -221,7 +220,7 @@ class TestCkTileGemmPipeline : public ::testing::Test } ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index cededd38f9..5aca02a433 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -136,7 +136,6 @@ class TestCkTileGroupedGemm : public ::testing::Test DsLayout, CLayout, ck_tile::element_wise::PassThrough, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GroupedGemKernelParam::M_Warp, @@ -150,8 +149,8 @@ class TestCkTileGroupedGemm : public ::testing::Test auto kargs = Kernel::MakeKargs(gemm_descs); EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); - const dim3 grids = Kernel::GridSize(gemm_descs); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, kargs.data(), @@ -169,7 +168,7 @@ class TestCkTileGroupedGemm : public ::testing::Test ave_time = ck_tile::launch_kernel( s, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{}, grids, blocks, @@ -227,12 +226,6 @@ class TestCkTileGroupedGemm : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; @@ -268,7 +259,6 @@ class TestCkTileGroupedGemm : public ::testing::Test DsLayout, CLayout, ck_tile::element_wise::PassThrough, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GroupedGemKernelParam::M_Warp, @@ -279,8 +269,8 @@ class TestCkTileGroupedGemm : public ::testing::Test UniversalGemmProblem::TransposeC, memory_operation>>; using Kernel = ck_tile::GroupedGemmKernel; - constexpr dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); if(s.log_level_ > 0) { @@ -291,7 +281,7 @@ class TestCkTileGroupedGemm : public ::testing::Test } ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{}, grids, blocks, diff --git a/test/ck_tile/image_to_column/test_tile_image_to_column.cpp b/test/ck_tile/image_to_column/test_tile_image_to_column.cpp index 9c0746e972..c721f1073f 100644 --- a/test/ck_tile/image_to_column/test_tile_image_to_column.cpp +++ b/test/ck_tile/image_to_column/test_tile_image_to_column.cpp @@ -97,13 +97,13 @@ class TestCkTileImageToColumn : public ::testing::Test kargs.N * kargs.output_spatial_lengths[0] * kargs.output_spatial_lengths[1], kargs.filter_spatial_lengths[0] * kargs.filter_spatial_lengths[1] * kargs.C, kargs.G); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 2; ck_tile::launch_kernel( ck_tile::stream_config{}, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); // reference ck_tile::reference_im2col(in, out_host, conv_params); diff --git a/test/ck_tile/layernorm2d/generate.py b/test/ck_tile/layernorm2d/generate.py index d77582630a..c4366f6662 100644 --- a/test/ck_tile/layernorm2d/generate.py +++ b/test/ck_tile/layernorm2d/generate.py @@ -235,7 +235,7 @@ float layernorm2d_fwd_(const S& s, A a) using Kernel = ck_tile::Layernorm2dFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); @@ -243,7 +243,7 @@ float layernorm2d_fwd_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/test/ck_tile/memory_copy/test_copy.cpp b/test/ck_tile/memory_copy/test_copy.cpp index e8962dce29..30a2e60ea9 100644 --- a/test/ck_tile/memory_copy/test_copy.cpp +++ b/test/ck_tile/memory_copy/test_copy.cpp @@ -76,17 +76,17 @@ class TestCkTileMemoryCopy : public ::testing::TestWithParam( - Kernel{}, - kGridSize, - kBlockSize, - 0, - static_cast(x_buf.GetDeviceBuffer()), - static_cast(y_buf.GetDeviceBuffer()), - m, - n, - warp_id)); + auto ms = launch_kernel( + ck_tile::stream_config{nullptr, true}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + m, + n, + warp_id)); auto bytes = 2 * m * n * sizeof(DataType); std::cout << "elapsed: " << ms << " (ms)" << std::endl; diff --git a/test/ck_tile/memory_copy/test_copy.hpp b/test/ck_tile/memory_copy/test_copy.hpp index a9840ba2c6..4833b29560 100644 --- a/test/ck_tile/memory_copy/test_copy.hpp +++ b/test/ck_tile/memory_copy/test_copy.hpp @@ -64,7 +64,8 @@ struct TileCopy using Problem = ck_tile::remove_cvref_t; using XDataType = typename Problem::XDataType; - static constexpr bool AsyncCopy = Problem::AsyncCopy; + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; + static constexpr bool AsyncCopy = Problem::AsyncCopy; template CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution() diff --git a/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp b/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp index 9d8c9caf00..f2875c72c8 100644 --- a/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp +++ b/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp @@ -61,5 +61,5 @@ float moe_smoothquant_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } diff --git a/test/ck_tile/moe_sorting/moe_sorting_api.cpp b/test/ck_tile/moe_sorting/moe_sorting_api.cpp index 0f25e17867..0cf600d2b4 100644 --- a/test/ck_tile/moe_sorting/moe_sorting_api.cpp +++ b/test/ck_tile/moe_sorting/moe_sorting_api.cpp @@ -209,7 +209,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -227,7 +227,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #if MOE_SORTING_SUPPORT_LARGE_EXPERT #define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -283,7 +283,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ const auto lds_size = kernel::GetSmemSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() #define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ @@ -334,15 +334,15 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi } \ } -#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \ - [&]() { \ - using problem_ = \ - ck_tile::MoeSortingClearWorkspaceProblem; \ - using kernel = ck_tile::MoeSortingClearWorkspaceKernel; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \ + [&]() { \ + using problem_ = \ + ck_tile::MoeSortingClearWorkspaceProblem; \ + using kernel = ck_tile::MoeSortingClearWorkspaceKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) diff --git a/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp index c94adc24c3..498d93b656 100644 --- a/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp +++ b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -115,11 +115,12 @@ struct matrix_core_swizzle_kernel __host__ void operator()(const ck_tile::stream_config& s) const { - ck_tile::kentry<<>>(a); + ck_tile::kentry<1, kernel><<>>(a); } struct kernel { + static constexpr ck_tile::index_t kBlockSize = BLOCK_SIZE; __device__ static constexpr auto get_src_dist() { using namespace ck_tile; diff --git a/test/ck_tile/permute/test_permute_util.hpp b/test/ck_tile/permute/test_permute_util.hpp index cca3148382..5494749541 100644 --- a/test/ck_tile/permute/test_permute_util.hpp +++ b/test/ck_tile/permute/test_permute_util.hpp @@ -54,11 +54,11 @@ float permute(permute_args a, const ck_tile::stream_config& s) auto kargs = Kernel::MakeKargs(a); - const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(a); + const dim3 blocks = Kernel::BlockSize(); - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + float ave_time = + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs)); return ave_time; } diff --git a/test/ck_tile/reduce/test_reduce2d.cpp b/test/ck_tile/reduce/test_reduce2d.cpp index 821d0a6c3e..ff807e52c9 100644 --- a/test/ck_tile/reduce/test_reduce2d.cpp +++ b/test/ck_tile/reduce/test_reduce2d.cpp @@ -82,18 +82,18 @@ class TestCkTileReduce : public ::testing::Test throw std::runtime_error("Wrong! Arguments not supported!\n"); } - ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0}, - ck_tile::make_kernel( - Kernel{}, - kGridSize, - kBlockSize, - 0, - static_cast(d_x_mem.GetDeviceBuffer()), - static_cast(d_y_mem.GetDeviceBuffer()), - input_shape_tuple, - input_strides_tuple, - kept_dims, - reduce_dims)); + ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, false, 0}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(d_x_mem.GetDeviceBuffer()), + static_cast(d_y_mem.GetDeviceBuffer()), + input_shape_tuple, + input_strides_tuple, + kept_dims, + reduce_dims)); // Get results back d_y_mem.FromDevice(h_y.data()); diff --git a/test/ck_tile/rmsnorm2d/generate.py b/test/ck_tile/rmsnorm2d/generate.py index 4296b7373e..1a1c842b3c 100644 --- a/test/ck_tile/rmsnorm2d/generate.py +++ b/test/ck_tile/rmsnorm2d/generate.py @@ -246,7 +246,7 @@ float rmsnorm2d_fwd_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp b/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp index 19310beb94..8929289cdb 100644 --- a/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp +++ b/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp @@ -57,5 +57,5 @@ float smoothquant_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } diff --git a/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp b/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp index 46c7abc697..7c90c8200c 100644 --- a/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp +++ b/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp @@ -13,11 +13,11 @@ \ auto kargs = kernel::MakeKargs(a); \ \ - const dim3 grids = kernel::GridSize(a); \ - constexpr dim3 blocks = kernel::BlockSize(); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(); \ \ - float ave_time = ck_tile::launch_kernel( \ - s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); \ + float ave_time = \ + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \ \ return ave_time; diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index 4a990f3309..dd9de36865 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -54,7 +54,6 @@ CSHUFFLE_EPILOGUE = """ ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, WarpM, diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 6d713bdcb8..7def4e2691 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -297,7 +297,7 @@ struct GemmKernel {{ throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); }} - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); const dim3 grids = {'Kernel::MaxOccupancyGridSize(stream)' if persistent == 'true' else 'Kernel::GridSize(args.M, args.N, args.k_batch)'}; if(stream.log_level_ > 0) @@ -346,12 +346,12 @@ struct GemmKernel {{ ave_time = ck_tile::launch_kernel_time_mask( stream, run_flush_cache, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{{}}, grids, blocks, 0, kargs)); }} else{{ ave_time = ck_tile::launch_kernel(stream, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{{}}, grids, blocks, 0, kargs)); }} return ave_time; diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py index 7d3629819d..9aca3407b1 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py @@ -56,7 +56,6 @@ CSHUFFLE_EPILOGUE = """ DsLayout, ELayout, CDEElementWise, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, WarpM, diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py index 6e65f6bf75..4b5acf1363 100755 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py @@ -443,8 +443,8 @@ struct GemmKernelMultiD {{ using Kernel = ck_tile::GemmKernelMultiD; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) {{ @@ -460,7 +460,7 @@ struct GemmKernelMultiD {{ }} ave_time = ck_tile::launch_kernel(stream, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{{}}, grids, blocks, 0, kargs)); return ave_time;