diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index f88a0efa3b..cbc9f255f1 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -60,7 +60,7 @@ struct FmhaFwdAppendKVKernel _SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s::name) + "_" + (kIsGroupMode ? "group" : "batch") + "_" "b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" + - _TS_(FmhaPipeline::kTileSizeDv) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + + _TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + (!kApplyRoPE ? _SS_("") : (_SS_("_") + RotaryEmbeddingEnumToStr::name)); #undef _SS_ @@ -439,7 +439,7 @@ struct FmhaFwdAppendKVKernel return pad_tensor_view( v_dram_transposed, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), sequence{}); } else @@ -453,7 +453,7 @@ struct FmhaFwdAppendKVKernel return pad_tensor_view( v_dram_naive, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), sequence{}); } }(); @@ -476,7 +476,7 @@ struct FmhaFwdAppendKVKernel return pad_tensor_view( vnew_dram_transposed, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), sequence{}); } else @@ -490,7 +490,7 @@ struct FmhaFwdAppendKVKernel return pad_tensor_view( vnew_dram_naive, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), sequence{}); } }(); @@ -616,15 +616,15 @@ struct FmhaFwdAppendKVKernel make_tuple(number{}, number{}), {i_n0, 0}); - auto v_dram_window = make_tile_window( - v_dram, - make_tuple(number{}, number{}), - {0, kargs.seqlen_k + i_n0}); + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {0, kargs.seqlen_k + i_n0}); - auto vnew_dram_window = make_tile_window( - vnew_dram, - make_tuple(number{}, number{}), - {0, i_n0}); + auto vnew_dram_window = + make_tile_window(vnew_dram, + make_tuple(number{}, number{}), + {0, i_n0}); if constexpr(kApplyRoPE) { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp index 5451f25ed0..97c9b960c2 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp @@ -7,15 +7,15 @@ namespace ck_tile { -template +template struct FmhaFwdAppendKVTilePartitioner { - static constexpr ck_tile::index_t kM0 = kM0_; - static constexpr ck_tile::index_t kN0 = kN0_; - static constexpr ck_tile::index_t kK0 = kK0_; - static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; - static_assert(kK0 == kTileSizeDv); + static_assert(kK0 == kN1); CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index c28fde636e..a6e012df1e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -22,10 +22,10 @@ struct BlockFmhaFwdAppendKVPipeline static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = Problem::kM0; - static constexpr index_t kN0 = Problem::kN0; - static constexpr index_t kK0 = Problem::kK0; - static constexpr index_t kTileSizeDv = Problem::kTileSizeDv; + static constexpr index_t kM0 = Problem::kM0; + static constexpr index_t kN0 = Problem::kN0; + static constexpr index_t kK0 = Problem::kK0; + static constexpr index_t kN1 = Problem::kN1; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; @@ -94,8 +94,8 @@ struct BlockFmhaFwdAppendKVPipeline KDramBlockWindow& k_dram_block_window, // N0*K0 tile const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile const KnewElementFunction& knew_element_func, - VDramBlockWindow& v_dram_block_window, // N1*K1 tile - const VnewDramBlockWindow& vnew_dram_block_window, // N1*K1 tile + VDramBlockWindow& v_dram_block_window, // N1*N0 tile + const VnewDramBlockWindow& vnew_dram_block_window, // N1*N0 tile const VnewElementFunction& vnew_element_func, const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window, const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp index ae7af58db6..f716e1eed1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp @@ -34,7 +34,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::kN0; - constexpr index_t kKPerBlock = Problem::kTileSizeDv; + constexpr index_t kKPerBlock = Problem::kN1; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; // TODO: not correct! @@ -188,7 +188,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy using VDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::kTileSizeDv; + constexpr index_t kNPerBlock = Problem::kN1; constexpr index_t kKPerBlock = Problem::kN0; if constexpr(std::is_same_v) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp index 198cdfebb4..3d66b38222 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp @@ -13,7 +13,7 @@ template @@ -27,10 +27,10 @@ struct BlockFmhaFwdAppendKVPipelineProblem static constexpr index_t kBlockSize = 256; static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr index_t kM0 = kM0_; - static constexpr index_t kN0 = kN0_; - static constexpr index_t kK0 = kK0_; - static constexpr index_t kTileSizeDv = kTileSizeDv_; + static constexpr index_t kM0 = kM0_; + static constexpr index_t kN0 = kN0_; + static constexpr index_t kK0 = kK0_; + static constexpr index_t kN1 = kN1_; using VLayout = std::conditional_t