From 7fbc9d6c9770b9aade0121a334b9c38845820570 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Tue, 16 Sep 2025 11:32:38 +0800 Subject: [PATCH 01/28] [CK_TILE] FMHA FAv3 scheduling fine-tuning for performance (#2833) * Re-mapping thread block indices for causal=True kernels * Use more intuitive remap_opt value * Fallback to origin remapping if seqlen_q >= 64K * Use GenericAttentionMask to reduce mask computation * Avoid unnecessary boundary check for IsMasking=false case * Fix wrong kernel entry specifier * Add s_nop to prevent delay wave0-3 * Refine scheduling * Remove unnecessary sched_group_barrier() * Move sched_group_barrier() call to scheduler * Replace inline asm s_setprio with intrinsics * Rephrase comments * Expend some o_acc rescaling insts to avoid SIMD idle * Fix block idx special mapping logic * Tune block index mapping for causal=False cases * Tune block index mapping for causal=True cases * Fix wrong vmcnt() * Remove parameter name * Use boolean option for turn on/off causal mask * Update benchmark_fwd_v3.sh option usages * Add option if compiler support it --- example/ck_tile/01_fmha/CMakeLists.txt | 14 +- .../ck_tile/01_fmha/example_fmha_fwd_v3.cpp | 24 ++-- example/ck_tile/01_fmha/fmha_fwd_v3.hpp | 3 +- example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 22 ++- .../01_fmha/script/benchmark_fwd_v3.sh | 24 ++-- .../ck_tile/ops/fmha/block/block_masking.hpp | 41 +++--- .../ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 107 ++++++++++----- .../pipeline/block_fmha_fwd_v3_pipeline.hpp | 127 +++++++++++++----- 8 files changed, 250 insertions(+), 112 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 68db468a7c..3d79f2f6d3 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -213,8 +213,20 @@ list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS -Wno-undefined-func-template --save-temps ) -target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) +set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS) +check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32) +if(HAS_DISABLE_PACKED_FP32) + list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS + -mllvm --amdgpu-disable-packed-fp32=1 + ) + list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS + -DCK_TILE_DISABLE_PACKED_FP32=1 + ) +endif() + +target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) +target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long # however, this property may affect global diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index d2428e5152..569c98a458 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -45,18 +45,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair(hdim)); - mask = mask_info::decode(args.get_str("mask"), seqlen_q, seqlen_k); + + const auto is_causal = args.get_bool("causal"); + if(is_causal) + { + mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k); + } + else + { + mask = mask_info::decode("0", seqlen_q, seqlen_k); + } input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp index 5361d27f0f..10cb5149a4 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp @@ -34,7 +34,8 @@ struct fmha_fwd_v3_args index_t window_size_left; index_t window_size_right; - index_t mask_type; + index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and + // window_size_right == 0). const void* q_ptr; index_t stride_q; diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index d6e4ac4c60..e0fbad39a5 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -18,6 +18,7 @@ #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "fmha_fwd_v3.hpp" +#include "mask.hpp" #define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ template <> \ @@ -79,7 +80,7 @@ struct fmha_fwd_v3_kernel_traits -1 // kBlockPerCu >; - using fmha_mask = SimplifiedGenericAttentionMask; + using fmha_mask = GenericAttentionMask; using fmha_pipeline_problem = BlockFmhaFwdV3PipelineProblem::qkvp_dtype, @@ -112,6 +113,22 @@ struct fmha_fwd_v3_kernel_traits template float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config) { + /// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly + /// maximizes the kernel's performance. + int remap_opt = 2; + if(args.mask_type != static_cast(mask_enum::no_mask) && + ((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q))) + { + if(65536 <= args.seqlen_q) + { + remap_opt = 0; + } + else + { + remap_opt = 1; + } + } + auto kargs = Kernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, @@ -140,7 +157,8 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi args.batch_stride_o, args.window_size_left, args.window_size_right, - args.mask_type); + args.mask_type, + remap_opt); dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); constexpr dim3 blocks = Kernel::BlockSize(); diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh index 9c500edf9d..b847e85398 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh @@ -8,22 +8,16 @@ for prec in "fp16" "bf16" ; do for hdim in 128 ; do for perm in 0 ; do -if [ $causal -eq 0 ]; then - mask=0 -else - mask=b:-1,0 -fi - -$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID done done diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index f5c12e11d2..2c45945fac 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -203,27 +203,36 @@ struct GenericAttentionMask CK_TILE_HOST_DEVICE constexpr auto IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number, number) const { - if constexpr(IsLocal) + if constexpr(!IsMasking) { - // check top-right corner > x or left-borrom corner < x - index_t i_tile_right = i_tile_left + TileWidth; - index_t i_tile_bottom = i_tile_top + TileHeight; - index_t x_end = min(i_tile_top + x, x_total); - - bool top_right_edge = i_tile_right > (i_tile_top + x); - bool bottom_left_edge = i_tile_bottom > (i_tile_left + y); - bool is_partial_out_of_bound = i_tile_right > x_end; // only consider right-pad for now - - return top_right_edge || bottom_left_edge || is_partial_out_of_bound; + // TODO: no need to check begin + return (i_tile_left + TileWidth) > x_total; } else { - // only need to check top-right corner > x - index_t i_tile_right = i_tile_left + TileWidth; - index_t x_end = min(i_tile_top + x, x_total); + if constexpr(IsLocal) + { + // check top-right corner > x or left-borrom corner < x + index_t i_tile_right = i_tile_left + TileWidth; + index_t i_tile_bottom = i_tile_top + TileHeight; + index_t x_end = min(i_tile_top + x, x_total); - bool top_right_edge = i_tile_right > x_end; - return top_right_edge; + bool top_right_edge = i_tile_right > (i_tile_top + x); + bool bottom_left_edge = i_tile_bottom > (i_tile_left + y); + bool is_partial_out_of_bound = + i_tile_right > x_end; // only consider right-pad for now + + return top_right_edge || bottom_left_edge || is_partial_out_of_bound; + } + else + { + // only need to check top-right corner > x + index_t i_tile_right = i_tile_left + TileWidth; + index_t x_end = min(i_tile_top + x, x_total); + + bool top_right_edge = i_tile_right > x_end; + return top_right_edge; + } } } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index 87021354aa..c5e5745817 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -81,6 +81,7 @@ struct FmhaFwdV3Kernel // ck_tile::index_t window_size_left, window_size_right; ck_tile::index_t window_size_left, window_size_right; ck_tile::GenericAttentionMaskEnum mask_type; + ck_tile::index_t remap_opt; }; struct FmhaFwdCommonLSEKargs @@ -143,7 +144,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + ck_tile::index_t remap_opt) { Kargs kargs{{q_ptr, k_ptr, @@ -176,6 +178,7 @@ struct FmhaFwdV3Kernel kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; kargs.mask_type = static_cast(mask_type); + kargs.remap_opt = remap_opt; } if constexpr(kStoreLSE) { @@ -213,7 +216,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + ck_tile::index_t remap_opt) { Kargs kargs{{q_ptr, k_ptr, @@ -245,6 +249,7 @@ struct FmhaFwdV3Kernel kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; kargs.mask_type = static_cast(mask_type); + kargs.remap_opt = remap_opt; } if constexpr(kStoreLSE) { @@ -261,39 +266,81 @@ struct FmhaFwdV3Kernel ck_tile::index_t hdim_v_) { // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - nhead_, - batch_size_); - } - - CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) - { - using namespace ck_tile; - - // const index_t num_tile_m0 = seqlen_q / kM0; - const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); - - const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; - - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; - return ck_tile::make_tuple(quotient, modulus); - }; - - const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - if constexpr(kHasMask) { - // assume that num_tile_n1 is always 1 - return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + return dim3(nhead_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + batch_size_); } else { - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + return dim3(nhead_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + batch_size_); + } + } + + CK_TILE_DEVICE static constexpr auto + RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option) + { + if(remap_option < 1) + { + return make_tuple(static_cast(gridDim.x - tg_idx - 1), tg_idy); + } + + int32_t remapped_tg_idx = tg_idx; + int32_t remapped_tg_idy = tg_idy; + + if(remap_option == 2) + { // special remapping + int32_t tmp0 = (remapped_tg_idy & 0x7) * gridDim.x + remapped_tg_idx; + int32_t tmp1 = tmp0 & 0x7; + + remapped_tg_idx = tmp0 >> 3; + remapped_tg_idy = (remapped_tg_idy & 0xfffffff8) + tmp1; + } + else + { // normal remapping + int32_t cus_per_xdim_per_xcc = gridDim.x >> 3; + int32_t tgs_cu_id = remapped_tg_idx >> 3; + + if(tgs_cu_id < cus_per_xdim_per_xcc) + { + int32_t tgs_xcc_id = remapped_tg_idx & 0x7; + int32_t new_tg_idx = tgs_xcc_id * cus_per_xdim_per_xcc + tgs_cu_id; + + remapped_tg_idx = new_tg_idx; + } + } + + return make_tuple(remapped_tg_idx, remapped_tg_idy); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs&) + { + using namespace ck_tile; + + // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, + // FmhaPipeline::kN1); + + // assume that num_tile_n1 is always 1 + if constexpr(kHasMask) + { + const index_t i_nhead = blockIdx.x; + const index_t i_block = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch); + } + else + { + const index_t i_nhead = blockIdx.x; + const index_t i_block = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); } } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 20d84116d4..5e2a4e898b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -57,7 +57,11 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU }); } - else if constexpr(Phase == 1) {} + else if constexpr(Phase == 1) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 2) { #if !CK_TILE_DISABLE_PACKED_FP32 @@ -68,11 +72,19 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU }); } - else if constexpr(Phase == 3) {} + else if constexpr(Phase == 3) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } } else { - if constexpr(Phase == 0) {} + if constexpr(Phase == 0) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 1) { static_for<0, 8, 1>{}([&](auto) { @@ -81,7 +93,11 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU }); } - else if constexpr(Phase == 2) {} + else if constexpr(Phase == 2) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 3) { #if !CK_TILE_DISABLE_PACKED_FP32 @@ -115,7 +131,11 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU }); } - else if constexpr(Phase == 1) {} + else if constexpr(Phase == 1) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 2) { #if !CK_TILE_DISABLE_PACKED_FP32 @@ -126,11 +146,19 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU }); } - else if constexpr(Phase == 3) {} + else if constexpr(Phase == 3) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } } else { - if constexpr(Phase == 0) {} + if constexpr(Phase == 0) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 1) { static_for<0, 8, 1>{}([&](auto) { @@ -139,7 +167,11 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU }); } - else if constexpr(Phase == 2) {} + else if constexpr(Phase == 2) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 3) { #if !CK_TILE_DISABLE_PACKED_FP32 @@ -177,6 +209,15 @@ CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs) return result; } +CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs) +{ + float result; + asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]" + : [result] "=v"(result) + : [lhs] "v"(lhs), [rhs] "v"(rhs)); + return result; +} + CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b) { fp16x2_t result; @@ -466,7 +507,7 @@ struct BlockFmhaFwdV3Pipeline statically_indexed_array sp; decltype(gemm_1.MakeCBlockTile()) o_acc; - constexpr index_t fmha_alu_D_reg_cnt = 0; // threshold to decide how many fmha_alu_D_upd() + constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd() // instructions should we move to fmha_alu1() static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size()); @@ -631,8 +672,8 @@ struct BlockFmhaFwdV3Pipeline // K_mem_su_ld_insts = 1 for 32 x 128 // V_mem_su_ld_insts = 1 for 128 x 32 - static constexpr int K_mem_su_ld_insts = 1; - static constexpr int V_mem_su_ld_insts = 1; + constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access(); + constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access(); auto K_mem_load = [&](auto k_lds_write_idx) { async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); @@ -648,7 +689,6 @@ struct BlockFmhaFwdV3Pipeline auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); - __builtin_amdgcn_sched_barrier(0); /// FIXME: use the future-predicting method to move the window move_tile_window(v_dram_window, {kK1, 0}); @@ -726,11 +766,12 @@ struct BlockFmhaFwdV3Pipeline #else block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); #endif - // update partial o_acc [0, 2) - static_for<0, ck_tile::min(2, fmha_alu_D_reg_cnt), 1>{}( - [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); // l{j} + /// Note: The compiler keeps moving the following instructions elsewhere because 'l' + /// is first consumed later. To anchor them here, we rewrite the final addition in + /// inline assembly to create a dependency, forcing the dependent instructions to + /// be emitted at this point. constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); @@ -739,13 +780,15 @@ struct BlockFmhaFwdV3Pipeline l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]); }); - // update partial o_acc [2, fmha_alu_D_reg_cnt) - static_for<2, ck_tile::max(2, fmha_alu_D_reg_cnt), 1>{}( - [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); + // update partial o_acc [0, fmha_alu_D_reg_cnt) + static_for<0, fmha_alu_D_reg_cnt, 1>{}([&](auto idx) { + o_acc.thread_buf_[idx] = detail::mul_impl_vv(o_acc.thread_buf_[idx], o_acc_scale); + }); - /// NOTICE: Compiler keep moving the conversion instructions to other places. We rewite - /// the cast_tile() call into inline asm to force the conversion instructions to be - /// generated here. The fmha_alu1() call should be placed at the end of a phase. + /// Note: The compiler keeps sinking the conversion instructions because the + /// result 'p' is only consumed later. To anchor them here, we rewrite + /// the cast_tile() call as inline assembly, forcing the conversions to be + /// emitted at this point. static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0); static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) { float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]); @@ -763,6 +806,10 @@ struct BlockFmhaFwdV3Pipeline sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; } }); + + /// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly + /// can interfere with the behavior of sched_group_barrier(), so ending the phase here + /// avoids unintended reordering. }; auto gemm = [&](auto sp_reg_idx, auto gemm_idx) { @@ -937,9 +984,9 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx); + Scheduler::schedule(cl_p, number<1>{}); fmha_mask(xdl_SP_p01_reg_idx); - Scheduler::schedule(cl_p, number<1>{}); __builtin_amdgcn_sched_barrier(0); // phase2 ASM_MARKER("phase2 Wave0-3"); @@ -947,6 +994,8 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 0"); + __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p23_reg_idx, gemm1); Scheduler::schedule(cl_p, number<2>{}); @@ -995,6 +1044,8 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 1"); + __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p01_reg_idx, gemm0); fmha_alu1(xdl_SP_p23_reg_idx); @@ -1005,9 +1056,9 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx); + Scheduler::schedule(cl_p, number<2>{}); fmha_mask(xdl_SP_p01_reg_idx); - Scheduler::schedule(cl_p, number<2>{}); kv_token_start += kN0; if(num_total_loop <= ++i_total_loops) { @@ -1021,6 +1072,8 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 1"); + __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p23_reg_idx, gemm1); Scheduler::schedule(cl_p, number<3>{}); @@ -1036,7 +1089,14 @@ struct BlockFmhaFwdV3Pipeline auto ps_pi = number<1>{} - d; auto V_lds_rd_idx = ps_pi; - s_waitcnt_vmcnt(); + if(1 < num_total_loop) + { + s_waitcnt_vmcnt(); + } + else + { + s_waitcnt_vmcnt<0>(); + } __builtin_amdgcn_s_barrier(); V_lds_load(V_lds_rd_idx); @@ -1102,14 +1162,14 @@ struct BlockFmhaFwdV3Pipeline V_mem_load(number<1>{}); // V1 K_lds_load(number<1>{}); // K1 - asm volatile("s_setprio 0"); + __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_barrier(); while(core_loop(number<0>{})) ; } if(warp_group_id != 0) { - asm volatile("s_setprio 1"); + __builtin_amdgcn_s_setprio(1); __builtin_amdgcn_s_barrier(); while(core_loop(number<1>{})) ; @@ -1167,14 +1227,13 @@ struct BlockFmhaFwdV3Pipeline typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename LSEDramBlockWindowTmp> - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile - FmhaMask mask, - float scale_s, - void* smem_ptr) const + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + float scale_s, + void* smem_ptr) const { using namespace ck_tile; From 59cb9064821de461c237736ae61691e96a07572d Mon Sep 17 00:00:00 2001 From: Haocong WANG Date: Tue, 16 Sep 2025 15:07:10 +0800 Subject: [PATCH 02/28] [CK_TILE] fix bug when iperm =0 in fmha fwd (#2820) * fix bug when iperm =0 in fmha fwd * Disable f8 fmha smoke test until fix pr merged --------- Co-authored-by: Po Yen Chen --- example/ck_tile/01_fmha/script/smoke_test_fwd.sh | 2 +- include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 3913a0d5c2..dda3943454 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -97,7 +97,7 @@ run_fp16_appendkv_tests() { set -x run_fp16_bf16_tests -run_fp8_tests +# run_fp8_tests if [ $TEST_APPENDKV -eq 1 ] ; then run_fp16_appendkv_tests diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 9d848dfd7a..6405ca50df 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1868,7 +1868,7 @@ struct FmhaFwdKernel const auto v_dram_naive = make_naive_tensor_view( data, // will update this pointer if using paged-kvcache make_tuple(length, kargs.hdim_v), - make_tuple(kargs.hdim_v, 1), + make_tuple(kargs.stride_v, 1), number{}, number<1>{}); From 804065a36b12abbb708ed65eba4513a5df59a25d Mon Sep 17 00:00:00 2001 From: JH-Leon-KIM-AMD Date: Tue, 16 Sep 2025 16:56:11 +0300 Subject: [PATCH 03/28] [CK Tile] Grouped conv fwd splitn support (#2776) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What's New Add Split-N support for grouped convolution forward to handle tensors >2GB by splitting the batch dimension. ## Bug Fix Fixed 32-bit integer overflow that caused crashes with 6+ splits: - Use `long_index_t` for batch offset calculations - Remove redundant GemmM initialization in constructors ## How It Works - Automatically splits batch dimension when tensor exceeds 2GB - Uses grid.z dimension for parallel processing of splits - Each split processes a subset of batches independently ## Testing Verified with tile_example_grouped_conv_fwd: - n=3000 (6 splits) ✓ - n=3500 (7 splits) ✓ - n=10480 (40 splits) ✓ --- .../grouped_convolution_forward_kernel.hpp | 100 ++++++++++++++++-- .../utils/transform_conv_fwd_to_gemm.hpp | 74 ++++++++----- 2 files changed, 135 insertions(+), 39 deletions(-) diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index cf4eca7a2d..6fcef5502e 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -23,7 +23,8 @@ struct GroupedConvFwdKernelArgs using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + GroupedConvTraitsType_::ConvSpecialization, + true>; // Split N enabled static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; template < @@ -56,7 +57,7 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - GemmM = args.N_ * args.output_spatial_lengths_[0]; + // GemmM will be set after Split-N calculation GemmN = args.K_; GemmK = args.C_ * args.filter_spatial_lengths_[0]; GemmBatch = args.G_; @@ -94,6 +95,19 @@ struct GroupedConvFwdKernelArgs 1, std::multiplies()); group_stride_c = args.K_; + + // Initialize Split-N support fields for 1D convolution (NWGC layout) + // Get the actual split N from transformer + n_per_split = conv_to_gemm_transformer.GetN(); + original_n = conv_to_gemm_transformer.GetOriginalN(); + n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split); + + // Calculate batch strides for NWGC layout + input_batch_stride = args.C_ * args.input_spatial_lengths_[0]; + output_batch_stride = args.K_ * args.output_spatial_lengths_[0]; + + // Update GemmM to use split N (not original N) + GemmM = n_per_split * args.output_spatial_lengths_[0]; } template < @@ -133,7 +147,7 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; + // Note: GemmM will be set after Split-N calculation GemmN = args.K_; GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1]; GemmBatch = args.G_; @@ -171,6 +185,21 @@ struct GroupedConvFwdKernelArgs 1, std::multiplies()); group_stride_c = args.K_; + + // Initialize Split-N support fields for 2D convolution (NHWGC layout) + // Get the actual split N from transformer + n_per_split = conv_to_gemm_transformer.GetN(); + original_n = conv_to_gemm_transformer.GetOriginalN(); + n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split); + + // Calculate batch strides for NHWGC layout + input_batch_stride = + args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1]; + output_batch_stride = + args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; + + // Update GemmM to use split N (not original N) + GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; } template < @@ -217,8 +246,7 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] * - args.output_spatial_lengths_[2]; + // Note: GemmM will be set after Split-N calculation GemmN = args.K_; GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] * args.filter_spatial_lengths_[2]; @@ -257,6 +285,22 @@ struct GroupedConvFwdKernelArgs 1, std::multiplies()); group_stride_c = args.K_; + + // Initialize Split-N support fields for 3D convolution (NDHWGC layout) + // Get the actual split N from transformer + n_per_split = conv_to_gemm_transformer.GetN(); + original_n = conv_to_gemm_transformer.GetOriginalN(); + n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split); + + // Calculate batch strides for NDHWGC layout + input_batch_stride = args.C_ * args.input_spatial_lengths_[0] * + args.input_spatial_lengths_[1] * args.input_spatial_lengths_[2]; + output_batch_stride = args.K_ * args.output_spatial_lengths_[0] * + args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2]; + + // Update GemmM to use split N (not original N) + GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] * + args.output_spatial_lengths_[2]; } using AGridDescMK = remove_cvref_t< @@ -297,6 +341,13 @@ struct GroupedConvFwdKernelArgs long_index_t group_stride_a; long_index_t group_stride_b; long_index_t group_stride_c; + + // Split-N support fields - initialize to safe defaults + index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2) + index_t n_per_split = 1; // Batches per split (N_ from transformer) + index_t original_n = 1; // Original batch size before splitting + index_t input_batch_stride = 0; // Stride to next batch in input tensor + index_t output_batch_stride = 0; // Stride to next batch in output tensor }; /// @brief The Grouped Convolution Forward kernel template. @@ -392,10 +443,10 @@ struct GroupedConvolutionForwardKernel // clang-format on } - CK_TILE_HOST static constexpr auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs) + CK_TILE_HOST static auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs) { return dim3( - TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch); + TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.n_splits); } CK_TILE_HOST static auto BlockSize() @@ -430,6 +481,17 @@ struct GroupedConvolutionForwardKernel } } + // Check Split-K and Split-N conflict (both use blockIdx.z) + if(kargs.k_batch > 1 && kargs.n_splits > 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Cannot use both Split-K and Split-N simultaneously (both use blockIdx.z)!"); + } + return false; + } + const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}]; const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}]; @@ -768,10 +830,26 @@ struct GroupedConvolutionForwardKernel const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY); const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY); - // options - const InDataType* a_ptr = static_cast(kargs.in_ptr) + group_offset_a; - const WeiDataType* b_ptr = static_cast(kargs.wei_ptr) + group_offset_b; - OutDataType* c_ptr = static_cast(kargs.out_ptr) + group_offset_c; + // Split-N handling: Get which split this workgroup handles + const auto blockIdZ = __builtin_amdgcn_readfirstlane(blockIdx.z); + + // Calculate batch offset for this split + const index_t batch_offset = __builtin_amdgcn_readfirstlane(blockIdZ * kargs.n_per_split); + + // Calculate memory offsets for this split + const long_index_t input_batch_offset = static_cast(batch_offset) * + static_cast(kargs.input_batch_stride); + const long_index_t output_batch_offset = + static_cast(batch_offset) * + static_cast(kargs.output_batch_stride); + + // Adjust pointers: combine group offset and batch offset + const InDataType* a_ptr = + static_cast(kargs.in_ptr) + group_offset_a + input_batch_offset; + const WeiDataType* b_ptr = static_cast(kargs.wei_ptr) + + group_offset_b; // No batch offset for weights! + OutDataType* c_ptr = + static_cast(kargs.out_ptr) + group_offset_c + output_batch_offset; // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp index c468ae4398..2663d8a494 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp @@ -24,7 +24,7 @@ struct TransformConvFwdToGemm static constexpr auto I3 = number<3>{}; static constexpr auto I4 = number<4>{}; static constexpr auto I5 = number<5>{}; -#if 0 // TODO: Enable these functionalities + template static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, const ConvDimsType& strides, @@ -42,24 +42,40 @@ struct TransformConvFwdToGemm template static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths, - const ConvDimsType& a_g_n_c_wis_strides, - const ConvDimsType& c_g_n_k_wos_lengths, - const ConvDimsType& c_g_n_k_wos_strides) + const ConvDimsType& c_g_n_k_wos_lengths) { + // Calculate strides internally assuming contiguous memory layout + ConvDimsType a_g_n_c_wis_strides, c_g_n_k_wos_strides; + const index_t num_dims = a_g_n_c_wis_lengths.size(); + + // Calculate strides for input tensor (innermost to outermost) + a_g_n_c_wis_strides[num_dims - 1] = 1; + for(index_t i = num_dims - 2; i >= 0; i--) + { + a_g_n_c_wis_strides[i] = a_g_n_c_wis_strides[i + 1] * a_g_n_c_wis_lengths[i + 1]; + } + + // Calculate strides for output tensor + c_g_n_k_wos_strides[num_dims - 1] = 1; + for(index_t i = num_dims - 2; i >= 0; i--) + { + c_g_n_k_wos_strides[i] = c_g_n_k_wos_strides[i + 1] * c_g_n_k_wos_lengths[i + 1]; + } + const long_index_t a_element_space_size = calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); const long_index_t c_element_space_size = calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1); - const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), - c_element_space_size * sizeof(CDataType)); - constexpr long_index_t TwoGB = (long_index_t{1} << 31); + const long_index_t element_space_size = ck_tile::max( + a_element_space_size * sizeof(ADataType), c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB const IndexType N = a_g_n_c_wis_lengths[I1]; if(element_space_size > TwoGB) { // Minimum divisor of N to not exceed 2GB - const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + const auto divisor = ck_tile::integer_divide_ceil(element_space_size, TwoGB); if(divisor <= static_cast(N)) { @@ -70,7 +86,8 @@ struct TransformConvFwdToGemm { if(N % least_divisor == 0) { - return N / least_divisor; + IndexType result = N / least_divisor; + return result; } } // Not found, process one Convolution N per block @@ -90,9 +107,12 @@ struct TransformConvFwdToGemm return N; } } -#endif public: + // Public getter methods for Split-N support + CK_TILE_HOST constexpr IndexType GetN() const { return N_; } + CK_TILE_HOST constexpr IndexType GetOriginalN() const { return original_N_; } + CK_TILE_HOST constexpr TransformConvFwdToGemm() {} template @@ -100,6 +120,7 @@ struct TransformConvFwdToGemm TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base) : G_{static_cast(transform_conv_fwd_to_gemm_base.G_)}, N_{static_cast(transform_conv_fwd_to_gemm_base.N_)}, + original_N_{static_cast(transform_conv_fwd_to_gemm_base.original_N_)}, Di_{static_cast(transform_conv_fwd_to_gemm_base.Di_)}, Hi_{static_cast(transform_conv_fwd_to_gemm_base.Hi_)}, Wi_{static_cast(transform_conv_fwd_to_gemm_base.Wi_)}, @@ -168,18 +189,14 @@ struct TransformConvFwdToGemm std::is_same_v>); static_assert(std::is_same_v> || std::is_same_v>); -#if 0 // TODO: Enable these functionalities if constexpr(SplitN) { - N_ = GetSplitedNSize( - a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths); } else { N_ = c_g_n_k_wos_lengths[I1]; } -#endif - N_ = c_g_n_k_wos_lengths[I1]; } template >); static_assert(std::is_same_v> || std::is_same_v>); -#if 0 // TODO: Enable these functionalities + + // Store original N + original_N_ = c_g_n_k_wos_lengths[I1]; + if constexpr(SplitN) { - N_ = GetSplitedNSize( - a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths); } else { - N_ = c_g_n_k_wos_lengths[I1]; + N_ = c_g_n_k_wos_lengths[I1]; + original_N_ = N_; } -#endif - N_ = c_g_n_k_wos_lengths[I1]; } template >); static_assert(std::is_same_v> || std::is_same_v>); -#if 0 // TODO: Enable these functionalities + + // Store original N before potential splitting + original_N_ = c_g_n_k_wos_lengths[I1]; + if constexpr(SplitN) { - N_ = GetSplitedNSize( - a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths); } else { - N_ = c_g_n_k_wos_lengths[I1]; + N_ = original_N_; } -#endif - N_ = c_g_n_k_wos_lengths[I1]; } #if 0 // TODO: Enable these functionalities @@ -1417,7 +1435,7 @@ struct TransformConvFwdToGemm } } - IndexType G_, N_; + IndexType G_, N_, original_N_; IndexType Di_, Hi_, Wi_; IndexType Do_, Ho_, Wo_; IndexType Z_, Y_, X_; From 78a9823cb41f65040d02034b73f029bdc1175c7a Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Tue, 16 Sep 2025 08:18:51 -0600 Subject: [PATCH 04/28] [CK TILE GEMM] Add support to convert i4 to OCP FP8/BF8 (#2853) --- .../unary_element_wise_operation.hpp | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 9e3ccb025d..692d5ec504 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -162,6 +162,16 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q) */ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a) { +#if CK_TILE_USE_OCP_FP8 + // register values [3, 2, 1, 0] + static constexpr uint32_t reg0 = 0xcaccced0; + // register values [7, 6, 5, 4] + static constexpr uint32_t reg1 = 0xb8c0c4c8; + // register values [-1, -2, -3, -4] + static constexpr uint32_t reg2 = 0x44403800; + // register values [-5, -6, -7, -8] + static constexpr uint32_t reg3 = 0x4e4c4a48; +#else // register values [3, 2, 1, 0] static constexpr uint32_t reg0 = 0xd2d4d6d8; // register values [7, 6, 5, 4] @@ -170,6 +180,7 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a) static constexpr uint32_t reg2 = 0x4C484000; // register values [-5, -6, -7, -8] static constexpr uint32_t reg3 = 0x56545250; +#endif uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel; @@ -227,6 +238,16 @@ CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src) */ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a) { +#if CK_TILE_USE_OCP_FP8 + // register values [3, 2, 1, 0] + static constexpr uint32_t reg0 = 0Xc5c6c7c8; + // register values [7, 6, 5, 4] + static constexpr uint32_t reg1 = 0Xbcc0c2c4; + // register values [11, 10, 9, 8] + static constexpr uint32_t reg2 = 0X42403c00; + // register values [15, 14, 13, 12] + static constexpr uint32_t reg3 = 0X47464544; +#else // register values [3, 2, 1, 0] static constexpr uint32_t reg0 = 0Xc9cacbcc; // register values [7, 6, 5, 4] @@ -235,6 +256,7 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a) static constexpr uint32_t reg2 = 0X46444000; // register values [15, 14, 13, 12] static constexpr uint32_t reg3 = 0X4b4a4948; +#endif uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel; From 671adb59c54875cdb7c485bb0be387045b83dfb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 16 Sep 2025 17:47:28 +0200 Subject: [PATCH 05/28] Disable GridwiseOp prints if env var is off (#2843) * Disable GridwiseOp prints if env var is off * Fixes --- ...d_contraction_multiple_d_wmma_cshuffle.hpp | 8 ++- ...ise_batched_gemm_gemm_wmma_cshuffle_v3.hpp | 51 ++++++++++------ ...atched_gemm_softmax_gemm_wmma_cshuffle.hpp | 59 ++++++++++++------- .../gpu/grid/gridwise_fpAintB_gemm_wmma.hpp | 32 ++++++---- ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 28 +++++++-- .../gpu/grid/gridwise_gemm_wmma.hpp | 32 ++++++---- 6 files changed, 140 insertions(+), 70 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index ab3f3856aa..537e6dab28 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include +#include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -853,7 +854,10 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle arg.e_grid_desc_m_n_, arg.block_2_ctile_map_)) { - printf("GridwiseOp: Validity check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: Validity check failure\n"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp index b61c7a09eb..fa7eb4faaa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -398,41 +398,54 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) { - print("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", - M, - N, - c_grid_desc_m_n.GetLength(I0), - c_grid_desc_m_n.GetLength(I1)); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + print("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", + M, + N, + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + } return false; } if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0)) { - print("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | M/L/K/NPerBlock = " - "%d, %d, %d, %d\n", - M, - L, - K, - N, - MPerBlock, - LPerBlock, - KPerBlock, - NPerBlock); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + print("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | " + "M/L/K/NPerBlock = " + "%d, %d, %d, %d\n", + M, + L, + K, + N, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock); + } return false; } // check gemm1 gridwise gemm pipeline if(!(LPerBlock % LTilePerBlock == 0)) { - print("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", - LPerBlock, - LTilePerBlock); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + print("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", + LPerBlock, + LTilePerBlock); + } return false; } if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) { - print("GridwiseOp: invalid block_2_ctile_map\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + print("GridwiseOp: invalid block_2_ctile_map\n"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp index 1754e07e6a..502c449ef1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -569,26 +570,33 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) { - printf("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", - M, - N, - c_grid_desc_m_n.GetLength(I0), - c_grid_desc_m_n.GetLength(I1)); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", + M, + N, + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + } return false; } if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0)) { - printf("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | M/L/K/NPerBlock = " - "%d, %d, %d, %d\n", - M, - L, - K, - N, - MPerBlock, - LPerBlock, - KPerBlock, - NPerBlock); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | " + "M/L/K/NPerBlock = " + "%d, %d, %d, %d\n", + M, + L, + K, + N, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock); + } return false; } @@ -596,23 +604,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma const auto num_gemm0_k_loop = K / KPerBlock; if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) { - printf("GridwiseOp: outer loop unsupport\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: outer loop unsupport\n"); + } return false; } // check gemm1 gridwise gemm pipeline if(!(LPerBlock % LTilePerBlock == 0)) { - printf("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", - LPerBlock, - LTilePerBlock); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", + LPerBlock, + LTilePerBlock); + } return false; } const auto num_gemm1_k_inner_loop = LPerBlock / LTilePerBlock; if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) { - printf("GridwiseOp: inner loop unsupport\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: inner loop unsupport\n"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index 8011fa56d3..c8b154228f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -466,20 +467,26 @@ struct GridwiseFpAintBGemm_Wmma if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && K == GetBProblemsizeNK()[I1])) { - printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", - GetAProblemsizeMK()[I0], - GetAProblemsizeMK()[I1], - GetBProblemsizeNK()[I0], - GetBProblemsizeNK()[I1], - c_grid_desc_m_n.GetLength(I0), - c_grid_desc_m_n.GetLength(I1)); - printf("GridwiseOp err: ProblemSize check"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", + GetAProblemsizeMK()[I0], + GetAProblemsizeMK()[I1], + GetBProblemsizeNK()[I0], + GetBProblemsizeNK()[I1], + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + printf("GridwiseOp err: ProblemSize check"); + } return false; } if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) { - printf("GridwiseOp err: ProblemSize division"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp err: ProblemSize division"); + } return false; } @@ -488,7 +495,10 @@ struct GridwiseFpAintBGemm_Wmma if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { - printf("GridwiseOp err: Pipeline not support this k_loop"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp err: Pipeline not support this k_loop"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 46979a5620..7d68d64ed8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -653,13 +654,19 @@ struct GridwiseGemmMultipleD_Wmma if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && K == GetBProblemsizeNK()[I1])) { - printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + } return false; } if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) { - printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + } return false; } @@ -747,20 +754,29 @@ struct GridwiseGemmMultipleD_Wmma if(!valid) { - printf("GridwiseOp: D descriptor dimension check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: D descriptor dimension check failure\n"); + } return false; } if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && K == GetBProblemsizeNK()[I1])) { - printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + } return false; } if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) { - printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 4a15958adb..65f74de3cf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -458,20 +459,26 @@ struct GridwiseGemm_Wmma if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && K == GetBProblemsizeNK()[I1])) { - printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", - GetAProblemsizeMK()[I0], - GetAProblemsizeMK()[I1], - GetBProblemsizeNK()[I0], - GetBProblemsizeNK()[I1], - c_grid_desc_m_n.GetLength(I0), - c_grid_desc_m_n.GetLength(I1)); - printf("GridwiseOp err: ProblemSize check"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", + GetAProblemsizeMK()[I0], + GetAProblemsizeMK()[I1], + GetBProblemsizeNK()[I0], + GetBProblemsizeNK()[I1], + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + printf("GridwiseOp err: ProblemSize check"); + } return false; } if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) { - printf("GridwiseOp err: ProblemSize division"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp err: ProblemSize division"); + } return false; } @@ -480,7 +487,10 @@ struct GridwiseGemm_Wmma if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { - printf("GridwiseOp err: Pipeline not support this k_loop"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp err: Pipeline not support this k_loop"); + } return false; } From b7a806f2442ed04db9e835e3e4e14aaebe3db9b4 Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Tue, 16 Sep 2025 23:47:55 +0800 Subject: [PATCH 06/28] =?UTF-8?q?[CK=5FTILE][REGRESSION]=20Correct=20block?= =?UTF-8?q?Size=20in=20Generic2dBlockShape=20(c254f=E2=80=A6=20(#2837)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [CK_TILE][REGRESSION] Correct blockSize in Generic2dBlockShape (c254f3d7b4ccc ) WarpPerBlock_M * WarpPerBlock_N are not equal with ThreadPerBlock_M * ThreadPerBlock_N /warpSize. we should calculate BlockSize from WarpPerBlock_M * WarpPerBlock_N To compatible with wave32, function GetBlockSize is added to calculate correct size in host side. * fix blocksize for all kernel related with generic2dblockshap * remove constexpr for blocks --- .../add_rmsnorm2d_rdquant_fwd_kernel.hpp | 6 ++- .../ops/common/generic_2d_block_shape.hpp | 51 ++++++++++++------- .../kernel/layernorm2d_fwd_kernel.hpp | 6 ++- .../rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp | 6 ++- .../kernel/moe_smoothquant_kernel.hpp | 6 ++- .../smoothquant/kernel/smoothquant_kernel.hpp | 6 ++- ..._rmsnorm2d_rdquant_fwd_instance_common.hpp | 2 +- .../moe_smoothquant_instance_common.hpp | 2 +- test/ck_tile/rmsnorm2d/generate.py | 2 +- .../instances/smoothquant_instance_common.hpp | 2 +- 10 files changed, 63 insertions(+), 26 deletions(-) diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp index c7717f08cd..b6eac45285 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp @@ -95,7 +95,11 @@ struct AddRmsnorm2dRdquantFwd return dim3(integer_divide_ceil(hargs.m, Block_M)); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/common/generic_2d_block_shape.hpp b/include/ck_tile/ops/common/generic_2d_block_shape.hpp index 333762e5d7..9c5d99efc3 100644 --- a/include/ck_tile/ops/common/generic_2d_block_shape.hpp +++ b/include/ck_tile/ops/common/generic_2d_block_shape.hpp @@ -45,47 +45,57 @@ struct Generic2dBlockShape static constexpr index_t Block_N = BlockTile_::at(number<1>{}); static constexpr index_t ThreadPerBlock_M = ThreadPerBlock_::at(number<0>{}); static constexpr index_t ThreadPerBlock_N = ThreadPerBlock_::at(number<1>{}); - static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N; // vector size along seq static constexpr index_t Vector_M = Vector_::at(number<0>{}); static constexpr index_t Vector_N = Vector_::at(number<1>{}); - static constexpr bool is_warp_per_row = ThreadPerBlock_N <= get_warp_size(); - static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % get_warp_size() == 0); - static constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / get_warp_size(); - // num warps along seq, within each block - static constexpr index_t WarpPerBlock_M = []() { + template + static constexpr index_t GetWarpPerBlock_M() + { + constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size(); + constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size; + static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % warp_size == 0); + constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / warp_size; + if constexpr(is_warp_per_row) { - static_assert(get_warp_size() % ThreadPerBlock_N == 0); - return total_warps * (get_warp_size() / ThreadPerBlock_N); + static_assert(warp_size % ThreadPerBlock_N == 0); + return total_warps * (warp_size / ThreadPerBlock_N); } else { // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N / get_warp_size()); + return total_warps / (ThreadPerBlock_N / warp_size); } - }(); + }; // num of warps along n - static constexpr index_t WarpPerBlock_N = []() { + template + static constexpr index_t GetWarpPerBlock_N() + { + constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size(); + constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size; if constexpr(is_warp_per_row) { - static_assert(get_warp_size() % ThreadPerBlock_N == 0); + static_assert(warp_size % ThreadPerBlock_N == 0); return 1; } else { - static_assert(ThreadPerBlock_N % get_warp_size() == 0); - return ThreadPerBlock_N / get_warp_size(); + static_assert(ThreadPerBlock_N % warp_size == 0); + return ThreadPerBlock_N / warp_size; } - }(); + } + + static constexpr index_t WarpPerBlock_M = GetWarpPerBlock_M(); + static constexpr index_t WarpPerBlock_N = GetWarpPerBlock_N(); // warp size - static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M; - static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N; + static constexpr index_t BlockSize = WarpPerBlock_M * WarpPerBlock_N * get_warp_size(); + static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M; + static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N; static_assert(Warp_M % Vector_M == 0); static_assert(Warp_N % Vector_N == 0); static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0); @@ -98,6 +108,13 @@ struct Generic2dBlockShape // num of threads along seq, within each warp static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + + template + static constexpr index_t GetBlockSize() + { + constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size(); + return GetWarpPerBlock_M() * GetWarpPerBlock_N() * warp_size; + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index 6998b358d8..0181a3291f 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -134,7 +134,11 @@ struct Layernorm2dFwd return dim3(integer_divide_ceil(hargs.m, Block_M)); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp index e7f4ce0ba8..32586a6343 100644 --- a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp @@ -124,7 +124,11 @@ struct Rmsnorm2dFwd return dim3(integer_divide_ceil(hargs.m, Block_M)); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp index b70e996617..2553b19fd8 100644 --- a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp +++ b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp @@ -93,7 +93,11 @@ struct MoeSmoothquant return dim3(hargs.topk, integer_divide_ceil(hargs.tokens, Block_M), 1); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp index 7dc913901e..e0ea9692c5 100644 --- a/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp +++ b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp @@ -82,7 +82,11 @@ struct Smoothquant return dim3(integer_divide_ceil(hargs.m, Block_M)); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp index dd90034064..d997596414 100644 --- a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp @@ -58,7 +58,7 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) using Kernel = ck_tile::AddRmsnorm2dRdquantFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); diff --git a/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp b/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp index f2875c72c8..c6ef822f64 100644 --- a/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp +++ b/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp @@ -53,7 +53,7 @@ float moe_smoothquant_(const S& s, A a) using Kernel = ck_tile::MoeSmoothquant; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); diff --git a/test/ck_tile/rmsnorm2d/generate.py b/test/ck_tile/rmsnorm2d/generate.py index 5eded8b310..3bcc427e83 100644 --- a/test/ck_tile/rmsnorm2d/generate.py +++ b/test/ck_tile/rmsnorm2d/generate.py @@ -201,7 +201,7 @@ float rmsnorm2d_fwd_(const S& s, A a) using Kernel = ck_tile::Rmsnorm2dFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); diff --git a/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp b/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp index 8929289cdb..138afcffaf 100644 --- a/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp +++ b/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp @@ -49,7 +49,7 @@ float smoothquant_(const S& s, A a) using Kernel = ck_tile::Smoothquant; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); From dee185d80c4e6052c532a51949f28e2f74ddc27f Mon Sep 17 00:00:00 2001 From: Emily Martins <65371150+ecamartins@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:21:47 -0600 Subject: [PATCH 07/28] [CK_TILE] Stream-K GEMM Implementation (#2781) * Change splitk_batch_offset parameter to k_size in UniversalGemmKernel::MakeGemmTensorViews function Prior to this change, the splitk_batch_offset parameter of MakeGemmTensorViews had type SplitKBatchOffset. But, the only member variable of the SplitKBatchOffset class used in the MakeGemmTensorViews function was splitted_k (an int32_t). The splitted_k value was used as part of defining the dimensions of the tensor view. That said, for Stream K, we do not need to use the SplitKBatchOffset class since we are not using Split K. Thus, this commit changes the splitk_batch_offset parameter to a int32_t called k_size. This will avoid the constraint of requiring a caller of MakeGemmTensorViews to use the SplitKBatchOffset class while still providing the same functionality. Calls to UniversalGemmKernel::MakeGemmTensorViews have been updated accordingly. * StreamK Kernel RunGemm Implementation Stream K cannot simply use UniversalGemmKernel's RunGemm for the following reasons: 1. The UniversalGemmKernel::RunGemm function computes num_loop based on a static function of the TilePartitioner. That said, for Stream K, num_loop must be computed using a member function (namely GetCurrentIterLength from PR #2708). 2. The UniversalGemmKernel::RunGemm function requires the use of a SplitKBatchOffset object which is not used for Stream K since we are not using Split K. Thus, this change adds a RunGemm function in the StreamKKernel class. * initial implementation for operator() for StreamKKernel: adding stream-k algorithm and calls to RunGemm * Fix indexing and offset issues for StreamK These changes do the following: - Ensure offsets along the M and N dimensions are multiplied by MPerblock or NPerBlock, respectively. This ensures tile window origins are at the correct locations. - Fix bug in the tile partitioner's GetTileIdxWithOffset. Now, we apply divmod to the given references to ensure correct values are available to the caller. - Added documentation in the Stream-K operator() * Initial gtests for Stream-K These changes add an initial gtest suite for the CK Tile Stream-K kernel. Currently, due to bugs in the StreamKTilePartitioner (which will be handled in a future PR), there are validation issues for certain cases which may differ on different architectures. Thus, we opted to run cases that are only fully data-parallel (skipping others). A guard was added to Stream-K's IsSupportedArgument method to ensure that callers are aware of this constraint. Additionally, to ensure testing reproducibility, options for setting the number of CUs and occupancy were added to MakeKernelArgs. * Use GemmPipeline operator() variant that takes hot loop and tail num In Stream-K, the num_loop value varies per WG and per iteration of a Stream-K loop. So instead, we use the version of the GemmPipeline's operator() function that takes in has_hot_loop and tail_num. This is similar to what is done in Grouped GEMM. * changes from review: comments, move readfirstlane, remove ifndef * Switch direction of C tensor traversal & add padding guard Prior to this change, WGs travelled backwards through their assigned macro tiles in the C tensor. For instance, if WG0 is responsible for C tiles 0 and 1, it would first visit tile 1 then tile 0. This means that the iter_end decrements in each iteration of the stream-K while loop. Since we are working with unsigned integers, the subtraction operation may not be safe. Thus, this change makes is such that WGs travel forward so that their iter_start is incremented and their iter_end remains fixed. Additionally, we added a guard against WGs that are neither sk_blocks nor dp_blocks to ensure such WGs do not participate in the GEMM. Together, these changes make is such that the algorithm is correct when sk_blocks is greater than zero. * Disable StreamK_M256_N256_K256_SKBlocks12 test case This instance involves >=3 WGs contributing to each macro tile in C. Due to the use of atomics, this is resulting in precision errors. These errors will not persist once the reduction strategy is implemented. We will re-enable this test then. --------- Co-authored-by: Astha Rai --- .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 19 +- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 4 +- .../ops/gemm/kernel/streamk_gemm_kernel.hpp | 157 +++++++++- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 20 +- test/ck_tile/CMakeLists.txt | 1 + test/ck_tile/gemm_streamk/CMakeLists.txt | 7 + .../gemm_streamk/test_gemm_streamk.cpp | 14 + .../gemm_streamk/test_gemm_streamk_cases.inc | 118 ++++++++ .../gemm_streamk/test_gemm_streamk_types.hpp | 25 ++ .../gemm_streamk/test_gemm_streamk_util.hpp | 282 ++++++++++++++++++ 10 files changed, 612 insertions(+), 35 deletions(-) create mode 100644 test/ck_tile/gemm_streamk/CMakeLists.txt create mode 100644 test/ck_tile/gemm_streamk/test_gemm_streamk.cpp create mode 100644 test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc create mode 100644 test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp create mode 100644 test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 92ae6411a5..a891d4df55 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -646,16 +646,13 @@ struct StreamKTilePartitioner * @brief Get length of loop iterations for stream-k loop */ CK_TILE_DEVICE uint32_t GetCurrentIterLength(uint32_t iter_start, - uint32_t iter_end, - uint32_t total_iter_length) const noexcept + uint32_t iter_end) const noexcept { - uint32_t iter_length_mod, iter_length_quo /*unused*/; - k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod); - uint32_t total_iter_length_val = static_cast(total_iter_length); - uint32_t current_iter_length = - min(iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, - total_iter_length_val); - return current_iter_length; + // A WG's iter_end is either in the current C macro tile or not. + // If it is not, then the macro tile boundary is where the WG must stop. + uint32_t distance_to_tile_boundary = + k_iters_per_tile.get() - (iter_start % k_iters_per_tile.get()); + return min(iter_start + distance_to_tile_boundary, iter_end) - iter_start; } /** @@ -672,9 +669,7 @@ struct StreamKTilePartitioner CK_TILE_DEVICE void GetTileIdxWithOffset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const noexcept { - uint32_t tile_idx_val = static_cast(tile_idx); - uint32_t iter_offset_val = static_cast(iter_offset); - k_iters_per_tile.divmod(iter, tile_idx_val, iter_offset_val); + k_iters_per_tile.divmod(iter, tile_idx, iter_offset); } /** diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 704d0d01ee..dda38bbc47 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -374,7 +374,7 @@ struct GroupedGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset); + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = @@ -436,7 +436,7 @@ struct GroupedGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset); + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp index 77c431e49c..5df1f092d7 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp @@ -141,11 +141,17 @@ struct StreamKKernel return UniversalGemmKernel::BlockSize(); } - CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args) + /// @brief Constructs kernel arguments for the Stream-K kernel. + /// @param host_args Stream-K host arguments. + /// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device. + /// The caller may select their own to assist with test reproducibility, etc. + /// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may + /// select their own to assist with test reproducibility, etc. + /// @return The kernel arguments for Stream-K. + CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args, + int num_cu = NumCU(), + int occupancy = Occupancy()) { - uint32_t occupancy = static_cast(Occupancy()); - uint32_t num_cu = static_cast(NumCU()); - return StreamKKernelArgs{{host_args.as_ptr, host_args.bs_ptr, host_args.ds_ptr, @@ -166,14 +172,71 @@ struct StreamKKernel TilePartitioner{static_cast(host_args.M), static_cast(host_args.N), static_cast(host_args.K), - num_cu, - occupancy, + static_cast(num_cu), + static_cast(occupancy), host_args.num_sk_blocks}}; } - CK_TILE_HOST static bool - IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) + template + CK_TILE_DEVICE static void + RunGemm(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + CDataType* c_ptr, + void* smem_ptr_0, + const typename UniversalGemmKernel::KernelArgs& kargs, + const index_t num_loop, + const index_t block_idx_m, + const index_t block_idx_n, + const index_t k_size) { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + UniversalGemmKernel::template MakeGemmTensorViews( + as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size); + + const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = + UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0); + const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1); + const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2); + + // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute + // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this + // case, we call the GemmPipeline's operator() function that takes both has_hot_loop and + // tail_num. + const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + + const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0], + bs_block_window[UniversalGemmKernel::I0], + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + + if(UseDefaultScheduler || (get_warp_id() == 0)) + { + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + } + + CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs) + { + if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy."); + } + return false; + } return UniversalGemmKernel::IsSupportedArgument(kargs); } @@ -199,9 +262,81 @@ struct StreamKKernel kargs.workspace_ptr = workspace_ptr; } - // Temporary placeholder to support the Occupancy() static function. - // Since the Occupancy function uses kentry, this class must have an operator() function - CK_TILE_DEVICE void operator()(StreamKKernelArgs /*kargs*/) const {} + /// @brief Entry point for the Stream-K Kernel, performing the main Stream-K loop. + CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const + { + // Allocate LDS + __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()]; + + uint32_t block_idx = ck_tile::get_block_1d_id(); + + bool is_padding_block = + __builtin_amdgcn_readfirstlane(block_idx >= kargs.tile_partitioner.sk_num_blocks && + block_idx < kargs.tile_partitioner.dp_start_block_idx); + + // Padding blocks make it such that the DP blocks are aligned with the number of CUs; they + // should not partake in the GEMM + if(is_padding_block) + return; + + // Determine the K offset of the first and final macro tile in the A and B tensors along the + // K dimension. + uint32_t iter_start, iter_end; + kargs.tile_partitioner.GetBlockItr(block_idx, iter_start, iter_end); + + // Main Stream-K loop + while(true) + { + // Determine the number of macro tiles in A and B this WG is resposible for in the + // current C macro tile. + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + kargs.tile_partitioner.GetCurrentIterLength(iter_start, iter_end)); + + // Determine the 1D tile_idx and the iter_offset for this WG. + // The tile_idx is the 1D macro tile index in the C tensor. + // The iter_offset is the starting macro tile index in the K dimension for the WG in the + // current iteration of the while loop. + uint32_t tile_idx, iter_offset; + kargs.tile_partitioner.GetTileIdxWithOffset(iter_start, tile_idx, iter_offset); + + // Get the 2D tile index in the C tensor for this WG using the 1D index (i.e. tile_idx) + auto spatial_idx = kargs.tile_partitioner.GetOutputTileIndex(tile_idx); + + // Get the offsets in A, B, C tensors. + index_t i_m = static_cast(spatial_idx[UniversalGemmKernel::I0] * + TilePartitioner::MPerBlock); + index_t i_n = static_cast(spatial_idx[UniversalGemmKernel::I1] * + TilePartitioner::NPerBlock); + index_t i_k = static_cast(iter_offset) * TilePartitioner::KPerBlock; + + // Determine the total size along the K dimension the WG is using in this iteration + // (used to construct tensor views). + index_t k_size = static_cast(current_iter_length * TilePartitioner::KPerBlock); + + // Update pointer offsets for A, B, and C. + const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + i_k; + const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + i_k; + CDataType* c_ptr = static_cast(kargs.e_ptr); + + // Run the GEMM pipeline and Epilogue. + RunGemm({a_ptr}, + {b_ptr}, + {/*ds_ptr*/}, + c_ptr, + smem_ptr_0, + kargs, + current_iter_length, + i_m, + i_n, + k_size); + + // Prepare for next Stream-K loop iteration. + iter_start += current_iter_length; + if(iter_end <= iter_start) + break; + block_sync_lds(); + } + } private: CK_TILE_HOST static int NumCU() diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 8117d65758..cfba8b6c9d 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -579,7 +579,7 @@ struct UniversalGemmKernel const std::array& ds_ptr, EDataType* e_ptr, const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + const index_t k_size) { static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); @@ -591,7 +591,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( static_cast(as_ptr[i]), - make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.M, k_size), make_tuple(kargs.stride_As[i], 1), number{}, number<1>{}); @@ -600,7 +600,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( static_cast(as_ptr[i]), - make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(k_size, kargs.M), make_tuple(kargs.stride_As[i], 1), number{}, number<1>{}); @@ -617,7 +617,7 @@ struct UniversalGemmKernel if constexpr(TilePartitioner::BlockGemmShape::PermuteB) { constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; + const index_t K0 = k_size / K1; constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); const auto b_k0_n_k1_desc = @@ -638,7 +638,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( bs_ptr[i], - make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(k_size, kargs.N), make_tuple(kargs.stride_Bs[i], 1), number{}, number<1>{}); @@ -649,7 +649,7 @@ struct UniversalGemmKernel if constexpr(TilePartitioner::BlockGemmShape::PermuteB) { constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; + const index_t K0 = k_size / K1; constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); const auto b_k0_n_k1_desc = @@ -672,7 +672,7 @@ struct UniversalGemmKernel { index_t kFlatK = GemmPipeline::BlockGemmShape::flatKPerWarp * - (splitk_batch_offset.splitted_k / + (k_size / TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})); index_t kFlatN = kargs.N * kargs.K / kFlatK; @@ -687,7 +687,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( bs_ptr[i], - make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.N, k_size), make_tuple(kargs.stride_Bs[i], 1), number{}, number<1>{}); @@ -962,7 +962,7 @@ struct UniversalGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -1018,7 +1018,7 @@ struct UniversalGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 993df2ec40..32230bbce2 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(gemm_weight_preshuffle) add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) add_subdirectory(gemm_multi_d) +add_subdirectory(gemm_streamk) add_subdirectory(data_type) add_subdirectory(container) add_subdirectory(elementwise) diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt new file mode 100644 index 0000000000..e00874ba07 --- /dev/null +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -0,0 +1,7 @@ +# Currently test_ck_tile_streamk is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9") + #TODO: support all arches + add_gtest_executable(test_ck_tile_streamk test_gemm_streamk.cpp) +else() + message(DEBUG "Skipping test_ck_tile_streamk tests for current target") +endif() diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk.cpp b/test/ck_tile/gemm_streamk/test_gemm_streamk.cpp new file mode 100644 index 0000000000..99c3fb397f --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk.cpp @@ -0,0 +1,14 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_types.hpp" +#include "test_gemm_streamk_util.hpp" +#include "gtest/gtest.h" + +#define TEST_SUITE_NAME TestCkTileStreamK + +TYPED_TEST_SUITE(TestCkTileStreamK, KernelTypesStreamK); + +#include "test_gemm_streamk_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc b/test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc new file mode 100644 index 0000000000..1db7ef0fb0 --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc @@ -0,0 +1,118 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_DP) +{ + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 0; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks4) +{ + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 4; + + this->Run(M, N, K, num_sk_blocks); +} + +// TODO: Renable this test once reduction is implemented +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks12) +{ + GTEST_SKIP() << "Skipping this test: There are precision issues with atomics due to >=3 WGs " + "contributing to each macro tile in C"; + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 12; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks8) +{ + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 8; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_DP) +{ + + ck_tile::index_t M = 512; + ck_tile::index_t N = 512; + ck_tile::index_t K = 512; + uint32_t num_sk_blocks = 0; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_SKBlocks16) +{ + + ck_tile::index_t M = 512; + ck_tile::index_t N = 512; + ck_tile::index_t K = 512; + uint32_t num_sk_blocks = 16; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_SKBlocks8) +{ + + ck_tile::index_t M = 512; + ck_tile::index_t N = 512; + ck_tile::index_t K = 512; + uint32_t num_sk_blocks = 8; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M3840_N4096_K4096_DP) +{ + + ck_tile::index_t M = 3840; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + uint32_t num_sk_blocks = 0; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M3840_N4096_K4096_SKBlocks64) +{ + + ck_tile::index_t M = 3840; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + uint32_t num_sk_blocks = 64; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_Unsupported_Reduction) +{ + + ck_tile::index_t M = 3840; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + uint32_t num_sk_blocks = 64; + + EXPECT_THROW(this->Run(M, N, K, num_sk_blocks, ck_tile::StreamKReductionStrategy::Reduction), + std::runtime_error); +} diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp new file mode 100644 index 0000000000..399f3f11e8 --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp @@ -0,0 +1,25 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using BF16 = ck_tile::bf16_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +using KernelTypesStreamK = ::testing::Types< +// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType + std::tuple< Row, Col, Row, F16, F16, F32, F16>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16> +>; + +// clang-format on diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp new file mode 100644 index 0000000000..b8a55b024d --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -0,0 +1,282 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // The logic below may need to become more advanced once bugs in Stream-K Tile Partitioner are + // resolved. Because the number of WGs contributing to a macro tile in C may not be the same for + // all macro tiles in C. + + // Calculate error due to more than 1 WG contributing to the same macro tile in C + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +class TestCkTileStreamK : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using AccDataType = std::tuple_element_t<5, Tuple>; + using CDataType = std::tuple_element_t<6, Tuple>; + using DsLayout = ck_tile::tuple<>; + using DsDataType = ck_tile::tuple<>; + + template + void invoke_streamk(const ck_tile::StreamKHostArgs& args, + const ck_tile::stream_config& s, + int num_cu, + int occupancy) + { + + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool kPadM = PadM; + constexpr bool kPadN = PadN; + constexpr bool kPadK = PadK; + constexpr bool preshuffle = Preshuffle; + + constexpr bool DoubleSmemBuffer = false; + constexpr int kBlockPerCu = 1; + constexpr bool StructuredSparsity = false; + constexpr bool NumWaveGroup = 1; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::StreamKTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + + // We create the GEMM pipeline without specifying has_hot_loop or tail_num. + // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K + // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K + // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + // For initial testing, we will just test with one pipeline. + // More extensive testing is coming later and will test other pipelines. + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + UniversalGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::StreamKKernel; + + auto kargs = Kernel::MakeKernelArgs(args, num_cu, occupancy); + + if(!Kernel::IsSupportedArgument(kargs)) + { + EXPECT_TRUE(false); + } + + dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner); + dim3 block_dims = Kernel::BlockSize(); + + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grid_dims, block_dims, 0, kargs)); + }; + + Run(ck_tile::integral_constant{}); + } + + public: + // Since Stream-K is build on gfx9, the lower bound for CUs is 104. Thus, we default num_cu to + // 104 and occupancy to 1 to ensure tests are reproducible on different architectures. + void Run(ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + uint32_t num_sk_blocks = 0xffffffff, + ck_tile::StreamKReductionStrategy reduction_strategy = + ck_tile::StreamKReductionStrategy::Atomic, + int occupancy = 1, + int num_cu = 104, + ck_tile::index_t stride_A = 0, + ck_tile::index_t stride_B = 0, + ck_tile::index_t stride_C = 0) + { + + using namespace ck_tile::literals; + + if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction) + { + throw std::runtime_error("Reduction Strategy is current unsupported!\n"); + } + + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + stride_A = f_get_default_stride(M, K, stride_A, ALayout{}); + stride_B = f_get_default_stride(K, N, stride_B, BLayout{}); + stride_C = f_get_default_stride(M, N, stride_C, CLayout{}); + + ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); + ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); + ck_tile::HostTensor c_m_n_dev_result( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + + ck_tile::FillUniformDistributionIntegerValue{-5, 5, /*seed*/ 11939}(a_m_k); + ck_tile::FillUniformDistributionIntegerValue{-5, 5, /*seed*/ 11940}(b_k_n); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + M, + N, + K, + stride_A, + stride_B, + stride_C, + reduction_strategy, + num_sk_blocks}; + + invoke_streamk( + args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + ck_tile::HostTensor c_m_n_host_ref( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, /*kbatch*/ 1, max_accumulated_value); + + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + EXPECT_TRUE(pass); + }; +}; From 48e08c64298893c282b869775da67b5a3a97f624 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 16 Sep 2025 18:43:30 -0400 Subject: [PATCH 08/28] test(grouped_gemm): add gtests for the example/grouped_gemm_preshuffle to ensure its integrity (#2811) * test(grouped_gemm): add gtests for the example to maintain its integrity * test(grouped_gemm_preshuffle): add prefill variant to testbed to cover wider range * fix: removed residue code to make b_shuffle() work again * test(grouped_gemm_preshuffle): limit the test suite to gfx942 arch as it fails on gfx90a * build: add gfx950 as build target for gtests * test(grouped_gemm_preshuffle): temporarily disable fp8 prec tests due to numerical errors * fix(grouped_gemm_preshuffle): resolved fp8 tests failure on gfx950 by adding correct compiler flag --- .../ck_tile/17_grouped_gemm/CMakeLists.txt | 7 + test/ck_tile/CMakeLists.txt | 1 + .../grouped_gemm_preshuffle/CMakeLists.txt | 9 + .../test_grouped_gemm_preshuffle.cpp | 58 +++ ..._grouped_gemm_preshuffle_prefill_cases.inc | 61 +++ .../test_grouped_gemm_preshuffle_ut_cases.inc | 53 +++ .../test_grouped_gemm_preshuffle_util.hpp | 374 ++++++++++++++++++ 7 files changed, 563 insertions(+) create mode 100644 test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt create mode 100644 test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp create mode 100644 test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_prefill_cases.inc create mode 100644 test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_ut_cases.inc create mode 100644 test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index 8e8026d88d..f97cc03d2a 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -1,3 +1,10 @@ add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp) add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp) + + +set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) \ No newline at end of file diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 32230bbce2..9314d4b795 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(gemm) add_subdirectory(gemm_weight_preshuffle) add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) +add_subdirectory(grouped_gemm_preshuffle) add_subdirectory(gemm_multi_d) add_subdirectory(gemm_streamk) add_subdirectory(data_type) diff --git a/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt b/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt new file mode 100644 index 0000000000..68120efc7e --- /dev/null +++ b/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt @@ -0,0 +1,9 @@ +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx94|gfx95") + add_gtest_executable(test_ck_tile_grouped_gemm_preshuffle test_grouped_gemm_preshuffle.cpp) + target_compile_options(test_ck_tile_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp new file mode 100644 index 0000000000..cf10853b3f --- /dev/null +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_preshuffle_util.hpp" + +using F16 = ck_tile::half_t; +using F8 = ck_tile::fp8_t; +using F32 = float; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// Custom tuple-like structure for kernel configuration +template +struct KernelConfig +{ + using ALayoutType = ALayout_; + using BLayoutType = BLayout_; + using CLayoutType = CLayout_; + using ADataType = ADataType_; + using BDataType = BDataType_; + using AccDataType = AccDataType_; + using CDataType = CDataType_; + + static constexpr int M_Tile_ = M_Tile_val_; + static constexpr int N_Tile_ = N_Tile_val_; + static constexpr int K_Tile_ = K_Tile_val_; + static constexpr int BlockPerCu_ = BlockPerCu_val_; +}; + +// clang-format off +using KernelTypes = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_Tile, N_Tile, K_Tile, BlockPerCu + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F8, F8, F32, F16, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 128, 128, 2>, + KernelConfig< Row, Col, Row, F8, F8, F32, F16, 128, 128, 128, 2> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmPreshuffle, KernelTypes); + +#include "test_grouped_gemm_preshuffle_ut_cases.inc" +#include "test_grouped_gemm_preshuffle_prefill_cases.inc" diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_prefill_cases.inc b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_prefill_cases.inc new file mode 100644 index 0000000000..340d807ba2 --- /dev/null +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_prefill_cases.inc @@ -0,0 +1,61 @@ +#pragma once + +// Test with prefill config struct +TYPED_TEST(TestCkTileGroupedGemmPreshuffle, PrefillVariant) +{ + const int group_count = 4; + const int kbatch = 1; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + for(int i = 0; i < group_count; i++) + { + + Ms.push_back(256 + 128 * i); + Ns.push_back(256 + 128 * i); + Ks.push_back(128 * (i + 1)); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Cs.push_back(Ns[i]); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count); +} + +TYPED_TEST(TestCkTileGroupedGemmPreshuffle, VariedDimensions) +{ + const int group_count = 6; + const int kbatch = 1; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + std::vector> test_cases = {{64, 128, 256}, + {128, 256, 512}, + {256, 512, 1024}, + {512, 256, 128}, + {128, 128, 128}, + {64, 512, 256}}; + + for(int i = 0; i < group_count; i++) + { + auto [M, N, K] = test_cases[i]; + Ms.push_back(M); + Ns.push_back(N); + Ks.push_back(K); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Cs.push_back(Ns[i]); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count); +} diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_ut_cases.inc b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_ut_cases.inc new file mode 100644 index 0000000000..beca5e62b5 --- /dev/null +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_ut_cases.inc @@ -0,0 +1,53 @@ +#pragma once + +// kPadK is not needed for these k values +TYPED_TEST(TestCkTileGroupedGemmPreshuffle, kPadKFalse) +{ + const int group_count = 4; + const int kbatch = 1; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 256 * i); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Cs.push_back(Ns[i]); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count); +} + +// kPadK is needed to be true for these k values +TYPED_TEST(TestCkTileGroupedGemmPreshuffle, kPadKTrue) +{ + const int group_count = 4; + const int kbatch = 1; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 128 * i); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Cs.push_back(Ns[i]); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count); +} diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp new file mode 100644 index 0000000000..799a5f2907 --- /dev/null +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -0,0 +1,374 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +template +constexpr ck_tile::index_t get_k_warp_tile_flatmm() +{ +#if defined(CK_GFX950_SUPPORT) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} + +template +class TestCkTileGroupedGemmPreshuffle : public ::testing::Test +{ + protected: + using ALayout = typename Tuple::ALayoutType; + using BLayout = typename Tuple::BLayoutType; + using CLayout = typename Tuple::CLayoutType; + using ADataType = typename Tuple::ADataType; + using BDataType = typename Tuple::BDataType; + using AccDataType = typename Tuple::AccDataType; + using CDataType = typename Tuple::CDataType; + using PrecType = BDataType; + using DsLayout = ck_tile::tuple<>; // not used + using DsDataType = ck_tile::tuple<>; // not used + + static const bool kPadM = false; + static const bool kPadN = false; + static const bool kPadK = true; // preshuffle pipeline requires k padding + + static const int kBlockPerCu = Tuple::BlockPerCu_; + + // Tile dimensions from tuple + static const ck_tile::index_t M_Tile = Tuple::M_Tile_; + static const ck_tile::index_t N_Tile = Tuple::N_Tile_; + static const ck_tile::index_t K_Tile = Tuple::K_Tile_; + + static const ck_tile::index_t M_Warp = 1; + static const ck_tile::index_t N_Warp = 4; + static const ck_tile::index_t K_Warp = 1; + + static const ck_tile::index_t M_Warp_Tile = 16; + static const ck_tile::index_t N_Warp_Tile = 16; + static const ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem + static constexpr bool TransposeC = false; // transpose c is not supported + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + + template + auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) + { + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + } + + using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; + inline std::size_t get_workspace_size(const std::vector& gemm_descs) + { + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); + } + + template + auto shuffle_b(const ck_tile::HostTensor& t) + { + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4; + ck_tile::HostTensor t_view( + {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + + template + void invoke_grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) + { + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + // for testing purposes, we can hardcode the values here as we what is compatible with + // pipeline + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = + ck_tile::GemmSpatiallyLocalTilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = + ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); + + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + // EXPECT TO FAIL because splitk is not supported + EXPECT_FALSE(true); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + } + + public: + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + std::vector& stride_As, + std::vector& stride_Bs, + std::vector& stride_Cs, + const int kbatch = 1, + const int group_count = 16) + { + + using namespace ck_tile::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + c_m_n_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + c_m_n_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + stride_As[i] = f_get_default_stride(M, K, stride_As[i], ALayout{}); + stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], BLayout{}); + stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{}); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + f_host_tensor_descriptor(M, K, stride_As[i], ALayout{}))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + f_host_tensor_descriptor(K, N, stride_Bs[i], BLayout{}))); + c_m_n_tensors.push_back(ck_tile::HostTensor( + f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + + // Host-side preshuffle of B + auto b_shuffle_host = shuffle_b(b_k_n_tensors[i]); + + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_shuffle_host.get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back(std::make_unique( + c_m_n_tensors[i].get_element_space_size_in_bytes())); + + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data()); + c_m_n_dev_buf[i]->SetZero(); + c_m_n_tensors[i].SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + + gemm_descs.push_back( + {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + } + + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + + invoke_grouped_gemm(gemm_descs, + ck_tile::stream_config{nullptr, false, 1}, + gemm_workspace.GetDeviceBuffer()); + + // Copy results back to host for validation + for(int i = 0; i < group_count; i++) + { + c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data()); + } + + bool pass{true}; + for(int i = 0; i < group_count; ++i) + { + ck_tile::HostTensor c_m_n_host_ref( + f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{})); + c_m_n_host_ref.SetZero(); + ck_tile::reference_gemm( + a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + Ks[i], kbatch, max_accumulated_value); + pass &= ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + } + EXPECT_TRUE(pass); + } +}; From 2723dbd33245b76bfe716c5adc8c9fb577a4b68f Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 16 Sep 2025 18:47:21 -0400 Subject: [PATCH 09/28] feat(tile_window): print content of tile window for easier debugging (#2827) * feat(tile_window): add function to print content of tile windowof static length, given a 2D range * chore: make documentation less verbose --- include/ck_tile/core/tensor/tile_window.hpp | 52 +++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index f5ddcd278c..4cecf5fc8d 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -887,6 +887,58 @@ struct tile_window_with_static_lengths this->window_lengths_ = window_lengths; this->bottom_tensor_view_ = bottom_tensor_view; } + + /** + * @brief Print tile window elements for debugging. + * + * @tparam DataType Element data type (e.g., fp16_t, float, bf8_t) + * @param start_i Starting row (inclusive) + * @param end_i Ending row (exclusive) + * @param start_j Starting column (inclusive) + * @param end_j Ending column (exclusive) + * @param label Optional output label + * + * @note Tested on fp16. Custom types may need adjustments. + * @example tile_window.template print_tile_window_range(0, 4, 0, 8, "A"); + */ + template + CK_TILE_DEVICE void print_tile_window_range(index_t start_i, + index_t end_i, + index_t start_j, + index_t end_j, + const char* label = "") const + { + const auto& tensor_view = this->get_bottom_tensor_view(); + const auto window_origin = this->get_window_origin(); + + printf("%s Window Range [%d:%d, %d:%d] (origin: %d, %d):\n", + label, + start_i, + end_i - 1, + start_j, + end_j - 1, + window_origin[0], + window_origin[1]); + + for(index_t i = start_i; i < end_i; i++) + { + for(index_t j = start_j; j < end_j; j++) + { + // Create coordinate for this element relative to window origin + auto coord = + make_tensor_coordinate(tensor_view.get_tensor_descriptor(), + make_tuple(window_origin[0] + i, window_origin[1] + j)); + + // Get the element using thread buffer type directly + using ThreadBuf = thread_buffer; + auto buf = tensor_view.template get_vectorized_elements(coord, 0); + auto value = buf.at(number<0>{}); // Extract first element from thread buffer + printf(" %s[%d,%d] = %f", label, i, j, static_cast(value)); + } + printf("\n"); + } + printf("\n"); + } }; template From f97b2a3f5d331009188c4601bd986a7b53a1ce2b Mon Sep 17 00:00:00 2001 From: Wojciech Laskowski <77888887+wj-laskowski@users.noreply.github.com> Date: Wed, 17 Sep 2025 01:23:29 +0200 Subject: [PATCH 10/28] Added wmma support for gemm quantization: (#2841) - profiler for gemm quantization for DL/XDL - tests for gemm quantization for DL/XDL - implementation for gemm quantization for WMMA - profiler/tests for gemm qunatization for WMMA Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- example/14_gemm_quantization/CMakeLists.txt | 1 + .../gemm_wmma_quantization_int8.cpp | 211 ++++++++++++++++ .../device_gemm_wmma_cshuffle_v3_common.hpp | 5 +- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 14 ++ .../gpu/quantization/gemm_quantization.hpp | 180 +++++++++++++- .../gpu/quantization/CMakeLists.txt | 6 + ...ation_wmma_c_shuffle_i8_i8_i8_instance.hpp | 79 ++++++ ...a_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp | 41 ++++ ...a_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp | 41 ++++ ...a_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp | 41 ++++ ...a_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp | 41 ++++ .../gemm/gemm_quantization_common.hpp | 5 +- .../profile_gemm_quantization_impl.hpp | 231 ++++++++++++++++++ profiler/src/CMakeLists.txt | 9 + profiler/src/profile_gemm_quantization.cpp | 115 +++++++++ test/CMakeLists.txt | 1 + test/quantization/CMakeLists.txt | 2 + test/quantization/gemm/CMakeLists.txt | 9 + .../gemm/test_gemm_quantization.cpp | 40 +++ .../gemm/test_gemm_quantization_ut_cases.inc | 41 ++++ .../gemm/test_gemm_quantization_util.hpp | 62 +++++ 21 files changed, 1167 insertions(+), 8 deletions(-) create mode 100644 example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp create mode 100644 profiler/include/profiler/profile_gemm_quantization_impl.hpp create mode 100644 profiler/src/profile_gemm_quantization.cpp create mode 100644 test/quantization/CMakeLists.txt create mode 100644 test/quantization/gemm/CMakeLists.txt create mode 100644 test/quantization/gemm/test_gemm_quantization.cpp create mode 100644 test/quantization/gemm/test_gemm_quantization_ut_cases.inc create mode 100644 test/quantization/gemm/test_gemm_quantization_util.hpp diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt index 8703fa3ed7..b058e7b0fa 100644 --- a/example/14_gemm_quantization/CMakeLists.txt +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -1,3 +1,4 @@ add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) +add_example_executable(example_gemm_wmma_quantization_int8 gemm_wmma_quantization_int8.cpp) add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) diff --git a/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp b/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp new file mode 100644 index 0000000000..a3023997a1 --- /dev/null +++ b/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp @@ -0,0 +1,211 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using I8 = int8_t; +using I32 = int32_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using DsDataType = ck::Tuple<>; +using EDataType = I8; + +using ALayout = Col; +using BLayout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + ActivationOp, + ActivationOp, + CDEElementOp, + GemmDefault, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<4, 64, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + true, + S<4, 64, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + true, + 1, + 1, + S<1, 32, 1, 8>, + S<1>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + I8, + I8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int /* argc */, char* /* argv */[]) +{ + bool do_verification = true; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideE = N; + + float requant_scale = 0.03; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}}; + + // device GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + std::array{}, + static_cast(e_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index 55aa7b59ee..72191632d8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -3,6 +3,7 @@ #pragma once +#include #include #include @@ -171,8 +172,8 @@ struct DeviceGemm_Wmma_CShuffleV3_Common // other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot // be odd. constexpr bool AtomicsImplementationExists = - !(std::is_same_v || - std::is_same_v) || + !(std::is_same_v || std::is_same_v || + std::is_same_v) || (CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0); if(has_main_k_block_loop) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index b226730a09..59d3a6a4c5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -1065,6 +1065,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } } + if constexpr(is_same, int8_t>::value) + { + if(karg.KBatch > 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "int8_t does not support KBatch > 1. KBatch: " << karg.KBatch + << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp index 19600a90f8..9f148618ae 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -77,6 +77,8 @@ void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( Activation_Mul_Clamp>>>& instances); #endif + +#ifdef CK_USE_XDL // Layout(A, B, C) = [Col, Row, Row] void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( std::vector>>>& instances); +#endif + +#ifdef CK_USE_WMMA +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>>& + instances); + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>>& + instances); + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>>& + instances); + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>>& + instances); +#endif template && is_same_v && @@ -195,7 +258,9 @@ struct DeviceOperationInstanceFactory && is_same_v && @@ -206,7 +271,9 @@ struct DeviceOperationInstanceFactory && is_same_v && @@ -217,12 +284,117 @@ struct DeviceOperationInstanceFactory>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif + + return op_ptrs; + } +}; + +template +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGemmMultipleDSplitK>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_USE_WMMA + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances( + op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances( + op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + op_ptrs); + } + } + } +#endif + return op_ptrs; } }; @@ -230,4 +402,4 @@ struct DeviceOperationInstanceFactory +using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | | + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, 1, 1, S<1, 16, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t> + // clang-format on + >; + +template +using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | | + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t> + // clang-format on + >; + +template +using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | | + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, 1, 1, S<1, 32, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t> + // clang-format on + >; + +template +using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | | + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..a3838bb398 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v3>{}); + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..31ff723166 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v3>{}); + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..07a632a77c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v3>{}); + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..ed9cc908ef --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v3>{}); + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp index e7c2500fef..a4eb29c7a1 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -33,7 +33,8 @@ using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp< using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; -static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; } // namespace instance } // namespace device diff --git a/profiler/include/profiler/profile_gemm_quantization_impl.hpp b/profiler/include/profiler/profile_gemm_quantization_impl.hpp new file mode 100644 index 0000000000..a115a41a34 --- /dev/null +++ b/profiler/include/profiler/profile_gemm_quantization_impl.hpp @@ -0,0 +1,231 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_quantization_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideE, + float requant_scale = 0.03f) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using MulClamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using ActivationOp = PassThrough; + using CDEElementOp = MulClamp; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple<>, + ELayout, + ADataType, + BDataType, + ck::Tuple<>, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Activation_Mul_Clamp>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + + if(do_log) + { + LogRangeAsType( + std::cout << "e_m_n_device_result: ", e_m_n_device_result.mData, ",") + << std::endl; + + LogRangeAsType( + std::cout << "e_m_n_host_result: ", e_m_n_host_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 7cfdc5bfc9..31f684fe75 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -32,6 +32,7 @@ set(PROFILER_OPS profile_conv_tensor_rearrange.cpp profile_transpose.cpp profile_permute_scale.cpp + profile_gemm_quantization.cpp ) if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") @@ -112,6 +113,10 @@ if(DL_KERNELS) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) endif() +if(CK_ENABLE_INT8) + list(APPEND PROFILER_OPS profile_gemm_quantization.cpp) +endif() + set(PROFILER_SOURCES profiler.cpp) foreach(SOURCE ${PROFILER_OPS}) string(REGEX REPLACE "profile_(.+)\.cpp" "\\1" OP_NAME ${SOURCE}) @@ -248,6 +253,10 @@ if(DL_KERNELS) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) endif() +if(CK_ENABLE_INT8) + list(APPEND DEVICE_INSTANCES device_quantization_instance) +endif() + set(PROFILER_LIBS utility getopt::getopt) foreach(LIB ${DEVICE_INSTANCES}) string(REGEX REPLACE "device_(.+)_instance" "\\1" INSTANCE_NAME ${LIB}) diff --git a/profiler/src/profile_gemm_quantization.cpp b/profiler/src/profile_gemm_quantization.cpp new file mode 100644 index 0000000000..d28dd60dce --- /dev/null +++ b/profiler/src/profile_gemm_quantization.cpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_gemm_quantization_impl.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "gemm_quantization" +#define OP_DESC "GEMM Quantization" + +using INT8 = int8_t; +using INT32 = int32_t; + +int profile_gemm_quantization(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN, // 0: + MK_NK_MN, // 1: + KM_KN_MN, // 2: + KM_NK_MN, // 3: + }; + + if(argc != 14) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: matrix layout (0: E[m, n] = A[m, k] * B[k, n];\n"); + printf(" 1: E[m, n] = A[m, k] * B[n, k];\n"); + printf(" 2: E[m, n] = A[k, m] * B[k, n];\n"); + printf(" 3: E[m, n] = A[k, m] * B[n, k])\n"); + printf("arg3: verification (0: no; 1: yes)\n"); + printf("arg4: initialization (0: no init; default: integer value)\n"); + printf("arg5: print tensor value (0: no; 1: yes)\n"); + printf("arg6: time kernel (0=no, 1=yes)\n"); + printf("arg7 to 12: M, N, K, StrideA, StrideB, StrideE\n"); + printf("arg13: requant_scale (float, e.g., 0.03)\n"); + // clang-format on + exit(1); + } + + const auto layout = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + + const int M = std::stoi(argv[7]); + const int N = std::stoi(argv[8]); + const int K = std::stoi(argv[9]); + + const int StrideA = std::stoi(argv[10]); + const int StrideB = std::stoi(argv[11]); + const int StrideE = std::stoi(argv[12]); + + const float requant_scale = std::stof(argv[13]); + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_layout, auto b_layout, auto e_layout) { + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using ELayout = decltype(e_layout); + + bool pass = ck::profiler::profile_gemm_quantization_impl(do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + StrideA, + StrideB, + StrideE, + requant_scale); + + return pass ? 0 : 1; + }; + + if(layout == MatrixLayout::MK_KN_MN) + { + return profile(Row{}, Row{}, Row{}); + } + else if(layout == MatrixLayout::MK_NK_MN) + { + return profile(Row{}, Col{}, Row{}); + } + else if(layout == MatrixLayout::KM_KN_MN) + { + return profile(Col{}, Row{}, Row{}); + } + else if(layout == MatrixLayout::KM_NK_MN) + { + return profile(Col{}, Col{}, Row{}); + } + else + { + std::cout << "this layout is not implemented" << std::endl; + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_quantization); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f898f67685..cedac568db 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -277,6 +277,7 @@ add_subdirectory(conv_tensor_rearrange) add_subdirectory(transpose) add_subdirectory(permute_scale) add_subdirectory(wrapper) +add_subdirectory(quantization) if(SUPPORTED_GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() diff --git a/test/quantization/CMakeLists.txt b/test/quantization/CMakeLists.txt new file mode 100644 index 0000000000..89a99f5e5d --- /dev/null +++ b/test/quantization/CMakeLists.txt @@ -0,0 +1,2 @@ +add_custom_target(test_quantization) +add_subdirectory(gemm) diff --git a/test/quantization/gemm/CMakeLists.txt b/test/quantization/gemm/CMakeLists.txt new file mode 100644 index 0000000000..630e6e09c9 --- /dev/null +++ b/test/quantization/gemm/CMakeLists.txt @@ -0,0 +1,9 @@ +add_custom_target(test_gemm_quantization_targets) + +add_gtest_executable(test_gemm_quantization test_gemm_quantization.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_quantization PRIVATE utility device_quantization_instance) + add_dependencies(test_gemm_quantization_targets test_gemm_quantization) +endif() + +add_dependencies(test_quantization test_gemm_quantization_targets) diff --git a/test/quantization/gemm/test_gemm_quantization.cpp b/test/quantization/gemm/test_gemm_quantization.cpp new file mode 100644 index 0000000000..9981ae8a41 --- /dev/null +++ b/test/quantization/gemm/test_gemm_quantization.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_quantization_impl.hpp" +#include "test_gemm_quantization_util.hpp" + +using I8 = int8_t; +using I32 = int32_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +class TestGemmQuantization : public ck::test::TestGemmQuantizationCommon +{ + protected: + using ProfileCall = bool (*const)(int, int, bool, bool, int, int, int, int, int, int, float); + + ProfileCall GetImpl() override + { + return &ck::profiler::profile_gemm_quantization_impl< + typename ck::test::TestGemmQuantizationCommon::ADataType, + typename ck::test::TestGemmQuantizationCommon::BDataType, + typename ck::test::TestGemmQuantizationCommon::AccDataType, + typename ck::test::TestGemmQuantizationCommon::EDataType, + typename ck::test::TestGemmQuantizationCommon::ALayout, + typename ck::test::TestGemmQuantizationCommon::BLayout, + typename ck::test::TestGemmQuantizationCommon::ELayout>; + } +}; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmQuantization, KernelTypes); + +#include "test_gemm_quantization_ut_cases.inc" diff --git a/test/quantization/gemm/test_gemm_quantization_ut_cases.inc b/test/quantization/gemm/test_gemm_quantization_ut_cases.inc new file mode 100644 index 0000000000..83a13e4a85 --- /dev/null +++ b/test/quantization/gemm/test_gemm_quantization_ut_cases.inc @@ -0,0 +1,41 @@ +#pragma once + +TYPED_TEST(TestGemmQuantization, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + for(int M : Ms) + this->Run({{M, N, K}}); +} + +TYPED_TEST(TestGemmQuantization, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 1024; + constexpr int K = 320; + + for(int M : Ms) + this->Run({{M, N, K}}); +} + +TYPED_TEST(TestGemmQuantization, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + for(int M : Ms) + this->Run({{M, N, K}}); +} + +TYPED_TEST(TestGemmQuantization, Regular) +{ + constexpr int M = 512; + constexpr int N = 512; + std::vector Ks{512}; + + for(int K : Ks) + this->Run({{M, N, K}}); +} diff --git a/test/quantization/gemm/test_gemm_quantization_util.hpp b/test/quantization/gemm/test_gemm_quantization_util.hpp new file mode 100644 index 0000000000..e1ca0de2db --- /dev/null +++ b/test/quantization/gemm/test_gemm_quantization_util.hpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using I8 = int8_t; +using I32 = int32_t; + +namespace ck { +namespace test { + +using TestMatrixSizes = std::vector>; + +static const TestMatrixSizes DefaultTestMatrixSizes = { + {16, 32, 64}, {512, 2048, 4096}, {2048, 1024, 16}}; + +template +class TestGemmQuantizationCommon : public ::testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using EDataType = std::tuple_element_t<3, Tuple>; + using ALayout = std::tuple_element_t<4, Tuple>; + using BLayout = std::tuple_element_t<5, Tuple>; + using ELayout = std::tuple_element_t<6, Tuple>; + + using ProfileCall = bool (*const)(int, int, bool, bool, int, int, int, int, int, int, float); + + virtual ProfileCall GetImpl() = 0; + + void Run(const TestMatrixSizes& lengths = DefaultTestMatrixSizes) + { + bool all_success = true; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + int StrideA = ck::is_same_v ? K : M; + int StrideB = ck::is_same_v ? N : K; + int StrideE = ck::is_same_v ? N : M; + float requant_scale = 0.03f; + + all_success = + all_success & + GetImpl()(1, 1, false, true, M, N, K, StrideA, StrideB, StrideE, requant_scale); + } + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck From db79fad16fe9c9d52b72c592715adf51d25e525e Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 16 Sep 2025 21:43:19 -0400 Subject: [PATCH 11/28] fix(grouped_gemm): pipeline selection when tail_num varies per group and leads to numerical error (#2863) * fix(grouped_gemm): numerical errors on gfx950 by correctly calculating the tail num * WIP: add temp config to stress test numerical error correction * refactor: remove comments --- .../ck_tile/17_grouped_gemm/CMakeLists.txt | 9 +++--- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 4 ++- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 25 +++++++++++++++- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 30 ++----------------- 4 files changed, 33 insertions(+), 35 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index f97cc03d2a..1a833df6c2 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -1,10 +1,9 @@ add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp) add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp) - - -set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) +set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) - list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) \ No newline at end of file +target_compile_options(tile_example_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 9975f2024b..606d98d9e2 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -356,6 +356,8 @@ int main(int argc, char* argv[]) #if CK_TILE_USE_WMMA return !run_grouped_gemm_example(argc, argv); #else - return !run_grouped_gemm_example(argc, argv); + return !run_grouped_gemm_example(argc, argv) || + !run_grouped_gemm_example(argc, argv) || + !run_grouped_gemm_example(argc, argv); #endif } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 1ae0844032..6493a542ba 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -91,7 +91,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; - static constexpr bool Persistent = false; + static constexpr bool Persistent = true; static constexpr bool DoubleSmemBuffer = false; }; @@ -139,6 +139,29 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr int kBlockPerCu = 2; }; +template +struct GemmConfigComputeV4_V2 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + + static constexpr int kBlockPerCu = 2; +}; + template struct GemmConfigPreshuffleDecode : public GemmConfigBase { diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index dda38bbc47..e38e49f5d1 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -292,34 +292,8 @@ struct GroupedGemmKernel { __shared__ char smem_ptr_1[GetSmemSize()]; - if constexpr(UsePersistentKernel || GemmPipeline::Preshuffle) - { - - RunGemmWithPipelineSelection2LDS(a_ptr, - b_ptr, - c_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - return; - } - else - { - - Base::RunGemm2LDS({a_ptr}, - {b_ptr}, - {/*ds_ptr*/}, - c_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + RunGemmWithPipelineSelection2LDS( + a_ptr, b_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, splitk_batch_offset, i_m, i_n); } else // SingleSmemBuffer { From c2997f2b7f1ae2729b50745baf590efea858d3e9 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Wed, 17 Sep 2025 10:54:06 +0800 Subject: [PATCH 12/28] [CK_TILE] Refine pk_fp4's fill, pack, and unpack (#2845) * fix bug * let pack/unpack return pk_fp4_t * fix clang-format --- include/ck_tile/core/numeric/pk_fp4.hpp | 84 +++++++++++++++---------- include/ck_tile/host/fill.hpp | 13 +++- test/ck_tile/data_type/test_pk_fp4.cpp | 9 ++- 3 files changed, 70 insertions(+), 36 deletions(-) diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index f25b98f5a0..8b78990d08 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -23,7 +23,8 @@ using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); -CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f); +struct pk_float4_e2m1_t; +CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t float_to_pk_fp4(const float& x, float scale = 1.f); // TODO: Add stochastic method struct pk_float4_e2m1_t @@ -31,7 +32,7 @@ struct pk_float4_e2m1_t // TODO: Can we merge raw_type and type? using raw_type = uint8_t; using type = raw_type; - raw_type data; + type data; CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t() : data{type{}} {} template >> @@ -39,12 +40,12 @@ struct pk_float4_e2m1_t { } CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init, float scale = 1.f) - : data{float_to_e2m1(init, scale)} + : data{float_to_pk_fp4(init, scale)} { } CK_TILE_HOST_DEVICE constexpr operator type() const { return data; } - CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; } - CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; } + CK_TILE_HOST_DEVICE constexpr type& get() { return data; } + CK_TILE_HOST_DEVICE constexpr type get() const { return data; } CK_TILE_HOST_DEVICE constexpr float to_float(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale = 1.f) const; @@ -61,8 +62,19 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); } template - CK_TILE_HOST_DEVICE constexpr raw_type unpack(number) const; - CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const type x0, const type x1) + CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t unpack(number) const + { + return _unpack(number{}); + } + CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const pk_float4_e2m1_t& x0, + const pk_float4_e2m1_t& x1) + { + return _pack(x0.get(), x1.get()); + } + + template + CK_TILE_HOST_DEVICE constexpr type _unpack(number) const; + CK_TILE_HOST_DEVICE constexpr static type _pack(const type x0, const type x1) { return (x1 << 4) | (x0 & 0b00001111); } @@ -92,7 +104,7 @@ struct pk_float4_e2m1_t }; using pk_fp4_t = pk_float4_e2m1_t; -using pk_fp4_raw_t = typename pk_fp4_t::raw_type; +using pk_fp4_raw_t = typename pk_fp4_t::type; template <> struct numeric_traits @@ -124,7 +136,7 @@ struct numeric CK_TILE_HOST_DEVICE static constexpr pk_fp4_t epsilon() { return binary_min_subnorm; } CK_TILE_HOST_DEVICE static constexpr pk_fp4_t round_error() { return binary_min_subnorm; } CK_TILE_HOST_DEVICE static constexpr pk_fp4_t zero() { return binary_zero; } - CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return binary_min_subnorm; } + CK_TILE_HOST_DEVICE static constexpr pk_fp4_t denorm_min() { return binary_min_subnorm; } CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; } // N/A @@ -136,7 +148,7 @@ struct numeric }; template -CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::unpack(number) const +CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::_unpack(number) const { static_assert(I < 2, "Index is out of range."); if constexpr(I == 1) @@ -202,7 +214,7 @@ CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return bf16_t{type_convert(convert_to_float(unpack(number<0>{}), scale))}; + return bf16_t{type_convert(convert_to_float(_unpack(number<0>{}), scale))}; #endif } @@ -211,13 +223,13 @@ CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_t::to_bf16x2(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return bf16x2_t{type_convert(convert_to_float(unpack(number<0>{}), scale)), - type_convert(convert_to_float(unpack(number<1>{}), scale))}; + return bf16x2_t{type_convert(convert_to_float(_unpack(number<0>{}), scale)), + type_convert(convert_to_float(_unpack(number<1>{}), scale))}; #endif } -// TODO: make float_to_e2m1 generic so that we can convert from directrly. -CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale) +// TODO: make it generic so that we can convert from directrly. +CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_mxfp4(float x, float scale) { #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); @@ -227,14 +239,20 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale) } CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale) { - return float_to_e2m1(x, scale); +#if CK_TILE_FP4_CVT_DEVICE + return impl::_to_f4(x, scale); +#else + auto res = convert_to_type(x, scale); + return pk_fp4_t::_pack(res, res); +#endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale) { #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return float_to_e2m1(type_convert(x), scale); + auto res = float_to_mxfp4(type_convert(x), scale); + return pk_fp4_t::_pack(res, res); #endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale) @@ -242,7 +260,8 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float sca #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return float_to_e2m1(type_convert(x), scale); + auto res = float_to_mxfp4(type_convert(x), scale); + return pk_fp4_t::_pack(res, res); #endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale) @@ -250,7 +269,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale)); + return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale)); #endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale) @@ -258,7 +277,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale)); + return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale)); #endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale) @@ -266,7 +285,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale)); + return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale)); #endif } @@ -301,7 +320,7 @@ CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return convert_to_float(unpack(number<0>{}), scale); + return convert_to_float(_unpack(number<0>{}), scale); #endif } CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const @@ -309,8 +328,8 @@ CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return fp32x2_t{convert_to_float(unpack(number<0>{}), scale), - convert_to_float(unpack(number<1>{}), scale)}; + return fp32x2_t{convert_to_float(_unpack(number<0>{}), scale), + convert_to_float(_unpack(number<1>{}), scale)}; #endif } @@ -319,7 +338,7 @@ CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return fp16_t{type_convert(convert_to_float(unpack(number<0>{}), scale))}; + return fp16_t{type_convert(convert_to_float(_unpack(number<0>{}), scale))}; #endif } CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const @@ -327,28 +346,29 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return fp16x2_t{type_convert(convert_to_float(unpack(number<0>{}), scale)), - type_convert(convert_to_float(unpack(number<1>{}), scale))}; + return fp16x2_t{type_convert(convert_to_float(_unpack(number<0>{}), scale)), + type_convert(convert_to_float(_unpack(number<1>{}), scale))}; #endif } #else CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const { - return e2m1_to_fp32_table[unpack(number<0>{})] * scale; + return e2m1_to_fp32_table[_unpack(number<0>{})] * scale; } CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const { - return fp32x2_t{e2m1_to_fp32_table[unpack(number<0>{})] * scale, e2m1_to_fp32_table[unpack(number<1>{}] * scale}; + return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, e2m1_to_fp32_table[_unpack(number<1>{}] * scale}; } CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const { - return type_convert(e2m1_to_fp16_table[unpack(number<0>{})]) * scale; + return type_convert(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale; } CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const { return fp16x2_t{ - type_convert(type_convert(e2m1_to_fp16_table[unpack(number<0>{})]) * scale), - type_convert(type_convert(e2m1_to_fp16_table[unpack(number<1>{})]) * scale)}; + type_convert(type_convert(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale), + type_convert(type_convert(e2m1_to_fp16_table[_unpack(number<1>{})]) * + scale)}; } #endif diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index e03881a1c7..817a46a8ea 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -67,7 +67,10 @@ struct FillUniformDistribution : std::random_device{}()); std::uniform_real_distribution dis(a_, b_); std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() { - return ck_tile::type_convert(dis(gen)); + if constexpr(numeric_traits::PackedSize == 2) + return ck_tile::type_convert(fp32x2_t{dis(gen), dis(gen)}); + else + return ck_tile::type_convert(dis(gen)); }); }; threads[it] = joinable_thread(thread_f); @@ -77,8 +80,12 @@ struct FillUniformDistribution { std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); std::uniform_real_distribution dis(a_, b_); - std::generate( - first, last, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); }); + std::generate(first, last, [&dis, &gen]() { + if constexpr(numeric_traits::PackedSize == 2) + return ck_tile::type_convert(fp32x2_t{dis(gen), dis(gen)}); + else + return ck_tile::type_convert(dis(gen)); + }); } } diff --git a/test/ck_tile/data_type/test_pk_fp4.cpp b/test/ck_tile/data_type/test_pk_fp4.cpp index 15f027e95d..b1e981624a 100644 --- a/test/ck_tile/data_type/test_pk_fp4.cpp +++ b/test/ck_tile/data_type/test_pk_fp4.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" +#include #include #include "ck_tile/core.hpp" @@ -29,6 +30,12 @@ TEST(PackedFp4, NumericLimits) EXPECT_EQ(ck_tile::numeric::epsilon(), pk_fp4_t{0b00010001}); EXPECT_EQ(ck_tile::numeric::round_error(), pk_fp4_t{0b00010001}); } +TEST(PackedFp4, fill) +{ + std::vector v_fp4(4); + ck_tile::FillUniformDistribution{1.f, 1.f}(v_fp4); + EXPECT_EQ(v_fp4[0].get(), pk_fp4_t{0b00100010}.get()); +} TEST(PackedFp4, ConvertBasic) { EXPECT_EQ(ck_tile::convert_to_type(0.0f), pk_fp4_t{0b00000000}.get()); @@ -102,7 +109,7 @@ struct SrcPkfp4Dst // ex: fp32_t -> fp4 -> bf16_t dst[i] = toDST(toPF4(src[i])); // ex: fp32x2_t -> pk_fp4 -> unpack<0> -> bf16_t - dst[i + 1] = toDST(toPF4(toPF4(input2).unpack(number<1>{}))); + dst[i + 1] = toDST(toPF4(input2).unpack(number<1>{})); } else { From 592d73ad733df36f02184b787e0548901ba414a7 Mon Sep 17 00:00:00 2001 From: pmaybank <113125070+pmaybank@users.noreply.github.com> Date: Wed, 17 Sep 2025 17:59:01 +0100 Subject: [PATCH 13/28] [CK_TILE] Add support for gfx12 in tile_engine for GEMM benchmarking (#2802) * initial work on adding support of gfx12 in tile_engine for GEMM benchmarking * add stage("Run TILE_ENGINE_GEMM Tests on gfx1201") to Jenkins config * make tile_[m/n/k] validation arch dependent --- Jenkinsfile | 62 ++++++++--- script/cmake-ck-dev.sh | 4 +- tile_engine/ops/gemm/CMakeLists.txt | 71 ++++++------ tile_engine/ops/gemm/codegen_utils.py | 5 + .../ops/gemm/configs/gfx120x_config.json | 102 ++++++++++++++++++ tile_engine/ops/gemm/validation_utils.py | 57 +++++++++- 6 files changed, 249 insertions(+), 52 deletions(-) create mode 100644 tile_engine/ops/gemm/configs/gfx120x_config.json diff --git a/Jenkinsfile b/Jenkinsfile index 9d1af7c5d9..efe08a7d41 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -157,9 +157,9 @@ def getDockerImage(Map conf=[:]){ image = getDockerImageName() echo "Using default docker: ${image}" } - //Check if image exists + //Check if image exists def retimage - try + try { echo "Pulling image: ${image}" retimage = docker.image("${image}") @@ -232,7 +232,7 @@ def cmake_build(Map conf=[:]){ def setup_args = conf.get("setup_args","") // make sure all unit tests always run on develop branch def runAllUnitTests = (env.BRANCH_NAME == "develop") ? true : params.RUN_ALL_UNIT_TESTS - + if (prefixpath != "/usr/local"){ setup_args = setup_args + " -DCMAKE_PREFIX_PATH=${prefixpath} " } @@ -357,7 +357,7 @@ def cmake_build(Map conf=[:]){ "build_cmd", "${build_envs} ninja -j${nt} ${config_targets}" ) - + cmd = conf.get("cmd", """ ${setup_cmd} ${build_cmd} @@ -449,7 +449,7 @@ def buildHipClangJob(Map conf=[:]){ checkout scm def prefixpath = conf.get("prefixpath", "/opt/rocm") - // Jenkins is complaining about the render group + // Jenkins is complaining about the render group def dockerOpts if ( params.BUILD_INSTANCES_ONLY ){ dockerOpts = "--group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" @@ -515,7 +515,7 @@ def Build_CK(Map conf=[:]){ checkout scm def prefixpath = conf.get("prefixpath", "/opt/rocm") - // Jenkins is complaining about the render group + // Jenkins is complaining about the render group def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " @@ -719,7 +719,7 @@ def process_results(Map conf=[:]){ def image = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm6.3" def prefixpath = "/opt/rocm" - // Jenkins is complaining about the render group + // Jenkins is complaining about the render group def dockerOpts="--cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " @@ -956,20 +956,20 @@ pipeline { defaultValue: '', description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( - name: 'ROCMVERSION', + name: 'ROCMVERSION', defaultValue: '6.4.1', description: 'Specify which ROCM version to use: 6.4.1 (default).') string( - name: 'COMPILER_VERSION', - defaultValue: '', + name: 'COMPILER_VERSION', + defaultValue: '', description: 'Specify which version of compiler to use: release, amd-staging, amd-mainline, or leave blank (default).') string( - name: 'COMPILER_COMMIT', - defaultValue: '', + name: 'COMPILER_COMMIT', + defaultValue: '', description: 'Specify which commit of compiler branch to use: leave blank to use the latest commit (default), or use some specific commit of llvm-project branch.') string( - name: 'BUILD_COMPILER', - defaultValue: '/opt/rocm/llvm/bin/clang++', + name: 'BUILD_COMPILER', + defaultValue: '/opt/rocm/llvm/bin/clang++', description: 'Build CK with /opt/rocm/bin/hipcc, /llvm-project/build/bin/clang++, or with /opt/rocm/llvm/bin/clang++ (default).') booleanParam( name: "RUN_FULL_QA", @@ -1448,6 +1448,36 @@ pipeline { cleanWs() } } + stage("Run TILE_ENGINE_GEMM Tests on gfx1201") + { + when { + beforeAgent true + expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx1201") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx1201" \ + -D GEMM_DATATYPE="fp16" \ + -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ + -DGEMM_CONFIG_FILE=gfx120x_config.json \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && \ + ninja -j64 benchmark_gemm_all && \ + python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \ + --warmup 5 --repeat 5 --verbose --json results.json && \ + ninja -j64 benchmark_gemm_fp16_rcr && \ + ninja -j64 benchmark_gemm_fp16_rrr && \ + ninja -j64 benchmark_gemm_fp16_crr && \ + ninja -j64 benchmark_gemm_fp16_ccr """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } } } @@ -1591,7 +1621,7 @@ pipeline { agent{ label rocmnode("gfx942") } steps{ script { - def execute_args = params.NINJA_FTIME_TRACE ? + def execute_args = params.NINJA_FTIME_TRACE ? """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ @@ -1600,7 +1630,7 @@ pipeline { -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """ - + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0") } cleanWs() diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 086359a79f..6220009b03 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -20,7 +20,7 @@ fi GPU_TARGETS="gfx908;gfx90a;gfx942" if [ $# -ge 1 ]; then - case "$1" in + case "$1" in gfx*) GPU_TARGETS=$1 shift 1 @@ -38,7 +38,7 @@ fi cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm/ \ -D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ --D CMAKE_CXX_FLAGS="-ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ +-D CMAKE_CXX_FLAGS="-ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker -fbracket-depth=512" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ -D GPU_TARGETS=$GPU_TARGETS \ diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index d52351af2d..77165ae0fa 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -13,38 +13,38 @@ function(create_individual_gemm_target datatype layout trait tile_config config_ message(WARNING "Skipping individual GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets") return() endif() - + # Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k # First split by underscore to get three groups string(REPLACE "_" ";" config_groups ${tile_config}) list(GET config_groups 0 tile_dims) # e.g., 256x256x32 list(GET config_groups 1 warp_dims) # e.g., 4x1x1 list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16 - + # Parse tile dimensions string(REPLACE "x" ";" tile_parts ${tile_dims}) list(GET tile_parts 0 tile_m) list(GET tile_parts 1 tile_n) list(GET tile_parts 2 tile_k) - + # Parse warp dimensions string(REPLACE "x" ";" warp_parts ${warp_dims}) list(GET warp_parts 0 warp_m) list(GET warp_parts 1 warp_n) list(GET warp_parts 2 warp_k) - + # Parse warp tile dimensions string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims}) list(GET warp_tile_parts 0 warp_tile_m) list(GET warp_tile_parts 1 warp_tile_n) list(GET warp_tile_parts 2 warp_tile_k) - + set(target_name "benchmark_gemm_${datatype}_${layout}_${trait}_${tile_config}") set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") - + # Generate the single instance header for this kernel set(instance_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") - + # Add custom command to generate the header file at build time add_custom_command( OUTPUT ${instance_header} @@ -60,27 +60,27 @@ function(create_individual_gemm_target datatype layout trait tile_config config_ DEPENDS ${GEMM_SOURCE_DIR}/gemm_instance_builder.py ${config_json} COMMENT "Generating ${instance_header}" ) - + # Create the executable - add_executable(${target_name} + add_executable(${target_name} ${GEMM_SOURCE_DIR}/benchmark_gemm_single.cpp ${instance_header} ) - + # Set GPU architectures set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS_INDIVIDUAL}) - + # Set compile definitions target_compile_definitions(${target_name} PRIVATE GEMM_SINGLE_INSTANCE_HPP="${instance_header}" ) - + # Include directories target_include_directories(${target_name} PRIVATE ${GEMM_SOURCE_DIR} ${working_path} ) - + # Compile options target_compile_options(${target_name} PRIVATE -Wno-undefined-func-template @@ -88,19 +88,19 @@ function(create_individual_gemm_target datatype layout trait tile_config config_ --offload-compress -include ${instance_header} ) - + # Add to collection targets add_dependencies(benchmark_gemm_all ${target_name}) add_dependencies(benchmark_gemm_${datatype} ${target_name}) add_dependencies(benchmark_gemm_${layout} ${target_name}) add_dependencies(benchmark_gemm_${datatype}_${layout} ${target_name}) - + # Add to trait-specific targets string(REPLACE "_" ";" trait_parts ${trait}) list(GET trait_parts 0 pipeline) list(GET trait_parts 1 epilogue) list(GET trait_parts 2 scheduler) - + add_dependencies(benchmark_gemm_${pipeline} ${target_name}) add_dependencies(benchmark_gemm_${epilogue} ${target_name}) add_dependencies(benchmark_gemm_${scheduler} ${target_name}) @@ -109,13 +109,13 @@ endfunction() # Function to build individual GEMM targets function(build_individual_gemm_targets datatype layout) set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") - + # Choose config file # Priority order: # 1. Environment variable GEMM_CONFIG_FILE - # 2. CMake variable GEMM_CONFIG_FILE + # 2. CMake variable GEMM_CONFIG_FILE # 3. Default based on layout - + # Check environment variable first if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "") set(config_filename "$ENV{GEMM_CONFIG_FILE}") @@ -130,12 +130,12 @@ function(build_individual_gemm_targets datatype layout) set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") message(STATUS " Using default config for layout ${layout}") endif() - + # Check if config file exists if(NOT EXISTS ${json_blob}) message(FATAL_ERROR "Config file not found: ${json_blob}") endif() - + # Determine number of workers for parallel generation if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL}) set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL}) @@ -147,17 +147,24 @@ function(build_individual_gemm_targets datatype layout) set(num_workers 8) endif() endif() - + # Generate individual kernel files using parallel version message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") message(STATUS " Working path: ${working_path}") message(STATUS " Config file: ${json_blob}") message(STATUS " Python executable: ${Python3_EXECUTABLE}") message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py") - + # Create working directory first file(MAKE_DIRECTORY ${working_path}) - + + message(STATUS "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --list_kernels") + # First, just list the kernels (fast operation) message(STATUS " Listing kernel configurations...") execute_process( @@ -172,11 +179,11 @@ function(build_individual_gemm_targets datatype layout) OUTPUT_VARIABLE list_output ERROR_VARIABLE list_error ) - + if(NOT ret EQUAL 0) message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}") endif() - + # Read kernel count if(EXISTS ${working_path}/gemm_kernel_count.txt) file(READ ${working_path}/gemm_kernel_count.txt kernel_count) @@ -185,7 +192,7 @@ function(build_individual_gemm_targets datatype layout) else() message(FATAL_ERROR "Kernel count file not found") endif() - + # Read kernel list and create targets if(EXISTS ${working_path}/gemm_kernel_list.txt) file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) @@ -195,7 +202,7 @@ function(build_individual_gemm_targets datatype layout) list(GET parts 0 kernel_name) list(GET parts 1 tile_config) list(GET parts 2 trait_combo) - + # Create individual target create_individual_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}") endforeach() @@ -210,9 +217,9 @@ message(STATUS "GEMM_DATATYPE: ${GEMM_DATATYPE}") message(STATUS "GEMM_LAYOUT: ${GEMM_LAYOUT}") message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") -# Filter GPU targets to only gfx90a, gfx942, and gfx950 +# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201 set(GEMM_GPU_TARGETS_INDIVIDUAL "") -set(DESIRED_TARGETS "gfx90a;gfx942;gfx950") +set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201") foreach(target IN LISTS SUPPORTED_GPU_TARGETS) if(target IN_LIST DESIRED_TARGETS) @@ -223,13 +230,13 @@ endforeach() # Skip build if no matching targets found if(NOT GEMM_GPU_TARGETS_INDIVIDUAL) - message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") else() message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}") # Enable parallel compilation optimizations # Set up job pools for better parallel compilation control - set_property(GLOBAL PROPERTY JOB_POOLS + set_property(GLOBAL PROPERTY JOB_POOLS compile_heavy=4 # Limit heavy compilations to prevent OOM compile_normal=16 # Allow more parallel normal compilations ) diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index 6a87193043..98595933b8 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -179,6 +179,11 @@ warp_tile_supported_combinations = { [32, 32, 64], ], }, + "gfx1201": { + "fp16_fp16_fp16": [ + [16, 16, 16], + ], + }, } # To Do: remove some unsupported combinations diff --git a/tile_engine/ops/gemm/configs/gfx120x_config.json b/tile_engine/ops/gemm/configs/gfx120x_config.json new file mode 100644 index 0000000000..6c4a5d0ec0 --- /dev/null +++ b/tile_engine/ops/gemm/configs/gfx120x_config.json @@ -0,0 +1,102 @@ +{ + "problem": { + }, + "tile_config": { + "tile_m": { + "values": [ + 256, + 128, + 64 + ] + }, + "tile_n": { + "values": [ + 256, + 128, + 64 + ] + }, + "tile_k": { + "values": [ + 256, + 128, + 64 + ] + }, + "warp_m": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_n": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 16 + ] + }, + "warp_tile_n": { + "values": [ + 16 + ] + }, + "warp_tile_k": { + "values": [ + 16 + ] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "compv3", + "mem" + ] + }, + "scheduler": { + "values": [ + "intrawave", + "interwave" + ] + }, + "epilogue": { + "values": [ + "cshuffle", + "default" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + }, + "persistent": { + "values": [ + false, + true + ] + } + } +} diff --git a/tile_engine/ops/gemm/validation_utils.py b/tile_engine/ops/gemm/validation_utils.py index 7367f2446d..c0e109bf11 100644 --- a/tile_engine/ops/gemm/validation_utils.py +++ b/tile_engine/ops/gemm/validation_utils.py @@ -103,6 +103,36 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { [32, 32, 64], ], }, + "gfx1201": { + "fp16_fp16_fp16": [ + [16, 16, 16], + ], + }, +} + +# Supported warp tile combinations for different GPU architectures and data types +WARP_SUPPORTED_COMBINATIONS = { + "gfx90a": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1], + ], + "gfx942": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1], + ], + "gfx950": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1], + ], + "gfx1201": [ + [2, 4, 1], + [1, 8, 1], + [8, 1, 1], + [4, 2, 1], + ], } # Unsupported trait combinations @@ -155,9 +185,32 @@ def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS -def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool: +def validate_warp_configuration( + warp_m: int, + warp_n: int, + warp_k: int, + gpu_name: str = None, +) -> bool: """Validate warp configuration.""" - return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + if gpu_name is None: + gpu_name = get_gpu_name_by_id(0) + + current_combination = [warp_m, warp_n, warp_k] + + allowed_combinations = WARP_SUPPORTED_COMBINATIONS.get(gpu_name, {}) + if not allowed_combinations: + # If GPU not recognized, try to be permissive but log warning + logging.warning(f"No warp_[m/n/k] combinations found for GPU: {gpu_name}") + return True + + # Check if current combination is in the allowed list + if current_combination not in allowed_combinations: + error_msg = ( + f"Invalid warp tile combination: {current_combination} not in allowed list. " + ) + return False + + return True def validate_dimension_alignment( From 5c4f52a02ae6e9d6a368e607c81a3edd682e9613 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Kulikowski?= Date: Wed, 17 Sep 2025 19:39:48 +0200 Subject: [PATCH 14/28] [CK][Examples] - fixing grouped_conv_bwd_weight command parser. (#2840) -added parameter to change group count for grouped_gemm examples. Signed-off-by: Michal Kulikowski --- .../grouped_gemm_xdl_fixed_nk_bias_fp16.cpp | 41 +++++++++++------- .../grouped_gemm_xdl_fixed_nk_fp16.cpp | 42 ++++++++++++------- .../grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp | 41 +++++++++++------- .../grouped_gemm_xdl_splitk_fp16.cpp | 36 +++++++++------- example/20_grouped_conv_bwd_weight/common.hpp | 4 +- 5 files changed, 101 insertions(+), 63 deletions(-) diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp index 5bdc993192..2fcc0e3cb1 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp @@ -323,6 +323,31 @@ int main(int argc, char* argv[]) problem_size.Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + problem_size.group_count = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + printf("arg5: group count (default=16)"); + exit(0); + } + for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ns.push_back(768); @@ -333,21 +358,5 @@ int main(int argc, char* argv[]) problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 5) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.k_batch = std::stoi(argv[4]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4: k_batch (>0)\n"); - exit(0); - } - return !run_grouped_gemm(problem_size, config); } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 6806bd1886..fb611fd444 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -296,6 +296,32 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + problem_size.group_count = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + printf("arg5: group count (default=16)"); + + exit(0); + } + for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(128 + rand() % 128); @@ -307,21 +333,5 @@ int main(int argc, char* argv[]) problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 5) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.k_batch = std::stoi(argv[4]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4: k_batch (> 0)\n"); - exit(0); - } - return !run_grouped_gemm(problem_size, config); } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp index 8418c10f5e..47eb6637bd 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp @@ -297,6 +297,31 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + problem_size.group_count = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + printf("arg5: group count (default=16)"); + exit(0); + } + for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); @@ -308,21 +333,5 @@ int main(int argc, char* argv[]) problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 5) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.k_batch = std::stoi(argv[4]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4: k_batch (> 0)\n"); - exit(0); - } - return !run_grouped_gemm(problem_size, config); } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp index 9f8f6cb1e4..16d018936b 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp @@ -66,6 +66,28 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + problem_size.group_count = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: group count (default=16)"); + exit(0); + } + for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); @@ -77,19 +99,5 @@ int main(int argc, char* argv[]) problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 4) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - exit(0); - } - return !run_grouped_gemm(problem_size, config); } diff --git a/example/20_grouped_conv_bwd_weight/common.hpp b/example/20_grouped_conv_bwd_weight/common.hpp index e0034bf7eb..9159e51eaf 100644 --- a/example/20_grouped_conv_bwd_weight/common.hpp +++ b/example/20_grouped_conv_bwd_weight/common.hpp @@ -123,7 +123,9 @@ inline bool parse_cmd_args(int argc, const ck::index_t num_dim_spatial = std::stoi(argv[4]); conv_param = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + num_dim_spatial, + threshold_to_catch_partial_args + 1, // +1 because we already parsed num_dim_spatial + argv); } else { From 7c934b72ab695ecbb5b07354cb23e34f6076d25b Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Wed, 17 Sep 2025 14:04:21 -0400 Subject: [PATCH 15/28] build(grouped_gemm): added appropriate compiler flag to resolve numerical error for fp8 on gfx950 (#2868) --- example/ck_tile/17_grouped_gemm/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index 1a833df6c2..4f3b173c55 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -7,3 +7,4 @@ if(CK_USE_OCP_FP8) endif() target_compile_options(tile_example_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) From dd7af118d7a83a4612fd72a6108e4a433913a89c Mon Sep 17 00:00:00 2001 From: yinglu Date: Thu, 18 Sep 2025 05:50:15 +0800 Subject: [PATCH 16/28] TF32 POC in Conv3d on MI30x platform #2763 (second attempt) (#2852) * Revert "Revert "feature:tf32:add initial conv3d fwd kernel support (#2763)" (#2848)" This reverts commit 03b59f8c76e48cdee4b84782017ae41feaf3f98f. * fix compile error on gf12x * only run tf32 example on gfx942 * only build tf32 instance on gfx942 * ckProfiler:only support tf32 in gfx942 * delete unuseful messages --- example/01_gemm/CMakeLists.txt | 10 ++ example/01_gemm/common.hpp | 16 ++- .../gemm_xdl_lds_direct_load_fp32_tf32.cpp | 85 +++++++++++ example/01_gemm/run_gemm_example.inc | 13 +- example/09_convnd_fwd/CMakeLists.txt | 11 +- example/09_convnd_fwd/convnd_fwd_common.hpp | 29 ++-- .../convnd_fwd_xdl_fp32_tf32.cpp | 89 ++++++++++++ example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp | 4 + .../09_convnd_fwd/run_convnd_fwd_example.inc | 27 ++-- include/ck/host_utility/device_prop.hpp | 2 + include/ck/library/utility/check_err.hpp | 2 +- .../gpu/block/blockwise_gemm_xdlops.hpp | 85 ++++++----- ...vice_gemm_xdl_cshuffle_lds_direct_load.hpp | 12 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 51 ++++++- ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 4 +- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 39 ++--- ...ultiple_d_xdl_cshuffle_lds_direct_load.hpp | 16 ++- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 136 +++++++++++++++--- include/ck/utility/amd_xdlops.hpp | 41 ++++++ include/ck/utility/data_type.hpp | 35 +++++ include/ck/utility/type_convert.hpp | 13 ++ .../cpu/reference_conv_fwd.hpp | 43 +++++- .../cpu/reference_gemm.hpp | 40 ++++-- .../gpu/reference_gemm.hpp | 19 ++- .../device_operation_instance_factory.hpp | 1 + ...ouped_conv_fwd_xdl_dynamic_op_instance.hpp | 1 + .../device_grouped_conv_fwd_xdl_instance.hpp | 43 +++++- .../gpu/grouped_convolution_forward.hpp | 6 + ...grouped_convolution_forward_bias_clamp.hpp | 8 ++ ...ped_convolution_forward_bias_clamp_xdl.inc | 16 +++ .../gpu/grouped_convolution_forward_clamp.hpp | 8 ++ .../grouped_convolution_forward_clamp_xdl.inc | 16 +++ ...grouped_convolution_forward_dynamic_op.hpp | 14 +- .../gpu/grouped_convolution_forward_xdl.inc | 16 +++ .../gpu/CMakeLists.txt | 7 +- .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 3 +- ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 56 ++++++++ .../CMakeLists.txt | 59 ++++---- ..._ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in | 81 +++++++++++ .../CMakeLists.txt | 4 +- ...dhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp | 60 ++++++++ .../grouped_conv3d_fwd_clamp/CMakeLists.txt | 4 +- ...dhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp | 60 ++++++++ library/src/utility/host_tensor.cpp | 2 +- profiler/src/profile_grouped_conv_fwd.cpp | 41 ++++-- 45 files changed, 1147 insertions(+), 181 deletions(-) create mode 100644 example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp create mode 100644 example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 2f9d85d51d..03bde86421 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -105,6 +105,16 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() +list(APPEND gpu_list_tf32 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) + add_example_executable(example_gemm_xdl_lds_direct_load_fp32_tf32 gemm_xdl_lds_direct_load_fp32_tf32.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32_tf32) + set(target 1) + endif() +endforeach() + add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 434f549443..e482953e46 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -310,10 +310,14 @@ bool parse_cmd_args(int argc, return true; } -template +template inline __host__ __device__ constexpr double get_rtol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) { return 1e-3; } @@ -351,10 +355,14 @@ inline __host__ __device__ constexpr double get_rtol() } } -template +template inline __host__ __device__ constexpr double get_atol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) { return 1e-3; } diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp new file mode 100644 index 0000000000..9b92fad779 --- /dev/null +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "common.hpp" + +#define USING_DIRECT_LOADS 1 +#if USING_DIRECT_LOADS +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp" +#else +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#endif + +#define EXAMPLE_WITH_COMPUTE_DATATYPE + +using F32 = float; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F32; +using ComputeDataType = ck::tf32_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +#if USING_DIRECT_LOADS +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_LdsDirectLoad +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| +// ######| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| +// ######| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler | pipeline ver | gemm type | +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| +// ######| XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, + 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, + 1, 1, S<1, 8, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v4, ComputeDataType>; +// clang-format on +#else +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>; +// clang-format on +#endif +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 3e018aad1e..08e2b8c15f 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -4,6 +4,11 @@ #pragma once #include "ck/library/utility/validation_common.hpp" +// use macro to minimize code change +#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE +using ComputeDataType = AccDataType; +#endif + template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { @@ -218,8 +223,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); #endif } @@ -249,8 +254,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_device_ref_result, "Error: Incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); } return pass == true; diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 91c072aef7..4f174bfcbb 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -19,4 +19,13 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) set(target 1) endif() -endforeach() \ No newline at end of file +endforeach() + +list(APPEND gpu_list_tf32 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) + add_example_executable(example_convnd_fwd_xdl_fp32_tf32 convnd_fwd_xdl_fp32_tf32.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index b0fd6a382a..d82b56ec00 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -27,10 +27,14 @@ void print_helper_msg() << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; } -template +template inline __host__ __device__ constexpr double get_rtol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 5e-3; + } + else if constexpr(std::is_same_v) { return 1e-3; } @@ -68,10 +72,14 @@ inline __host__ __device__ constexpr double get_rtol() } } -template +template inline __host__ __device__ constexpr double get_atol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-2; + } + else if constexpr(std::is_same_v) { return 1e-3; } @@ -116,7 +124,8 @@ template + typename DeviceConvNDFwdInstance, + typename ComputeDataType = OutDataType> bool run_grouped_conv_fwd(bool do_verification, int init_method, bool time_kernel, @@ -228,7 +237,11 @@ bool run_grouped_conv_fwd(bool do_verification, OutDataType, InElementOp, WeiElementOp, - OutElementOp>(); + OutElementOp, + 0, + 0, + 0, + ComputeDataType>(); auto ref_invoker = ref_conv.MakeInvoker(); auto ref_argument = ref_conv.MakeArgument(in, @@ -249,8 +262,8 @@ bool run_grouped_conv_fwd(bool do_verification, return ck::utils::check_err(out_device, out_host, "Error: incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); } return true; diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp new file mode 100644 index 0000000000..348da7e1ef --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +#define EXAMPLE_WITH_COMPUTE_DATATYPE + +using InDataType = float; +using WeiDataType = float; +using AccDataType = float; +using CShuffleDataType = float; +using OutDataType = float; +using ComputeDataType = ck::tf32_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, // ALayout + WeiLayout, // BLayout + ck::Tuple<>, // DsLayout + OutLayout, // ELayout + InDataType, // ADataType + WeiDataType, // BDataType + AccDataType, // AccDataType + CShuffleDataType, // CShuffleDataType + ck::Tuple<>, // DsDataType + OutDataType, // EDataType + InElementOp, // AElementwiseOperation + WeiElementOp, // BElementwiseOperation + OutElementOp, // CDEElementwiseOperation + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 192, // NPerBlock + 16, // KPerBlock + 4, // AK1 + 4, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 3, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 16, 1, 16>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4, // CDEBlockTransferScalarPerVector_NPerBlock + ComputeDataType, // AComputeDataType + ComputeDataType, // BComputeDataType + ck::LoopScheduler::Default, // LoopScheduler + 1 // NumGroupsToMerge + >; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp index fde0f51bc7..c635d01d8f 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp @@ -7,6 +7,8 @@ #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#define EXAMPLE_WITH_COMPUTE_DATATYPE + using InDataType = ck::f8_t; using WeiDataType = ck::f8_t; using AccDataType = float; @@ -87,3 +89,5 @@ int main(int argc, char* argv[]) } return run_convnd_fwd_example(argc, argv) ? 0 : 1; } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/09_convnd_fwd/run_convnd_fwd_example.inc b/example/09_convnd_fwd/run_convnd_fwd_example.inc index 49852ff667..016a189d4b 100644 --- a/example/09_convnd_fwd/run_convnd_fwd_example.inc +++ b/example/09_convnd_fwd/run_convnd_fwd_example.inc @@ -3,6 +3,11 @@ #pragma once +// use macro to minimize code change +#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE +using ComputeDataType = AccDataType; +#endif + bool run_convnd_fwd_example(int argc, char* argv[]) { print_helper_msg(); @@ -65,17 +70,17 @@ bool run_convnd_fwd_example(int argc, char* argv[]) InElementOp, WeiElementOp, OutElementOp, - DeviceGroupedConvNDFwdInstance>( - do_verification, - init_method, - time_kernel, - conv_param, - in_g_n_c_wis_desc, - wei_g_k_c_xs_desc, - out_g_n_k_wos_desc, - in_element_op, - wei_element_op, - out_element_op); + DeviceGroupedConvNDFwdInstance, + ComputeDataType>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); }; namespace ctc = ck::tensor_layout::convolution; diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 9c3967d99b..0c4f056a46 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -129,5 +129,7 @@ inline bool is_wmma_supported() return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported(); } +inline bool is_tf32_supported() { return (ck::get_device_name() == "gfx942") ? true : false; } + } // namespace ck #endif diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index d33ecaeef8..185166f7ec 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -180,13 +180,13 @@ check_err(const Range& out, if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { max_err = err > max_err ? err : max_err; - err_count++; if(err_count < 5) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; } res = false; + err_count++; } } if(!res) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index e848ca35b5..55015dd30f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -49,6 +49,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 using ThisThreadBlock = ThisThreadBlock; + using ElementDataTypeA = + conditional_t, float, ComputeTypeA>; + using ElementDataTypeB = + conditional_t, float, ComputeTypeB>; + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); static constexpr index_t KPerBlock = @@ -64,7 +69,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; static constexpr auto xdlops_gemm = - XdlopsGemm{}; + XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; @@ -172,6 +177,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, "wrong!"); + if constexpr(is_same_v || is_same_v) + { + static_assert(is_same_v, + "ComputeTypeA and ComputeTypeB must be same when one of them is tf32"); + } } __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() @@ -297,9 +307,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -321,20 +331,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 b_thread_buf); static_for<0, KPerThread, KPack>{}([&](auto k) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = a_thread_buf + a_thread_vec.template AsType()(i) = a_thread_buf [Number{}]; - b_thread_vec.template AsType()(i) = b_thread_buf + b_thread_vec.template AsType()(i) = b_thread_buf [Number{}]; }); using mfma_input_type_a = - typename vector_type::type; + typename vector_type::type; using mfma_input_type_b = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -361,7 +371,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -371,7 +381,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -445,6 +455,11 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 using Base::KPerThread; using Base::xdlops_gemm; + using ElementDataTypeA = + conditional_t, float, ComputeTypeA>; + using ElementDataTypeB = + conditional_t, float, ComputeTypeB>; + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); // 2-wave optimized blockwise gemm @@ -453,9 +468,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { @@ -499,22 +514,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = + a_thread_vec.template AsType()(i) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = + b_thread_vec.template AsType()(i) = b_thread_buf[Number{}]; }); using mfma_input_type_a = - typename vector_type::type; + typename vector_type::type; using mfma_input_type_b = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -563,7 +578,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 make_tuple(Number{}, I1, I1, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -573,7 +588,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -622,19 +637,21 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() } else if constexpr(LoopSched == LoopScheduler::Interwave) { - return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< + BlockSize, + FloatA, + FloatB, + FloatAcc, + AK0MK1BlockDesc, + BK0NK1BlockDesc, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack, + ComputeTypeA, + ComputeTypeB, + CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>{}; } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp index 8daaafaed1..23b0faec67 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp @@ -119,7 +119,9 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm; + PipelineVer, + ComputeDataType>; + using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; @@ -214,6 +216,14 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm) + { + if(!is_tf32_supported()) + { + return false; + } + } + // Check vector load/store. { using Row = ck::tensor_layout::gemm::RowMajor; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 1412c960c7..cc8561a09f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -1003,11 +1003,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle void Print() const { + std::cout << "AComputeDataType: " << get_type_name() + << "; BComputeDataType: " << get_type_name() + << "; EDataType: " << get_type_name() << std::endl; + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; static_for<0, NumDTensor, 1>{}( [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + + std::cout << "a grid desc" << a_grid_desc_ak0_m_ak1_ << std::endl; + std::cout << "b grid desc" << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "e grid desc" << e_grid_desc_mblock_mperblock_nblock_nperblock_ + << std::endl; } // private: @@ -1198,7 +1207,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle isMultiA, isMultiB, CTranspose>; - return launch_and_time_kernel( stream_config, kernel, @@ -1281,7 +1289,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float avg_time = 0.f; - if constexpr(NeedTransposeKernel) { const index_t a_grid_size = @@ -1686,7 +1693,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { return false; } - + if constexpr(is_same_v || + is_same_v) + { + if(!is_tf32_supported()) + { + return false; + } + if constexpr(!is_same_v) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ComputeDataType for A and B should be same while using TF32" + << std::endl; + } + return false; + } + } // check Gridwise GEMM if(get_warp_size() == 64) { @@ -1766,6 +1789,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } } + if constexpr(is_same_v || + is_same_v) + + { + if(!(ck::get_device_name() == "gfx942")) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "TF32 is enabled on gfx942 only" << std::endl; + } + return false; + } + if constexpr(!is_same_v) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ComputeDataType for A and B should be same while using TF32" + << std::endl; + } + return false; + } + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index c198711dbb..cbad6a5673 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -708,7 +708,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle MXdlPerWave, NXdlPerWave, KPack, - LoopSched>(); + LoopSched, + AComputeDataType, + BComputeDataType>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 59d7f357ec..a97e4503a8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -107,8 +107,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle using BComputeDataType = conditional_t, ck::bhalf_t, BComputeDataType_>; #else - using AComputeDataType = AComputeDataType_; - using BComputeDataType = BComputeDataType_; + using AComputeDataType = + conditional_t, float, AComputeDataType_>; + using BComputeDataType = + conditional_t, float, BComputeDataType_>; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -659,26 +661,27 @@ struct GridwiseGemmMultipleD_xdl_cshuffle : false; constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); - - auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< - BlockSize, - AComputeDataType, - BComputeDataType, - AccDataType, - decltype(a_block_desc_ak0_m_ak1), - decltype(b_block_desc_bk0_n_bk1), - MPerXdl, - NPerXdl, - MXdlPerWave, - NXdlPerWave, - KPack, - LoopSched>(); + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + AComputeDataType, + BComputeDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched, + AComputeDataType_, + BComputeDataType_>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 095b1c5d63..1e72e78349 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -144,7 +144,7 @@ template + typename BComputeDataType_ = AComputeDataType_> struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad { static constexpr index_t NumDTensor = DsDataType::Size(); @@ -172,7 +172,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad using AComputeDataType = conditional_t, ck::bhalf_t, AComputeDataType_>; #else - using AComputeDataType = AComputeDataType_; + using AComputeDataType = + conditional_t, float, AComputeDataType_>; + using BComputeDataType = + conditional_t, float, BComputeDataType_>; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -573,7 +576,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad // This forces m/n_block_data_idx_on_grid into SGPR. const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); @@ -640,10 +642,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); @@ -659,7 +661,9 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad MXdlPerWave, NXdlPerWave, KPack, - LoopSched>(); + LoopSched, + AComputeDataType_, + BComputeDataType_>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index deea6ae9cc..a97d9589cf 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -41,11 +41,11 @@ static constexpr bool scale_mfma_hw_support() enum struct MfmaInstr { - mfma_f32_32x32x1xf32 = 0, - mfma_f32_16x16x1xf32, - mfma_f32_4x4x1xf32, - mfma_f32_32x32x2xf32, - mfma_f32_16x16x4xf32, + mfma_f32_32x32x1f32 = 0, + mfma_f32_16x16x1f32, + mfma_f32_4x4x1f32, + mfma_f32_32x32x2f32, + mfma_f32_16x16x4f32, mfma_f32_32x32x4f16, mfma_f32_16x16x4f16, mfma_f32_4x4x4f16, @@ -78,6 +78,8 @@ enum struct MfmaInstr mfma_f32_16x16x128f8f6f4, mfma_scale_f32_32x32x64f8f6f4, mfma_scale_f32_16x16x128f8f6f4, + mfma_f32_16x16x8xf32, // tf32 + mfma_f32_32x32x4xf32, // gfx11 wmma_f32_16x16x16_f16, wmma_f32_16x16x16_bf16, @@ -98,7 +100,7 @@ template struct mfma_type; template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 4; @@ -120,7 +122,7 @@ struct mfma_type }; template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 4; @@ -142,7 +144,7 @@ struct mfma_type }; template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 1; @@ -164,7 +166,7 @@ struct mfma_type }; template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 1; @@ -187,7 +189,7 @@ struct mfma_type // treat 4x4x1 as a single-blk 4x64 mfma template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 1; @@ -947,6 +949,70 @@ struct mfma_type } }; +/** + * num_threads_per_blk == n_per_blk + * num_regs_per_blk * num_input_blks == m_per_blk + * num_regs_per_blk * wave_size == m_per_blk * n_per_blk + * + * group_size * num_groups_per_blk == num_regs_per_blk + * + * num_regs_per_blk is output(CD) register size which is determined by the instruction. + * k_per_blk(K1PerXdlops) is input(AB) register size which is determined by the instruction. + * group_size is corresponding to CD rows mapping. see: GetBeginOfThreadBlk() + * + * is_k_reduction = (k_per_blk == KPerXdlops) ? false: true. + * + * if (is_k_reduction){ + * num_output_blks == 1; + * } else { + * num_input_blks == num_output_blks; + * } + */ +template <> +struct mfma_type +{ + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t m_per_blk = 16; // from the instruction + static constexpr index_t n_per_blk = 16; // from the instruction + static constexpr index_t num_threads_per_blk = n_per_blk; // 16 + static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 4 + static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 4 + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t k_per_blk = 2; // k_per_blk(K1PerXdlops) should be 2. + static constexpr bool is_k_reduction = true; + + // AB register size : 2, register size: 4 + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x8xf32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t m_per_blk = 32; // from the instruction + static constexpr index_t n_per_blk = 32; // from the instruction + static constexpr index_t num_threads_per_blk = n_per_blk; // 32 + static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 16 + static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 2 + static constexpr index_t group_size = 4; // corresponding to CD rows mapping + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; + // AB register size: 2, CD register size: 16 + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x4xf32::Run(a, b, reg_c); + } +}; + // gfx11 struct mfma_type_gfx11_base { @@ -1116,6 +1182,20 @@ struct mfma_type : public mfma_type_gfx12 } }; +/** + * @class MfmaSelector + * @brief Selects the appropriate MFMA instruction type and configuration for given data types + * and tile sizes on AMD GPUs. + * + * @tparam base_type The base data type for the matrix operation (e.g., float, half_t). + * @tparam MPerXdlops The number of rows per XDLops tile. + * @tparam NPerXdlops The number of columns per XDLops tile. + * @tparam additional_type (Optional) Additional data type for mixed-precision or special cases. + * Defaults to base_type. + * @tparam is_single_rate_mfma (Optional) Whether to use single-rate MFMA instructions. + * Defaults to false. + * @tparam is_scale_mfma (Optional) Whether to use scale MFMA instructions. Defaults to false. + */ template constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_32x32x1xf32; + return MfmaInstr::mfma_f32_32x32x1f32; } template <> constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_32x32x1xf32; + return MfmaInstr::mfma_f32_32x32x1f32; } template <> constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_16x16x1xf32; + return MfmaInstr::mfma_f32_16x16x1f32; } template <> constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_4x4x1xf32; + return MfmaInstr::mfma_f32_4x4x1f32; } template <> constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_4x4x1xf32; + return MfmaInstr::mfma_f32_4x4x1f32; } template <> constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_32x32x2xf32; + return MfmaInstr::mfma_f32_32x32x2f32; } template <> @@ -1188,10 +1268,22 @@ struct MfmaSelector #elif defined(__gfx11__) return MfmaInstr::wmma_unsupport_16x16_gfx11; #else - return MfmaInstr::mfma_f32_16x16x4xf32; + return MfmaInstr::mfma_f32_16x16x4f32; #endif } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x4xf32; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x8xf32; + } + template <> constexpr auto GetMfma() { @@ -1896,7 +1988,7 @@ struct XdlopsGemm __device__ __host__ static constexpr index_t GetRegSizePerXdlops() { - return MPerXdlops * NPerXdlops / mfma_instr.wave_size; + return mfma_instr.num_regs_per_blk; } __device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; } @@ -1906,12 +1998,12 @@ struct XdlopsGemm { static_assert( is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || (is_same::value && is_same::value) || (is_same::value && is_same::value), - "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"); + "base_type must be double, float, tf32_t, half, bfloat16, int8_t, f8_t or bf8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { if constexpr(!TransposeC) diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 02a7a72b8c..be3a5cea42 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1636,4 +1636,45 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> } }; +/******************* tf32 *************************************/ +template +struct intrin_mfma_f32_16x16x8xf32; + +template <> +struct intrin_mfma_f32_16x16x8xf32<16, 16> +{ + template + __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_f32_32x32x4xf32; + +template <> +struct intrin_mfma_f32_32x32x4xf32<32, 32> +{ + template + __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + } // namespace ck diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 5fbe30d21b..48b352986e 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -26,6 +26,7 @@ using byte = unsigned char; using std::byte; #endif +using tf32_t = _BitInt(19); // 1 sign bit, 8 exponent bits, 10 mantissa bits using bhalf_t = ushort; using half_t = _Float16; using int4_t = _BitInt(4); @@ -461,4 +462,38 @@ using int64_t = long long; using int64_t = long; #endif +template +inline const char* get_type_name() +{ + if constexpr(is_same_v) + return "fp16"; + else if constexpr(is_same_v) + return "bf16"; + else if constexpr(is_same_v) + return "tf32"; + else if constexpr(is_same_v) + return "int4"; + else if constexpr(is_same_v) + return "f4"; + else if constexpr(is_same_v) + return "f6"; + else if constexpr(is_same_v) + return "bf6"; + else if constexpr(is_same_v) + return "f8"; + else if constexpr(is_same_v) + return "bf8"; + else if constexpr(is_same_v) + return "e8m0"; + else if constexpr(is_same_v) + return "fp32"; +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) + else + return "unknown"; +#else + else + return typeid(T).name(); +#endif +} + } // namespace ck diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 8e53728ef6..290a6c8dd6 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -187,6 +187,19 @@ inline __host__ __device__ constexpr bf8_ocp_t type_convert(int return bf8_ocp_t{type_convert(x)}; } +template , bool> = false> +inline __host__ __device__ constexpr float type_convert(float x) +{ + union + { + float fp32; + uint32_t int32; + } u = {x}; + + u.int32 = u.int32 & 0xffffe000; + return u.fp32; +} + // Convert X to Y template __host__ __device__ constexpr Y type_convert_sp(X x) diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index 3884902bbf..573571bc07 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -59,6 +59,7 @@ template = 1 && NDimSpatial <= 3, bool>::type = false> struct ReferenceConvFwd : public device::BaseOperator { @@ -163,8 +164,18 @@ struct ReferenceConvFwd : public device::BaseOperator k, c, x); - v_acc += - ck::type_convert(v_in) * ck::type_convert(v_wei); + if constexpr(is_same_v) + { + v_acc += ck::type_convert( + ck::type_convert(v_in)) * + ck::type_convert( + ck::type_convert(v_wei)); + } + else + { + v_acc += ck::type_convert(v_in) * + ck::type_convert(v_wei); + } } } } @@ -238,8 +249,18 @@ struct ReferenceConvFwd : public device::BaseOperator c, y, x); - v_acc += ck::type_convert(v_in) * - ck::type_convert(v_wei); + if constexpr(is_same_v) + { + v_acc += ck::type_convert( + ck::type_convert(v_in)) * + ck::type_convert( + ck::type_convert(v_wei)); + } + else + { + v_acc += ck::type_convert(v_in) * + ck::type_convert(v_wei); + } } } } @@ -327,8 +348,18 @@ struct ReferenceConvFwd : public device::BaseOperator z, y, x); - v_acc += ck::type_convert(v_in) * - ck::type_convert(v_wei); + if constexpr(is_same_v) + { + v_acc += ck::type_convert( + ck::type_convert(v_in)) * + ck::type_convert( + ck::type_convert(v_wei)); + } + else + { + v_acc += ck::type_convert(v_in) * + ck::type_convert(v_wei); + } } } } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index ed07e53e6d..8b9b973b2d 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -25,6 +25,12 @@ template struct ReferenceGemm : public device::BaseOperator { + + using ElementDataTypeA = + ck::conditional_t, float, ComputeTypeA>; + using ElementDataTypeB = + ck::conditional_t, float, ComputeTypeB>; + // Argument struct Argument : public device::BaseArgument { @@ -63,8 +69,8 @@ struct ReferenceGemm : public device::BaseOperator const int K = arg.a_m_k_.mDesc.GetLengths()[1]; AccDataType v_acc{0}; - ComputeTypeA v_a{0}; - ComputeTypeB v_b{0}; + ElementDataTypeA v_a{0}; + ElementDataTypeB v_b{0}; for(int k = 0; k < K; ++k) { @@ -77,16 +83,16 @@ struct ReferenceGemm : public device::BaseOperator else i4 = (i4x2 >> 4) & 0xf; i4 = i4 - 8; - v_a = type_convert(i4); + v_a = type_convert(i4); } else if constexpr(is_same_v) { // TODO: add support for ColMajor layout as well if(k % 2 == 1) - v_a = type_convert( + v_a = type_convert( f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))); else - v_a = type_convert( + v_a = type_convert( f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))); } else if constexpr(is_same_v || @@ -94,7 +100,7 @@ struct ReferenceGemm : public device::BaseOperator is_same_v || is_same_v) { - v_a = type_convert( + v_a = type_convert( arg.a_m_k_(m, k).unpack(k % ADataType::packed_size)); } else @@ -111,16 +117,16 @@ struct ReferenceGemm : public device::BaseOperator else i4 = (i4x2 >> 4) & 0xf; i4 = i4 - 8; - v_b = type_convert(i4); + v_b = type_convert(i4); } else if constexpr(is_same_v) { // TODO: add support for RowMajor layout as well if(k % 2 == 1) - v_b = type_convert( + v_b = type_convert( f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))); else - v_b = type_convert( + v_b = type_convert( f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))); } else if constexpr(is_same_v || @@ -128,7 +134,7 @@ struct ReferenceGemm : public device::BaseOperator is_same_v || is_same_v) { - v_b = type_convert( + v_b = type_convert( arg.b_k_n_(k, n).unpack(k % BDataType::packed_size)); } else @@ -136,8 +142,18 @@ struct ReferenceGemm : public device::BaseOperator arg.b_element_op_(v_b, arg.b_k_n_(k, n)); } - v_acc += - ck::type_convert(v_a) * ck::type_convert(v_b); + if constexpr(is_same_v && + is_same_v) + { // only for tf32 now + v_acc += + ck::type_convert(ck::type_convert(v_a)) * + ck::type_convert(ck::type_convert(v_b)); + } + else + { + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } } CDataType v_c{0}; diff --git a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp index 28274a5154..cf30bc7dda 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp @@ -38,6 +38,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CDEElementwiseOperation c_element_op) { using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ElementDataTypeA = + ck::conditional_t, float, ComputeTypeA>; + using ElementDataTypeB = + ck::conditional_t, float, ComputeTypeB>; const int row_idx = blockIdx.x * blockDim.x + threadIdx.x; const int col_idx = blockIdx.y * blockDim.y + threadIdx.y; @@ -46,8 +50,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) { AccDataType v_acc{0}; - ComputeTypeA v_a{0}; - ComputeTypeB v_b{0}; + ElementDataTypeA v_a{0}; + ElementDataTypeB v_b{0}; CDataType v_c{0}; for(int k_idx = 0; k_idx < k; ++k_idx) @@ -76,7 +80,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) // apply b_element_op b_element_op(v_b, p_b_grid[element_idx_b]); // multiply and accumulate - v_acc += type_convert(v_a) * type_convert(v_b); + if constexpr(is_same_v && + is_same_v) + { // only for tf32 now + v_acc += ck::type_convert(ck::type_convert(v_a)) * + ck::type_convert(ck::type_convert(v_b)); + } + else + { + v_acc += type_convert(v_a) * type_convert(v_b); + } } // apply c_element_op c_element_op(v_c, v_acc); diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 7164f345cd..9aeca39718 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -16,6 +16,7 @@ namespace instance { // aliasing, for commonly used data type using F64 = double; using F32 = float; +using TF32 = ck::tf32_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; using I8 = int8_t; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp index 82c01a634b..568f0e0dc4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp @@ -16,6 +16,7 @@ namespace instance { using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; template using S = ck::Sequence; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index 768fcbada0..52c389d020 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -24,6 +24,7 @@ using BF8 = ck::bf8_t; using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; template using S = ck::Sequence; @@ -199,7 +200,7 @@ using device_grouped_conv_fwd_xdl_f16_nchw_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>, - // 32x32 instance + // 32x32 instance DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, @@ -284,7 +285,45 @@ using device_grouped_conv_fwd_xdl_f32_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_f32_tf32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| AComputeType| BComputeType| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| DATATYPE | DATATYPE | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 545826650c..5a26abecc2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -443,6 +443,12 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(op_ptrs); + } #endif #ifdef CK_ENABLE_FP8 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp index 43411b0031..11e827878c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp @@ -215,6 +215,14 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } #endif } #endif // CK_USE_XDL diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc index aaaacb0d18..045d1623cf 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc @@ -578,6 +578,22 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_insta PassThrough, AddClamp>>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( std::vector && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } #endif } #endif // CK_USE_XDL diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc index d5a8a5344a..b0061b966d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc @@ -578,6 +578,22 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( PassThrough, Clamp>>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( std::vector>>& instances); + #endif #ifdef CK_ENABLE_INT8 @@ -159,7 +160,8 @@ template + typename AComputeType, + typename BComputeType = AComputeType> struct DeviceOperationInstanceFactory> + AComputeType, + BComputeType>> { using DeviceOp = DeviceGroupedConvFwdMultipleABD; + AComputeType, + BComputeType>; static auto GetInstances() { @@ -207,7 +211,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instances( op_ptrs); @@ -244,7 +248,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index a3f2515099..af6041bbc5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -559,6 +559,22 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt index bda9149227..6a776b4943 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -2,7 +2,7 @@ set(GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP) include(ShardInstantiation) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances @@ -11,7 +11,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances @@ -20,7 +20,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances @@ -29,7 +29,16 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances + TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in + NUM_SHARDS 16 + SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl +) + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances @@ -38,7 +47,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances @@ -47,7 +56,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances @@ -58,7 +67,7 @@ generate_sharded_instantiations( ) # large tensor # NDHWGC, GKZYXC, NDHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances @@ -67,7 +76,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances @@ -76,7 +85,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances @@ -87,7 +96,7 @@ generate_sharded_instantiations( ) # merged groups # NDHWGC, GKZYXC, NDHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances @@ -96,7 +105,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances @@ -105,7 +114,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances @@ -116,7 +125,7 @@ generate_sharded_instantiations( ) #mem # NDHWGC, GKZYXC, NDHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances @@ -125,7 +134,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances @@ -134,7 +143,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances @@ -144,7 +153,7 @@ generate_sharded_instantiations( OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) # NDHWGC, GKZYXC, NDHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances @@ -153,7 +162,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances @@ -162,7 +171,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances @@ -173,7 +182,7 @@ generate_sharded_instantiations( ) #comp # NDHWGC, GKZYXC, NDHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances @@ -182,7 +191,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances @@ -191,7 +200,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances @@ -200,7 +209,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances @@ -209,7 +218,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances @@ -218,7 +227,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances @@ -227,7 +236,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in new file mode 100644 index 0000000000..d7f3c87b83 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances = + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances_shard( + device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1P0, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt index 3bd6916cf0..bcc7020ca9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt @@ -23,6 +23,8 @@ set(GROUPED_CONV3D_FWD xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp -) + + xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp + ) add_instance_library(device_grouped_conv3d_fwd_bias_clamp_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..328838bff2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddClamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt index 234533244e..059d22f8d2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt @@ -23,6 +23,8 @@ set(GROUPED_CONV3D_FWD xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp -) + + xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp + ) add_instance_library(device_grouped_conv3d_fwd_clamp_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..a1bf6562c2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/utility/host_tensor.cpp b/library/src/utility/host_tensor.cpp index 7211552641..02bd562e43 100644 --- a/library/src/utility/host_tensor.cpp +++ b/library/src/utility/host_tensor.cpp @@ -53,7 +53,7 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc) os << "strides {"; LogRange(os, desc.GetStrides(), ", "); - os << "}"; + os << "} "; return os; } diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index a7714b4c73..a8d343405d 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -21,14 +21,15 @@ enum struct ConvLayout enum struct ConvDataType { - F32_F32_F32, // 0 - F16_F16_F16, // 1 - BF16_BF16_BF16, // 2 - INT8_INT8_INT8, // 3 - F8_F8_F8, // 4 - BF8_BF8_F8, // 5 - F8_BF8_F8, // 6 - BF8_F8_F8, // 7 + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 + F32_F32_F32_TF32, // 8 }; enum struct IndexType @@ -52,7 +53,8 @@ static void print_helper_msg() << " 4: Input fp8, Weight fp8, Output fp8\n" << " 5: Input bf8, Weight bf8, Output fp8\n" << " 6: Input fp8, Weight bf8, Output fp8\n" - << " 7: Input bf8, Weight fp8, Output fp8)\n" + << " 7: Input bf8, Weight fp8, Output fp8\n" + << " 8: Input fp32, Weight fp32, Output fp32, Compute tf32)\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " @@ -103,6 +105,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using INT8 = int8_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; +#if defined(__gfx942__) + using TF32 = ck::tf32_t; +#endif // using GNWC = ck::tensor_layout::convolution::GNWC; @@ -261,6 +266,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) return profile( I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } // NHWGC_GKYXC_NHWGK else if(num_dim_spatial == 1 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -367,6 +378,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, BF8{}, F8{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } // NGCDHW_GKCZYX_NGKDHW else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) @@ -384,6 +401,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) return profile( I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } std::cout << "this data_type & layout is not implemented" << std::endl; From 427dca076b228f7db32c7a0046652d5b37e88aa2 Mon Sep 17 00:00:00 2001 From: aledudek Date: Thu, 18 Sep 2025 01:43:41 +0200 Subject: [PATCH 17/28] [CK_TILE] Fix batched_gemm tests for gfx950 (#2869) --- test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc index b2f965764d..035377734b 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc +++ b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc @@ -23,13 +23,16 @@ TYPED_TEST(TestCkTileBatchedGemm, Basic) std::vector gemmParams{{256, 256, 256, 1}, {256, 256, 256, 2}, {256, 256, 512, 2}, - {256, 256, 128, 2}, {256, 256, 64, 2}, {256, 256, 64, 3}, {256, 256, 64, 4}, {256, 256, 64, 8}, {256, 256, 64, 16}}; + if(ck_tile::get_device_name() != "gfx950") { + gemmParams.emplace_back(256, 256, 128, 2); + } + for(auto& params : gemmParams) { std::vector strideConfigs{{params.K, From 7ee7915e94ab0a8b13b734921978cdf226a4fd72 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Thu, 18 Sep 2025 16:51:21 +0800 Subject: [PATCH 18/28] [CK_TILE] FMHA Test Ignore Known Errors (#2872) --- .../script/fmha_bwd_known_fails_gfx90a.txt | 2 + .../script/fmha_bwd_known_fails_gfx942.txt | 2 + .../script/fmha_bwd_known_fails_gfx950.txt | 31 +++++++++ .../script/fmha_fwd_known_fails_gfx90a.txt | 0 .../script/fmha_fwd_known_fails_gfx942.txt | 0 .../script/fmha_fwd_known_fails_gfx950.txt | 4 ++ .../ck_tile/01_fmha/script/smoke_test_bwd.sh | 57 +++++++++++++-- .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 69 +++++++++++++++---- 8 files changed, 144 insertions(+), 21 deletions(-) create mode 100644 example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt create mode 100644 example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt create mode 100644 example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt create mode 100644 example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx90a.txt create mode 100644 example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx942.txt create mode 100644 example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt new file mode 100644 index 0000000000..ea601ec002 --- /dev/null +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt @@ -0,0 +1,2 @@ +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt new file mode 100644 index 0000000000..ea601ec002 --- /dev/null +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt @@ -0,0 +1,2 @@ +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt new file mode 100644 index 0000000000..1497d491bb --- /dev/null +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt @@ -0,0 +1,31 @@ +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx90a.txt b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx90a.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx942.txt b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx942.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt new file mode 100644 index 0000000000..90c5e2b7fb --- /dev/null +++ b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt @@ -0,0 +1,4 @@ +tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 +tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 +tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 +tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh index d123f842a2..3b59505ff0 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -2,13 +2,35 @@ # TODO: run this script from CK root or build directory set -euo pipefail -EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)" +SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) +EXE_NAME=tile_example_fmha_bwd +EXE="$(find . -name $EXE_NAME -type f | head -n 1)" KNAME=1 +GPU_arch=$GPU_arch +if [ -z "$GPU_arch" ] ; then + GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') +fi export CK_WARMUP=0 export CK_REPEAT=1 +CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_bwd_fails_$GPU_arch.txt"} +rm -f $CURR_FAILS_FILE +touch $CURR_FAILS_FILE +KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_bwd_known_fails_$GPU_arch.txt"} + COMMON_ARGS='-v=1' + +run_exe() { + set +ex + $EXE $@ + local ret=$? + if [ $ret -ne 0 ] ; then + echo "$EXE_NAME $*" >> $CURR_FAILS_FILE + fi + set -ex +} + set -x for prec in "fp16" "bf16" ; do for perm in 0 1 ; do @@ -19,12 +41,12 @@ for dbias in 0 ; do for p_drop in 0.0 0.2 ; do for deterministic in 0 ; do -$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +run_exe -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +run_exe -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +run_exe -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +run_exe -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +run_exe -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +run_exe -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS done done @@ -35,3 +57,24 @@ done done done set +x + +new_fails_count=0 +known_fails_count=0 +if [ -f $KNOWN_FAILS_FILE ] ; then + echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):" + while IFS= read -r line; do + if grep -Fxq "$line" $KNOWN_FAILS_FILE; then + echo "Known fail: $line" + known_fails_count=$(($known_fails_count + 1)) + else + echo "New fail: $line" + new_fails_count=$(($new_fails_count + 1)) + fi + done < $CURR_FAILS_FILE +else + new_fails_count=$(wc -l < $CURR_FAILS_FILE) + echo "No known fails file, all fails ($new_fails_count) are new:" + cat $CURR_FAILS_FILE +fi +echo "New fails count: $new_fails_count; Known fails count: $known_fails_count" +exit $(($new_fails_count != 0)) diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index dda3943454..c087a1fb3e 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -2,12 +2,23 @@ # TODO: run this script from CK root or build directory set -euo pipefail -EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)" +SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) +EXE_NAME=tile_example_fmha_fwd +EXE="$(find . -name $EXE_NAME -type f | head -n 1)" KNAME=1 +GPU_arch=$GPU_arch +if [ -z "$GPU_arch" ] ; then + GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') +fi export CK_WARMUP=0 export CK_REPEAT=1 +CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_fwd_fails_$GPU_arch.txt"} +rm -f $CURR_FAILS_FILE +touch $CURR_FAILS_FILE +KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_fwd_known_fails_$GPU_arch.txt"} + COMMON_ARGS='-v=1 -warmup=0 -repeat=1' # mode=0 # export HIP_VISIBLE_DEVICES=4 @@ -30,6 +41,16 @@ while getopts ":sa" opt; do esac done +run_exe() { + set +ex + $EXE $@ + local ret=$? + if [ $ret -ne 0 ] ; then + echo "$EXE_NAME $*" >> $CURR_FAILS_FILE + fi + set -ex +} + run_fp16_bf16_tests() { local NUM_SPLITS="1" local PAGE_BLOCK_SIZE="0" @@ -52,16 +73,16 @@ run_fp16_bf16_tests() { for page_block_size in $PAGE_BLOCK_SIZE ; do for cache_batch_idx in $CACHE_BATCH_IDX ; do - # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + # run_exe -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS done ; done ; done ; done ; done done ; done ; done ; done ; done @@ -73,8 +94,7 @@ run_fp8_tests() { for b in 1 2 ; do for hdim in 64 128 256 ; do - $EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS - + run_exe -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS done ; done ; done ; done } @@ -88,7 +108,7 @@ run_fp16_appendkv_tests() { for page_block_size in 0 128 ; do for cache_batch_idx in 0 1 ; do - $EXE -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS + run_exe -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS done ; done ; done ; done ; done done ; done ; done @@ -104,3 +124,24 @@ if [ $TEST_APPENDKV -eq 1 ] ; then fi set +x + +new_fails_count=0 +known_fails_count=0 +if [ -f $KNOWN_FAILS_FILE ] ; then + echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):" + while IFS= read -r line; do + if grep -Fxq "$line" $KNOWN_FAILS_FILE; then + echo "Known fail: $line" + known_fails_count=$(($known_fails_count + 1)) + else + echo "New fail: $line" + new_fails_count=$(($new_fails_count + 1)) + fi + done < $CURR_FAILS_FILE +else + new_fails_count=$(wc -l < $CURR_FAILS_FILE) + echo "No known fails file, all fails ($new_fails_count) are new:" + cat $CURR_FAILS_FILE +fi +echo "New fails count: $new_fails_count; Known fails count: $known_fails_count" +exit $(($new_fails_count != 0)) From 14bbc545ea672e66cdce00a3edbf4c532e2657e8 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Thu, 18 Sep 2025 09:12:37 -0500 Subject: [PATCH 19/28] Fix UB caused by reinterpret_cast (#2849) * Use bit_cast instead of reinterpret_cast to avoid UB * Apply same fix in ck_tile --- include/ck/utility/random_gen.hpp | 5 +++-- include/ck_tile/core/utility/random.hpp | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/include/ck/utility/random_gen.hpp b/include/ck/utility/random_gen.hpp index c37d3922ca..2ff46457fc 100644 --- a/include/ck/utility/random_gen.hpp +++ b/include/ck/utility/random_gen.hpp @@ -3,6 +3,7 @@ #pragma once #include +#include #include "ck/ck.hpp" #ifdef CK_CODE_GEN_RTC @@ -17,7 +18,7 @@ namespace ck { template {}, bool> = false> __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) { - uint32_t x = *(reinterpret_cast(&val)); + uint32_t x = bit_cast(val); uint32_t drop_bits = uint32_t(x) & 0xFFFFu; drop_bits ^= x >> 16; drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); @@ -33,7 +34,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = template {}, bool> = false> __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) { - uint16_t x = *(reinterpret_cast(&val)); + uint16_t x = bit_cast(val); uint32_t drop_bits = uint32_t(x) & 0xFFFFu; drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); drop_bits *= 0x7000149; diff --git a/include/ck_tile/core/utility/random.hpp b/include/ck_tile/core/utility/random.hpp index f7fbfad4dd..6a38ad3bde 100644 --- a/include/ck_tile/core/utility/random.hpp +++ b/include/ck_tile/core/utility/random.hpp @@ -24,7 +24,7 @@ struct prand_generator_t { CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_) { - uint32_t x = *(reinterpret_cast(&val)); + uint32_t x = bit_cast(val); uint32_t drop_bits = uint32_t(x) & 0xFFFFu; drop_bits ^= x >> 16; drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); @@ -43,7 +43,7 @@ struct prand_generator_t { CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_) { - uint16_t x = *(reinterpret_cast(&val)); + uint16_t x = bit_cast(val); uint32_t drop_bits = uint32_t(x) & 0xFFFFu; drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); drop_bits *= 0x7000149; From 30ab1d6a7108b6f9b4463f8e8183e223428222c0 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga <110818320+mozga-amd@users.noreply.github.com> Date: Fri, 19 Sep 2025 01:14:11 +0200 Subject: [PATCH 20/28] [CK_TILE] Multiple-ABD GEMM example (#2788) * Multi ABD - initial commit * Clang-foramt fix * block gemm, unify the name of CDataType * Apply chnages to mem-pipeline * Rollback prefix for DType and Layout * Gemm Kernel Basic, rename * WMMA config * Grouped GEMM * Clang-format * Dropout, name * Review v2 * Move element_wise fn to unnary, remov old ones fn * clang-format * Fix issue review * WP operator adjust to universal gemm * v2 prepare * Remove unused comment * Remove vectorsize * Rollback * Adjust pipeline for abd * Shuffle argument * CI-fail fix quant * Fix ag_br pipeline * Failing tests * Typo * Single argument support --- CHANGELOG.md | 1 + .../ck_tile/22_gemm_multi_abd/CMakeLists.txt | 1 + example/ck_tile/22_gemm_multi_abd/README.md | 35 ++ .../22_gemm_multi_abd/gemm_multi_abd_fp16.cpp | 184 +++++++ .../22_gemm_multi_abd/gemm_multi_abd_fp16.hpp | 186 +++++++ .../run_gemm_multi_abd_fp16_example.inc | 311 +++++++++++ example/ck_tile/22_gemm_multi_abd/utils.hpp | 38 ++ example/ck_tile/CMakeLists.txt | 1 + include/ck_tile/core/tensor/load_tile.hpp | 23 + include/ck_tile/core/tensor/tile_window.hpp | 143 +++++ .../ck_tile/host/reference/reference_gemm.hpp | 75 +++ .../unary_element_wise_operation.hpp | 17 + .../ops/epilogue/cshuffle_epilogue.hpp | 27 +- .../ops/epilogue/default_2d_epilogue.hpp | 32 +- include/ck_tile/ops/gemm.hpp | 1 + .../ops/gemm/kernel/batched_gemm_kernel.hpp | 4 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 8 +- .../ops/gemm/kernel/gemm_multi_abd_kernel.hpp | 193 +++++++ .../ops/gemm/kernel/gemm_multi_d_kernel.hpp | 8 +- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 12 +- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 42 +- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 129 ++++- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 253 ++++++--- .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 311 +++++++---- .../gemm_pipeline_ag_bg_cr_comp_v5.hpp | 146 +++-- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 320 +++++++---- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 152 ++++-- .../gemm_pipeline_agmem_bgmem_creg_v2.hpp | 164 ++++-- .../gemm/pipeline/gemm_pipeline_problem.hpp | 137 +++-- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 28 +- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 28 +- .../wp_pipeline_agmem_bgmem_creg_v1.hpp | 60 ++- .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 65 ++- .../pipeline/tile_gemm_quant_traits.hpp | 4 + test/ck_tile/CMakeLists.txt | 1 + test/ck_tile/gemm_multi_abd/CMakeLists.txt | 12 + .../test_gemm_multi_abd_cshuffle.cpp | 40 ++ .../test_gemm_multi_abd_default2d.cpp | 41 ++ .../test_gemm_multi_abd_ut_cases_cshuffle.inc | 211 ++++++++ ...test_gemm_multi_abd_ut_cases_default2d.inc | 211 ++++++++ .../test_gemm_multi_abd_util.hpp | 500 ++++++++++++++++++ 41 files changed, 3603 insertions(+), 552 deletions(-) create mode 100644 example/ck_tile/22_gemm_multi_abd/CMakeLists.txt create mode 100644 example/ck_tile/22_gemm_multi_abd/README.md create mode 100644 example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp create mode 100644 example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp create mode 100644 example/ck_tile/22_gemm_multi_abd/run_gemm_multi_abd_fp16_example.inc create mode 100644 example/ck_tile/22_gemm_multi_abd/utils.hpp create mode 100644 include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp create mode 100644 test/ck_tile/gemm_multi_abd/CMakeLists.txt create mode 100644 test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp create mode 100644 test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp create mode 100644 test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc create mode 100644 test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_default2d.inc create mode 100644 test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 38669385f3..dafe1b5c87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). * Added support for Stream-K version of mixed fp8/bf16 GEMM * Added support for Multiple D GEMM +* Added support for Multiple ABD GEMM * Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types * Added support for FP16 2:4 structured sparsity to universal GEMM. * Added support for Split K for grouped convolution backward data. diff --git a/example/ck_tile/22_gemm_multi_abd/CMakeLists.txt b/example/ck_tile/22_gemm_multi_abd/CMakeLists.txt new file mode 100644 index 0000000000..f382e0cf45 --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/CMakeLists.txt @@ -0,0 +1 @@ +add_executable(tile_example_gemm_multi_abd_fp16 EXCLUDE_FROM_ALL gemm_multi_abd_fp16.cpp) diff --git a/example/ck_tile/22_gemm_multi_abd/README.md b/example/ck_tile/22_gemm_multi_abd/README.md new file mode 100644 index 0000000000..c272df3fb5 --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/README.md @@ -0,0 +1,35 @@ +#Multiple ABD GEMM + +This folder contains example for Multiple ABD GEMM using ck_tile tile-programming implementation. + +## build +``` +#in the root of ck_tile +mkdir build && cd build +#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \ + leave it blank +sh ../script/cmake-ck-dev.sh ../ +#The basic pipeline method on the gemm calculation +make tile_example_gemm_multi_abd_fp16 -j +``` +This will result in an executable `build/bin/tile_example_gemm_multi_abd_fp16` + +## example +``` +args: + -m M dimensions - (Default: 3840) + -n N dimensions - (Default: 4096) + -k K dimensions - (Default: 4096) +-as_layout Tensor A layout (default:R) +-bs_layout Tensor B layout (default:C) +-ds_layout Tensor D layout (default:R) +-e_layout Tensor E layout (default:R) +-stride_as Tensor A strides - (Default: 0) +-stride_bs Tensor B strides - (Default: 0) +-stride_e Tensor C strides - (Default: 0) +-stride_ds Tensor D strides - (Default: 0) +-validate 0. No validation, 1. Validation on GPU. (Default: 1) + -warmup Number of iterations before benchmark the kernel. (Default: 10) + -repeat Number of iterations to benchmark the kernel. (Default: 100) + -kbatch kbatch for SplitK. (Default: 1) +``` \ No newline at end of file diff --git a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp new file mode 100644 index 0000000000..6d955c3a09 --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "gemm_multi_abd_fp16.hpp" +#include "utils.hpp" + +template +auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_config& s) -> float +{ + constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; + constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile; + constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile; + + constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp; + constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp; + constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp; + + constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; + + constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer; + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GemmKernelMultiABD; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " + << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " + << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + + return ave_time; +} + +#include "run_gemm_multi_abd_fp16_example.inc" + +int main(int argc, char* argv[]) +{ +#if CK_TILE_USE_WMMA + return !run_multiple_abd_gemm_example(argc, argv); +#else + return !run_multiple_abd_gemm_example(argc, argv); +#endif +} diff --git a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp new file mode 100644 index 0000000000..35bc232eca --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 +#endif + +using A0DataType = ck_tile::half_t; +using A1DataType = ck_tile::half_t; + +using B0DataType = ck_tile::half_t; +using B1DataType = ck_tile::half_t; + +using D0DataType = ck_tile::half_t; +using D1DataType = ck_tile::half_t; + +using EDataType = ck_tile::half_t; + +using AsDataType = ck_tile::tuple; +using BsDataType = ck_tile::tuple; +using DsDataType = ck_tile::tuple; + +using AccDataType = float; + +struct GemmConfigMemory +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 8; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +struct GemmConfigV3 +{ + // Compute friendly for Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +struct GemmConfigV4 +{ + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 32; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +struct GemmConfigV3_Wmma +{ + // Compute friendly for Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "4096", "k dimension") + .insert("as_layout", "R", "As tensor data layout - Row by default") + .insert("bs_layout", "C", "Bs tensor data layout - Col by default") + .insert("ds_layout", "R", "Ds tensor data layout - Row by default") + .insert("e_layout", "R", "E tensor data layout - Row by default") + .insert("stride_as", "0", "Tensor A stride") + .insert("stride_bs", "0", "Tensor B stride") + .insert("stride_ds", "0", "Tensor Ds stride") + .insert("stride_e", "0", "Tensor E stride") + .insert("v", "1", "0. No validation, 1. Validation on GPU") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("kbatch", "1", "kbatch for SplitK"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} +using gemm_multi_abd_kargs = + ck_tile::GemmMultiABDHostArgs; + +template +float gemm_multi_abd(const gemm_multi_abd_kargs& kargs, const ck_tile::stream_config& s); diff --git a/example/ck_tile/22_gemm_multi_abd/run_gemm_multi_abd_fp16_example.inc b/example/ck_tile/22_gemm_multi_abd/run_gemm_multi_abd_fp16_example.inc new file mode 100644 index 0000000000..881961c9db --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/run_gemm_multi_abd_fp16_example.inc @@ -0,0 +1,311 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include + +template +float invoke_gemm_multi_abd(const std::array& as_m_k_dev_buf, + const std::array& bs_k_n_dev_buf, + const std::array& ds_m_n_dev_buf, + void* e_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + const std::array& StrideAs, + const std::array& StrideBs, + const std::array& StrideDs, + ck_tile::index_t StrideE, + int n_warmup, + int n_repeat, + int k_batch) +{ + gemm_multi_abd_kargs gemm_descs({as_m_k_dev_buf, + bs_k_n_dev_buf, + ds_m_n_dev_buf, + e_m_n_dev_buf, + k_batch, + M, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE}); + + float ave_time = gemm_multi_abd( + gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::string op_name{"Gemm Multiple-ABD"}; + + std::size_t flop = 0, num_btype = 0; + + flop += std::size_t(2) * M * N * K; + + num_btype += + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Run Gemm Multiple-ABD kernel with:\n"; + std::cout << "M =" << M << " N =" << N << " K =" << K << "\n"; + std::cout << "StrideA = " << StrideAs[0] << " StrideB = " << StrideBs[0] + << " StrideE = " << StrideE << "\n"; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << "\n"; + + return ave_time; +} + +template +int run_gemm_multi_abd_example_with_layouts(int argc, + char* argv[], + const A0Layout a0_layout = A0Layout{}, + const A1Layout a1_layout = A1Layout{}, + const B0Layout b0_layout = B0Layout{}, + const B1Layout b1_layout = B1Layout{}, + const D0Layout d0_layout = D0Layout{}, + const D1Layout d1_layout = D1Layout{}, + const ELayout e_layout = ELayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + using AElementWiseFn = ck_tile::element_wise::AddScale; + using BElementWiseFn = ck_tile::element_wise::AddScale; + using CDEElementWiseFn = ck_tile::element_wise::MultiDMultiply; + using AsLayout = ck_tile::tuple; + using BsLayout = ck_tile::tuple; + using DsLayout = ck_tile::tuple; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t StrideA = arg_parser.get_int("stride_as"); + ck_tile::index_t StrideB = arg_parser.get_int("stride_bs"); + ck_tile::index_t StrideD = arg_parser.get_int("stride_ds"); + ck_tile::index_t StrideE = arg_parser.get_int("stride_e"); + + ck_tile::index_t StrideA0 = StrideA; + ck_tile::index_t StrideA1 = StrideA; + + ck_tile::index_t StrideB0 = StrideB; + ck_tile::index_t StrideB1 = StrideB; + + ck_tile::index_t StrideD0 = StrideD; + ck_tile::index_t StrideD1 = StrideD; + + const int n_warmup = arg_parser.get_int("warmup"); + const int n_repeat = arg_parser.get_int("repeat"); + const int k_batch = arg_parser.get_int("kbatch"); + + StrideA0 = get_default_stride(M, N, StrideA0, is_row_major(a1_layout)); + StrideA1 = get_default_stride(M, N, StrideA1, is_row_major(a1_layout)); + + StrideB0 = get_default_stride(K, N, StrideB0, is_row_major(b0_layout)); + StrideB1 = get_default_stride(K, N, StrideB1, is_row_major(b1_layout)); + + StrideD0 = get_default_stride(M, N, StrideD0, is_row_major(d0_layout)); + StrideD1 = get_default_stride(M, N, StrideD1, is_row_major(d1_layout)); + + StrideE = get_default_stride(M, N, StrideE, is_row_major(e_layout)); + + ck_tile::HostTensor a0_m_k_tesnor( + host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout))); + ck_tile::HostTensor a1_m_k_tesnor( + host_tensor_descriptor(M, K, StrideA1, is_row_major(a1_layout))); + + ck_tile::HostTensor b0_k_n_tensors( + host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout))); + ck_tile::HostTensor b1_k_n_tensors( + host_tensor_descriptor(K, N, StrideB1, is_row_major(b1_layout))); + + ck_tile::HostTensor d0_m_n_tensors( + host_tensor_descriptor(M, N, StrideD0, is_row_major(d0_layout))); + ck_tile::HostTensor d1_m_n_tensors( + host_tensor_descriptor(M, N, StrideD1, is_row_major(d1_layout))); + + ck_tile::HostTensor e_m_n_device_result( + host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout))); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a0_m_k_tesnor); + ck_tile::FillUniformDistribution{-1.f, 1.f}(a1_m_k_tesnor); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(b0_k_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b1_k_n_tensors); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors); + + ck_tile::DeviceMem a0_m_k_dev_buf(a0_m_k_tesnor.get_element_space_size_in_bytes()); + ck_tile::DeviceMem a1_m_k_dev_buf(a1_m_k_tesnor.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem b0_k_n_dev_buf(b0_k_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b1_k_n_dev_buf(b1_k_n_tensors.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes()); + + a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data()); + a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data()); + + b0_k_n_dev_buf.ToDevice(b0_k_n_tensors.mData.data()); + b1_k_n_dev_buf.ToDevice(b1_k_n_tensors.mData.data()); + + d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data()); + d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data()); + + e_m_n_dev_buf.SetZero(); + e_m_n_device_result.SetZero(); + + std::array as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(), + a1_m_k_dev_buf.GetDeviceBuffer()}; + + std::array bs_ptr_buf = {b0_k_n_dev_buf.GetDeviceBuffer(), + b1_k_n_dev_buf.GetDeviceBuffer()}; + + std::array ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(), + d1_m_n_dev_buf.GetDeviceBuffer()}; + + std::array strideAs = {StrideA0, StrideA1}; + std::array strideBs = {StrideB0, StrideB1}; + std::array strideDs = {StrideD0, StrideD1}; + + invoke_gemm_multi_abd(as_ptr_buf, + bs_ptr_buf, + ds_ptr_buf, + e_m_n_dev_buf.GetDeviceBuffer(), + M, + N, + K, + strideAs, + strideBs, + strideDs, + StrideE, + n_warmup, + n_repeat, + k_batch); + + e_m_n_dev_buf.FromDevice(e_m_n_device_result.data()); + + ck_tile::HostTensor a_m_k_host_ref_element_result( + host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout))); + ck_tile::HostTensor b_k_n_host_ref_element_result( + host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout))); + ck_tile::HostTensor e_m_n_host_ref( + host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout))); + a_m_k_host_ref_element_result.SetZero(); + b_k_n_host_ref_element_result.SetZero(); + e_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_multiple_abd({a0_m_k_tesnor, a1_m_k_tesnor}, + {b0_k_n_tensors, b1_k_n_tensors}, + {d0_m_n_tensors, d1_m_n_tensors}, + a_m_k_host_ref_element_result, + b_k_n_host_ref_element_result, + e_m_n_host_ref); + + bool pass{true}; + if(arg_parser.get_int("v")) + { + const float max_accumulated_value = + *std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end()); + + const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value); + + pass &= ck_tile::check_err(e_m_n_device_result, + e_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << std::endl; + std::cout << "Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + } + return pass; +} + +template +int run_multiple_abd_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + + const std::string as_layout = arg_parser.get_str("as_layout"); + const std::string bs_layout = arg_parser.get_str("bs_layout"); + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if(as_layout == "R" && bs_layout == "C") + { + return run_gemm_multi_abd_example_with_layouts( + argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} diff --git a/example/ck_tile/22_gemm_multi_abd/utils.hpp b/example/ck_tile/22_gemm_multi_abd/utils.hpp new file mode 100644 index 0000000000..38bf8623d4 --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/utils.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeTypeAB = + std::conditional_t; + + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 8fce70ba04..75d32a5eb0 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -21,6 +21,7 @@ add_subdirectory(18_flatmm) add_subdirectory(19_gemm_multi_d) add_subdirectory(20_grouped_convolution) add_subdirectory(21_elementwise) +add_subdirectory(22_gemm_multi_abd) add_subdirectory(35_batched_transpose) add_subdirectory(38_block_scale_gemm) add_subdirectory(39_copy) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 8b7541bf23..c7c4702e22 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -26,6 +26,29 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, return tile_window.load(number{}, bool_constant{}); } +/** + * @brief Load tile with elementwise function + * + * @note This function is a modification of the existing load function. + * It has been extended with two additional parameters: it takes a tuple as input + * and an elementwise function. For each A = A0, A1… AN, the elementwise function + * is additionally applied during a single read. + */ +template +CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, + ElementWise_ elementwise, + number = {}, + bool_constant = {}) +{ + // TODO: Tile windows should works with unknow number of params + // Load element_wise API works only when the input typle is a tuple-tyupe + return tile_window[number<0>{}].load( + tile_window, elementwise, number{}, bool_constant{}); +} + template + CK_TILE_DEVICE auto load(const TileWindow_& tile_window, + ElementWise_ elementwise, + number = {}, + bool_constant = {}) const + { + constexpr auto tile_dstr = typename Base::TileDstr{}; + auto dst_tensor = make_static_distributed_tensor(tile_dstr); + load(dst_tensor, + tile_window, + elementwise, + number{}, + bool_constant{}); + return dst_tensor; + } + + template + CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, + const TileWindow_& tile_window, + ElementWise_ elementwise, + number = {}, + bool_constant = {}) const + { + + using Traits = typename Base::Traits; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = typename Base::TileDstr{}; + constexpr auto sizeOfTuple = TileWindow_::size(); + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = + tile_window[number<0>{}].pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = + tile_window[number<0>{}].pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from bottom tensor + const auto idx_vec_value = generate_tuple( + [&](auto jj) { + return tile_window[number{}] + .get_bottom_tensor_view() + .template get_vectorized_elements( + bottom_tensor_thread_coord, + 0, + bool_constant{}); + }, + number{}); + + // write into distributed tensor + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + + ck_tile::apply( + [&](auto&&... t) { + elementwise(dst_tensor.get_thread_buffer().template at(), + t.template get_as< + typename Base::DataType>()[j / Traits::PackedSize]...); + }, + idx_vec_value); + }); + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); + + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + template @@ -857,6 +967,39 @@ CK_TILE_DEVICE void move_tile_window( window.move(step); } +template +CK_TILE_DEVICE void move_tile_window( + tuple>& window, + const typename tile_window_with_static_distribution::BottomTensorIndex& step) +{ + using T = tuple>; + + static constexpr auto N = T::size(); + static_for<0, N, 1>{}([&](auto Is) { window[number{}].move(step); }); +} + +template ::value>* = nullptr> +CK_TILE_DEVICE void move_tile_window(TileWindowWithStaticDistributionType& window, StepType& step) +{ + static constexpr auto N = TileWindowWithStaticDistributionType::size(); + static_for<0, N, 1>{}([&](auto Is) { window[number{}].move(step); }); +} + /** * @brief This class provides description of tile windowed view on the device memory. * diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index caa00e5994..d9379b4420 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -261,6 +261,81 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); } +template >, + typename BDataType = remove_cvref_t>, + typename DDataType = remove_cvref_t>> +CK_TILE_HOST void +reference_gemm_multiple_abd(const std::array, AsDataType::size()>& as_m_k, + const std::array, BsDataType::size()>& bs_k_n, + const std::array, DsDataType::size()>& ds_m_n, + HostTensor& a_m_k, + HostTensor& b_k_n, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const CDElementOp& acc_element_op = {}) +{ + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); + + auto as_m_k_tuple = + generate_tie([&](auto idx) -> auto& { return as_m_k[idx]; }, number{}); + + auto bs_k_n_tuple = + generate_tie([&](auto idx) -> auto& { return bs_k_n[idx]; }, number{}); + + auto ds_m_n_tuple = + generate_tie([&](auto idx) -> auto& { return ds_m_n[idx]; }, number{}); + + // Apply elementwise function to A + auto a_elementwise_fn = [&](auto i, auto j) { + ck_tile::apply([&](auto&&... t) { a_element_op(a_m_k(i, j), t(i, j)...); }, as_m_k_tuple); + }; + + make_ParallelTensorFunctor(a_elementwise_fn, M, K)(std::thread::hardware_concurrency()); + + // Apply elementwise function to B + auto b_elementwise_fn = [&](auto i, auto j) { + ck_tile::apply([&](auto&&... t) { b_element_op(b_k_n(i, j), t(i, j)...); }, bs_k_n_tuple); + }; + + make_ParallelTensorFunctor(b_elementwise_fn, K, N)(std::thread::hardware_concurrency()); + + auto f_mk_kn_mn = [&](auto m, auto n) { + AccDataType v_acc = 0; + for(std::size_t k = 0; k < K; ++k) + { + ADataType v_a = a_m_k(m, k); + BDataType v_b = b_k_n(k, n); + v_acc += + ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); + } + + CDataType v_c = 0; + + ck_tile::apply( + [&](auto&&... t) { + acc_element_op(v_c, + ck_tile::type_convert(v_acc), + ck_tile::type_convert(t(m, n))...); + }, + ds_m_n_tuple); + + c_m_n(m, n) = ck_tile::type_convert(v_c); + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency()); +} + template + CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const As&... as) const + { + // Start with the base value c + float result = ck_tile::type_convert(0.0f); + + // Add by each D parameter using fold expression + ((result += ck_tile::type_convert(as)), ...); + + a = ck_tile::type_convert(scale * result); + } + + float scale = 1.0; +}; + struct MultiDMultiply { template diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 628af0e0b3..ebd97c1c66 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -28,8 +28,8 @@ struct GetDataType using type = typename T::DataType; // Use T::ScaleN::DataType }; -template struct CShuffleEpilogueProblem { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using DsDataType = remove_cvref_t; @@ -83,12 +83,27 @@ template struct CShuffleEpilogue { using Problem = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using DsDataType = remove_cvref_t; using DsLayout = remove_cvref_t; + + static constexpr bool ADataTypeIsTuple = is_detected::value; + static constexpr bool BDataTypeIsTuple = is_detected::value; + + using AsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using BsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; + using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; + using ATypeToUse = std::conditional_t, BDataType, ADataType>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 54becd3c0f..2843966cd7 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -28,8 +28,8 @@ struct Default2DEpilogueProblem static constexpr index_t NumDTensor = 0; }; -template { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using CLayout = remove_cvref_t; using DsDataType = remove_cvref_t; using CDElementwise = remove_cvref_t; @@ -157,14 +157,28 @@ struct Default2DEpilogue template struct DefaultGemm2DEpilogue : public Default2DEpilogue { - using Problem = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; + using Problem = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool ADataTypeIsTuple = is_detected::value; + static constexpr bool BDataTypeIsTuple = is_detected::value; + + using AsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using BsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; + using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; + using DsDataType = remove_cvref_t; using DsLayout = remove_cvref_t; using CDElementwise = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index de13e305e0..6e07dbc00e 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -31,6 +31,7 @@ #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index fcfbf9635f..588d903b25 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -90,10 +90,10 @@ struct BatchedGemmKernel !is_detected::value && !is_detected::value, "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. static_assert(!is_detected::value && !is_detected::value, - "C/ELayout and C/EDataType must be scalars."); + "C/CLayout and C/EDataType must be scalars."); struct BatchedGemmKernelArgs : ck_tile::UniversalGemmKernelArgs<> { diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index e37b4f36d4..d632b1596c 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -89,7 +89,7 @@ struct GemmKernel /// @brief Specify the layout configurations for A, B, E and D using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; + using CLayout = remove_cvref_t; /// @brief Specify the data type configurations for A, B, E and D using ADataType = remove_cvref_t; @@ -106,10 +106,10 @@ struct GemmKernel !is_detected::value && !is_detected::value, "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. - static_assert(!is_detected::value && + /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && !is_detected::value, - "C/ELayout and C/EDataType must be scalars."); + "C/CLayout and C/EDataType must be scalars."); static constexpr index_t NumATensor = 1; static constexpr index_t NumBTensor = 1; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp new file mode 100644 index 0000000000..3b050e03ed --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +/// @brief The MultiABD GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref GemmKernelMultiABD "GemmKernelMultiABD" when creating +/// kernel arguments object. It contain all necessary information required to build proper +/// kernel argument and launch kernel on GPU. This structure defines the GEMM problem +/// configuration by stating all required information like M,N,K sizes and respective strides. +/// NumATensor describes the number of A tensors. The minimum number of tensors is 1(required). +/// NumBTensor describes the number of B tensors. The minimum number of tensors is 1(required). +/// NumDTensor describes the number of D tensors. The minimum number of tensors is 0(not +/// required). +template +struct GemmMultiABDHostArgs +{ + CK_TILE_HOST GemmMultiABDHostArgs(const std::array& as_ptr_, + const std::array& bs_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + const std::array& stride_As_, + const std::array& stride_Bs_, + const std::array& stride_Ds_, + index_t stride_E_) + : as_ptr(as_ptr_), + bs_ptr(bs_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_As(stride_As_), + stride_Bs(stride_Bs_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const std::array as_ptr; + const std::array bs_ptr; + const std::array ds_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + index_t M; + index_t N; + index_t K; + const std::array stride_As; + const std::array stride_Bs; + const std::array stride_Ds; + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + +template +struct GemmKernelMultiABD +{ + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using UniversalGemmKernel = + UniversalGemmKernel; + static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize; + + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + /// @brief Specify the layout configurations for A, B, E and D + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + /// @brief Specify the data type configurations for A, B, E and D + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using EDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + + /// @brief ALayout and ADataType are expected to be a tuple, not a scalar. + static_assert(is_detected::value && + is_detected::value, + "ALayout and ADataType must be a tuple."); + + /// @brief BLayout and BDataType are expected to be a tuple, not a scalar. + static_assert(is_detected::value && + is_detected::value, + "BLayout and BDataType must be a tuple."); + + /// @brief CLayout and EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "CLayout and EDataType must be a scalar."); + + /// @brief DsLayout and DsDataType are expected to be tuple, not a scalar. + static_assert(is_detected::value && + is_detected::value && + DsLayout::size() == DsDataType::size() && DsLayout::size() > 0, + "DsLayout and DsDataType must be tuples and must have the same size."); + + /// @brief The sizes of NumATensor, NumBTensor and NumDTensor is set by the user." + static constexpr index_t NumATensor = AsDataType::size(); + static constexpr index_t NumBTensor = BsDataType::size(); + static constexpr index_t NumDTensor = DsDataType::size(); + + CK_TILE_HOST static auto GetName() -> const std::string + { + return UniversalGemmKernel::GetName(); + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3 + { + return UniversalGemmKernel::GridSize(M, N, KBatch); + } + + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + return UniversalGemmKernel::MaxOccupancyGridSize(s); + } + + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 + { + return UniversalGemmKernel::BlockSize(); + } + + CK_TILE_HOST static constexpr auto + MakeKernelArgs(const GemmMultiABDHostArgs& hostArgs) -> + typename UniversalGemmKernel::KernelArgs + { + /// @brief Universal GEMM requires array objects and corresponding stride information for + /// matrices A, B, and D. + return UniversalGemmKernel::MakeKernelArgs( + UniversalGemmHostArgs(hostArgs.as_ptr, + hostArgs.bs_ptr, + hostArgs.ds_ptr, + hostArgs.e_ptr, + hostArgs.k_batch, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_As, + hostArgs.stride_Bs, + hostArgs.stride_Ds, + hostArgs.stride_E)); + } + + CK_TILE_HOST static auto + IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool + { + // Currently MultiABD kernel doesn't support k_batch > 1 + if(kargs.k_batch > 1) + { + return false; + } + + return UniversalGemmKernel::IsSupportedArgument(kargs); + } + + CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void + { + UniversalGemmKernel{}.template operator()(kargs); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp index 9d3ac8b901..b0b2905cb4 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp @@ -95,7 +95,7 @@ struct GemmKernelMultiD /// @brief Specify the layout configurations for A, B, E and D using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; + using CLayout = remove_cvref_t; using DsLayout = remove_cvref_t; /// @brief Specify the data type configurations for A, B, E and D @@ -114,10 +114,10 @@ struct GemmKernelMultiD !is_detected::value, "BLayout and BDataType must be scalars."); - /// @brief ELayout and EDataType are expected to be scalars, not a tuple. - static_assert(!is_detected::value && + /// @brief CLayout and EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && !is_detected::value, - "ELayout and EDataType must be scalars."); + "CLayout and EDataType must be scalars."); /// @brief DsLayout and DsDataType are expected to be tuple, not a scalar. static_assert(is_detected::value && diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index e38e49f5d1..df1d6c9e4f 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -120,10 +120,10 @@ struct GroupedGemmKernel !is_detected::value && !is_detected::value, "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. static_assert(!is_detected::value && !is_detected::value, - "C/ELayout and C/EDataType must be scalars."); + "C/CLayout and C/EDataType must be scalars."); using OffsetTile1DPartitioner = OffsettedTile1DPartitioner; using Kernel = GroupedGemmKernel; @@ -364,12 +364,8 @@ struct GroupedGemmKernel const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0], - b_block_window[Base::I0], - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0); + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(Base::I3); EpiloguePipeline{}.template diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index cfba8b6c9d..8f44108cc4 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -157,23 +157,23 @@ struct UniversalGemmKernel using EpiloguePipeline = remove_cvref_t; static constexpr bool ADataTypeIsTuple = - is_detected::value; + is_detected::value; static constexpr bool BDataTypeIsTuple = - is_detected::value; + is_detected::value; static constexpr bool DDataTypeIsTuple = is_detected::value; static constexpr bool ALayoutIsTuple = - is_detected::value; + is_detected::value; static constexpr bool BLayoutIsTuple = - is_detected::value; + is_detected::value; static constexpr bool DLayoutIsTuple = is_detected::value; using AsLayout = std::conditional_t, + remove_cvref_t, remove_cvref_t>>; using BsLayout = std::conditional_t, + remove_cvref_t, remove_cvref_t>>; using DsLayout = std::conditional_t>>; using AsDataType = std::conditional_t, + remove_cvref_t, remove_cvref_t>>; using BsDataType = std::conditional_t, + remove_cvref_t, remove_cvref_t>>; using DsDataType = @@ -193,9 +193,12 @@ struct UniversalGemmKernel remove_cvref_t, remove_cvref_t>>; - using ELayout = remove_cvref_t; + using CLayout = remove_cvref_t; using EDataType = remove_cvref_t; + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; // Get the persistent kernel if the pipeline has it available @@ -483,7 +486,7 @@ struct UniversalGemmKernel bool DTesnorIsValid = {true}; static_for<0, NumDTensor, 1>{}([&](auto index) { using DiLayout = remove_cvref_t>; - if(std::is_same_v == false) + if(std::is_same_v == false) { DTesnorIsValid = false; } @@ -529,7 +532,7 @@ struct UniversalGemmKernel } }); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) { @@ -724,7 +727,7 @@ struct UniversalGemmKernel // TODO: enable vector write for C in ColMajor const auto& e_tensor_view = [&]() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { return make_naive_tensor_view( e_ptr, @@ -818,7 +821,7 @@ struct UniversalGemmKernel // TODO vector write in for C in ColMajor const auto& e_pad_view = [&]() { const auto& e_tensor_view = views.at(I3); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { return pad_tensor_view(e_tensor_view, make_tuple(number{}, @@ -975,8 +978,8 @@ struct UniversalGemmKernel const auto& bs_block_window = gemm_tile_windows.at(I1); const auto& ds_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = - GemmPipeline{}(as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0); + const auto& c_block_tile = GemmPipeline{}.template operator()( + as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0); if(UseDefaultScheduler || (get_warp_id() == 0)) { @@ -1031,8 +1034,13 @@ struct UniversalGemmKernel const auto& bs_block_window = gemm_tile_windows.at(I1); const auto& ds_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}( - as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0, smem_ptr_1); + const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window, + AElementWise{}, + bs_block_window, + BElementWise{}, + num_loop, + smem_ptr_0, + smem_ptr_1); // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I3); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 2bee550b3c..b5584f98df 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -11,12 +11,17 @@ namespace ck_tile { template struct GemmPipelineAgBgCrImplBase { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + using ADataType = remove_cvref_t{}, AsDataType>>; + using ALayout = remove_cvref_t{}, AsLayout>>; + using BDataType = remove_cvref_t{}, BsDataType>>; + using BLayout = remove_cvref_t{}, BsLayout>>; + static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; @@ -57,6 +62,13 @@ struct GemmPipelineAgBgCrImplBase store_tile(lds_tile_window, block_tile_tmp); } + template + CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window, + const SrcBlockTile& src_block_tile) const + { + store_tile(lds_tile_window, src_block_tile); + } + template CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile, const SrcTileWindow& lds_tile_window, @@ -88,23 +100,100 @@ struct GemmPipelineAgBgCrImplBase return make_tuple(std::move(a_lds_block), std::move(b_lds_block)); } + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp, + const array& offset = {0, 0}) const + { + constexpr bool is_col_major = std::is_same_v; + + using YPerTile = std::conditional_t, number>; + using XPerTile = std::conditional_t, number>; + // A DRAM tile window for load + auto a_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(YPerTile{}, XPerTile{}), + dram_block_window_tmp[number{}].get_window_origin() + offset, + Policy::template MakeADramTileDistribution()); + }, + number{}); + return std::move(a_copy_dram_window); + } + + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp, + const array& offset = {0, 0}) const + { + constexpr bool is_col_major = std::is_same_v; + + using YPerTile = std::conditional_t, number>; + using XPerTile = std::conditional_t, number>; + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(YPerTile{}, XPerTile{}), + dram_block_window_tmp.get_window_origin() + offset, + Policy::template MakeADramTileDistribution()); + + return std::move(a_copy_dram_window); + } + + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp, + const array& offset = {0, 0}) const + { + constexpr bool is_row_major = std::is_same_v; + + using YPerTile = std::conditional_t, number>; + using XPerTile = std::conditional_t, number>; + // A DRAM tile window for load + auto a_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(YPerTile{}, XPerTile{}), + dram_block_window_tmp[number{}].get_window_origin() + offset, + Policy::template MakeBDramTileDistribution()); + }, + number{}); + return std::move(a_copy_dram_window); + } + + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp, + const array& offset = {0, 0}) const + { + constexpr bool is_row_major = std::is_same_v; + + using YPerTile = std::conditional_t, number>; + using XPerTile = std::conditional_t, number>; + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(YPerTile{}, XPerTile{}), + dram_block_window_tmp.get_window_origin() + offset, + Policy::template MakeBDramTileDistribution()); + + return std::move(a_copy_dram_window); + } + template CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, const ALdsTensorView& a_lds_block_view, const ALdsLoadTileDistr&, const array& offset = {0, 0}) const { - constexpr bool is_col_major = std::is_same_v; - - using YPerTile = std::conditional_t, number>; - using XPerTile = std::conditional_t, number>; - // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(YPerTile{}, XPerTile{}), - a_dram_block_window_tmp.get_window_origin() + offset, - Policy::template MakeADramTileDistribution()); + auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset); // A LDS tile window for store auto a_lds_shape = []() { @@ -138,16 +227,8 @@ struct GemmPipelineAgBgCrImplBase const BLdsLoadTileDistr&, const array& offset = {0, 0}) const { - constexpr bool is_row_major = std::is_same_v; - - using YPerTile = std::conditional_t, number>; - using XPerTile = std::conditional_t, number>; - - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(YPerTile{}, XPerTile{}), - b_dram_block_window_tmp.get_window_origin() + offset, - Policy::template MakeBDramTileDistribution()); + // A DRAM tile window for load + auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset); // TODO: Do we really need those two tile windows??? // They're exactly same... diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 5f4ee8987e..7159eda683 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -107,14 +107,23 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 using Base = BaseGemmPipelineAgBgCrCompV3; using PipelineImplBase = GemmPipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; using BlockGemm = remove_cvref_t())>; using I0 = number<0>; @@ -386,17 +395,25 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v auto block_gemm = BlockGemm(); auto c_block_tile = block_gemm.MakeCBlockTile(); - using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); - - using ABlockTile = - decltype(make_static_distributed_tensor(ABlockTileDistr{})); - using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); - - ABlockTile a_block_tile; - BBlockTile b_block_tile; - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; @@ -470,45 +476,61 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 // ----------------------------------------------------------------------------------------- // Gemm pipeline start - - // prefetch - // global read 0 - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + // LDS write 0 if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // global read 1 + + elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); __builtin_amdgcn_sched_barrier(0); @@ -520,38 +542,42 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { block_sync_lds(); - if constexpr(is_a_col_major && !is_a_load_tr_v()) + if constexpr(is_a_col_major) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } - if constexpr(is_b_row_major && !is_b_load_tr_v()) + if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + elementwise_As_res = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + elementwise_Bs_res = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); @@ -574,27 +600,26 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } // __builtin_amdgcn_sched_barrier(0); @@ -602,13 +627,16 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 } }; - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const @@ -628,9 +656,13 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 * @note This is used by the persistent gemm kernel variants that don't determine * hot loop and tail number on the host side, e.g. grouped gemm kernel. */ - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, @@ -639,7 +671,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr bool hot_loop = hot_loop_.value; constexpr auto tail_num = tail_num_.value; - constexpr auto PassThrough = [](const auto& x) { return x; }; + constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; }; return PipelineImpl{}.template operator()( a_dram_block_window_tmp, PassThrough, @@ -658,20 +690,97 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 * @note This is used by the kernel variants that are able to determine * hot loop and tail number on the host side, e.g. non-persistent gemm kernel. */ - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, void* p_smem) const { return PipelineImpl{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](auto& e, const ADataType& a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](auto& e, const BDataType& b) { e = b; }, num_loop, p_smem); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem); + } + + /** + * @brief Quant operator(), single input: This function runs the pipeline by wrapping it with + * the tail handler. + * + * @note This is used by the persistent gemm kernel variants that don't determine + * hot loop and tail number on the host side, e.g. grouped gemm kernel. + */ + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + has_hot_loop, + tail_number, + p_smem); + } + + /** + * @brief Quant operator(), single input: This function runs the pipeline using compile-time + * known hot loop and tail number. + * @param num_loop The number of loop iterations. This is determined at runtime due to e.g. + * SplitK. + * @note This is used by the kernel variants that are able to determine + * hot loop and tail number on the host side, e.g. non-persistent gemm kernel. + */ + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index c835809b5d..b362f751c6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -97,11 +97,24 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 using Base = BaseGemmPipelineAgBgCrCompV4; using PipelineImplBase = GemmPipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + static_assert(!std::is_same_v, "Not implemented"); static constexpr index_t APackedSize = @@ -109,10 +122,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 static constexpr index_t BPackedSize = ck_tile::numeric_traits>::PackedSize; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - using BlockGemm = remove_cvref_t())>; using I0 = number<0>; using I1 = number<1>; @@ -244,18 +253,26 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* __restrict__ p_smem_0, void* __restrict__ p_smem_1) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), "B block window has incorrect lengths for defined BLayout!"); - ////////////// global window & register ///////////////// - // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); - - // B DRAM tile window for load - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); - - // A register tile for global load - constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution(); - constexpr auto BBlockTileDistr = b_copy_dram_window.get_tile_distribution(); - using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr)); - using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr)); - ABlockTile a_global_load_tile; - BBlockTile b_global_load_tile; - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; @@ -312,8 +306,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // global prefetch 0 // global read 0 - Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + ////////////// LDS desc, window & register ///////////////// auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); @@ -343,34 +336,75 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + // Generating a tuple with tile_windows for values A0, A1, ... AN + auto a_tile_windows = generate_tuple( + [&](auto idx) { + return make_tile_window( + a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeADramTileDistribution()); + }, + number{}); + + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func); + + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(a_tile_windows, a_dram_tile_window_step); + + // Generating a tuple with tile_windows for values B0, B1, ... BN + auto b_tile_windows = generate_tuple( + [&](auto idx) { + return make_tile_window( + b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeBDramTileDistribution()); + }, + number{}); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func); + + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(b_tile_windows, b_dram_tile_window_step); + // LDS write 0 if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res); } // global read 1 - Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func); + move_tile_window(a_tile_windows, a_dram_tile_window_step); + + elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func); + move_tile_window(b_tile_windows, b_dram_tile_window_step); block_sync_lds(); constexpr auto ALdsTileDistr = @@ -423,27 +457,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res); } - Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func); + move_tile_window(a_tile_windows, a_dram_tile_window_step); + + elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func); + move_tile_window(b_tile_windows, b_dram_tile_window_step); if(HasHotLoop) { @@ -461,31 +500,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp); } else { - Base::LocalPrefill( - a_copy_lds_window0, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp); } else { - Base::LocalPrefill( - b_copy_lds_window0, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res); } - Base::GlobalPrefetch( - a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + elementwise_As_res = + load_tile_with_elementwise(a_tile_windows, a_element_func); + move_tile_window(a_tile_windows, a_dram_tile_window_step); + + elementwise_Bs_res = + load_tile_with_elementwise(b_tile_windows, b_element_func); + move_tile_window(b_tile_windows, b_dram_tile_window_step); // gemm block_gemm(c_block_tile, a_block_tile0, b_block_tile0); HotLoopScheduler(); @@ -501,32 +541,34 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp); } else { - Base::LocalPrefill( - a_copy_lds_window1, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp); } else { - Base::LocalPrefill( - b_copy_lds_window1, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res); } block_sync_lds(); - Base::GlobalPrefetch( - a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + elementwise_As_res = + load_tile_with_elementwise(a_tile_windows, a_element_func); + move_tile_window(a_tile_windows, a_dram_tile_window_step); + + elementwise_Bs_res = + load_tile_with_elementwise(b_tile_windows, b_element_func); + move_tile_window(b_tile_windows, b_dram_tile_window_step); + // gemm block_gemm(c_block_tile, a_block_tile1, b_block_tile1); HotLoopScheduler(); @@ -548,23 +590,23 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res); } block_gemm(c_block_tile, a_block_tile0, b_block_tile0); } @@ -606,13 +648,17 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 } }; - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem_0, @@ -628,27 +674,34 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 p_smem_1); } - public: - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const index_t num_loop, void* __restrict__ p_smem_0, void* __restrict__ p_smem_1) const { return PipelineImpl{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](auto& e, const ADataType& a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](auto& e, const BDataType& b) { e = b; }, num_loop, p_smem_0, p_smem_1); } - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, @@ -658,7 +711,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr bool hot_loop = hot_loop_.value; constexpr auto tail_num = tail_num_.value; - constexpr auto PassThrough = [](const auto& x) { return x; }; + constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; }; return PipelineImpl{}.template operator()( a_dram_block_window_tmp, PassThrough, @@ -670,5 +723,69 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem_0, + void* p_smem_1) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem_0, + p_smem_1); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const index_t num_loop, + void* __restrict__ p_smem_0, + void* __restrict__ p_smem_1) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem_0, + p_smem_1); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* __restrict__ p_smem_0, + void* __restrict__ p_smem_1) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + has_hot_loop, + tail_number, + p_smem_0, + p_smem_1); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index b83d37a790..474d1a5a21 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -41,15 +41,24 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 using Base = BaseGemmPipelineAgBgCrCompV5; using PipelineImplBase = GemmPipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using CDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; @@ -121,17 +130,25 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BsDramBlockWindowTmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* __restrict__ p_smem_0) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v BGemmTile b_tile_0, b_tile_1; // Register tile for A and B. - using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using ABlockTileDistr = + decltype(a_copy_dram_window[number<0>{}].get_tile_distribution()); + using BBlockTileDistr = + decltype(b_copy_dram_window[number<0>{}].get_tile_distribution()); using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); - ABlockTile a_global_load_tile; - BBlockTile b_global_load_tile; + ABlockTile elementwise_As_res; + BBlockTile elementwise_Bs_res; // Block GEMM auto block_gemm = BlockGemm(); @@ -248,33 +267,45 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 // define ping, pong steps here as lambda functions. auto MemoryOpsStep = [&](auto idx) { // Memory read half here. - Base::GlobalPrefetch( - a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each + // A0, A1, … AN. The values A0, A1, … AN are read by the same thread. + elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + // Move each A — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each + // B0, B1, … BN. The values B0, B1, … BN are read by the same thread. + elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + // Move each B — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); if constexpr(is_a_col_major) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } if(idx == 0) @@ -351,13 +382,17 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 } }; - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem_0) const @@ -371,21 +406,62 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 p_smem_0); } - public: - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const index_t num_loop, void* __restrict__ p_smem_0) const { return PipelineImpl{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](auto& e, const ADataType& a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](auto& e, const BDataType& b) { e = b; }, num_loop, p_smem_0); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem_0) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem_0); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const index_t num_loop, + void* __restrict__ p_smem_0) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem_0); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index e1acfebc47..9e522d4364 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -157,14 +157,23 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem using Base = BaseGemmPipelineAgBgCrMem; using PipelineImplBase = GemmPipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; using BlockGemm = remove_cvref_t())>; @@ -236,17 +245,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v auto block_gemm = BlockGemm(); auto c_block_tile = block_gemm.MakeCBlockTile(); - using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using ABlockTileDistr = + decltype(a_copy_dram_window[number<0>{}].get_tile_distribution()); + using BBlockTileDistr = + decltype(b_copy_dram_window[number<0>{}].get_tile_distribution()); using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); @@ -334,10 +353,21 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // prefetch // global read 0 - Base::GlobalPrefetch( - a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -348,32 +378,35 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{})); } // Global prefetch [1, PrefetchStages] static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window, - b_dram_tile_window_step); + a_block_tiles.at(number{}) = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + b_block_tiles.at(number{}) = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); }); // main body @@ -397,14 +430,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem transpose_tile2d( a_shuffle_tmp, a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { Base::LocalPrefill( a_copy_lds_window, - a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - a_element_func); + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { @@ -413,22 +445,23 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem transpose_tile2d( b_shuffle_tmp, b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { Base::LocalPrefill( b_copy_lds_window, - b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - b_element_func); + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); } - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window, - b_dram_tile_window_step); + a_block_tiles.at(number{}) = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + b_block_tiles.at(number{}) = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); }); i += PrefetchStages; @@ -450,26 +483,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { Base::LocalPrefill(a_copy_lds_window, - a_block_tiles.get(number{}), - a_element_func); + a_block_tiles.get(number{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { Base::LocalPrefill(b_copy_lds_window, - b_block_tiles.get(number{}), - b_element_func); + b_block_tiles.get(number{})); } }); @@ -526,17 +557,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v auto block_gemm = BlockGemm(); auto c_block_tile = block_gemm.MakeCBlockTile(); - using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using ABlockTileDistr = + decltype(a_copy_dram_window[number<0>{}].get_tile_distribution()); + using BBlockTileDistr = + decltype(b_copy_dram_window[number<0>{}].get_tile_distribution()); using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); @@ -623,10 +664,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // prefetch // global read 0 - Base::GlobalPrefetch( - a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -637,32 +690,35 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{})); } // Global prefetch [1, PrefetchStages] static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window, - b_dram_tile_window_step); + a_block_tiles.at(number{}) = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + b_block_tiles.at(number{}) = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); }); // main body @@ -687,14 +743,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem transpose_tile2d( a_shuffle_tmp, a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { Base::LocalPrefill( a_copy_lds_window, - a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - a_element_func); + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { @@ -703,22 +758,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem transpose_tile2d( b_shuffle_tmp, b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { Base::LocalPrefill( b_copy_lds_window, - b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - b_element_func); + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); } - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window, - b_dram_tile_window_step); + a_block_tiles.at(number{}) = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + b_block_tiles.at(number{}) = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); }); i += PrefetchStages; @@ -740,26 +797,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { Base::LocalPrefill(a_copy_lds_window, - a_block_tiles.get(number{}), - a_element_func); + a_block_tiles.get(number{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { Base::LocalPrefill(b_copy_lds_window, - b_block_tiles.get(number{}), - b_element_func); + b_block_tiles.get(number{})); } }); @@ -813,13 +868,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } }; - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const @@ -833,9 +891,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem p_smem); } - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, @@ -844,7 +906,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr bool hot_loop = hot_loop_.value; constexpr auto tail_num = tail_num_.value; - constexpr auto PassThrough = [](const auto& x) { return x; }; + constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; }; return PipelineImpl{}.template operator()( a_dram_block_window_tmp, PassThrough, @@ -856,20 +918,82 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, void* p_smem) const { return PipelineImpl{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](auto& e, const ADataType& a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](auto& e, const ADataType& a) { e = a; }, num_loop, p_smem); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + has_hot_loop, + tail_number, + p_smem); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index e3b4863392..eb363d59b8 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -15,14 +15,23 @@ namespace ck_tile { template struct GemmPipelineAGmemBGmemCRegV1 { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; using BlockGemm = remove_cvref_t())>; @@ -81,17 +90,25 @@ struct GemmPipelineAGmemBGmemCRegV1 return Policy::template GetSmemSize(); } - template - CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v>, @@ -133,22 +150,30 @@ struct GemmPipelineAGmemBGmemCRegV1 auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); + auto as_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeADramTileDistribution()); + }, + number{}); // A LDS tile window for store auto a_copy_lds_window = make_tile_window( a_lds_block, make_tuple(number{}, number{}), {0, 0}); // B DRAM tile window for load - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); + auto bs_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeBDramTileDistribution()); + }, + number{}); // B LDS tile window for store auto b_copy_lds_window = make_tile_window( @@ -182,13 +207,22 @@ struct GemmPipelineAGmemBGmemCRegV1 // prefetch // global read 0 - auto a_block_tile = load_tile(a_copy_dram_window); - auto b_block_tile = load_tile(b_copy_dram_window); + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); { // move to 1 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -198,13 +232,12 @@ struct GemmPipelineAGmemBGmemCRegV1 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); - store_tile(a_copy_lds_window, a_block_tile_tmp); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp); } else { - store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); + store_tile(a_copy_lds_window, elementwise_As_res); } // LDS write 0 @@ -212,13 +245,12 @@ struct GemmPipelineAGmemBGmemCRegV1 { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp); - store_tile(b_copy_lds_window, b_block_tile_tmp); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp); } else { - store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile)); + store_tile(b_copy_lds_window, elementwise_Bs_res); } } @@ -226,8 +258,8 @@ struct GemmPipelineAGmemBGmemCRegV1 while(iCounter > 0) { // global read i + 1 - a_block_tile = load_tile(a_copy_dram_window); - b_block_tile = load_tile(b_copy_dram_window); + elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); + elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); block_sync_lds(); @@ -237,22 +269,20 @@ struct GemmPipelineAGmemBGmemCRegV1 block_sync_lds(); // move to i + 2 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); // LDS write i + 1 if constexpr(is_a_col_major) { auto a_shuffle_tmp_loop = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp_loop, a_block_tile); - store_tile(a_copy_lds_window, - tile_elementwise_in(a_element_func, a_shuffle_tmp_loop)); + transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp_loop); } else { - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + store_tile(a_copy_lds_window, elementwise_As_res); } // LDS write i + 1 @@ -260,14 +290,12 @@ struct GemmPipelineAGmemBGmemCRegV1 { auto b_shuffle_tmp_loop = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp_loop, b_block_tile); - store_tile(b_copy_lds_window, - tile_elementwise_in(b_element_func, b_shuffle_tmp_loop)); + transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp_loop); } else { - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + store_tile(b_copy_lds_window, elementwise_Bs_res); } iCounter--; @@ -284,20 +312,40 @@ struct GemmPipelineAGmemBGmemCRegV1 return c_block_tile; } - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, void* p_smem) const { return operator()( a_dram_block_window_tmp, - [](const ADataType & a) { return a; }, + [](auto& e, const ADataType & a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType & b) { return b; }, + [](auto& e, const BDataType & b) { e = b; }, num_loop, p_smem); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index b151cd6782..c309f8908a 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -15,30 +15,66 @@ namespace ck_tile { template struct GemmPipelineAGmemBGmemCRegV2 { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + static constexpr index_t APackedSize = ck_tile::numeric_traits>::PackedSize; static constexpr index_t BPackedSize = ck_tile::numeric_traits>::PackedSize; - static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kKPerBlock = BlockGemmShape::kK; + template + static constexpr index_t GetVectorSizeA() + { + return Problem::VectorSizeA; + } + template + static constexpr index_t GetVectorSizeB() + { + return Problem::VectorSizeB; + } + static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool Preshuffle = Problem::Preshuffle; + + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. + static constexpr bool DoubleSmemBuffer = false; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off return concat('_', "pipeline_AGmemBGmemCRegV2", - concat('x', kMPerBlock, kNPerBlock, kKPerBlock, kBlockSize)); + concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize)); // clang-format on } CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } @@ -56,17 +92,31 @@ struct GemmPipelineAGmemBGmemCRegV2 BPackedSize; } - template (); + } + + template - CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const { + + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v>, @@ -98,32 +148,40 @@ struct GemmPipelineAGmemBGmemCRegV2 auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); + auto as_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeADramTileDistribution()); + }, + number{}); // A LDS tile window for store auto a_copy_lds_window = make_tile_window(a_lds_block, make_tuple(number{}, number{}), {0, 0}, - a_copy_dram_window.get_tile_distribution()); + as_copy_dram_window[number<0>{}].get_tile_distribution()); // B DRAM tile window for load - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); + auto bs_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeBDramTileDistribution()); + }, + number{}); // B LDS tile window for store auto b_copy_lds_window = make_tile_window(b_lds_block, make_tuple(number{}, number{}), {0, 0}, - b_copy_dram_window.get_tile_distribution()); + bs_copy_dram_window[number<0>{}].get_tile_distribution()); // Block GEMM constexpr auto block_gemm = Policy::template GetBlockGemm(); @@ -153,28 +211,30 @@ struct GemmPipelineAGmemBGmemCRegV2 // prefetch // global read 0 - auto a_block_tile = load_tile(a_copy_dram_window); - auto b_block_tile = load_tile(b_copy_dram_window); + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); { // move to 1 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + store_tile(a_copy_lds_window, elementwise_As_res); // global read 1 - a_block_tile = load_tile(a_copy_dram_window); + elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); // LDS write 0 - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + store_tile(b_copy_lds_window, elementwise_Bs_res); // global read 1 - b_block_tile = load_tile(b_copy_dram_window); + elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); } index_t iCounter = num_loop - 2; @@ -189,20 +249,18 @@ struct GemmPipelineAGmemBGmemCRegV2 block_sync_lds(); // move to i + 2 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); // LDS write i + 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + store_tile(a_copy_lds_window, elementwise_As_res); // global read i + 2 - a_block_tile = load_tile(a_copy_dram_window); + elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); // LDS write i + 1 - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + store_tile(b_copy_lds_window, elementwise_Bs_res); // global read i + 2 - b_block_tile = load_tile(b_copy_dram_window); + elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); iCounter--; @@ -218,11 +276,9 @@ struct GemmPipelineAGmemBGmemCRegV2 block_sync_lds(); // LDS write num_loop - 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + store_tile(a_copy_lds_window, elementwise_As_res); - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + store_tile(b_copy_lds_window, elementwise_Bs_res); block_sync_lds(); @@ -241,12 +297,28 @@ struct GemmPipelineAGmemBGmemCRegV2 { return operator()( a_dram_block_window_tmp, - [](const ADataType & a) { return a; }, + [](auto& e, const ADataType & a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType & b) { return b; }, + [](auto& e, const BDataType & b) { e = b; }, num_loop, p_smem); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 52bd07c9e2..c73fa29245 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -5,16 +5,19 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/host/concat.hpp" namespace ck_tile { -template @@ -22,18 +25,49 @@ struct GemmPipelineProblemBase { using Traits = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; // actually AccDataType - using ComputeDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; // actually AccDataType static constexpr bool FixedVectorSize = FixedVectorSize_; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + static constexpr bool ComputeDataTypeIsTuple = is_detected::value; + static constexpr bool ADataTypeIsTuple = is_detected::value; + static constexpr bool BDataTypeIsTuple = is_detected::value; + + static constexpr bool ALayoutIsTuple = is_detected::value; + static constexpr bool BLayoutIsTuple = is_detected::value; + + using ComputeDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + using AsLayoutTuple = std:: + conditional_t, remove_cvref_t>>; + using BsLayoutTuple = std:: + conditional_t, remove_cvref_t>>; + + using AsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using BsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using ComputeDataType = remove_cvref_t{}, ComputeDataTypeTuple>>; + using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; + using ALayout = remove_cvref_t{}, AsLayoutTuple>>; + using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; + using BLayout = remove_cvref_t{}, BsLayoutTuple>>; static constexpr bool TransposeC = Traits::TransposeC; static constexpr index_t NumWaveGroups = Traits::NumWaveGroups; @@ -66,7 +100,7 @@ struct GemmPipelineProblemBase { constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize; @@ -84,7 +118,7 @@ struct GemmPipelineProblemBase { constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize; @@ -125,7 +159,7 @@ struct GemmPipelineProblemBase { return VectorSizeA_; } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { return kPadK ? 1 : GetAlignmentA(); } @@ -140,7 +174,7 @@ struct GemmPipelineProblemBase { return VectorSizeB_; } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { return kPadN ? 1 : GetAlignmentB(); } @@ -161,35 +195,40 @@ struct GemmPipelineProblemBase }(); }; -// Alias for GemmPipelineProblem -template -using GemmPipelineProblem = GemmPipelineProblemBase; -template @@ -197,18 +236,48 @@ struct UniversalGemmPipelineProblem { using Traits = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; // actually AccDataType - using ComputeDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; // actually AccDataType + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; static constexpr bool FixedVectorSize = FixedVectorSize_; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + static constexpr bool ComputeDataTypeIsTuple = is_detected::value; + static constexpr bool ADataTypeIsTuple = is_detected::value; + static constexpr bool BDataTypeIsTuple = is_detected::value; + + static constexpr bool ALayoutIsTuple = is_detected::value; + static constexpr bool BLayoutIsTuple = is_detected::value; + + using ComputeDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + using AsLayoutTuple = std:: + conditional_t, remove_cvref_t>>; + using BsLayoutTuple = std:: + conditional_t, remove_cvref_t>>; + + using AsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using BsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using ComputeDataType = remove_cvref_t{}, ComputeDataTypeTuple>>; + using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; + using ALayout = remove_cvref_t{}, AsLayoutTuple>>; + using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; + using BLayout = remove_cvref_t{}, BsLayoutTuple>>; static constexpr bool TransposeC = Traits::TransposeC; static constexpr index_t NumWaveGroups = Traits::NumWaveGroups; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 8d47ab878e..c8f874acd6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -356,11 +356,14 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() { - using ALayout = remove_cvref_t; - using ADataType = remove_cvref_t; + using AsLayout = remove_cvref_t; + using AsDataType = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + using ALayout = remove_cvref_t{}, AsLayout>>; + using ADataType = remove_cvref_t{}, AsDataType>>; + if constexpr(std::is_same_v) { return GetGlobalVectorLoadSize CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() { - using BLayout = remove_cvref_t; - using BDataType = remove_cvref_t; + using BsLayout = remove_cvref_t; + using BsDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + using BLayout = remove_cvref_t{}, BsLayout>>; + using BDataType = remove_cvref_t{}, BsDataType>>; + if constexpr(std::is_same_v) { return GetGlobalVectorLoadSize CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { - using ALayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -491,6 +495,8 @@ struct UniversalGemmBasePolicy Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + using ALayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; // Tile: MPerBlock X KPerBlock if constexpr(std::is_same_v) { @@ -518,8 +524,6 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() { - using BLayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -527,6 +531,8 @@ struct UniversalGemmBasePolicy Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + using BLayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; // Tile: KPerBlock X NPerBlock if constexpr(std::is_same_v) { @@ -554,7 +560,8 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution() { - using ALayout = remove_cvref_t; + using ALayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; static_assert(std::is_same_v); constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; @@ -574,7 +581,8 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution() { - using BLayout = remove_cvref_t; + using BLayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; static_assert(std::is_same_v); constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 64900c9a97..96203b2cd2 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -10,8 +10,8 @@ namespace ck_tile { template struct TileGemmTraits @@ -23,9 +23,9 @@ struct TileGemmTraits // TODO this can't be hardcoded here! Should be in policy! static constexpr int _VectorSize = 16; - using ALayout = ALayout_; - using BLayout = BLayout_; - using CLayout = CLayout_; + using AsLayout = AsLayout_; + using BsLayout = BsLayout_; + using CLayout = CLayout_; static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; @@ -36,8 +36,8 @@ template @@ -76,8 +76,8 @@ using PersistentTileGemmUniversalTraits = TileGemmUniversalTraits { - using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV1; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV1; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; using BlockWeightPreshuffle = remove_cvref_t())>; @@ -188,7 +197,12 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 } } - template + template ::value && + !is_detected::value, + bool>* = nullptr> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, @@ -455,7 +469,33 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 return c_block_tile; } - template + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + [[maybe_unused]] const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + [[maybe_unused]] const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return operator()( + a_dram_block_window_tmp[number<0>{}], + [](const ADataType & a) { return a; }, + b_flat_dram_block_window_tmp[number<0>{}], + num_loop, + p_smem); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, @@ -463,7 +503,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 { return operator()( a_dram_block_window_tmp, - [](const ADataType & a) { return a; }, + [](auto& e, const ADataType & a) { e = a; }, b_flat_dram_block_window_tmp, num_loop, p_smem); diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 129eac6557..356ad91448 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -53,14 +53,23 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 { using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; using BlockWeightPreshuffle = remove_cvref_t())>; @@ -502,7 +511,10 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 template + typename AElementFunction, + typename std::enable_if_t::value && + !is_detected::value, + bool>* = nullptr> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, @@ -1001,8 +1013,37 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 return c_block_tile; } + // called from universal gemm kernel + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + [[maybe_unused]] const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + [[maybe_unused]] const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem_ping, + void* p_smem_pong) const + { + return operator()( + a_dram_block_window_tmp[number<0>{}], + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp[number<0>{}], + num_loop, + p_smem_ping, + p_smem_pong); + } + // called from general gemm kernel - template + template ::value && + !is_detected::value, + bool>* = nullptr> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, @@ -1019,9 +1060,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 } // called from grouped gemm kernel - template + template ::value && + !is_detected::value, + bool>* = nullptr> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_flat_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, TailNumber tail_number, void* __restrict__ p_smem_0, diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp index 44c6cd66c6..f505efe4e0 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp @@ -44,6 +44,10 @@ struct TileGemmQuantTraits using AQLayout = AQLayout_; using BQLayout = BQLayout_; + // TODO: It should be replaced to single value + using AsLayout = ALayout_; + using BsLayout = BLayout_; + static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; static constexpr index_t NumWaveGroups = 1; diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 9314d4b795..b08f0d8316 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm_preshuffle) add_subdirectory(gemm_multi_d) +add_subdirectory(gemm_multi_abd) add_subdirectory(gemm_streamk) add_subdirectory(data_type) add_subdirectory(container) diff --git a/test/ck_tile/gemm_multi_abd/CMakeLists.txt b/test/ck_tile/gemm_multi_abd/CMakeLists.txt new file mode 100644 index 0000000000..ac3b59d5d3 --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/CMakeLists.txt @@ -0,0 +1,12 @@ +# Currently ck_tile is only built on gfx9 +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") + add_gtest_executable(test_gemm_multi_abd_cshuffle test_gemm_multi_abd_cshuffle.cpp) + add_gtest_executable(test_gemm_multi_abd_default2d test_gemm_multi_abd_default2d.cpp) + target_compile_definitions(test_gemm_multi_abd_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_definitions(test_gemm_multi_abd_default2d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp new file mode 100644 index 0000000000..9821963458 --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_gemm_multi_abd_util.hpp" + +using F16 = ck_tile::half_t; +using BF16 = ck_tile::bf16_t; +using F32 = float; +using F8 = ck_tile::fp8_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +using KernelTypes = ::testing::Types< + // Has cshuffle epilogue enabled + // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGemmMultiABD, KernelTypes); + +#include "test_gemm_multi_abd_ut_cases_cshuffle.inc" diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp new file mode 100644 index 0000000000..b3a89aba05 --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_gemm_multi_abd_util.hpp" + +using F16 = ck_tile::half_t; +using BF16 = ck_tile::bf16_t; +using F32 = float; +using F8 = ck_tile::fp8_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +using KernelTypes = ::testing::Types< + // Has cshuffle epilogue disabled + // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGemmMultiABD, KernelTypes); + +#include "test_gemm_multi_abd_ut_cases_default2d.inc" diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc new file mode 100644 index 0000000000..5aa113608f --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc @@ -0,0 +1,211 @@ +#pragma once + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x512) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_default2d.inc b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_default2d.inc new file mode 100644 index 0000000000..cc7603164c --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_default2d.inc @@ -0,0 +1,211 @@ +#pragma once + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x512x512) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp new file mode 100644 index 0000000000..428bed4e25 --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp @@ -0,0 +1,500 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +struct AddScale +{ + template + CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const A0& a0, const A1& a1) const + { + a = scale * (ck_tile::type_convert(a0) + ck_tile::type_convert(a1)); + } + + float scale = 1.0; +}; + +struct MultiplyMultiply +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) * ck_tile::type_convert(d0) * + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +struct ElementWiseAddAdd +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) + ck_tile::type_convert(d0) + + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeTypeAB = + std::conditional_t; + + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +class TestCkTileGemmMultiABD : public ::testing::Test +{ + protected: + using A0Layout = std::tuple_element_t<0, Tuple>; + using A1Layout = std::tuple_element_t<1, Tuple>; + using B0Layout = std::tuple_element_t<2, Tuple>; + using B1Layout = std::tuple_element_t<3, Tuple>; + using D0Layout = std::tuple_element_t<4, Tuple>; + using D1Layout = std::tuple_element_t<5, Tuple>; + using ELayout = std::tuple_element_t<6, Tuple>; + using A0DataType = std::tuple_element_t<7, Tuple>; + using A1DataType = std::tuple_element_t<8, Tuple>; + using B0DataType = std::tuple_element_t<9, Tuple>; + using B1DataType = std::tuple_element_t<10, Tuple>; + using D0DataType = std::tuple_element_t<11, Tuple>; + using D1DataType = std::tuple_element_t<12, Tuple>; + using AccDataType = std::tuple_element_t<13, Tuple>; + using EDataType = std::tuple_element_t<14, Tuple>; + using AElementWiseFn = std::tuple_element_t<15, Tuple>; + using BElementWiseFn = std::tuple_element_t<16, Tuple>; + using CDElementWiseFn = std::tuple_element_t<17, Tuple>; + using UseCshuffleEpilog = std::tuple_element_t<18, Tuple>; + + using AsLayout = ck_tile::tuple; + using AsDataType = ck_tile::tuple; + using BsLayout = ck_tile::tuple; + using BsDataType = ck_tile::tuple; + using DsLayout = ck_tile::tuple; + using DsDataType = ck_tile::tuple; + + template + void invoke_gemm_multi_abd(const ck_tile::GemmMultiABDHostArgs& args, + const ck_tile::stream_config& s) + { + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; + + using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using GemmEpilogue = std:: + conditional_t; + + using Kernel = ck_tile::GemmKernelMultiABD; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + std::cout << "Run without SplitK" << std::endl; + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + std::cout << "Run using SplitK" << std::endl; + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + } + + public: + bool Run(const int M, + const int N, + const int K, + const int k_batch, + int StrideA0 = 0, + int StrideA1 = 0, + int StrideB0 = 0, + int StrideB1 = 0, + int StrideD0 = 0, + int StrideD1 = 0, + int StrideE = 0) + { + using namespace ck_tile::literals; + + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + StrideA0 = f_get_default_stride(M, K, StrideA0, A0Layout{}); + StrideA1 = f_get_default_stride(M, K, StrideA1, A1Layout{}); + + StrideB0 = f_get_default_stride(K, N, StrideB0, B0Layout{}); + StrideB1 = f_get_default_stride(K, N, StrideB1, B1Layout{}); + + StrideD0 = f_get_default_stride(M, N, StrideD0, D0Layout{}); + StrideD1 = f_get_default_stride(M, N, StrideD1, D1Layout{}); + + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + + ck_tile::HostTensor a0_m_k_tesnor( + f_host_tensor_descriptor(M, K, StrideA0, A0Layout{})); + ck_tile::HostTensor a1_m_k_tesnor( + f_host_tensor_descriptor(M, K, StrideA1, A1Layout{})); + + ck_tile::HostTensor b0_k_n_tensors( + f_host_tensor_descriptor(K, N, StrideB0, B0Layout{})); + ck_tile::HostTensor b1_k_n_tensors( + f_host_tensor_descriptor(K, N, StrideB1, B1Layout{})); + + ck_tile::HostTensor d0_m_n_tensors( + f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + ck_tile::HostTensor d1_m_n_tensors( + f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); + + ck_tile::HostTensor e_m_n_device_result( + f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a0_m_k_tesnor); + ck_tile::FillUniformDistribution{-1.f, 1.f}(a1_m_k_tesnor); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(b0_k_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b1_k_n_tensors); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors); + + ck_tile::DeviceMem a0_m_k_dev_buf(a0_m_k_tesnor.get_element_space_size_in_bytes()); + ck_tile::DeviceMem a1_m_k_dev_buf(a1_m_k_tesnor.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem b0_k_n_dev_buf(b0_k_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b1_k_n_dev_buf(b1_k_n_tensors.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes()); + + a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data()); + a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data()); + + b0_k_n_dev_buf.ToDevice(b0_k_n_tensors.mData.data()); + b1_k_n_dev_buf.ToDevice(b1_k_n_tensors.mData.data()); + + d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data()); + d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data()); + + e_m_n_dev_buf.SetZero(); + e_m_n_device_result.SetZero(); + + std::array as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(), + a1_m_k_dev_buf.GetDeviceBuffer()}; + + std::array bs_ptr_buf = {b0_k_n_dev_buf.GetDeviceBuffer(), + b1_k_n_dev_buf.GetDeviceBuffer()}; + + std::array ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(), + d1_m_n_dev_buf.GetDeviceBuffer()}; + + std::array strideAs = {StrideA0, StrideA1}; + std::array strideBs = {StrideB0, StrideB1}; + std::array strideDs = {StrideD0, StrideD1}; + + ck_tile::GemmMultiABDHostArgs + args({as_ptr_buf, + bs_ptr_buf, + ds_ptr_buf, + e_m_n_dev_buf.GetDeviceBuffer(), + k_batch, + M, + N, + K, + strideAs, + strideBs, + strideDs, + StrideE}); + + invoke_gemm_multi_abd(args, ck_tile::stream_config{nullptr, false}); + + std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA0 =" << StrideA0 << " StrideA1 =" << StrideA1 + << " StrideB0 =" << StrideB0 << " StrideB1 =" << StrideB1 + << " StrideE =" << StrideE << " StrideD0 =" << StrideD0 + << " StrideD1 =" << StrideD1 << std::endl; + + e_m_n_dev_buf.FromDevice(e_m_n_device_result.data()); + bool pass = true; + + ck_tile::HostTensor a_m_k_host_ref_element_result( + f_host_tensor_descriptor(M, K, StrideA0, A0Layout{})); + ck_tile::HostTensor b_k_n_host_ref_element_result( + f_host_tensor_descriptor(K, N, StrideB0, B0Layout{})); + ck_tile::HostTensor e_m_n_host_ref( + f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + a_m_k_host_ref_element_result.SetZero(); + b_k_n_host_ref_element_result.SetZero(); + e_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_multiple_abd({a0_m_k_tesnor, a1_m_k_tesnor}, + {b0_k_n_tensors, b1_k_n_tensors}, + {d0_m_n_tensors, d1_m_n_tensors}, + a_m_k_host_ref_element_result, + b_k_n_host_ref_element_result, + e_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + K, k_batch, max_accumulated_value); + pass = ck_tile::check_err(e_m_n_device_result, + e_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + + return pass; + } +}; From 47cd0d5cff77658adc1c9f184c012ec3496e8214 Mon Sep 17 00:00:00 2001 From: SamiAario-AMD Date: Fri, 19 Sep 2025 07:26:10 +0300 Subject: [PATCH 21/28] Add gemm weight preshuffle pk_int_t support (#2858) * Factor out the three separate copies of load_interleaved_pk_type into a common utility class * Add preprocessing with optional cache flushing and clearing of output for k_batch > 1 to the weight preshuffle GEMM example * Remove a duplicate function * Add support for B tensor type pk_int4_t for the weight preshuffle GEMM, with tests included * I4 support introduced more failing test cases that mirror the existing ones for F8 * Simplify the check for which tests to skip (they all have F8 as A tensor type) * Add a changelog entry * add the test for v2 wp pipeline, polish the code, add the support of int4 for v2 wp pipeline * have a workable version for atomic add * Revert "have a workable version for atomic add" This reverts commit 792377a590c26cfff9c8f545d9a9e8484a7422eb. --------- Co-authored-by: ThomasNing --- CHANGELOG.md | 1 + .../ops/common/load_interleaved_pk_type.hpp | 58 +++++++++++++++++++ .../block/block_universal_gemm_as_bs_cr.hpp | 37 ++++-------- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 18 +++--- .../wp_pipeline_agmem_bgmem_creg_v1.hpp | 28 +++++---- .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 28 +++++---- .../block_universal_gemm_as_aquant_bs_cr.hpp | 30 +++------- .../block_universal_gemm_as_bs_bquant_cr.hpp | 31 +++------- .../test_batched_gemm_ut_cases.inc | 3 +- .../test_gemm_pipeline_smoke_run_test.inc | 57 +----------------- .../test_gemm_pipeline_kernel_types.hpp | 25 ++++---- .../test_gemm_pipeline_ut_cases.inc | 8 +-- .../test_gemm_pipeline_util.hpp | 36 +++++++++--- 13 files changed, 183 insertions(+), 177 deletions(-) create mode 100644 include/ck_tile/ops/common/load_interleaved_pk_type.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index dafe1b5c87..6dd06195c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.0.0 ### Added +* Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM. * Added support for B Tensor Preshuffle in CK TILE Grouped GEMM. * Added a basic copy kernel example and supporting documentation for new CK Tile developers. * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp new file mode 100644 index 0000000000..f8432b9da0 --- /dev/null +++ b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/ops/elementwise.hpp" + +namespace ck_tile { + +template +struct is_pk_int4 : std::false_type +{ +}; +template <> +struct is_pk_int4 : std::true_type +{ +}; + +template +struct InterleavedPKTypeLoader +{ + template + CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, + const WarpWindow& warp_window) + { + const element_wise::PassThroughPack8 elementwise_op{}; + + static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); + constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; + const auto in_dstr_tensors = load_tile(warp_window); + + using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), + in_dstr_tensors.get_thread_buffer().template get_as()[i]); + }); + } +}; + +template +CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) +{ + if constexpr(is_pk_int4>::value) + { + InterleavedPKTypeLoader::load_interleaved_pk_type(dst, src); + } + else + { + dst = load_tile(src); + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index e1b0792ecf..94adb42880 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" @@ -13,7 +14,9 @@ namespace ck_tile { // A is block window on shared memory // B is block window on shared memory // C is block distributed tensor -template +template struct BlockUniversalGemmAsBsCr { private: @@ -91,6 +94,7 @@ struct BlockUniversalGemmAsBsCr using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; + using Loader = remove_cvref_t>; using WarpGemm = remove_cvref_t; static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; @@ -179,25 +183,6 @@ struct BlockUniversalGemmAsBsCr return b_block_dstr_encode; } - private: - template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, - const WarpWindow& warp_window) - { - constexpr index_t UnaryOpSize = 8; - const element_wise::PassThroughPack8 elementwise_op{}; - constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto in_dstr_tensors = load_tile(warp_window); - - static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); - - using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); - static_for<0, thread_buffer_size, 1>{}([&](auto i) { - elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); - }); - } - template struct BlockGemmImpl { @@ -239,7 +224,7 @@ struct BlockUniversalGemmAsBsCr if constexpr(std::is_same_v) { - load_interleaved_pk_type(a_warp_tile_, a_block_window); + Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window); } else { @@ -247,7 +232,7 @@ struct BlockUniversalGemmAsBsCr } if constexpr(std::is_same_v) { - load_interleaved_pk_type(b_warp_tile_, b_block_window); + Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window); } else { @@ -317,7 +302,7 @@ struct BlockUniversalGemmAsBsCr { if constexpr(std::is_same_v) { - load_interleaved_pk_type(a_warp_tile_, a_block_window); + Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window); } else if constexpr(ALoadTranspose) { @@ -329,7 +314,7 @@ struct BlockUniversalGemmAsBsCr } if constexpr(std::is_same_v) { - load_interleaved_pk_type(b_warp_tile_, b_block_window); + Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window); } else if constexpr(BLoadTranspose) { @@ -468,7 +453,7 @@ struct BlockUniversalGemmAsBsCr if constexpr(std::is_same_v) { - load_interleaved_pk_type(a_warp_tile_, a_block_window); + Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window); } else if constexpr(ALoadTranspose) { @@ -480,7 +465,7 @@ struct BlockUniversalGemmAsBsCr } if constexpr(std::is_same_v) { - load_interleaved_pk_type(b_warp_tile_, b_block_window); + Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window); } else if constexpr(BLoadTranspose) { diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 71ca907c07..f1c8f2ec9b 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -289,13 +289,17 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy { using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmDispatcher; + using BTypeToUse = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + using WarpGemm = WarpGemmDispatcher; using BlockWeightPreshufflePolicy = BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy::value && !is_detected::value, - bool>* = nullptr> + bool>* = nullptr, + index_t UnaryOpSize_ = 8> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, @@ -310,14 +312,14 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 NIterPerWarp> b_flat_dram_windows; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; + using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); + + statically_indexed_array, NIterPerWarp> b_warp_tensor; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> + statically_indexed_array, NIterPerWarp> b_warp_tensor_2; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { @@ -327,7 +329,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -375,7 +378,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_2(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -408,7 +412,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -445,7 +450,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_2(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 356ad91448..670f4b0575 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/host/concat.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" @@ -514,7 +515,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 typename AElementFunction, typename std::enable_if_t::value && !is_detected::value, - bool>* = nullptr> + bool>* = nullptr, + index_t UnaryOpSize_ = 8> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, @@ -631,19 +633,19 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 b_flat_distribution); // pingpong buffer for B + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; + using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); + statically_indexed_array< statically_indexed_array, NIterPerWarp> b_flat_dram_windows; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> + statically_indexed_array, NIterPerWarp> b_warp_tensor_ping; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> + statically_indexed_array, NIterPerWarp> b_warp_tensor_pong; // Prefetch A0 @@ -659,7 +661,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); // move B window to next flat K @@ -706,7 +709,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -782,7 +786,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -862,7 +867,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 182d9251b1..f75d02f1a6 100644 --- a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -5,19 +5,19 @@ #include "ck_tile/core.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" namespace ck_tile { -template +template struct BlockGemmAQuantBase { using AQDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; - static constexpr index_t UnaryOpSize = UnaryOpSize_; template CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale) { @@ -42,23 +42,6 @@ struct BlockGemmAQuantBase } return scale_reg_f; } - - template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, - const WarpWindow& warp_window) - { - const element_wise::PassThroughPack8 elementwise_op{}; - - static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); - constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto in_dstr_tensors = load_tile(warp_window); - - using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); - static_for<0, thread_buffer_size, 1>{}([&](auto i) { - elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); - }); - } }; // A is block window on shared memory @@ -66,7 +49,9 @@ struct BlockGemmAQuantBase // Consecutive kQuantGroupSize elements of A are quantized with a separate scale. // B is block window on shared memory // C is block distributed tensor -template +template struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase { private: @@ -172,6 +157,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase using Base = BlockGemmAQuantBase; + using Loader = remove_cvref_t>; using WarpGemm = remove_cvref_t; static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; @@ -292,7 +278,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase { static_assert(std::is_same_v || std::is_same_v); - Base::load_interleaved_pk_type(a_warp_tile_, a_block_window); + Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window); } else { @@ -302,7 +288,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase { static_assert(std::is_same_v || std::is_same_v); - Base::load_interleaved_pk_type(b_warp_tile_, b_block_window); + Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window); } else { diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 7e28ea8fa9..077d0d8fe2 100644 --- a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -5,19 +5,19 @@ #include "ck_tile/core.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" namespace ck_tile { -template +template struct BlockGemmBQuantBase { using BQDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; - static constexpr index_t UnaryOpSize = UnaryOpSize_; template CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale) { @@ -42,24 +42,6 @@ struct BlockGemmBQuantBase } return scale_reg_f; } - - // can be inherited from A - template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, - const WarpWindow& warp_window) - { - const element_wise::PassThroughPack8 elementwise_op{}; - - static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); - constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto in_dstr_tensors = load_tile(warp_window); - - using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); - static_for<0, thread_buffer_size, 1>{}([&](auto i) { - elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); - }); - } }; // A is block window on shared memory @@ -67,7 +49,9 @@ struct BlockGemmBQuantBase // Consecutive kQuantGroupSize elements of B are quantized with a separate scale. // B is block window on shared memory // C is block distributed tensor -template +template struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase { private: @@ -170,6 +154,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using Base = BlockGemmBQuantBase; + using Loader = remove_cvref_t>; using WarpGemm = remove_cvref_t; static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; @@ -291,7 +276,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase { static_assert(std::is_same_v || std::is_same_v); - Base::load_interleaved_pk_type(a_warp_tile_, a_block_window); + Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window); } else { @@ -301,7 +286,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase { static_assert(std::is_same_v || std::is_same_v); - Base::load_interleaved_pk_type(b_warp_tile_, b_block_window); + Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window); } else { diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc index 035377734b..8f24c9bfe1 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc +++ b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc @@ -29,7 +29,8 @@ TYPED_TEST(TestCkTileBatchedGemm, Basic) {256, 256, 64, 8}, {256, 256, 64, 16}}; - if(ck_tile::get_device_name() != "gfx950") { + if(ck_tile::get_device_name() != "gfx950") + { gemmParams.emplace_back(256, 256, 128, 2); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc index ab74e4e7b1..57feefceab 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc @@ -2,6 +2,8 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck_tile/host/permute_pk_int4.hpp" + template static constexpr inline auto is_row_major(Layout layout_) { @@ -91,61 +93,6 @@ void permute_tensor_b(Tensor& tensor) } } -template -void permute_vectors_i4x4_b(Tensor& tensor) -{ - const ck_tile::index_t K = tensor.get_length(0); - const ck_tile::index_t N = tensor.get_length(1); - // vector pk_i4x4 permute - for(int i = 0; i < N; i++) - { - for(int j = 0; j < K; j += 8) - { - int8_t input[8]; - - for(int k = 0; k < 4; k++) - { - int8_t i4x2 = tensor(j + k * 2, i).data; - input[k * 2 + 0] = (i4x2 >> 4) & 0xf; - input[k * 2 + 1] = (i4x2 >> 0) & 0xf; - } - - // permute 01234567->20643175 - { - int8_t hi = input[2]; - int8_t lo = input[0]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 0, i) = i4x2; - } - - { - int8_t hi = input[6]; - int8_t lo = input[4]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 2, i) = i4x2; - } - - { - int8_t hi = input[3]; - int8_t lo = input[1]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 4, i) = i4x2; - } - - { - int8_t hi = input[7]; - int8_t lo = input[5]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 6, i) = i4x2; - } - } - } -} - template ; -using WeightPreshuffle = - ck_tile::integral_constant; - -// Adding alias for the F8 parameters to facilitate skipping tests. -// This alias can be removed once test failures are fixed. -using F8Types = std::tuple; +using WeightPreshuffleV1 = + ck_tile::integral_constant; +using WeightPreshuffleV2 = + ck_tile::integral_constant; // clang-format off using KernelTypesWeightPreshuffle = ::testing::Types< - std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffle>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffle> -#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 - , F8Types + std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffleV1>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffleV2>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffleV2>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffleV1> +#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 + , + std::tuple< Row, Col, Row, F8, F8, F32, F16, Default, WeightPreshuffleV1>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, Default, WeightPreshuffleV2>, + std::tuple< Row, Col, Row, F8, I4, F32, F16, Default, WeightPreshuffleV2>, + std::tuple< Row, Col, Row, F8, I4, F32, F16, Default, WeightPreshuffleV1> #endif >; diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc index 389e0d53ea..bb56c63413 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc @@ -20,7 +20,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x128x128) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v, F8>) { GTEST_SKIP() << "Skipping this test due to failures with F8"; } @@ -48,7 +48,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x128x4096) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x2048x128) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v, F8>) { GTEST_SKIP() << "Skipping this test due to failures with F8"; } @@ -77,7 +77,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x2048x4096) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x128x128) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v, F8>) { GTEST_SKIP() << "Skipping this test due to failures with F8"; } @@ -106,7 +106,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x128x4096) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x2048x128) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v, F8>) { GTEST_SKIP() << "Skipping this test due to failures with F8"; } diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 42d0149498..62f819ac1e 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" @@ -34,20 +35,31 @@ auto calculate_rtol_atol(const ck_tile::index_t K, enum struct GemmPipelineType { - WeightPreshuffle + WeightPreshuffleV1, + WeightPreshuffleV2 }; template struct GemmPipelineTypeSelector; template -struct GemmPipelineTypeSelector +struct GemmPipelineTypeSelector { using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; using pipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; - static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffle"; } + static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffleV1"; } }; + +template +struct GemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using pipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffleV2"; } +}; + template struct config { @@ -122,7 +134,8 @@ class TestCkTileGemmPipeline : public ::testing::Test constexpr bool kPadK = PadK; constexpr bool preshuffle = Preshuffle; - constexpr bool DoubleSmemBuffer = false; + constexpr bool DoubleSmemBuffer = + (PipelineType == GemmPipelineType::WeightPreshuffleV2) ? true : false; // TODO: For now - but this should also be a test parameter constexpr bool TransposeC = false; @@ -391,10 +404,19 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n); - a_m_k_dev_buf.ToDevice(a_m_k.data()); - b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n); + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_shuffle_host_dev = b_shuffle_host; + ck_tile::permute_vectors_i4x4_b(b_shuffle_host_dev); + b_k_n_dev_buf.ToDevice(b_shuffle_host_dev.data()); + } + else + { + b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); + } c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); From e469fee0460bb33cef2daa8ef9e05175b02195bc Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Thu, 18 Sep 2025 22:51:01 -0700 Subject: [PATCH 22/28] poc convert fnuz fp8 to non-native dtype similar to ocp (#2871) --- include/ck/utility/amd_ck_fp8.hpp | 30 +++++++++++++++++++++++++++-- include/ck/utility/data_type.hpp | 6 +++--- include/ck/utility/dtype_vector.hpp | 12 ++++++++++++ include/ck/utility/f8_utils.hpp | 29 ++++++++++++++-------------- include/ck/utility/type_convert.hpp | 12 ++++++------ test/data_type/test_bf8_fnuz.cpp | 20 ++++++++----------- test/data_type/test_fp8_fnuz.cpp | 20 ++++++++----------- 7 files changed, 80 insertions(+), 49 deletions(-) diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index 2edbb7c789..0b73f76155 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -33,8 +33,34 @@ namespace ck { -using f8_fnuz_t = _BitInt(8); -using bf8_fnuz_t = unsigned _BitInt(8); +struct f8_fnuz_t +{ + using data_type = unsigned char; + data_type m_data; + __host__ __device__ explicit constexpr f8_fnuz_t(data_type in_data) : m_data(in_data) {} + __host__ __device__ explicit constexpr f8_fnuz_t() = default; + __host__ __device__ bool constexpr operator==(f8_fnuz_t other) const + { + return m_data == other.m_data; + } + __host__ __device__ explicit constexpr operator data_type() const { return m_data; } +}; + +struct bf8_fnuz_t +{ + using data_type = unsigned char; + data_type m_data; + __host__ __device__ explicit constexpr bf8_fnuz_t(data_type in_data) : m_data(in_data) {} + __host__ __device__ explicit constexpr bf8_fnuz_t() = default; + __host__ __device__ bool constexpr operator==(bf8_fnuz_t other) const + { + return m_data == other.m_data; + } + __host__ __device__ explicit constexpr operator data_type() const { return m_data; } +}; + +static_assert(1 == sizeof(f8_fnuz_t)); +static_assert(1 == sizeof(bf8_fnuz_t)); typedef unsigned char fp8_storage_t; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 48b352986e..984bb4d862 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -205,7 +205,7 @@ inline constexpr bool is_native_type() return is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value; + is_same_v || is_same_v || is_same::value; } // scalar_type @@ -300,14 +300,14 @@ struct scalar_type template <> struct scalar_type { - using type = f8_fnuz_t; + using type = f8_fnuz_t::data_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf8_fnuz_t; + using type = bf8_fnuz_t::data_type; static constexpr index_t vector_size = 1; }; diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index ae0edb35ee..27a7545a0e 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -1294,6 +1294,18 @@ struct nnvb_data_t_selector using type = bf8_ocp_t::data_type; }; +template <> +struct nnvb_data_t_selector +{ + using type = f8_fnuz_t::data_type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf8_fnuz_t::data_type; +}; + template <> struct nnvb_data_t_selector { diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index 799683ae65..748aa07f9e 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -39,7 +39,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) int exponent, bias; uint32_t head, mantissa, sign; // nan code is same for float and half - constexpr Y nan_code = 0x80; + constexpr uint8_t nan_code = 0x80; constexpr uint32_t nan_mask = NumericUtils::nan_mask; // convert to bitwise @@ -60,17 +60,17 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) if constexpr(negative_zero_nan) { if((x_bitwise & nan_mask) == nan_mask) - return nan_code; + return Y{nan_code}; } else { if((x_bitwise & nan_mask) == nan_mask) - return signed_inf + (mantissa != 0 ? 1 : 0); + return Y{static_cast(signed_inf + (mantissa != 0 ? 1 : 0))}; } // check if x is 0.0 if(x_bitwise == 0) - return 0; + return Y{0}; // First need to check if it is normal or denorm as there is a difference of implict 1 // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift @@ -178,9 +178,10 @@ In this case, the fp16 mantissa should be shift left by 1 */ // check if x is 0.0 or -0.0 if(out_exponent == 0 && mantissa == 0) - return negative_zero_nan ? 0 : (sign << (out_exp + out_mant)); + return Y{negative_zero_nan ? 0 : static_cast(sign << (out_exp + out_mant))}; mantissa &= (1 << out_mant) - 1; - return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa; + return Y{static_cast((sign << (out_exp + out_mant)) | (out_exponent << out_mant) | + mantissa)}; } template @@ -195,8 +196,8 @@ __host__ __device__ Y run_cast_from_f8(X x) constexpr int out_mant = NumericUtils::mant; // prepare the codes - constexpr X nan_code = 0x80; - using T_bitwise = typename NumericUtils::bitwise_type; + constexpr uint8_t nan_code = 0x80; + using T_bitwise = typename NumericUtils::bitwise_type; constexpr T_bitwise Inf_bitwise = NumericUtils::Inf; constexpr T_bitwise NegInf_bitwise = NumericUtils::NegInf; @@ -209,13 +210,13 @@ __host__ __device__ Y run_cast_from_f8(X x) constexpr Y Neg0 = bit_cast(Neg0_bitwise); // check if x is 0.0 - if(x == 0) + if(!static_cast(x)) return static_cast(0); // unpack the input - uint32_t sign = x >> (in_exp + in_mant); - uint32_t mantissa = x & ((1 << in_mant) - 1); - int exponent = (x & 0x7F) >> in_mant; + uint32_t sign = static_cast(x) >> (in_exp + in_mant); + uint32_t mantissa = static_cast(x) & ((1 << in_mant) - 1); + int exponent = (static_cast(x) & 0x7F) >> in_mant; constexpr int exp_low_cutoff = (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); @@ -223,12 +224,12 @@ __host__ __device__ Y run_cast_from_f8(X x) if constexpr(negative_zero_nan) { - if(x == nan_code) + if(static_cast(x) == nan_code) return NaN; } else { - if(x == nan_code) + if(static_cast(x) == nan_code) return Neg0; if(exponent == ((1 << in_exp) - 1)) return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 290a6c8dd6..913557fc7a 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -351,7 +351,7 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8); ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos val.i32val = ival; - return val.i8val[0]; // little endian + return f8_t{val.i8val[0]}; // little endian #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -419,7 +419,7 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8); ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos val.i32val = ival; - return val.i8val[0]; // little endian + return bf8_t{val.i8val[0]}; // little endian #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -655,7 +655,7 @@ inline __host__ __device__ f8_fnuz_t f8_convert_rne(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8); ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0 val.i32val = ival; - return val.i8val[0]; + return f8_t{val.i8val[0]}; #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -707,7 +707,7 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_rne(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8); ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0 val.i32val = ival; - return val.i8val[0]; + return bf8_t{val.i8val[0]}; #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -924,7 +924,7 @@ inline __host__ __device__ float type_convert(f8_fnuz_t x) { #if defined(__gfx94__) float fval; - uint32_t i32val = static_cast(x); + uint32_t i32val = static_cast(static_cast(x)); fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); return fval; @@ -1430,7 +1430,7 @@ inline __host__ __device__ float type_convert(bf8_fnuz_t x) { #if defined(__gfx94__) float fval; - uint32_t i32val = static_cast(x); + uint32_t i32val = static_cast(static_cast(x)); fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); return fval; diff --git a/test/data_type/test_bf8_fnuz.cpp b/test/data_type/test_bf8_fnuz.cpp index 4ff796a614..f028c0da73 100644 --- a/test/data_type/test_bf8_fnuz.cpp +++ b/test/data_type/test_bf8_fnuz.cpp @@ -43,9 +43,8 @@ TEST(BF8FNUZ, ConvertFP32Nearest) type_convert(f8_convert_rne(std::numeric_limits::max())), abs_tol); // convert inf float to bf8_fnuz_t and check if it is qNan - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_rne(std::numeric_limits::infinity()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_rne(std::numeric_limits::infinity())); // positive norm float value to bf8 and back, check if holds float pos_float = 0.0000762939f; ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); @@ -80,9 +79,8 @@ TEST(BF8FNUZ, ConvertFP32Stochastic) type_convert(f8_convert_sr(std::numeric_limits::max())), abs_tol); // convert inf float to bf8_fnuz_t and check if it is qNan - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_sr(std::numeric_limits::infinity()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_sr(std::numeric_limits::infinity())); // positive norm float value to bf8 and back, check if holds float pos_float = 0.0000762939f; ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); @@ -118,9 +116,8 @@ TEST(BF8FNUZ, ConvertFP16Nearest) type_convert(f8_convert_rne(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_rne(ck::NumericLimits::QuietNaN()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_rne(ck::NumericLimits::QuietNaN())); // positive norm fp16 value to bf8 and back, check if holds half_t pos_half = half_t{0.0000762939}; ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); @@ -155,9 +152,8 @@ TEST(BF8FNUZ, ConvertFP16Stochastic) type_convert(f8_convert_sr(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_sr(ck::NumericLimits::QuietNaN()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_sr(ck::NumericLimits::QuietNaN())); // positive norm fp16 value to bf8 and back, check if holds half_t pos_half = half_t{0.0000762939}; ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); diff --git a/test/data_type/test_fp8_fnuz.cpp b/test/data_type/test_fp8_fnuz.cpp index c2ec6dad94..0cf775f947 100644 --- a/test/data_type/test_fp8_fnuz.cpp +++ b/test/data_type/test_fp8_fnuz.cpp @@ -48,9 +48,8 @@ TEST(FP8FNUZ, ConvertFP32Nearest) type_convert(f8_convert_rne(std::numeric_limits::max())), abs_tol); // convert inf float to f8_fnuz_t and check if it is qNan - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_rne(std::numeric_limits::infinity()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_rne(std::numeric_limits::infinity())); // positive norm float value to fp8 and back, check if holds float pos_float = 0.017578125f; ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); @@ -85,9 +84,8 @@ TEST(FP8FNUZ, ConvertFP32Stochastic) type_convert(f8_convert_sr(std::numeric_limits::max())), abs_tol); // convert inf float to f8_fnuz_t and check if it is qNan - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_sr(std::numeric_limits::infinity()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_sr(std::numeric_limits::infinity())); // positive norm float value to fp8 and back, check if holds float pos_float = 0.017578125f; ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); @@ -122,9 +120,8 @@ TEST(FP8FNUZ, ConvertFP16Nearest) type_convert(f8_convert_rne(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_rne(ck::NumericLimits::QuietNaN()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_rne(ck::NumericLimits::QuietNaN())); // positive norm fp16 value to fp8 and back, check if holds half_t pos_half = half_t{0.017578125}; ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); @@ -159,9 +156,8 @@ TEST(FP8FNUZ, ConvertFP16Stochastic) type_convert(f8_convert_sr(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_sr(ck::NumericLimits::QuietNaN()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_sr(ck::NumericLimits::QuietNaN())); // positive norm fp16 value to fp8 and back, check if holds half_t pos_half = half_t{0.017578125}; ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); From dd249f1cd6c516f7a1d45663f7f26eb4a4c086ca Mon Sep 17 00:00:00 2001 From: ltqin Date: Fri, 19 Sep 2025 14:26:43 +0800 Subject: [PATCH 23/28] Add input fp8 and output bf16 attention (#2726) * change host using fp16 to check * fp8 to fp8 compare * rewrite input parameters * add not squant * remove some output code * for scale = 1 * format * saturates only for fp8 * add fp8bf16 data type * add fp8bf16 data type * fix test fp8 code * add run_fp8bf16_tests * change fmha fwd example parameter(adding fp8bf16) * Support fp8bf16 for Aiter * Support aiter fp8bf16 in c++ * fix comment about fp8 in readme.md * add fp8fp32 * add fp8fp32 test * remove range_q etc. * format * fix test parameters about squant and fmha example input fp8bf16 fp8fp32 data type * add fp8bf16 to data_type function * change colmajor to rowmajor in test_ck_tile_fmha_fwd_fp8 * format * reset atol for fp8 * fix bug for atol --------- Co-authored-by: rocking Co-authored-by: asleepzzz --- example/ck_tile/01_fmha/README.md | 2 +- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 3 +- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 48 +++-- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 1 - .../codegen/ops/fmha_pagedkv_prefill.py | 10 +- example/ck_tile/01_fmha/example_fmha_fwd.cpp | 30 ++-- example/ck_tile/01_fmha/fmha_fwd.hpp | 36 ++++ example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 170 +++++++++++------- .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 29 ++- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 52 +++--- test/ck_tile/fmha/test_fmha_fwd.inc | 21 +-- test/ck_tile/fmha/test_fmha_fwd_fp8.cpp | 13 +- 12 files changed, 262 insertions(+), 153 deletions(-) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index cb6cd44f64..7f55d7412f 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -131,4 +131,4 @@ TBD ## FP8 experimental support As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+. -Currently we only support `-vlayout=c`( `hdim*seqlen` for V matrix) and `-squant=1`(static quantization) with `hdim=128` for fp8 now. Full feature support will come later. +Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later. diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 42a9d5148a..802c9e51d7 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -7,7 +7,8 @@ FWD_DTYPE_MAP = { "bf16" : "FmhaFwdBf16", "fp8" : "FmhaFwdFp8", "fp8fp16": "FmhaFwdFp8Fp16", - "fp8bf16": "FmhaFwdFp8Bf16" + "fp8bf16": "FmhaFwdFp8Bf16", + "fp8fp32": "FmhaFwdFp8Fp32" } BWD_DTYPE_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index d9452206e7..cfb96b7d53 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -163,7 +163,7 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); }}; - + const bool has_load_tr = ck_tile::is_load_tr_supported(); {F_dispatch} @@ -248,11 +248,11 @@ class FmhaFwdApiTrait: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False - + @property def seqtune(self) -> str: if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true - else: + else: return f'a.seqlen_q <= {self.bm0}' @property @@ -351,7 +351,7 @@ class FmhaFwdPipeline: if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' - + if self.F_trload == 't' : n += '_trload' else: n += '_ntrload' @@ -378,7 +378,7 @@ class FmhaFwdApiPool: "t": "has_load_tr", "f": "true" } - + per_tr_load =str() for tr_load in ["t", "f"]: per_dtypes=str() @@ -550,12 +550,16 @@ class KernelComponentFactory: (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } - elif dtype == 'fp8' or dtype == 'bf8': + elif dtype == 'fp8' or dtype == 'fp8bf16': return { (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } + elif dtype == 'fp8fp32': + return { + (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } else: return None @@ -567,9 +571,9 @@ class KernelComponentFactory: # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: + squant = 'f' for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): if hdim == 256 and hdim_v == 256: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) @@ -589,11 +593,12 @@ class KernelComponentFactory: pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim - elif dtype in ['fp8', 'bf8']: + elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']: # no need lse/dropout kernels - for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) - elif dtype in ['fp8fp16', 'fp8bf16']: + for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + elif dtype in ['fp8fp16', 'bf8']: # TODO None else: @@ -674,25 +679,34 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue # Aiter(mha_fwd) integration elif receipt == 100: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] cond &= mode == 'batch' cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + if dtype == 'fp8bf16': + cond &= hdim == 128 if not cond: continue # Aiter(mha_varlen_fwd) integration elif receipt == 200: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] cond &= mode == 'group' cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + if dtype == 'fp8bf16': + cond &= hdim == 128 if not cond: continue # aiter::mha_fwd C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + if dtype == 'fp8bf16': + cond &= hdim == 128 + if not cond: + continue + elif receipt == 888: + cond = dtype in ['fp8', 'fp8bf16', 'fp8fp32'] + cond &= pipeline.F_vlayout == 'row' + cond &= hdim == 128 if not cond: continue diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 3b48b3d005..cee1505486 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -645,7 +645,6 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: return { '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), } else: return None diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 7b93e9654c..df6b422981 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -465,14 +465,14 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]): pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) elif dtype in ['fp8', 'bf8']: - # TODO - None + # no need lse/dropout kernels + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index c3bbb7a558..91cb9f55be 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -44,21 +44,15 @@ auto create_args(int argc, char* argv[]) .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" - "note when squant=1, this value will be modified by range_q/k") + "note when squant=1, this value will be modified") .insert("logits_soft_cap", "0", "attention logits soft capping value.") - .insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.") - .insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.") - .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") - .insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.") - .insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.") .insert("squant", "auto", "if using static quantization fusion or not. auto: fp8 will default use squant, " "other will not\n" "0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to " "P and O.\n" - "calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, " - "range_p, range_o") + "calculate scale_s, scale_p, scale_o auto") .insert("iperm", "1", "permute input\n" @@ -89,7 +83,7 @@ auto create_args(int argc, char* argv[]) "uf", "init method:\n ui or 0 - uniform random int\n ni - normalized random int" "\n uf or 1 - uniform random float\n nf - normalized random float" - "\n tf or 2 - trig float\n uf:q or ufq or 3 - fp8 quantization") + "\n tf or 2 - trig float\n") .insert("seed", "11939", "random seed used for initializing input tensors. 0 for " @@ -148,11 +142,6 @@ auto run(const ck_tile::ArgParser& arg_parser) uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); bool drop_prefs = arg_parser.get_bool("drop_prefs"); std::string mask_str = arg_parser.get_str("mask"); - float range_q = arg_parser.get_float("range_q"); - float range_k = arg_parser.get_float("range_k"); - float range_v = arg_parser.get_float("range_v"); - float range_p = arg_parser.get_float("range_p"); - float range_o = arg_parser.get_float("range_o"); bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved"); ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); std::string init_method = arg_parser.get_str("init"); @@ -201,11 +190,6 @@ auto run(const ck_tile::ArgParser& arg_parser) drop_offset, drop_prefs, mask_str, - range_q, - range_k, - range_v, - range_p, - range_o, squant, is_rotary_interleaved, num_splits, @@ -237,6 +221,14 @@ int main(int argc, char* argv[]) { return run(arg_parser) == fwd_result::success ? 0 : -2; } + else if(data_type == "fp8bf16") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "fp8fp32") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } std::cerr << "Unsupported precision: " << data_type << std::endl; return -1; } diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index df1e9e5699..c41e48e6aa 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -41,6 +41,10 @@ struct FmhaFwdFp8Bf16 { }; +struct FmhaFwdFp8Fp32 +{ +}; + template struct FmhaFwdTypeConfig; @@ -108,6 +112,38 @@ struct FmhaFwdTypeConfig using ODataType = ck_tile::bf8_t; }; +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = float; +}; + struct FmhaMasks { using NoMask = ck_tile::GenericAttentionMask; diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 397245ab32..43f484fe14 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -50,20 +50,30 @@ auto get_elimit(std::string /*init_method*/) } template <> -auto get_elimit(std::string init_method) +auto get_elimit(std::string /*init_method*/) { - if(init_method == "ui" || init_method == "ni") - { - unsigned max_rounding_point_distance = 0; - double atol = 2e-3; - return ck_tile::make_tuple(max_rounding_point_distance, atol); - } - else - { - unsigned max_rounding_point_distance = 1; - double atol = 0.0625; - return ck_tile::make_tuple(max_rounding_point_distance, atol); - } + using TypeConfig = FmhaFwdTypeConfig; + using ODataType = typename TypeConfig::ODataType; + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + double rtol = 0; + double atol = 16 * (o_dtype_max > 240 ? 2 : 1); + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1.8e-1; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1.8e-1; + return ck_tile::make_tuple(rtol, atol); } int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits) @@ -157,11 +167,6 @@ fwd_result fmha_fwd_run(mode_enum mode, uint64_t drop_offset, bool drop_prefs, std::string mask_str, - float range_q, - float range_k, - float range_v, - float range_p, - float range_o, bool squant, bool is_rotary_interleaved, ck_tile::index_t num_splits, @@ -180,6 +185,10 @@ fwd_result fmha_fwd_run(mode_enum mode, return "fp8"; else if constexpr(std::is_same_v) return "bf8"; + else if constexpr(std::is_same_v) + return "fp8bf16"; + else if constexpr(std::is_same_v) + return "fp8fp32"; else static_assert(false); }(); @@ -367,22 +376,6 @@ fwd_result fmha_fwd_run(mode_enum mode, using OaccDataType = typename TypeConfig::OaccDataType; using ODataType = typename TypeConfig::ODataType; - float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float p_dtype_max = v_dtype_max; // assume p and v is the same type - float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - - float scale_p = 1.f; - float scale_o = 1.f; - - if(squant) - { - scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max); - scale_p = p_dtype_max / range_p; - scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max); - } - // accumulation numbers for performance evaluation std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = @@ -528,7 +521,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::HostTensor cache_batch_idx_host(use_cache_batch_idx ? std::array{batch} : std::array{1}); - + float max_o = 5.0; if(init_method == "ui" || init_method == "0") { ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); @@ -576,32 +569,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::FillTrigValue{}(vnew_host); ck_tile::FillTrigValue{}(bias_host); } - else if(init_method == "ufq" || init_method == "uf:q" || init_method == "3") - { - // suitable for fp8 quantization - if(!squant) - { - std::cerr << "init method " << init_method << " can not be used without quantization" - << std::endl; - return fwd_result::invalid_args; - } - ck_tile::FillUniformDistribution{0.f, q_dtype_max, next_seed()}(q_host); - ck_tile::FillUniformDistribution{0.f, k_dtype_max, next_seed()}(k_host); - ck_tile::FillUniformDistribution{0.f, k_dtype_max, next_seed()}(knew_host); - ck_tile::FillUniformDistribution{0.f, v_dtype_max, next_seed()}(v_host); - ck_tile::FillUniformDistribution{0.f, v_dtype_max, next_seed()}(vnew_host); - - // bias_fp8 = qscale_bias * bias_fp32 - float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k); - // Assume bias is in [0.f, 1.f] in original fp32 - ck_tile::FillUniformDistribution{0.f, qscale_bias, next_seed()}(bias_host); - } - else - { - std::cerr << "Unknown value for init argument: " << init_method << std::endl; - return fwd_result::invalid_args; - } - if(bias.type == bias_enum::alibi) { auto slopes = ck_tile::get_alibi_slopes(nhead); @@ -625,8 +592,8 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); @@ -650,10 +617,79 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes()); + float scale_p = 1.f; + float scale_o = 1.f; + if(squant) + { + float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float p_dtype_max = v_dtype_max; // assume p and v is the same type + // Q tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + q_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + + float scale = q_dtype_max / max_value; + + q_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + scale_s = scale_s / scale; + } + + // K tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + k_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + float scale = k_dtype_max / max_value; + k_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + scale_s = scale_s / scale; + } + + // V tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + v_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + + float scale = k_dtype_max / max_value; + v_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + + scale_o = (1.0 / p_dtype_max) / scale; + } + + scale_p = p_dtype_max; + + if constexpr(std::is_same_v) + { + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + scale_o = scale_o * o_dtype_max / max_o; + } + } + q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); - knew_buf.ToDevice(knew_host.data()); v_buf.ToDevice(v_host.data()); + knew_buf.ToDevice(knew_host.data()); vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); @@ -1103,7 +1139,9 @@ fwd_result fmha_fwd_run(mode_enum mode, lse_buf.FromDevice(lse_host.data()); randval_buf.FromDevice(randval_host.data()); - constexpr bool supports_squant = std::is_same_v; + constexpr bool supports_squant = std::is_same_v || + std::is_same_v || + std::is_same_v; auto p_compute_element_func = [&]() { if constexpr(supports_squant) @@ -1113,9 +1151,11 @@ fwd_result fmha_fwd_run(mode_enum mode, }(); auto oacc_element_func = [&]() { - if constexpr(supports_squant) + if constexpr(std::is_same_v && supports_squant) return ck_tile::composes(ck_tile::saturates{}, ck_tile::scales{scale_o}); + else if constexpr(supports_squant) + return ck_tile::scales{scale_o}; else return ck_tile::identity{}; }(); diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index c087a1fb3e..afd0c728c6 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -94,7 +94,30 @@ run_fp8_tests() { for b in 1 2 ; do for hdim in 64 128 256 ; do - run_exe -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp8bf16_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp8fp32_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + done ; done ; done ; done } @@ -117,7 +140,9 @@ run_fp16_appendkv_tests() { set -x run_fp16_bf16_tests -# run_fp8_tests +run_fp8_tests +run_fp8bf16_tests +run_fp8fp32_tests if [ $TEST_APPENDKV -eq 1 ] ; then run_fp16_appendkv_tests diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 6405ca50df..58fdad149a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1446,29 +1446,35 @@ struct FmhaFwdKernel auto o_acc_tile = [&]() { if constexpr(kDoFp8StaticQuant) { - return FmhaPipeline{}( - q_dram_window, - identity{}, // q_element_func - k_dram_window, - identity{}, // k_element_func - v_dram_window, - identity{}, // v_element_func - bias_dram_window, - identity{}, // bias_element_func - randval_dram_window, - lse_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - scales{kargs.scale_p}, // p_compute_element_func - composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); + auto o_acc_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::composes(ck_tile::saturates{}, + ck_tile::scales{kargs.scale_o}); + else + return ck_tile::scales{kargs.scale_o}; + }(); + return FmhaPipeline{}(q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{kargs.scale_p}, // p_compute_element_func + o_acc_element_func, // o_acc_element_func + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); } else { diff --git a/test/ck_tile/fmha/test_fmha_fwd.inc b/test/ck_tile/fmha/test_fmha_fwd.inc index f02ef1e55e..08abd3358d 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.inc +++ b/test/ck_tile/fmha/test_fmha_fwd.inc @@ -32,9 +32,6 @@ const ck_tile::stream_config stream_config{ 1, // rotating_count_ }; -// range_q, range_k, range_v, range_p, range_o, squant -#define QUANT_ARGS 1, 1, 1, 1, 1, squant - #define COMMON_ARGS \ init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \ stream_config @@ -117,7 +114,7 @@ TEST_P(AllLong, Test) 1024, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -179,7 +176,7 @@ TEST_P(HDimPadding, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -236,7 +233,7 @@ TEST_P(ElementwiseBias, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -292,7 +289,7 @@ TEST_P(Alibi, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -350,7 +347,7 @@ TEST_P(Dropout, Test) drop_offset, // drop_offset drop_prefs, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -410,7 +407,7 @@ TEST_P(PagedKV, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -476,7 +473,7 @@ TEST_P(SplitKV, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved num_splits, // num_splits COMMON_ARGS); @@ -548,7 +545,7 @@ TEST_P(AppendKV, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, false, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -618,7 +615,7 @@ TEST_P(AppendKVRoPE, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, is_rotary_interleaved, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); diff --git a/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp b/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp index 46ed8f4125..b99c304d1f 100644 --- a/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp +++ b/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp @@ -17,22 +17,21 @@ using DataTypeConfig = FmhaFwdFp8; // instances are added), however the corresponding tests are not disabled (they will be skipped) // in case such instances will be added in the future. -const auto HDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); +const auto HDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}); -const auto SplitKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); +const auto SplitKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}); -const auto AppendKVHDimValues = - Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); +const auto AppendKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}); // There are no fp8 instances with seqlen padding (mode_enum::group requires it) const auto ModeValues = Values(mode_enum::batch); const auto IsVRowmajorValues = Values(false); -const bool squant = true; -const std::string init_method = "ufq"; +const auto squant = true; +const std::string init_method = "uf"; const bool def_lse = false; -const bool def_is_v_rowmajor = false; +const bool def_is_v_rowmajor = true; int adjust_seqlen(int seqlen) { From 2aec38f9ec67bfbdccbdb3a5c25913e5a9ba6136 Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Fri, 19 Sep 2025 12:34:45 +0600 Subject: [PATCH 24/28] [CK_TILE] FMHA Fix synchronization issues in BWD pipelines (#2876) * Run ctest with --output-on-failure * Fix synchronization issues in bwd pipelines The bwd kernel reuses the same area of LDS for ds (SGrad), bias and dbias (BiasGrad). This means that there must be block_sync_lds between loading one tensor and storing another to the same area. Heavy instructions like MFMA/WMMA and global loads are executed between reuses of the same memory so in MOST cases loading is finished by all warps before storing is started. However, sometimes warps progress at different speeds. Running the tests multiple times and, preferably, with multiple processes on the same GPU helps to trigger this issue: bin/test_ck_tile_fmha_bwd_bf16 --gtest_repeat=-1 --gtest_shuffle --gtest_throw_on_failure --- ...fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 16 +++++++++++----- ...ha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 6 ++++++ ...a_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp | 6 ++++++ .../block_fmha_bwd_pipeline_default_policy.hpp | 2 +- script/launch_tests.sh | 4 +--- 5 files changed, 25 insertions(+), 9 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index b883aad155..c402eaeac4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -559,6 +559,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto shuffled_bias_tile = make_static_distributed_tensor( Policy::template MakeShuffledBiasTileDistribution()); shuffle_tile(shuffled_bias_tile, bias_tile); + // SGrad and Bias use the same address in LDS, finish loading ds on the previous + // iteration to reuse LDS. + block_sync_lds(); store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); @@ -814,6 +817,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto shuffled_bias_tile = make_static_distributed_tensor( Policy::template MakeShuffledBiasTileDistribution()); shuffle_tile(shuffled_bias_tile, bias_tile); + // SGrad and Bias use the same address in LDS, finish loading ds in the hot loop to + // reuse LDS. + block_sync_lds(); store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); @@ -956,6 +962,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP return cast_tile(ds); } }(); + // Finish loading bias_s to reuse LDS. + block_sync_lds(); store_tile(bias_lds_write_window, dbias); block_sync_lds(); auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); @@ -975,11 +983,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); - if constexpr(kHasBiasGrad) - { - // SGrad and BiasGrad use the same address in LDS. - block_sync_lds(); - } + // SGrad and Bias/BiasGrad use the same address in LDS, finish loading bias/dbias or, when + // bias is not used, loading ds in the hot loop to reuse LDS. + block_sync_lds(); store_tile(ds_lds_window, ds_gemm); block_sync_lds(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 81950bd30a..41cb4fc306 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -698,6 +698,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse + // LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); } s_waitcnt(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 16d9f695df..8c8d2af486 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -656,6 +656,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); dk_acc = gemm_3(dst_reg_tensor, qt_reg_tensor); + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse + // LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); } __builtin_amdgcn_s_waitcnt(3952); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 68ead7c765..ad9e2959f5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1941,7 +1941,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt; constexpr index_t smem_size_stage0_1 = smem_size_v; - constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot + + constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + smem_size_dot + smem_size_do + smem_size_lse + smem_size_d + max(smem_size_bias, smem_size_ds); diff --git a/script/launch_tests.sh b/script/launch_tests.sh index 5e71e25478..17a99e62a3 100755 --- a/script/launch_tests.sh +++ b/script/launch_tests.sh @@ -49,7 +49,7 @@ with open('$TEST_FILE', 'r') as f: if tests: # Extract just the filename after the last '/' clean_tests = [os.path.basename(test) for test in tests] - print('ctest -R \"' + '|'.join(clean_tests) + '\"') + print('ctest --output-on-failure -R \"' + '|'.join(clean_tests) + '\"') else: print('# No tests to run') ") @@ -57,5 +57,3 @@ with open('$TEST_FILE', 'r') as f: echo "$command" eval "$command" - - From 86dd59cd01e41a4190bf2405a0fb0e89d9498b4c Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Fri, 19 Sep 2025 17:36:49 +0800 Subject: [PATCH 25/28] =?UTF-8?q?[CK=5FTILE]=20Add=20sequence=20padding=20?= =?UTF-8?q?and=20variable=20length=20support=20in=20fmha=20(a=E2=80=A6=20(?= =?UTF-8?q?#2851)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [CK_TILE] Add sequence padding and variable length support in fmha (and v3) - Group Mode Padding: Introduces the `-s_qpad` argument to support physically padded layouts. Kernels now use padded start pointers (`seqstart_padded_*_ptr`) for memory addressing. - Batch Mode Variable Length: Adds `-q_eff_lens` and `-kv_eff_lens` arguments for efficient processing of variable-length sequences by passing cumulative effective lengths (`cu_seqlen_*_ptr`) to the kernel. - FMHA examples: Support padding and variable length both in group and batch mode. Dispatcher is updated as well (dispatch to kPadSeqLenK enabled pipeline). - New padding test cases: Add padding test cases to `smoke_test_fwd.sh`, and add benchmarks to `benchmark_fwd.sh` and `benchmark_fwd_v3.sh` as well. These test cases and benchmarks that specifically validate/benchmark the new padding and variable-length functionalities in both group and batch modes. * [CK_TILE] Fix build error in fmha unit tests --------- Co-authored-by: Po Yen Chen Co-authored-by: Yi DING --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 6 +- example/ck_tile/01_fmha/example_fmha_fwd.cpp | 20 +- .../ck_tile/01_fmha/example_fmha_fwd_v3.cpp | 148 ++++++++- example/ck_tile/01_fmha/fmha_fwd.hpp | 17 +- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 127 +++++++- example/ck_tile/01_fmha/fmha_fwd_v3.hpp | 5 + example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 4 +- .../ck_tile/01_fmha/script/benchmark_fwd.sh | 33 ++ .../01_fmha/script/benchmark_fwd_v3.sh | 17 ++ .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 109 +++++++ .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 285 ++++++++++++++++-- .../ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 180 ++++++++++- test/ck_tile/fmha/test_fmha_fwd.inc | 141 +++++++++ 13 files changed, 1032 insertions(+), 60 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index cfb96b7d53..da0c9ca931 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -259,11 +259,11 @@ class FmhaFwdApiTrait: def skcheck(self) -> str: if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + if self.skpad == 't' : return f'(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)' + else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' elif self.pipeline_tag in ['qr', 'qs']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k % {self.bn0} == 0' + else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' elif self.pipeline_tag == 'qr_async_trload': if self.skpad == 't' : return 'true' else: return 'true' diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index 91cb9f55be..79fda6d564 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -33,6 +33,10 @@ auto create_args(int argc, char* argv[]) "0", "seqlen_k for new key/value, 0 means not to use this at all; " "-1 to choose s_knew in [1, s] randomly.") + .insert("s_qpad", + "-1", + "seqlen_q stride between 2 batches (group-mode optional).\n" + "Provide positive strides per-batch to simulate physical padding on Q.") .insert("s_kpad", "-1", "seqlen_k stride between 2 batches, currently used in group-mode only\n" @@ -107,7 +111,15 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") - .insert("jsonfile", "fmha_fwd.json", "json file name to dump results"); + .insert("jsonfile", "fmha_fwd.json", "json file name to dump results") + .insert("q_eff_lens", + "", + "Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override.") + .insert("kv_eff_lens", + "", + "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -127,6 +139,9 @@ auto run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew"); auto seqlen_kpads = arg_parser.get_int_vec("s_kpad"); + auto seqlen_qpads = arg_parser.get_int_vec("s_qpad"); + auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens"); + auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens"); ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); bool i_perm = arg_parser.get_bool("iperm"); bool o_perm = arg_parser.get_bool("operm"); @@ -174,7 +189,10 @@ auto run(const ck_tile::ArgParser& arg_parser) hdim_q, hdim_v, seqlen_knew, + seqlen_qpads, seqlen_kpads, + q_eff_lens_per_batch, + kv_eff_lens_per_batch, rotary_dim, i_perm, o_perm, diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index 569c98a458..7ddb65a2db 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -52,7 +52,16 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair get_query_shape() const @@ -172,6 +183,8 @@ struct Problem mask_info mask; TensorLayout input_layout; TensorLayout output_layout; + std::vector q_eff_lens; + std::vector kv_eff_lens; }; struct RunConfig @@ -326,8 +339,10 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) q_buf.ToDevice(q.data()); k_buf.ToDevice(k.data()); v_buf.ToDevice(v.data()); + // Ensure output buffer is zero-initialized so padded regions compare cleanly + o_buf.SetZero(); - ck_tile::fmha_fwd_v3_args args; + ck_tile::fmha_fwd_v3_args args{}; args.data_type = problem.data_type; args.batch = problem.batch; @@ -380,6 +395,60 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) : problem.seqlen_q * problem.hdim; args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; + // Optional cumulative seqlen overrides (exclude PAD) + const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1; + const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1; + + auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { + std::vector eff; + if(!opt_vec.empty() && opt_vec[0] != -1) + { + eff.assign(opt_vec.begin(), opt_vec.end()); + if(eff.size() < static_cast(problem.batch)) + { + eff.resize(problem.batch, eff.back()); + } + } + else + { + eff.assign(problem.batch, fallback); + } + return eff; + }; + + const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q); + const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k); + + // Calculate cumulative sums for kernel arguments if varlen is used + std::vector cuq_cum, cukv_cum; + auto calculate_cumulative = [&](const std::vector& per_batch_vec, + std::vector& cum_vec) { + cum_vec.resize(per_batch_vec.size() + 1); + cum_vec[0] = 0; + for(std::size_t i = 0; i < per_batch_vec.size(); ++i) + cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; + }; + + if(has_varlen_q) + { + calculate_cumulative(eff_q_vec, cuq_cum); + } + if(has_varlen_k) + { + calculate_cumulative(eff_kv_vec, cukv_cum); + } + + ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0); + ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0); + cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr); + cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr); + args.cu_seqlen_q_ptr = + !cuq_cum.empty() ? reinterpret_cast(cuq_buf.GetDeviceBuffer()) + : nullptr; + args.cu_seqlen_kv_ptr = + !cukv_cum.empty() ? reinterpret_cast(cukv_buf.GetDeviceBuffer()) + : nullptr; + ck_tile::stream_config stream_config{nullptr, true, /*log_level=*/0, @@ -442,15 +511,72 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) o_ref = o_ref.transpose({0, 2, 1, 3}); } - host::fmha_fwd(q, - k, - v, - problem.mask, - o_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); + // If variable lengths are provided, compute per-batch references + // with the effective lengths; else compute a single full reference. + if(has_varlen_q || has_varlen_k) + { + // Variable-length aware verification: zero-fill padded region and only compute valid part. + o_ref.SetZero(); + + for(int b = 0; b < problem.batch; ++b) + { + const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; + const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; + + if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) + continue; + + // Slice current batch from inputs (bshd) and build single-batch tensors + ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + + // Copy effective region + q_b.ForEach([&](auto& self, auto idx) { + // idx: [0, s, h, d] + self(idx) = q(b, idx[1], idx[2], idx[3]); + }); + k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); + v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); + + // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) + host::fmha_fwd(q_b, + k_b, + v_b, + problem.mask, + o_b, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); + + // Scatter into o_ref's bshd descriptor memory + for(int s = 0; s < seqlen_q_eff; ++s) + { + for(int h = 0; h < problem.nhead_q; ++h) + { + for(int d = 0; d < problem.hdim; ++d) + { + o_ref(b, s, h, d) = o_b(0, s, h, d); + } + } + } + } + } + else + { + // No varlen override: compute the full reference once + host::fmha_fwd(q, + k, + v, + problem.mask, + o_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); + } ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index c41e48e6aa..f5dd42a6bd 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -162,11 +162,20 @@ struct fmha_fwd_args void* lse_ptr; void* o_ptr; + // Optional cumulative sequence length arrays + // Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD) + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] + const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + // Group mode: seqstart_padded_* provide physical starts including PAD (optional) + const void* seqstart_padded_q_ptr = nullptr; // [batch+1] + const void* seqstart_padded_k_ptr = nullptr; // [batch+1] + ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -554,7 +563,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.min_seqlen_q, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.seqstart_padded_q_ptr, + args.seqstart_padded_k_ptr); } else { // create batch mode kernel arguments @@ -600,7 +611,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.mask_type, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.cu_seqlen_q_ptr, + args.cu_seqlen_kv_ptr); } }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 43f484fe14..cb5827975e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -151,7 +151,10 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t seqlen_knew, + std::vector seqlen_qpads, std::vector seqlen_kpads, + std::vector q_eff_lens_per_batch, + std::vector kv_eff_lens_per_batch, ck_tile::index_t rotary_dim, bool i_perm, bool o_perm, @@ -362,6 +365,44 @@ fwd_result fmha_fwd_run(mode_enum mode, const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); + // Optional padded Q seqstarts (group-mode only) + std::vector seqstart_q_with_padding_host; + if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1) + { + if(seqlen_qpads.size() < static_cast(batch)) + { + seqlen_qpads.resize(batch, seqlen_qpads.back()); + } + if(seqlen_qpads.size() == static_cast(batch)) + { + seqstart_q_with_padding_host = to_seqstarts( + ck_tile::span(seqlen_qpads.data(), seqlen_qpads.size())); + } + } + + // Optional batch-mode cumulative seqlen overrides + std::vector cuq_cum, cukv_cum; + if(mode == mode_enum::batch) + { + auto calculate_cumulative = [&](std::vector& per_batch_vec, + std::vector& cum_vec) { + if(!per_batch_vec.empty() && per_batch_vec[0] != -1) + { + if(per_batch_vec.size() < static_cast(batch)) + { + per_batch_vec.resize(batch, per_batch_vec.back()); + } + cum_vec.resize(batch + 1); + cum_vec[0] = 0; + for(int i = 0; i < batch; ++i) + cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; + } + }; + + calculate_cumulative(q_eff_lens_per_batch, cuq_cum); + calculate_cumulative(kv_eff_lens_per_batch, cukv_cum); + } + using TypeConfig = FmhaFwdTypeConfig; using QDataType = typename TypeConfig::QDataType; @@ -445,8 +486,15 @@ fwd_result fmha_fwd_run(mode_enum mode, // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); - const ck_tile::index_t shape_seqlen_q = + // logical(unpadded) total seqlen_q for group; batch uses fixed seqlen + const ck_tile::index_t shape_seqlen_q_lse = (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); + // physical(padded) total seqlen_q for group when s_qpad is provided; else use logical + const ck_tile::index_t shape_seqlen_q = + (mode == mode_enum::batch + ? seqlen_qs[0] + : (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back() + : seqstart_q_with_padding_host.back())); const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_ks[0] : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() @@ -504,7 +552,7 @@ fwd_result fmha_fwd_run(mode_enum mode, // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] ck_tile::HostTensor lse_host( - lse ? std::array{shape_batch, nhead, shape_seqlen_q} + lse ? std::array{shape_batch, nhead, shape_seqlen_q_lse} : std::array{1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor o_host( @@ -602,6 +650,16 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty() + ? 0 + : seqstart_q_with_padding_host.size() * + sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k_padded_buf( + seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 + : cuq_cum.size() * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem cu_seqlen_kv_buf( + cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t)); ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) @@ -693,8 +751,14 @@ fwd_result fmha_fwd_run(mode_enum mode, vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); - seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() - : seqstart_k_with_padding_host.data()); + // Keep logical starts in seqstart_k; pass padded K via separate pointer + seqstart_k.ToDevice(seqstart_k_host.data()); + seqstart_q_padded_buf.ToDevice( + seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data()); + seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr + : seqstart_k_with_padding_host.data()); + cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data()); + cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data()); seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); @@ -830,8 +894,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse); const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments @@ -846,8 +910,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse); + const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse); const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); @@ -961,6 +1025,29 @@ fwd_result fmha_fwd_run(mode_enum mode, { args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); } + + // Group-mode: optional physical padded starts for Q/K + if(mode == mode_enum::group) + { + args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty() + ? nullptr + : seqstart_q_padded_buf.GetDeviceBuffer()); + args.seqstart_padded_k_ptr = + (seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer()); + } + + // Batch-mode: optional cumulative effective seqlen overrides + if(mode == mode_enum::batch) + { + args.cu_seqlen_q_ptr = cuq_cum.empty() + ? nullptr + : reinterpret_cast( + cu_seqlen_q_buf.GetDeviceBuffer()); + args.cu_seqlen_kv_ptr = cukv_cum.empty() + ? nullptr + : reinterpret_cast( + cu_seqlen_kv_buf.GetDeviceBuffer()); + } } else if constexpr(std::is_same_v>) { @@ -1167,15 +1254,29 @@ fwd_result fmha_fwd_run(mode_enum mode, for(ck_tile::index_t wb = 0; wb < batch; ++wb) { - const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + if(mode == mode_enum::batch) + { + if(!cuq_cum.empty()) + { + real_seqlen_q = cuq_cum[wb + 1] - cuq_cum[wb]; + } + if(!cukv_cum.empty()) + { + real_seqlen_k = cukv_cum[wb + 1] - cukv_cum[wb]; + } + } // adjust matrix index according to the mode const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t cache_b_idx = (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); const ck_tile::index_t query_offset = - (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + (mode == mode_enum::batch + ? 0 + : (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb] + : seqstart_q_with_padding_host[wb])); const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 @@ -1538,8 +1639,10 @@ fwd_result fmha_fwd_run(mode_enum mode, if(lse) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); + const ck_tile::index_t query_offset_lse = + (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse); }); cur_pass = ck_tile::check_err(lse_host_result, diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp index 10cb5149a4..4bd1d1a367 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp @@ -56,6 +56,11 @@ struct fmha_fwd_v3_args index_t stride_o; index_t nhead_stride_o; index_t batch_stride_o; + + // Optional batch-mode cumulative seqlen overrides (exclude PAD) + // If provided, they override per-batch effective lengths to skip tail padding. + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] }; std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type); diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index e0fbad39a5..194675f962 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -158,7 +158,9 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi args.window_size_left, args.window_size_right, args.mask_type, - remap_opt); + remap_opt, + args.cu_seqlen_q_ptr, + args.cu_seqlen_kv_ptr); dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); constexpr dim3 blocks = Kernel::BlockSize(); diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh index 88c16cceb6..31ad800039 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd.sh @@ -18,3 +18,36 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn done done done + +#Padding Benchmarks: batch mode (baseline vs low/med/high pad) +prec="fp16" +base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" + +# baseline (no pad) +$EXE $base_batch_args + +# low pad (≈90–95% effective) +$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + +# medium pad (≈60–75% effective) +$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + +# high pad (≈30–40% effective) +$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 + +# Padding Benchmarks: group mode (baseline vs low/med/high physical pad) +seqlens_q="1024,768,512,256" +seqlens_k="1024,768,512,256" +base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" + +# baseline (no physical pad) +$EXE $base_group_args + +# low physical pad +$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 + +# medium physical pad +$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 + +# high physical pad +$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh index b847e85398..a3f7d68eb3 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh @@ -23,3 +23,20 @@ done done done done + +# Padding benchmark comparisons for v3 (batch mode only) +# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ==== +prec="fp16" +base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID" + +# baseline (no pad) +$EXE $base_v3_args + +# low pad (≈90–95% effective) +$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + +# medium pad (≈60–75% effective) +$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + +# high pad (≈30–40% effective) +$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index afd0c728c6..fca6b8d0cd 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -137,9 +137,118 @@ run_fp16_appendkv_tests() { done ; done ; done } +run_padding_smoke_tests() { + # Padding-only smoke tests for batch/group mode using COMMON_ARGS + local prec="fp16" + + # Batch mode: padding via effective lengths (exclude PAD) + # Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches + local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" + # low pad (≈90–95% effective) + $EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + # medium pad (≈60–75% effective) + $EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + # high pad (≈30–40% effective) + $EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 + + # Group mode: padding via physical stride along seqlen + local seqlens_q="1024,768,512,256" + local seqlens_k="1024,768,512,256" + local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" + # low physical pad + $EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 + # medium physical pad + $EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 + # high physical pad + $EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 +} + +run_padding_basic_boundary_tests() { + # Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh) + local prec + local perm + + # Group mode: Q&K padded with per-batch different strides + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \ + -s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \ + -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # slightly larger, uneven padding strides + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \ + -s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # only K padded; Q unpadded + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \ + -s=55 -s_k=256 -s_kpad=272,260 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # use cu_seqlen overrides to skip tail PAD + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \ + -q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \ + -q_eff_lens=55,60 -kv_eff_lens=200,256 \ + -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # no padding (equal), mixed Q/KV, all len=1 + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + done + + # highly variable logical lengths + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \ + -s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + done + + # GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16 + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \ + -s=256,129 -s_k=256,129 -s_kpad=256 \ + -bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \ + -kname=$KNAME $COMMON_ARGS + done +} + set -x run_fp16_bf16_tests +run_padding_smoke_tests +run_padding_basic_boundary_tests run_fp8_tests run_fp8bf16_tests run_fp8fp32_tests diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 58fdad149a..3f417bc125 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -291,6 +291,11 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_o; + + // Optional cumulative sequence length pointers for batch mode + // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD }; struct FmhaFwdGroupModeKargs @@ -310,6 +315,11 @@ struct FmhaFwdKernel const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; + + // Optional cumulative padded sequence starts (including PAD tokens) + // Used solely to compute memory offsets when sequences are physically padded. + const int32_t* seqstart_padded_q_ptr = nullptr; + const int32_t* seqstart_padded_k_ptr = nullptr; }; using Kargs = std::conditional_t; @@ -460,6 +470,105 @@ struct FmhaFwdKernel return kargs; } + // Overload: Batch mode with optional cu_seqlen pointers (unpadded cumulative lengths) + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset, + const ck_tile::index_t* cu_seqlen_q_ptr, + const ck_tile::index_t* cu_seqlen_kv_ptr) + { + auto kargs = MakeKargsImpl(q_ptr, + k_ptr, + v_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + logits_soft_cap, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + window_size_left, + window_size_right, + mask_type, + p_drop, + s_randval, + drop_seed_offset); + + kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; + kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; + return kargs; + } + // std::variant<> can't take in a list initializer, overload for backward compatibility template CK_TILE_HOST static constexpr std::enable_if_t @@ -781,6 +890,95 @@ struct FmhaFwdKernel return kargs; } + // Overload: Group mode with optional padded seqstarts for memory offsets + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + ck_tile::index_t min_seqlen_q, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset, + const void* seqstart_padded_q_ptr, + const void* seqstart_padded_k_ptr) + { + auto kargs = MakeKargsImpl(q_ptr, + k_ptr, + v_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + logits_soft_cap, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + window_size_left, + window_size_right, + mask_type, + min_seqlen_q, + p_drop, + s_randval, + drop_seed_offset); + + kargs.seqstart_padded_q_ptr = reinterpret_cast(seqstart_padded_q_ptr); + kargs.seqstart_padded_k_ptr = reinterpret_cast(seqstart_padded_k_ptr); + return kargs; + } + // std::variant<> can't take in a list initializer, overload for backward compatibility template CK_TILE_HOST static constexpr std::enable_if_t @@ -1073,35 +1271,44 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) { - // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + // logical and physical (padded) starts + const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; + const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr + ? kargs.seqstart_padded_q_ptr[i_batch] + : query_start_unpadded; + const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr + ? kargs.seqstart_padded_k_ptr[i_batch] + : key_start_unpadded; + + // DRAM base offsets use physical padded starts + batch_offset_q = query_start_padded * kargs.stride_q; + batch_offset_k = key_start_padded * kargs.stride_k; if constexpr(std::is_same_v) { - batch_offset_v = key_start * kargs.stride_v; + batch_offset_v = key_start_padded * kargs.stride_v; } else { - batch_offset_v = key_start; + batch_offset_v = key_start_padded; } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - batch_offset_bias = query_start * kargs.stride_bias; + batch_offset_bias = query_start_padded * kargs.stride_bias; } if constexpr(kStoreLSE) { - batch_offset_lse = query_start; + // LSE stays indexed by unpadded starts + batch_offset_lse = query_start_unpadded; } if constexpr(kHasDropout) { - batch_offset_randval = query_start * kargs.stride_randval; + batch_offset_randval = query_start_padded * kargs.stride_randval; } - batch_offset_o = query_start * kargs.stride_o; + batch_offset_o = query_start_padded * kargs.stride_o; - // get real # queries & # keys under group mode + // real logical lengths (exclude PAD) const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; @@ -1113,8 +1320,7 @@ struct FmhaFwdKernel } } - // # of required blocks is different in each groups, terminate unnecessary blocks - // earlier + // terminate unnecessary blocks earlier if(kargs.seqlen_q <= i_m0) { return; @@ -1150,6 +1356,18 @@ struct FmhaFwdKernel static_cast(i_batch) * kargs.batch_stride_randval; } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + + // If cumulative seqlen pointers are provided, override per-batch effective lengths + if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + if(kargs.cu_seqlen_kv_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; + } } // for simplicity, batch stride we just modify the pointer @@ -1548,26 +1766,35 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) { // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; + const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr + ? kargs.seqstart_padded_q_ptr[i_batch] + : query_start_unpadded; + const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr + ? kargs.seqstart_padded_k_ptr[i_batch] + : key_start_unpadded; + + batch_offset_q = query_start_padded * kargs.stride_q; + batch_offset_k = key_start_padded * kargs.stride_k; if constexpr(std::is_same_v) { - batch_offset_v = key_start * kargs.stride_v; + batch_offset_v = key_start_padded * kargs.stride_v; } else { - batch_offset_v = key_start; + // col-major V: offset along seqlen dimension is scalar index + batch_offset_v = key_start_padded; } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - batch_offset_bias = query_start * kargs.stride_bias; + batch_offset_bias = query_start_padded * kargs.stride_bias; } - batch_offset_lse = query_start; - batch_offset_o = query_start * kargs.stride_o; + // LSE layout is [nhead, total_seqlen], index by unpadded start + batch_offset_lse = query_start_unpadded; + batch_offset_o = query_start_padded * kargs.stride_o; // get real # queries & # keys under group mode kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; @@ -1605,6 +1832,18 @@ struct FmhaFwdKernel batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; } + + // If cumulative seqlen pointers are provided, override per-batch effective lengths + if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + if(kargs.cu_seqlen_kv_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; + } } // for simplicity, batch stride we just modify the pointer diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index c5e5745817..52b9da40b8 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -100,6 +100,11 @@ struct FmhaFwdV3Kernel ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_o; + + // Optional cumulative sequence length pointers for batch mode + // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] }; struct FmhaFwdGroupModeKargs @@ -110,6 +115,11 @@ struct FmhaFwdV3Kernel const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; + + // Optional cumulative padded sequence starts (including PAD tokens) + // Used solely to compute memory offsets when sequences are physically padded. + const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1] + const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1] }; using Kargs = std::conditional_t; @@ -190,6 +200,78 @@ struct FmhaFwdV3Kernel return kargs; } + // Overload: Batch mode with optional cu_seqlen pointers + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + ck_tile::index_t remap_opt, + const ck_tile::index_t* cu_seqlen_q_ptr, + const ck_tile::index_t* cu_seqlen_kv_ptr) + { + auto kargs = MakeKargs(q_ptr, + k_ptr, + v_ptr, + lse_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_lse, + batch_stride_o, + window_size_left, + window_size_right, + mask_type, + remap_opt); + + kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; + kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; + return kargs; + } + template CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, @@ -260,6 +342,70 @@ struct FmhaFwdV3Kernel return kargs; } + // Overload: Group mode with optional padded seqstarts for memory offsets + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + ck_tile::index_t remap_opt, + const void* seqstart_padded_q_ptr, + const void* seqstart_padded_k_ptr) + { + auto kargs = MakeKargs(q_ptr, + k_ptr, + v_ptr, + lse_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_lse, + nhead_stride_o, + window_size_left, + window_size_right, + mask_type, + remap_opt); + + kargs.seqstart_padded_q_ptr = reinterpret_cast(seqstart_padded_q_ptr); + kargs.seqstart_padded_k_ptr = reinterpret_cast(seqstart_padded_k_ptr); + return kargs; + } + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, @@ -373,18 +519,26 @@ struct FmhaFwdV3Kernel if constexpr(kIsGroupMode) { // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - batch_offset_v = key_start * kargs.stride_v; + const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr + ? kargs.seqstart_padded_q_ptr[i_batch] + : query_start_unpadded; + const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr + ? kargs.seqstart_padded_k_ptr[i_batch] + : key_start_unpadded; + + batch_offset_q = query_start_padded * kargs.stride_q; + batch_offset_k = key_start_padded * kargs.stride_k; + batch_offset_v = key_start_padded * kargs.stride_v; if constexpr(kStoreLSE) { - batch_offset_lse = query_start; + // LSE layout is [nhead, total_seqlen], index by unpadded start + batch_offset_lse = query_start_unpadded; } - batch_offset_o = query_start * kargs.stride_o; + batch_offset_o = query_start_padded * kargs.stride_o; // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; @@ -417,6 +571,18 @@ struct FmhaFwdV3Kernel batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + + // If cumulative seqlen pointers are provided, override per-batch effective lengths + if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + if(kargs.cu_seqlen_kv_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; + } } // for simplicity, batch stride we just modify the pointer diff --git a/test/ck_tile/fmha/test_fmha_fwd.inc b/test/ck_tile/fmha/test_fmha_fwd.inc index 08abd3358d..66d4e3dc21 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.inc +++ b/test/ck_tile/fmha/test_fmha_fwd.inc @@ -98,7 +98,10 @@ TEST_P(AllLong, Test) hdim_q, hdim_v, 0, // seqlen_knew + {-1}, // seqlen_qpads {seqlen_kpad}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim perm, // i_perm perm, // o_perm @@ -160,7 +163,10 @@ TEST_P(HDimPadding, Test) hdim_q, hdim_v, 0, // seqlen_knew + {-1}, // seqlen_qpads {seqlen_kpad}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim perm, // i_perm perm, // o_perm @@ -217,7 +223,10 @@ TEST_P(ElementwiseBias, Test) hdim_q, hdim_v, 0, // seqlen_knew + {-1}, // seqlen_qpads {-1}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm false, // o_perm @@ -273,7 +282,10 @@ TEST_P(Alibi, Test) hdim_q, hdim_v, 0, // seqlen_knew + {-1}, // seqlen_qpads {-1}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim true, // i_perm true, // o_perm @@ -331,7 +343,10 @@ TEST_P(Dropout, Test) hdim_q, hdim_v, 0, // seqlen_knew + {-1}, // seqlen_qpads {-1}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim false, // i_perm false, // o_perm @@ -391,7 +406,10 @@ TEST_P(PagedKV, Test) hdim_q, hdim_v, 0, // seqlen_knew + {-1}, // seqlen_qpads {-1}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm false, // o_perm @@ -457,7 +475,10 @@ TEST_P(SplitKV, Test) hdim_q, hdim_v, 0, // seqlen_knew + {-1}, // seqlen_qpads {-1}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm false, // o_perm @@ -529,7 +550,10 @@ TEST_P(AppendKV, Test) hdim_q, hdim_v, seqlen_knew, // seqlen_knew + {-1}, // seqlen_qpads {-1}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm true, // o_perm @@ -599,7 +623,10 @@ TEST_P(AppendKVRoPE, Test) hdim_q, hdim_v, seqlen_knew, // seqlen_knew + {-1}, // seqlen_qpads {-1}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch rotary_dim, // rotary_dim i_perm, // i_perm true, // o_perm @@ -623,3 +650,117 @@ TEST_P(AppendKVRoPE, Test) } #endif // CK_TILE_FMHA_FWD_APPENDKV_API + +// --------------------------------------------------------------- +// Additional padding tests (q/kv physical padding & effective len) +// --------------------------------------------------------------- + +// Simple batch-mode test with per-batch Q/KV padding strides and effective lengths +TEST(TestCkTileFmhaFwd, BatchModeQKvPadding) +{ + if constexpr(std::is_same_v) + { + GTEST_SKIP() << "Skip for fp8"; + } + const mode_enum mode = mode_enum::batch; + const int batch = 3; + const int nhead = 2; + const int nhead_k = -1; + const int seqlen_q = 128; + const int seqlen_k = 128; + const int hdim_q = 64; + const int hdim_v = 64; + const int seqlen_knew = 0; + const std::vector seqlen_qpads{}; + const std::vector seqlen_kpads{}; + const std::vector q_eff_lens{120, 128, 100}; + const std::vector kv_eff_lens{110, 128, 90}; + + auto result = fmha_fwd_run(mode, + batch, + nhead, + nhead_k, + {adjust_seqlen(seqlen_q)}, + {adjust_seqlen(seqlen_k)}, + hdim_q, + hdim_v, + seqlen_knew, // seqlen_knew + seqlen_qpads, // seqlen_qpads + seqlen_kpads, // seqlen_kpads + q_eff_lens, // q_eff_lens_per_batch + kv_eff_lens, // kv_eff_lens_per_batch + 0, // rotary_dim + true, // i_perm + true, // o_perm + 0, // scale_s + 0, // logits_soft_cap + def_is_v_rowmajor, + def_lse, // lse + 0, // page_block_size + false, // use_cache_batch_idx + "n", // bias_str + 0.0f, // p_drop + 0, // drop_seed + 0, // drop_offset + false, // drop_prefs + "0", // mask_str + QUANT_ARGS, + true, // is_rotary_interleaved + 1, // num_splits + COMMON_ARGS); + CHECK_RESULT(result); +} + +// Simple group-mode test with uniform seqlen but per-batch padding & effective lengths +TEST(TestCkTileFmhaFwd, GroupModeQKvPadding) +{ + if constexpr(std::is_same_v) + { + GTEST_SKIP() << "Skip for fp8"; + } + const mode_enum mode = mode_enum::group; + const int batch = 2; + const int nhead = 2; + const int nhead_k = -1; + const std::vector seqlen_q{96, 128}; // unpadded + const std::vector seqlen_k{96, 128}; // unpadded + const int hdim_q = 64; + const int hdim_v = 64; + const int seqlen_knew = 0; + const std::vector seqlen_qpads{128, 160}; + const std::vector seqlen_kpads{128, 160}; + + auto result = fmha_fwd_run(mode, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + seqlen_knew, // seqlen_knew + seqlen_qpads, // seqlen_qpads + seqlen_kpads, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch + 0, // rotary_dim + true, // i_perm + true, // o_perm + 0, // scale_s + 0, // logits_soft_cap + def_is_v_rowmajor, + def_lse, // lse + 0, // page_block_size + false, // use_cache_batch_idx + "n", // bias_str + 0.0f, // p_drop + 0, // drop_seed + 0, // drop_offset + false, // drop_prefs + "0", // mask_str + QUANT_ARGS, + true, // is_rotary_interleaved + 1, // num_splits + COMMON_ARGS); + CHECK_RESULT(result); +} From 6cf3fdd21c502249767f814a087fbd9be88013eb Mon Sep 17 00:00:00 2001 From: Yi DING Date: Fri, 19 Sep 2025 21:45:02 +0800 Subject: [PATCH 26/28] [CK_TILE] FMHA BWD Fix Decode Accuracy (#2881) * [CK_TILE] FMHA BWD Fix Decode Accuracy * use s_waitcnt utils --- .../block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 8c8d2af486..6d90429407 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -489,7 +489,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR move_tile_window(k_dram_window, {kN0, 0}); async_load_tile(v_lds_write_window, v_dram_window); move_tile_window(v_dram_window, {kN0, 0}); - // __builtin_amdgcn_s_waitcnt(0); + s_waitcnt(); k_reg_tensor = load_tile(k_lds_read_window); v_reg_tensor = load_tile(v_lds_read_window); kt_reg_tensor = load_tile_transpose(kt_lds_read_window); @@ -636,7 +636,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR } }(); store_tile(bias_lds_write_window, dbias); - __builtin_amdgcn_s_waitcnt(3952); + s_waitcnt(); block_sync_lds(); auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); auto dbias_tile = make_static_distributed_tensor( @@ -664,7 +664,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR } store_tile(ds_lds_window, ds_gemm); } - __builtin_amdgcn_s_waitcnt(3952); + s_waitcnt(); block_sync_lds(); if constexpr(is_epilogue) { From 29446da1d57170a8bda47a452113ef7e44363a04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 19 Sep 2025 16:27:50 +0200 Subject: [PATCH 27/28] Disable bwd weight split-k autodeduce for single stage kernels (#2856) * Disable bwd weight split-k autodeduce for single stage kernels * update interface tests --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- .../device/device_grouped_conv_bwd_weight.hpp | 2 + ...ice_grouped_conv_bwd_weight_multiple_d.hpp | 2 + ...e_grouped_conv_bwd_weight_explicit_xdl.hpp | 47 +++++++++++++--- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 8 +++ ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 8 +++ ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 9 ++++ ...rouped_convnd_bwd_weight_interface_xdl.cpp | 53 ++++++++++--------- 7 files changed, 96 insertions(+), 33 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp index 7296e4faaa..18223c78f7 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -11,6 +11,8 @@ namespace ck { namespace tensor_operation { namespace device { +#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1 + template ()) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 934dc7ee8e..987a1e273a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -671,6 +671,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); +#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN; @@ -683,6 +684,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle grid_size); } else +#endif { k_batch_ = split_k; } @@ -939,6 +941,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif if(!ck::is_xdl_wmma_supported()) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index b361409e38..22fc13bae4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -553,6 +553,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, e_g_k_c_xs_strides); +#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN; @@ -565,6 +566,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle grid_size); } else +#endif { k_batch_ = split_k; } @@ -934,6 +936,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif if(!ck::is_xdl_wmma_supported()) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 8bf188be2e..735eebbdf6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -524,6 +524,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); +#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -549,6 +550,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } } else +#endif { k_batch_ = split_k; } @@ -1275,6 +1277,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp index 2a9421fcd1..354d1fc23b 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp @@ -52,7 +52,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test // clang-format on ck::utils::conv::ConvParam conv_param; - std::vector split_ks{-1, 2}; + ck::index_t split_k_ = 2; template bool Run() @@ -96,30 +96,24 @@ class TestGroupedConvndBwdWeight : public ::testing::Test auto conv = GroupedConvBwdWeightDeviceInstance{}; - bool is_supported = true; - - for(const auto split_k : split_ks) - { - auto argument = conv.MakeArgument(nullptr, - nullptr, - nullptr, - input_lengths, - input_strides, - filter_lengths, - weights_strides, - output_lengths, - output_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}, - split_k); - is_supported &= conv.IsSupportedArgument(argument); - } - return is_supported; + auto argument = conv.MakeArgument(nullptr, + nullptr, + nullptr, + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}, + split_k_); + return conv.IsSupportedArgument(argument); } }; @@ -183,3 +177,12 @@ TYPED_TEST(TestGroupedConvndBwdWeightDefault, VectorLoadCheck) is_supported = this->template Run<2>(); EXPECT_FALSE(is_supported); } + +TYPED_TEST(TestGroupedConvndBwdWeightDefault, SingleStageAutoDeduce) +{ + // Supported version but with auto deduce and single stage + this->conv_param = {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + this->split_k_ = -1; + bool is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); +} From b765fe78f37c85a9ca10c24fec6b7247a170034f Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 19 Sep 2025 08:15:02 -0700 Subject: [PATCH 28/28] =?UTF-8?q?Revert=20"[CK=5FTILE]=20Add=20sequence=20?= =?UTF-8?q?padding=20and=20variable=20length=20support=20in=20fmha=20(a?= =?UTF-8?q?=E2=80=A6"=20(#2883)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 86dd59cd01e41a4190bf2405a0fb0e89d9498b4c. --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 6 +- example/ck_tile/01_fmha/example_fmha_fwd.cpp | 20 +- .../ck_tile/01_fmha/example_fmha_fwd_v3.cpp | 148 +-------- example/ck_tile/01_fmha/fmha_fwd.hpp | 17 +- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 127 +------- example/ck_tile/01_fmha/fmha_fwd_v3.hpp | 5 - example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 4 +- .../ck_tile/01_fmha/script/benchmark_fwd.sh | 33 -- .../01_fmha/script/benchmark_fwd_v3.sh | 17 -- .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 109 ------- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 285 ++---------------- .../ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 180 +---------- test/ck_tile/fmha/test_fmha_fwd.inc | 141 --------- 13 files changed, 60 insertions(+), 1032 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index da0c9ca931..cfb96b7d53 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -259,11 +259,11 @@ class FmhaFwdApiTrait: def skcheck(self) -> str: if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)' - else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' elif self.pipeline_tag in ['qr', 'qs']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' + else : return f'a.seqlen_k % {self.bn0} == 0' elif self.pipeline_tag == 'qr_async_trload': if self.skpad == 't' : return 'true' else: return 'true' diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index 79fda6d564..91cb9f55be 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -33,10 +33,6 @@ auto create_args(int argc, char* argv[]) "0", "seqlen_k for new key/value, 0 means not to use this at all; " "-1 to choose s_knew in [1, s] randomly.") - .insert("s_qpad", - "-1", - "seqlen_q stride between 2 batches (group-mode optional).\n" - "Provide positive strides per-batch to simulate physical padding on Q.") .insert("s_kpad", "-1", "seqlen_k stride between 2 batches, currently used in group-mode only\n" @@ -111,15 +107,7 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") - .insert("jsonfile", "fmha_fwd.json", "json file name to dump results") - .insert("q_eff_lens", - "", - "Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override.") - .insert("kv_eff_lens", - "", - "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override."); + .insert("jsonfile", "fmha_fwd.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -139,9 +127,6 @@ auto run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew"); auto seqlen_kpads = arg_parser.get_int_vec("s_kpad"); - auto seqlen_qpads = arg_parser.get_int_vec("s_qpad"); - auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens"); - auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens"); ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); bool i_perm = arg_parser.get_bool("iperm"); bool o_perm = arg_parser.get_bool("operm"); @@ -189,10 +174,7 @@ auto run(const ck_tile::ArgParser& arg_parser) hdim_q, hdim_v, seqlen_knew, - seqlen_qpads, seqlen_kpads, - q_eff_lens_per_batch, - kv_eff_lens_per_batch, rotary_dim, i_perm, o_perm, diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index 7ddb65a2db..569c98a458 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -52,16 +52,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair get_query_shape() const @@ -183,8 +172,6 @@ struct Problem mask_info mask; TensorLayout input_layout; TensorLayout output_layout; - std::vector q_eff_lens; - std::vector kv_eff_lens; }; struct RunConfig @@ -339,10 +326,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) q_buf.ToDevice(q.data()); k_buf.ToDevice(k.data()); v_buf.ToDevice(v.data()); - // Ensure output buffer is zero-initialized so padded regions compare cleanly - o_buf.SetZero(); - ck_tile::fmha_fwd_v3_args args{}; + ck_tile::fmha_fwd_v3_args args; args.data_type = problem.data_type; args.batch = problem.batch; @@ -395,60 +380,6 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) : problem.seqlen_q * problem.hdim; args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; - // Optional cumulative seqlen overrides (exclude PAD) - const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1; - const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1; - - auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { - std::vector eff; - if(!opt_vec.empty() && opt_vec[0] != -1) - { - eff.assign(opt_vec.begin(), opt_vec.end()); - if(eff.size() < static_cast(problem.batch)) - { - eff.resize(problem.batch, eff.back()); - } - } - else - { - eff.assign(problem.batch, fallback); - } - return eff; - }; - - const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q); - const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k); - - // Calculate cumulative sums for kernel arguments if varlen is used - std::vector cuq_cum, cukv_cum; - auto calculate_cumulative = [&](const std::vector& per_batch_vec, - std::vector& cum_vec) { - cum_vec.resize(per_batch_vec.size() + 1); - cum_vec[0] = 0; - for(std::size_t i = 0; i < per_batch_vec.size(); ++i) - cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; - }; - - if(has_varlen_q) - { - calculate_cumulative(eff_q_vec, cuq_cum); - } - if(has_varlen_k) - { - calculate_cumulative(eff_kv_vec, cukv_cum); - } - - ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0); - ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0); - cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr); - cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr); - args.cu_seqlen_q_ptr = - !cuq_cum.empty() ? reinterpret_cast(cuq_buf.GetDeviceBuffer()) - : nullptr; - args.cu_seqlen_kv_ptr = - !cukv_cum.empty() ? reinterpret_cast(cukv_buf.GetDeviceBuffer()) - : nullptr; - ck_tile::stream_config stream_config{nullptr, true, /*log_level=*/0, @@ -511,72 +442,15 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) o_ref = o_ref.transpose({0, 2, 1, 3}); } - // If variable lengths are provided, compute per-batch references - // with the effective lengths; else compute a single full reference. - if(has_varlen_q || has_varlen_k) - { - // Variable-length aware verification: zero-fill padded region and only compute valid part. - o_ref.SetZero(); - - for(int b = 0; b < problem.batch; ++b) - { - const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; - const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; - - if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) - continue; - - // Slice current batch from inputs (bshd) and build single-batch tensors - ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - - // Copy effective region - q_b.ForEach([&](auto& self, auto idx) { - // idx: [0, s, h, d] - self(idx) = q(b, idx[1], idx[2], idx[3]); - }); - k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); - v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); - - // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) - host::fmha_fwd(q_b, - k_b, - v_b, - problem.mask, - o_b, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - - // Scatter into o_ref's bshd descriptor memory - for(int s = 0; s < seqlen_q_eff; ++s) - { - for(int h = 0; h < problem.nhead_q; ++h) - { - for(int d = 0; d < problem.hdim; ++d) - { - o_ref(b, s, h, d) = o_b(0, s, h, d); - } - } - } - } - } - else - { - // No varlen override: compute the full reference once - host::fmha_fwd(q, - k, - v, - problem.mask, - o_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - } + host::fmha_fwd(q, + k, + v, + problem.mask, + o_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index f5dd42a6bd..c41e48e6aa 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -162,20 +162,11 @@ struct fmha_fwd_args void* lse_ptr; void* o_ptr; - // Optional cumulative sequence length arrays - // Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD) - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] - const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr - // Group mode: seqstart_padded_* provide physical starts including PAD (optional) - const void* seqstart_padded_q_ptr = nullptr; // [batch+1] - const void* seqstart_padded_k_ptr = nullptr; // [batch+1] - ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -563,9 +554,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.min_seqlen_q, args.p_drop, args.s_randval, - args.drop_seed_offset, - args.seqstart_padded_q_ptr, - args.seqstart_padded_k_ptr); + args.drop_seed_offset); } else { // create batch mode kernel arguments @@ -611,9 +600,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.mask_type, args.p_drop, args.s_randval, - args.drop_seed_offset, - args.cu_seqlen_q_ptr, - args.cu_seqlen_kv_ptr); + args.drop_seed_offset); } }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index cb5827975e..43f484fe14 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -151,10 +151,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t seqlen_knew, - std::vector seqlen_qpads, std::vector seqlen_kpads, - std::vector q_eff_lens_per_batch, - std::vector kv_eff_lens_per_batch, ck_tile::index_t rotary_dim, bool i_perm, bool o_perm, @@ -365,44 +362,6 @@ fwd_result fmha_fwd_run(mode_enum mode, const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); - // Optional padded Q seqstarts (group-mode only) - std::vector seqstart_q_with_padding_host; - if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1) - { - if(seqlen_qpads.size() < static_cast(batch)) - { - seqlen_qpads.resize(batch, seqlen_qpads.back()); - } - if(seqlen_qpads.size() == static_cast(batch)) - { - seqstart_q_with_padding_host = to_seqstarts( - ck_tile::span(seqlen_qpads.data(), seqlen_qpads.size())); - } - } - - // Optional batch-mode cumulative seqlen overrides - std::vector cuq_cum, cukv_cum; - if(mode == mode_enum::batch) - { - auto calculate_cumulative = [&](std::vector& per_batch_vec, - std::vector& cum_vec) { - if(!per_batch_vec.empty() && per_batch_vec[0] != -1) - { - if(per_batch_vec.size() < static_cast(batch)) - { - per_batch_vec.resize(batch, per_batch_vec.back()); - } - cum_vec.resize(batch + 1); - cum_vec[0] = 0; - for(int i = 0; i < batch; ++i) - cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; - } - }; - - calculate_cumulative(q_eff_lens_per_batch, cuq_cum); - calculate_cumulative(kv_eff_lens_per_batch, cukv_cum); - } - using TypeConfig = FmhaFwdTypeConfig; using QDataType = typename TypeConfig::QDataType; @@ -486,15 +445,8 @@ fwd_result fmha_fwd_run(mode_enum mode, // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); - // logical(unpadded) total seqlen_q for group; batch uses fixed seqlen - const ck_tile::index_t shape_seqlen_q_lse = - (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); - // physical(padded) total seqlen_q for group when s_qpad is provided; else use logical const ck_tile::index_t shape_seqlen_q = - (mode == mode_enum::batch - ? seqlen_qs[0] - : (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back() - : seqstart_q_with_padding_host.back())); + (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_ks[0] : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() @@ -552,7 +504,7 @@ fwd_result fmha_fwd_run(mode_enum mode, // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] ck_tile::HostTensor lse_host( - lse ? std::array{shape_batch, nhead, shape_seqlen_q_lse} + lse ? std::array{shape_batch, nhead, shape_seqlen_q} : std::array{1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor o_host( @@ -650,16 +602,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty() - ? 0 - : seqstart_q_with_padding_host.size() * - sizeof(int32_t)); - ck_tile::DeviceMem seqstart_k_padded_buf( - seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 - : cuq_cum.size() * sizeof(ck_tile::index_t)); - ck_tile::DeviceMem cu_seqlen_kv_buf( - cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t)); ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) @@ -751,14 +693,8 @@ fwd_result fmha_fwd_run(mode_enum mode, vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); - // Keep logical starts in seqstart_k; pass padded K via separate pointer - seqstart_k.ToDevice(seqstart_k_host.data()); - seqstart_q_padded_buf.ToDevice( - seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data()); - seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr - : seqstart_k_with_padding_host.data()); - cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data()); - cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data()); + seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() + : seqstart_k_with_padding_host.data()); seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); @@ -894,8 +830,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments @@ -910,8 +846,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse); - const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); + const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); @@ -1025,29 +961,6 @@ fwd_result fmha_fwd_run(mode_enum mode, { args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); } - - // Group-mode: optional physical padded starts for Q/K - if(mode == mode_enum::group) - { - args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty() - ? nullptr - : seqstart_q_padded_buf.GetDeviceBuffer()); - args.seqstart_padded_k_ptr = - (seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer()); - } - - // Batch-mode: optional cumulative effective seqlen overrides - if(mode == mode_enum::batch) - { - args.cu_seqlen_q_ptr = cuq_cum.empty() - ? nullptr - : reinterpret_cast( - cu_seqlen_q_buf.GetDeviceBuffer()); - args.cu_seqlen_kv_ptr = cukv_cum.empty() - ? nullptr - : reinterpret_cast( - cu_seqlen_kv_buf.GetDeviceBuffer()); - } } else if constexpr(std::is_same_v>) { @@ -1254,29 +1167,15 @@ fwd_result fmha_fwd_run(mode_enum mode, for(ck_tile::index_t wb = 0; wb < batch; ++wb) { - ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; - if(mode == mode_enum::batch) - { - if(!cuq_cum.empty()) - { - real_seqlen_q = cuq_cum[wb + 1] - cuq_cum[wb]; - } - if(!cukv_cum.empty()) - { - real_seqlen_k = cukv_cum[wb + 1] - cukv_cum[wb]; - } - } + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; // adjust matrix index according to the mode const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t cache_b_idx = (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); const ck_tile::index_t query_offset = - (mode == mode_enum::batch - ? 0 - : (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb] - : seqstart_q_with_padding_host[wb])); + (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 @@ -1639,10 +1538,8 @@ fwd_result fmha_fwd_run(mode_enum mode, if(lse) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); - const ck_tile::index_t query_offset_lse = - (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse); + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); }); cur_pass = ck_tile::check_err(lse_host_result, diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp index 4bd1d1a367..10cb5149a4 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp @@ -56,11 +56,6 @@ struct fmha_fwd_v3_args index_t stride_o; index_t nhead_stride_o; index_t batch_stride_o; - - // Optional batch-mode cumulative seqlen overrides (exclude PAD) - // If provided, they override per-batch effective lengths to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] }; std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type); diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index 194675f962..e0fbad39a5 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -158,9 +158,7 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi args.window_size_left, args.window_size_right, args.mask_type, - remap_opt, - args.cu_seqlen_q_ptr, - args.cu_seqlen_kv_ptr); + remap_opt); dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); constexpr dim3 blocks = Kernel::BlockSize(); diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh index 31ad800039..88c16cceb6 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd.sh @@ -18,36 +18,3 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn done done done - -#Padding Benchmarks: batch mode (baseline vs low/med/high pad) -prec="fp16" -base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" - -# baseline (no pad) -$EXE $base_batch_args - -# low pad (≈90–95% effective) -$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 - -# medium pad (≈60–75% effective) -$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 - -# high pad (≈30–40% effective) -$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 - -# Padding Benchmarks: group mode (baseline vs low/med/high physical pad) -seqlens_q="1024,768,512,256" -seqlens_k="1024,768,512,256" -base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" - -# baseline (no physical pad) -$EXE $base_group_args - -# low physical pad -$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 - -# medium physical pad -$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 - -# high physical pad -$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh index a3f7d68eb3..b847e85398 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh @@ -23,20 +23,3 @@ done done done done - -# Padding benchmark comparisons for v3 (batch mode only) -# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ==== -prec="fp16" -base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID" - -# baseline (no pad) -$EXE $base_v3_args - -# low pad (≈90–95% effective) -$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 - -# medium pad (≈60–75% effective) -$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 - -# high pad (≈30–40% effective) -$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index fca6b8d0cd..afd0c728c6 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -137,118 +137,9 @@ run_fp16_appendkv_tests() { done ; done ; done } -run_padding_smoke_tests() { - # Padding-only smoke tests for batch/group mode using COMMON_ARGS - local prec="fp16" - - # Batch mode: padding via effective lengths (exclude PAD) - # Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches - local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" - # low pad (≈90–95% effective) - $EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 - # medium pad (≈60–75% effective) - $EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 - # high pad (≈30–40% effective) - $EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 - - # Group mode: padding via physical stride along seqlen - local seqlens_q="1024,768,512,256" - local seqlens_k="1024,768,512,256" - local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" - # low physical pad - $EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 - # medium physical pad - $EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 - # high physical pad - $EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 -} - -run_padding_basic_boundary_tests() { - # Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh) - local prec - local perm - - # Group mode: Q&K padded with per-batch different strides - for prec in fp16 bf16 ; do - for perm in 0 1 ; do - $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \ - -s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \ - -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ - -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS - done - done - - # slightly larger, uneven padding strides - for prec in fp16 bf16 ; do - for perm in 0 1 ; do - $EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \ - -s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \ - -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ - -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS - done - done - - # only K padded; Q unpadded - for prec in fp16 bf16 ; do - for perm in 0 1 ; do - $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \ - -s=55 -s_k=256 -s_kpad=272,260 \ - -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ - -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS - done - done - - # use cu_seqlen overrides to skip tail PAD - for prec in fp16 bf16 ; do - for perm in 0 1 ; do - $EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \ - -q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \ - -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ - -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS - - $EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \ - -q_eff_lens=55,60 -kv_eff_lens=200,256 \ - -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ - -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS - done - done - - # no padding (equal), mixed Q/KV, all len=1 - for prec in fp16 bf16 ; do - $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ - -q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \ - -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS - - $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ - -q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \ - -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS - - $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ - -q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \ - -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS - done - - # highly variable logical lengths - for prec in fp16 bf16 ; do - $EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \ - -s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \ - -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS - done - - # GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16 - for prec in fp16 bf16 ; do - $EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \ - -s=256,129 -s_k=256,129 -s_kpad=256 \ - -bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \ - -kname=$KNAME $COMMON_ARGS - done -} - set -x run_fp16_bf16_tests -run_padding_smoke_tests -run_padding_basic_boundary_tests run_fp8_tests run_fp8bf16_tests run_fp8fp32_tests diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 3f417bc125..58fdad149a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -291,11 +291,6 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_o; - - // Optional cumulative sequence length pointers for batch mode - // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD }; struct FmhaFwdGroupModeKargs @@ -315,11 +310,6 @@ struct FmhaFwdKernel const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; - - // Optional cumulative padded sequence starts (including PAD tokens) - // Used solely to compute memory offsets when sequences are physically padded. - const int32_t* seqstart_padded_q_ptr = nullptr; - const int32_t* seqstart_padded_k_ptr = nullptr; }; using Kargs = std::conditional_t; @@ -470,105 +460,6 @@ struct FmhaFwdKernel return kargs; } - // Overload: Batch mode with optional cu_seqlen pointers (unpadded cumulative lengths) - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargsImpl(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* rand_val_ptr, - void* lse_ptr, - void* o_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, - ck_tile::index_t batch_stride_randval, - ck_tile::index_t batch_stride_lse, - ck_tile::index_t batch_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - std::variant, std::pair> - drop_seed_offset, - const ck_tile::index_t* cu_seqlen_q_ptr, - const ck_tile::index_t* cu_seqlen_kv_ptr) - { - auto kargs = MakeKargsImpl(q_ptr, - k_ptr, - v_ptr, - bias_ptr, - rand_val_ptr, - lse_ptr, - o_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - scale_p, - scale_o, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - window_size_left, - window_size_right, - mask_type, - p_drop, - s_randval, - drop_seed_offset); - - kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; - kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; - return kargs; - } - // std::variant<> can't take in a list initializer, overload for backward compatibility template CK_TILE_HOST static constexpr std::enable_if_t @@ -890,95 +781,6 @@ struct FmhaFwdKernel return kargs; } - // Overload: Group mode with optional padded seqstarts for memory offsets - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargsImpl(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* rand_val_ptr, - void* lse_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - ck_tile::index_t min_seqlen_q, - float p_drop, - bool s_randval, - std::variant, std::pair> - drop_seed_offset, - const void* seqstart_padded_q_ptr, - const void* seqstart_padded_k_ptr) - { - auto kargs = MakeKargsImpl(q_ptr, - k_ptr, - v_ptr, - bias_ptr, - rand_val_ptr, - lse_ptr, - o_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - scale_p, - scale_o, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - window_size_left, - window_size_right, - mask_type, - min_seqlen_q, - p_drop, - s_randval, - drop_seed_offset); - - kargs.seqstart_padded_q_ptr = reinterpret_cast(seqstart_padded_q_ptr); - kargs.seqstart_padded_k_ptr = reinterpret_cast(seqstart_padded_k_ptr); - return kargs; - } - // std::variant<> can't take in a list initializer, overload for backward compatibility template CK_TILE_HOST static constexpr std::enable_if_t @@ -1271,44 +1073,35 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) { - // logical and physical (padded) starts - const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr - ? kargs.seqstart_padded_q_ptr[i_batch] - : query_start_unpadded; - const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr - ? kargs.seqstart_padded_k_ptr[i_batch] - : key_start_unpadded; - - // DRAM base offsets use physical padded starts - batch_offset_q = query_start_padded * kargs.stride_q; - batch_offset_k = key_start_padded * kargs.stride_k; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; if constexpr(std::is_same_v) { - batch_offset_v = key_start_padded * kargs.stride_v; + batch_offset_v = key_start * kargs.stride_v; } else { - batch_offset_v = key_start_padded; + batch_offset_v = key_start; } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - batch_offset_bias = query_start_padded * kargs.stride_bias; + batch_offset_bias = query_start * kargs.stride_bias; } if constexpr(kStoreLSE) { - // LSE stays indexed by unpadded starts - batch_offset_lse = query_start_unpadded; + batch_offset_lse = query_start; } if constexpr(kHasDropout) { - batch_offset_randval = query_start_padded * kargs.stride_randval; + batch_offset_randval = query_start * kargs.stride_randval; } - batch_offset_o = query_start_padded * kargs.stride_o; + batch_offset_o = query_start * kargs.stride_o; - // real logical lengths (exclude PAD) + // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; @@ -1320,7 +1113,8 @@ struct FmhaFwdKernel } } - // terminate unnecessary blocks earlier + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier if(kargs.seqlen_q <= i_m0) { return; @@ -1356,18 +1150,6 @@ struct FmhaFwdKernel static_cast(i_batch) * kargs.batch_stride_randval; } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; - - // If cumulative seqlen pointers are provided, override per-batch effective lengths - if(kargs.cu_seqlen_q_ptr != nullptr) - { - kargs.seqlen_q = - kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; - } - if(kargs.cu_seqlen_kv_ptr != nullptr) - { - kargs.seqlen_k = - kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; - } } // for simplicity, batch stride we just modify the pointer @@ -1766,35 +1548,26 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) { // get starting offset for each batch - const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr - ? kargs.seqstart_padded_q_ptr[i_batch] - : query_start_unpadded; - const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr - ? kargs.seqstart_padded_k_ptr[i_batch] - : key_start_unpadded; - - batch_offset_q = query_start_padded * kargs.stride_q; - batch_offset_k = key_start_padded * kargs.stride_k; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; if constexpr(std::is_same_v) { - batch_offset_v = key_start_padded * kargs.stride_v; + batch_offset_v = key_start * kargs.stride_v; } else { - // col-major V: offset along seqlen dimension is scalar index - batch_offset_v = key_start_padded; + batch_offset_v = key_start; } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - batch_offset_bias = query_start_padded * kargs.stride_bias; + batch_offset_bias = query_start * kargs.stride_bias; } - // LSE layout is [nhead, total_seqlen], index by unpadded start - batch_offset_lse = query_start_unpadded; - batch_offset_o = query_start_padded * kargs.stride_o; + batch_offset_lse = query_start; + batch_offset_o = query_start * kargs.stride_o; // get real # queries & # keys under group mode kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; @@ -1832,18 +1605,6 @@ struct FmhaFwdKernel batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; } - - // If cumulative seqlen pointers are provided, override per-batch effective lengths - if(kargs.cu_seqlen_q_ptr != nullptr) - { - kargs.seqlen_q = - kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; - } - if(kargs.cu_seqlen_kv_ptr != nullptr) - { - kargs.seqlen_k = - kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; - } } // for simplicity, batch stride we just modify the pointer diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index 52b9da40b8..c5e5745817 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -100,11 +100,6 @@ struct FmhaFwdV3Kernel ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_o; - - // Optional cumulative sequence length pointers for batch mode - // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] }; struct FmhaFwdGroupModeKargs @@ -115,11 +110,6 @@ struct FmhaFwdV3Kernel const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; - - // Optional cumulative padded sequence starts (including PAD tokens) - // Used solely to compute memory offsets when sequences are physically padded. - const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1] - const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1] }; using Kargs = std::conditional_t; @@ -200,78 +190,6 @@ struct FmhaFwdV3Kernel return kargs; } - // Overload: Batch mode with optional cu_seqlen pointers - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - void* lse_ptr, - void* o_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_lse, - ck_tile::index_t batch_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - ck_tile::index_t remap_opt, - const ck_tile::index_t* cu_seqlen_q_ptr, - const ck_tile::index_t* cu_seqlen_kv_ptr) - { - auto kargs = MakeKargs(q_ptr, - k_ptr, - v_ptr, - lse_ptr, - o_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - stride_q, - stride_k, - stride_v, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_lse, - batch_stride_o, - window_size_left, - window_size_right, - mask_type, - remap_opt); - - kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; - kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; - return kargs; - } - template CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, @@ -342,70 +260,6 @@ struct FmhaFwdV3Kernel return kargs; } - // Overload: Group mode with optional padded seqstarts for memory offsets - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - void* lse_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - ck_tile::index_t remap_opt, - const void* seqstart_padded_q_ptr, - const void* seqstart_padded_k_ptr) - { - auto kargs = MakeKargs(q_ptr, - k_ptr, - v_ptr, - lse_ptr, - o_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - stride_q, - stride_k, - stride_v, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_lse, - nhead_stride_o, - window_size_left, - window_size_right, - mask_type, - remap_opt); - - kargs.seqstart_padded_q_ptr = reinterpret_cast(seqstart_padded_q_ptr); - kargs.seqstart_padded_k_ptr = reinterpret_cast(seqstart_padded_k_ptr); - return kargs; - } - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, @@ -519,26 +373,18 @@ struct FmhaFwdV3Kernel if constexpr(kIsGroupMode) { // get starting offset for each batch - const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr - ? kargs.seqstart_padded_q_ptr[i_batch] - : query_start_unpadded; - const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr - ? kargs.seqstart_padded_k_ptr[i_batch] - : key_start_unpadded; - - batch_offset_q = query_start_padded * kargs.stride_q; - batch_offset_k = key_start_padded * kargs.stride_k; - batch_offset_v = key_start_padded * kargs.stride_v; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_v = key_start * kargs.stride_v; if constexpr(kStoreLSE) { - // LSE layout is [nhead, total_seqlen], index by unpadded start - batch_offset_lse = query_start_unpadded; + batch_offset_lse = query_start; } - batch_offset_o = query_start_padded * kargs.stride_o; + batch_offset_o = query_start * kargs.stride_o; // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; @@ -571,18 +417,6 @@ struct FmhaFwdV3Kernel batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; - - // If cumulative seqlen pointers are provided, override per-batch effective lengths - if(kargs.cu_seqlen_q_ptr != nullptr) - { - kargs.seqlen_q = - kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; - } - if(kargs.cu_seqlen_kv_ptr != nullptr) - { - kargs.seqlen_k = - kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; - } } // for simplicity, batch stride we just modify the pointer diff --git a/test/ck_tile/fmha/test_fmha_fwd.inc b/test/ck_tile/fmha/test_fmha_fwd.inc index 66d4e3dc21..08abd3358d 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.inc +++ b/test/ck_tile/fmha/test_fmha_fwd.inc @@ -98,10 +98,7 @@ TEST_P(AllLong, Test) hdim_q, hdim_v, 0, // seqlen_knew - {-1}, // seqlen_qpads {seqlen_kpad}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim perm, // i_perm perm, // o_perm @@ -163,10 +160,7 @@ TEST_P(HDimPadding, Test) hdim_q, hdim_v, 0, // seqlen_knew - {-1}, // seqlen_qpads {seqlen_kpad}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim perm, // i_perm perm, // o_perm @@ -223,10 +217,7 @@ TEST_P(ElementwiseBias, Test) hdim_q, hdim_v, 0, // seqlen_knew - {-1}, // seqlen_qpads {-1}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm false, // o_perm @@ -282,10 +273,7 @@ TEST_P(Alibi, Test) hdim_q, hdim_v, 0, // seqlen_knew - {-1}, // seqlen_qpads {-1}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim true, // i_perm true, // o_perm @@ -343,10 +331,7 @@ TEST_P(Dropout, Test) hdim_q, hdim_v, 0, // seqlen_knew - {-1}, // seqlen_qpads {-1}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim false, // i_perm false, // o_perm @@ -406,10 +391,7 @@ TEST_P(PagedKV, Test) hdim_q, hdim_v, 0, // seqlen_knew - {-1}, // seqlen_qpads {-1}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm false, // o_perm @@ -475,10 +457,7 @@ TEST_P(SplitKV, Test) hdim_q, hdim_v, 0, // seqlen_knew - {-1}, // seqlen_qpads {-1}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm false, // o_perm @@ -550,10 +529,7 @@ TEST_P(AppendKV, Test) hdim_q, hdim_v, seqlen_knew, // seqlen_knew - {-1}, // seqlen_qpads {-1}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm true, // o_perm @@ -623,10 +599,7 @@ TEST_P(AppendKVRoPE, Test) hdim_q, hdim_v, seqlen_knew, // seqlen_knew - {-1}, // seqlen_qpads {-1}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch rotary_dim, // rotary_dim i_perm, // i_perm true, // o_perm @@ -650,117 +623,3 @@ TEST_P(AppendKVRoPE, Test) } #endif // CK_TILE_FMHA_FWD_APPENDKV_API - -// --------------------------------------------------------------- -// Additional padding tests (q/kv physical padding & effective len) -// --------------------------------------------------------------- - -// Simple batch-mode test with per-batch Q/KV padding strides and effective lengths -TEST(TestCkTileFmhaFwd, BatchModeQKvPadding) -{ - if constexpr(std::is_same_v) - { - GTEST_SKIP() << "Skip for fp8"; - } - const mode_enum mode = mode_enum::batch; - const int batch = 3; - const int nhead = 2; - const int nhead_k = -1; - const int seqlen_q = 128; - const int seqlen_k = 128; - const int hdim_q = 64; - const int hdim_v = 64; - const int seqlen_knew = 0; - const std::vector seqlen_qpads{}; - const std::vector seqlen_kpads{}; - const std::vector q_eff_lens{120, 128, 100}; - const std::vector kv_eff_lens{110, 128, 90}; - - auto result = fmha_fwd_run(mode, - batch, - nhead, - nhead_k, - {adjust_seqlen(seqlen_q)}, - {adjust_seqlen(seqlen_k)}, - hdim_q, - hdim_v, - seqlen_knew, // seqlen_knew - seqlen_qpads, // seqlen_qpads - seqlen_kpads, // seqlen_kpads - q_eff_lens, // q_eff_lens_per_batch - kv_eff_lens, // kv_eff_lens_per_batch - 0, // rotary_dim - true, // i_perm - true, // o_perm - 0, // scale_s - 0, // logits_soft_cap - def_is_v_rowmajor, - def_lse, // lse - 0, // page_block_size - false, // use_cache_batch_idx - "n", // bias_str - 0.0f, // p_drop - 0, // drop_seed - 0, // drop_offset - false, // drop_prefs - "0", // mask_str - QUANT_ARGS, - true, // is_rotary_interleaved - 1, // num_splits - COMMON_ARGS); - CHECK_RESULT(result); -} - -// Simple group-mode test with uniform seqlen but per-batch padding & effective lengths -TEST(TestCkTileFmhaFwd, GroupModeQKvPadding) -{ - if constexpr(std::is_same_v) - { - GTEST_SKIP() << "Skip for fp8"; - } - const mode_enum mode = mode_enum::group; - const int batch = 2; - const int nhead = 2; - const int nhead_k = -1; - const std::vector seqlen_q{96, 128}; // unpadded - const std::vector seqlen_k{96, 128}; // unpadded - const int hdim_q = 64; - const int hdim_v = 64; - const int seqlen_knew = 0; - const std::vector seqlen_qpads{128, 160}; - const std::vector seqlen_kpads{128, 160}; - - auto result = fmha_fwd_run(mode, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - seqlen_knew, // seqlen_knew - seqlen_qpads, // seqlen_qpads - seqlen_kpads, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch - 0, // rotary_dim - true, // i_perm - true, // o_perm - 0, // scale_s - 0, // logits_soft_cap - def_is_v_rowmajor, - def_lse, // lse - 0, // page_block_size - false, // use_cache_batch_idx - "n", // bias_str - 0.0f, // p_drop - 0, // drop_seed - 0, // drop_offset - false, // drop_prefs - "0", // mask_str - QUANT_ARGS, - true, // is_rotary_interleaved - 1, // num_splits - COMMON_ARGS); - CHECK_RESULT(result); -}