From f5e00ec9049f2d87b021063c21210584de4b3f82 Mon Sep 17 00:00:00 2001 From: Hosang Yoon <156028780+hyoon1@users.noreply.github.com> Date: Sat, 18 Apr 2026 02:44:46 -0400 Subject: [PATCH] [CK_TILE] Skip padded k/n fragment work in qr_hpad FMHA fwd (#6450) ## Motivation `qr_hpad` currently executes work for padded head-dim fragments even when only a subset of the values are valid. This adds unnecessary computation for head dimensions that require padding, such as `hdim=72` and `hdim=80`, and hurts FMHA forward performance. The goal of this PR is to make the padded-head-dim path skip invalid work based on the actual valid fragment count, while preserving the existing behavior for the non-padded path. ## Technical Details This PR improves the `qr_hpad` FMHA forward path in three parts: - Skip padded `k`/`n` fragments in the GEMM/pipeline path when only part of the fragment is valid. - Add partial GEMM0 tail handling for `qr_hpad` so the kernel uses the valid fragment range instead of always computing over the padded extent. - Retune the gfx11 `qr_hpad` kernel configuration after enabling the partial-fragment path. To keep the existing path stable, the implementation adds overloads for the updated GEMM/pipeline interfaces. This allows existing full-tile callers to keep using the previous form, while the `qr_hpad` path can pass valid fragment counts when needed. ## Test Plan ./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=16 -d={72/80} -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1} ## Test Result - On gfx11 and gfx12, for head dimensions that require padding, `tile_example_fmha_fwd` shows about 20-30% performance improvement at `hdim=72/80`. ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 22 +- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 182 ++++++---- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 329 ++++++++++++++++-- .../block/block_gemm_areg_bsmem_creg_v2.hpp | 73 +++- 4 files changed, 478 insertions(+), 128 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index c64a19104e..978c9d0a75 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1194,18 +1194,15 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): if (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128): return True - is_64x32_tile = kernel_ctx.tile.F_bm0 == 64 and kernel_ctx.tile.F_bn0 == 32 - pads_hdim = ( - kernel_ctx.pipeline.F_dpad == "t" and kernel_ctx.pipeline.F_dvpad == "t" - ) - exact_hdim = ( - kernel_ctx.pipeline.F_dpad == "f" and kernel_ctx.pipeline.F_dvpad == "f" - ) + # For (128, 128) head dims, partial-fragment support in qr_hpad removes the need + # for the previous qr_hpad-specific handling that was added to avoid register spill. + # qr_hpad now reuses the regular 128x64 tile choice. + # The 64x64 tile remains disabled for qr_hpad because it is consistently slower + # in our measurements. + if kernel_ctx.tile.F_bm0 == 64 and kernel_ctx.tile.F_bn0 == 64: + return kernel_ctx.pipeline.tag != "qr_hpad" - if is_64x32_tile: - return pads_hdim - - return exact_hdim + return True rules.append(check_d128_tile_pipeline) return rules @@ -1218,8 +1215,7 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")), FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - (128, 128) : [FmhaFwdTileSize( 64, 32, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, 6, CppConstraint("a.hdim_q != 128 || a.hdim_v != 128")), - FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")), + (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 2048")), FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)], (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)] diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 16f5b00bb1..b04205f2c2 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -39,6 +39,9 @@ struct FmhaFwdKernel using EpiloguePipeline = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + template + using has_hdim_tail_args = decltype(T::kUseHdimTailArgs); + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; static_assert(kBlockPerCu > 0); static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; @@ -1891,6 +1894,35 @@ struct FmhaFwdKernel }(); BlockIndices block_indices{i_batch, i_nhead, i_nhead_k}; + constexpr bool kPassHdimTailArgs = [] { + if constexpr(ck_tile::is_detected::value) + return static_cast(FmhaPipeline::kUseHdimTailArgs); + else + return false; + }(); + auto invoke_fmha_pipeline = [&](auto&&... args) -> decltype(auto) { + if constexpr(kPassHdimTailArgs) + { + const ck_tile::index_t valid_k0_loops = + ck_tile::integer_divide_ceil(kargs.hdim_q, FmhaPipeline::kK0); + const ck_tile::index_t valid_last_k0_length = + kargs.hdim_q - (valid_k0_loops - 1) * FmhaPipeline::kK0; + const ck_tile::index_t valid_n1_length = [&]() { + const ck_tile::index_t remaining_n1 = kargs.hdim_v - i_n1; + return ck_tile::min(remaining_n1, + static_cast(FmhaPipeline::kN1)); + }(); + return FmhaPipeline{}(static_cast(args)..., + sink_value, + valid_k0_loops, + valid_last_k0_length, + valid_n1_length); + } + else + { + return FmhaPipeline{}(static_cast(args)..., sink_value); + } + }; auto o_acc_tile = [&, i_nhead_ = i_nhead, i_nhead_k_ = i_nhead_k]() { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) @@ -1910,36 +1942,35 @@ struct FmhaFwdKernel else return ck_tile::scales>{scale_o}; }(); - return FmhaPipeline{}(q_dram_window, - identity{}, // q_element_func - k_dram_window, - identity{}, // k_element_func - v_dram_window, - identity{}, // v_element_func - bias_dram_window, - identity{}, // bias_element_func - randval_dram_window, - lse_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - scales>{ - scale_p}, // p_compute_element_func - o_acc_element_func, // o_acc_element_func - mask, - position_encoding, - variant_params.sm_scale, - variant, - variant_params, - block_indices, - smem_ptr, - dropout, - nullptr, - nullptr, - 1, - make_null_tile_window(make_tuple()), - make_null_tile_window(make_tuple()), - make_null_tile_window(make_tuple()), - sink_value); + return invoke_fmha_pipeline(q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales>{ + scale_p}, // p_compute_element_func + o_acc_element_func, // o_acc_element_func + mask, + position_encoding, + variant_params.sm_scale, + variant, + variant_params, + block_indices, + smem_ptr, + dropout, + nullptr, + nullptr, + 1, + make_null_tile_window(make_tuple()), + make_null_tile_window(make_tuple()), + make_null_tile_window(make_tuple())); } else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) { @@ -1964,7 +1995,7 @@ struct FmhaFwdKernel // Both P and rowsum are scaled by 2^shift, canceling in normalization // No additional scaling needed in p_compute_element_func or o_acc_element_func - return FmhaPipeline{}( + return invoke_fmha_pipeline( q_dram_window, identity{}, // q_element_func k_dram_window, @@ -1992,8 +2023,7 @@ struct FmhaFwdKernel kargs.block_scale_size_kv, make_null_tile_window(make_tuple()), make_null_tile_window(make_tuple()), - make_null_tile_window(make_tuple()), - sink_value); + make_null_tile_window(make_tuple())); } else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) { @@ -2098,53 +2128,51 @@ struct FmhaFwdKernel number{}), {i_n1, 0}); - return FmhaPipeline{}(q_dram_window, - identity{}, // q_element_func - k_dram_window, - identity{}, // k_element_func - v_dram_window, - identity{}, // v_element_func - bias_dram_window, - identity{}, // bias_element_func - randval_dram_window, - lse_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - identity{}, // p_compute_element_func - identity{}, // o_acc_element_func - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout, - nullptr, - nullptr, - 1, - q_scale_dram_window, - k_scale_dram_window, - v_scale_dram_window, - sink_value); + return invoke_fmha_pipeline(q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + identity{}, // p_compute_element_func + identity{}, // o_acc_element_func + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout, + nullptr, + nullptr, + 1, + q_scale_dram_window, + k_scale_dram_window, + v_scale_dram_window); } else { - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - randval_dram_window, - lse_dram_window, - mask, - position_encoding, - variant_params.sm_scale, - variant, - variant_params, - block_indices, - smem_ptr, - dropout, - sink_value); + return invoke_fmha_pipeline(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + variant_params.sm_scale, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 48c79177d4..9b932462d0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -39,6 +39,11 @@ struct BlockFmhaPipelineQRKSVS using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; + template + using has_partial_k_support = decltype(T::kSupportsPartialK); + template + using has_partial_n_support = decltype(T::kSupportsPartialN); + using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once @@ -68,6 +73,7 @@ struct BlockFmhaPipelineQRKSVS static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kHasSink = Problem::kHasSink; static constexpr bool kPaddedVecLoadStore = PaddedVecLoadStore_; + static constexpr bool kUseHdimTailArgs = kPadHeadDimQ || kPadHeadDimV; static constexpr ck_tile::index_t kQKScaleGranularity = Problem::kQKScaleGranularity; static constexpr ck_tile::index_t kVScaleGranularity = Problem::kVScaleGranularity; @@ -203,7 +209,10 @@ struct BlockFmhaPipelineQRKSVS k_scale_dram_block_window_tmp, // N0*(K0/kQKScaleGranularity) tile const VScaleDramBlockWindowTmp& v_scale_dram_block_window_tmp, // N1*(K1/kVScaleGranularity) tile - const float sink_v) const + const float sink_v, + const index_t valid_k0_loops, + const index_t valid_last_k0_length, + const index_t valid_n1_length) const { static_assert( std::is_same_v> && @@ -261,8 +270,30 @@ struct BlockFmhaPipelineQRKSVS v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); // Block GEMM - constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + using BlockGemm0 = remove_cvref_t; + using BlockGemm1 = remove_cvref_t; + constexpr bool kBlockGemm0SupportsPartialK = [] { + if constexpr(ck_tile::is_detected::value) + return static_cast(BlockGemm0::kSupportsPartialK); + else + return false; + }(); + constexpr bool kBlockGemm1SupportsPartialN = [] { + if constexpr(ck_tile::is_detected::value) + return static_cast(BlockGemm1::kSupportsPartialN); + else + return false; + }(); + + constexpr auto gemm_0_config = + BlockGemm0::Policy::template GetWarpGemmMWarpNWarp(); + using Gemm0WarpGemm = remove_cvref_t())>; + constexpr index_t kGemm0WarpK = Gemm0WarpGemm::kK; + constexpr index_t kGemm0KItersPerBlock = kK0 / kGemm0WarpK; + constexpr bool kUsePartialKForGemm0Tail = + kPadHeadDimQ && kBlockGemm0SupportsPartialK && (kGemm0KItersPerBlock > 1); auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), q_dram_block_window_tmp.get_window_lengths(), @@ -428,10 +459,26 @@ struct BlockFmhaPipelineQRKSVS index_t i_total_loops = 0; constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k1_loops = kN0 / kK1; + // Number of k0 iterations prefetched ahead of the current compute iteration. + // The skip decision must be made this many iterations before the last k0 loop. + constexpr index_t kK0PrefetchDepth = 2; + const index_t gemm0_tail_k_iters = [&]() { + if constexpr(kUsePartialKForGemm0Tail) + { + return ck_tile::integer_divide_ceil(valid_last_k0_length, kGemm0WarpK); + } + return static_cast(kGemm0KItersPerBlock); + }(); + const bool skip_last_k0_loop = [&]() { + if constexpr(kPadHeadDimQ) + { + return valid_k0_loops == (k0_loops - 1); + } + return false; + }(); // Use compile-time conditional for group barrier sequence // (No runtime lambda selection) auto schedule_gemm_0 = [] { - using BlockGemm0 = remove_cvref_t; constexpr auto WarpGemmConfig = BlockGemm0::Policy::template GetWarpGemmMWarpNWarp(); using WarpGemm0 = remove_cvref_t())>; @@ -456,7 +503,7 @@ struct BlockFmhaPipelineQRKSVS } }; - static_assert(2 <= k0_loops); + static_assert(kK0PrefetchDepth <= k0_loops); static_assert(1 <= k1_loops); do { @@ -523,6 +570,46 @@ struct BlockFmhaPipelineQRKSVS } auto run_gemm_0 = [&](auto i_k0) { + if constexpr(kUsePartialKForGemm0Tail) + { + if(static_cast(i_k0.value) == (valid_k0_loops - 1) && + gemm0_tail_k_iters < kGemm0KItersPerBlock) + { + static_for<1, kGemm0KItersPerBlock, 1>{}([&](auto i_tail_k_iter) { + constexpr index_t kTailKIters = i_tail_k_iter; + constexpr index_t kTailK0 = kTailKIters * kGemm0WarpK; + + if(gemm0_tail_k_iters == kTailKIters) + { + using Gemm0TailProblem = BlockGemmProblem< + QDataType, + KDataType, + SaccDataType, + Problem::kNumGemm0Warps * get_warp_size(), + TileGemmShape< + sequence, + typename BlockFmhaShape::Gemm0BlockWarps, + sequence{}), + BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + kGemm0WarpK>>>; + constexpr auto gemm_0_tail = + BlockGemmARegBSmemCRegV2{}; + + auto q_slice = + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}); + auto k_tail_window = make_tile_window( + k_lds, make_tuple(number{}, number{}), {0, 0}); + + gemm_0_tail(s_acc, q_slice, k_tail_window); + } + }); + return; + } + } + auto q_slice = get_slice_tile( q_tile, sequence<0, i_k0 * kK0>{}, sequence{}); if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) @@ -540,19 +627,37 @@ struct BlockFmhaPipelineQRKSVS } }; - if constexpr(k0_loops > 2) + if constexpr(k0_loops > kK0PrefetchDepth) { - static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + static_for<0, k0_loops - kK0PrefetchDepth, 1>{}([&](auto i_k0) { block_sync_lds(); run_gemm_0(number{}); block_sync_lds(); - move_tile_window(k_dram_window, {0, kK0}); + if constexpr(kPadHeadDimQ && i_k0 == (k0_loops - 1 - kK0PrefetchDepth)) + { + if(!skip_last_k0_loop) + { + move_tile_window(k_dram_window, {0, kK0}); + } - store_tile( - k_lds_window, - tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 - k_block_tile = load_tile(k_dram_window); // global read i + 2 + store_tile( + k_lds_window, + tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 + if(!skip_last_k0_loop) + { + k_block_tile = load_tile(k_dram_window); // global read i + 2 + } + } + else + { + move_tile_window(k_dram_window, {0, kK0}); + + store_tile( + k_lds_window, + tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + } k_scale_block_tile = load_k_scale_block_tile(); }); } @@ -577,16 +682,19 @@ struct BlockFmhaPipelineQRKSVS } { // tail block_sync_lds(); - run_gemm_0(number{}); - block_sync_lds(); + run_gemm_0(number{}); + if(!skip_last_k0_loop) + { + block_sync_lds(); - store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); - k_scale_block_tile = load_k_scale_block_tile(); + k_scale_block_tile = load_k_scale_block_tile(); - block_sync_lds(); + block_sync_lds(); - run_gemm_0(number{}); + run_gemm_0(number{}); + } } if constexpr(kVPrefetch == VPrefetchPoint::AfterGemm0Tail) { @@ -933,6 +1041,31 @@ struct BlockFmhaPipelineQRKSVS auto o_acc0 = decltype(o_acc){}; clear_tile(o_acc0); + constexpr auto gemm_1_config = + BlockGemm1::Policy::template GetWarpGemmMWarpNWarp(); + using Gemm1WarpGemm = remove_cvref_t())>; + constexpr index_t kGemm1NWarp = gemm_1_config.template at<2>(); + constexpr index_t kGemm1NPerIter = kGemm1NWarp * Gemm1WarpGemm::kN; + const index_t valid_n_iters = [&]() { + if constexpr(kPadHeadDimV && kBlockGemm1SupportsPartialN) + { + return ck_tile::integer_divide_ceil(valid_n1_length, kGemm1NPerIter); + } + return static_cast(0); + }(); + + auto run_gemm_1_impl = + [&](auto& o_acc_tensor, const auto& p_slice, const auto&... gemm_1_args) { + if constexpr(kPadHeadDimV && kBlockGemm1SupportsPartialN) + { + gemm_1(o_acc_tensor, p_slice, gemm_1_args..., valid_n_iters); + } + else + { + gemm_1(o_acc_tensor, p_slice, gemm_1_args...); + } + }; + auto run_gemm_1 = [&](auto i_k1) { auto p_slice = get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}); @@ -942,15 +1075,19 @@ struct BlockFmhaPipelineQRKSVS get_slice_tile(p_scale, sequence<0, i_k1*(kK1 / kVScaleGranularity)>{}, sequence{}); - gemm_1(o_acc, p_slice, p_scale_slice, v_lds_window, v_scale_block_tile); - } - else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - gemm_1(o_acc0, p_slice, v_lds_window); + run_gemm_1_impl( + o_acc, p_slice, p_scale_slice, v_lds_window, v_scale_block_tile); } else { - gemm_1(o_acc, p_slice, v_lds_window); + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + run_gemm_1_impl(o_acc0, p_slice, v_lds_window); + } + else + { + run_gemm_1_impl(o_acc, p_slice, v_lds_window); + } } }; @@ -1075,6 +1212,94 @@ struct BlockFmhaPipelineQRKSVS return o_acc; } + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + DropoutType& dropout, + const float* k_descale_ptr, + const float* v_descale_ptr, + const index_t block_scale_size_kv, + const QScaleDramBlockWindowTmp& + q_scale_dram_block_window_tmp, // M0*(K0/kQKScaleGranularity) tile + const KScaleDramBlockWindowTmp& + k_scale_dram_block_window_tmp, // N0*(K0/kQKScaleGranularity) tile + const VScaleDramBlockWindowTmp& + v_scale_dram_block_window_tmp, // N1*(K1/kVScaleGranularity) tile + const float sink_v) const + { + return operator()(q_dram_block_window_tmp, + q_element_func, + k_dram_block_window_tmp, + k_element_func, + v_dram_block_window_tmp, + v_element_func, + bias_dram_block_window_tmp, + bias_element_func, + randval_dram_block_window_tmp, + lse_dram_window_tmp, + lse_element_func, + s_acc_element_func, + p_compute_element_func, + o_acc_element_func, + mask, + position_encoding, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout, + k_descale_ptr, + v_descale_ptr, + block_scale_size_kv, + q_scale_dram_block_window_tmp, + k_scale_dram_block_window_tmp, + v_scale_dram_block_window_tmp, + sink_v, + kQKHeaddim / kK0, + kK0, + kN1); + } + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + DropoutType& dropout, + const float sink_v) const + { + return operator()(q_dram_block_window_tmp, + k_dram_block_window_tmp, + v_dram_block_window_tmp, + bias_dram_block_window_tmp, + randval_dram_block_window_tmp, + lse_dram_block_window_tmp, + mask, + position_encoding, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout, + sink_v, + kQKHeaddim / kK0, + kK0, + kN1); } }; diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp index d292cade24..de5ba747d3 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp @@ -21,13 +21,18 @@ struct BlockGemmARegBSmemCRegV2 using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr bool kSupportsPartialK = true; + static constexpr bool kSupportsPartialN = true; - // C += A * B - template - CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ABlockTensorTmp& a_block_tensor_tmp, - const BBlockWindowTmp& b_block_window_tmp) const + template + CK_TILE_DEVICE void Impl(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp, + [[maybe_unused]] const index_t valid_n_iters) const { static_assert( std::is_same_v> && @@ -134,10 +139,7 @@ struct BlockGemmARegBSmemCRegV2 constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - // hot loop: - static_ford>{}([&](auto kn) { - constexpr auto kIter = number{}]>{}; - constexpr auto nIter = number{}]>{}; + auto run_n_iter = [&](auto kIter, auto nIter) { // read B warp tensor from B Block window const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); @@ -166,7 +168,44 @@ struct BlockGemmARegBSmemCRegV2 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); }); - }); + }; + + // hot loop: + if constexpr(UsePartialN) + { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + if(static_cast(nIter.value) < valid_n_iters) + { + run_n_iter(kIter, nIter); + } + }); + }); + } + else + { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { run_n_iter(kIter, nIter); }); + }); + } + } + + // C += A * B (executing only the first valid_n_iters N sub-iterations) + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp, + const index_t valid_n_iters) const + { + Impl(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp, valid_n_iters); + } + + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + Impl(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp, 0); } template @@ -227,7 +266,17 @@ struct BlockGemmARegBSmemCRegV2 return c_block_tensor; } - // C = A * B + // C = A * B (executing only the first valid_n_iters N sub-iterations) + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp, + const index_t valid_n_iters) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp, valid_n_iters); + return c_block_tensor; + } + template CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, const BBlockWindowTmp& b_block_window_tmp) const