From 9fcc1ee9fd9730efd865f530afde505f2556954d Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Tue, 19 Aug 2025 01:08:31 +0800 Subject: [PATCH 01/46] Support Wave32 in CK_TILE - Part 1 (#2594) * Support wave32/wave64 in CK_TILE - Part 1 * remove blocksize in kernel launch * fix build error * fix clang format * fix clang format 2 * fix clang format 3 * fix fmha build error * fix fmha build 2 * fix fmha build 3 * fix build error 4 * address review comment * update change log * replace KernelBlockSize with kBlockSize * fix CI fail * fix clang format * address review comment and rebase code. * fix universal test fail --------- Co-authored-by: Lin, Qun Co-authored-by: Thomas Ning --- CHANGELOG.md | 1 + CMakeLists.txt | 2 - .../01_fmha/codegen/ops/fmha_batch_prefill.py | 4 +- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 24 +-- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 4 +- .../01_fmha/codegen/ops/fmha_fwd_appendkv.py | 4 +- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 8 +- .../codegen/ops/fmha_pagedkv_prefill.py | 4 +- example/ck_tile/02_layernorm2d/generate.py | 4 +- example/ck_tile/03_gemm/gemm_basic.cpp | 11 +- .../03_gemm/gemm_splitk_two_stage_reduce.cpp | 17 +- .../03_gemm/gemm_weight_preshuffle.cpp | 13 +- example/ck_tile/03_gemm/universal_gemm.cpp | 13 +- .../ck_tile/04_img2col/image_to_column.cpp | 5 +- example/ck_tile/05_reduce/reduce.cpp | 24 +-- .../matrix_core_swizzle_kernel.hpp | 5 +- example/ck_tile/06_permute/permute.cpp | 24 +-- .../09_topk_softmax/topk_softmax_api.cpp | 8 +- .../10_rmsnorm2d/example_rmsnorm2d_fwd.cpp | 5 +- example/ck_tile/10_rmsnorm2d/generate.py | 4 +- .../example_add_rmsnorm2d_rdquant_fwd.cpp | 5 +- ..._rmsnorm2d_rdquant_fwd_instance_common.hpp | 4 +- .../12_smoothquant/example_smoothquant.cpp | 5 +- .../instances/smoothquant_instance_common.hpp | 4 +- .../13_moe_sorting/moe_sorting_api.cpp | 24 +-- .../moe_smoothquant_instance_common.hpp | 4 +- .../instances/fused_moegemm_api_internal.hpp | 4 +- .../instances/fused_moesorting_api.cpp | 6 +- .../ck_tile/16_batched_gemm/batched_gemm.cpp | 7 +- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 9 +- example/ck_tile/18_flatmm/flatmm_basic.cpp | 15 +- .../ck_tile/18_flatmm/run_flatmm_example.inc | 14 +- .../19_gemm_multi_d/gemm_multi_d_fp16.cpp | 7 +- .../grouped_convolution_backward_weight.cpp | 7 +- .../grouped_convolution_forward.cpp | 7 +- .../21_elementwise/elementwise_example.cpp | 22 +-- .../elementwise_example_add_4d.cpp | 2 +- .../elementwise_example_transpose.cpp | 22 +-- .../elementwise_example_unary.cpp | 22 +-- .../batched_transpose_api.cpp | 8 +- .../38_block_scale_gemm/gemm_aquant_basic.cpp | 7 +- .../gemm_aquant_preshuffle.cpp | 7 +- example/ck_tile/39_copy/copy_basic.cpp | 20 +-- include/ck_tile/core/arch/arch.hpp | 19 +- include/ck_tile/core/config.hpp | 6 - include/ck_tile/host/kernel_launch.hpp | 12 +- .../add_rmsnorm2d_rdquant_fwd_kernel.hpp | 3 +- .../kernel/batched_transpose_kernel.hpp | 2 + .../batched_transpose_lds_problem.hpp | 5 +- .../elementwise/kernel/elementwise_kernel.hpp | 2 + .../ops/epilogue/cshuffle_epilogue.hpp | 3 +- .../ops/flatmm/kernel/flatmm_kernel.hpp | 16 +- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 7 +- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 9 +- .../fmha/kernel/fmha_fwd_appendkv_kernel.hpp | 3 +- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 7 +- .../fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 1 + .../fmha_fwd_splitkv_combine_kernel.hpp | 3 +- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 3 +- .../fused_moe/kernel/moe_sorting_kernel.hpp | 162 +++++++++--------- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 12 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 1 + .../ops/gemm/kernel/gemm_multi_d_kernel.hpp | 1 + .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 8 +- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 35 ++-- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 12 +- .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 12 +- .../gemm_pipeline_ag_bg_cr_comp_v5.hpp | 12 +- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 12 +- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 12 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 38 +++- .../wp_pipeline_agmem_bgmem_creg_v1.hpp | 6 +- .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 6 +- .../kernel/gemm_aquant_kernel.hpp | 20 +-- ...ped_convolution_backward_weight_kernel.hpp | 4 +- .../grouped_convolution_forward_kernel.hpp | 4 +- .../kernel/image_to_column_kernel.hpp | 3 +- .../pipeline/tile_image_to_column_shape.hpp | 7 +- .../kernel/layernorm2d_fwd_kernel.hpp | 6 +- .../permute/kernel/generic_permute_kernel.hpp | 2 +- .../ops/reduce/kernel/reduce2d_kernel.hpp | 2 + .../rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp | 1 + .../kernel/moe_smoothquant_kernel.hpp | 1 + .../smoothquant/kernel/smoothquant_kernel.hpp | 1 + .../kernel/topk_softmax_kernel.hpp | 4 +- include/ck_tile/ref/naive_attention.hpp | 6 +- ..._rmsnorm2d_rdquant_fwd_instance_common.hpp | 2 +- .../batched_gemm/test_batched_gemm_util.hpp | 7 +- .../test_batched_transpose.cpp | 10 +- .../elementwise/test_elementwise_1d.cpp | 24 ++- .../test_gemm_pipeline_basic_run_test.inc | 7 +- .../test_gemm_pipeline_universal_run_test.inc | 13 +- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 6 +- .../test_run_gemm_aquant_example.inc | 7 +- .../gemm_multi_d/test_gemm_multi_d_util.hpp | 7 +- .../test_gemm_pipeline_util.hpp | 5 +- .../grouped_gemm/test_grouped_gemm_util.hpp | 22 +-- .../test_tile_image_to_column.cpp | 4 +- test/ck_tile/layernorm2d/generate.py | 4 +- test/ck_tile/memory_copy/test_copy.cpp | 22 +-- test/ck_tile/memory_copy/test_copy.hpp | 3 +- .../moe_smoothquant_instance_common.hpp | 2 +- test/ck_tile/moe_sorting/moe_sorting_api.cpp | 24 +-- .../matrix_core_swizzle_kernel.hpp | 7 +- test/ck_tile/permute/test_permute_util.hpp | 8 +- test/ck_tile/reduce/test_reduce2d.cpp | 24 +-- test/ck_tile/rmsnorm2d/generate.py | 2 +- .../instances/smoothquant_instance_common.hpp | 2 +- .../topk_softmax/test_topk_softmax_api.cpp | 8 +- tile_engine/ops/gemm/codegen_utils.py | 1 - tile_engine/ops/gemm/gemm_instance_builder.py | 6 +- .../gemm_multi_d_codegen_utils.py | 1 - .../gemm_multi_d_instance_builder.py | 6 +- 113 files changed, 610 insertions(+), 531 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c09271edc..1246248eac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ None * Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced. * Number of instances in instance factory for grouped convolution backward weight NGCHW/GKYXC/NGKHW has been reduced. * Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced. +* Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK_TILE (#2594) ### Known issues diff --git a/CMakeLists.txt b/CMakeLists.txt index 07d2e166bb..35ebba8085 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -327,7 +327,6 @@ endif() if(USE_OPT_GFX11) add_compile_options(-mcumode) add_compile_options(-mno-wavefrontsize64) - add_compile_definitions(CK_TILE_WAVE32_ENABLED) message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}") endif() @@ -339,7 +338,6 @@ endif() if(USE_OPT_GFX12 AND (SUPPORTED_GPU_TARGETS MATCHES "gfx12")) add_compile_options(-mno-wavefrontsize64) - add_compile_definitions(CK_TILE_WAVE32_ENABLED) message(STATUS "CK compiled with USE_OPT_GFX12 set to ${USE_OPT_GFX12}") endif() diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 5d55e8bc36..0d8f366d8a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -110,9 +110,9 @@ float fmha_batch_prefill_(const ck_tile::stream_config& s, fmha_b if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index bb3a0587e7..0391191fb2 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -136,10 +136,10 @@ float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} template <> @@ -148,9 +148,9 @@ void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_co {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( ck_tile::stream_config{{s.stream_id_}}); }} @@ -425,10 +425,10 @@ float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} template <> @@ -436,9 +436,9 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_co {{ using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( ck_tile::stream_config{{s.stream_id_}}); }} @@ -530,10 +530,10 @@ float fmha_bwd_convert_dq_(const ck_tile::stream_confi if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} template <> @@ -542,9 +542,9 @@ void fmha_bwd_convert_dq_oneshot_(const ck_tile::strea {{ using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( ck_tile::stream_config{{s.stream_id_}}); }} diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index f614f42e6b..e59147a4f3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -110,9 +110,9 @@ float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 2e5bc2bd3d..0ebeaddf9c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -60,9 +60,9 @@ float fmha_fwd_appendkv_(const ck_tile::stream_config& s, fmha_fw if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index b2d962cd74..1dd8f0e3c6 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -108,9 +108,9 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ using k_ = fmha_kernel; auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); }} }}; }} @@ -208,9 +208,9 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ using k_ = fmha_kernel; auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); }} }}; }} diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 650ebaf80e..e468e82ed5 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -109,9 +109,9 @@ float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index d77582630a..c4366f6662 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -235,7 +235,7 @@ float layernorm2d_fwd_(const S& s, A a) using Kernel = ck_tile::Layernorm2dFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); @@ -243,7 +243,7 @@ float layernorm2d_fwd_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 25781a4ae8..8cdbe39e86 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -65,7 +65,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -81,8 +80,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -100,10 +99,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) << std::endl; } - float ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index a4a8039288..f42135a0b5 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. #include @@ -208,7 +208,6 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& DsLayout, ELayout, CDEElementWise, - UniversalGemmProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GemmConfig::M_Warp, @@ -232,7 +231,7 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& { grids = Kernel::GridSize(args.M, args.N, args.k_batch); } - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -279,15 +278,13 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& ave_time = ck_tile::launch_kernel_time_mask( s, run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } return ave_time; }; @@ -373,7 +370,7 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config float ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{}, kGridSize, kBlockSize, diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index 2057f1e4f5..0018db2c99 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -103,7 +103,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) DsLayout, ELayout, CDEElementWise, - UniversalGemmProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GemmConfig::M_Warp, @@ -126,7 +125,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { grids = Kernel::GridSize(args.M, args.N, args.k_batch); } - constexpr dim3 blocks = Kernel::BlockSize(); + dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -172,15 +171,13 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ave_time = ck_tile::launch_kernel_time_mask( s, run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } return ave_time; }; diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 149a8c2f0c..4e01710b4d 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -103,7 +103,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) DsLayout, ELayout, CDEElementWise, - UniversalGemmProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GemmConfig::M_Warp, @@ -127,7 +126,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { grids = Kernel::GridSize(args.M, args.N, args.k_batch); } - constexpr dim3 blocks = Kernel::BlockSize(); + dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -173,15 +172,13 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ave_time = ck_tile::launch_kernel_time_mask( s, run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } return ave_time; }; diff --git a/example/ck_tile/04_img2col/image_to_column.cpp b/example/ck_tile/04_img2col/image_to_column.cpp index 299a2f3444..22b5d640d8 100644 --- a/example/ck_tile/04_img2col/image_to_column.cpp +++ b/example/ck_tile/04_img2col/image_to_column.cpp @@ -55,13 +55,12 @@ float image_to_column(const image_to_column_traits& traits, args.N * args.output_spatial_lengths[0] * args.output_spatial_lengths[1], args.filter_spatial_lengths[0] * args.filter_spatial_lengths[1] * args.C, args.G); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 2; float ave_time = ck_tile::launch_kernel( - stream_conf, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + stream_conf, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; } diff --git a/example/ck_tile/05_reduce/reduce.cpp b/example/ck_tile/05_reduce/reduce.cpp index cf816caa88..a110c2f98d 100644 --- a/example/ck_tile/05_reduce/reduce.cpp +++ b/example/ck_tile/05_reduce/reduce.cpp @@ -94,18 +94,18 @@ bool run(const ck_tile::ArgParser& arg_parser) throw std::runtime_error("Wrong! Arguments not supported!\n"); } - float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, - ck_tile::make_kernel( - Kernel{}, - kGridSize, - kBlockSize, - 0, - static_cast(x_buf.GetDeviceBuffer()), - static_cast(y_buf.GetDeviceBuffer()), - input_shape, - input_strides, - kept_dim, - reduce_dims)); + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + input_shape, + input_strides, + kept_dim, + reduce_dims)); std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C; diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp index 688f4f3d50..d486196fc3 100644 --- a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -115,11 +115,12 @@ struct matrix_core_swizzle_kernel __host__ void operator()(const ck_tile::stream_config& s) const { - ck_tile::kentry<<>>(a); + ck_tile::kentry<1, kernel><<>>(a); } struct kernel { + static constexpr int kBlockSize = BLOCK_SIZE; __device__ static constexpr auto get_src_dist() { using namespace ck_tile; diff --git a/example/ck_tile/06_permute/permute.cpp b/example/ck_tile/06_permute/permute.cpp index 477ae370b9..aafece0f25 100644 --- a/example/ck_tile/06_permute/permute.cpp +++ b/example/ck_tile/06_permute/permute.cpp @@ -53,11 +53,11 @@ float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s) auto kargs = Kernel::MakeKargs(a); - const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(a); + const dim3 blocks = Kernel::BlockSize(); - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + float ave_time = + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs)); return ave_time; } @@ -69,11 +69,11 @@ float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s) auto kargs = Kernel::MakeKargs(a); - const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(a); + const dim3 blocks = Kernel::BlockSize(); - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + float ave_time = + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs)); return ave_time; } @@ -85,11 +85,11 @@ float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s) auto kargs = Kernel::MakeKargs(a); - const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(a); + const dim3 blocks = Kernel::BlockSize(); - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + float ave_time = + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs)); return ave_time; } diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp index 249a307b81..c2bad24cfe 100644 --- a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp +++ b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp @@ -13,11 +13,11 @@ \ auto kargs = kernel::MakeKargs(a); \ \ - const dim3 grids = kernel::GridSize(a); \ - constexpr dim3 blocks = kernel::BlockSize(); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(); \ \ - float ave_time = ck_tile::launch_kernel( \ - s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); \ + float ave_time = \ + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \ \ return ave_time; diff --git a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp index e0a71452ea..511efeeaec 100644 --- a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -138,12 +138,11 @@ bool run(const ck_tile::ArgParser& arg_parser) auto kargs = Kernel::MakeKargs(args); const dim3 grids = Kernel::GridSize(args); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat}; - ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::launch_kernel(s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); bool pass = true; diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index b0ba400af1..ea8dfdf9ce 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -249,7 +249,7 @@ float rmsnorm2d_fwd_(const S& s, A a) using Kernel = ck_tile::Rmsnorm2dFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); @@ -257,7 +257,7 @@ float rmsnorm2d_fwd_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp index 449bc17e04..ace5fe0c4f 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -136,12 +136,11 @@ bool run(const ck_tile::ArgParser& arg_parser) auto kargs = Kernel::MakeKargs(args); const dim3 grids = Kernel::GridSize(args); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat}; - ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::launch_kernel(s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); bool pass = true; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp index 25b10e1dc4..d997596414 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp @@ -58,7 +58,7 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) using Kernel = ck_tile::AddRmsnorm2dRdquantFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); @@ -66,5 +66,5 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } diff --git a/example/ck_tile/12_smoothquant/example_smoothquant.cpp b/example/ck_tile/12_smoothquant/example_smoothquant.cpp index 5fcacacee8..e688947d71 100644 --- a/example/ck_tile/12_smoothquant/example_smoothquant.cpp +++ b/example/ck_tile/12_smoothquant/example_smoothquant.cpp @@ -126,12 +126,11 @@ bool run(const ck_tile::ArgParser& arg_parser) auto kargs = Kernel::MakeKargs(args); const dim3 grids = Kernel::GridSize(args); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto s = ck_tile::stream_config{nullptr, true, 1, warmup, repeat}; - ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::launch_kernel(s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); bool pass = true; diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp b/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp index 555159566e..873a474afb 100644 --- a/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp @@ -50,7 +50,7 @@ float smoothquant_(const S& s, A a) using Kernel = ck_tile::Smoothquant; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); @@ -58,5 +58,5 @@ float smoothquant_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index a71c5e51a6..d614b8462a 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -209,7 +209,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -227,7 +227,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #if MOE_SORTING_SUPPORT_LARGE_EXPERT #define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -283,7 +283,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ const auto lds_size = kernel::GetSmemSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() #define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ @@ -334,15 +334,15 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi } \ } -#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \ - [&]() { \ - using problem_ = \ - ck_tile::MoeSortingClearWorkspaceProblem; \ - using kernel = ck_tile::MoeSortingClearWorkspaceKernel; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \ + [&]() { \ + using problem_ = \ + ck_tile::MoeSortingClearWorkspaceProblem; \ + using kernel = ck_tile::MoeSortingClearWorkspaceKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp index 885d9ff7bf..607217ea52 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp @@ -53,7 +53,7 @@ float moe_smoothquant_(const S& s, A a) using Kernel = ck_tile::MoeSmoothquant; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); @@ -61,5 +61,5 @@ float moe_smoothquant_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp index 6e54df9fde..9d1675386f 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp @@ -53,7 +53,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) using f_kernel = ck_tile::FusedMoeGemmKernel; const dim3 grids = f_kernel::GridSize(a); - constexpr dim3 blocks = f_kernel::BlockSize(); + const dim3 blocks = f_kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; static int printed = 0; @@ -66,5 +66,5 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) } return ck_tile::launch_kernel( - s, ck_tile::make_kernel(f_kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(f_kernel{}, grids, blocks, 0, kargs)); } diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 5f87393a0a..441aa84edf 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -213,7 +213,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -231,7 +231,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #if MOE_SORTING_SUPPORT_LARGE_EXPERT #define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -287,7 +287,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ const auto lds_size = kernel::GetSmemSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() #define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index 9616abb800..09ba010e00 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -142,7 +142,6 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre DsLayout, CLayout, CDEElementWise, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -156,8 +155,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre using Kernel = ck_tile::BatchedGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -176,7 +175,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre } ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index a821af0649..1e6844261f 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -82,7 +82,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GemmConfig::M_Warp, @@ -92,9 +91,9 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, GemmConfig::K_Warp_Tile, UniversalGemmProblem::TransposeC, memory_operation>>; - using Kernel = ck_tile::GroupedGemmKernel; - constexpr dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); if(s.log_level_ > 0) { @@ -105,7 +104,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{}, grids, blocks, diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 50bf791207..93117e5b75 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -101,7 +101,6 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c DsLayout, ELayout, CDEElementWise, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, FlatmmConfig::M_Warp, @@ -119,8 +118,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -171,15 +170,13 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c ave_time = ck_tile::launch_kernel_time_mask( s, run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } return ave_time; }; diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index 8f39b07be5..013db6715d 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -42,7 +42,9 @@ auto shuffle_b(const ck_tile::HostTensor& t) assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; - constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4; + + int divisor = ck_tile::is_wave32() ? (FlatmmConfig::N_Warp_Tile == 32 ? 1 : 2) + : (FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4); ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, FlatmmConfig::N_Warp_Tile, k_ / FlatmmConfig::K_Warp_Tile, @@ -213,6 +215,16 @@ int run_flatmm_example_with_layouts(int argc, ck_tile::FillUniformDistribution{1.f, 1.f}(a_host); ck_tile::FillUniformDistribution{1.f, 1.f}(b_origin_host); } + else if(init_method == 3) + { + ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_origin_host); + } + else if(init_method == 4) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_origin_host); + } else { a_host.SetZero(); diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp index d7bf2b5c42..fc52cb66cc 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp @@ -146,7 +146,6 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& DsLayout, CLayout, CDEElementWise, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -160,8 +159,8 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& using Kernel = ck_tile::GemmKernelMultiD; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -176,7 +175,7 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& } ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp index 67db775e09..debbb6bc0c 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp @@ -78,7 +78,6 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, typename GroupedConvTraitsType::ImplicitGemmDsLayout, ck_tile::tensor_layout::gemm::RowMajor, CDEElementWise, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -98,8 +97,8 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, ConvEpilogue>; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(kargs); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -123,7 +122,7 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, float ave_time = ck_tile::launch_kernel_time_mask( s, Kernel::Preprocess(kargs, s), - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp index ce19c77bc1..6700970583 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp @@ -77,7 +77,6 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til typename GroupedConvTraitsType::ImplicitGemmDsLayout, ck_tile::tensor_layout::gemm::RowMajor, CDEElementWise, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -97,8 +96,8 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til ConvEpilogue>; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(kargs); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -120,7 +119,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til } float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/21_elementwise/elementwise_example.cpp b/example/ck_tile/21_elementwise/elementwise_example.cpp index 469345b46c..2cc539e117 100644 --- a/example/ck_tile/21_elementwise/elementwise_example.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example.cpp @@ -167,17 +167,17 @@ bool run(const ck_tile::ArgParser& arg_parser) } // 4. Run the kernel - float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, - ck_tile::make_kernel( - Kernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(N, 1), // Input Stride - ck_tile::make_tuple(N, 1), // Output Stride - input_tensors, - static_cast(y_buf.GetDeviceBuffer()))); + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(N, 1), // Input Stride + ck_tile::make_tuple(N, 1), // Output Stride + input_tensors, + static_cast(y_buf.GetDeviceBuffer()))); std::cout << "Average time: " << ave_time << " ms" << std::endl; diff --git a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp index 4a031265c9..7087d092a2 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp @@ -113,7 +113,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // Run the kernel float ave_time = launch_kernel( ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{}, kGridSize, kBlockSize, diff --git a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp index aff74ae250..28cdaf27b9 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp @@ -112,17 +112,17 @@ bool run(const ck_tile::ArgParser& arg_parser) } // 4. Run the kernel - float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, - ck_tile::make_kernel( - Kernel{}, - kGridSize, - kBlockSize, - 0, // Shared memory - op_lengths, // Logical dimensions for the operation (M, N) - input_strides, // Strides for input tensor(s) - output_strides, // Strides for output tensor (N, M) - input_tensors, - static_cast(y_buf.GetDeviceBuffer()))); + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, // Shared memory + op_lengths, // Logical dimensions for the operation (M, N) + input_strides, // Strides for input tensor(s) + output_strides, // Strides for output tensor (N, M) + input_tensors, + static_cast(y_buf.GetDeviceBuffer()))); std::cout << "Average time: " << ave_time << " ms" << std::endl; diff --git a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp index d83592a033..782d3da24d 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp @@ -99,17 +99,17 @@ bool run(const ck_tile::ArgParser& arg_parser) } // 4. Run the kernel - float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, - ck_tile::make_kernel( - Kernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(N, 1), // Input Stride - ck_tile::make_tuple(N, 1), // Output Stride - input_tensors, - static_cast(y_buf.GetDeviceBuffer()))); + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(N, 1), // Input Stride + ck_tile::make_tuple(N, 1), // Output Stride + input_tensors, + static_cast(y_buf.GetDeviceBuffer()))); std::cout << "Average time: " << ave_time << " ms" << std::endl; diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp index 1f0f0b9bc1..931a9dfa3c 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp @@ -74,8 +74,8 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con auto kargs = kernel::MakeKargs(a); - const dim3 grids = kernel::GridSize(a); - constexpr dim3 blocks = kernel::BlockSize(); + const dim3 grids = kernel::GridSize(a); + const dim3 blocks = kernel::BlockSize(); printf("Pipeline: %d\n", Config::kPipelineId); printf("Grid: x=%u y=%u z=%u\n", grids.x, grids.y, grids.z); @@ -96,8 +96,8 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con printf("Launching Kernel...\n"); - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); + float ave_time = + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); printf("Kernel finished...\n"); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp index 2ac08c7343..2ea8530cb2 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp @@ -96,7 +96,6 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -111,8 +110,8 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(args.k_batch != 1) { @@ -136,7 +135,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s } float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp index f4f1aa98d3..4adc3df94b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp @@ -96,7 +96,6 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -111,8 +110,8 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(args.k_batch != 1) { @@ -136,7 +135,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s } float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/39_copy/copy_basic.cpp b/example/ck_tile/39_copy/copy_basic.cpp index 460036a641..3f36d7f4f0 100644 --- a/example/ck_tile/39_copy/copy_basic.cpp +++ b/example/ck_tile/39_copy/copy_basic.cpp @@ -99,16 +99,16 @@ bool run(const ck_tile::ArgParser& arg_parser) << ")" << std::endl; // Launch kernel - float ave_time = launch_kernel( - ck_tile::stream_config{nullptr, true, warmup, repeat, 1}, - ck_tile::make_kernel(Kernel{}, - kGridSize, - kBlockSize, - 0, - static_cast(x_buf.GetDeviceBuffer()), - static_cast(y_buf.GetDeviceBuffer()), - m, - n)); + float ave_time = + launch_kernel(ck_tile::stream_config{nullptr, true, warmup, repeat, 1}, + ck_tile::make_kernel<1>(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + m, + n)); // Calculate and print performance metrics std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m * n; diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index ec5f49108e..234929d6e6 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -60,13 +60,30 @@ enum struct memory_operation_enum : std::uint16_t CK_TILE_HOST_DEVICE constexpr index_t get_warp_size() { -#if defined(__GFX9__) || (!defined(__HIP_DEVICE_COMPILE__) && !defined(CK_TILE_WAVE32_ENABLED)) +#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) return 64; #else return 32; #endif } +CK_TILE_HOST bool is_wave32() +{ + hipDeviceProp_t props{}; + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return false; + } + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return false; + } + return props.major > 9; +} + CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; } CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; } diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index f94065da2b..7b5b862cb1 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -274,12 +274,6 @@ #define CK_TILE_WA_ISSUE_2028 0 #endif -#ifndef CK_TILE_WAVE32_ENABLED -#if defined(__gfx11__) || defined(__gfx12__) -#define CK_TILE_WAVE32_ENABLED -#endif -#endif - // Y pointed to R, we don't see a valuable use case. // Will enforce encoding to check Y not pointed to R if set to zero #ifndef CK_TILE_ENC_SUPPORT_Y_TO_R diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index 91ac3d5a0b..368a0594c5 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -15,9 +15,9 @@ namespace ck_tile { -template +template #if CK_TILE_USE_LAUNCH_BOUNDS -__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu) +__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) #endif __global__ void kentry(Args... args) { @@ -35,15 +35,11 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu) // // the "static __device__ operator()(some_arg)" is the entry point of KernelImpl // -template +template CK_TILE_HOST auto make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { - const auto kernel = kentry; - + const auto kernel = kentry; return [=](const stream_config& s) { kernel<<>>(args...); }; diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp index f06910db3d..c7717f08cd 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -53,6 +53,7 @@ struct AddRmsnorm2dRdquantFwd static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; diff --git a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp index a4150e8d84..b0f48f6c5b 100644 --- a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp +++ b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp @@ -34,6 +34,8 @@ struct BatchedTransposeKernel using Type = typename Problem::DataType; + static constexpr index_t kBlockSize = Problem::kBlockSize; + struct BatchedTransposeKargs { const void* p_input; diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp index 45803ae2da..b791bf9727 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -20,11 +20,10 @@ struct BatchedTransposeLdsProblem static constexpr index_t kRowWarps_ = NumWarps::at(number<0>{}); static constexpr index_t kColWarps_ = NumWarps::at(number<1>{}); - static constexpr index_t kBlockSize_ = get_warp_size() * kRowWarps_ * kColWarps_; static constexpr index_t kRowPerBlock_ = BlockTile::at(number<0>{}); static constexpr index_t kColPerBlock_ = BlockTile::at(number<1>{}); - static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kBlockSize = get_warp_size() * kRowWarps_ * kColWarps_; // warps per block static constexpr index_t kLeadNumWarps = kColWarps_; static constexpr index_t kSecondNumWarps = kRowWarps_; diff --git a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp index 103468c5fa..2ec9414f42 100644 --- a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp +++ b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp @@ -20,6 +20,8 @@ struct ElementWiseKernel using YDataType = ck_tile::remove_cvref_t; using ElementWiseOperation = ck_tile::remove_cvref_t; + static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize; + template CK_TILE_DEVICE void operator()(Dims lens, Dims input_strides, diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index f773de9e7e..1d0a4c42f4 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -17,7 +17,6 @@ template ; using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size(); static constexpr index_t kMPerBlock = kM_; static constexpr index_t kNPerBlock = kN_; static constexpr index_t MWave = MWave_; diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index 76df056ea6..20ca976590 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -91,13 +91,13 @@ struct FlatmmKernel using FlatmmPipeline = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; - using DsLayout = remove_cvref_t; - using DsDataType = remove_cvref_t; - static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + using DsDataType = remove_cvref_t; + static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -127,7 +127,7 @@ struct FlatmmKernel return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr KernelArgs MakeKernelArgs(const FlatmmHostArgs& hostArgs) diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 20783ea8bf..3ca79fc46e 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -237,15 +237,16 @@ struct UniversalFlatmmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad() { - using TileShape = typename Problem::BlockGemmShape; + using TileShape = typename Problem::BlockGemmShape; + constexpr index_t scale = get_warp_size() == 32 ? 2 : 1; if constexpr(TileShape::WarpTile::at(I1) == 32) { - return TileShape::WarpTile::at(I2) / 2; + return TileShape::WarpTile::at(I2) * scale / 2; } else { static_assert(TileShape::WarpTile::at(I1) == 16); - return TileShape::WarpTile::at(I2) / 4; + return TileShape::WarpTile::at(I2) * scale / 4; } } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 0d0959ba27..2850ce3379 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -24,9 +24,10 @@ namespace ck_tile { template struct FmhaBatchPrefillWithPagedKVCacheKernel { - using FmhaPipeline = ck_tile::remove_cvref_t; - using EpiloguePipeline = ck_tile::remove_cvref_t; - static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; static_assert(kBlockPerCu > 0); static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 5129f83532..81075d0ec6 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,6 +16,7 @@ struct FmhaFwdAppendKVKernel using FmhaPipeline = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 5b3d38d3e7..6d35afaa26 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -25,9 +25,10 @@ namespace ck_tile { template struct FmhaFwdKernel { - using FmhaPipeline = ck_tile::remove_cvref_t; - using EpiloguePipeline = ck_tile::remove_cvref_t; - static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; static_assert(kBlockPerCu > 0); static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index d8cd006c60..9a3e8ac304 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -30,6 +30,7 @@ struct FmhaFwdPagedKVKernel using EpiloguePipeline = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index 99ee912db9..ee1236d465 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,6 +14,7 @@ struct FmhaFwdSplitKVCombineKernel static constexpr index_t kNumWarps = FmhaPipeline::kNumWarps; static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); static constexpr index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 501aa26667..c50537f3fe 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -26,6 +26,7 @@ struct FmhaFwdSplitKVKernel using EpiloguePipeline = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index a5f9f31d6a..faeb5cf6b3 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -213,7 +213,7 @@ struct MoeSortingKernel using Hargs = MoeSortingHostArgs; - static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t kBlockSize = 256; static constexpr index_t OCCUPANCY = 2; // hard coded struct Kargs @@ -487,8 +487,8 @@ struct MoeSortingKernel vector_type* p_buf = reinterpret_cast(buf); auto zero_ = vector_type{0}; - for(long_index_t i = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x; i < total_elems; - i += (gridDim.x - 1) * BLOCK_SIZE) + for(long_index_t i = (blockIdx.x - 1) * kBlockSize + threadIdx.x; i < total_elems; + i += (gridDim.x - 1) * kBlockSize) { p_buf[i] = zero_; } @@ -1419,7 +1419,7 @@ template struct MoeSortingClearWorkspaceKernel { using Problem = remove_cvref_t; - static constexpr index_t BLOCK_SIZE = Problem::BlockSize; + static constexpr index_t kBlockSize = Problem::BlockSize; static constexpr index_t OCCUPANCY = Problem::Occu; using Hargs = MoeSortingHostArgs; @@ -1461,7 +1461,7 @@ struct MoeSortingClearWorkspaceKernel CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; } - CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); } // in byte CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; } @@ -1499,8 +1499,8 @@ struct MoeSortingClearWorkspaceKernel vector_type* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); auto zero_ = vector_type{0}; - for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elems; - i += gridDim.x * BLOCK_SIZE) + for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elems; + i += gridDim.x * kBlockSize) { p_expert_mesh[i] = zero_; } @@ -1560,7 +1560,7 @@ struct MoeSortingMultiPhaseKernel_P0 using WeightType = typename Problem::WeightType; using MeshType = typename Problem::MeshType; - static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t kBlockSize = 256; static constexpr index_t OCCUPANCY = 2; // hard coded typedef MoeSortingHostArgs MoeSortingKargs; @@ -1604,7 +1604,7 @@ struct MoeSortingMultiPhaseKernel_P0 CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; } - CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); } // in byte CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; } @@ -1647,8 +1647,8 @@ struct MoeSortingMultiPhaseKernel_P0 index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; #pragma unroll Problem::SubTokenTile - for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; - i += gridDim.x * BLOCK_SIZE) + for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elem; + i += gridDim.x * kBlockSize) { auto x = p_topk_ids[i]; static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) { @@ -1678,7 +1678,7 @@ struct MoeSortingMultiPhaseKernel_P1 using WeightType = typename Problem::WeightType; using MeshType = typename Problem::MeshType; - static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t kBlockSize = 256; static constexpr index_t OCCUPANCY = 2; // hard coded typedef MoeSortingHostArgs MoeSortingKargs; @@ -1709,12 +1709,12 @@ struct MoeSortingMultiPhaseKernel_P1 CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); } - CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); } // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - return BLOCK_SIZE / get_warp_size() * sizeof(IndexType); + return kBlockSize / get_warp_size() * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -1756,7 +1756,7 @@ struct MoeSortingMultiPhaseKernel_P1 r_t* p_expert_mesh = reinterpret_cast( reinterpret_cast(kargs.p_expert_mesh) + eid * mesh_stride); - int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize; if constexpr(Problem::LocalExpertMasking) { @@ -1768,7 +1768,7 @@ struct MoeSortingMultiPhaseKernel_P1 index_t cnt = 0; // per-wave cnt for(int i = 0; i < loops; i++) { - int position = i * BLOCK_SIZE + threadIdx.x; + int position = i * kBlockSize + threadIdx.x; r_t v{0}; if(position < (mesh_stride / index_pack)) v = p_expert_mesh[position]; @@ -1792,7 +1792,7 @@ struct MoeSortingMultiPhaseKernel_P1 if(threadIdx.x == 0) { index_t c = 0; - for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++) + for(auto i = 0; i < (kBlockSize / get_warp_size()); i++) { c += s[i]; } @@ -1811,7 +1811,7 @@ struct MoeSortingMultiPhaseKernel_P01 using WeightType = typename Problem::WeightType; using MeshType = typename Problem::MeshType; - static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t kBlockSize = 256; static constexpr index_t OCCUPANCY = 2; // hard coded typedef MoeSortingHostArgs MoeSortingKargs; @@ -1878,12 +1878,12 @@ struct MoeSortingMultiPhaseKernel_P01 CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; } - CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto WGCounts(const Hargs& h) { index_t total_elem = h.tokens * h.topk / Problem::SubTokenTile; - index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE; + index_t elem_cnt = (total_elem + kBlockSize - 1) / kBlockSize; // no more than grid_size return min(elem_cnt, GridSize(h)); @@ -1892,7 +1892,7 @@ struct MoeSortingMultiPhaseKernel_P01 // in byte CK_TILE_HOST static constexpr auto GetSmemSize() { - return BLOCK_SIZE / get_warp_size() * sizeof(IndexType); + return kBlockSize / get_warp_size() * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -1921,7 +1921,7 @@ struct MoeSortingMultiPhaseKernel_P01 if constexpr(Problem::LocalToken) { index_t total_elem = rounded_tokens * kargs.topk / Problem::SubTokenTile; - index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE; + index_t elem_cnt = (total_elem + kBlockSize - 1) / kBlockSize; // no more than grid_size return min(elem_cnt, kargs.wg_count); @@ -1940,8 +1940,8 @@ struct MoeSortingMultiPhaseKernel_P01 index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; #pragma unroll Problem::SubTokenTile - for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; - i += BLOCK_SIZE * gridDim.x) + for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elem; + i += kBlockSize * gridDim.x) { auto x = p_topk_ids[i]; static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) { @@ -1996,7 +1996,7 @@ struct MoeSortingMultiPhaseKernel_P01 auto f_sum = [](auto x_, auto y_) { return x_ + y_; }; - int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + int loops = (kargs.mesh_stride / index_pack + kBlockSize - 1) / kBlockSize; if constexpr(Problem::LocalExpertMasking) { @@ -2008,7 +2008,7 @@ struct MoeSortingMultiPhaseKernel_P01 index_t cnt = 0; // per-wave cnt for(int i = 0; i < loops; i++) { - int position = i * BLOCK_SIZE + threadIdx.x; + int position = i * kBlockSize + threadIdx.x; r_t v{0}; if(position < (kargs.mesh_stride / index_pack)) v = p_expert_mesh[position]; @@ -2033,7 +2033,7 @@ struct MoeSortingMultiPhaseKernel_P01 if(threadIdx.x == 0) { index_t c = 0; - for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++) + for(auto i = 0; i < (kBlockSize / get_warp_size()); i++) { c += s[i]; } @@ -2055,7 +2055,7 @@ struct MoeSortingMultiPhaseKernel_P2 using WeightType = typename Problem::WeightType; using MeshType = typename Problem::MeshType; - static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t kBlockSize = 256; static constexpr index_t OCCUPANCY = 2; // hard coded typedef MoeSortingHostArgs MoeSortingKargs; @@ -2123,17 +2123,17 @@ struct MoeSortingMultiPhaseKernel_P2 return dim3(h.num_experts + get_num_cu() * OCCUPANCY); #else // use 1 block to cumsum - return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); + return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16)); #endif } - CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); } // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - // return 2 * BLOCK_SIZE * sizeof(IndexType); - return (4 + 2 * BLOCK_SIZE / get_warp_size()) * sizeof(IndexType); + // return 2 * kBlockSize * sizeof(IndexType); + return (4 + 2 * kBlockSize / get_warp_size()) * sizeof(IndexType); } // reduce single pixel within a wave @@ -2142,7 +2142,7 @@ struct MoeSortingMultiPhaseKernel_P2 if(blockIdx.x > 0) { #if MOE_SORTING_FMOE_2D_BUF - impl::moe_buf_set_zero_kernel_2d(kargs.p_moe_buf, + impl::moe_buf_set_zero_kernel_2d(kargs.p_moe_buf, kargs.tokens, kargs.moe_buf_interm_dim, kargs.moe_buf_elem_bytes, @@ -2150,7 +2150,7 @@ struct MoeSortingMultiPhaseKernel_P2 gridDim.x - 1); return; #else - impl::moe_buf_set_zero_kernel( + impl::moe_buf_set_zero_kernel( reinterpret_cast(kargs.p_moe_buf), kargs.moe_buf_bytes, blockIdx.x - 1); @@ -2167,7 +2167,7 @@ struct MoeSortingMultiPhaseKernel_P2 reinterpret_cast(kargs.p_total_tokens_post_pad); IndexType* p_sorted_expert_ids = reinterpret_cast(kargs.p_sorted_expert_ids); - const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; + const index_t loops = (kargs.num_experts + kBlockSize - 1) / kBlockSize; index_t wave_id = threadIdx.x / get_warp_size(); index_t lane_id = threadIdx.x % get_warp_size(); @@ -2176,7 +2176,7 @@ struct MoeSortingMultiPhaseKernel_P2 for(index_t i = 0; i < loops; i++) { - index_t position = i * BLOCK_SIZE + threadIdx.x; + index_t position = i * kBlockSize + threadIdx.x; IndexType a_ = 0; // token count for a expert IndexType b_ = 0; // mask for a expert if(position < kargs.num_experts) @@ -2221,15 +2221,15 @@ struct MoeSortingMultiPhaseKernel_P2 if(lane_id == get_warp_size() - 1) { s[4 + wave_id] = cumsum_a; - s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b; + s[4 + wave_id + kBlockSize / get_warp_size()] = cumsum_b; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { + static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev_a = s[4 + i_w]; - IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()]; + IndexType prev_b = s[4 + i_w + kBlockSize / get_warp_size()]; prev_a = wave_id > i_w ? prev_a : 0; // mask out prev_b = wave_id > i_w ? prev_b : 0; // mask out cumsum_a += prev_a; @@ -2240,7 +2240,7 @@ struct MoeSortingMultiPhaseKernel_P2 cumsum_a += prev_cumsum_a; cumsum_b += prev_cumsum_b; - if(threadIdx.x == BLOCK_SIZE - 1) + if(threadIdx.x == kBlockSize - 1) { s[2] = cumsum_a; // store the last cumsum s[3] = cumsum_b; @@ -2297,7 +2297,7 @@ struct MoeSortingMultiPhaseKernel_P3 using WeightType = typename Problem::WeightType; using MeshType = typename Problem::MeshType; - static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t kBlockSize = 256; static constexpr index_t OCCUPANCY = 2; // hard coded typedef MoeSortingHostArgs MoeSortingKargs; @@ -2341,12 +2341,12 @@ struct MoeSortingMultiPhaseKernel_P3 CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); } - CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); } // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - return (4 + BLOCK_SIZE / get_warp_size()) * sizeof(IndexType); + return (4 + kBlockSize / get_warp_size()) * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -2391,11 +2391,11 @@ struct MoeSortingMultiPhaseKernel_P3 } // cumsum one by one - int loops = (kargs.mesh_stride + BLOCK_SIZE - 1) / BLOCK_SIZE; + int loops = (kargs.mesh_stride + kBlockSize - 1) / kBlockSize; int prev_cumsum = 0; for(int i = 0; i < loops; i++) { - int i_token = i * BLOCK_SIZE + threadIdx.x; + int i_token = i * kBlockSize + threadIdx.x; IndexType x = 0; if(i_token < tokens) { @@ -2414,13 +2414,13 @@ struct MoeSortingMultiPhaseKernel_P3 __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { + static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; }); cumsum += prev_cumsum; // add previous round cumsum - if(threadIdx.x == BLOCK_SIZE - 1) + if(threadIdx.x == kBlockSize - 1) { s[0] = cumsum; } @@ -2441,7 +2441,7 @@ struct MoeSortingMultiPhaseKernel_P3 } } - for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE) + for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += kBlockSize) { #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor); @@ -2457,9 +2457,9 @@ namespace impl { // we use dynamic LDS size here CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_) { - constexpr index_t BLOCK_SIZE = 256; // hardcoded 256 + constexpr index_t kBlockSize = 256; // hardcoded 256 const index_t expert_cumsum_elem = num_experts_ + 1; - return (4 + 2 * BLOCK_SIZE / get_warp_size() + expert_cumsum_elem) * sizeof(int); + return (4 + 2 * kBlockSize / get_warp_size() + expert_cumsum_elem) * sizeof(int); } } // namespace impl @@ -2473,7 +2473,7 @@ struct MoeSortingMultiPhaseKernel_P23 using WeightType = typename Problem::WeightType; using MeshType = typename Problem::MeshType; - static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t kBlockSize = 256; static constexpr index_t OCCUPANCY = 2; // hard coded typedef MoeSortingHostArgs MoeSortingKargs; @@ -2563,18 +2563,18 @@ struct MoeSortingMultiPhaseKernel_P23 return dim3(h.num_experts + get_num_cu() * OCCUPANCY); #else // use 1 block to cumsum - // return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); - return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); + // return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16)); + return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16)); #endif } - CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); } // only use this at host ! CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h) { const auto smem_23 = impl::moe_sorting_get_smem_size_p23(h.num_experts); - const auto smem_sf = BLOCK_SIZE * 4 * sizeof(IndexType); + const auto smem_sf = kBlockSize * 4 * sizeof(IndexType); return max(smem_23, smem_sf); } @@ -2595,7 +2595,7 @@ struct MoeSortingMultiPhaseKernel_P23 if(static_cast(blockIdx.x) >= kargs.num_experts) { #if MOE_SORTING_FMOE_2D_BUF - impl::moe_buf_set_zero_kernel_2d(kargs.p_moe_buf, + impl::moe_buf_set_zero_kernel_2d(kargs.p_moe_buf, tokens, kargs.moe_buf_interm_dim, kargs.moe_buf_elem_bytes, @@ -2603,7 +2603,7 @@ struct MoeSortingMultiPhaseKernel_P23 gridDim.x - kargs.num_experts); return; #else - impl::moe_buf_set_zero_kernel( + impl::moe_buf_set_zero_kernel( reinterpret_cast(kargs.p_moe_buf), kargs.moe_buf_bytes, blockIdx.x - kargs.num_experts); @@ -2618,13 +2618,13 @@ struct MoeSortingMultiPhaseKernel_P23 const IndexType* p_local_expert_mask = static_cast(kargs.p_local_expert_mask); IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); - IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size(); + IndexType* p_expert_cumsum_smem = s + 4 + 2 * kBlockSize / get_warp_size(); IndexType* p_total_tokens_post_pad = reinterpret_cast(kargs.p_total_tokens_post_pad); IndexType* p_sorted_expert_ids = reinterpret_cast(kargs.p_sorted_expert_ids); - const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; + const index_t loops = (kargs.num_experts + kBlockSize - 1) / kBlockSize; index_t wave_id = threadIdx.x / get_warp_size(); index_t lane_id = threadIdx.x % get_warp_size(); @@ -2633,7 +2633,7 @@ struct MoeSortingMultiPhaseKernel_P23 for(index_t i = 0; i < loops; i++) { - index_t position = i * BLOCK_SIZE + threadIdx.x; + index_t position = i * kBlockSize + threadIdx.x; IndexType a_ = 0; // token count for a expert IndexType b_ = 0; // mask for a expert if(position < kargs.num_experts) @@ -2678,15 +2678,15 @@ struct MoeSortingMultiPhaseKernel_P23 if(lane_id == get_warp_size() - 1) { s[4 + wave_id] = cumsum_a; - s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b; + s[4 + wave_id + kBlockSize / get_warp_size()] = cumsum_b; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { + static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev_a = s[4 + i_w]; - IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()]; + IndexType prev_b = s[4 + i_w + kBlockSize / get_warp_size()]; prev_a = wave_id > i_w ? prev_a : 0; // mask out prev_b = wave_id > i_w ? prev_b : 0; // mask out cumsum_a += prev_a; @@ -2697,7 +2697,7 @@ struct MoeSortingMultiPhaseKernel_P23 cumsum_a += prev_cumsum_a; cumsum_b += prev_cumsum_b; - if(threadIdx.x == BLOCK_SIZE - 1) + if(threadIdx.x == kBlockSize - 1) { s[2] = cumsum_a; // store the last cumsum s[3] = cumsum_b; @@ -2758,7 +2758,7 @@ struct MoeSortingMultiPhaseKernel_P23 IndexType* s = reinterpret_cast(smem); MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); IndexType* p_sorted_token_ids = reinterpret_cast(kargs.p_sorted_token_ids); - IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size(); + IndexType* p_expert_cumsum_smem = s + 4 + 2 * kBlockSize / get_warp_size(); const WeightType* p_weights = static_cast(kargs.p_weights); WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); @@ -2795,13 +2795,13 @@ struct MoeSortingMultiPhaseKernel_P23 constexpr index_t index_pack = Problem::SubTokenTile; // always packed using r_t = ext_vector_t; // always use int32x4 using d_t = ext_vector_t; - int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize; int prev_cumsum = 0; for(int i = 0; i < loops; i++) { - int i_token_pack = i * BLOCK_SIZE + threadIdx.x; + int i_token_pack = i * kBlockSize + threadIdx.x; r_t x_v = 0; if(i_token_pack < (tokens + index_pack - 1) / index_pack) { @@ -2819,7 +2819,7 @@ struct MoeSortingMultiPhaseKernel_P23 static_for<0, index_pack, 1>{}([&](auto j_) { constexpr auto j = j_.value; - x_r[j] = reinterpret_cast(s)[threadIdx.x + j * BLOCK_SIZE]; + x_r[j] = reinterpret_cast(s)[threadIdx.x + j * kBlockSize]; }); } #else @@ -2830,7 +2830,7 @@ struct MoeSortingMultiPhaseKernel_P23 #pragma unroll for(int j = 0; j < index_pack / 2; j++) { - int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j * BLOCK_SIZE; + int i_token = i * kBlockSize * index_pack + threadIdx.x + j * kBlockSize; index_t x = x_d[j]; int i_topk = x - 1; // topk of this token int i_show = x != 0 ? 1 : 0; // has this token or not @@ -2845,13 +2845,13 @@ struct MoeSortingMultiPhaseKernel_P23 __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { + static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; }); cumsum += prev_cumsum; // add previous round cumsum - if(threadIdx.x == BLOCK_SIZE - 1) + if(threadIdx.x == kBlockSize - 1) { s[0] = cumsum; } @@ -2896,13 +2896,13 @@ struct MoeSortingMultiPhaseKernel_P23 __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { + static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; }); cumsum += prev_cumsum; // add previous round cumsum - if(threadIdx.x == BLOCK_SIZE - 1) + if(threadIdx.x == kBlockSize - 1) { s[0] = cumsum; } @@ -2912,10 +2912,10 @@ struct MoeSortingMultiPhaseKernel_P23 int position = cumsum - cumsum_store; static_for<0, index_pack, 1>{}([&](auto j_) { constexpr auto j = j_.value; - // int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j * - // BLOCK_SIZE; + // int i_token = i * kBlockSize * index_pack + threadIdx.x + j * + // kBlockSize; int i_token = - i * BLOCK_SIZE * index_pack + threadIdx.x * index_pack + j; + i * kBlockSize * index_pack + threadIdx.x * index_pack + j; if(i_show[j]) { @@ -2932,7 +2932,7 @@ struct MoeSortingMultiPhaseKernel_P23 }); #if 0 - int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x * 2 + j * BLOCK_SIZE * 2; + int i_token = i * kBlockSize * index_pack + threadIdx.x * 2 + j * kBlockSize * 2; index_t x = x_d[j]; index_t x0 = static_cast(x & 0xffff); index_t x1 = static_cast(x >> 16); @@ -2951,13 +2951,13 @@ struct MoeSortingMultiPhaseKernel_P23 __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { + static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; }); cumsum += prev_cumsum; // add previous round cumsum - if(threadIdx.x == BLOCK_SIZE - 1) + if(threadIdx.x == kBlockSize - 1) { s[0] = cumsum; } @@ -2996,7 +2996,7 @@ struct MoeSortingMultiPhaseKernel_P23 } } - for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE) + for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += kBlockSize) { #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor); diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 9c1ce73eac..fcfbf9635f 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -64,6 +64,7 @@ struct BatchedGemmKernel /// functions. using UniversalGemmKernel = UniversalGemmKernel; + static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize; using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; @@ -121,9 +122,16 @@ struct BatchedGemmKernel return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch); } - CK_TILE_HOST static constexpr auto BlockSize() -> dim3 + CK_TILE_HOST static auto BlockSize() -> dim3 { - return dim3(UniversalGemmKernel::KernelBlockSize); + if(ck_tile::is_wave32()) + { + return dim3(UniversalGemmKernel::kBlockSize / 2); + } + else + { + return dim3(UniversalGemmKernel::kBlockSize); + } } CK_TILE_HOST static constexpr BatchedGemmKernelArgs diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 079d3972d1..e37b4f36d4 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -113,6 +113,7 @@ struct GemmKernel static constexpr index_t NumATensor = 1; static constexpr index_t NumBTensor = 1; + static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize; CK_TILE_HOST static auto GetName() -> const std::string { diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp index 34340008d4..34c4e72b22 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp @@ -86,6 +86,7 @@ struct GemmKernelMultiD /// functions. using UniversalGemmKernel = UniversalGemmKernel; + static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize; using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 477a87d42f..c35435ee5e 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -128,7 +128,7 @@ struct GroupedGemmKernel using OffsetTile1DPartitioner = OffsettedTile1DPartitioner; using Kernel = GroupedGemmKernel; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -155,7 +155,7 @@ struct GroupedGemmKernel return group_count * sizeof(GemmTransKernelArg); } - CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(kBlockSize); } /** * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. @@ -166,10 +166,10 @@ struct GroupedGemmKernel CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 { using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*; - const auto kernel = kentry; + const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>; int occupancy; HIP_CHECK_ERROR( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0)); + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0)); const int grid_size = get_available_compute_units(s) * occupancy; return dim3(grid_size, 1, 1); } diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index ec1cc2ddb4..8117d65758 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -196,7 +196,7 @@ struct UniversalGemmKernel using ELayout = remove_cvref_t; using EDataType = remove_cvref_t; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; // Get the persistent kernel if the pipeline has it available struct has_persistent_kernel @@ -275,15 +275,26 @@ struct UniversalGemmKernel CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 { using Kernel = UniversalGemmKernel; - const auto kernel = kentry; + const auto kernel = kentry<1, Kernel, KernelArgs>; int occupancy; hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0)); + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0)); + const int grid_size = get_available_compute_units(s) * occupancy; return dim3(grid_size, 1, 1); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + CK_TILE_HOST static auto BlockSize() + { + if(ck_tile::is_wave32()) + { + return dim3(kBlockSize / 2); + } + else + { + return dim3(kBlockSize); + } + } CK_TILE_HOST static constexpr KernelArgs MakeKernelArgs(const UniversalGemmHostArgs& hostArgs) @@ -371,7 +382,9 @@ struct UniversalGemmKernel } } - bool AsTesnorIsValid = {true}; + const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA() + : GemmPipeline::template GetVectorSizeA(); + bool AsTesnorIsValid = {true}; static_for<0, NumATensor, 1>{}([&](auto index) { using AiLayout = remove_cvref_t>; if constexpr(std::is_same_v) @@ -387,7 +400,7 @@ struct UniversalGemmKernel } AsTesnorIsValid = false; } - if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) + if(kargs.K % vectorSizeA != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { @@ -407,7 +420,7 @@ struct UniversalGemmKernel } AsTesnorIsValid = false; } - if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) + if(kargs.M % vectorSizeA != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { @@ -418,7 +431,9 @@ struct UniversalGemmKernel } }); - bool BsTesnorIsValid = {true}; + bool BsTesnorIsValid = {true}; + const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB() + : GemmPipeline::template GetVectorSizeB(); static_for<0, NumBTensor, 1>{}([&](auto index) { using BiLayout = remove_cvref_t>; if constexpr(std::is_same_v) @@ -432,7 +447,7 @@ struct UniversalGemmKernel } BsTesnorIsValid = false; } - if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) + if(kargs.N % vectorSizeB != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { @@ -454,7 +469,7 @@ struct UniversalGemmKernel } BsTesnorIsValid = false; } - if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) + if(kargs.K % vectorSizeB != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 2d439c6970..5f4ee8987e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -127,8 +127,16 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } - static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + template + static constexpr index_t GetVectorSizeA() + { + return Policy::template GetVectorSizeA(); + } + template + static constexpr index_t GetVectorSizeB() + { + return Policy::template GetVectorSizeB(); + } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } static constexpr index_t APackedSize = diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index b0cd93a661..c835809b5d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -124,8 +124,16 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } - static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + template + static constexpr index_t GetVectorSizeA() + { + return Policy::template GetVectorSizeA(); + } + template + static constexpr index_t GetVectorSizeB() + { + return Policy::template GetVectorSizeB(); + } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index 0fdcc04d89..b05145890f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -61,8 +61,16 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } - static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + template + static constexpr index_t GetVectorSizeA() + { + return Policy::template GetVectorSizeA(); + } + template + static constexpr index_t GetVectorSizeB() + { + return Policy::template GetVectorSizeB(); + } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } static constexpr bool kPadM = Problem::kPadM; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index d62add7ef3..e1acfebc47 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -176,8 +176,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } - static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + template + static constexpr index_t GetVectorSizeA() + { + return Policy::template GetVectorSizeA(); + } + template + static constexpr index_t GetVectorSizeB() + { + return Policy::template GetVectorSizeB(); + } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index d8118a7f8f..e3b4863392 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -36,8 +36,16 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kKPerBlock = BlockGemmShape::kK; - static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; } - static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; } + template + static constexpr index_t GetVectorSizeA() + { + return Problem::VectorSizeA; + } + template + static constexpr index_t GetVectorSizeB() + { + return Problem::VectorSizeB; + } static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index e4b3649595..40ee952b1b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -305,11 +305,15 @@ struct UniversalGemmBasePolicy * @tparam XPerTile The contiguous Tile dimension size. * @return Maximum DRAM vector load size. */ - template + template CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize() { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t BlockSize = IsWave32Host ? Problem::kBlockSize / 2 : Problem::kBlockSize; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; @@ -349,7 +353,7 @@ struct UniversalGemmBasePolicy } } - template + template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() { using ALayout = remove_cvref_t; @@ -359,15 +363,23 @@ struct UniversalGemmBasePolicy if constexpr(std::is_same_v) { - return GetGlobalVectorLoadSize(); + return GetGlobalVectorLoadSize(); } else { - return GetGlobalVectorLoadSize(); + return GetGlobalVectorLoadSize(); } } - template + template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() { using BLayout = remove_cvref_t; @@ -377,11 +389,19 @@ struct UniversalGemmBasePolicy if constexpr(std::is_same_v) { - return GetGlobalVectorLoadSize(); + return GetGlobalVectorLoadSize(); } else { - return GetGlobalVectorLoadSize(); + return GetGlobalVectorLoadSize(); } } diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp index cadd77a61f..b91c211d91 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp @@ -59,13 +59,15 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; + template static constexpr index_t GetVectorSizeA() { - return PipelinePolicy::template GetVectorSizeA(); + return PipelinePolicy::template GetVectorSizeA(); } + template static constexpr index_t GetVectorSizeB() { - return PipelinePolicy::template GetVectorSizeB(); + return PipelinePolicy::template GetVectorSizeB(); } static constexpr bool kPadM = Problem::kPadM; diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 9c0f257e8e..c507d8d8d8 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -76,13 +76,15 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; + template static constexpr index_t GetVectorSizeA() { - return PipelinePolicy::template GetVectorSizeA(); + return PipelinePolicy::template GetVectorSizeA(); } + template static constexpr index_t GetVectorSizeB() { - return PipelinePolicy::template GetVectorSizeB(); + return PipelinePolicy::template GetVectorSizeB(); } static constexpr bool kPadM = Problem::kPadM; diff --git a/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp index 78a514d6cd..6973c80d57 100644 --- a/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp +++ b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp @@ -99,15 +99,15 @@ struct AQuantGemmKernelArgs template struct AQuantGemmKernel { - using TilePartitioner = remove_cvref_t; - using GemmPipeline = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using AQLayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; - static constexpr bool Preshuffle = GemmPipeline::Preshuffle; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using AQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; + static constexpr bool Preshuffle = GemmPipeline::Preshuffle; using ADataType = remove_cvref_t; using AQDataType = remove_cvref_t; @@ -131,7 +131,7 @@ struct AQuantGemmKernel return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr AQuantGemmKernelArgs MakeKernelArgs(const AQuantGemmHostArgs& hostArgs) diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 115f6dea19..7ea2e31706 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -354,7 +354,7 @@ struct GroupedConvolutionBackwardWeightKernel using GemmDsLayout = remove_cvref_t; static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; using InDataType = remove_cvref_t; using WeiDataType = remove_cvref_t; @@ -393,7 +393,7 @@ struct GroupedConvolutionBackwardWeightKernel TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs) diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 8cd1710043..d3a90ea144 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -361,7 +361,7 @@ struct GroupedConvolutionForwardKernel using GemmDsLayout = remove_cvref_t; static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; using InDataType = remove_cvref_t; using WeiDataType = remove_cvref_t; @@ -398,7 +398,7 @@ struct GroupedConvolutionForwardKernel TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs) diff --git a/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp b/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp index ee74f1588f..eb54807d88 100644 --- a/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp +++ b/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -31,6 +31,7 @@ struct ImageToColumn static constexpr index_t kMPerBlock = Problem::BlockShape::kMPerBlock; static constexpr index_t kKPerBlock = Problem::BlockShape::kKPerBlock; + static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize; struct Kargs { diff --git a/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp b/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp index ad513dbd11..05490ac3ed 100644 --- a/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp +++ b/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp @@ -14,11 +14,10 @@ struct TileImageToColumnShape static constexpr index_t kMPerThread = ThreadTile::at(number<0>{}); static constexpr index_t kKPerThread = ThreadTile::at(number<1>{}); - static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); - static constexpr index_t kKPerWarp = WarpTile::at(number<1>{}); - + static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread; - static constexpr index_t kKThreadPerWarp = kKPerWarp / kKPerThread; + static constexpr index_t kKThreadPerWarp = get_warp_size() / kMThreadPerWarp; + static constexpr index_t kKPerWarp = kKPerThread * kKThreadPerWarp; static constexpr index_t kMPerBlock = BlockTile::at(number<0>{}); static constexpr index_t kKPerBlock = BlockTile::at(number<1>{}); diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index 146ac40fb7..6998b358d8 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -76,9 +76,9 @@ struct Layernorm2dFwd static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; - - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; struct Kargs { diff --git a/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp b/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp index 1c5cc4a11a..3578e3b375 100644 --- a/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp +++ b/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp index 0cae4023b7..5755f38475 100644 --- a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp +++ b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp @@ -25,6 +25,8 @@ struct Reduce using ComputeDataType = ck_tile::remove_cvref_t; using YDataType = ck_tile::remove_cvref_t; + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; + private: // Helper function to calculate optimal vector size for input tensor template diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp index 6cb81b8856..e7f4ce0ba8 100644 --- a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp @@ -70,6 +70,7 @@ struct Rmsnorm2dFwd static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; diff --git a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp index cb934c6c52..b70e996617 100644 --- a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp +++ b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp @@ -48,6 +48,7 @@ struct MoeSmoothquant static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; diff --git a/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp index 540fddd2e8..7dc913901e 100644 --- a/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp +++ b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp @@ -45,6 +45,7 @@ struct Smoothquant static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; diff --git a/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp index b8520ae61a..277049f6b0 100644 --- a/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp +++ b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -34,6 +34,8 @@ struct TopkSoftmaxKernel using WeightType = typename Problem::WeightType; using IndexType = typename Problem::IndexType; + static constexpr index_t kBlockSize = Problem::BlockSize; + struct TopkSoftmaxKargs { const void* p_input; diff --git a/include/ck_tile/ref/naive_attention.hpp b/include/ck_tile/ref/naive_attention.hpp index 172fcee2e3..50e963bd72 100644 --- a/include/ck_tile/ref/naive_attention.hpp +++ b/include/ck_tile/ref/naive_attention.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -117,7 +117,7 @@ struct naive_attention_fwd_kernel std::is_same_v && std::is_same_v; static constexpr int v_per_token_quant_group_size = 64; - + static constexpr int kBlockSize = 256; // TODO: hardcode using SoftmaxType = float; // always using float to do softmax compute using QuantComputeType = float; // used for quant/dequant scale compute @@ -254,7 +254,7 @@ struct naive_attention_fwd_kernel __device__ T load(int i_s, int i_h, int i_d) { return base_ptr[get_offset(i_s, i_h, i_d)]; } }; - __device__ __host__ static constexpr int get_block_size() { return 256; } + __device__ __host__ static constexpr int get_block_size() { return kBlockSize; } // for simpliciy, 1 WG always compute 1 token along q, compute all token along kv // compute all hdim from q, compute WG_SIZE hdim from v diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp index 25b10e1dc4..dd90034064 100644 --- a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp @@ -66,5 +66,5 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index f654d1a917..f634e508e3 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -111,7 +111,6 @@ class TestCkTileBatchedGemm : public ::testing::Test DsLayout, CLayout, ck_tile::element_wise::PassThrough, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -124,8 +123,8 @@ class TestCkTileBatchedGemm : public ::testing::Test using Kernel = ck_tile::BatchedGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -144,7 +143,7 @@ class TestCkTileBatchedGemm : public ::testing::Test } ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/test/ck_tile/batched_transpose/test_batched_transpose.cpp b/test/ck_tile/batched_transpose/test_batched_transpose.cpp index 77d5825eed..8812397946 100644 --- a/test/ck_tile/batched_transpose/test_batched_transpose.cpp +++ b/test/ck_tile/batched_transpose/test_batched_transpose.cpp @@ -137,11 +137,11 @@ class TestCkTileBatchedTranspose // N C H W layout_in== Config::BlockTile::at(1)}; auto kargs = Kernel::MakeKargs(host_args); - auto sc = ck_tile::stream_config{}; - const dim3 grid_size = Kernel::GridSize(host_args); - constexpr dim3 block_size = Kernel::BlockSize(); - ck_tile::launch_kernel( - sc, ck_tile::make_kernel(Kernel{}, grid_size, block_size, 0, kargs)); + auto sc = ck_tile::stream_config{}; + const dim3 grid_size = Kernel::GridSize(host_args); + const dim3 block_size = Kernel::BlockSize(); + ck_tile::launch_kernel(sc, + ck_tile::make_kernel<1>(Kernel{}, grid_size, block_size, 0, kargs)); y_dev.FromDevice(y_host.data()); ck_tile::reference_batched_transpose(x_host, y_ref, layout_in, layout_out); diff --git a/test/ck_tile/elementwise/test_elementwise_1d.cpp b/test/ck_tile/elementwise/test_elementwise_1d.cpp index 9966c369be..3ce6e78d1d 100644 --- a/test/ck_tile/elementwise/test_elementwise_1d.cpp +++ b/test/ck_tile/elementwise/test_elementwise_1d.cpp @@ -118,19 +118,17 @@ class TestCkTileElementwise : public ::testing::Test "The kernel configuration is not supported for the given input size."); } - ck_tile::launch_kernel( - s, - ck_tile::make_kernel // MinBlockPerCu - (ew_kernel, - grid, - block, - 0, // actual shared memory - lens, - strides, // input strides - strides, // output strides - d_x_ptrs_tuple, - p_y_device)); + ck_tile::launch_kernel(s, + ck_tile::make_kernel // MinBlockPerCu + (ew_kernel, + grid, + block, + 0, // actual shared memory + lens, + strides, // input strides + strides, // output strides + d_x_ptrs_tuple, + p_y_device)); d_y_mem.FromDevice(h_y.data()); diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc index 4321709ea5..53eff9ecc4 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc @@ -77,7 +77,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -93,8 +92,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -114,7 +113,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) } float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc index a22ecf2486..adae8dcf92 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc @@ -91,7 +91,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) DsLayout, ELayout, CDEElementWise, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GemmConfig::M_Warp, @@ -114,7 +113,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { grids = Kernel::GridSize(args.M, args.N, args.k_batch); } - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -165,15 +164,13 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ave_time = ck_tile::launch_kernel_time_mask( s, run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } return ave_time; }; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 26ff847841..af4f8d3d38 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -10,6 +10,7 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +#include "ck_tile/core/numeric/math.hpp" template auto calculate_rtol_atol(const ck_tile::index_t K, @@ -184,7 +185,6 @@ class TestCkTileGemmPipeline : public ::testing::Test DsLayout, CLayout, ck_tile::element_wise::PassThrough, - GemmPipeline::BlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -207,7 +207,7 @@ class TestCkTileGemmPipeline : public ::testing::Test { grids = Kernel::GridSize(args.M, args.N, args.k_batch); } - constexpr dim3 blocks = Kernel::BlockSize(); + dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -222,7 +222,7 @@ class TestCkTileGemmPipeline : public ::testing::Test } ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { diff --git a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc index 0b886938b8..e8ff45fc5e 100644 --- a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc +++ b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc @@ -99,7 +99,6 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -114,8 +113,8 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(args.k_batch != 1) { @@ -139,7 +138,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s } float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp index c08951435e..d21777c92b 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -178,7 +178,6 @@ class TestCkTileGemmMultiD : public ::testing::Test DsLayout, ELayout, CDEElementWise, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -192,8 +191,8 @@ class TestCkTileGemmMultiD : public ::testing::Test using Kernel = ck_tile::GemmKernelMultiD; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -212,7 +211,7 @@ class TestCkTileGemmMultiD : public ::testing::Test } ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index af229aad29..5d52f15696 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -183,7 +183,6 @@ class TestCkTileGemmPipeline : public ::testing::Test DsLayout, CLayout, ck_tile::element_wise::PassThrough, - GemmPipeline::BlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GemmConfig::M_Warp, @@ -206,7 +205,7 @@ class TestCkTileGemmPipeline : public ::testing::Test { grids = Kernel::GridSize(args.M, args.N, args.k_batch); } - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -221,7 +220,7 @@ class TestCkTileGemmPipeline : public ::testing::Test } ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index cededd38f9..5aca02a433 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -136,7 +136,6 @@ class TestCkTileGroupedGemm : public ::testing::Test DsLayout, CLayout, ck_tile::element_wise::PassThrough, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GroupedGemKernelParam::M_Warp, @@ -150,8 +149,8 @@ class TestCkTileGroupedGemm : public ::testing::Test auto kargs = Kernel::MakeKargs(gemm_descs); EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); - const dim3 grids = Kernel::GridSize(gemm_descs); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, kargs.data(), @@ -169,7 +168,7 @@ class TestCkTileGroupedGemm : public ::testing::Test ave_time = ck_tile::launch_kernel( s, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{}, grids, blocks, @@ -227,12 +226,6 @@ class TestCkTileGroupedGemm : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; @@ -268,7 +259,6 @@ class TestCkTileGroupedGemm : public ::testing::Test DsLayout, CLayout, ck_tile::element_wise::PassThrough, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GroupedGemKernelParam::M_Warp, @@ -279,8 +269,8 @@ class TestCkTileGroupedGemm : public ::testing::Test UniversalGemmProblem::TransposeC, memory_operation>>; using Kernel = ck_tile::GroupedGemmKernel; - constexpr dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); if(s.log_level_ > 0) { @@ -291,7 +281,7 @@ class TestCkTileGroupedGemm : public ::testing::Test } ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{}, grids, blocks, diff --git a/test/ck_tile/image_to_column/test_tile_image_to_column.cpp b/test/ck_tile/image_to_column/test_tile_image_to_column.cpp index 9c0746e972..c721f1073f 100644 --- a/test/ck_tile/image_to_column/test_tile_image_to_column.cpp +++ b/test/ck_tile/image_to_column/test_tile_image_to_column.cpp @@ -97,13 +97,13 @@ class TestCkTileImageToColumn : public ::testing::Test kargs.N * kargs.output_spatial_lengths[0] * kargs.output_spatial_lengths[1], kargs.filter_spatial_lengths[0] * kargs.filter_spatial_lengths[1] * kargs.C, kargs.G); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 2; ck_tile::launch_kernel( ck_tile::stream_config{}, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); // reference ck_tile::reference_im2col(in, out_host, conv_params); diff --git a/test/ck_tile/layernorm2d/generate.py b/test/ck_tile/layernorm2d/generate.py index d77582630a..c4366f6662 100644 --- a/test/ck_tile/layernorm2d/generate.py +++ b/test/ck_tile/layernorm2d/generate.py @@ -235,7 +235,7 @@ float layernorm2d_fwd_(const S& s, A a) using Kernel = ck_tile::Layernorm2dFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); @@ -243,7 +243,7 @@ float layernorm2d_fwd_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/test/ck_tile/memory_copy/test_copy.cpp b/test/ck_tile/memory_copy/test_copy.cpp index e8962dce29..30a2e60ea9 100644 --- a/test/ck_tile/memory_copy/test_copy.cpp +++ b/test/ck_tile/memory_copy/test_copy.cpp @@ -76,17 +76,17 @@ class TestCkTileMemoryCopy : public ::testing::TestWithParam( - Kernel{}, - kGridSize, - kBlockSize, - 0, - static_cast(x_buf.GetDeviceBuffer()), - static_cast(y_buf.GetDeviceBuffer()), - m, - n, - warp_id)); + auto ms = launch_kernel( + ck_tile::stream_config{nullptr, true}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + m, + n, + warp_id)); auto bytes = 2 * m * n * sizeof(DataType); std::cout << "elapsed: " << ms << " (ms)" << std::endl; diff --git a/test/ck_tile/memory_copy/test_copy.hpp b/test/ck_tile/memory_copy/test_copy.hpp index a9840ba2c6..4833b29560 100644 --- a/test/ck_tile/memory_copy/test_copy.hpp +++ b/test/ck_tile/memory_copy/test_copy.hpp @@ -64,7 +64,8 @@ struct TileCopy using Problem = ck_tile::remove_cvref_t; using XDataType = typename Problem::XDataType; - static constexpr bool AsyncCopy = Problem::AsyncCopy; + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; + static constexpr bool AsyncCopy = Problem::AsyncCopy; template CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution() diff --git a/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp b/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp index 9d8c9caf00..f2875c72c8 100644 --- a/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp +++ b/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp @@ -61,5 +61,5 @@ float moe_smoothquant_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } diff --git a/test/ck_tile/moe_sorting/moe_sorting_api.cpp b/test/ck_tile/moe_sorting/moe_sorting_api.cpp index 0f25e17867..0cf600d2b4 100644 --- a/test/ck_tile/moe_sorting/moe_sorting_api.cpp +++ b/test/ck_tile/moe_sorting/moe_sorting_api.cpp @@ -209,7 +209,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -227,7 +227,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #if MOE_SORTING_SUPPORT_LARGE_EXPERT #define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -283,7 +283,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ const auto lds_size = kernel::GetSmemSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() #define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ @@ -334,15 +334,15 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi } \ } -#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \ - [&]() { \ - using problem_ = \ - ck_tile::MoeSortingClearWorkspaceProblem; \ - using kernel = ck_tile::MoeSortingClearWorkspaceKernel; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \ + [&]() { \ + using problem_ = \ + ck_tile::MoeSortingClearWorkspaceProblem; \ + using kernel = ck_tile::MoeSortingClearWorkspaceKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) diff --git a/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp index c94adc24c3..498d93b656 100644 --- a/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp +++ b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -115,11 +115,12 @@ struct matrix_core_swizzle_kernel __host__ void operator()(const ck_tile::stream_config& s) const { - ck_tile::kentry<<>>(a); + ck_tile::kentry<1, kernel><<>>(a); } struct kernel { + static constexpr ck_tile::index_t kBlockSize = BLOCK_SIZE; __device__ static constexpr auto get_src_dist() { using namespace ck_tile; diff --git a/test/ck_tile/permute/test_permute_util.hpp b/test/ck_tile/permute/test_permute_util.hpp index cca3148382..5494749541 100644 --- a/test/ck_tile/permute/test_permute_util.hpp +++ b/test/ck_tile/permute/test_permute_util.hpp @@ -54,11 +54,11 @@ float permute(permute_args a, const ck_tile::stream_config& s) auto kargs = Kernel::MakeKargs(a); - const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(a); + const dim3 blocks = Kernel::BlockSize(); - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + float ave_time = + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs)); return ave_time; } diff --git a/test/ck_tile/reduce/test_reduce2d.cpp b/test/ck_tile/reduce/test_reduce2d.cpp index 821d0a6c3e..ff807e52c9 100644 --- a/test/ck_tile/reduce/test_reduce2d.cpp +++ b/test/ck_tile/reduce/test_reduce2d.cpp @@ -82,18 +82,18 @@ class TestCkTileReduce : public ::testing::Test throw std::runtime_error("Wrong! Arguments not supported!\n"); } - ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0}, - ck_tile::make_kernel( - Kernel{}, - kGridSize, - kBlockSize, - 0, - static_cast(d_x_mem.GetDeviceBuffer()), - static_cast(d_y_mem.GetDeviceBuffer()), - input_shape_tuple, - input_strides_tuple, - kept_dims, - reduce_dims)); + ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, false, 0}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(d_x_mem.GetDeviceBuffer()), + static_cast(d_y_mem.GetDeviceBuffer()), + input_shape_tuple, + input_strides_tuple, + kept_dims, + reduce_dims)); // Get results back d_y_mem.FromDevice(h_y.data()); diff --git a/test/ck_tile/rmsnorm2d/generate.py b/test/ck_tile/rmsnorm2d/generate.py index 4296b7373e..1a1c842b3c 100644 --- a/test/ck_tile/rmsnorm2d/generate.py +++ b/test/ck_tile/rmsnorm2d/generate.py @@ -246,7 +246,7 @@ float rmsnorm2d_fwd_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp b/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp index 19310beb94..8929289cdb 100644 --- a/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp +++ b/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp @@ -57,5 +57,5 @@ float smoothquant_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } diff --git a/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp b/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp index 46c7abc697..7c90c8200c 100644 --- a/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp +++ b/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp @@ -13,11 +13,11 @@ \ auto kargs = kernel::MakeKargs(a); \ \ - const dim3 grids = kernel::GridSize(a); \ - constexpr dim3 blocks = kernel::BlockSize(); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(); \ \ - float ave_time = ck_tile::launch_kernel( \ - s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); \ + float ave_time = \ + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \ \ return ave_time; diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index 4a990f3309..dd9de36865 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -54,7 +54,6 @@ CSHUFFLE_EPILOGUE = """ ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, WarpM, diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 6d713bdcb8..7def4e2691 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -297,7 +297,7 @@ struct GemmKernel {{ throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); }} - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); const dim3 grids = {'Kernel::MaxOccupancyGridSize(stream)' if persistent == 'true' else 'Kernel::GridSize(args.M, args.N, args.k_batch)'}; if(stream.log_level_ > 0) @@ -346,12 +346,12 @@ struct GemmKernel {{ ave_time = ck_tile::launch_kernel_time_mask( stream, run_flush_cache, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{{}}, grids, blocks, 0, kargs)); }} else{{ ave_time = ck_tile::launch_kernel(stream, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{{}}, grids, blocks, 0, kargs)); }} return ave_time; diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py index 7d3629819d..9aca3407b1 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py @@ -56,7 +56,6 @@ CSHUFFLE_EPILOGUE = """ DsLayout, ELayout, CDEElementWise, - GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, WarpM, diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py index 6e65f6bf75..4b5acf1363 100755 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py @@ -443,8 +443,8 @@ struct GemmKernelMultiD {{ using Kernel = ck_tile::GemmKernelMultiD; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) {{ @@ -460,7 +460,7 @@ struct GemmKernelMultiD {{ }} ave_time = ck_tile::launch_kernel(stream, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{{}}, grids, blocks, 0, kargs)); return ave_time; From 8b55afcd9389d0c0d6ca8b6222e1b8be2417dbba Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 18 Aug 2025 11:16:25 -0700 Subject: [PATCH 02/46] Build ckProfiler package for all architectures. (#2701) * stash ckprofiler package built for all targets * build the lib for all instances in newer docker * make sure packages get posted --- Jenkinsfile | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index d1f1baf15f..b3b63098c2 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -401,7 +401,8 @@ def cmake_build(Map conf=[:]){ sh 'ninja -j64 package' archiveArtifacts artifacts: 'composablekernel-dev*.deb' sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.1.0_amd64.deb' - stash includes: "composablekernel-dev_all_targets_1.1.0_amd64.deb", name: "packages" + sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb' + stash includes: "composablekernel-**.deb", name: "packages" } } else{ @@ -571,19 +572,6 @@ def Build_CK(Map conf=[:]){ python3 -m pytest python/test/test_gen_instances.py """ } - dir("build"){ - if (params.RUN_FULL_QA && arch == 2 ){ - // build deb packages - echo "Build packages" - sh 'ninja package' - archiveArtifacts artifacts: 'composablekernel*.deb' - sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb' - sh 'mv composablekernel-dev_*.deb composablekernel-dev_1.1.0_amd64.deb' - sh 'mv composablekernel-examples_*.deb composablekernel-examples_1.1.0_amd64.deb' - sh 'mv composablekernel-tests_*.deb composablekernel-tests_1.1.0_amd64.deb' - stash includes: "composablekernel-**.deb", name: "packages" - } - } // run performance tests, stash the logs, results will be processed on the master node dir("script"){ if (params.RUN_PERFORMANCE_TESTS){ @@ -738,7 +726,7 @@ def process_results(Map conf=[:]){ echo "could not locate the FMHA performance logs: ${err.getMessage()}." } } - if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){ + if (params.BUILD_INSTANCES_ONLY){ // unstash deb packages unstash "packages" sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" @@ -1440,7 +1428,7 @@ pipeline { -D CMAKE_BUILD_TYPE=Release \ -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """ - buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0") } cleanWs() } @@ -1517,7 +1505,7 @@ pipeline { stage("Process results"){ when { beforeAgent true - expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + expression { (params.RUN_PERFORMANCE_TESTS.toBoolean() || params.BUILD_INSTANCES_ONLY.toBoolean()) && !params.BUILD_LEGACY_OS.toBoolean() } } agent { label 'mici' } steps{ From b4f3487d8423a0e14bbb90e1cf8305d3560f3d17 Mon Sep 17 00:00:00 2001 From: Geo Min Date: Mon, 18 Aug 2025 14:16:31 -0700 Subject: [PATCH 03/46] [TheRock CI] Adding presubmit check for CK (#2688) * Adding presubmit check for CK * Adding exclusion * Enable forks --- .github/scripts/therock_configure_ci.py | 112 ++++++++++++++++++++ .github/workflows/therock-ci-linux.yml | 8 +- .github/workflows/therock-ci.yml | 31 ++++++ .github/workflows/therock-test-packages.yml | 1 + 4 files changed, 149 insertions(+), 3 deletions(-) create mode 100644 .github/scripts/therock_configure_ci.py diff --git a/.github/scripts/therock_configure_ci.py b/.github/scripts/therock_configure_ci.py new file mode 100644 index 0000000000..557afe2d84 --- /dev/null +++ b/.github/scripts/therock_configure_ci.py @@ -0,0 +1,112 @@ +import fnmatch +import json +import os +from pathlib import Path +import subprocess +import sys +from typing import Iterable, Optional, Mapping + +def gha_set_output(vars: Mapping[str, str | Path]): + """Sets values in a step's output parameters. + + This appends to the file located at the $GITHUB_OUTPUT environment variable. + + See + * https://docs.github.com/en/actions/reference/workflow-commands-for-github-actions#setting-an-output-parameter + * https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/passing-information-between-jobs + """ + print(f"Setting github output:\n{vars}") + + step_output_file = os.getenv("GITHUB_OUTPUT") + if not step_output_file: + print(" Warning: GITHUB_OUTPUT env var not set, can't set github outputs") + return + + with open(step_output_file, "a") as f: + f.writelines(f"{k}={str(v)}" + "\n" for k, v in vars.items()) + +def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]: + """Returns the paths of modified files relative to the base reference.""" + try: + return subprocess.run( + ["git", "diff", "--name-only", base_ref], + stdout=subprocess.PIPE, + check=True, + text=True, + timeout=60, + ).stdout.splitlines() + except TimeoutError: + print( + "Computing modified files timed out. Not using PR diff to determine" + " jobs to run.", + file=sys.stderr, + ) + return None + +# Paths matching any of these patterns are considered to have no influence over +# build or test workflows so any related jobs can be skipped if all paths +# modified by a commit/PR match a pattern in this list. +SKIPPABLE_PATH_PATTERNS = [ + "docs/*", + "*.gitignore", + "*.md", + "*.pre-commit-config.*", + "*LICENSE", + 'Jenkinsfile', + '.github/ISSUE_TEMPLATE/*', + '.github/CODEOWNERS', + '.github/*.md', + '.github/dependabot.yml', +] + +def is_path_skippable(path: str) -> bool: + """Determines if a given relative path to a file matches any skippable patterns.""" + return any(fnmatch.fnmatch(path, pattern) for pattern in SKIPPABLE_PATH_PATTERNS) + +def check_for_non_skippable_path(paths: Optional[Iterable[str]]) -> bool: + """Returns true if at least one path is not in the skippable set.""" + if paths is None: + return False + return any(not is_path_skippable(p) for p in paths) + +def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool: + """Returns true if CI workflows should run given a list of modified paths.""" + + if paths is None: + print("No files were modified, skipping TheRock CI jobs") + return False + + paths_set = set(paths) + github_workflows_paths = set( + [p for p in paths if p.startswith(".github/workflows")] + ) + other_paths = paths_set - github_workflows_paths + + contains_other_non_skippable_files = check_for_non_skippable_path(other_paths) + + print("should_ci_run_given_modified_paths findings:") + print(f" contains_other_non_skippable_files: {contains_other_non_skippable_files}") + + if contains_other_non_skippable_files: + print("Enabling TheRock CI jobs since a non-skippable path was modified") + return True + else: + print( + "Only unrelated and/or skippable paths were modified, skipping TheRock CI jobs" + ) + return False + +def main(args): + base_ref = args.get("base_ref") + modified_paths = get_modified_paths(base_ref) + print("modified_paths (max 200):", modified_paths[:200]) + enable_jobs = should_ci_run_given_modified_paths(modified_paths) + output = { + 'enable_therock_ci': json.dumps(enable_jobs) + } + gha_set_output(output) + +if __name__ == "__main__": + args = {} + args["base_ref"] = os.environ.get("BASE_REF", "HEAD^1") + main(args) diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index 645a91c030..7db124d2a1 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -21,9 +21,11 @@ jobs: id-token: write container: image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:044b113562629f4bd2ec5d2e64b32eee11562d48fb1a75d7493daec9dd8d8292 + options: -v /runner/config:/home/awsconfig/ env: AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} TEATIME_FORCE_INTERACTIVE: 0 + AWS_SHARED_CREDENTIALS_FILE: /home/awsconfig/credentials.ini steps: - name: Checkout composable_kernel repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -83,9 +85,9 @@ jobs: echo "----------" du -h -d 1 TheRock/build/artifacts - - name: Configure AWS Credentials - if: always() - uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 + - name: Configure AWS Credentials for non-forked repos + if: ${{ always() && !github.event.pull_request.head.repo.fork }} + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4.3.1 with: aws-region: us-east-2 role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external diff --git a/.github/workflows/therock-ci.yml b/.github/workflows/therock-ci.yml index 18411baa09..3232652b6b 100644 --- a/.github/workflows/therock-ci.yml +++ b/.github/workflows/therock-ci.yml @@ -5,6 +5,15 @@ on: branches: - develop workflow_dispatch: + pull_request: + types: + - opened + - synchronize + branches: + - mainline + - release/* + - release-staging/* + - develop permissions: contents: read @@ -18,8 +27,29 @@ concurrency: cancel-in-progress: true jobs: + setup: + runs-on: ubuntu-24.04 + env: + # The commit being checked out is the merge commit for a PR. Its first + # parent will be the tip of the base branch. + BASE_REF: HEAD^ + outputs: + enable_therock_ci: ${{ steps.configure.outputs.enable_therock_ci }} + steps: + - name: "Checking out repository" + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + # We need the parent commit to do a diff + fetch-depth: 2 + + - name: "Configuring CI options" + id: configure + run: python .github/scripts/therock_configure_ci.py + therock-ci-linux: name: TheRock CI Linux + needs: setup + if: ${{ needs.setup.outputs.enable_therock_ci == 'true' }} permissions: contents: read id-token: write @@ -34,6 +64,7 @@ jobs: name: TheRock CI Summary if: always() needs: + - setup - therock-ci-linux runs-on: ubuntu-24.04 steps: diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml index 439135743c..37ddd399ad 100644 --- a/.github/workflows/therock-test-packages.yml +++ b/.github/workflows/therock-test-packages.yml @@ -68,6 +68,7 @@ jobs: VENV_DIR: ${{ env.VENV_DIR }} FETCH_ARTIFACT_ARGS: ${{ matrix.components.fetch_artifact_args }} PLATFORM: ${{ inputs.platform }} + IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }} - name: Test timeout-minutes: ${{ matrix.components.timeout_minutes }} From 8f6dc23a899c1bbfd3fe89b1c6801fda5cd5c58c Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 19 Aug 2025 00:20:54 -0700 Subject: [PATCH 04/46] remove script (#2692) --- script/cmake-ck-release.sh | 34 ---------------------------------- 1 file changed, 34 deletions(-) delete mode 100755 script/cmake-ck-release.sh diff --git a/script/cmake-ck-release.sh b/script/cmake-ck-release.sh deleted file mode 100755 index 5263de92c8..0000000000 --- a/script/cmake-ck-release.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash -set -euo pipefail -IFS=$'\n\t' - -rm -f CMakeCache.txt -rm -f *.cmake -rm -rf CMakeFiles - -MY_PROJECT_SOURCE=$1 - -if [ $# -ge 2 ] && [[ "$2" =~ ^gfx ]]; then - GPU_TARGETS=$2 - shift 2 - echo "GPU targets provided: $GPU_TARGETS" - REST_ARGS=$@ -else - echo "No GPU targets provided, using default targets: gfx908;gfx90a;gfx942" - GPU_TARGETS="gfx908;gfx90a;gfx942" - shift 1 - REST_ARGS=$@ -fi - -cmake \ --D CMAKE_PREFIX_PATH=/opt/rocm \ --D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ --D CMAKE_CXX_FLAGS="-O3" \ --D CMAKE_BUILD_TYPE=Release \ --D BUILD_DEV=OFF \ --D GPU_TARGETS=$GPU_TARGETS \ --D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ --D USE_BITINT_EXTENSION_INT4=OFF \ -$REST_ARGS \ -${MY_PROJECT_SOURCE} - From 696ef05784677173e16078a6253329284dd464ed Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 19 Aug 2025 00:22:23 -0700 Subject: [PATCH 05/46] [Dev infra] cmake_ck_dev.sh inline docs and refactor argument list (#2689) * invoke script directly * script fixup * keep the docs update separate * add newline * escape arg * use portable way of setting IFS --- script/cmake-ck-dev.sh | 47 ++++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 25a1590808..b93555901e 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -1,44 +1,47 @@ #!/bin/bash -set -euo pipefail -IFS=$'\n\t' +# exit when a command exits with non-zero status; also when an unbound variable is referenced +set -eu +# pipefail is supported by many shells, not supported by sh and dash +set -o pipefail 2>/dev/null | true +# when treating a string as a sequence, do not split on spaces +IFS=$(printf '\n\t') -rm -f CMakeCache.txt -rm -f *.cmake -rm -rf CMakeFiles +# clean the build system files +find . -name CMakeFiles -type d -exec rm -rfv {} + +find . -name CMakeCache.txt -type f -exec rm -rv {} + -MY_PROJECT_SOURCE=$1 +if [ $# -ge 1 ]; then + MY_PROJECT_SOURCE="$1" + shift 1 +else + MY_PROJECT_SOURCE=".." +fi +GPU_TARGETS="gfx908;gfx90a;gfx942" -if [ $# -ge 2 ]; then - case "$2" in - gfx*) - GPU_TARGETS=$2 - shift 2 +if [ $# -ge 1 ]; then + case "$1" in + gfx*) + GPU_TARGETS=$1 + shift 1 echo "GPU targets provided: $GPU_TARGETS" - REST_ARGS=$@ ;; *) - echo "No GPU targets provided, using default targets: gfx908;gfx90a;gfx942" - GPU_TARGETS="gfx908;gfx90a;gfx942" - shift 1 - REST_ARGS=$@ + echo "No GPU targets provided, using default targets: $GPU_TARGETS" ;; esac else - echo "No GPU targets provided, using default targets: gfx908;gfx90a;gfx942" - GPU_TARGETS="gfx908;gfx90a;gfx942" - shift 1 - REST_ARGS=$@ + echo "No GPU targets provided, using default targets: $GPU_TARGETS" fi cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm/ \ -D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ --D CMAKE_CXX_FLAGS="-std=c++20 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ +-D CMAKE_CXX_FLAGS="-ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ -D GPU_TARGETS=$GPU_TARGETS \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ -$REST_ARGS \ +$@ \ ${MY_PROJECT_SOURCE} From f38751fc2aa0f84bca7eab7ff4a588ae9cf16a24 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 19 Aug 2025 00:23:07 -0700 Subject: [PATCH 06/46] invoke script directly (#2687) --- example/ck_tile/01_fmha/README.md | 2 +- example/ck_tile/02_layernorm2d/README.md | 2 +- example/ck_tile/03_gemm/README.md | 2 +- example/ck_tile/04_img2col/README.md | 2 +- example/ck_tile/06_permute/README.md | 2 +- example/ck_tile/09_topk_softmax/README.md | 2 +- example/ck_tile/10_rmsnorm2d/README.md | 2 +- example/ck_tile/11_add_rmsnorm2d_rdquant/README.md | 2 +- example/ck_tile/12_smoothquant/README.md | 2 +- example/ck_tile/13_moe_sorting/README.md | 2 +- example/ck_tile/14_moe_smoothquant/README.md | 2 +- example/ck_tile/16_batched_gemm/README.md | 2 +- example/ck_tile/17_grouped_gemm/README.md | 2 +- example/ck_tile/18_flatmm/README.md | 2 +- example/ck_tile/19_gemm_multi_d/README.md | 2 +- example/ck_tile/35_batched_transpose/README.md | 2 +- example/ck_tile/38_block_scale_gemm/README.md | 2 +- example/ck_tile/39_copy/README.md | 2 +- test/ck_tile/memory_copy/README.md | 2 +- tile_engine/ops/gemm/README.md | 4 ++-- tile_engine/ops/gemm_multi_d/README.md | 4 ++-- 21 files changed, 23 insertions(+), 23 deletions(-) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 72109a660b..f72d7afa02 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -7,7 +7,7 @@ This folder contains example for fmha(fused multi-head attention) using ck_tile # in the root of ck_tile mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ +../script/cmake-ck-dev.sh ../ make tile_example_fmha_fwd -j ``` This will result in an executable `build/bin/tile_example_fmha_fwd` diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md index 817f62dae7..da74e2e3c1 100644 --- a/example/ck_tile/02_layernorm2d/README.md +++ b/example/ck_tile/02_layernorm2d/README.md @@ -42,7 +42,7 @@ return hidden_states, per_token_scale ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... make tile_example_layernorm2d_fwd -j ``` This will result in an executable `build/bin/tile_example_layernorm2d_fwd` diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index c9e392dbd5..6358b76fd9 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -7,7 +7,7 @@ This folder contains example for GEMM using ck_tile tile-programming implementat # in the root of ck_tile mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ +../script/cmake-ck-dev.sh ../ # The basic pipeline method on the gemm calculation make tile_example_gemm_basic -j # The memory bound pipeline on the gemm calculation diff --git a/example/ck_tile/04_img2col/README.md b/example/ck_tile/04_img2col/README.md index df5c51a9c0..3b1b6f999b 100644 --- a/example/ck_tile/04_img2col/README.md +++ b/example/ck_tile/04_img2col/README.md @@ -7,7 +7,7 @@ This folder contains example for Image to Column using ck_tile tile-programming # in the root of ck_tile mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ +../script/cmake-ck-dev.sh ../ make tile_example_img2col -j ``` This will result in an executable `build/bin/tile_example_img2col` diff --git a/example/ck_tile/06_permute/README.md b/example/ck_tile/06_permute/README.md index 03bd810ff4..5e88e71572 100644 --- a/example/ck_tile/06_permute/README.md +++ b/example/ck_tile/06_permute/README.md @@ -15,7 +15,7 @@ args: ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... make tile_example_permute -j ``` This will result in an executable `build/bin/tile_example_permute` diff --git a/example/ck_tile/09_topk_softmax/README.md b/example/ck_tile/09_topk_softmax/README.md index 1043012900..2e15aeaae5 100644 --- a/example/ck_tile/09_topk_softmax/README.md +++ b/example/ck_tile/09_topk_softmax/README.md @@ -6,7 +6,7 @@ This folder contains example for topk-softmax kernel using ck_tile tile-programm ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... make tile_example_topk_softmax -j ``` This will result in an executable `build/bin/tile_example_topk_softmax` diff --git a/example/ck_tile/10_rmsnorm2d/README.md b/example/ck_tile/10_rmsnorm2d/README.md index c067496477..1d27ad153e 100644 --- a/example/ck_tile/10_rmsnorm2d/README.md +++ b/example/ck_tile/10_rmsnorm2d/README.md @@ -6,7 +6,7 @@ This folder contains example for Rmsnorm2D forward using ck_tile tile-programmin ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... make tile_rmsnorm2d_fwd -j ``` This will result in an executable `build/bin/tile_rmsnorm2d_fwd` diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/README.md b/example/ck_tile/11_add_rmsnorm2d_rdquant/README.md index 960369b78d..f9ba76c9e3 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/README.md +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/README.md @@ -6,7 +6,7 @@ This folder contains example for add + Rmsnorm2D + rowwise dynamic quantization ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... make tile_add_rmsnorm2d_rdquant_fwd -j ``` This will result in an executable `build/bin/tile_add_rmsnorm2d_rdquant_fwd` diff --git a/example/ck_tile/12_smoothquant/README.md b/example/ck_tile/12_smoothquant/README.md index d6b815f8cf..6b3acd558b 100644 --- a/example/ck_tile/12_smoothquant/README.md +++ b/example/ck_tile/12_smoothquant/README.md @@ -6,7 +6,7 @@ This folder contains example for smoothquant using ck_tile tile-programming impl ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... make tile_smoothquant -j ``` This will result in an executable `build/bin/tile_smoothquant` diff --git a/example/ck_tile/13_moe_sorting/README.md b/example/ck_tile/13_moe_sorting/README.md index 1822ff3a37..c99f40aa57 100644 --- a/example/ck_tile/13_moe_sorting/README.md +++ b/example/ck_tile/13_moe_sorting/README.md @@ -6,7 +6,7 @@ This folder contains example for moe-sorting kernel using ck_tile tile-programmi ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... make tile_example_moe_sorting -j ``` This will result in an executable `build/bin/tile_example_moe_sorting` diff --git a/example/ck_tile/14_moe_smoothquant/README.md b/example/ck_tile/14_moe_smoothquant/README.md index 599b4c3489..c10a922607 100644 --- a/example/ck_tile/14_moe_smoothquant/README.md +++ b/example/ck_tile/14_moe_smoothquant/README.md @@ -9,7 +9,7 @@ Unlike standard smoothquant op, the input scale is from different expert `[exper ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... make tile_example_moe_smoothquant -j ``` This will result in an executable `build/bin/tile_example_moe_smoothquant` diff --git a/example/ck_tile/16_batched_gemm/README.md b/example/ck_tile/16_batched_gemm/README.md index 34b56db526..8a64a3912c 100644 --- a/example/ck_tile/16_batched_gemm/README.md +++ b/example/ck_tile/16_batched_gemm/README.md @@ -7,7 +7,7 @@ This folder contains example for batched GEMM using ck_tile tile-programming imp # in the root of ck_tile mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ +../script/cmake-ck-dev.sh ../ make tile_example_batched_gemm -j ``` This will result in an executable `build/bin/tile_example_batched_gemm` diff --git a/example/ck_tile/17_grouped_gemm/README.md b/example/ck_tile/17_grouped_gemm/README.md index 29642e96c1..8715ee79e1 100644 --- a/example/ck_tile/17_grouped_gemm/README.md +++ b/example/ck_tile/17_grouped_gemm/README.md @@ -148,7 +148,7 @@ All the necessary parameters are set, the tiling is computed, the GEMM pipeline # in the root of ck_tile mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ +../script/cmake-ck-dev.sh ../ # The basic pipeline method on the gemm calculation make tile_example_grouped_gemm -j ``` diff --git a/example/ck_tile/18_flatmm/README.md b/example/ck_tile/18_flatmm/README.md index beaac785fc..eeaa7658bd 100644 --- a/example/ck_tile/18_flatmm/README.md +++ b/example/ck_tile/18_flatmm/README.md @@ -7,7 +7,7 @@ This folder contains example for FLATMM using ck_tile tile-programming implement # in the root of ck_tile mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ +../script/cmake-ck-dev.sh ../ # The basic pipeline method on the flatmm calculation make tile_example_flatmm_basic -j ``` diff --git a/example/ck_tile/19_gemm_multi_d/README.md b/example/ck_tile/19_gemm_multi_d/README.md index 7e8cd87546..2cf2b1ea03 100644 --- a/example/ck_tile/19_gemm_multi_d/README.md +++ b/example/ck_tile/19_gemm_multi_d/README.md @@ -8,7 +8,7 @@ This folder contains example for Multiple D GEMM using ck_tile tile-programming mkdir build && cd build #you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \ leave it blank -sh ../script/cmake-ck-dev.sh ../ +../script/cmake-ck-dev.sh ../ #The basic pipeline method on the gemm calculation make tile_example_gemm_multi_d_fp16 -j ``` diff --git a/example/ck_tile/35_batched_transpose/README.md b/example/ck_tile/35_batched_transpose/README.md index 38bb2b32e4..56e9610b35 100644 --- a/example/ck_tile/35_batched_transpose/README.md +++ b/example/ck_tile/35_batched_transpose/README.md @@ -6,7 +6,7 @@ This folder contains example for batched Transpose using ck_tile tile-programmin # in the root of ck_tile mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ +../script/cmake-ck-dev.sh ../ # Make the transpose executable make tile_example_batched_transpose -j ``` diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 742a88dee7..fc905790f1 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -7,7 +7,7 @@ This folder contains example for Block Scale GEMM using ck_tile tile-programming # in the root of ck_tile mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ +../script/cmake-ck-dev.sh ../ # The aquant pipeline method on the gemm calculation make tile_example_gemm_aquant_basic -j ``` diff --git a/example/ck_tile/39_copy/README.md b/example/ck_tile/39_copy/README.md index fa98cc1de6..b5bc5d56be 100644 --- a/example/ck_tile/39_copy/README.md +++ b/example/ck_tile/39_copy/README.md @@ -12,7 +12,7 @@ This experimental kernel is intended for novice CK developers. It introduces the mkdir build && cd build # you can replace with the appropriate architecture # (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ +../script/cmake-ck-dev.sh ../ # Make the copy kernel executable make tile_example_copy -j ``` diff --git a/test/ck_tile/memory_copy/README.md b/test/ck_tile/memory_copy/README.md index 7856f0b4bd..9c56052b64 100644 --- a/test/ck_tile/memory_copy/README.md +++ b/test/ck_tile/memory_copy/README.md @@ -12,7 +12,7 @@ is moved to output DRAM window for a simple copy operation. mkdir build && cd build # you can replace with the appropriate architecture # (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ +../script/cmake-ck-dev.sh ../ # Make the copy kernel executable make test_copy -j ``` diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md index a16b74d297..79152a1a0d 100644 --- a/tile_engine/ops/gemm/README.md +++ b/tile_engine/ops/gemm/README.md @@ -20,7 +20,7 @@ mkdir build && cd build # replace [Arch] with the appropriate architecture or leave blank and # replace [Datatype1;Datatype2;...] in comma separated datatypes string (possible datatypes are [fp8, bf8, int8, fp16, bf16]) # replace [Layout1;Layout2;...] in comma separated datatypes string (possible layouts are [rcr, rrr, crr, ccr]) -sh ../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]" -DGEMM_LAYOUT="[Layout1;Layout2]" +../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]" -DGEMM_LAYOUT="[Layout1;Layout2]" # generate different executable for each passed datatype make benchmark_gemm_[Datatype1]_[Layout1] -j make benchmark_gemm_[Datatype1]_[Layout2] -j @@ -38,7 +38,7 @@ rm -rf tile_engine/ && make benchmark_gemm_[Datatypes]_[Layout] -j # rebuild ## For eaxmple build for gfx942 for fp8 and fp16 datatypes with rcr layout ``` bash mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16" -DGEMM_LAYOUT="rcr" +../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16" -DGEMM_LAYOUT="rcr" make benchmark_gemm_fp8_rcr -j make benchmark_gemm_fp16_rcr -j ``` diff --git a/tile_engine/ops/gemm_multi_d/README.md b/tile_engine/ops/gemm_multi_d/README.md index 369553b121..66f0ed80af 100644 --- a/tile_engine/ops/gemm_multi_d/README.md +++ b/tile_engine/ops/gemm_multi_d/README.md @@ -21,7 +21,7 @@ mkdir build && cd build # replace [Datatype] in comma separated datatypes string (possible datatypes are [fp16]) # replace [Layout1;Layout2;...] in comma separated datatypes string (possible layouts are [rcr, rrr, crr, ccr]) # replace "mul" with either of mul,add,passthrough for Elementwise function as Multiply, Add or Passthrough respectively. If this is not specified it is considered as mul by default. -sh ../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_MULTI_D_DATATYPE="[Datatype]" -DGEMM_MULTI_D_LAYOUT="[Layout1;Layout2]" -DGEMM_MULTI_D_ELEMENTWISE_FUNCTION="mul" +../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_MULTI_D_DATATYPE="[Datatype]" -DGEMM_MULTI_D_LAYOUT="[Layout1;Layout2]" -DGEMM_MULTI_D_ELEMENTWISE_FUNCTION="mul" # generate different executable for each passed datatype make benchmark_gemm_multi_d_[Datatype]_[Layout1] -j make benchmark_gemm_multi_d_[Datatype]_[Layout2] -j @@ -37,7 +37,7 @@ rm -rf tile_engine/ && make benchmark_gemm_multi_d_[Datatype]_[Layout] -j # reb ## For eaxmple build for gfx942 for datatype with rcr layout ``` bash mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_MULTI_D_DATATYPE="fp16" -DGEMM_MULTI_D_LAYOUT="rcrr" +../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_MULTI_D_DATATYPE="fp16" -DGEMM_MULTI_D_LAYOUT="rcrr" make benchmark_gemm_multi_d_fp16_rcrr -j ## benchmark_gemm inputs From 60320e90c11b80411cb2b70c9c5a5976a56abad5 Mon Sep 17 00:00:00 2001 From: mirchen-amd Date: Tue, 19 Aug 2025 04:19:17 -0400 Subject: [PATCH 07/46] Mirchen/gemm blockscale wp segfault fix (#2638) * Add stride validation to prevent segfault in blockscale GEMM * run clang-format * Update profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp Co-authored-by: rahjain-amd * added stride length checking to more gemm examples in ckprofiler * ran clang format * added validation header and implement in core gemm operations * remove ck_tile transpose and gemm stages from CI (#2646) * update CK build instruction step 4 (#2563) Co-authored-by: Aviral Goel * Fixes to "General 2D Reduction Kernel" (#2535) (#2656) * fix reduce2d - revret the combine_partial_results() chnages - remove auto from function def * clang-format * enable aiter test_mha in daily CI (#2659) * feat(copy_kernel): add basic copy kernel example with beginner friendly documentation (#2582) * feat(copy_kernel): add basic copy kernel example with documentation * docs(CHANGELOG): Updated changelog * chore: performed clang format * Update example/ck_tile/39_copy/copy_basic.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update example/ck_tile/39_copy/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update example/ck_tile/39_copy/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update example/ck_tile/39_copy/README.md Co-authored-by: spolifroni-amd * Update example/ck_tile/39_copy/README.md Co-authored-by: spolifroni-amd * Update example/ck_tile/39_copy/README.md Co-authored-by: spolifroni-amd * fix(terminology): follow amd terms * extract elementwise copy to a new kernel * fix(copy_kernel): bug in verification * add comments about vgpr usage * lint and nits * add notes and comments * print hostTensor via stream * print hostTensor via stream --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: spolifroni-amd * [CK_TILE] FMHA BWD Optimization For GFX950 (#2628) * simplify fmha_bwd_kernel MakeKargs & dq_dram_window * simply duplicate * trload pipeline * Try two-stage * add prefetch * optimize & iglp * Fix num_byte calculations to use nhead_k for K & V size (#2653) Simple fix just to calculate the number of bytes correctly for what's reported in the output. I was getting 6200 GB/s which is past the SoL of MI300. Before: ``` ./bin/tile_example_fmha_fwd -prec=bf16 -b=2 -s=1 -s_k=32768 -h=32 -h_k=8 -d=128 -page_block_size=128 -num_splits=8 -iperm=0 -operm=0 -v=0 -kname=1 [bf16|batch|bshd] b:2, h:32/8, s:1/32768, d:128/128, scale_s:0.0883883, bias:n, p_drop:0, lse:0, squant:0, mask:n, v:r, num_splits:8, page_block_size:128, fmha_fwd_splitkv_d128_bf16_batch_b16x64x64x128x64x128_r1x4x1_r1x4x1_w16x16x16_w16x16x16_qr_nwarp_sshuffle_vr_ps_nlogits_nbias_nmask_lse_nsquant_pagedkv, fmha_fwd_splitkv_combine_d128_bf16_batch_b32_unused_ps_nlse_nsquant, 0.173 ms, 6.20 TFlops, 6202.95 GB/s ``` After: ``` ./bin/tile_example_fmha_fwd -prec=bf16 -b=2 -s=1 -s_k=32768 -h=32 -h_k=8 -d=128 -page_block_size=128 -num_splits=8 -iperm=0 -operm=0 -v=0 -kname=1 [bf16|batch|bshd] b:2, h:32/8, s:1/32768, d:128/128, scale_s:0.0883883, bias:n, p_drop:0, lse:0, squant:0, mask:n, v:r, num_splits:8, page_block_size:128, fmha_fwd_splitkv_d128_bf16_batch_b16x64x64x128x64x128_r1x4x1_r1x4x1_w16x16x16_w16x16x16_qr_nwarp_sshuffle_vr_ps_nlogits_nbias_nmask_lse_nsquant_pagedkv, fmha_fwd_splitkv_combine_d128_bf16_batch_b32_unused_ps_nlse_nsquant, 0.163 ms, 6.58 TFlops, 1644.53 GB/s ``` * [CK_TILE] FMHA BWD Decode Pipeline (#2643) * Fix distr * Duplicate block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr * decode 16x16 o2 * fix (#2668) * Optimize fmha fwd decode & prefill for gfx950 (#2641) * Fix for fwd/bwd kernel build filter * fix bwd code * save an example for __bf16 type * temp save, waiting for debug * tempsave, fmha_decode * temp save, change all instance to 1wave * fix async copytest bug * Add block_sync_lds_direct_load utility * fix the s_waitcnt_imm calculation * Improve s_waitcnt_imm calculation * fix vmcnt shift * add input validation and bug fix * remove unnecessary output * move test_copy into test * temp save * tempsave * compile pass * tempsave, trload+asyncload done * tempsave. asynccopy+trload sanity checked * remove unnecessary features * fix the lds alignment caused performance regression * enable prefill overload operator(). * remove all lds bankconflict with xor layouts * enable larger tile size; upgrade xor pattern * upgrade prefill pipeline; simple iglp; consistent data produce and consume order * small refactor * Load Q through lds, implement xor; * add vmcnt guard before load ktile * Add v_permlaneb32 for block_reduce. Disable it as it will cause un-coexecutable packed math in FA * Add XOR fold strategy for hdim<128, but perf dropped; disable it by default; wait further perf debug * add __restrict__ to tr load * merge fa_decode pipeline into fmha_fwd api * remove unnecessary files; rename some files * Remove unnecessary changes * bug fix, clang format; * remove non-necessary change * fix clangformat with 18.1.3 * fix bugs * fix bug * fix bug on non-gfx950 * fix bugs in gemm * fix bug in pki4 * tempsave, update the blocksync functions * change the warp setting for hdim32 fmha fwd * clang format * fix conflict. disable all v-col instance for fmha fwd * Fix the bug * clang format --------- Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> * Revert "Optimize fmha fwd decode & prefill for gfx950 (#2641)" (#2670) This reverts commit b7322a521a91fe4762701237f0243dd2c94b7644. * added batch stride checking to batched gemm ops in profiler * removed batch stride validation * removed batched stride validation again * Update include/ck/library/utility/profiler_validation_common.hpp Co-authored-by: rahjain-amd * refactor function names * added gemm stride checking to more profiler gemm operations * run clang format * add stride checkign to 01 gemm example * rename from profiler to validation common, used for examples and profiler * build of ckProfiler success * update file headers --------- Co-authored-by: rahjain-amd Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: geozhai <44495440+geozhai@users.noreply.github.com> Co-authored-by: Aviral Goel Co-authored-by: Yashvardhan Agarwal Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: spolifroni-amd Co-authored-by: Yi DING Co-authored-by: Cameron Shinn Co-authored-by: Mateusz Ozga <110818320+mozga-amd@users.noreply.github.com> Co-authored-by: Haocong WANG Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Co-authored-by: asleepzzz --- example/01_gemm/run_gemm_example.inc | 14 +++++- example/01_gemm/run_gemm_example_v2.inc | 2 +- .../ck/library/utility/validation_common.hpp | 50 +++++++++++++++++++ .../profiler/profile_gemm_ab_scale_impl.hpp | 7 ++- .../profile_gemm_bias_add_reduce_impl.hpp | 6 ++- .../profile_gemm_blockscale_wp_impl.hpp | 5 ++ .../include/profiler/profile_gemm_impl.hpp | 6 ++- .../profiler/profile_gemm_reduce_impl.hpp | 6 ++- .../profiler/profile_gemm_splitk_impl.hpp | 6 ++- .../profiler/profile_gemm_streamk_impl.hpp | 6 ++- .../profiler/profile_gemm_universal_impl.hpp | 4 ++ ...profile_gemm_universal_preshuffle_impl.hpp | 4 ++ .../profile_gemm_universal_reduce_impl.hpp | 6 ++- .../profile_gemm_universal_streamk_impl.hpp | 6 ++- 14 files changed, 118 insertions(+), 10 deletions(-) create mode 100644 include/ck/library/utility/validation_common.hpp mode change 100755 => 100644 profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 6c5d9f9fba..3e018aad1e 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -1,7 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/library/utility/validation_common.hpp" template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) @@ -53,6 +54,17 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + try + { + ck::utils::validate_gemm_strides_abc( + M, N, K, StrideA, StrideB, StrideC); + } + catch(const std::runtime_error& e) + { + std::cerr << "Error: " << e.what() << std::endl; + return false; + } + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index 4adb6f896b..3d8cf32221 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/library/utility/validation_common.hpp b/include/ck/library/utility/validation_common.hpp new file mode 100644 index 0000000000..38933c6d7c --- /dev/null +++ b/include/ck/library/utility/validation_common.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/ck.hpp" +#include "ck/utility/type.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +namespace ck { +namespace utils { + +template +inline void +validate_gemm_stride(int M, int N, int stride, const std::string& stride_name = "Stride") +{ + if(ck::is_same_v) + { + if(stride < M) + { + throw std::runtime_error( + "Error: For ColumnMajor layout, " + stride_name + " (" + std::to_string(stride) + + ") must be greater than or equal to dim (" + std::to_string(M) + ")"); + } + } + else // RowMajor + { + if(stride < N) + { + throw std::runtime_error( + "Error: For RowMajor layout, " + stride_name + " (" + std::to_string(stride) + + ") must be greater than or equal to dim (" + std::to_string(N) + ")"); + } + } +} + +// Convenience functions for common GEMM patterns +template +inline void validate_gemm_strides_abc(int M, int N, int K, int StrideA, int StrideB, int StrideC) +{ + validate_gemm_stride(M, K, StrideA, "StrideA"); + validate_gemm_stride(K, N, StrideB, "StrideB"); + validate_gemm_stride(M, N, StrideC, "StrideC"); +} + +} // namespace utils +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp b/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp index a84ad5269b..d68a1065ab 100644 --- a/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp +++ b/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,6 +19,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" +#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -74,6 +75,10 @@ bool profile_gemm_ab_scale_impl(int do_verification, ? ((K + ScaleBlockK - 1) / ScaleBlockK) : ((N + ScaleBlockN - 1) / ScaleBlockN); + ck::utils::validate_gemm_stride(M, K, StrideA, "StrideA"); + ck::utils::validate_gemm_stride(K, N, StrideB, "StrideB"); + ck::utils::validate_gemm_stride(M, N, StrideE, "StrideE"); + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor a1_m_k(f_host_tensor_descriptor((M + ScaleBlockM - 1) / ScaleBlockM, (K + ScaleBlockK - 1) / ScaleBlockK, diff --git a/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp index c0ffea8a32..405a2359c2 100644 --- a/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,6 +15,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" +#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -93,6 +94,9 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, } }; + ck::utils::validate_gemm_strides_abc( + M, N, K, StrideA, StrideB, StrideC); + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); diff --git a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp index 53073a6c75..32bdf05771 100644 --- a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp +++ b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp @@ -20,6 +20,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/validation_common.hpp" namespace ck { namespace profiler { @@ -104,6 +105,10 @@ bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification, ? ((K + ScaleBlockK - 1) / ScaleBlockK) : ((N + ScaleBlockN - 1) / ScaleBlockN); + ck::utils::validate_gemm_stride(M, K, StrideA, "StrideA"); + ck::utils::validate_gemm_stride(K, N, StrideB, "StrideB"); + ck::utils::validate_gemm_stride(M, N, StrideE, "StrideE"); + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor a1_m_k(f_host_tensor_descriptor((M + ScaleBlockM - 1) / ScaleBlockM, (K + ScaleBlockK - 1) / ScaleBlockK, diff --git a/profiler/include/profiler/profile_gemm_impl.hpp b/profiler/include/profiler/profile_gemm_impl.hpp index d2a38b2a81..fdcb3ad128 100644 --- a/profiler/include/profiler/profile_gemm_impl.hpp +++ b/profiler/include/profiler/profile_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -24,6 +24,7 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/utility/fill.hpp" +#include "ck/library/utility/validation_common.hpp" namespace ck { namespace profiler { @@ -64,6 +65,9 @@ int profile_gemm_impl(int do_verification, } }; + ck::utils::validate_gemm_strides_abc( + M, N, K, StrideA, StrideB, StrideC); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); diff --git a/profiler/include/profiler/profile_gemm_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_reduce_impl.hpp index ff801e8afd..a74d2a01d9 100644 --- a/profiler/include/profiler/profile_gemm_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_reduce_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,6 +15,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" +#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -88,6 +89,9 @@ bool profile_gemm_reduce_impl(int do_verification, } }; + ck::utils::validate_gemm_strides_abc( + M, N, K, StrideA, StrideB, StrideC); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/profiler/include/profiler/profile_gemm_splitk_impl.hpp b/profiler/include/profiler/profile_gemm_splitk_impl.hpp index 5d5ae1ad15..0640e95aba 100644 --- a/profiler/include/profiler/profile_gemm_splitk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_splitk_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,6 +19,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" +#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -62,6 +63,9 @@ bool profile_gemm_splitk_impl(int do_verification, } }; + ck::utils::validate_gemm_strides_abc( + M, N, K, StrideA, StrideB, StrideC); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); diff --git a/profiler/include/profiler/profile_gemm_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_streamk_impl.hpp index 71b54c1f47..d24ee1c7ea 100644 --- a/profiler/include/profiler/profile_gemm_streamk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_streamk_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,6 +19,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" +#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -59,6 +60,9 @@ bool profile_gemm_streamk_impl(int do_verification, } }; + ck::utils::validate_gemm_strides_abc( + M, N, K, StrideA, StrideB, StrideC); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); diff --git a/profiler/include/profiler/profile_gemm_universal_impl.hpp b/profiler/include/profiler/profile_gemm_universal_impl.hpp index ed62828158..feb75c9660 100644 --- a/profiler/include/profiler/profile_gemm_universal_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_impl.hpp @@ -19,6 +19,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" +#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -63,6 +64,9 @@ bool profile_gemm_universal_impl(int do_verification, } }; + ck::utils::validate_gemm_strides_abc( + M, N, K, StrideA, StrideB, StrideC); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp index e218143857..271bc6ef59 100644 --- a/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp @@ -19,6 +19,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" +#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -91,6 +92,9 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification, } }; + ck::utils::validate_gemm_strides_abc( + M, N, K, StrideA, StrideB, StrideC); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp index d600de0978..a0ee6a6674 100644 --- a/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,6 +19,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" +#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -64,6 +65,9 @@ bool profile_gemm_universal_reduce_impl(int do_verification, } }; + ck::utils::validate_gemm_strides_abc( + M, N, K, StrideA, StrideB, StrideC); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); diff --git a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp old mode 100755 new mode 100644 index 640b192baf..5c859b830d --- a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -21,6 +21,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" +#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp" @@ -67,6 +68,9 @@ bool profile_gemm_universal_streamk_impl(int do_verification, } }; + ck::utils::validate_gemm_strides_abc( + M, N, K, StrideA, StrideB, StrideC); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); From a1589a9667517ddc73048c05c6f3c859db99851d Mon Sep 17 00:00:00 2001 From: joyeamd Date: Tue, 19 Aug 2025 16:20:43 +0800 Subject: [PATCH 08/46] fix grouped gemm example when wave32 enabled (#2707) 1, delete some unused variables 2, fix BlockSize when wave32 enabled --- example/ck_tile/17_grouped_gemm/grouped_gemm.cpp | 7 ------- .../ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp | 12 +++++++++++- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 1e6844261f..527ef1e466 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -29,10 +29,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, void* kargs_ptr, bool splitk) { - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; @@ -44,7 +40,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; float ave_time{0}; diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index c35435ee5e..eac7f547c1 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -155,7 +155,17 @@ struct GroupedGemmKernel return group_count * sizeof(GemmTransKernelArg); } - CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(kBlockSize); } + CK_TILE_HOST static auto BlockSize() -> dim3 + { + if(is_wave32()) + { + return dim3(kBlockSize / 2); + } + else + { + return dim3(kBlockSize); + } + } /** * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. From 6ba9289b26b5df0960e0d314f2ade988f88ea35e Mon Sep 17 00:00:00 2001 From: jefyang1 <146495389+jefyang1@users.noreply.github.com> Date: Tue, 19 Aug 2025 09:58:28 -0700 Subject: [PATCH 09/46] Fix pk i4 v3 example test regression on gfx942 (#2706) Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- .../gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index dc8e98218e..57adcd4f6d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -44,10 +44,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) is_gfx950_build = false, #endif }; - // skip building the instances with K1>=32 on pre-gfx950 - if constexpr(((GridwiseGemm::AK1Number >= 32 || GridwiseGemm::BK1Number >= 32) && - static_cast(Arch::is_gfx950_build)) || - (GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32)) + // skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950 + if constexpr(static_cast(Arch::is_gfx950_build) || + (GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32) || + (GridwiseGemm::AK1Number >= 32 && GridwiseGemm::APackedSize == 2) || + (GridwiseGemm::BK1Number >= 32 && GridwiseGemm::BPackedSize == 2)) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -86,10 +87,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) is_gfx950_build = false, #endif }; - // skip building the instances with K1>=32 on pre-gfx950 - if constexpr(((GridwiseGemm::AK1Number >= 32 || GridwiseGemm::BK1Number >= 32) && - static_cast(Arch::is_gfx950_build)) || - (GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32)) + // skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950 + if constexpr(static_cast(Arch::is_gfx950_build) || + (GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32) || + (GridwiseGemm::AK1Number >= 32 && GridwiseGemm::APackedSize == 2) || + (GridwiseGemm::BK1Number >= 32 && GridwiseGemm::BPackedSize == 2)) { // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy From 071165919f1237bf187e2653437bf51d6cf87a6e Mon Sep 17 00:00:00 2001 From: Emily Martins <65371150+ecamartins@users.noreply.github.com> Date: Tue, 19 Aug 2025 15:08:52 -0600 Subject: [PATCH 10/46] [CK Tile] Stream K GEMM Kernel HostArgs and Kernel Classes (#2681) * CK Tile Stream K Device Ops Implementation of CK Tile StreamKHostArgs and StreamKKernel classes. The StreamKKernel class injects Universal Gemm and includes functions to facilitate kernel preparation for the GPU. * Stream K Device Ops Fixes - Update GetWorkSpaceSize to call TilePartitioner's GetWorkSpaceSize to ensure we get size needed for accumulation buffers and semaphores. - Pass in num_sk_blocks into TilePartitioner constructor - Update documentation * Add WarpTile dimensions to GetName function in StreamKKernel class * Fix typos in StreamKHostArgs class description. Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com> * Apply clang format on updated comment for StreamKHostArgs * Explicitly specify type for StreamKReductionStrategy enum * Remove unecessary scopes * Unify the commenting style to inline comments * Add explicit casts for occupancy and num_cu in MakeKernelArgs function Both the static functions Occupancy and NumCU in the StreamKKernel class use functions from the HIP API that result in the returned occupancy and num_cu types being type int. The TilePartitioner interface for stream K will have occupancy and num_cu being type ck_tile::index_t which is int32_t. Thus, to be safe, this change ensures that both occupancy and num_cu are cast to int32_t. * Fix use of kentry due to interface update PR #2594 updated the interface for the kentry function in include/ck_tile/host/kernel_launch.hpp. As a result, the static function Occupancy was updated to work correctly with the new interface. PR #2594 also changed UniversalGemmKernel's KernelBlockSize static variable to kBlockSize, so the StreamKKernel class was updated to reflect this change. * Switch type of num_sk_blocks from uint32_t to int32_t This change switches the type of num_sk_blocks to type ck_tile::index_t which is int32_t. This was done because parallel work for the CK Tile StreamK TilePartitioner's constructor will have num_sk_blocks as ck_tile::index_t. Thus, this change will help unify the interfaces to avoid any type conversion errors. --------- Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com> --- include/ck_tile/ops/gemm.hpp | 5 +- .../ops/gemm/kernel/streamk_gemm_kernel.hpp | 243 ++++++++++++++++++ 2 files changed, 246 insertions(+), 2 deletions(-) create mode 100644 include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 7a01420c51..28273f581d 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -1,5 +1,5 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once @@ -33,6 +33,7 @@ #include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp new file mode 100644 index 0000000000..a05e7b2ad0 --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp @@ -0,0 +1,243 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +enum StreamKReductionStrategy : uint32_t +{ + /// @brief Workgroups atomically add their results to the C tensor + Atomic = 0u, + /// @brief For a given tile in the C tensor, one workgroup accumulates results of other + /// contributing workgroups + Reduction = 1u +}; + +/// @brief The Stream K GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref StreamKKernel "StreamKKernel" when creating the kernel +/// arguments object. It contains all necessary information required to build proper kernel +/// arguments and launch the kernel on GPU. This structure defines the GEMM problem +/// configuration by stating all required information like M,N,K sizes and respective strides. +struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<> +{ + CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* c_ptr_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_, + StreamKReductionStrategy reduction_strategy_, + index_t num_sk_blocks_ = -1) + : UniversalGemmHostArgs<>({a_ptr_}, + {b_ptr_}, + {/*ds_ptr*/}, + c_ptr_, + /*k_batch_ =*/1, + M_, + N_, + K_, + {stride_A_}, + {stride_B_}, + {/*stride_Ds_*/}, + stride_C_), + reduction_strategy{reduction_strategy_}, + num_sk_blocks{num_sk_blocks_} + { + } + + ck_tile::StreamKReductionStrategy reduction_strategy; + index_t num_sk_blocks; +}; + +template +struct StreamKKernel +{ + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using UniversalGemmKernel = + UniversalGemmKernel; + + static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize; + + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + /// @brief Specify the layout configurations for A, B, and C + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + /// @brief Specify the data type configurations for A, B, and C + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "ALayout and ADataType must be scalars."); + + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "BLayout and BDataType must be scalars."); + + /// @brief CLayout and CDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "CLayout and CDataType must be scalars."); + + struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<> + { + /// @brief The strategy used by work groups to compute final results in C tensor. + StreamKReductionStrategy reduction_strategy; + /// @brief The number of stream k blocks. + index_t num_sk_blocks; + /// @brief A pointer to a buffer in device memory for accumulating partial via reduction + /// strategy. + void* workspace_ptr; + /// @brief An instance of the TilePartioner class for assisting with mapping workgroups to + /// the C tensor. + TilePartitioner tile_partitioner; + }; + + using KernelArgs = StreamKKernelArgs; + using Kernel = StreamKKernel; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + using P_ = GemmPipeline; + using WarpTile = typename P_::BlockGemmShape::WarpTile; + + return concat('_', "streamk", gemm_prec_str(), + concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), + concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})), + concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), + concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); + // clang-format on + } + + /// @brief Compute the grid size for the Stream K kernel using the tile_partitioner. + /// @return The grid size. + CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3 + { + return tile_partitioner.GridSize(); + } + + /// @brief Get the maximum occupancy grid size for the persistent kernel on the current device. + /// @return The maximum occupancy grid size. + /// @note This function queries the maximum occupancy of the kernel using + /// `hipOccupancyMaxActiveBlocksPerMultiprocessor`. + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + return UniversalGemmKernel::MaxOccupancyGridSize(s); + } + + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 + { + return UniversalGemmKernel::BlockSize(); + } + + CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args) + { + index_t occupancy = static_cast(Occupancy()); + index_t num_cu = static_cast(NumCU()); + + return StreamKKernelArgs{ + {host_args.as_ptr, + host_args.bs_ptr, + host_args.ds_ptr, + host_args.e_ptr, + host_args.M, + host_args.N, + host_args.K, + host_args.stride_As, + host_args.stride_Bs, + host_args.stride_Ds, + host_args.stride_E, + host_args.k_batch}, + host_args.reduction_strategy, + host_args.num_sk_blocks, + // The workspace pointer is set to nullptr because we must first + // instantiate the TilePartitioner to get the necessary size + /*workspace_ptr =*/nullptr, + TilePartitioner{ + host_args.M, host_args.N, host_args.K, num_cu, occupancy, host_args.num_sk_blocks}}; + } + + CK_TILE_HOST static bool + IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) + { + return UniversalGemmKernel::IsSupportedArgument(kargs); + } + + /// @brief Computes the buffer size needed to store accumulation results for Stream K. + /// @return The buffer size needed. + CK_TILE_HOST static uint32_t GetWorkSpaceSize(const StreamKKernelArgs& kargs) + { + // For reduction, we need to determine the amount of device space for acculumation + // results and semaphores. + if(kargs.reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction) + { + return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType)); + } + + // Otherwise, no additional space is needed since blocks atomically store their results. + return 0; + } + + /// @brief Sets the kargs' current workspace_ptr to the given workspace_ptr. + /// @note Assumes that the given workspace_ptr points to allocated device memory. + CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr) + { + kargs.workspace_ptr = workspace_ptr; + } + + // Temporary placeholder to support the Occupancy() static function. + // Since the Occupancy function uses kentry, this class must have an operator() function + CK_TILE_DEVICE void operator()(StreamKKernelArgs /*kargs*/) const {} + + private: + CK_TILE_HOST static int NumCU() + { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + hip_check_error(hipGetDevice(&dev)); + hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + int num_cu = dev_prop.multiProcessorCount; + + return num_cu; + } + + /// @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel + /// @return The occupancy + /// @note This function queries the maximum occupancy of the kernel using + /// `hipOccupancyMaxActiveBlocksPerMultiprocessor`. + CK_TILE_HOST static int Occupancy() + { + int occupancy; + + // Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1 + constexpr int min_block_per_cu = 1; + const auto kernel = kentry; + + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0)); + + return occupancy; + } +}; + +} // namespace ck_tile From bf3e719c16846c704e8b93b0116954b321933d74 Mon Sep 17 00:00:00 2001 From: Vidyasagar Ananthan Date: Tue, 19 Aug 2025 18:12:06 -0700 Subject: [PATCH 11/46] Setting gpu target filtering for tile engine to gfx90a, gfx942 and gfx950. (#2709) --- tile_engine/ops/gemm/CMakeLists.txt | 21 +++++++++++++++++++++ tile_engine/ops/gemm_multi_d/CMakeLists.txt | 21 +++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index fe9b7802a7..42c114b499 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -3,6 +3,24 @@ set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon set(GEMM_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)") function(build_gemm_for_datatype datatype layout) + # Filter GPU targets to only gfx90a, gfx942, and gfx950 + set(GEMM_GPU_TARGETS "") + set(DESIRED_TARGETS "gfx90a;gfx942;gfx950") + + foreach(target IN LISTS SUPPORTED_GPU_TARGETS) + if(target IN_LIST DESIRED_TARGETS) + list(APPEND GEMM_GPU_TARGETS ${target}) + endif() + endforeach() + + # Skip compilation if no matching targets found + if(NOT GEMM_GPU_TARGETS) + message(WARNING "Skipping Tile Engine GEMM compilation: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + return() + endif() + + message(STATUS "Building GEMM for GPU targets: ${GEMM_GPU_TARGETS}") + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") # Comment this if-else block when using user_provided_config @@ -83,6 +101,7 @@ function(build_gemm_for_datatype datatype layout) if(chunk_files) set(sub_intermediate_lib_name "gemm_objlib_${name}_${i}_${datatype}_${layout}") add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files}) + set_property(TARGET ${sub_intermediate_lib_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS}) list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name}) endif() @@ -102,6 +121,7 @@ function(build_gemm_for_datatype datatype layout) add_library(${intermediate_lib_name} STATIC ${obj_exprs}) add_dependencies(${intermediate_lib_name} gemm_gen_${datatype}_${layout}) + set_property(TARGET ${intermediate_lib_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS}) #foreach(objlib IN LISTS sub_intermediate_libs) # target_sources(${intermediate_lib_name} PRIVATE $) #endforeach() @@ -132,6 +152,7 @@ function(build_gemm_for_datatype datatype layout) # Executable per datatype set(exec_name "benchmark_gemm_${datatype}_${layout}") add_executable(${exec_name} benchmark_gemm.cpp) + set_property(TARGET ${exec_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS}) target_link_libraries(${exec_name} PRIVATE gemm_host_api_${datatype}_${layout}) target_compile_options(${exec_name} PRIVATE -Wno-undefined-func-template diff --git a/tile_engine/ops/gemm_multi_d/CMakeLists.txt b/tile_engine/ops/gemm_multi_d/CMakeLists.txt index 3708dd3fee..dc08e9cad3 100644 --- a/tile_engine/ops/gemm_multi_d/CMakeLists.txt +++ b/tile_engine/ops/gemm_multi_d/CMakeLists.txt @@ -4,6 +4,24 @@ set(GEMM_MULTI_D_LAYOUT "rcrr" CACHE STRING "List of layout for GEMM Multi D(sem set(GEMM_MULTI_D_ELEMENTWISE_FUNCTION "mul" CACHE STRING "Elementwise function") function(build_gemm_multi_d_for_datatype_layout datatype layout) + # Filter GPU targets to only gfx90a, gfx942, and gfx950 + set(GEMM_GPU_TARGETS "") + set(DESIRED_TARGETS "gfx90a;gfx942;gfx950") + + foreach(target IN LISTS SUPPORTED_GPU_TARGETS) + if(target IN_LIST DESIRED_TARGETS) + list(APPEND GEMM_GPU_TARGETS ${target}) + endif() + endforeach() + + # Skip compilation if no matching targets found + if(NOT GEMM_GPU_TARGETS) + message(WARNING "Skipping Tile Engine GEMM Multi D compilation: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + return() + endif() + + message(STATUS "Building GEMM Multi D for GPU targets: ${GEMM_GPU_TARGETS}") + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") # Comment this if-else block when using user_provided_config @@ -86,6 +104,7 @@ function(build_gemm_multi_d_for_datatype_layout datatype layout) if(chunk_files) set(sub_intermediate_lib_name "gemm_multi_d_objlib_${name}_${i}_${datatype}_${layout}") add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files}) + set_property(TARGET ${sub_intermediate_lib_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS}) list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name}) endif() @@ -105,6 +124,7 @@ function(build_gemm_multi_d_for_datatype_layout datatype layout) add_library(${intermediate_lib_name} STATIC ${obj_exprs}) add_dependencies(${intermediate_lib_name} gemm_multi_d_gen_${datatype}_${layout}) + set_property(TARGET ${intermediate_lib_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS}) #foreach(objlib IN LISTS sub_intermediate_libs) # target_sources(${intermediate_lib_name} PRIVATE $) #endforeach() @@ -136,6 +156,7 @@ function(build_gemm_multi_d_for_datatype_layout datatype layout) # Executable per datatype set(exec_name "benchmark_gemm_multi_d_${datatype}_${layout}") add_executable(${exec_name} benchmark_gemm_multi_d.cpp) + set_property(TARGET ${exec_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS}) target_link_libraries(${exec_name} PRIVATE gemm_multi_d_host_api_${datatype}_${layout}) target_compile_options(${exec_name} PRIVATE -Wno-undefined-func-template From 81b265cf91f489ee370639b9308051def413819c Mon Sep 17 00:00:00 2001 From: Haocong WANG Date: Wed, 20 Aug 2025 16:24:43 +0800 Subject: [PATCH 12/46] [CK_TILE] Update the fmhafwd dispatch logic (#2698) * update the fmhafwd dispatch logic * Fix fmha test scripts * Fix bash --------- Co-authored-by: Ding, Yi --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 2 +- .../ck_tile/01_fmha/script/run_full_test.sh | 2 ++ .../ck_tile/01_fmha/script/smoke_test_bwd.sh | 16 +++++++------- .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 21 ++++++++++--------- 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index e59147a4f3..d9452206e7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -385,7 +385,7 @@ class FmhaFwdApiPool: for i, dtype in enumerate(self.pool.keys()): per_hdim_case=str() for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][(hdim, hdim_v)] + traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load] inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' diff --git a/example/ck_tile/01_fmha/script/run_full_test.sh b/example/ck_tile/01_fmha/script/run_full_test.sh index b5e6778aa5..e7babd2744 100755 --- a/example/ck_tile/01_fmha/script/run_full_test.sh +++ b/example/ck_tile/01_fmha/script/run_full_test.sh @@ -9,6 +9,8 @@ # host name : $hostname # gpu architecture: e.g., gfx90a, or gfx942, etc. +set -euo pipefail + #get the command line arguments: export env_type=$1 echo 'Environment type: ' $env_type diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh index 5ba3425e26..d123f842a2 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -1,5 +1,7 @@ -#!/bin/sh +#!/bin/bash # TODO: run this script from CK root or build directory +set -euo pipefail + EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)" KNAME=1 @@ -17,12 +19,12 @@ for dbias in 0 ; do for p_drop in 0.0 0.2 ; do for deterministic in 0 ; do -$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS done done diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index dc2be933bd..3913a0d5c2 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -1,5 +1,7 @@ #!/bin/bash # TODO: run this script from CK root or build directory +set -euo pipefail + EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)" KNAME=1 @@ -51,19 +53,18 @@ run_fp16_bf16_tests() { for cache_batch_idx in $CACHE_BATCH_IDX ; do # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS done ; done ; done ; done ; done done ; done ; done ; done ; done - done ; } run_fp8_tests() { From 4212bbc170948292dc826c0f79aebea87b56d3f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 20 Aug 2025 14:29:57 +0200 Subject: [PATCH 13/46] [CK Tile] Grouped convolution backward data (#2652) * base working version for single groupped conv bwd data * Fix 2d descriptor * fix groups * Add 3d support * fixes * fixes * fixes --------- Co-authored-by: Jakub Piasecki --- .../20_grouped_convolution/CMakeLists.txt | 3 + .../grouped_convolution_backward_data.cpp | 216 ++++ ...n_grouped_convolution_bwd_data_example.inc | 188 +++ include/ck_tile/core/tensor/tensor_view.hpp | 1 + include/ck_tile/host.hpp | 1 + .../reference_grouped_conv_bwd_data.hpp | 227 ++++ include/ck_tile/ops/grouped_convolution.hpp | 2 + ...ouped_convolution_backward_data_kernel.hpp | 985 +++++++++++++++ ...ped_convolution_backward_weight_kernel.hpp | 85 +- .../grouped_convolution_forward_kernel.hpp | 84 +- .../utils/grouped_convolution_utils.hpp | 1 + .../utils/transform_conv_bwd_data_to_gemm.hpp | 1064 +++++++++++++++++ 12 files changed, 2771 insertions(+), 86 deletions(-) create mode 100644 example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp create mode 100644 example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc create mode 100644 include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp create mode 100644 include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp create mode 100644 include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp diff --git a/example/ck_tile/20_grouped_convolution/CMakeLists.txt b/example/ck_tile/20_grouped_convolution/CMakeLists.txt index c05dcac09c..5cb1d2650e 100644 --- a/example/ck_tile/20_grouped_convolution/CMakeLists.txt +++ b/example/ck_tile/20_grouped_convolution/CMakeLists.txt @@ -6,3 +6,6 @@ target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMP add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp) target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + +add_executable(tile_example_grouped_conv_bwd_data EXCLUDE_FROM_ALL grouped_convolution_backward_data.cpp) +target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp new file mode 100644 index 0000000000..308961de5a --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "grouped_convolution_utils.hpp" + +template , + typename DsLayout = ck_tile::tuple<>, + typename CDEElementWise = ck_tile::element_wise::PassThrough> +float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args, + const ck_tile::stream_config& s) +{ + constexpr int kBlockPerCu = 1; + + constexpr ck_tile::index_t M_Tile = 64; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr ck_tile::index_t VectorSizeA = 8; + constexpr ck_tile::index_t VectorSizeB = 8; + constexpr ck_tile::index_t VectorSizeC = 8; + + // Implicit GEMM Traits + using CodegenShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using GroupedConvTraitsType = + ck_tile::GroupedConvTraits; + using CodegenPipelineProblem = + ck_tile::GemmPipelineProblem; + using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using ConvEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << CodegenPipeline::GetVectorSizeA() + << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } +} + +#include "run_grouped_convolution_bwd_data_example.inc" + +template +int run_grouped_conv_bwd_data_example_prec_type( + std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) +{ + using NWGC = ck_tile::tensor_layout::convolution::NWGC; + using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; + using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; + + using GKXC = ck_tile::tensor_layout::convolution::GKXC; + using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; + using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; + + using NWGK = ck_tile::tensor_layout::convolution::NWGK; + using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; + using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; + + if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") + { + return run_grouped_conv_bwd_data_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NWGC{}, GKXC{}, NWGK{}); + } + else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") + { + return run_grouped_conv_bwd_data_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); + } + else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK") + { + return run_grouped_conv_bwd_data_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); + } + else + { + throw std::runtime_error("Unsupported memory layout!"); + } +} + +int run_grouped_conv_bwd_data_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::string data_type = arg_parser.get_str("prec"); + std::string in_layout = arg_parser.get_str("in_layout"); + std::string wei_layout = arg_parser.get_str("wei_layout"); + std::string out_layout = arg_parser.get_str("out_layout"); + + if(data_type == "fp16") + { + return run_grouped_conv_bwd_data_example_prec_type( + in_layout, wei_layout, out_layout, argc, argv); + } + else if(data_type == "bf16") + { + return run_grouped_conv_bwd_data_example_prec_type( + in_layout, wei_layout, out_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation!"); + } +} + +int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_data_example(argc, argv); } diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc new file mode 100644 index 0000000000..3e1c13c833 --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc @@ -0,0 +1,188 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +template +float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args, + int n_warmup, + int n_repeat) +{ + float ave_time = grouped_conv_bwd_data( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = args.GetFlops(); + std::size_t num_byte = args.GetByte(); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + + return ave_time; +} + +template +int run_grouped_conv_bwd_data_example_with_layouts( + int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using AccDataType = float; + + std::vector filter_spatial_lengths; + std::vector image_spatial_lengths; + std::vector strides; + std::vector dilations; + std::vector lpads; + std::vector rpads; + + const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths, + image_spatial_lengths, + strides, + dilations, + lpads, + rpads, + arg_parser); + + ck_tile::conv::ConvParam conv_param{num_dim_sp, + arg_parser.get_int("g"), + arg_parser.get_int("n"), + arg_parser.get_int("k"), + arg_parser.get_int("c"), + filter_spatial_lengths, + image_spatial_lengths, + strides, + dilations, + lpads, + rpads}; + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + + const auto in_g_n_c_wis_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + const auto wei_g_k_c_xs_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + const auto out_g_n_k_wos_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + ck_tile::HostTensor input(in_g_n_c_wis_desc); + ck_tile::HostTensor weight(wei_g_k_c_xs_desc); + ck_tile::HostTensor output(out_g_n_k_wos_desc); + + if(init_method == 0) + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(weight); + ck_tile::FillUniformDistribution{-1.f, 1.f}(output); + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(weight); + ck_tile::FillMonotonicSeq{}(output); + } + else if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(weight); + ck_tile::FillUniformDistribution{1.f, 1.f}(output); + } + else + { + weight.SetZero(); + output.SetZero(); + } + + ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes()); + + input_dev_buf.SetZero(); + weight_dev_buf.ToDevice(weight.data()); + output_dev_buf.ToDevice(output.data()); + + ck_tile::GroupedConvBwdDataHostArgs args(conv_param, + input_dev_buf.GetDeviceBuffer(), + weight_dev_buf.GetDeviceBuffer(), + {}, + output_dev_buf.GetDeviceBuffer(), + kbatch); + + std::cout << "Run Grouped Conv Bwd Data kernel" << std::endl; + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << output.mDesc << std::endl; + + invoke_grouped_conv_bwd_data(args, n_warmup, n_repeat); + + input_dev_buf.FromDevice(input.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor input_host_ref(in_g_n_c_wis_desc); + input_host_ref.SetZero(); + + ck_tile:: + reference_grouped_conv_bwd_data( + input_host_ref, + weight, + output, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_); + const ck_tile::index_t GemmK = + weight.get_element_size() / (conv_param.G_ * conv_param.K_); + const float max_accumulated_value = + *std::max_element(input_host_ref.mData.begin(), input_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + GemmK, kbatch, max_accumulated_value); + pass = ck_tile::check_err(input, + input_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + throw std::runtime_error("Unsupported gpu verification !!!"); + } + + return pass; +} diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 269465fae6..a85dbc6d00 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -445,6 +445,7 @@ struct null_tensor_view }; template diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index aa5afd25e5..41f5200413 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -27,6 +27,7 @@ #include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_fused_moe.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" +#include "ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp" #include "ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp" #include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" diff --git a/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp b/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp new file mode 100644 index 0000000000..c8264800c9 --- /dev/null +++ b/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST void reference_grouped_conv_bwd_data(HostTensor& input, + const HostTensor& weight, + const HostTensor& output, + std::vector conv_strides, + std::vector conv_dilations, + std::vector in_left_pads, + std::vector) +{ + if(!(input.get_num_of_dimension() == NDimSpatial + 3 && + weight.get_num_of_dimension() == NDimSpatial + 3 && + output.get_num_of_dimension() == NDimSpatial + 3)) + { + + printf("%lu %lu %lu", + input.get_num_of_dimension(), + weight.get_num_of_dimension(), + output.get_num_of_dimension()); + + throw std::runtime_error("wrong! inconsistent dimension"); + } + + if constexpr(NDimSpatial == 1) + { + auto func = [&](auto g, auto n, auto c, auto wi) { + std::size_t K = weight.get_lengths()[1]; + std::size_t X = weight.get_lengths()[3]; + + std::size_t Wo = output.get_lengths()[3]; + float v_acc = 0; + + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = static_cast(wi) + + static_cast(in_left_pads[0]) - + static_cast(x * conv_dilations[0]); + + if(w_tmp % conv_strides[0] == 0) + { + auto wo = static_cast(w_tmp) / + static_cast(conv_strides[0]); + + if(wo >= 0 && ck_tile::type_convert(wo) < Wo) + { + for(std::size_t k = 0; k < K; ++k) + { + OutDataType v_out = output(g, n, k, wo); + WeiDataType v_wei = weight(g, k, c, x); + v_acc += ck_tile::type_convert(v_out) * + ck_tile::type_convert(v_wei); + } + } + } + } + InDataType v_acc_converted = ck_tile::type_convert(v_acc); + input(g, n, c, wi) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + input.get_lengths()[0], + input.get_lengths()[1], + input.get_lengths()[2], + input.get_lengths()[3])(std::thread::hardware_concurrency()); + } + else if constexpr(NDimSpatial == 2) + { + auto func = [&](auto g, auto n, auto c, auto hi, auto wi) { + std::size_t K = weight.get_lengths()[1]; + std::size_t Y = weight.get_lengths()[3]; + std::size_t X = weight.get_lengths()[4]; + + std::size_t Ho = output.get_lengths()[3]; + std::size_t Wo = output.get_lengths()[4]; + + float v_acc = 0; + + for(std::size_t y = 0; y < Y; ++y) + { + auto h_tmp = static_cast(hi) + + static_cast(in_left_pads[0]) - + static_cast(y * conv_dilations[0]); + if(h_tmp % conv_strides[0] == 0) + { + auto ho = static_cast(h_tmp) / + static_cast(conv_strides[0]); + if(ho >= 0 && ck_tile::type_convert(ho) < Ho) + { + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = static_cast(wi) + + static_cast(in_left_pads[1]) - + static_cast(x * conv_dilations[1]); + if(w_tmp % conv_strides[1] == 0) + { + auto wo = static_cast(w_tmp) / + static_cast(conv_strides[1]); + + if(wo >= 0 && ck_tile::type_convert(wo) < Wo) + { + for(std::size_t k = 0; k < K; ++k) + { + OutDataType v_out = output(g, n, k, ho, wo); + WeiDataType v_wei = weight(g, k, c, y, x); + v_acc += ck_tile::type_convert(v_out) * + ck_tile::type_convert(v_wei); + } + } + } + } + } + } + } + InDataType v_acc_converted = ck_tile::type_convert(v_acc); + input(g, n, c, hi, wi) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + input.get_lengths()[0], + input.get_lengths()[1], + input.get_lengths()[2], + input.get_lengths()[3], + input.get_lengths()[4])(std::thread::hardware_concurrency()); + } + else if constexpr(NDimSpatial == 3) + { + auto func = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) { + std::size_t K = weight.get_lengths()[1]; + std::size_t Z = weight.get_lengths()[3]; + std::size_t Y = weight.get_lengths()[4]; + std::size_t X = weight.get_lengths()[5]; + + std::size_t Do = output.get_lengths()[3]; + std::size_t Ho = output.get_lengths()[4]; + std::size_t Wo = output.get_lengths()[5]; + + float v_acc = 0; + + for(std::size_t z = 0; z < Z; ++z) + { + auto d_tmp = static_cast(di) + + static_cast(in_left_pads[0]) - + static_cast(z * conv_dilations[0]); + if(d_tmp % conv_strides[0] == 0) + { + auto do_ = static_cast(d_tmp) / + static_cast(conv_strides[0]); + if(do_ >= 0 && ck_tile::type_convert(do_) < Do) + { + for(std::size_t y = 0; y < Y; ++y) + { + auto h_tmp = static_cast(hi) + + static_cast(in_left_pads[1]) - + static_cast(y * conv_dilations[1]); + if(h_tmp % conv_strides[1] == 0) + { + auto ho = static_cast(h_tmp) / + static_cast(conv_strides[1]); + if(ho >= 0 && ck_tile::type_convert(ho) < Ho) + { + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = + static_cast(wi) + + static_cast(in_left_pads[2]) - + static_cast(x * + conv_dilations[2]); + + if(w_tmp % conv_strides[2] == 0) + { + auto wo = + static_cast(w_tmp) / + static_cast(conv_strides[2]); + if(wo >= 0 && + ck_tile::type_convert(wo) < Wo) + { + for(std::size_t k = 0; k < K; ++k) + { + OutDataType v_out = + output(g, n, k, do_, ho, wo); + WeiDataType v_wei = weight(g, k, c, z, y, x); + v_acc += ck_tile::type_convert(v_out) * + ck_tile::type_convert(v_wei); + } + } + } + } + } + } + } + } + } + } + InDataType v_acc_converted = ck_tile::type_convert(v_acc); + input(g, n, c, di, hi, wi) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + input.get_lengths()[0], + input.get_lengths()[1], + input.get_lengths()[2], + input.get_lengths()[3], + input.get_lengths()[4], + input.get_lengths()[5])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error( + "Ref_conv_bwd_data: number of dimensions must be between 1 and 3."); + } +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index 29332f941a..09b50f26b0 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -3,10 +3,12 @@ #pragma once +#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp" #include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" +#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp new file mode 100644 index 0000000000..282a187eae --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -0,0 +1,985 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp" +#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" + +namespace ck_tile { + +/// @brief The Grouped Convolution kernel device arguments. +template +struct GroupedConvBwdDataKernelArgs +{ + using TilePartitioner = remove_cvref_t; + + using ConvToGemmTransformer = + TransformConvBwdDataToGemm; + static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + + template < + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0])}; + input_left_pads = {static_cast(args.input_left_pads_[0])}; + input_right_pads = {static_cast(args.input_right_pads_[0])}; + + k_batch = args.k_batch; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + const index_t X = wei_g_k_c_xs_lengths[3]; + const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW); + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde); + + if(XDotSlice <= 0) + { + continue; + } + + if(gemm_count >= MaxGroupedGemmGroupsNum) + { + gemm_count++; + // Avoid array segfault + continue; + } + + tildes = {i_xtilde}; + + ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes}; + + auto grid_descs = + conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< + GroupedConvTraitsType_::NDimSpatial>(1); + + a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{}); + b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{}); + c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{}); + + const index_t grid_size_grp = + TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0), + c_grid_descs_m_n[gemm_count].get_length(I1)); + + block_starts[gemm_count] = grid_size_; + block_ends[gemm_count] = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + ++gemm_count; + } + group_stride_a = args.K_; // A: Out NWGK + group_stride_b = args.K_ * args.C_ * + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); // B: Wei GKXC + group_stride_c = args.C_; // C: In NWGC + + GemmBatch = args.G_; + } + + template < + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0]), + static_cast(args.output_spatial_lengths_[1])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), + static_cast(args.conv_filter_strides_[1])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), + static_cast(args.conv_filter_dilations_[1])}; + input_left_pads = {static_cast(args.input_left_pads_[0]), + static_cast(args.input_left_pads_[1])}; + input_right_pads = {static_cast(args.input_right_pads_[0]), + static_cast(args.input_right_pads_[1])}; + + k_batch = args.k_batch; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + const index_t Y = wei_g_k_c_xs_lengths[3]; + const index_t X = wei_g_k_c_xs_lengths[4]; + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW); + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde); + + if(XDotSlice * YDotSlice <= 0) + { + continue; + } + + if(gemm_count >= MaxGroupedGemmGroupsNum) + { + gemm_count++; + // Avoid array segfault + continue; + } + + tildes = {i_ytilde, i_xtilde}; + + ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes}; + + auto grid_descs = conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< + GroupedConvTraitsType_::NDimSpatial>(1); + + a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{}); + b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{}); + c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{}); + + const index_t grid_size_grp = + TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0), + c_grid_descs_m_n[gemm_count].get_length(I1)); + + block_starts[gemm_count] = grid_size_; + block_ends[gemm_count] = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + ++gemm_count; + } + } + group_stride_a = args.K_; // A: Out NWGK + group_stride_b = args.K_ * args.C_ * + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); // B: Wei GKXC + group_stride_c = args.C_; // C: In NWGC + + GemmBatch = args.G_; + } + + template < + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1]), + static_cast(args.input_spatial_lengths_[2])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1]), + static_cast(args.filter_spatial_lengths_[2])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0]), + static_cast(args.output_spatial_lengths_[1]), + static_cast(args.output_spatial_lengths_[2])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), + static_cast(args.conv_filter_strides_[1]), + static_cast(args.conv_filter_strides_[2])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), + static_cast(args.conv_filter_dilations_[1]), + static_cast(args.conv_filter_dilations_[2])}; + input_left_pads = {static_cast(args.input_left_pads_[0]), + static_cast(args.input_left_pads_[1]), + static_cast(args.input_left_pads_[2])}; + input_right_pads = {static_cast(args.input_right_pads_[0]), + static_cast(args.input_right_pads_[1]), + static_cast(args.input_right_pads_[2])}; + + k_batch = args.k_batch; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + const index_t Z = wei_g_k_c_xs_lengths[3]; + const index_t Y = wei_g_k_c_xs_lengths[4]; + const index_t X = wei_g_k_c_xs_lengths[5]; + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + const auto GcdStrideDilationD = gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW); + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + const auto ZDotSlice = integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde); + + if(ZDotSlice * XDotSlice * YDotSlice <= 0) + { + continue; + } + + if(gemm_count >= MaxGroupedGemmGroupsNum) + { + gemm_count++; + // Avoid array segfault + continue; + } + + tildes = {i_ztilde, i_ytilde, i_xtilde}; + + ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes}; + + auto grid_descs = conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< + GroupedConvTraitsType_::NDimSpatial>(1); + + a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{}); + b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{}); + c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{}); + + const index_t grid_size_grp = + TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0), + c_grid_descs_m_n[gemm_count].get_length(I1)); + + block_starts[gemm_count] = grid_size_; + block_ends[gemm_count] = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + ++gemm_count; + } + } + } + + group_stride_a = args.K_; // A: Out NWGK + group_stride_b = args.K_ * args.C_ * + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); // B: Wei GKXC + group_stride_c = args.C_; // C: In NWGC + + GemmBatch = args.G_; // C: In NWGC + } + + static constexpr index_t MaxGroupedGemmGroupsNum = 128; + + using ABCGridDescs = + remove_cvref_t; + + using AGridDescMK = remove_cvref_t{}])>; + using BGridDescNK = remove_cvref_t{}])>; + using CGridDescMN = remove_cvref_t{}])>; + + static constexpr index_t NonSpatialDims = 3; + array in_g_n_c_wis_lengths; + array wei_g_k_c_xs_lengths; + array out_g_n_k_wos_lengths; + + array conv_filter_strides; + array conv_filter_dilations; + array input_left_pads; + array input_right_pads; + array tildes; + + index_t k_batch; + index_t GemmBatch; + index_t grid_size_ = 0; + index_t gemm_count = 0; + + const void* out_ptr; + void* in_ptr; + std::array ds_ptr; + const void* wei_ptr; + + array a_grid_descs_m_k; + array b_grid_descs_n_k; + array c_grid_descs_m_n; + + array block_starts; + array block_ends; + + long_index_t group_stride_a; + long_index_t group_stride_b; + long_index_t group_stride_c; +}; + +/// @brief The Grouped Convolution Backward Data kernel template. +/// +/// @paragraph Overview Overview +/// This class provides the grouped convolution backward data kernel template. By +/// semantic division of Implicit GEMM algorithm into following parts we achieve +/// flexible, versatile and robust kernel implementation. +/// +/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() +/// function call operator" which determines the work scope of each workgroup. +/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm. +/// This is the place where each workgroup is loading data from global memory and +/// carrying out dot products. +/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation +/// responsible for storing results to global memory. This is also the place where +/// any additional operator fusion may take place. +/// +/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_ +/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all +/// internal details of those functional parts. You can think of it like both gemm and +/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover +/// the policy is responsible for definition of all necessary data layouts and thread's +/// work distribution. +/// +/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution. +/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into +/// the +/// output data tile to be calculated. It determines the +/// workgroup to data relationship (or in other words - which +/// data would be processed and calculated by which workgroup). +/// @tparam GemmPipeline_ The type of class which provides the core part of matrix +/// multiplication. This class should provide implementation of +/// data loading from global memory and performing block-wise +/// matrix multiplication. You can think of it as a work done by +/// single workgroup point of view. +/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix +/// multiplication implementation. It is responsible for storing +/// results calculated by @ref GemmPipeline_ "GemmPipeline" to +/// the output C tensor in global memory. +template +struct GroupedConvolutionBackwardDataKernel +{ + static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_; + static constexpr ConvolutionSpecialization ConvSpecialization = + GroupedConvTraitsType_::ConvSpecialization; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using GemmALayout = remove_cvref_t; + using GemmBLayout = remove_cvref_t; + using GemmCLayout = remove_cvref_t; + + using InLayout = remove_cvref_t; + using WeiLayout = remove_cvref_t; + using OutLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + using GemmDsLayout = remove_cvref_t; + static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; + + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + + using InDataType = remove_cvref_t; + using WeiDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + + using OutDataType = remove_cvref_t; + + using GroupedConvBwdDataKernelArgsSpecialized = + GroupedConvBwdDataKernelArgs; + static constexpr index_t MaxGroupedGemmGroupsNum = + GroupedConvBwdDataKernelArgsSpecialized::MaxGroupedGemmGroupsNum; + + // TODO: Enable this + static constexpr bool IsSplitKSupported = false; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>(); + + static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK, + "Not supported!"); + static_assert(std::is_same_v, + "Not supported A GEMM layout!"); + static_assert(std::is_same_v, + "Not supported B GEMM layout!"); + static_assert(std::is_same_v, + "Not supported C GEMM layout!"); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "grouped_convolution_backward_data", gemm_prec_str, GemmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static auto GridSize(const GroupedConvBwdDataKernelArgsSpecialized& kargs) + { + // enable batched grouped gemm + return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.k_batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + CK_TILE_HOST static constexpr GroupedConvBwdDataKernelArgsSpecialized + MakeKernelArgs(const GroupedConvBwdDataHostArgs& hostArgs) + { + return GroupedConvBwdDataKernelArgsSpecialized(hostArgs); + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_HOST static bool + IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs) + { + if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value) || + !IsSplitKSupported) + { + if(kargs.k_batch != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + } + return false; + } + } + + if(kargs.gemm_count > MaxGroupedGemmGroupsNum) + { + return false; + } + + const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}]; + const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}]; + + // check ConvSpecialization + if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3]; + const index_t ConvStride = kargs.conv_filter_strides[i]; + const index_t LeftPad = kargs.input_left_pads[i]; + const index_t RightPad = kargs.input_right_pads[i]; + + if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3]; + const index_t LeftPad = kargs.input_left_pads[i]; + const index_t RightPad = kargs.input_right_pads[i]; + + if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) + { + if(ConvC != 1) + { + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3]; + + if(filter_spatial_dim != I3) + { + return false; + } + } + } + + namespace ctc = tensor_layout::convolution; + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) + { + // Check access per C + if(ConvC % GemmPipeline::GetVectorSizeB() != 0) + { + CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported input layout!"); + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(ConvC % EpiloguePipeline::GetVectorSizeC() != 0) + { + CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported weight layout!"); + return false; + } + + // check vector access of E + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(ConvK % GemmPipeline::GetVectorSizeA() != 0) + { + CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported output layout!"); + return false; + } + + return true; + } + + template + CK_TILE_DEVICE static auto + MakeGemmTensorViews(const OutDataType* a_ptr, + const InDataType* b_ptr, + const std::array& ds_ptr, + WeiDataType* c_ptr, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t group_id) + { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!"); + const auto& a_tensor_view = [&]() { + return make_tensor_view( + a_ptr, + kargs.a_grid_descs_m_k[group_id]); // A: out + }(); + + const auto& b_tensor_view = [&]() { + return make_tensor_view( + b_ptr, + kargs.b_grid_descs_n_k[group_id]); // B: weight + }(); + + const auto& c_tensor_view = [&]() { + return make_tensor_view(c_ptr, + kargs.c_grid_descs_m_n[group_id]); + }(); + + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + static_assert(std::is_same_v, OutLayout>, + "Not supported!"); + static_assert(std::is_same_v, + "Not supported!"); + static_assert(std::is_same_v, OutDataType>, + "Not supported!"); + + return make_tensor_view( + static_cast(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]); + }, + number{}); + + return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + const auto& b_pad_view = [&]() { + const auto& b_tensor_view = views.at(I1); + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + const auto& ds_tensor_view = views.at(I2); + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + }, + number{}); + + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I3); + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, + const index_t i_m, + const index_t i_n, + const index_t i_k = 0) + { + const auto& a_pad_view = views.at(I0); + const auto& b_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& c_pad_view = views.at(I3); + + const auto& a_block_window = [&]() { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, i_k}); + }(); + + const auto& b_block_window = [&]() { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, i_k}); + }(); + + const auto ds_block_window = generate_tuple( + [&](auto i) { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_n}); + }, + number{}); + + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs Grouped Convolution Backward Data kernel arguments + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr, + const InDataType* b_ptr, + const std::array& ds_ptr, + WeiDataType* c_ptr, + void* smem_ptr_0, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_n, + const index_t group_id) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum( + gemm_pad_views.at(I0).get_tensor_descriptor().get_length(I1))); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The starting pointer of 1st shared memory block. + * @param smem_ptr_1 The starting pointer of 2nd shared memory block. + * @param kargs Grouped Convolution Backward Data kernel arguments + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr, + const InDataType* b_ptr, + const std::array& ds_ptr, + WeiDataType* c_ptr, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_n, + const index_t group_id) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(gemm_tile_windows.at(I0).get_length(I1))); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized& kargs, + index_t block_id) const + { + index_t left = 0; + index_t right = kargs.gemm_count; + index_t group_id = index_t((left + right) >> 1); + + while((!(block_id >= kargs.block_starts[group_id] && + block_id < kargs.block_ends[group_id])) && + left <= right) + { + if(block_id < kargs.block_starts[group_id]) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) >> 1); + } + + return group_id; + } + + CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized kargs) const + { + const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t group_id = FindGroupId(kargs, blockIdX); + + const auto [iM, iN] = OffsettedTile1DPartitioner::GetOffsetedTileIndex( + kargs.block_starts[group_id], + kargs.c_grid_descs_m_n[group_id].get_length(I0), + kargs.c_grid_descs_m_n[group_id].get_length(I1)); + + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y); + const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY); + const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY); + const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY); + + // options + // conv_bwd_data = Out * Weight = In + const OutDataType* a_ptr = static_cast(kargs.out_ptr) + group_offset_a; + const WeiDataType* b_ptr = static_cast(kargs.wei_ptr) + group_offset_b; + InDataType* c_ptr = static_cast(kargs.in_ptr) + group_offset_c; + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(a_ptr, + b_ptr, + kargs.ds_ptr, + c_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + i_m, + i_n, + group_id); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id); + } + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 7ea2e31706..2700353049 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -17,19 +17,19 @@ namespace ck_tile { /// @brief The Grouped Convolution kernel device arguments. -template +template struct GroupedConvBwdWeightKernelArgs { using ConvToGemmTransformer = - TransformConvBwdWeightToGemm; - static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + TransformConvBwdWeightToGemm; + static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; template < - typename InLay = typename GroupedConvTraitsType::InLayout, - typename WeiLay = typename GroupedConvTraitsType::WeiLayout, - typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, typename std::enable_if && std::is_same_v && std::is_same_v, @@ -75,7 +75,7 @@ struct GroupedConvBwdWeightKernelArgs // tuple auto grid_descs = conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< - GroupedConvTraitsType::NDimSpatial>(); + GroupedConvTraitsType_::NDimSpatial>(); a_grid_desc_m_k = grid_descs.at(number<0>{}); b_grid_desc_n_k = grid_descs.at(number<1>{}); @@ -96,9 +96,9 @@ struct GroupedConvBwdWeightKernelArgs } template < - typename InLay = typename GroupedConvTraitsType::InLayout, - typename WeiLay = typename GroupedConvTraitsType::WeiLayout, - typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, typename std::enable_if && std::is_same_v && std::is_same_v, @@ -151,7 +151,7 @@ struct GroupedConvBwdWeightKernelArgs // tuple auto grid_descs = conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< - GroupedConvTraitsType::NDimSpatial>(); + GroupedConvTraitsType_::NDimSpatial>(); a_grid_desc_m_k = grid_descs.at(number<0>{}); b_grid_desc_n_k = grid_descs.at(number<1>{}); @@ -172,9 +172,9 @@ struct GroupedConvBwdWeightKernelArgs } template < - typename InLay = typename GroupedConvTraitsType::InLayout, - typename WeiLay = typename GroupedConvTraitsType::WeiLayout, - typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, typename std::enable_if && std::is_same_v && std::is_same_v, @@ -234,7 +234,7 @@ struct GroupedConvBwdWeightKernelArgs // tuple auto grid_descs = conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< - GroupedConvTraitsType::NDimSpatial>(); + GroupedConvTraitsType_::NDimSpatial>(); a_grid_desc_m_k = grid_descs.at(number<0>{}); b_grid_desc_n_k = grid_descs.at(number<1>{}); @@ -263,14 +263,14 @@ struct GroupedConvBwdWeightKernelArgs using CGridDescMN = remove_cvref_t{}])>; static constexpr index_t NonSpatialDims = 3; - array in_g_n_c_wis_lengths; - array wei_g_k_c_xs_lengths; - array out_g_n_k_wos_lengths; + array in_g_n_c_wis_lengths; + array wei_g_k_c_xs_lengths; + array out_g_n_k_wos_lengths; - array conv_filter_strides; - array conv_filter_dilations; - array input_left_pads; - array input_right_pads; + array conv_filter_strides; + array conv_filter_dilations; + array input_left_pads; + array input_right_pads; index_t k_batch; index_t GemmM; @@ -292,12 +292,12 @@ struct GroupedConvBwdWeightKernelArgs long_index_t group_stride_c; }; -/// @brief The Grouped Convolution Forward kernel template. +/// @brief The Grouped Convolution Backward Weight kernel template. /// /// @paragraph Overview Overview -/// This class provides the grouped convolution forward kernel template. By semantic -/// division of Implicit GEMM algorithm into following parts we achieve flexible, -/// versatile and robust kernel implementation. +/// This class provides the grouped convolution backward weight kernel template. By +/// semantic division of Implicit GEMM algorithm into following parts we achieve +/// flexible, versatile and robust kernel implementation. /// /// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() /// function call operator" which determines the work scope of each workgroup. @@ -315,7 +315,7 @@ struct GroupedConvBwdWeightKernelArgs /// the policy is responsible for definition of all necessary data layouts and thread's /// work distribution. /// -/// tparam ConvSpecialization Tensor descriptors specialization. +/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution. /// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into /// the /// output data tile to be calculated. It determines the @@ -330,15 +330,15 @@ struct GroupedConvBwdWeightKernelArgs /// multiplication implementation. It is responsible for storing /// results calculated by @ref GemmPipeline_ "GemmPipeline" to /// the output C tensor in global memory. -template struct GroupedConvolutionBackwardWeightKernel { - static constexpr index_t NDimSpatial = GroupedConvTraitsType::NDimSpatial_; + static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_; static constexpr ConvolutionSpecialization ConvSpecialization = - GroupedConvTraitsType::ConvSpecialization; + GroupedConvTraitsType_::ConvSpecialization; using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; @@ -346,13 +346,13 @@ struct GroupedConvolutionBackwardWeightKernel using GemmBLayout = remove_cvref_t; using GemmCLayout = remove_cvref_t; - using InLayout = remove_cvref_t; - using WeiLayout = remove_cvref_t; - using OutLayout = remove_cvref_t; - using DsLayout = remove_cvref_t; + using InLayout = remove_cvref_t; + using WeiLayout = remove_cvref_t; + using OutLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; using GemmDsLayout = remove_cvref_t; - static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; static constexpr index_t kBlockSize = GemmPipeline::BlockSize; @@ -363,7 +363,7 @@ struct GroupedConvolutionBackwardWeightKernel using OutDataType = remove_cvref_t; using GroupedConvBwdWeightKernelArgsSpecialized = - GroupedConvBwdWeightKernelArgs; + GroupedConvBwdWeightKernelArgs; // TODO: Enable this static constexpr bool IsSplitKSupported = true; @@ -594,12 +594,9 @@ struct GroupedConvolutionBackwardWeightKernel }(); const auto& c_tensor_view = [&]() { - return make_naive_tensor_view( + return make_tensor_view( c_ptr, - make_tuple(kargs.GemmM, kargs.GemmN), - make_tuple(kargs.GemmN, 1), - number{}, - number<1>{}); + kargs.c_grid_desc_m_n); // B: in }(); const auto& ds_tensor_view = generate_tuple( @@ -708,7 +705,7 @@ struct GroupedConvolutionBackwardWeightKernel * @param b_ptr input B pointer * @param c_ptr output C pointer * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param kargs Grouped Convolution Forward kernel arguments + * @param kargs Grouped Convolution Backward Weight kernel arguments * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @@ -758,7 +755,7 @@ struct GroupedConvolutionBackwardWeightKernel * @param c_ptr output C pointer * @param smem_ptr_0 The starting pointer of 1st shared memory block. * @param smem_ptr_1 The starting pointer of 2nd shared memory block. - * @param kargs Grouped Convolution Forward kernel arguments + * @param kargs Grouped Convolution Backward Weight kernel arguments * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index d3a90ea144..d4f4eca0d0 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -17,19 +17,19 @@ namespace ck_tile { /// @brief The Grouped Convolution kernel device arguments. -template +template struct GroupedConvFwdKernelArgs { using ConvToGemmFwdTransformer = - TransformConvFwdToGemm; - static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + TransformConvFwdToGemm; + static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; template < - typename InLay = typename GroupedConvTraitsType::InLayout, - typename WeiLay = typename GroupedConvTraitsType::WeiLayout, - typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, typename std::enable_if && std::is_same_v && std::is_same_v, @@ -79,13 +79,13 @@ struct GroupedConvFwdKernelArgs a_grid_desc_m_k = conv_to_gemm_transformer - .template MakeADescriptor_M_K(); + .template MakeADescriptor_M_K(); b_grid_desc_n_k = conv_to_gemm_transformer - .template MakeBDescriptor_N_K(); + .template MakeBDescriptor_N_K(); c_grid_desc_m_n = conv_to_gemm_transformer - .template MakeCDescriptor_M_N(); + .template MakeCDescriptor_M_N(); group_stride_a = args.C_; group_stride_b = args.K_ * args.C_ * @@ -97,9 +97,9 @@ struct GroupedConvFwdKernelArgs } template < - typename InLay = typename GroupedConvTraitsType::InLayout, - typename WeiLay = typename GroupedConvTraitsType::WeiLayout, - typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, typename std::enable_if && std::is_same_v && std::is_same_v, @@ -156,13 +156,13 @@ struct GroupedConvFwdKernelArgs a_grid_desc_m_k = conv_to_gemm_transformer - .template MakeADescriptor_M_K(); + .template MakeADescriptor_M_K(); b_grid_desc_n_k = conv_to_gemm_transformer - .template MakeBDescriptor_N_K(); + .template MakeBDescriptor_N_K(); c_grid_desc_m_n = conv_to_gemm_transformer - .template MakeCDescriptor_M_N(); + .template MakeCDescriptor_M_N(); group_stride_a = args.C_; group_stride_b = args.K_ * args.C_ * @@ -174,9 +174,9 @@ struct GroupedConvFwdKernelArgs } template < - typename InLay = typename GroupedConvTraitsType::InLayout, - typename WeiLay = typename GroupedConvTraitsType::WeiLayout, - typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, typename std::enable_if && std::is_same_v && std::is_same_v, @@ -242,13 +242,13 @@ struct GroupedConvFwdKernelArgs a_grid_desc_m_k = conv_to_gemm_transformer - .template MakeADescriptor_M_K(); + .template MakeADescriptor_M_K(); b_grid_desc_n_k = conv_to_gemm_transformer - .template MakeBDescriptor_N_K(); + .template MakeBDescriptor_N_K(); c_grid_desc_m_n = conv_to_gemm_transformer - .template MakeCDescriptor_M_N(); + .template MakeCDescriptor_M_N(); group_stride_a = args.C_; group_stride_b = args.K_ * args.C_ * @@ -261,23 +261,23 @@ struct GroupedConvFwdKernelArgs using AGridDescMK = remove_cvref_t< decltype(ConvToGemmFwdTransformer{} - .template MakeADescriptor_M_K())>; + .template MakeADescriptor_M_K())>; using BGridDescNK = remove_cvref_t< decltype(ConvToGemmFwdTransformer{} - .template MakeBDescriptor_N_K())>; + .template MakeBDescriptor_N_K())>; using CGridDescMN = remove_cvref_t< decltype(ConvToGemmFwdTransformer{} - .template MakeCDescriptor_M_N())>; + .template MakeCDescriptor_M_N())>; static constexpr index_t NonSpatialDims = 3; - array in_g_n_c_wis_lengths; - array wei_g_k_c_xs_lengths; - array out_g_n_k_wos_lengths; + array in_g_n_c_wis_lengths; + array wei_g_k_c_xs_lengths; + array out_g_n_k_wos_lengths; - array conv_filter_strides; - array conv_filter_dilations; - array input_left_pads; - array input_right_pads; + array conv_filter_strides; + array conv_filter_dilations; + array input_left_pads; + array input_right_pads; index_t k_batch; index_t GemmM; @@ -322,7 +322,7 @@ struct GroupedConvFwdKernelArgs /// the policy is responsible for definition of all necessary data layouts and thread's /// work distribution. /// -/// @tparam GroupedConvTraitsType The type of class providing traits for grouped convolution. +/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution. /// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into /// the /// output data tile to be calculated. It determines the @@ -337,15 +337,15 @@ struct GroupedConvFwdKernelArgs /// multiplication implementation. It is responsible for storing /// results calculated by @ref GemmPipeline_ "GemmPipeline" to /// the output C tensor in global memory. -template struct GroupedConvolutionForwardKernel { - static constexpr index_t NDimSpatial = GroupedConvTraitsType::NDimSpatial; + static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial; static constexpr ConvolutionSpecialization ConvSpecialization = - GroupedConvTraitsType::ConvSpecialization; + GroupedConvTraitsType_::ConvSpecialization; using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; @@ -353,13 +353,13 @@ struct GroupedConvolutionForwardKernel using GemmBLayout = remove_cvref_t; using GemmCLayout = remove_cvref_t; - using InLayout = remove_cvref_t; - using WeiLayout = remove_cvref_t; - using OutLayout = remove_cvref_t; - using DsLayout = remove_cvref_t; + using InLayout = remove_cvref_t; + using WeiLayout = remove_cvref_t; + using OutLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; using GemmDsLayout = remove_cvref_t; - static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; static constexpr index_t kBlockSize = GemmPipeline::BlockSize; @@ -369,7 +369,7 @@ struct GroupedConvolutionForwardKernel // Below type is actually accumulation data type - the output of block GEMM. using OutDataType = remove_cvref_t; - using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs; + using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs; // TODO: Enable this static constexpr bool IsSplitKSupported = false; diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index b173ab25a1..3e5e87a975 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -42,6 +42,7 @@ struct GroupedConvHostArgs : public conv::ConvParam using GroupedConvFwdHostArgs = GroupedConvHostArgs; using GroupedConvBwdWeightHostArgs = GroupedConvHostArgs; +using GroupedConvBwdDataHostArgs = GroupedConvHostArgs; template +struct TransformConvBwdDataToGemm +{ + private: + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; + static constexpr auto I4 = number<4>{}; + static constexpr auto I5 = number<5>{}; +#if 0 // TODO: Enable these functionalities + template + static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, + const ConvDimsType& strides, + index_t i) + { + long_index_t acc = 1; + for(; i < (NDimSpatial + 3); i++) + { + acc += + static_cast(lengths[i] - I1) * static_cast(strides[i]); + } + + return acc; + } + + template + static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides) + { + const long_index_t a_element_space_size = + calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); + const long_index_t c_element_space_size = + calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1); + const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), + c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const IndexType N = a_g_n_c_wis_lengths[I1]; + + if(element_space_size > TwoGB) + { + // Minimum divisor of N to not exceed 2GB + const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + + if(divisor <= static_cast(N)) + { + // Find least divisor of N larger than element_space_size / TwoGB + // Iterate up to sqrt(N). There are no divisors above this value. + for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N; + least_divisor++) + { + if(N % least_divisor == 0) + { + return N / least_divisor; + } + } + // Not found, process one Convolution N per block + return 1; + } + else + { + // Split Convolution's N dimension into N workgroups. However + // this still might not result in sufficiently small tensor, + // but at least later on we could divide the image as well. + return 1; + } + } + else + { + // Split N is not needed. + return N; + } + } +#endif + + public: + CK_TILE_HOST constexpr TransformConvBwdDataToGemm() {} + + template + CK_TILE_HOST + TransformConvBwdDataToGemm(const TransformConvBwdDataToGemmBase& transform_conv_to_gemm_base) + : G_{static_cast(transform_conv_to_gemm_base.G_)}, + N_{static_cast(transform_conv_to_gemm_base.N_)}, + Di_{static_cast(transform_conv_to_gemm_base.Di_)}, + Hi_{static_cast(transform_conv_to_gemm_base.Hi_)}, + Wi_{static_cast(transform_conv_to_gemm_base.Wi_)}, + Do_{static_cast(transform_conv_to_gemm_base.Do_)}, + Ho_{static_cast(transform_conv_to_gemm_base.Ho_)}, + Wo_{static_cast(transform_conv_to_gemm_base.Wo_)}, + Z_{static_cast(transform_conv_to_gemm_base.Z_)}, + Y_{static_cast(transform_conv_to_gemm_base.Y_)}, + X_{static_cast(transform_conv_to_gemm_base.X_)}, + K_{static_cast(transform_conv_to_gemm_base.K_)}, + C_{static_cast(transform_conv_to_gemm_base.C_)}, + ConvStrideD_{static_cast(transform_conv_to_gemm_base.ConvStrideD_)}, + ConvStrideH_{static_cast(transform_conv_to_gemm_base.ConvStrideH_)}, + ConvStrideW_{static_cast(transform_conv_to_gemm_base.ConvStrideW_)}, + ConvDilationD_{static_cast(transform_conv_to_gemm_base.ConvDilationD_)}, + ConvDilationH_{static_cast(transform_conv_to_gemm_base.ConvDilationH_)}, + ConvDilationW_{static_cast(transform_conv_to_gemm_base.ConvDilationW_)}, + InLeftPadD_{static_cast(transform_conv_to_gemm_base.InLeftPadD_)}, + InLeftPadH_{static_cast(transform_conv_to_gemm_base.InLeftPadH_)}, + InLeftPadW_{static_cast(transform_conv_to_gemm_base.InLeftPadW_)}, + InRightPadD_{static_cast(transform_conv_to_gemm_base.InRightPadD_)}, + InRightPadH_{static_cast(transform_conv_to_gemm_base.InRightPadH_)}, + InRightPadW_{static_cast(transform_conv_to_gemm_base.InRightPadW_)} + { + } + + template ::type = false> + CK_TILE_HOST TransformConvBwdDataToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads, + const ConvSpatialDimsType& tildes) + : G_{a_g_n_c_wis_lengths[I0]}, + N_{a_g_n_c_wis_lengths[I1]}, + Di_{I1}, + Hi_{I1}, + Wi_{a_g_n_c_wis_lengths[I3]}, + Do_{I1}, + Ho_{I1}, + Wo_{c_g_n_k_wos_lengths[I3]}, + Z_{I1}, + Y_{I1}, + X_{b_g_k_c_xs_lengths[I3]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{I1}, + ConvStrideH_{I1}, + ConvStrideW_{conv_filter_strides[I0]}, + ConvDilationD_{I1}, + ConvDilationH_{I1}, + ConvDilationW_{conv_filter_dilations[I0]}, + InLeftPadD_{I0}, + InLeftPadH_{I0}, + InLeftPadW_{input_left_pads[I0]}, + InRightPadD_{I0}, + InRightPadH_{I0}, + InRightPadW_{input_right_pads[I0]}, + IdxZTilde_{I1}, + IdxYTilde_{I1}, + IdxXTilde_{tildes[I0]} + { +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + + GcdStrideDilationW_ = gcd(ConvStrideW_, ConvDilationW_); + XTilde_ = ConvStrideW_ / GcdStrideDilationW_; + WTilde_ = Wo_ + integer_divide_ceil(ConvDilationW_ * (X_ - I1), ConvStrideW_); + XDot_ = integer_divide_ceil(X_, XTilde_); + } + + template ::type = false> + CK_TILE_HOST TransformConvBwdDataToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads, + const ConvSpatialDimsType& tildes) + : G_{a_g_n_c_wis_lengths[I0]}, + N_{a_g_n_c_wis_lengths[I1]}, + Di_{I1}, + Hi_{a_g_n_c_wis_lengths[I3]}, + Wi_{a_g_n_c_wis_lengths[I4]}, + Do_{I1}, + Ho_{c_g_n_k_wos_lengths[I3]}, + Wo_{c_g_n_k_wos_lengths[I4]}, + Z_{I1}, + Y_{b_g_k_c_xs_lengths[I3]}, + X_{b_g_k_c_xs_lengths[I4]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{I1}, + ConvStrideH_{conv_filter_strides[I0]}, + ConvStrideW_{conv_filter_strides[I1]}, + ConvDilationD_{I1}, + ConvDilationH_{conv_filter_dilations[I0]}, + ConvDilationW_{conv_filter_dilations[I1]}, + InLeftPadD_{I0}, + InLeftPadH_{input_left_pads[I0]}, + InLeftPadW_{input_left_pads[I1]}, + InRightPadD_{I0}, + InRightPadH_{input_right_pads[I0]}, + InRightPadW_{input_right_pads[I1]}, + IdxZTilde_{I1}, + IdxYTilde_{tildes[I0]}, + IdxXTilde_{tildes[I1]} + { +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + GcdStrideDilationW_ = gcd(ConvStrideW_, ConvDilationW_); + GcdStrideDilationH_ = gcd(ConvStrideH_, ConvDilationH_); + XTilde_ = ConvStrideW_ / GcdStrideDilationW_; + YTilde_ = ConvStrideH_ / GcdStrideDilationH_; + WTilde_ = Wo_ + integer_divide_ceil(ConvDilationW_ * (X_ - I1), ConvStrideW_); + HTilde_ = Ho_ + integer_divide_ceil(ConvDilationH_ * (Y_ - I1), ConvStrideH_); + XDot_ = integer_divide_ceil(X_, XTilde_); + YDot_ = integer_divide_ceil(Y_, YTilde_); + } + + template ::type = false> + CK_TILE_HOST TransformConvBwdDataToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads, + [[maybe_unused]] const ConvSpatialDimsType& tildes) + : G_{a_g_n_c_wis_lengths[I0]}, + N_{a_g_n_c_wis_lengths[I1]}, + Di_{a_g_n_c_wis_lengths[I3]}, + Hi_{a_g_n_c_wis_lengths[I4]}, + Wi_{a_g_n_c_wis_lengths[I5]}, + Do_{c_g_n_k_wos_lengths[I3]}, + Ho_{c_g_n_k_wos_lengths[I4]}, + Wo_{c_g_n_k_wos_lengths[I5]}, + Z_{b_g_k_c_xs_lengths[I3]}, + Y_{b_g_k_c_xs_lengths[I4]}, + X_{b_g_k_c_xs_lengths[I5]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{conv_filter_strides[I0]}, + ConvStrideH_{conv_filter_strides[I1]}, + ConvStrideW_{conv_filter_strides[I2]}, + ConvDilationD_{conv_filter_dilations[I0]}, + ConvDilationH_{conv_filter_dilations[I1]}, + ConvDilationW_{conv_filter_dilations[I2]}, + InLeftPadD_{input_left_pads[I0]}, + InLeftPadH_{input_left_pads[I1]}, + InLeftPadW_{input_left_pads[I2]}, + InRightPadD_{input_right_pads[I0]}, + InRightPadH_{input_right_pads[I1]}, + InRightPadW_{input_right_pads[I2]}, + IdxZTilde_{tildes[I0]}, + IdxYTilde_{tildes[I1]}, + IdxXTilde_{tildes[I2]} + { +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + GcdStrideDilationW_ = gcd(ConvStrideW_, ConvDilationW_); + GcdStrideDilationH_ = gcd(ConvStrideH_, ConvDilationH_); + GcdStrideDilationD_ = gcd(ConvStrideD_, ConvDilationD_); + XTilde_ = ConvStrideW_ / GcdStrideDilationW_; + YTilde_ = ConvStrideH_ / GcdStrideDilationH_; + ZTilde_ = ConvStrideD_ / GcdStrideDilationD_; + WTilde_ = Wo_ + integer_divide_ceil(ConvDilationW_ * (X_ - I1), ConvStrideW_); + HTilde_ = Ho_ + integer_divide_ceil(ConvDilationH_ * (Y_ - I1), ConvStrideH_); + DTilde_ = Do_ + integer_divide_ceil(ConvDilationD_ * (Z_ - I1), ConvStrideD_); + XDot_ = integer_divide_ceil(X_, XTilde_); + YDot_ = integer_divide_ceil(Y_, YTilde_); + ZDot_ = integer_divide_ceil(Z_, ZTilde_); + } + +#if 0 // TODO: Enable these functionalities + __host__ bool AreDescriptorsSmallerThan2GB() const + { + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const long_index_t in_desc_space_size = + I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ + + (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_; + const long_index_t out_desc_space_size = + I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ + + (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_; + + bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB; + bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB; + + return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB; + } + + __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base, + CDataType* c_grid_ptr_base) const + { + // Create copies + auto conv_to_gemm_transformer_left = *this; + auto conv_to_gemm_transformer_right = *this; + IndexType a_right_offset = 0; + IndexType c_right_offset = 0; + // Calculate real filter size + const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1; + const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1; + const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1; + // Calculate start position in input for right tensor + const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_; + const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_; + const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_; + // Calculate last position in input for left tensor + const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff; + const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff; + const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff; + // Allow to split if whole left padding will be in left tensor and right padding in right + // tensor + const bool is_possible_to_split_d = Do_ != 1 && + di_right_transformer_start_idx > InLeftPadD_ && + di_left_transformer_end_idx <= (InLeftPadD_ + Di_); + const bool is_possible_to_split_h = Ho_ != 1 && + hi_right_transformer_start_idx > InLeftPadH_ && + hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_); + const bool is_possible_to_split_w = Wo_ != 1 && + wi_right_transformer_start_idx > InLeftPadW_ && + wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_); + + if(is_possible_to_split_d) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Do_ = Do_ / 2; + conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_; + conv_to_gemm_transformer_right.InLeftPadD_ = 0; + // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadD_ = 0; + conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_; + // Calculate new input size + conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_; + conv_to_gemm_transformer_right.Di_ = + math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_), + (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff); + ; + // Calcualte offsets + a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_; + c_right_offset = (Do_ / 2) * DoStride_; + } + else if(is_possible_to_split_h) + { + conv_to_gemm_transformer_left.Ho_ = Ho_ / 2; + conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2; + + conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_; + conv_to_gemm_transformer_right.InLeftPadH_ = 0; + + conv_to_gemm_transformer_left.InRightPadH_ = 0; + conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_; + + conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_; + conv_to_gemm_transformer_right.Hi_ = + math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_), + (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff); + a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_; + c_right_offset = (Ho_ / 2) * HoStride_; + } + else if(is_possible_to_split_w) + { + conv_to_gemm_transformer_left.Wo_ = Wo_ / 2; + conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2; + + conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_; + conv_to_gemm_transformer_right.InLeftPadW_ = 0; + + conv_to_gemm_transformer_left.InRightPadW_ = 0; + conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_; + + conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_; + conv_to_gemm_transformer_right.Wi_ = + math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_), + (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff); + + a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_; + c_right_offset = (Wo_ / 2) * WoStride_; + } + // Return left transform, right transformer, right offset to Input and right offset to + // Output + return ck_tile::make_tuple(conv_to_gemm_transformer_left, + conv_to_gemm_transformer_right, + a_grid_ptr_base + a_right_offset, + c_grid_ptr_base + c_right_offset); + } +#endif + + template ::type = false> + CK_TILE_HOST auto make_out_grid_desc() const + { + // NWGK + const index_t NStride = Wo_ * G_ * K_; + const index_t WoStride = G_ * K_; + constexpr auto KStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + + return make_naive_tensor_descriptor(make_tuple(N_, Wo_, K_), + make_tuple(NStride, WoStride, KStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_wei_grid_desc() const + { + // GKXC + return make_naive_tensor_descriptor_packed(make_tuple(K_, X_, C_)); + } + + template ::type = false> + CK_TILE_HOST auto make_in_grid_desc() const + { + // NWGC + const index_t NStride = Wi_ * G_ * C_; + const index_t WiStride = G_ * C_; // GC? + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_), + make_tuple(NStride, WiStride, CStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_out_grid_desc() const + { + // NHWGK + const index_t NStride = Ho_ * Wo_ * G_ * K_; + const index_t HoStride = Wo_ * G_ * K_; + const index_t WoStride = G_ * K_; + constexpr auto KStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + + return make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, K_), + make_tuple(NStride, HoStride, WoStride, KStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_in_grid_desc() const + { + // NHWGC + const index_t NStride = Hi_ * Wi_ * G_ * C_; + const index_t HiStride = Wi_ * G_ * C_; + const index_t WiStride = G_ * C_; + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStride, HiStride, WiStride, CStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_wei_grid_desc() const + { + // GKYXC + return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_)); + } + + template ::type = false> + CK_TILE_HOST auto make_out_grid_desc() const + { + // NDHWGK + const index_t NStride = Do_ * Ho_ * Wo_ * G_ * K_; + const index_t DoStride = Ho_ * Wo_ * G_ * K_; + const index_t HoStride = Wo_ * G_ * K_; + const index_t WoStride = G_ * K_; + constexpr auto KStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, K_), + make_tuple(NStride, DoStride, HoStride, WoStride, KStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_in_grid_desc() const + { + const index_t NStride = Di_ * Hi_ * Wi_ * G_ * C_; + const index_t DiStride = Hi_ * Wi_ * G_ * C_; + const index_t HiStride = Wi_ * G_ * C_; + const index_t WiStride = G_ * C_; + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_wei_grid_desc() const + { + // GKZYXC + return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_)); + } + // TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as + // properties + + template ::type = false> + CK_TILE_HOST auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const + { + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IWTildeSliceBegin = integer_divide_floor( + max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_); + + const auto IWTildeSliceEnd = + min(WTilde_, integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1); + + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto XDotSlice = integer_divide_ceil(X_ - IdxXTilde_, XTilde_); + + const auto out_grid_desc = make_out_grid_desc(); + const auto in_grid_desc = make_in_grid_desc(); + const auto wei_grid_desc = make_wei_grid_desc(); + + // A: output tensor comes in K_M + const auto out_n_wop_k_grid_desc = + transform_tensor_descriptor(out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + const auto out_n_xdotslice_wtildeslice_k_grid_desc = transform_tensor_descriptor( + out_n_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( + out_n_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), + make_merge_transform(make_tuple(N_, WTildeSlice))), + make_tuple(sequence<1, 3>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // B: weight tensor comes in K_N + const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + const auto wei_k_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<>{}, sequence<2>{})); + + const auto wei_gemmn_gemmkraw_grid_desc = + transform_tensor_descriptor(wei_k_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // c: input + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<>{}, sequence<1>{}, sequence<2>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tuple(out_gemmm_gemmkraw_grid_desc, + wei_gemmn_gemmkraw_grid_desc, + in_gemmmraw_gemmnraw_grid_desc); + } + + template ::type = false> + CK_TILE_HOST auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const + { + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = integer_divide_floor( + max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_); + const auto IWTildeSliceBegin = integer_divide_floor( + max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_); + + const auto IHTildeSliceEnd = + min(HTilde_, integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1); + const auto IWTildeSliceEnd = + min(WTilde_, integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = integer_divide_ceil(Y_ - IdxYTilde_, YTilde_); + const auto XDotSlice = integer_divide_ceil(X_ - IdxXTilde_, XTilde_); + + const auto out_grid_desc = make_out_grid_desc(); + const auto in_grid_desc = make_in_grid_desc(); + const auto wei_grid_desc = make_wei_grid_desc(); + + // A: output tensor comes in K_M + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{})); + + const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))), + make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // B: weight tensor comes in K_N + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform(make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto wei_k_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<3>{}, + sequence<2>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<>{}, + sequence<>{}, + sequence<3>{})); + + const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(sequence<1, 2, 0>{}, sequence<3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // c: input + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<>{}, + sequence<1>{}, + sequence<>{}, + sequence<2>{}, + sequence<3>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tuple(out_gemmm_gemmkraw_grid_desc, + wei_gemmn_gemmkraw_grid_desc, + in_gemmmraw_gemmnraw_grid_desc); + } + + template ::type = false> + CK_TILE_HOST auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const + { + // only work on DTilde, HTilde and WTilde that contribute to non-padding area of input + // tensor + const auto IDTildeSliceBegin = integer_divide_floor( + max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_); + const auto IHTildeSliceBegin = integer_divide_floor( + max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_); + const auto IWTildeSliceBegin = integer_divide_floor( + max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_); + + const auto IDTildeSliceEnd = + min(DTilde_, integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1); + const auto IHTildeSliceEnd = + min(HTilde_, integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1); + const auto IWTildeSliceEnd = + min(WTilde_, integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1); + + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto ZDotSlice = integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_); + const auto YDotSlice = integer_divide_ceil(Y_ - IdxYTilde_, YTilde_); + const auto XDotSlice = integer_divide_ceil(X_ - IdxXTilde_, XTilde_); + + const auto out_grid_desc = make_out_grid_desc(); + const auto in_grid_desc = make_in_grid_desc(); + const auto wei_grid_desc = make_wei_grid_desc(); + + // A: output tensor comes in K_M + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Do_, I0, I0), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(ZDot_, DTilde_), + make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)), + make_embed_transform(make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}, + sequence<6>{}, + sequence<7>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}, + sequence<6>{}, + sequence<7>{})); + + const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))), + make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // B: weight tensor comes in K_N + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform(make_tuple(ZDot_, ZTilde_), + make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)), + make_embed_transform(make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto wei_k_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxZTilde_), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<3>{}, + sequence<5>{}, + sequence<2>{}, + sequence<4>{}, + sequence<6>{}, + sequence<7>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<>{}, + sequence<>{}, + sequence<>{}, + sequence<4>{})); + + const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(sequence<1, 2, 3, 0>{}, sequence<4>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // c: input + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(ZTilde_, DTilde_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxZTilde_), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}, + sequence<6>{}, + sequence<7>{}), + make_tuple(sequence<0>{}, + sequence<>{}, + sequence<1>{}, + sequence<>{}, + sequence<2>{}, + sequence<>{}, + sequence<3>{}, + sequence<4>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tuple(out_gemmm_gemmkraw_grid_desc, + wei_gemmn_gemmkraw_grid_desc, + in_gemmmraw_gemmnraw_grid_desc); + } + + IndexType G_, N_; + IndexType Di_, Hi_, Wi_; + IndexType Do_, Ho_, Wo_; + IndexType Z_, Y_, X_; + IndexType K_, C_; + IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_; + IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_; + IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_; + IndexType InRightPadD_, InRightPadH_, InRightPadW_; + IndexType IdxZTilde_, IdxYTilde_, IdxXTilde_; + IndexType GcdStrideDilationD_, GcdStrideDilationH_, GcdStrideDilationW_; + IndexType ZTilde_, YTilde_, XTilde_; + IndexType DTilde_, HTilde_, WTilde_; + IndexType ZDot_, YDot_, XDot_; +}; + +} // namespace ck_tile From 49c6b05c72f50fd41ae452ab46036db4d52b1a79 Mon Sep 17 00:00:00 2001 From: dnovakovic-dxc Date: Wed, 20 Aug 2025 17:22:51 +0200 Subject: [PATCH 14/46] Script for generating list of files not referenced in tests (#2696) * script for generating list of not referenced files in tests, list is in json format * script comment added * added empty line at the end of the script * format changes --- ...e_list_of_files_not_referenced_in_tests.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 script/dependency-parser/generate_list_of_files_not_referenced_in_tests.py diff --git a/script/dependency-parser/generate_list_of_files_not_referenced_in_tests.py b/script/dependency-parser/generate_list_of_files_not_referenced_in_tests.py new file mode 100644 index 0000000000..7a15fee128 --- /dev/null +++ b/script/dependency-parser/generate_list_of_files_not_referenced_in_tests.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +# This script generate list of files that are not referenced from any test (list in JSON format) +# Script only looks at not referenced files from three directories: include, library and profiler +# CK needs to be built with ability to use dependency parser and generate dependencies + +# Usage: python3 generate_list_of_files_not_referenced_in_tests -f /path/to/enhanced_dependency_mapping/json/file + +import argparse +import subprocess +import json + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-f", + required=True, + help="Path to enhanced_dependency_mapping.json file generated by dependency parser", + ) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + with open(args.f, "r") as file: + ref_files = json.load(file) + file_to_executables = ref_files["file_to_executables"] + + all_files = ( + subprocess.check_output( + 'find ../../include/ ../../library/ ../../profiler/ -type f -iname "*.cpp" -o -iname "*.hpp"', + shell=True, + ) + .decode("utf-8") + .split("\n") + ) + all_files = all_files[:-1] + all_files[:] = [x[6:] for x in all_files] + + all_referenced_files = [] + for v in file_to_executables: + if ( + "composablekernel/include/" in v + or "composablekernel/library/" in v + or "composablekernel/profiler/" in v + ): + exe_list = file_to_executables[v] + else: + continue + + found = any("bin/test_" in el for el in exe_list) + if found: + all_referenced_files.append(v) + + not_referenced_files = {"include": [], "library": [], "profiler": []} + for f in all_files: + found = any(f in el for el in all_referenced_files) + if not found: + pos = f.find("/") + not_referenced_files[f[:pos]].append(f) + + print(json.dumps(not_referenced_files, indent="\t")) + + +if __name__ == "__main__": + main() From 4cfa2c715876fb170bace7d564403b796d5045ba Mon Sep 17 00:00:00 2001 From: Yi DING Date: Fri, 22 Aug 2025 10:01:10 +0800 Subject: [PATCH 15/46] [CK_TILE] FMHA BWD Fix Compilation with Bias (#2682) * [CK_TILE] FMHA BWD Fix Compilation with Bias * Fix appendkv kApplyRoPE --- example/ck_tile/01_fmha/fmha_bwd.cpp | 14 -------- .../fmha/kernel/fmha_fwd_appendkv_kernel.hpp | 8 +++-- ...bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 18 +++++----- ...wd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp | 18 +++++----- ...mha_bwd_pipeline_trload_default_policy.hpp | 35 +++---------------- 5 files changed, 28 insertions(+), 65 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index 9c2907778f..9f1e0f6948 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -809,20 +809,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::stream_config stream_config_v{ nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; - - printf("\nfmha_bwd_traits: hdim_q=%d, hdim_v=%d, data_type=%s, is_group_mode=%d, mask_type=%d, " - "bias_type=%d, has_dbias=%d, has_dropout=%d, is_store_randval=%d, is_deterministic=%d\n", - fmha_traits.hdim_q, - fmha_traits.hdim_v, - fmha_traits.data_type.c_str(), - fmha_traits.is_group_mode, - static_cast(fmha_traits.mask_type), - static_cast(fmha_traits.bias_type), - fmha_traits.has_dbias, - fmha_traits.has_dropout, - fmha_traits.is_store_randval, - fmha_traits.is_deterministic); - fflush(stdout); fmha_bwd(fmha_traits, fmha_args, stream_config_v); dq_buf.FromDevice(dq_host.data()); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 81075d0ec6..66f51459af 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -649,8 +649,12 @@ struct FmhaFwdAppendKVKernel {0, i_n0}); // If kApplyRoPe is false, we set the rotary_dim to 0 - auto rotary_dim = kApplyRoPE ? kargs.rotary_dim : 0; - + auto rotary_dim = [&]() { + if constexpr(kApplyRoPE) + return kargs.rotary_dim; + else + return 0; + }(); FmhaPipeline{}(q_dram_window, k_dram_window, i_page_block_k, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 1d95bc2801..9a31498dd1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -347,22 +347,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = - make_tile_window(Policy::template TransformXDramTensorView( - bias_dram_block_window_tmp.get_bottom_tensor_view()), + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), {seqlen_q_start, bias_origin.at(number<1>{})}, Policy::template MakeBiasTileDistribution()); auto bias_lds = make_tensor_view( - bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor()); + bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor()); auto bias_lds_write_window = make_tile_window(bias_lds, make_tuple(number{}, number{}), {0, 0}); - auto bias_lds_read = make_tensor_view( - bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor()); auto bias_s_lds_read_window = - make_tile_window(bias_lds_read, - make_tuple(number{}, number{}), + make_tile_window(bias_lds_write_window.get_bottom_tensor_view(), + bias_lds_write_window.get_window_lengths(), bias_lds_write_window.get_window_origin(), Policy::template MakeBiasSTileDistribution()); @@ -500,8 +497,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - async_load_tile(bias_lds_write_window, bias_dram_window); - __builtin_amdgcn_s_waitcnt(3952); + const auto bias_tile = load_tile(bias_dram_window); + auto shuffled_bias_tile = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(shuffled_bias_tile, bias_tile); + store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); tile_elementwise_inout( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 65f70c4f62..3112070271 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -323,22 +323,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = - make_tile_window(Policy::template TransformXDramTensorView( - bias_dram_block_window_tmp.get_bottom_tensor_view()), + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), {bias_origin.at(number<0>{}), seqlen_kv_start}, Policy::template MakeBiasTileDistribution()); auto bias_lds = make_tensor_view( - bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor()); + bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor()); auto bias_lds_write_window = make_tile_window(bias_lds, make_tuple(number{}, number{}), {0, 0}); - auto bias_lds_read = make_tensor_view( - bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor()); auto bias_s_lds_read_window = - make_tile_window(bias_lds_read, - make_tuple(number{}, number{}), + make_tile_window(bias_lds_write_window.get_bottom_tensor_view(), + bias_lds_write_window.get_window_lengths(), bias_lds_write_window.get_window_origin(), Policy::template MakeBiasSTileDistribution()); @@ -490,8 +487,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - async_load_tile(bias_lds_write_window, bias_dram_window); - __builtin_amdgcn_s_waitcnt(3952); + const auto bias_tile = load_tile(bias_dram_window); + auto shuffled_bias_tile = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(shuffled_bias_tile, bias_tile); + store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); tile_elementwise_inout( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp index 7849c931f7..6259e5b473 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp @@ -551,11 +551,9 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy Problem::BlockFmhaShape::kQKHeaddim>(); } template - CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsWriteBlockDescriptor() + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsBlockDescriptor() { - return MakeXLdsWriteBlockDescriptor(); + return BlockFmhaBwdPipelineDefaultPolicy::MakeBiasLdsBlockDescriptor(); } template @@ -684,13 +682,6 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kQKHeaddim>(); } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsReadBlockDescriptor() - { - return MakeXLdsReadBlockDescriptor(); - } template CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSliceBlockDescriptor() @@ -966,25 +957,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution() { - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - - constexpr index_t N1 = min(static_cast(GetAlignmentBias()), - kMPerBlock * kNPerBlock / kBlockSize); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t M0 = kBlockSize / get_warp_size(); - constexpr index_t M1 = get_warp_size() / N0; - constexpr index_t M2 = kMPerBlock / M1 / M0; - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<2, 1>, - sequence<1, 2>>{}); + return BlockFmhaBwdPipelineDefaultPolicy::MakeShuffledBiasTileDistribution(); } template @@ -1048,7 +1021,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy { if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return sizeof(typename Problem::BiasDataType) * - MakeBiasLdsWriteBlockDescriptor().get_element_space_size(); + MakeBiasLdsBlockDescriptor().get_element_space_size(); else return 0; } From 4a7ecce096fa9008934b38336bc2ea4f2066a16d Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 22 Aug 2025 10:13:47 +0800 Subject: [PATCH 16/46] [CK_TILE][FMHA] Enable dwordx4 loading in async_load_tile_raw() (#2549) * Support async load dwordx4 * Enlarge load size on gfx950 --- .../core/arch/amd_buffer_addressing.hpp | 73 ++++++++++++------- .../arch/amd_buffer_addressing_builtins.hpp | 73 ++++++++++++------- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 8 +- 3 files changed, 103 insertions(+), 51 deletions(-) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 07be65a150..037e86909d 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1276,26 +1276,46 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, index_t offset, index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); -template -CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, - int32x4_t rsrc, - index_t voffset, - index_t /*soffset*/, - index_t ioffset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) +template +CK_TILE_DEVICE void async_buffer_load_dwordxn_v(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t /*soffset*/, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0, + bool_constant = {}) { - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) +#define CK_TILE_ASYNC_LOAD_WITH_INSTR(instr) \ + if constexpr(pre_nop) \ + asm volatile("s_nop 4\n" instr " %1, %2, 0 offen offset:%3 lds" \ + : "=r"(smem) /*dummy dependency for smem*/ \ + : "v"(voffset), "s"(rsrc), "n"(ioffset) \ + : "memory"); \ + else \ + asm volatile(instr " %1, %2, 0 offen offset:%3 lds" \ + : "=r"(smem) /*dummy dependency for smem*/ \ + : "v"(voffset), "s"(rsrc), "n"(ioffset) \ : "memory"); + + if constexpr(num_dwords == 1) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dword"); + } +#if defined(__gfx950__) + else if constexpr(num_dwords == 3) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx3"); + } + else if constexpr(num_dwords == 4) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx4"); + } +#endif else - asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) - : "memory"); + { + static_assert(false, "wrong! not implemented data width"); + } +#undef CK_TILE_ASYNC_LOAD_WITH_INSTR } CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) @@ -1766,15 +1786,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem, index_t src_immediate_addr_offset = 0, bool_constant = {}) { - static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); + constexpr index_t num_bytes = sizeof(T) * N; + constexpr index_t num_words = num_bytes / 4; + static_assert(num_bytes % 4 == 0 && (num_words == 1 || num_words == 3 || num_words == 4), + "wrong! only support in dword, dwordx3, dwordx4"); - async_buffer_load_dword_v(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - 0, - bool_constant{}); + async_buffer_load_dwordxn_v(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + 0, + bool_constant{}); } template -CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, - int32x4_t rsrc, - index_t voffset, - index_t /*soffset*/, - index_t ioffset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) +template +CK_TILE_DEVICE void async_buffer_load_dwordxn_v(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t /*soffset*/, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0, + bool_constant = {}) { - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) +#define CK_TILE_ASYNC_LOAD_WITH_INSTR(instr) \ + if constexpr(pre_nop) \ + asm volatile("s_nop 4\n" instr " %1, %2, 0 offen offset:%3 lds" \ + : "=r"(smem) /*dummy dependency for smem*/ \ + : "v"(voffset), "s"(rsrc), "n"(ioffset) \ + : "memory"); \ + else \ + asm volatile(instr " %1, %2, 0 offen offset:%3 lds" \ + : "=r"(smem) /*dummy dependency for smem*/ \ + : "v"(voffset), "s"(rsrc), "n"(ioffset) \ : "memory"); + + if constexpr(num_dwords == 1) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dword"); + } +#if defined(__gfx950__) + else if constexpr(num_dwords == 3) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx3"); + } + else if constexpr(num_dwords == 4) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx4"); + } +#endif else - asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) - : "memory"); + { + static_assert(false, "wrong! not implemented data width"); + } +#undef CK_TILE_ASYNC_LOAD_WITH_INSTR } CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) @@ -1536,15 +1556,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, index_t src_immediate_addr_offset = 0, bool_constant = {}) { - static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); + constexpr index_t num_bytes = sizeof(T) * N; + constexpr index_t num_words = num_bytes / 4; + static_assert(num_bytes % 4 == 0 && (num_words == 1 || num_words == 3 || num_words == 4), + "wrong! only support in dword, dwordx3, dwordx4"); - async_buffer_load_dword_v(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - 0, - bool_constant{}); + async_buffer_load_dwordxn_v(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + 0, + bool_constant{}); } template ; if constexpr(AsyncCopy) { - return 4 / sizeof(KDataType); +#if defined(__gfx950__) + constexpr index_t MaxLoadSizeInBytes = 4 * 4; // dwordx4 +#else + constexpr index_t MaxLoadSizeInBytes = 4; // dword +#endif + + return MaxLoadSizeInBytes / sizeof(KDataType); } else { From 0db21053e68817a50b0ed0ceea87e88228ab2475 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 22 Aug 2025 10:17:05 +0800 Subject: [PATCH 17/46] [CK_TILE] Allow switching between SGPR/VGPR get_warp_id() return values (#2669) * Allow return VGPR get_warp_id() value * Avoid using SALU in async_load_raw() --- include/ck_tile/core/arch/arch.hpp | 13 +++++++++++-- include/ck_tile/core/tensor/tile_window.hpp | 7 +++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 234929d6e6..42f2390cde 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -98,9 +98,18 @@ CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; } // Use these instead CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); } -CK_TILE_DEVICE index_t get_warp_id() +template +CK_TILE_DEVICE index_t get_warp_id(bool_constant = {}) { - return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size()); + const index_t warp_id = threadIdx.x / get_warp_size(); + if constexpr(ReturnSgpr) + { + return __builtin_amdgcn_readfirstlane(warp_id); + } + else + { + return warp_id; + } } CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; } diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index ad5902f16e..f5ddcd278c 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -288,8 +288,11 @@ struct tile_window_with_static_distribution sizeof(LdsDataType) - size_per_buf; - const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - m0_set_with_memory(m0_init_value); // This should be wave independent + // Use VALU so the compiler can optimize redundant/repeated computations + const index_t m0_init_value = + size_per_buf + size_per_wave * get_warp_id(/*ReturnSgpr=*/bool_constant{}); + m0_set_with_memory( + __builtin_amdgcn_readfirstlane(m0_init_value)); // This should be wave independent using Traits = typename Base::Traits; From d6e49c5fdec1eedf9c6e6dbd59e7f788c2e2fc2e Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Sat, 23 Aug 2025 05:46:30 +0800 Subject: [PATCH 18/46] Extend XDL kernel to Support RDNA3/4 - Part 1 (#2606) --- include/ck/host_utility/device_prop.hpp | 37 +- .../blockwise_gemm_pipeline_xdlops_base.hpp | 9 +- .../blockwise_gemm_pipeline_xdlops_v2.hpp | 8 +- ...kwise_gemm_pipeline_xdlops_v2_ab_scale.hpp | 5 +- ...ckwise_gemm_pipeline_xdlops_v2_b_scale.hpp | 8 +- .../blockwise_gemm_pipeline_xdlops_v3.hpp | 4 +- .../impl/device_gemm_xdl_cshuffle_v3.hpp | 155 ++++++- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 161 ++++++-- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 386 ++++++++++++++++-- include/ck/utility/blkgemmpipe_scheduler.hpp | 2 +- include/ck/utility/get_id.hpp | 35 +- 11 files changed, 683 insertions(+), 127 deletions(-) diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 5439bbe1f0..2bc5a4414e 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -52,10 +52,27 @@ inline std::string get_device_name() } } +inline bool is_gfx12_supported() +{ + return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201"; +} + +inline bool is_gfx11_supported() +{ + return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103" || + ck::get_device_name() == "gfx1150" || ck::get_device_name() == "gfx1151" || + ck::get_device_name() == "gfx1152"; +} + inline bool is_xdl_supported() { return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; + ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" +#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE) + || is_gfx12_supported() || is_gfx11_supported() +#endif + ; } inline bool is_lds_direct_load_supported() @@ -67,7 +84,8 @@ inline bool is_lds_direct_load_supported() inline bool is_bf16_atomic_supported() { - return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; + return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || + is_gfx12_supported(); } inline bool is_gfx101_supported() @@ -83,18 +101,5 @@ inline bool is_gfx103_supported() ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036"; } -inline bool is_gfx11_supported() -{ - return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103" || - ck::get_device_name() == "gfx1150" || ck::get_device_name() == "gfx1151" || - ck::get_device_name() == "gfx1152"; -} - -inline bool is_gfx12_supported() -{ - return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201"; -} - } // namespace ck #endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index cd13dbb836..acd1d2ae49 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -41,7 +41,9 @@ struct BlockwiseGemmXdlops_pipeline_base using ThisThreadBlock = ThisThreadBlock; // Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs. - static constexpr index_t WaveSize = 64; + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); @@ -74,9 +76,6 @@ struct BlockwiseGemmXdlops_pipeline_base return 1; }(); - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - using HotLoopInstList = ck::BlockwiseGemmXdlops_pipeline_hotloop_inst= 1 ? 4 * WarpSize / BlockSize : 1; + (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); @@ -625,13 +626,14 @@ struct BlockwiseGemmXdlops_pipeline_v2= 1 ? 4 * WarpSize / BlockSize : 1; + (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp index 0c030030fe..119f8a3306 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -141,9 +141,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale= 1 ? 4 * WarpSize / BlockSize : 1; + (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp index 69002d7962..80c65515e8 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -139,9 +139,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale= 1 ? 4 * WarpSize / BlockSize : 1; + (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); @@ -626,13 +627,14 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale= 1 ? 4 * WarpSize / BlockSize : 1; + (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp index b5d6180ab3..7203348418 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -159,6 +159,7 @@ struct BlockwiseGemmXdlops_pipeline_v3 { + template + static constexpr auto GetNXdlPerWave() + { + constexpr index_t Waves = isWave64 ? BlockSize / 64 : BlockSize / 32; + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL); + static_assert(MWaves > 0); + + constexpr index_t NWaves = Waves / MWaves; + if constexpr(NWaves == 0) + { + return 0; + } + else + { + if constexpr(NPerBlock % (NPerXDL * NWaves) == 0) + { + return NPerBlock / (NWaves * NPerXDL); + } + else + { + return 0; + } + } + } // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, @@ -199,7 +227,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) @@ -254,12 +284,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -285,7 +312,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 rotating_mem( + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); rotating_mem.Print(); @@ -733,6 +760,31 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return RunImp( + reinterpret_cast(arg), + stream_config); + } + } + return 0; + } // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -754,9 +806,39 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 && arg.KBatch > 1) + if(arg.KBatch > 1) { - return false; + if(is_gfx11_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v) + { + return false; + } + + if(sizeof(CDataType) == 1) + { + return false; + } + } + + if(is_gfx11_supported() || is_gfx12_supported()) + { + if(MPerXDL != 16 || NPerXDL != 16) + { + return false; + } + } + + if(is_gfx11_supported()) + { + if constexpr(std::is_same_v || + std::is_same_v) + { + return false; + } } if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || @@ -767,7 +849,29 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + else + { + return false; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + else + { + return false; + } + } } // polymorphic @@ -849,6 +953,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 0) + { + PrefetchStages = GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages; + AMmaKStride = GridwiseGemm64::BlockwiseGemmPipe::AMmaKStride; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + PrefetchStages = GridwiseGemm32::BlockwiseGemmPipe::PrefetchStages; + AMmaKStride = GridwiseGemm32::BlockwiseGemmPipe::AMmaKStride; + } + } + // clang-format off str << "DeviceGemmXdlUniversal" << "<" @@ -872,9 +995,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2=32 && PackedSize != 2 on pre-gfx950 - if constexpr(static_cast(Arch::is_gfx950_build) || - (GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32) || - (GridwiseGemm::AK1Number >= 32 && GridwiseGemm::APackedSize == 2) || - (GridwiseGemm::BK1Number >= 32 && GridwiseGemm::BPackedSize == 2)) +#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -78,23 +66,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - enum struct Arch : bool +#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { -#if defined(__gfx950__) - is_gfx950_build = true, -#else - is_gfx950_build = false, -#endif - }; - // skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950 - if constexpr(static_cast(Arch::is_gfx950_build) || - (GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32) || - (GridwiseGemm::AK1Number >= 32 && GridwiseGemm::APackedSize == 2) || - (GridwiseGemm::BK1Number >= 32 && GridwiseGemm::BPackedSize == 2)) - { - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -696,12 +672,23 @@ struct GridwiseGemm_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " - << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC - << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 - << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " + // clang-format off + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; + // clang-format off } index_t M; @@ -831,6 +818,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = (NXdlPerWave * NPerXdl == 0) ? 0 : NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = (MWaves * NWaves == 0) ? 64 : BlockSize / (MWaves * NWaves); + // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -888,7 +879,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto KThreadRead = WaveSize / MPerXdl; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) @@ -969,6 +960,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = (NXdlPerWave * NPerXdl == 0) ? 0 : NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = (MWaves * NWaves == 0) ? 64 : BlockSize / (MWaves * NWaves); // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -1022,7 +1016,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto KThreadRead = WaveSize / NPerXdl; constexpr auto K0PerThreadRead = BK0Number / KThreadRead; constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) @@ -1169,12 +1163,99 @@ struct GridwiseGemm_xdl_cshuffle_v3 c_block_size * sizeof(CShuffleDataType)); } + template + __device__ static bool constexpr IsValidCompilationParameter() + { + enum struct Arch : bool + { +#if defined(__gfx950__) + is_gfx950_build = true, +#else + is_gfx950_build = false, +#endif + }; + + // skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950 + if constexpr(static_cast(Arch::is_gfx950_build) || + (AK1Number < 32 && BK1Number < 32) || + (AK1Number >= 32 && APackedSize == 2) || + (BK1Number >= 32 && BPackedSize == 2)) + { + + } + else + { + return false; + } + + // Check tile size +#if defined(__gfx11__) || defined(__gfx12__) + if constexpr(MPerXdl != 16 || NPerXdl != 16) + { + return false; + } +#endif + // Check atomic caps +#if defined(__gfx11__) + constexpr bool SupportMemOp = CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set; +#else + constexpr bool SupportMemOp = sizeof(CDataType) >= 2 || (CGlobalMemoryDataOperation == + InMemoryDataOperationEnum::Set); +#endif + if constexpr(SupportMemOp == false) + { + return false; + } + + // Check tile size + if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + if constexpr(MWaves > 0 && NWaves > 0) + { + constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); + if constexpr(WaveSize == get_warp_size()) + { + return true; + } + else + { + return false; + } + } + else + { + return false; + } + } + else + { + return false; + } + } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { - static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && - (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, - "Invalid tuning param!"); + if constexpr((MPerXdl * MXdlPerWave) == 0 || (NXdlPerWave * NPerXdl) == 0) + { + return false; + } + else + { + if constexpr((MPerBlock % (MPerXdl * MXdlPerWave) != 0) || + (NPerBlock % (NXdlPerWave * NPerXdl) != 0)) + { + return false; + } + else + { + if(BlockwiseGemmPipe::WaveSize != get_warp_size()) + { + return false; + } + } + } if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 64d7f92750..2ce08e7044 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -6,6 +6,7 @@ #include "ck/utility/common_header.hpp" #include "ck/utility/math.hpp" #include "ck/utility/amd_xdlops.hpp" +#include "ck/utility/amd_wmma.hpp" namespace ck { /** @@ -76,7 +77,21 @@ enum struct MfmaInstr mfma_f32_32x32x64f8f6f4, mfma_f32_16x16x128f8f6f4, mfma_scale_f32_32x32x64f8f6f4, - mfma_scale_f32_16x16x128f8f6f4 + mfma_scale_f32_16x16x128f8f6f4, + // gfx11 + wmma_f32_16x16x16_f16, + wmma_f32_16x16x16_bf16, + wmma_i32_16x16x16_iu8, + wmma_unsupport_16x16_gfx11, + // gfx12 + wmma_f32_16x16x16_f16_gfx12, + wmma_f32_16x16x16_bf16_gfx12, + wmma_i32_16x16x16_iu8_gfx12, + wmma_f32_16x16x16_f8f8_gfx12, + wmma_f32_16x16x16_f8bf8_gfx12, + wmma_f32_16x16x16_bf8f8_gfx12, + wmma_f32_16x16x16_bf8bf8_gfx12, + wmma_unsupport_16x16_gfx12, }; template @@ -932,6 +947,175 @@ struct mfma_type } }; +// gfx11 +struct mfma_type_gfx11_base +{ + static constexpr index_t group_size = 8; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 8; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 32; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; +}; + +template <> +struct mfma_type : public mfma_type_gfx11_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_f16_w32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx11_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_bf16_w32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx11_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_i32_16x16x16_iu8_w32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx11_base +{ + static constexpr index_t k_per_blk = 2; + template + __device__ void run(const FloatA&, const FloatB&, FloatC&) const + { + // empty for all unsupported types. + } +}; + +// gfx12 +struct mfma_type_gfx12_base +{ + static constexpr index_t group_size = 8; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 8; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 32; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_f16_w32_gfx12::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_bf16_w32_gfx12::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_i32_16x16x16_iu8_w32_gfx12::Run( + a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_f8f8_w32_gfx12::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type : public mfma_type_gfx12_base +{ + static constexpr index_t k_per_blk = 2; + template + __device__ void run(const FloatA&, const FloatB&, FloatC&) const + { + // empty for all unsupported types. + } +}; + template constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_unsupport_16x16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_f64_16x16x4f64; +#endif } template <> @@ -993,7 +1183,13 @@ struct MfmaSelector template <> constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_unsupport_16x16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_f32_16x16x4xf32; +#endif } template <> @@ -1026,7 +1222,11 @@ struct MfmaSelector template <> constexpr auto GetMfma() { -#if defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_f16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_f32_16x16x16_f16; +#elif defined(__gfx950__) return MfmaInstr::mfma_f32_16x16x32f16; #else return MfmaInstr::mfma_f32_16x16x16f16; @@ -1036,7 +1236,13 @@ struct MfmaSelector template <> constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_f16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_f32_16x16x16_f16; +#else return MfmaInstr::mfma_f32_16x16x16f16; +#endif } template <> @@ -1082,7 +1288,11 @@ struct MfmaSelector template <> constexpr auto GetMfma() { -#if defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_bf16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_f32_16x16x16_bf16; +#elif defined(__gfx950__) return MfmaInstr::mfma_f32_16x16x32bf16; #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_16x16x16bf16_1k; @@ -1094,7 +1304,11 @@ struct MfmaSelector template <> constexpr auto GetMfma() { -#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_bf16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_f32_16x16x16_bf16; +#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_16x16x16bf16_1k; #else return MfmaInstr::mfma_f32_16x16x8bf16; @@ -1126,7 +1340,11 @@ struct MfmaSelector template <> constexpr auto GetMfma() { -#if defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_i32_16x16x16_iu8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_i32_16x16x16_iu8; +#elif defined(__gfx950__) return MfmaInstr::mfma_i32_16x16x64i8; #elif defined(__gfx942__) return MfmaInstr::mfma_i32_16x16x32i8; @@ -1138,7 +1356,11 @@ struct MfmaSelector template <> constexpr auto GetMfma() { -#if defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_i32_16x16x16_iu8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_i32_16x16x16_iu8; +#elif defined(__gfx942__) || defined(__gfx950__) return MfmaInstr::mfma_i32_16x16x32i8; #else return MfmaInstr::mfma_i32_16x16x16i8; @@ -1186,13 +1408,23 @@ struct MfmaSelector template <> constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_f8f8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_f32_16x16x32f8f8; +#endif } template <> constexpr auto GetMfma() { -#if defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_f8f8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#elif defined(__gfx950__) return MfmaInstr::mfma_f32_16x16x128f8f6f4; #else return MfmaInstr::mfma_f32_16x16x32f8f8; @@ -1263,13 +1495,23 @@ struct MfmaSelector template <> constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_f32_16x16x32bf8bf8; +#endif } template <> constexpr auto GetMfma() { -#if defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#elif defined(__gfx950__) return MfmaInstr::mfma_f32_16x16x128f8f6f4; #else return MfmaInstr::mfma_f32_16x16x32bf8bf8; @@ -1295,13 +1537,23 @@ struct MfmaSelector template <> constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_f8bf8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_f32_16x16x32f8bf8; +#endif } template <> constexpr auto GetMfma() { -#if defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_f8bf8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#elif defined(__gfx950__) return MfmaInstr::mfma_f32_16x16x128f8f6f4; #else return MfmaInstr::mfma_f32_16x16x32f8bf8; @@ -1327,13 +1579,23 @@ struct MfmaSelector template <> constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_bf8f8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_f32_16x16x32bf8f8; +#endif } template <> constexpr auto GetMfma() { -#if defined(__gfx950__) +#if defined(__gfx12__) + return MfmaInstr::wmma_f32_16x16x16_bf8f8_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#elif defined(__gfx950__) return MfmaInstr::mfma_f32_16x16x128f8f6f4; #else return MfmaInstr::mfma_f32_16x16x32bf8f8; @@ -1355,10 +1617,18 @@ struct MfmaSelector static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk, "n_per_blk != num_threads_per_blk"); - +#if defined(__gfx11__) + if constexpr(MPerXdlops == 16 && NPerXdlops == 16) + { + static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks * 2 == + selected_mfma.m_per_blk, + "m_per_blk != num_input_blks * num_regs_per_blk"); + } +#else static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks == selected_mfma.m_per_blk, "m_per_blk != num_input_blks * num_regs_per_blk"); +#endif static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks || selected_mfma.num_output_blks == 1, @@ -1424,8 +1694,9 @@ struct XdlopsGemm static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 || MPerXdlops == 64, "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); - +#if defined(__HIP_DEVICE_COMPILE__) static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk"); +#endif } // XDL output supporting C = A * B @@ -1434,10 +1705,11 @@ struct XdlopsGemm __host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) { - const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); - const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); - const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); - const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk; return transform_tensor_descriptor( c_desc_m0_n0_m1_n1_m2_n2, @@ -1446,7 +1718,7 @@ struct XdlopsGemm make_pass_through_transform(M1), make_pass_through_transform(N1), make_unmerge_transform(make_tuple(Number{}, - Number{}, + Number{}, Number{})), make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, @@ -1469,12 +1741,13 @@ struct XdlopsGemm __host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3( const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) { - const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); - const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); - const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); - const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); - const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); - const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk; return transform_tensor_descriptor( c_desc_m0_n0_m1_n1_m2_n2, @@ -1485,7 +1758,7 @@ struct XdlopsGemm make_pass_through_transform(M2), make_pass_through_transform(N2), make_unmerge_transform(make_tuple(Number{}, - Number{}, + Number{}, Number{})), make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, @@ -1512,10 +1785,11 @@ struct XdlopsGemm __host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) { - const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); - const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); - const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); - const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk; return transform_tensor_descriptor( c_desc_m0_n0_m1_n1_m2_n2, @@ -1525,7 +1799,7 @@ struct XdlopsGemm make_pass_through_transform(N1), make_pass_through_transform(Number{}), make_unmerge_transform(make_tuple(Number{}, - Number{}, + Number{}, Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -1545,11 +1819,12 @@ struct XdlopsGemm __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2) { - const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); - const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1); - const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2); - const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3); - const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4); + const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk; return transform_tensor_descriptor( c_desc_g_m0_n0_m1_n1_m2_n2, @@ -1558,9 +1833,8 @@ struct XdlopsGemm make_pass_through_transform(N0), make_pass_through_transform(M1), make_pass_through_transform(N1), - make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk, - mfma_instr.num_input_blks, - mfma_instr.group_size)), + make_unmerge_transform(make_tuple( + mfma_instr.num_groups_per_blk, num_blks, mfma_instr.group_size)), make_pass_through_transform(mfma_instr.num_threads_per_blk)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -1642,8 +1916,32 @@ struct XdlopsGemm __device__ static auto GetBlkIdx() { - const auto laneId = GetLaneId(); + const auto laneId = GetLaneId(); + constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk; + constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform(make_tuple(1, num_blks, mfma_instr.num_threads_per_blk))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto blk_idx = + threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId)); + + const auto blk_id = blk_idx[I1]; + const auto blk_td = blk_idx[I2]; + + return make_tuple(blk_id, blk_td); + } + + template + __device__ static auto GetGfx11InputBlkIdx() + { + const auto laneId = GetLaneId() % mfma_instr.num_threads_per_blk; + if constexpr(SwizzleA) + { + laneId = ((laneId & 1) << 3) | (laneId >> 1); + } constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform( make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))), @@ -1661,8 +1959,12 @@ struct XdlopsGemm __host__ __device__ static auto CalculateAThreadOriginDataIndex() { - const auto laneId = GetLaneId(); + const auto laneId = GetLaneId(); +#if defined(__gfx11__) + const auto blk_idx = GetGfx11InputBlkIdx(); +#else const auto blk_idx = GetBlkIdx(); +#endif const auto blk_id = blk_idx[I0]; const auto blk_td = blk_idx[I1]; @@ -1679,8 +1981,12 @@ struct XdlopsGemm __host__ __device__ static auto CalculateBThreadOriginDataIndex() { - const auto laneId = GetLaneId(); + const auto laneId = GetLaneId(); +#if defined(__gfx11__) + const auto blk_idx = GetGfx11InputBlkIdx(); +#else const auto blk_idx = GetBlkIdx(); +#endif const auto blk_id = blk_idx[I0]; const auto blk_td = blk_idx[I1]; diff --git a/include/ck/utility/blkgemmpipe_scheduler.hpp b/include/ck/utility/blkgemmpipe_scheduler.hpp index 861b81b1f6..63466a36f2 100644 --- a/include/ck/utility/blkgemmpipe_scheduler.hpp +++ b/include/ck/utility/blkgemmpipe_scheduler.hpp @@ -75,9 +75,9 @@ template struct BlockwiseGemmXdlops_pipeline_hotloop_inst { - static constexpr index_t WaveSize = 64; static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL); + static constexpr index_t WaveSize = BlockSize / WaveNumM / WaveNumN; static constexpr index_t A_LDS_Read_Width = ALDSReadWidth; static constexpr index_t B_LDS_Read_Width = BLDSReadWidth; diff --git a/include/ck/utility/get_id.hpp b/include/ck/utility/get_id.hpp index fd0d1024b2..53e865767b 100644 --- a/include/ck/utility/get_id.hpp +++ b/include/ck/utility/get_id.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,38 @@ namespace ck { +#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE) +__device__ constexpr index_t get_warp_size() +{ +#if defined(__HIP_DEVICE_COMPILE__) +#if defined(__GFX9__) + return 64; +#else + return 32; +#endif +#else + return 64; +#endif +} + +inline __host__ index_t get_warp_size() +{ +#if !(defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC)) + int device = 0; + int result = 0; + auto status = hipGetDevice(&device); + if(status == hipSuccess) + { + status = hipDeviceGetAttribute(&result, hipDeviceAttributeWarpSize, device); + if(status == hipSuccess) + { + return result; + } + } +#endif + return 64; +} +#else __host__ __device__ constexpr index_t get_warp_size() { #if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) @@ -15,6 +47,7 @@ __host__ __device__ constexpr index_t get_warp_size() return 32; #endif } +#endif __device__ index_t get_thread_local_1d_id() { return threadIdx.x; } From c71d7ddd7473b1c952f961e29b09f4a61f0a87d5 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Sun, 24 Aug 2025 21:29:23 -0700 Subject: [PATCH 19/46] Remove unsupported use of c++20 concept. (#2719) Downstream libraries aren't migrated to c++20 yet, so replace a use of c++20 concept with equivalent SFINAE logic. The template checks for both the existence and the truthiness of the static member variable. --- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 2 +- ...wd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp | 22 +++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 8750c8b377..5e16fc563b 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -40,7 +40,7 @@ struct FmhaBwdDQDKDVKernel static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; static constexpr bool kUseQrQtrDorPipeline = - ck_tile::fmha_bwd_qr_qtr_dor_pipeline_c; + ck_tile::fmha_bwd_qr_qtr_dor_pipeline::value; static_assert(!kUseQrQtrDorPipeline || !std::is_same_v, "QrQtrDorPipeline needs QGradEpiloguePipeline"); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 3112070271..789cfb3ea4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -738,6 +738,24 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR } }; -template -concept fmha_bwd_qr_qtr_dor_pipeline_c = T::is_qr_qtr_dor_pipeline; +// We don't support C++20 concepts yet, so we use SFINAE check the existence and truthiness +// of is_qr_qtr_dor_pipeline static member instead of using concepts directly. +// +// The template struct's value field is equivalent to the following commented concept definition. +// +// template +// concept fmha_bwd_qr_qtr_dor_pipeline_c = T::is_qr_qtr_dor_pipeline; + +// SFINAE test for existence and truthiness of static member is_qr_qtr_dor_pipeline. +template +struct fmha_bwd_qr_qtr_dor_pipeline : std::false_type +{ +}; + +template +struct fmha_bwd_qr_qtr_dor_pipeline> + : std::bool_constant +{ +}; + } // namespace ck_tile From de61e554938265a5d17a1bba8c148457125e80cd Mon Sep 17 00:00:00 2001 From: Yi DING Date: Mon, 25 Aug 2025 20:55:12 +0800 Subject: [PATCH 20/46] [CK_TILE] FMHA avoid unnecessary vmcnt0 (#2715) * FMHA avoid unnecessary vmcnt0 Squashed commit of the following: commit 7bdf6a7eef84d254cdcea1af01402307c566e6fe Author: aska-0096 Date: Fri Aug 22 03:15:51 2025 +0000 merge develop and solve conflicts commit f21e916a8c430de660abf480d54cefc80255c268 Merge: a7dd2a7d1 0db21053e Author: aska-0096 Date: Fri Aug 22 03:15:21 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into vmcnt0issue commit a7dd2a7d136e8796b1257d4124195f0a1b746ed9 Author: Ding, Yi Date: Tue Aug 19 02:17:43 2025 +0000 update bwd commit 380aa8f311875cf7281442bf3fa9be720218a78f Author: Kevin Choi Date: Mon Aug 18 19:36:38 2025 +0000 add restrict to applicable functions commit b85daba2a36fa9a15250c0a91949b63c63aee11e Author: Ding, Yi Date: Mon Aug 18 02:07:03 2025 +0000 bwd filter commit 75c4b9372fa73f2a45fd1c4f44b7504cc459b621 Author: Kevin Choi Date: Sat Aug 16 08:15:23 2025 +0000 remove noinline attr as it causes a lot more s_waitcnt's commit 598e3fec417eb0ff8089c260e758aa2c305ccd1d Author: Kevin Choi Date: Thu Aug 14 12:11:17 2025 +0000 remove innerloop, move restrict parameters to mainloop and add noinline attribute. commit 334040853749a931bd5c317170f17773967d377b Author: Kevin Choi Date: Thu Aug 14 07:06:51 2025 +0000 Create inner lambda with restrict parameters, add restrict to some parameters commit 3bc45ecbc7d4b630fd8fc436b89c0f2720a0449a Author: aska-0096 Date: Thu Aug 14 03:43:54 2025 +0000 save for debug commit de4db6c4c5d7cbe7b98ca597c48e300abe6dc4a1 Merge: 108abf00e 68694cb78 Author: aska-0096 Date: Wed Aug 13 02:15:22 2025 +0000 Merge branch 'wip-async-tr-fa' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit 108abf00e062552a1533f4339acf0dc831f671b7 Merge: 0810799e2 0f42a92fc Author: aska-0096 Date: Wed Aug 13 02:14:26 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit 68694cb781b76827a2ccf8b27bd8dd4cf81d1c80 Merge: 0810799e2 20288caa2 Author: asleepzzz Date: Wed Aug 13 00:34:11 2025 +0800 Merge branch 'develop' into wip-async-tr-fa commit 0810799e25c8b7a4c45eea9a027eaa5ca4acc767 Author: aska-0096 Date: Tue Aug 12 14:25:50 2025 +0000 refactor blockgemm change, isolate to v2; commit fd1eb323af1f0c1121fbbf0deccaaaa804fa3508 Author: aska-0096 Date: Tue Aug 12 09:26:13 2025 +0000 clang format commit 75f6f6bac4cd9921768bfd488f2887fdbd802c7f Merge: bcc05eee6 8e1eb0c1e Author: aska-0096 Date: Tue Aug 12 09:04:41 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit bcc05eee62ab82719bf69202022fd86fd5c69e70 Author: aska-0096 Date: Tue Aug 12 08:46:06 2025 +0000 Fix the bug commit 96d24497f5f94be894e9a06bd65cab25cacf20ac Author: aska-0096 Date: Tue Aug 12 04:02:41 2025 +0000 fix conflict. disable all v-col instance for fmha fwd commit 1716171be4a5e91a03f2030560c9eddd033b046f Merge: 1c9800790 4fde1646e Author: aska-0096 Date: Tue Aug 12 03:52:34 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit 1c98007901db992bf3b56186511fdab90f9d260c Author: aska-0096 Date: Tue Aug 12 01:53:31 2025 +0000 clang format commit f43e903b1dc41b91b6db1b457700822bdfe3d16f Merge: 3868ddd70 a7badc6ec Author: aska-0096 Date: Tue Aug 12 01:52:52 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit 3868ddd7082633518e7e84b0a9a7cc2aece58003 Merge: 498d234ab 191c62967 Author: aska-0096 Date: Mon Aug 11 15:59:40 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit 498d234ab875a05afc8236fc8952729741b70b28 Author: aska-0096 Date: Mon Aug 11 15:37:37 2025 +0000 change the warp setting for hdim32 fmha fwd commit b86f7786e2b36140f002f51959a47255a8d98251 Author: aska-0096 Date: Mon Aug 11 14:21:09 2025 +0000 tempsave, update the blocksync functions commit 7b8052d7ca0191d3142543633bca29036fb2d342 Author: aska-0096 Date: Sun Aug 10 06:00:51 2025 +0000 fix bug in pki4 commit 76cbbb84a2f0d9517f1d832b14b89d9445c23c1c Author: aska-0096 Date: Sat Aug 9 03:25:12 2025 +0000 fix bugs in gemm commit 8c101ccb884597eef9afc46a29abc24f5f56e7b1 Author: aska-0096 Date: Fri Aug 8 18:35:53 2025 +0000 fix bug on non-gfx950 commit efb854927966ca8ce605daa230a612aa3cc38ebf Author: aska-0096 Date: Fri Aug 8 17:53:19 2025 +0000 fix bug commit 729e8785fb6b9ecae6b71fc73233894be3e1fffb Author: aska-0096 Date: Fri Aug 8 15:42:15 2025 +0000 fix bugs commit 250dc13c75acc23850f03334cd603d76210a9429 Author: aska-0096 Date: Fri Aug 8 09:31:01 2025 +0000 fix clangformat with 18.1.3 commit 106edeecd9e1d1304b56d6f70d97d08e6cb93cc0 Author: aska-0096 Date: Fri Aug 8 09:07:40 2025 +0000 remove non-necessary change commit 78edd7303b8248e3c4fb266efc92b08fd17b9add Author: aska-0096 Date: Fri Aug 8 09:04:02 2025 +0000 bug fix, clang format; commit 3b9fb6af389dcdb45df9f78887b03921e4f4dff4 Author: aska-0096 Date: Fri Aug 8 08:08:03 2025 +0000 Remove unnecessary changes commit 6bb57c2c574234e2ed3b22c5a54d336bc0c63767 Merge: 1ecee378d ab2602683 Author: aska-0096 Date: Fri Aug 8 07:50:12 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit 1ecee378d528433f76876a892da41f07733ee935 Author: aska-0096 Date: Fri Aug 8 06:19:31 2025 +0000 remove unnecessary files; rename some files commit b4640a9de65a6e8310879a8691c260b47052361a Author: aska-0096 Date: Fri Aug 8 05:46:18 2025 +0000 merge fa_decode pipeline into fmha_fwd api commit fe63a646a459498e5677efd213fa3f8b714387c8 Author: aska-0096 Date: Wed Aug 6 05:58:43 2025 +0000 add __restrict__ to tr load commit 414cad667ba6cabf70165dadac85b74b791916de Author: aska-0096 Date: Tue Aug 5 07:23:51 2025 +0000 Add XOR fold strategy for hdim<128, but perf dropped; disable it by default; wait further perf debug commit 0d12fc944ff1078ca31beced4cc6235ef781c996 Author: aska-0096 Date: Mon Aug 4 10:27:42 2025 +0000 Add v_permlaneb32 for block_reduce. Disable it as it will cause un-coexecutable packed math in FA commit 4f31847de1d03e83eab539cdad792aded2ffed54 Author: aska-0096 Date: Mon Aug 4 10:02:17 2025 +0000 add vmcnt guard before load ktile commit 746f4ccb991b177099d557adb711db602128a578 Author: aska-0096 Date: Mon Aug 4 06:49:01 2025 +0000 Load Q through lds, implement xor; commit 2d4e73d2b449392b9fa3f1d011132d621e64f9a9 Author: aska-0096 Date: Fri Aug 1 10:44:54 2025 +0000 small refactor commit a28b6e67fedf0b6e934102eb98c9d3bd96ac8da5 Author: aska-0096 Date: Thu Jul 31 10:25:37 2025 +0000 upgrade prefill pipeline; simple iglp; consistent data produce and consume order commit 75cba48682ebba3586ac8574c4bc848773941a20 Author: aska-0096 Date: Thu Jul 31 05:13:27 2025 +0000 enable larger tile size; upgrade xor pattern commit 69890afc982e8a9d7932c5026f3313ee0b9c51d1 Author: aska-0096 Date: Wed Jul 30 12:25:33 2025 +0000 remove all lds bankconflict with xor layouts commit 8dacc35c4c74a391676140d180dd52099486f649 Author: aska-0096 Date: Wed Jul 30 03:51:06 2025 +0000 enable prefill overload operator(). commit 13bcc913de41823c68ed16cb1432c67f8ad0ea43 Author: aska-0096 Date: Fri Jul 25 07:10:01 2025 +0000 fix the lds alignment caused performance regression commit af28123cec1a0c8f6b81b97820e4923e00604f34 Author: aska-0096 Date: Wed Jul 23 09:05:57 2025 +0000 remove unnecessary features commit 14e0ab70c65be04f422157242c9be5711347d167 Author: aska-0096 Date: Tue Jul 22 08:04:05 2025 +0000 tempsave. asynccopy+trload sanity checked commit 1b468bac0bee62381fa7591ee2c114f8ea83061f Author: aska-0096 Date: Mon Jul 21 05:55:55 2025 +0000 tempsave, trload+asyncload done commit afd96d81807c39d5b3433739556b016758f09f7b Author: aska-0096 Date: Fri Jul 18 10:04:34 2025 +0000 compile pass commit 5616551115267174128f8fed7d6241d41baaf81d Merge: ae39c84f5 095393276 Author: aska-0096 Date: Fri Jul 18 05:17:27 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit ae39c84f559f5c5bbfc2ebe4535874df3a41462f Author: aska-0096 Date: Fri Jul 18 05:16:39 2025 +0000 tempsave commit 94b6430489a7be3611234322a9e1b88ebcf0564f Author: aska-0096 Date: Thu Jul 17 10:06:09 2025 +0000 temp save commit 7e330553dca887b4779dde988e1be57417c76199 Merge: 18669925c 804f77dce Author: aska-0096 Date: Thu Jul 17 07:24:32 2025 +0000 Merge branch 'test_copy_fix' of https://github.com/ROCm/composable_kernel into fa_decode_pipeline commit 804f77dce519a819ba29571791acd8db59dc5191 Author: aska-0096 Date: Thu Jul 17 03:10:46 2025 +0000 move test_copy into test commit 21627d7ca78d084c4fd38e9e9e6818fa129b6cf3 Author: aska-0096 Date: Thu Jul 17 02:41:31 2025 +0000 remove unnecessary output commit 287792c44a21f5996363757fae90efff694239dc Merge: a4221db30 21fd7e953 Author: aska-0096 Date: Thu Jul 17 02:26:13 2025 +0000 Merge branch 'test_copy_fix' of https://github.com/ROCm/composable_kernel into test_copy_fix commit a4221db30433cd3d2c7f7df6dc8be75c09151814 Author: aska-0096 Date: Thu Jul 17 02:26:10 2025 +0000 add input validation and bug fix commit 21fd7e953852b25c95afdacaaca5512e6dbfe82e Merge: d6df7bf85 6e76b8205 Author: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Wed Jul 16 11:23:57 2025 -0700 Merge branch 'develop' into test_copy_fix commit d6df7bf8512d5a97adb74aa47161ddef7405bc03 Author: aska-0096 Date: Wed Jul 16 08:55:50 2025 +0000 fix vmcnt shift commit 40e039e4e48ccf8eb4160ea628c47587cd1f695e Author: aska-0096 Date: Wed Jul 16 08:37:07 2025 +0000 Improve s_waitcnt_imm calculation commit c30f8b709b6ded0c8600304a5da823355d6ed893 Author: aska-0096 Date: Wed Jul 16 05:39:50 2025 +0000 fix the s_waitcnt_imm calculation commit ec0a45b29fb7871aee01374b41974263558d3774 Merge: e5cc4af80 6b09f0823 Author: aska-0096 Date: Wed Jul 16 03:57:57 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into test_copy_fix commit e5cc4af808456f56425dd290bb82318650240dee Author: aska-0096 Date: Wed Jul 16 03:54:33 2025 +0000 Add block_sync_lds_direct_load utility commit eea58629cf141652115febfc6122227ab6f59d7d Author: aska-0096 Date: Tue Jul 15 09:39:03 2025 +0000 fix async copytest bug commit 18669925cc6a40c3296ef4e7abd942f5739b0c29 Author: aska-0096 Date: Thu Jul 10 04:29:33 2025 +0000 temp save, change all instance to 1wave commit 18686cfe5b83a2d16424fa6cc3d3eecc5e1a24ef Author: aska-0096 Date: Tue Jul 8 08:37:20 2025 +0000 tempsave, fmha_decode commit 47565f21a5ccfc25192cdc9beb1b62ac89caf921 Author: aska-0096 Date: Sat Jun 21 15:02:57 2025 +0000 temp save, waiting for debug commit e0a634ef9770116c7268b46d64152f116c981042 Author: aska-0096 Date: Thu Jun 19 05:11:52 2025 +0000 save an example for __bf16 type commit 4bd5fd4a3c0263d57b36b2e95bf94654833275d5 Author: aska-0096 Date: Wed Jun 18 07:27:24 2025 +0000 fix bwd code commit 69809d9513742e2e7cb7ffbdd7184396c71c5e43 Author: aska-0096 Date: Wed Jun 18 06:37:16 2025 +0000 Fix for fwd/bwd kernel build filter commit d5ec3d0e5768aafed7f77151b2a835e87b9f95ba Author: Ding, Yi Date: Tue Aug 19 08:13:18 2025 +0000 Add restrict to avoid unnecessary vmcnt --------- Co-authored-by: aska-0096 * Add comments for c-stype cast * Better comments --------- Co-authored-by: aska-0096 --- .../core/arch/amd_buffer_addressing.hpp | 39 ++--- .../arch/amd_buffer_addressing_builtins.hpp | 39 ++--- include/ck_tile/core/tensor/buffer_view.hpp | 20 +-- include/ck_tile/core/tensor/tensor_view.hpp | 6 +- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 8 +- ...k_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp | 4 +- ...a_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 4 +- ...bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 137 +++++++++++------- ...wd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp | 70 +++++---- ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 41 +++--- 10 files changed, 217 insertions(+), 151 deletions(-) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 037e86909d..7a9c017eb2 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1833,14 +1833,17 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, if constexpr(oob_conditional_check) v_offset = flag ? v_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - src_wave_addr_offset, - /*src_immediate_addr_offset*/ 0, - static_cast(coherence)); +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + // Use C-style cast to change address space without dropping llvm noalias attribute + llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, + (as3_uint32_ptr)(smem), + bytes, + v_offset, + src_wave_addr_offset, + /*src_immediate_addr_offset*/ 0, + static_cast(coherence)); +#pragma clang diagnostic pop } template & src_thread_ template __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) { +#define __LDS_ADDR __attribute__((address_space(3))) static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), "We need to have the compatible compiler version to build this instruction"); + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + // Use C-style cast to change address space without dropping llvm noalias attribute + const auto in_ptr_ = (__LDS_ADDR T*)(const_cast(in_ptr)); +#pragma clang diagnostic pop if constexpr(std::is_same_v, ck_tile::half_t>) { typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t; - __attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>( - reinterpret_cast(in_ptr)); + auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_fp16x4_t*>(in_ptr_); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr)); } else if constexpr(std::is_same_v, ck_tile::bf16_t>) { typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t; - __attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>( - reinterpret_cast(in_ptr)); + auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_bf16x4_t*>(in_ptr_); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr)); } else if constexpr(std::is_same_v, ck_tile::fp8_t> || @@ -2812,15 +2818,14 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) std::is_same_v, ck_tile::int8_t>) { typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t; - __attribute__((address_space(3))) llvm_i32x2_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>( - reinterpret_cast(in_ptr)); + auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_); return bit_cast>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr)); } else { static_assert(false, "not implemented"); } +#undef __LDS_ADDR } #endif diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index d1e4eb3da3..4013b51479 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1603,14 +1603,17 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, if constexpr(oob_conditional_check) v_offset = flag ? v_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - src_wave_addr_offset, - /*src_immediate_addr_offset*/ 0, - static_cast(coherence)); +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + // Use C-style cast to change address space without dropping llvm noalias attribute + llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, + (as3_uint32_ptr)(smem), + bytes, + v_offset, + src_wave_addr_offset, + /*src_immediate_addr_offset*/ 0, + static_cast(coherence)); +#pragma clang diagnostic pop } template __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) { +#define __LDS_ADDR __attribute__((address_space(3))) static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), "We need to have the compatible compiler version to build this instruction"); + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + // Use C-style cast to change address space without dropping llvm noalias attribute + const auto in_ptr_ = (__LDS_ADDR T*)(const_cast(in_ptr)); +#pragma clang diagnostic pop if constexpr(std::is_same_v, ck_tile::half_t>) { typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t; - __attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>( - reinterpret_cast(in_ptr)); + auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_fp16x4_t*>(in_ptr_); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr)); } else if constexpr(std::is_same_v, ck_tile::bf16_t>) { typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t; - __attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>( - reinterpret_cast(in_ptr)); + auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_bf16x4_t*>(in_ptr_); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr)); } else if constexpr(std::is_same_v, ck_tile::fp8_t> || @@ -2630,15 +2636,14 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) std::is_same_v, ck_tile::int8_t>) { typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t; - __attribute__((address_space(3))) llvm_i32x2_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>( - reinterpret_cast(in_ptr)); + auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_); return bit_cast>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr)); } else { static_assert(false, "not implemented"); } +#undef __LDS_ADDR } #endif diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index ca314a6abe..d1e770ef42 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -62,12 +62,12 @@ struct buffer_view -CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* p, BufferSizeType buffer_size) +CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size) { return buffer_view{p, buffer_size}; } @@ -1266,7 +1266,7 @@ template , remove_cvref_t>::value, bool>::type = false> CK_TILE_HOST_DEVICE constexpr auto -make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value) +make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size, X invalid_element_value) { return buffer_view{ p, buffer_size, invalid_element_value}; diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index a85dbc6d00..6fa8f898e5 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -449,7 +449,7 @@ template -CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p, +CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* __restrict__ p, const tensor_descriptor& desc) { auto buffer_view = @@ -468,7 +468,7 @@ template ::type = false> CK_TILE_HOST_DEVICE constexpr auto -make_naive_tensor_view(DataType* p, +make_naive_tensor_view(DataType* __restrict__ p, const tuple& lengths, const tuple& strides, number = number<-1>{}, @@ -491,7 +491,7 @@ template CK_TILE_HOST_DEVICE constexpr auto -make_naive_tensor_view_packed(DataType* p, +make_naive_tensor_view_packed(DataType* __restrict__ p, const tuple& lengths, number = number<-1>{}) { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 5e16fc563b..3f5bef366e 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -1115,7 +1115,8 @@ struct FmhaBwdDQDKDVKernel {i_n0, 0}); if constexpr(!kUseQrQtrDorPipeline) { - auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window, + auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(smem_ptr, + q_dram_window, k_dram_window, v_dram_window, bias_dram_window, @@ -1131,7 +1132,6 @@ struct FmhaBwdDQDKDVKernel kargs.scale, rp_undrop, scale_rp_undrop, - smem_ptr, dropout); KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile); @@ -1139,7 +1139,8 @@ struct FmhaBwdDQDKDVKernel } else { - FmhaPipeline{}(q_dram_window, + FmhaPipeline{}(smem_ptr, + q_dram_window, k_dram_window, v_dram_window, bias_dram_window, @@ -1160,7 +1161,6 @@ struct FmhaBwdDQDKDVKernel kargs.scale, rp_undrop, scale_rp_undrop, - smem_ptr, dropout); } } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index d36f8ad724..5e63fb714a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -93,7 +93,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR typename BiasGradDramBlockWindowTmp, typename PositionEncoding> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + operator()(void* smem_ptr, + const QDramBlockWindowTmp& q_dram_block_window_tmp, const KDramBlockWindowTmp& k_dram_block_window_tmp, const VDramBlockWindowTmp& v_dram_block_window_tmp, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, @@ -109,7 +110,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR float scale, float rp_undrop, float scale_rp_undrop, - void* smem_ptr, FmhaDropout& dropout) const { static_assert( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 88fb1281aa..b883aad155 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -93,7 +93,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP typename BiasGradDramBlockWindowTmp, typename PositionEncoding> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + operator()(void* smem_ptr, + const QDramBlockWindowTmp& q_dram_block_window_tmp, const KDramBlockWindowTmp& k_dram_block_window_tmp, const VDramBlockWindowTmp& v_dram_block_window_tmp, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, @@ -109,7 +110,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP float scale, float rp_undrop, float scale_rp_undrop, - void* smem_ptr, FmhaDropout& dropout) const { static_assert( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 9a31498dd1..9bd78b4077 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -90,6 +90,53 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR else return raw_lse; }; + template + CK_TILE_DEVICE auto operator()(void* smem_ptr, Ts&&... args) const + { + // LDS allocation + // cast to char* to do pointer arithmetic + const auto smem_ptr_ = reinterpret_cast(smem_ptr); + const auto k_lds_ptr = reinterpret_cast(smem_ptr_); + const auto v_lds_ptr = + reinterpret_cast(smem_ptr_ + Policy::template GetSmemSizeK()); + + const auto do_lds_ptr0 = reinterpret_cast(smem_ptr_); + const auto do_lds_ptr1 = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad()); + const auto q_lds_ptr0 = reinterpret_cast( // + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad()); + const auto q_lds_ptr1 = reinterpret_cast( // + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ()); + const auto lse_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ()); + const auto d_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeLSE()); + const auto ds_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeLSE() + Policy::template GetSmemSizeD()); + const auto bias_lds_ptr = reinterpret_cast(ds_lds_ptr); + return run(k_lds_ptr, + v_lds_ptr, + do_lds_ptr0, + do_lds_ptr1, + q_lds_ptr0, + q_lds_ptr1, + lse_lds_ptr, + d_lds_ptr, + ds_lds_ptr, + bias_lds_ptr, + std::forward(args)...); + } template - CK_TILE_DEVICE auto operator()( // + CK_TILE_DEVICE auto run( // + KDataType* __restrict__ k_lds_ptr, + VDataType* __restrict__ v_lds_ptr, + OGradDataType* __restrict__ do_lds_ptr0, + OGradDataType* __restrict__ do_lds_ptr1, + QDataType* __restrict__ q_lds_ptr0, + QDataType* __restrict__ q_lds_ptr1, + LSEDataType* __restrict__ lse_lds_ptr, + DDataType* __restrict__ d_lds_ptr, + GemmDataType* __restrict__ ds_lds_ptr, + BiasDataType* __restrict__ bias_lds_ptr, const QDramBlockWindowTmp& q_dram_block_window_tmp, const KDramBlockWindowTmp& k_dram_block_window_tmp, const VDramBlockWindowTmp& v_dram_block_window_tmp, @@ -119,7 +176,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR float scale, float rp_undrop, float scale_rp_undrop, - void* smem_ptr, FmhaDropout& dropout) const { static_assert( @@ -184,40 +240,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR } } - // LDS allocation - const auto smem_ptr_ = - reinterpret_cast(smem_ptr); // cast to char* to do pointer arithmetic - - const auto k_lds_ptr = reinterpret_cast(smem_ptr_); - const auto v_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeK()); - - const auto do_lds_ptr0 = reinterpret_cast(smem_ptr_); - const auto do_lds_ptr1 = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad()); - const auto q_lds_ptr0 = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad()); - const auto q_lds_ptr1 = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ()); - const auto lse_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ()); - const auto d_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + - Policy::template GetSmemSizeLSE()); - const auto ds_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + - Policy::template GetSmemSizeLSE() + Policy::template GetSmemSizeD()); - const auto bias_lds_ptr = reinterpret_cast(ds_lds_ptr); - auto k_lds = make_tensor_view( k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); auto k_lds_write_window = @@ -453,13 +475,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR decltype(load_tile(d_dram_window)) d_block_tile; index_t i_total_bodys = 0; - auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable { - const bool is_even = (i_total_bodys % 2 == 0); - QDataType* const __restrict__ q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0; - QDataType* const __restrict__ q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1; - OGradDataType* const __restrict__ do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0; - OGradDataType* const __restrict__ do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1; - + auto main_body_impl = [&](auto is_prologue_, + auto is_epilogue_, + QDataType* const __restrict__ q_lds_ptr_curr, + QDataType* const __restrict__ q_lds_ptr_next, + OGradDataType* const __restrict__ do_lds_ptr_curr, + OGradDataType* const __restrict__ do_lds_ptr_next) mutable { constexpr bool is_prologue = is_prologue_.value; constexpr bool is_epilogue = is_epilogue_.value; static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true"); @@ -467,19 +488,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR if constexpr(is_prologue) { + lse_block_tile = load_tile(lse_dram_window); + move_tile_window(lse_dram_window, {kM0}); + + d_block_tile = load_tile(d_dram_window); + move_tile_window(d_dram_window, {kM0}); + q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next); async_load_tile(q_lds_write_window, q_dram_window); move_tile_window(q_dram_window, {kM0, 0}); - lse_block_tile = load_tile(lse_dram_window); - move_tile_window(lse_dram_window, {kM0}); - do_lds_write_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next); async_load_tile(do_lds_write_window, do_dram_window); move_tile_window(do_dram_window, {kM0, 0}); - - d_block_tile = load_tile(d_dram_window); - move_tile_window(d_dram_window, {kM0}); } if constexpr(is_epilogue) { @@ -611,8 +632,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR constexpr auto i_j_idx = make_tuple(idx0, idx1); bool undrop_flag = p[i_j_idx] >= 0; ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag - ? (dp_acc[i_j_idx] - d[i_idx]) - : d[i_idx]); + ? (dp_acc[i_j_idx] - d[i_idx]) + : d[i_idx]); }); }); @@ -725,6 +746,20 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR } move_tile_window(dq_dram_window, {kM0, 0}); } + }; + + auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable { + const bool is_even = (i_total_bodys % 2 == 0); + const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0; + const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1; + const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0; + const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1; + main_body_impl(is_prologue_, + is_epilogue_, + q_lds_ptr_curr, + q_lds_ptr_next, + do_lds_ptr_curr, + do_lds_ptr_next); i_total_bodys += 1; }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 789cfb3ea4..5adb64564d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -93,6 +93,42 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR return raw_lse; }; + template + CK_TILE_DEVICE auto operator()(void* smem_ptr, Ts&&... args) const + { + // LDS allocation + const auto smem_ptr_ = + reinterpret_cast(smem_ptr); // cast to char* to do pointer arithmetic + + const auto k_lds_ptr = reinterpret_cast(smem_ptr_); + const auto v_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeK()); + + const auto do_lds_ptr = reinterpret_cast(smem_ptr_); + const auto q_lds_ptr = reinterpret_cast( // + smem_ptr_ + Policy::template GetSmemSizeOGrad()); + const auto lse_lds_ptr = reinterpret_cast( // + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ()); + const auto d_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeLSE()); + + const auto ds_lds_ptr = + reinterpret_cast(smem_ptr_ + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeV()); + const auto bias_lds_ptr = reinterpret_cast(ds_lds_ptr); + return run(k_lds_ptr, + v_lds_ptr, + do_lds_ptr, + q_lds_ptr, + lse_lds_ptr, + d_lds_ptr, + ds_lds_ptr, + bias_lds_ptr, + std::forward(args)...); + } + template - CK_TILE_DEVICE auto operator()( // + CK_TILE_DEVICE auto run( // + KDataType* __restrict__ k_lds_ptr, + VDataType* __restrict__ v_lds_ptr, + OGradDataType* __restrict__ do_lds_ptr, + QDataType* __restrict__ q_lds_ptr, + LSEDataType* __restrict__ lse_lds_ptr, + DDataType* __restrict__ d_lds_ptr, + GemmDataType* __restrict__ ds_lds_ptr, + BiasDataType* __restrict__ bias_lds_ptr, const QDramBlockWindowTmp& q_dram_block_window_tmp, const KDramBlockWindowTmp& k_dram_block_window_tmp, const VDramBlockWindowTmp& v_dram_block_window_tmp, @@ -131,7 +175,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR float scale, float rp_undrop, float scale_rp_undrop, - void* smem_ptr, FmhaDropout& dropout) const { static_assert( @@ -181,29 +224,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR {seqlen_kv_start, 0}, Policy::template MakeKDramTileDistribution()); - // LDS allocation - const auto smem_ptr_ = - reinterpret_cast(smem_ptr); // cast to char* to do pointer arithmetic - - const auto k_lds_ptr = reinterpret_cast(smem_ptr_); - const auto v_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeK()); - - const auto do_lds_ptr = reinterpret_cast(smem_ptr_); - const auto q_lds_ptr = reinterpret_cast( // - smem_ptr_ + Policy::template GetSmemSizeOGrad()); - const auto lse_lds_ptr = reinterpret_cast( // - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ()); - const auto d_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeLSE()); - - const auto ds_lds_ptr = - reinterpret_cast(smem_ptr_ + Policy::template GetSmemSizeK() + - Policy::template GetSmemSizeV()); - const auto bias_lds_ptr = reinterpret_cast(ds_lds_ptr); - auto k_lds = make_tensor_view( k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); auto k_lds_write_window = diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 39d8814692..aafe481d2b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -638,11 +638,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload typename LSEaccDramBlockWindowTmp, typename PositionEncoding> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile + operator()(const QDramBlockWindowTmp& __restrict__ q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& __restrict__ k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& __restrict__ v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& __restrict__ bias_dram_block_window_tmp, // M0*N0 tile + LSEaccDramBlockWindowTmp& __restrict__ lse_acc_dram_window_tmp, // M0*1 tile FmhaMask mask, PositionEncoding position_encoding, float scale_s, @@ -854,18 +854,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload __builtin_amdgcn_sched_barrier(0); - auto mainloop = [&](index_t cur_loop) { - const bool is_even_loop = (cur_loop % 2 == 0); - - auto k_lds_write_ptr = is_even_loop ? static_cast(smem_ptrk0) - : static_cast(smem_ptrk1); - auto k_lds_read_ptr = is_even_loop ? static_cast(smem_ptrk1) - : static_cast(smem_ptrk0); - auto v_lds_write_ptr = is_even_loop ? static_cast(smem_ptrv1) - : static_cast(smem_ptrv0); - auto v_lds_read_ptr = is_even_loop ? static_cast(smem_ptrv0) - : static_cast(smem_ptrv1); - + auto mainloop = [&](KDataType* __restrict__ k_lds_write_ptr, + KDataType* __restrict__ k_lds_read_ptr, + KDataType* __restrict__ v_lds_write_ptr, + KDataType* __restrict__ v_lds_read_ptr) { // move V tile windows block_sync_lds(); move_tile_window(v_dram_window, {kN0, 0}); @@ -1110,11 +1102,20 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ }); - }; + }; // mainloop do { - mainloop(i_total_loops); + bool is_even_loop = i_total_loops % 2 == 0; + auto k_lds_write_ptr = is_even_loop ? static_cast(smem_ptrk0) + : static_cast(smem_ptrk1); + auto k_lds_read_ptr = is_even_loop ? static_cast(smem_ptrk1) + : static_cast(smem_ptrk0); + auto v_lds_write_ptr = is_even_loop ? static_cast(smem_ptrv1) + : static_cast(smem_ptrv0); + auto v_lds_read_ptr = is_even_loop ? static_cast(smem_ptrv0) + : static_cast(smem_ptrv1); + mainloop(k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr); i_total_loops++; } while(i_total_loops < num_total_loop); From 61806856885e9d6d500b1e112142128ee90ab997 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 25 Aug 2025 09:51:29 -0700 Subject: [PATCH 21/46] Resolve issues with performance logs in CI. (#2733) * update the performance test logic * fix unstash perf logs logic * untangle unstashing fmha logs for different archs * run process stage after running fmha tests * fix the processing of perf logs * fix arguments for run_performance scripts --- Jenkinsfile | 116 +++++++++++++++++++++------ script/process_perf_data.sh | 37 ++++++++- script/run_full_performance_tests.sh | 29 ++++--- script/run_performance_tests.sh | 10 ++- 4 files changed, 145 insertions(+), 47 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index b3b63098c2..6c79acb14b 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -578,31 +578,60 @@ def Build_CK(Map conf=[:]){ if (params.RUN_FULL_QA && arch == 1){ // run full tests on gfx90a echo "Run full performance tests" - sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" - archiveArtifacts "perf_gemm.log" - archiveArtifacts "perf_resnet50_N256.log" - archiveArtifacts "perf_resnet50_N4.log" - archiveArtifacts "perf_batched_gemm.log" - archiveArtifacts "perf_grouped_gemm.log" - archiveArtifacts "perf_grouped_conv_fwd.log" - archiveArtifacts "perf_grouped_conv_bwd_data.log" - archiveArtifacts "perf_grouped_conv_bwd_weight.log" - archiveArtifacts "perf_gemm_bilinear.log" - archiveArtifacts "perf_reduction.log" - archiveArtifacts "perf_splitK_gemm.log" - archiveArtifacts "perf_onnx_gemm.log" - archiveArtifacts "perf_mixed_gemm.log" - stash includes: "perf_**.log", name: "perf_log" + sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx90a" + archiveArtifacts "perf_gemm_gfx90a.log" + archiveArtifacts "perf_resnet50_N256_gfx90a.log" + archiveArtifacts "perf_resnet50_N4_gfx90a.log" + archiveArtifacts "perf_batched_gemm_gfx90a.log" + archiveArtifacts "perf_grouped_gemm_gfx90a.log" + archiveArtifacts "perf_grouped_conv_fwd_gfx90a.log" + archiveArtifacts "perf_grouped_conv_bwd_data_gfx90a.log" + archiveArtifacts "perf_grouped_conv_bwd_weight_gfx90a.log" + archiveArtifacts "perf_gemm_bilinear_gfx90a.log" + archiveArtifacts "perf_reduction_gfx90a.log" + archiveArtifacts "perf_splitK_gemm_gfx90a.log" + archiveArtifacts "perf_onnx_gemm_gfx90a.log" + archiveArtifacts "perf_mixed_gemm_gfx90a.log" + stash includes: "perf_**.log", name: "perf_log_gfx90a" + } + if (params.RUN_FULL_QA && arch == 2){ + // run full tests on gfx942 + echo "Run full performance tests" + sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx942" + archiveArtifacts "perf_gemm_gfx942.log" + archiveArtifacts "perf_resnet50_N256_gfx942.log" + archiveArtifacts "perf_resnet50_N4_gfx942.log" + archiveArtifacts "perf_batched_gemm_gfx942.log" + archiveArtifacts "perf_grouped_gemm_gfx942.log" + archiveArtifacts "perf_grouped_conv_fwd_gfx942.log" + archiveArtifacts "perf_grouped_conv_bwd_data_gfx942.log" + archiveArtifacts "perf_grouped_conv_bwd_weight_gfx942.log" + archiveArtifacts "perf_gemm_bilinear_gfx942.log" + archiveArtifacts "perf_reduction_gfx942.log" + archiveArtifacts "perf_splitK_gemm_gfx942.log" + archiveArtifacts "perf_onnx_gemm_gfx942.log" + archiveArtifacts "perf_mixed_gemm_gfx942.log" + stash includes: "perf_**.log", name: "perf_log_gfx942" } else if ( arch == 1 ){ // run standard tests on gfx90a echo "Run performance tests" - sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" - archiveArtifacts "perf_gemm.log" - archiveArtifacts "perf_onnx_gemm.log" - archiveArtifacts "perf_resnet50_N256.log" - archiveArtifacts "perf_resnet50_N4.log" - stash includes: "perf_**.log", name: "perf_log" + sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx90a" + archiveArtifacts "perf_gemm_gfx90a.log" + archiveArtifacts "perf_onnx_gemm_gfx90a.log" + archiveArtifacts "perf_resnet50_N256_gfx90a.log" + archiveArtifacts "perf_resnet50_N4_gfx90a.log" + stash includes: "perf_**.log", name: "perf_log_gfx90a" + } + else if ( arch == 2 ){ + // run standard tests on gfx942 + echo "Run performance tests" + sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx942" + archiveArtifacts "perf_gemm_gfx942.log" + archiveArtifacts "perf_onnx_gemm_gfx942.log" + archiveArtifacts "perf_resnet50_N256_gfx942.log" + archiveArtifacts "perf_resnet50_N4_gfx942.log" + stash includes: "perf_**.log", name: "perf_log_gfx942" } // disable performance tests on gfx1030 for now. //else if ( arch == 3){ @@ -720,10 +749,15 @@ def process_results(Map conf=[:]){ if (params.RUN_CK_TILE_FMHA_TESTS){ try{ unstash "perf_fmha_log_gfx942" + } + catch(Exception err){ + echo "could not locate the FMHA performance logs for gfx942: ${err.getMessage()}." + } + try{ unstash "perf_fmha_log_gfx90a" } catch(Exception err){ - echo "could not locate the FMHA performance logs: ${err.getMessage()}." + echo "could not locate the FMHA performance logs for gfx90a: ${err.getMessage()}." } } if (params.BUILD_INSTANCES_ONLY){ @@ -733,16 +767,46 @@ def process_results(Map conf=[:]){ } else{ // unstash perf files to master - unstash "perf_log" + try{ + unstash "perf_log_gfx90a" + } + catch(Exception err){ + echo "could not locate the gfx90a performance logs: ${err.getMessage()}." + } + try{ + unstash "perf_log_gfx942" + } + catch(Exception err){ + echo "could not locate the gfx942 performance logs: ${err.getMessage()}." + } + try{ + unstash "perf_log_gfx950" + } + catch(Exception err){ + echo "could not locate the gfx950 performance logs: ${err.getMessage()}." + } + try{ + unstash "perf_log_gfx908" + } + catch(Exception err){ + echo "could not locate the gfx908 performance logs: ${err.getMessage()}." + } try{ unstash "perf_log_gfx11" + } + catch(Exception err){ + echo "could not locate the gfx11 performance logs: ${err.getMessage()}." + } + try{ + unstash "perf_log_gfx12" } catch(Exception err){ - echo "could not locate the GEMM gfx11/gfx12 performance logs: ${err.getMessage()}." + echo "could not locate the gfx12 performance logs: ${err.getMessage()}." } - sh "./process_perf_data.sh" } + // process the logs + sh "./process_perf_data.sh" } } catch(e){ @@ -1505,7 +1569,7 @@ pipeline { stage("Process results"){ when { beforeAgent true - expression { (params.RUN_PERFORMANCE_TESTS.toBoolean() || params.BUILD_INSTANCES_ONLY.toBoolean()) && !params.BUILD_LEGACY_OS.toBoolean() } + expression { (params.RUN_PERFORMANCE_TESTS.toBoolean() || params.BUILD_INSTANCES_ONLY.toBoolean() || params.RUN_CK_TILE_FMHA_TESTS.toBoolean()) && !params.BUILD_LEGACY_OS.toBoolean() } } agent { label 'mici' } steps{ diff --git a/script/process_perf_data.sh b/script/process_perf_data.sh index fc44064874..50c84924f5 100755 --- a/script/process_perf_data.sh +++ b/script/process_perf_data.sh @@ -10,10 +10,39 @@ # please contact Illia.Silin@amd.com for more details #process results -python3 process_perf_data.py perf_gemm.log -python3 process_perf_data.py perf_onnx_gemm.log -python3 process_perf_data.py perf_resnet50_N256.log -python3 process_perf_data.py perf_resnet50_N4.log +file=./perf_gemm_gfx90a.log +if [ -e "$file" ]; then + python3 process_perf_data.py perf_gemm_gfx90a.log +fi +file=./perf_onnx_gemm_gfx90a.log +if [ -e "$file" ]; then + python3 process_perf_data.py perf_onnx_gemm_gfx90a.log +fi +file=./perf_resnet50_N256_gfx90a.log +if [ -e "$file" ]; then + python3 process_perf_data.py perf_resnet50_N256_gfx90a.log +fi +file=./perf_resnet50_N4_gfx90a.log +if [ -e "$file" ]; then + python3 process_perf_data.py perf_resnet50_N4_gfx90a.log +fi + +file=./perf_gemm_gfx942.log +if [ -e "$file" ]; then + python3 process_perf_data.py perf_gemm_gfx942.log +fi +file=./perf_onnx_gemm_gfx942.log +if [ -e "$file" ]; then + python3 process_perf_data.py perf_onnx_gemm_gfx942.log +fi +file=./perf_resnet50_N256_gfx942.log +if [ -e "$file" ]; then + python3 process_perf_data.py perf_resnet50_N256_gfx942.log +fi +file=./perf_resnet50_N4_gfx942.log +if [ -e "$file" ]; then + python3 process_perf_data.py perf_resnet50_N4_gfx942.log +fi file=./perf_onnx_gemm_gfx10.log if [ -e "$file" ]; then diff --git a/script/run_full_performance_tests.sh b/script/run_full_performance_tests.sh index ddc5c270b8..508200b21a 100755 --- a/script/run_full_performance_tests.sh +++ b/script/run_full_performance_tests.sh @@ -22,6 +22,9 @@ export branch=$3 echo 'Branch name: ' $branch export host_name=$4 echo 'Host name: ' $host_name +export arch=$5 +echo 'GPU architecture: ' $arch + function print_log_header(){ rm -f $1; echo 'On branch ' $3 &> $1; @@ -35,7 +38,7 @@ function print_log_header(){ } #run gemm tests -export gemm_log="perf_gemm.log" +export gemm_log="perf_gemm_$arch.log" print_log_header $gemm_log $env_type $branch $host_name ./profile_gemm.sh gemm 0 0 $verify 1 0 1 2>&1 | tee -a $gemm_log ./profile_gemm.sh gemm 1 0 $verify 1 0 1 2>&1 | tee -a $gemm_log @@ -55,7 +58,7 @@ print_log_header $gemm_log $env_type $branch $host_name ./profile_gemm.sh gemm 3 3 $verify 1 0 1 2>&1 | tee -a $gemm_log #run batched_gemm tests -export batched_gemm_log="perf_batched_gemm.log" +export batched_gemm_log="perf_batched_gemm_$arch.log" print_log_header $batched_gemm_log $env_type $branch $host_name ./profile_batched_gemm.sh batched_gemm 0 0 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log ./profile_batched_gemm.sh batched_gemm 0 1 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log @@ -75,7 +78,7 @@ print_log_header $batched_gemm_log $env_type $branch $host_name ./profile_batched_gemm.sh batched_gemm 3 3 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log #run grouped_gemm tests -export grouped_gemm_log="perf_grouped_gemm.log" +export grouped_gemm_log="perf_grouped_gemm_$arch.log" print_log_header $grouped_gemm_log $env_type $branch $host_name ./profile_grouped_gemm.sh grouped_gemm 1 0 $verify 1 0 1 2>&1 | tee -a $grouped_gemm_log ./profile_grouped_gemm.sh grouped_gemm 1 1 $verify 1 0 1 2>&1 | tee -a $grouped_gemm_log @@ -83,7 +86,7 @@ print_log_header $grouped_gemm_log $env_type $branch $host_name ./profile_grouped_gemm.sh grouped_gemm 1 3 $verify 1 0 1 2>&1 | tee -a $grouped_gemm_log #run GEMM+Bilinear tests -export gemm_bilinear_log="perf_gemm_bilinear.log" +export gemm_bilinear_log="perf_gemm_bilinear_$arch.log" print_log_header $gemm_bilinear_log $env_type $branch $host_name ./profile_gemm_bilinear.sh gemm_bilinear 1 0 $verify 1 0 1 2>&1 | tee -a $gemm_bilinear_log ./profile_gemm_bilinear.sh gemm_bilinear 1 1 $verify 1 0 1 2>&1 | tee -a $gemm_bilinear_log @@ -91,21 +94,21 @@ print_log_header $gemm_bilinear_log $env_type $branch $host_name ./profile_gemm_bilinear.sh gemm_bilinear 1 3 $verify 1 0 1 2>&1 | tee -a $gemm_bilinear_log #run grouped_fwd tests -export grouped_conv_fwd_log="perf_grouped_conv_fwd.log" +export grouped_conv_fwd_log="perf_grouped_conv_fwd_$arch.log" print_log_header $grouped_conv_fwd_log $env_type $branch $host_name ./profile_grouped_conv_fwd.sh grouped_conv_fwd 0 1 0 $verify 1 0 1 256 2>&1 | tee -a $grouped_conv_fwd_log ./profile_grouped_conv_fwd.sh grouped_conv_fwd 1 1 0 $verify 1 0 1 256 2>&1 | tee -a $grouped_conv_fwd_log ./profile_grouped_conv_fwd.sh grouped_conv_fwd 2 1 0 $verify 1 0 1 256 2>&1 | tee -a $grouped_conv_fwd_log #run grouped_bwd_data tests -export grouped_conv_bwd_data_log="perf_grouped_conv_bwd_data.log" +export grouped_conv_bwd_data_log="perf_grouped_conv_bwd_data_$arch.log" print_log_header $grouped_conv_bwd_data_log $env_type $branch $host_name ./profile_grouped_conv_bwd_data.sh grouped_conv_bwd_data 0 1 $verify 1 0 1 256 2>&1 | tee -a $grouped_conv_bwd_data_log ./profile_grouped_conv_bwd_data.sh grouped_conv_bwd_data 1 1 $verify 1 0 1 256 2>&1 | tee -a $grouped_conv_bwd_data_log ./profile_grouped_conv_bwd_data.sh grouped_conv_bwd_data 2 1 $verify 1 0 1 256 2>&1 | tee -a $grouped_conv_bwd_data_log #run grouped_bwd_weight tests -export grouped_conv_bwd_weight_log="perf_grouped_conv_bwd_weight.log" +export grouped_conv_bwd_weight_log="perf_grouped_conv_bwd_weight_$arch.log" print_log_header $grouped_conv_bwd_weight_log $env_type $branch $host_name ./profile_grouped_conv_bwd_weight.sh grouped_conv_bwd_weight 0 2 $verify 1 0 1 256 1 2>&1 | tee -a $grouped_conv_bwd_weight_log ./profile_grouped_conv_bwd_weight.sh grouped_conv_bwd_weight 1 2 $verify 1 0 1 256 1 2>&1 | tee -a $grouped_conv_bwd_weight_log @@ -113,21 +116,21 @@ print_log_header $grouped_conv_bwd_weight_log $env_type $branch $host_name ./profile_grouped_conv_bwd_weight.sh grouped_conv_bwd_weight 1 2 $verify 1 0 1 256 4 2>&1 | tee -a $grouped_conv_bwd_weight_log #run resnet50 tests -export resnet256_log="perf_resnet50_N256.log" +export resnet256_log="perf_resnet50_N256_$arch.log" print_log_header $resnet256_log $env_type $branch $host_name ./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1 $verify 1 0 1 256 2>&1 | tee -a $resnet256_log -export resnet4_log="perf_resnet50_N4.log" +export resnet4_log="perf_resnet50_N4_$arch.log" print_log_header $resnet4_log $env_type $branch $host_name ./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1 $verify 1 0 1 4 2>&1 | tee -a $resnet4_log #run reduction tests -export reduction_log="perf_reduction.log" +export reduction_log="perf_reduction_$arch.log" print_log_header $reduction_log $env_type $branch $host_name ./profile_reduce_with_index.sh $verify 2 10 --half 2>&1 | tee -a $reduction_log ./profile_reduce_no_index.sh $verify 2 10 --half 2>&1 | tee -a $reduction_log #run splitK_gemm tests, first correctness verification, then performance -export splitK_gemm_log="perf_splitK_gemm.log" +export splitK_gemm_log="perf_splitK_gemm_$arch.log" print_log_header $splitK_gemm_log $env_type $branch $host_name ./profile_splitK_gemm.sh gemm_splitk 0 0 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log ./profile_splitK_gemm.sh gemm_splitk 0 1 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log @@ -139,13 +142,13 @@ print_log_header $splitK_gemm_log $env_type $branch $host_name ./profile_splitK_gemm.sh gemm_splitk 1 3 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log #run ONNX gemm tests -export onnx_log="perf_onnx_gemm.log" +export onnx_log="perf_onnx_gemm_$arch.log" print_log_header $onnx_log $env_type $branch $host_name ./profile_onnx_gemm.sh gemm 0 0 $verify 1 0 1 2>&1 | tee -a $onnx_log ./profile_onnx_gemm.sh gemm 1 0 $verify 1 0 1 2>&1 | tee -a $onnx_log #run mixed fp16/fp8 and fp8/fp16 gemm tests -export mixed_gemm_log="perf_mixed_gemm.log" +export mixed_gemm_log="perf_mixed_gemm_$arch.log" print_log_header $mixed_gemm_log $env_type $branch $host_name ./profile_mixed_gemm.sh gemm_splitk 4 0 $verify 2 0 1 16 2>&1 | tee -a $mixed_gemm_log ./profile_mixed_gemm.sh gemm_splitk 5 0 $verify 2 0 1 16 2>&1 | tee -a $mixed_gemm_log \ No newline at end of file diff --git a/script/run_performance_tests.sh b/script/run_performance_tests.sh index c8a281dc07..4e13b59d34 100755 --- a/script/run_performance_tests.sh +++ b/script/run_performance_tests.sh @@ -18,6 +18,8 @@ export branch=$3 echo 'Branch name: ' $branch export host_name=$4 echo 'Host name: ' $host_name +export arch=$5 +echo 'GPU architecture: ' $arch function print_log_header(){ rm -f $1; @@ -32,7 +34,7 @@ function print_log_header(){ } #run gemm tests -export gemm_log="perf_gemm.log" +export gemm_log="perf_gemm_$arch.log" print_log_header $gemm_log $env_type $branch $host_name ./profile_gemm.sh gemm 0 0 $verify 1 0 1 | tee -a $gemm_log ./profile_gemm.sh gemm 1 0 $verify 1 0 1 | tee -a $gemm_log @@ -52,15 +54,15 @@ print_log_header $gemm_log $env_type $branch $host_name ./profile_gemm.sh gemm 3 3 $verify 1 0 1 | tee -a $gemm_log #run ONNX gemm tests -export onnx_log="perf_onnx_gemm.log" +export onnx_log="perf_onnx_gemm_$arch.log" print_log_header $onnx_log $env_type $branch $host_name ./profile_onnx_gemm.sh gemm 0 0 $verify 1 0 1 2>&1 | tee -a $onnx_log ./profile_onnx_gemm.sh gemm 1 0 $verify 1 0 1 2>&1 | tee -a $onnx_log #run resnet50 tests -export resnet256_log="perf_resnet50_N256.log" +export resnet256_log="perf_resnet50_N256_$arch.log" print_log_header $resnet256_log $env_type $branch $host_name ./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1 $verify 1 0 1 256 | tee -a $resnet256_log -export resnet4_log="perf_resnet50_N4.log" +export resnet4_log="perf_resnet50_N4_$arch.log" print_log_header $resnet4_log $env_type $branch $host_name ./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1 $verify 1 0 1 4 | tee -a $resnet4_log From bb6132116fa55c3e7434a95a665f29629329f50e Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Mon, 25 Aug 2025 13:48:51 -0400 Subject: [PATCH 22/46] build!: Update composable kernel version to 1.2.0 for rocm 7.0 release (#2734) * build!: Update composable kernel version to 1.2.0 for rocm 7.0 release --- CHANGELOG.md | 2 +- CMakeLists.txt | 2 +- Jenkinsfile | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1246248eac..76fb46cdd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/). -## Composable Kernel 1.1.0 for ROCm 7.0.0 +## Composable Kernel 1.2.0 for ROCm 7.0.0 ### Added diff --git a/CMakeLists.txt b/CMakeLists.txt index 35ebba8085..f77a41371f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,7 +21,7 @@ if(NOT WIN32) set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "") endif() -set(version 1.1.0) +set(version 1.2.0) # Check support for CUDA/HIP in Cmake project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) include(CTest) diff --git a/Jenkinsfile b/Jenkinsfile index 6c79acb14b..8f5c724776 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -400,8 +400,8 @@ def cmake_build(Map conf=[:]){ echo "Build packages" sh 'ninja -j64 package' archiveArtifacts artifacts: 'composablekernel-dev*.deb' - sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.1.0_amd64.deb' - sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb' + sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.2.0_amd64.deb' + sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64.deb' stash includes: "composablekernel-**.deb", name: "packages" } } From 1d4a3341f088534b441127681efd88b9c584fad6 Mon Sep 17 00:00:00 2001 From: spolifroni-amd Date: Mon, 25 Aug 2025 14:16:57 -0400 Subject: [PATCH 23/46] removed the blog posts as as these are broken links (#2732) --- docs/Contributors_Guide.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/Contributors_Guide.rst b/docs/Contributors_Guide.rst index 3788ba609c..1b978ed63e 100644 --- a/docs/Contributors_Guide.rst +++ b/docs/Contributors_Guide.rst @@ -19,7 +19,6 @@ Getting started build the library. You can also find some of this information in the `README file `_ on the project's GitHub page. -#. **Additional reading:** The blog post `AMD Composable Kernel library: efficient fused kernels for AI apps with just a few lines of code `_ provides a deeper understanding of the CK library and showcases its performance capabilities. `_ from the AMD Community portal. It offers a deeper understanding of the library's objectives and showcases its performance capabilities. #. **General information:** For broader information about AMD products, consider exploring the From e9605ed36db7948491d21911267127823351991d Mon Sep 17 00:00:00 2001 From: Tianyuan Wu Date: Tue, 26 Aug 2025 03:55:35 +0800 Subject: [PATCH 24/46] [CK_TILE] Fix the Wrong Output Generated by Gemm Examples on GFX11/12 (#2713) * Introduce macro CK_TILE_USE_WMMA Signed-off-by: Tianyuan Wu * Make CK_TILE_USE_WMMA global for all examples Signed-off-by: Tianyuan Wu * Remove CK_TILE_USE_WMMA from config.hpp Signed-off-by: Tianyuan Wu --------- Signed-off-by: Tianyuan Wu --- CMakeLists.txt | 13 ++----------- example/ck_tile/03_gemm/gemm_basic.cpp | 10 ++++++++++ example/ck_tile/03_gemm/gemm_utils.hpp | 2 ++ example/ck_tile/03_gemm/universal_gemm.cpp | 4 ++++ 4 files changed, 18 insertions(+), 11 deletions(-) mode change 100755 => 100644 example/ck_tile/03_gemm/gemm_utils.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f77a41371f..f148f31d25 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -225,6 +225,8 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1 message(STATUS "Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") + add_definitions(-DCK_TILE_USE_WMMA) + set(CK_TILE_USE_WMMA "ON") endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") message(STATUS "Enabling WMMA FP8 gemms on native architectures") @@ -324,23 +326,12 @@ if(USE_BITINT_EXTENSION_INT4) message(STATUS "CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") endif() -if(USE_OPT_GFX11) - add_compile_options(-mcumode) - add_compile_options(-mno-wavefrontsize64) - message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}") -endif() - if(ENABLE_ASM_DUMP) add_compile_options(--save-temps) add_compile_options(-Wno-gnu-line-marker) message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}") endif() -if(USE_OPT_GFX12 AND (SUPPORTED_GPU_TARGETS MATCHES "gfx12")) - add_compile_options(-mno-wavefrontsize64) - message(STATUS "CK compiled with USE_OPT_GFX12 set to ${USE_OPT_GFX12}") -endif() - ## Threads set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 8cdbe39e86..99c943a7f1 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -26,6 +26,15 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t K_Tile = 64; +#if CK_TILE_USE_WMMA + constexpr ck_tile::index_t M_Warp = 4; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 16; + constexpr ck_tile::index_t N_Warp_Tile = 16; + constexpr ck_tile::index_t K_Warp_Tile = 16; +#else constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; @@ -33,6 +42,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 16; +#endif using CodegenGemmShape = ck_tile::TileGemmShape, diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp old mode 100755 new mode 100644 index eb0a6de8aa..ed2006d4b9 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -172,6 +172,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr int kBlockPerCu = 2; }; +#if CK_TILE_USE_WMMA template struct GemmConfigComputeV3_WMMA : public GemmConfigBase { @@ -192,6 +193,7 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase static constexpr int kBlockPerCu = 2; }; +#endif template struct GemmConfigComputeV4 : public GemmConfigBase diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 4e01710b4d..b80d9991d4 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -335,7 +335,11 @@ int main(int argc, char* argv[]) try { +#if CK_TILE_USE_WMMA + return !run_gemm_example(arg_parser); +#else return !run_gemm_example(arg_parser); +#endif } catch(const std::runtime_error& e) { From c88e24ebe5e929b62373cc83106bd89879a6a915 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Mon, 25 Aug 2025 21:53:40 -0400 Subject: [PATCH 25/46] fix(gemm_universal): define CK_TILE_USE_WMMA with default value to stop compilation error (#2737) --- CMakeLists.txt | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f148f31d25..39eb815680 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -221,13 +221,20 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9 add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") endif() + +# new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA +set(CK_TILE_USE_WMMA 0) + if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") message(STATUS "Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") - add_definitions(-DCK_TILE_USE_WMMA) - set(CK_TILE_USE_WMMA "ON") + set(CK_TILE_USE_WMMA 1) endif() + +# define the macro with the current value (0 or 1) +add_definitions(-DCK_TILE_USE_WMMA=${CK_TILE_USE_WMMA}) + if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") message(STATUS "Enabling WMMA FP8 gemms on native architectures") add_definitions(-DCK_USE_WMMA_FP8) From 99d27aca17f19f4cfed938c055917c4d27d2507e Mon Sep 17 00:00:00 2001 From: John Shumway Date: Mon, 25 Aug 2025 18:56:58 -0700 Subject: [PATCH 26/46] Add a CMake property for c++ standard (17 or 20) (#2736) Configure C++ standard with a CMake variable. Defaults to C++20, but can be set to C++17 to test backwards compatibility. * Add validation for allowed C++ standards. * build CK in rehl8 docker with std=c++17 --------- Co-authored-by: illsilin_amdeng --- CMakeLists.txt | 11 ++++++++++- Jenkinsfile | 1 + 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 39eb815680..52bb2ccd2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,6 +16,15 @@ else() "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel.") endif() +# Allow user to specify the C++ standard. +# We must support C++17 builds until downstream users are migrated to C++20, but we default to C++20. +set(CK_CXX_STANDARD "20" CACHE STRING "C++ standard to use (e.g. 17 or 20)") +set(valid_cxx_standards 17 20) +set_property(CACHE CK_CXX_STANDARD PROPERTY STRINGS ${valid_cxx_standards}) +if(NOT CK_CXX_STANDARD IN_LIST valid_cxx_standards) + message(FATAL_ERROR "CK_CXX_STANDARD must be one of ${valid_cxx_standards}") +endif() + # Default installation path if(NOT WIN32) set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "") @@ -345,7 +354,7 @@ find_package(Threads REQUIRED) link_libraries(Threads::Threads) ## C++ -set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD ${CK_CXX_STANDARD}) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) message(STATUS "CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}") diff --git a/Jenkinsfile b/Jenkinsfile index 8f5c724776..d590c01ba7 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1358,6 +1358,7 @@ pipeline { def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_rhel8_rocm6.3" setup_args = """ -DGPU_TARGETS="gfx942" \ -DCMAKE_CXX_FLAGS=" -O3 " \ + -DCK_CXX_STANDARD="17" \ -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """ execute_args = " " } From d43228fbca5d903a032afee1487a089a83858b1b Mon Sep 17 00:00:00 2001 From: Mateusz Ozga <110818320+mozga-amd@users.noreply.github.com> Date: Tue, 26 Aug 2025 04:29:35 +0200 Subject: [PATCH 27/46] [CK-TILE] Default epilogue, adding support for D (#2629) * Extend 2d-epilogue, D support * Added tests & update * Remove unused attribute * Extend tests --------- Co-authored-by: Thomas Ning --- .../ops/epilogue/default_2d_epilogue.hpp | 120 +++++-- .../ops/gemm/kernel/gemm_multi_d_kernel.hpp | 6 + test/ck_tile/gemm_multi_d/CMakeLists.txt | 6 +- ...i_d.cpp => test_gemm_multi_d_cshuffle.cpp} | 27 +- .../test_gemm_multi_d_default2d.cpp | 43 +++ .../test_gemm_multi_d_ut_cases.inc | 334 ------------------ .../test_gemm_multi_d_ut_cases_cshuffle.inc | 211 +++++++++++ .../test_gemm_multi_d_ut_cases_default2d.inc | 211 +++++++++++ .../gemm_multi_d/test_gemm_multi_d_util.hpp | 89 ++--- tile_engine/ops/gemm/codegen_utils.py | 5 + 10 files changed, 624 insertions(+), 428 deletions(-) rename test/ck_tile/gemm_multi_d/{test_gemm_multi_d.cpp => test_gemm_multi_d_cshuffle.cpp} (75%) create mode 100644 test/ck_tile/gemm_multi_d/test_gemm_multi_d_default2d.cpp delete mode 100644 test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc create mode 100644 test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_cshuffle.inc create mode 100644 test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_default2d.inc diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 8a0970f494..401f90f78f 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -29,9 +29,14 @@ struct Default2DEpilogueProblem template ; using BDataType = remove_cvref_t; using CLayout = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; + static constexpr index_t kMPerBlock = kM_; + static constexpr index_t kNPerBlock = kN_; static constexpr index_t kMPerXdl = kMPerXdl_; static constexpr index_t kNPerXdl = kNPerXdl_; static constexpr index_t kKPerXdl = kKPerXdl_; static constexpr index_t isCTransposed = isCTransposed_; + + static constexpr index_t NumDTensor = DsDataType::size(); + + static_assert(NumDTensor == DsLayout::size(), + "The size of DsDataType and DsLayout should be the same"); }; template @@ -62,6 +77,7 @@ struct Default2DEpilogue using Problem = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using CDElementwise = remove_cvref_t; static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool UseRawStore = Problem::UseRawStore; @@ -71,43 +87,70 @@ struct Default2DEpilogue // TODO: this function assume store out vector size is the same as OAccTile last dimension size // how do we fix this ? - template - CK_TILE_DEVICE auto - operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) const - { - // TODO: this is ugly - if constexpr(UseRawStore && (kPadM || kPadN)) - { - if constexpr(MemoryOperation == memory_operation_enum::set) - { - store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); - } - else - { - update_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); - } - buffer_store_fence(); - } - else - { - if constexpr(MemoryOperation == memory_operation_enum::set) - { - store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); - } - else - { - update_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); - } - } - } - template CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, - const DsDramWindows& /* unused */, - void* = nullptr) const + const DsDramWindows& ds_dram_windows, + void* = nullptr) { - return operator()(o_dram_window_tmp, o_acc_tile); + const auto storeOrUpdateTile = [&](const auto& o_tile) { + // TODO: this is ugly + if constexpr(UseRawStore && (kPadM || kPadN)) + { + if constexpr(MemoryOperation == memory_operation_enum::set) + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_tile)); + } + else + { + update_tile_raw(o_dram_window_tmp, cast_tile(o_tile)); + } + buffer_store_fence(); + } + else + { + if constexpr(MemoryOperation == memory_operation_enum::set) + { + store_tile(o_dram_window_tmp, cast_tile(o_tile)); + } + else + { + update_tile(o_dram_window_tmp, cast_tile(o_tile)); + } + } + }; + + if constexpr(Problem::NumDTensor >= 1) + { + using elementwise_result_t = decltype(load_tile( + make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(), + make_tuple(Problem::kMPerBlock, Problem::kNPerBlock), + ds_dram_windows[number<0>{}].get_window_origin(), + o_acc_tile.get_tile_distribution()))); + + elementwise_result_t elementwise_result; + + const auto d_tensor_tuple = generate_tuple( + [&](auto idx) { + const auto d_tile_window = + make_tile_window(ds_dram_windows[idx], o_acc_tile.get_tile_distribution()); + return load_tile(d_tile_window); + }, + number{}); + + const auto c_d_tuple = concat_tuple_of_reference( + tie(elementwise_result, o_acc_tile), + generate_tie([&](auto idx) -> const auto& { return d_tensor_tuple[idx]; }, + number{})); + + tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_d_tuple); + + storeOrUpdateTile(elementwise_result); + } + else + { + storeOrUpdateTile(o_acc_tile); + } } }; @@ -122,8 +165,9 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; - using DsDataType = ck_tile::tuple<>; - using DsLayout = ck_tile::tuple<>; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; using CLayout = remove_cvref_t; static constexpr index_t kMPerXdl = Problem::kMPerXdl; static constexpr index_t kNPerXdl = Problem::kNPerXdl; @@ -192,7 +236,11 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue } } - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; } + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD([[maybe_unused]] number index) + { + return GetVectorSizeC(); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp index 34c4e72b22..9d3ac8b901 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp @@ -175,6 +175,12 @@ struct GemmKernelMultiD CK_TILE_HOST static auto IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool { + // Currently MultiD kernel doesn't support k_batch > 1 + if(kargs.k_batch > 1) + { + return false; + } + return UniversalGemmKernel::IsSupportedArgument(kargs); } diff --git a/test/ck_tile/gemm_multi_d/CMakeLists.txt b/test/ck_tile/gemm_multi_d/CMakeLists.txt index a50de7178b..c9d53e53e2 100644 --- a/test/ck_tile/gemm_multi_d/CMakeLists.txt +++ b/test/ck_tile/gemm_multi_d/CMakeLists.txt @@ -5,6 +5,8 @@ if(CK_USE_OCP_FP8) endif() if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") - add_gtest_executable(test_ck_tile_gemm_multi_d test_gemm_multi_d.cpp) - target_compile_definitions(test_ck_tile_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_gemm_multi_d_cshuffle test_gemm_multi_d_cshuffle.cpp) + add_gtest_executable(test_gemm_multi_d_default2d test_gemm_multi_d_default2d.cpp) + target_compile_definitions(test_gemm_multi_d_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_definitions(test_gemm_multi_d_default2d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_cshuffle.cpp similarity index 75% rename from test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp rename to test/ck_tile/gemm_multi_d/test_gemm_multi_d_cshuffle.cpp index a634d825b7..8ac847e888 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_cshuffle.cpp @@ -18,22 +18,23 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, CDataType, CDElementWiseFn - std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F16, ElementWiseAddAdd>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F16, ElementWiseAddAdd>, + // Has cshuffle epilogue enabled + // ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, CDElementWiseFn, UseCshuffleEpilog + std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F16, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F16, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F32, MultiplyMultiply>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F32, MultiplyMultiply> + std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F32, MultiplyMultiply, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F32, MultiplyMultiply, std::true_type> >; // clang-format on TYPED_TEST_SUITE(TestCkTileGemmMultiD, KernelTypes); -#include "test_gemm_multi_d_ut_cases.inc" +#include "test_gemm_multi_d_ut_cases_cshuffle.inc" diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_default2d.cpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_default2d.cpp new file mode 100644 index 0000000000..4f14cc49f9 --- /dev/null +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_default2d.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_gemm_multi_d_util.hpp" + +using F16 = ck_tile::half_t; +using BF16 = ck_tile::bf16_t; +using F32 = float; +using F8 = ck_tile::fp8_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +using KernelTypes = ::testing::Types< + // Has cshuffle epilogue disabled + // ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, CDElementWiseFn, UseCshuffleEpilog + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, BF16, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, BF16, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, F16, F16, F32, F16, ElementWiseAddAdd, std::false_type>, + + std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, BF16, MultiplyMultiply, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, BF16, MultiplyMultiply, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, F16, F16, F32, F16, MultiplyMultiply, std::false_type> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGemmMultiD, KernelTypes); + +#include "test_gemm_multi_d_ut_cases_default2d.inc" diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc deleted file mode 100644 index 22d887fa83..0000000000 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc +++ /dev/null @@ -1,334 +0,0 @@ -#pragma once - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x512x256) -{ - constexpr int M = 256; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x256x256) -{ - constexpr int M = 512; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x512x256) -{ - constexpr int M = 512; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x256x256) -{ - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x768x256) -{ - constexpr int M = 512; - constexpr int N = 768; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x1280x256) -{ - constexpr int M = 512; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x1280x256) -{ - constexpr int M = 256; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_768x512x256) -{ - constexpr int M = 768; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_1280x512x256) -{ - constexpr int M = 1280; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_1280x256x256) -{ - constexpr int M = 1280; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x512x256) -{ - constexpr int M = 256; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x256x256) -{ - constexpr int M = 512; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x512x256) -{ - constexpr int M = 512; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x256x256) -{ - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x768x256) -{ - constexpr int M = 512; - constexpr int N = 768; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x1280x256) -{ - constexpr int M = 512; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x1280x256) -{ - constexpr int M = 256; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_768x512x256) -{ - constexpr int M = 768; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_1280x512x256) -{ - constexpr int M = 1280; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_1280x256x256) -{ - constexpr int M = 1280; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_256x256x512) -{ - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_512x768x512) -{ - constexpr int M = 512; - constexpr int N = 768; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_512x1280x512) -{ - constexpr int M = 512; - constexpr int N = 1280; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_256x1280x512) -{ - constexpr int M = 256; - constexpr int N = 1280; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_768x512x512) -{ - constexpr int M = 768; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_1280x512x512) -{ - constexpr int M = 1280; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_1280x256x512) -{ - constexpr int M = 1280; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x512x512) -{ - constexpr int M = 256; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x256x512) -{ - constexpr int M = 512; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x512x512) -{ - constexpr int M = 512; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x256x512) -{ - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x768x512) -{ - constexpr int M = 512; - constexpr int N = 768; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x1280x512) -{ - constexpr int M = 512; - constexpr int N = 1280; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x1280x512) -{ - constexpr int M = 256; - constexpr int N = 1280; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_768x512x512) -{ - constexpr int M = 768; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_1280x512x512) -{ - constexpr int M = 1280; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_1280x256x512) -{ - constexpr int M = 1280; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_cshuffle.inc b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_cshuffle.inc new file mode 100644 index 0000000000..8d21c65692 --- /dev/null +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_cshuffle.inc @@ -0,0 +1,211 @@ +#pragma once + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x512x512) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_default2d.inc b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_default2d.inc new file mode 100644 index 0000000000..35b40a896a --- /dev/null +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_default2d.inc @@ -0,0 +1,211 @@ +#pragma once + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x512x512) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp index d21777c92b..8399bc7ee3 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -70,20 +70,21 @@ template class TestCkTileGemmMultiD : public ::testing::Test { protected: - using ALayout = std::tuple_element_t<0, Tuple>; - using BLayout = std::tuple_element_t<1, Tuple>; - using D0Layout = std::tuple_element_t<2, Tuple>; - using D1Layout = std::tuple_element_t<3, Tuple>; - using ELayout = std::tuple_element_t<4, Tuple>; - using ADataType = std::tuple_element_t<5, Tuple>; - using BDataType = std::tuple_element_t<6, Tuple>; - using D0DataType = std::tuple_element_t<7, Tuple>; - using D1DataType = std::tuple_element_t<8, Tuple>; - using AccDataType = std::tuple_element_t<9, Tuple>; - using EDataType = std::tuple_element_t<10, Tuple>; - using CDElementWiseFn = std::tuple_element_t<11, Tuple>; - using DsLayout = ck_tile::tuple; - using DsDataType = ck_tile::tuple; + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using D0Layout = std::tuple_element_t<2, Tuple>; + using D1Layout = std::tuple_element_t<3, Tuple>; + using ELayout = std::tuple_element_t<4, Tuple>; + using ADataType = std::tuple_element_t<5, Tuple>; + using BDataType = std::tuple_element_t<6, Tuple>; + using D0DataType = std::tuple_element_t<7, Tuple>; + using D1DataType = std::tuple_element_t<8, Tuple>; + using AccDataType = std::tuple_element_t<9, Tuple>; + using EDataType = std::tuple_element_t<10, Tuple>; + using CDElementWiseFn = std::tuple_element_t<11, Tuple>; + using UseCshuffleEpilog = std::tuple_element_t<12, Tuple>; + using DsLayout = ck_tile::tuple; + using DsDataType = ck_tile::tuple; template ; using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< + + using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; + + using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = std:: + conditional_t; + using Kernel = ck_tile::GemmKernelMultiD; auto kargs = Kernel::MakeKernelArgs(args); @@ -218,6 +243,7 @@ class TestCkTileGemmMultiD : public ::testing::Test const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { + std::cout << "Run without SplitK" << std::endl; Run(has_hot_loop_, tail_number_, ck_tile::integral_constant{}); } }; - if(has_hot_loop) - { - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "For compute pipeline tail number should always be Full, but have \"" - << tail_num << "\" which is not supported! PrefetchStages: " - << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - } - else - { - std::ostringstream err; - err << "Num K loop must be larger than number of prefetech stages." - << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } public: - void Run(const int M, + bool Run(const int M, const int N, const int K, const int k_batch, @@ -401,6 +404,6 @@ class TestCkTileGemmMultiD : public ::testing::Test << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; - EXPECT_TRUE(pass); + return pass; } }; diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index dd9de36865..392125aa0b 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -31,9 +31,14 @@ DEFAULT_EPILOGUE = """ using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< ck_tile::DefaultGemm2DEpilogueProblem, AccDataType, CDataType, + ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, kPadM, kPadN, WarpTileM, From 5e85c38d7d86cad76af7130a7434d5dcccc20898 Mon Sep 17 00:00:00 2001 From: SamiAario-AMD Date: Tue, 26 Aug 2025 13:25:48 +0300 Subject: [PATCH 28/46] Lwpck 3548 gemm test cleanups (#2717) * Remove some unnecessary calls to create_args in basic and universal GEMM tests * Remove unnecessary include statements in universal GEMM tests * Improve compilation time of basic GEMM tests by only compiling the precision variants that we need * Universal GEMM PrecType should be the same as CDataType * Improve compilation time of universal GEMM tests by only compiling the precision variants that we need * Revert to constexpr when defining some constants --- .../gemm/test_gemm_pipeline_basic_bf16.cpp | 2 +- .../gemm/test_gemm_pipeline_basic_bf8.cpp | 2 +- .../gemm/test_gemm_pipeline_basic_fp16.cpp | 2 +- .../gemm/test_gemm_pipeline_basic_fp8.cpp | 2 +- .../test_gemm_pipeline_basic_run_test.inc | 127 ++++++--------- .../test_gemm_pipeline_smoke_run_test.inc | 7 +- .../test_gemm_pipeline_universal_bf16.cpp | 9 +- .../gemm/test_gemm_pipeline_universal_bf8.cpp | 9 +- .../test_gemm_pipeline_universal_fp16.cpp | 9 +- .../gemm/test_gemm_pipeline_universal_fp8.cpp | 9 +- .../test_gemm_pipeline_universal_run_test.inc | 148 +++++++----------- 11 files changed, 108 insertions(+), 218 deletions(-) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp index af2cb398f5..4e3033782c 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp @@ -2,4 +2,4 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "test_gemm_pipeline_basic_run_test.inc" -int main() { return run_gemm_combinations("bf16"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp index fd8c28ef17..61614fc6f5 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp @@ -2,4 +2,4 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "test_gemm_pipeline_basic_run_test.inc" -int main() { return run_gemm_combinations("bf8"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp index 4a93d6046a..c667c08053 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp @@ -2,4 +2,4 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "test_gemm_pipeline_basic_run_test.inc" -int main() { return run_gemm_combinations("fp16"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp index fd8c28ef17..9a3498b7ea 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp @@ -2,4 +2,4 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "test_gemm_pipeline_basic_run_test.inc" -int main() { return run_gemm_combinations("bf8"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc index 53eff9ecc4..1fdf26f01c 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc @@ -131,7 +131,9 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) } template -bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +bool run_gemm_test_prec_type(std::string a_layout, + std::string b_layout, + ck_tile::ArgParser& arg_parser) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -141,12 +143,12 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg if(a_layout == "R" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Col{}, Col{}, Row{}); + arg_parser, Col{}, Col{}, Row{}); } else { @@ -159,22 +161,22 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg if(a_layout == "R" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "R" && b_layout == "R") { return run_gemm_test_with_layouts( - argc, argv, Row{}, Row{}, Row{}); + arg_parser, Row{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { return run_gemm_test_with_layouts( - argc, argv, Col{}, Row{}, Row{}); + arg_parser, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Col{}, Col{}, Row{}); + arg_parser, Col{}, Col{}, Row{}); } else { @@ -183,60 +185,26 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg } } +template bool run_gemm_test(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); if(!result) return false; - std::string data_type = arg_parser.get_str("prec"); - std::string a_layout = arg_parser.get_str("a_layout"); - std::string b_layout = arg_parser.get_str("b_layout"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); - if(data_type == "fp16") - { - return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); - } - else if(data_type == "bf16") - { - return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); - } - else if(data_type == "fp8") - { - return run_gemm_test_prec_type( - a_layout, b_layout, argc, argv); - } - else if(data_type == "bf8") - { - return run_gemm_test_prec_type( - a_layout, b_layout, argc, argv); - } - else if(data_type == "pk_int4_t") - { - // TODO: Add support for bhalf_t ADataType - if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) - { - return run_gemm_test_prec_type( - a_layout, b_layout, argc, argv); - } - else - { - throw std::runtime_error("Unsupported data type for this operation !!!"); - } - } - else - { - throw std::runtime_error("Unsupported data type for this operation !!!"); - } + return run_gemm_test_prec_type(a_layout, b_layout, arg_parser); } -int run_gemm_combinations(std::string const& data_type) +template +int run_gemm_combinations() { // Define possible values for each parameter - std::vector m_values = {"128", "1024"}; - std::vector n_values = {"128", "2048"}; - std::vector k_values = {"64", "128"}; - std::vector prec_values = {data_type}; + std::vector m_values = {"128", "1024"}; + std::vector n_values = {"128", "2048"}; + std::vector k_values = {"64", "128"}; // We'll store all our arguments as strings first std::vector arg_strings = {"./bin/tile_example_gemm_basic", @@ -246,13 +214,12 @@ int run_gemm_combinations(std::string const& data_type) "-stride_a=0", "-stride_b=0", "-stride_c=0", - "", // prec placeholder "-v=2", "-warmup=0", "-repeat=1"}; // Create an array of const char pointers for argv - constexpr size_t ARG_COUNT = 11; + constexpr size_t ARG_COUNT = 10; constexpr size_t ARG_MAX_LEN = 64; char args[ARG_COUNT][ARG_MAX_LEN]; char* argv[ARG_COUNT]; @@ -271,39 +238,35 @@ int run_gemm_combinations(std::string const& data_type) { arg_strings[3] = "-k=" + k; - for(const auto& prec : prec_values) + // Set up the argv array with pointers to the string data + for(size_t i = 0; i < ARG_COUNT; i++) { - arg_strings[7] = "-prec=" + prec; + strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN); + argv[i] = args[i]; + } - // Set up the argv array with pointers to the string data - for(size_t i = 0; i < ARG_COUNT; i++) - { - strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN); - argv[i] = args[i]; - } + std::cout << "Arguments received: "; + for(size_t i = 1; i < ARG_COUNT; ++i) + { + std::cout << argv[i] << " "; + } + std::cout << std::endl; - std::cout << "Arguments received: "; - for(size_t i = 1; i < ARG_COUNT; ++i) - { - std::cout << argv[i] << " "; - } - std::cout << std::endl; - - // Call the function with the current configuration - try - { - is_success = run_gemm_test(ARG_COUNT, argv) && is_success; - } - catch(const ArgumentsNotSupportedException& e) - { - std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n'; - // ArgumentsNotSupportedException is not an error. Do not change is_success - } - catch(const std::runtime_error& e) - { - std::cerr << "Caught runtime error: " << e.what() << '\n'; - is_success = false; - } + // Call the function with the current configuration + try + { + is_success = run_gemm_test(ARG_COUNT, argv) && + is_success; + } + catch(const ArgumentsNotSupportedException& e) + { + std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n'; + // ArgumentsNotSupportedException is not an error. Do not change is_success + } + catch(const std::runtime_error& e) + { + std::cerr << "Caught runtime error: " << e.what() << '\n'; + is_success = false; } } } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc index a967b92e7f..ab74e4e7b1 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc @@ -256,16 +256,11 @@ template -bool run_gemm_test_with_layouts(int argc, - char* argv[], +bool run_gemm_test_with_layouts(ck_tile::ArgParser& arg_parser, const ALayout a_layout = ALayout{}, const BLayout b_layout = BLayout{}, [[maybe_unused]] const CLayout c_layout = CLayout{}) { - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return false; - using AccDataType = typename GemmTypeConfig::AccDataType; ck_tile::index_t M = arg_parser.get_int("m"); diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp index 0673272f5f..1336f6fd70 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp @@ -1,16 +1,9 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include - -#include -#include -#include - #include "ck_tile/host.hpp" #include "test_gemm_pipeline_smoke_util.hpp" #include "test_gemm_pipeline_smoke_run_test.inc" #include "test_gemm_pipeline_universal_run_test.inc" -int main() { return run_gemm_combinations("bf16"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp index 70eae12e82..5d55f34b84 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp @@ -1,16 +1,9 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include - -#include -#include -#include - #include "ck_tile/host.hpp" #include "test_gemm_pipeline_smoke_util.hpp" #include "test_gemm_pipeline_smoke_run_test.inc" #include "test_gemm_pipeline_universal_run_test.inc" -int main() { return run_gemm_combinations("bf8"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp index 8ea192c7f3..0cebbcc721 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp @@ -1,16 +1,9 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include - -#include -#include -#include - #include "ck_tile/host.hpp" #include "test_gemm_pipeline_smoke_util.hpp" #include "test_gemm_pipeline_smoke_run_test.inc" #include "test_gemm_pipeline_universal_run_test.inc" -int main() { return run_gemm_combinations("fp16"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp index 20414b4fec..29fb5f87ce 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp @@ -1,16 +1,9 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include - -#include -#include -#include - #include "ck_tile/host.hpp" #include "test_gemm_pipeline_smoke_util.hpp" #include "test_gemm_pipeline_smoke_run_test.inc" #include "test_gemm_pipeline_universal_run_test.inc" -int main() { return run_gemm_combinations("fp8"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc index adae8dcf92..fd50596f2f 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc @@ -200,7 +200,9 @@ template -bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +bool run_gemm_test_prec_type(std::string a_layout, + std::string b_layout, + ck_tile::ArgParser& arg_parser) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -210,12 +212,12 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg if(a_layout == "R" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Col{}, Col{}, Row{}); + arg_parser, Col{}, Col{}, Row{}); } else { @@ -228,22 +230,22 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg if(a_layout == "R" && b_layout == "R") { return run_gemm_test_with_layouts( - argc, argv, Row{}, Row{}, Row{}); + arg_parser, Row{}, Row{}, Row{}); } else if(a_layout == "R" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { return run_gemm_test_with_layouts( - argc, argv, Col{}, Row{}, Row{}); + arg_parser, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Col{}, Col{}, Row{}); + arg_parser, Col{}, Col{}, Row{}); } else { @@ -252,69 +254,27 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg } } -template