diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index e35587b71d..8593975e9c 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -299,10 +299,10 @@ struct FmhaFwdAppendKVKernel __shared__ char smem_ptr[GetSmemSize()]; // divide problem - const auto [i_tile_sk, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v); + const auto [i_tile, i_nhead, i_batch] = TilePartitioner{}(); - const index_t i_sk = __builtin_amdgcn_readfirstlane(i_tile_sk * FmhaPipeline::kTileSizeSk); - // const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kTileSizeS); + const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kTileSizeSk); long_index_t batch_offset_q = 0; long_index_t batch_offset_k = 0; @@ -514,9 +514,8 @@ struct FmhaFwdAppendKVKernel sequence{}); }(); - /// TODO: use tile idx for q return make_tile_window( - rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_sk, 0}); + rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0}); } else { @@ -540,9 +539,8 @@ struct FmhaFwdAppendKVKernel sequence{}); }(); - /// TODO: use tile idx for q return make_tile_window( - rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_sk, 0}); + rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0}); } else { @@ -570,7 +568,7 @@ struct FmhaFwdAppendKVKernel }(); return make_tile_window( - rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_sk, 0}); + rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0}); } else { @@ -595,7 +593,7 @@ struct FmhaFwdAppendKVKernel }(); return make_tile_window( - rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_sk, 0}); + rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0}); } else { @@ -603,31 +601,30 @@ struct FmhaFwdAppendKVKernel } }(); - /// TODO: use tile idx for q auto q_dram_window = make_tile_window( q_dram, make_tuple(number{}, number{}), - {0, 0}); + {i_m0, 0}); auto k_dram_window = make_tile_window( k_dram, make_tuple(number{}, number{}), - {kargs.seqlen_k + i_sk, 0}); + {kargs.seqlen_k + i_n0, 0}); auto knew_dram_window = make_tile_window( knew_dram, make_tuple(number{}, number{}), - {i_sk, 0}); + {i_n0, 0}); auto v_dram_window = make_tile_window( v_dram, make_tuple(number{}, number{}), - {0, kargs.seqlen_k + i_sk}); + {0, kargs.seqlen_k + i_n0}); auto vnew_dram_window = make_tile_window( vnew_dram, make_tuple(number{}, number{}), - {0, i_sk}); + {0, i_n0}); if constexpr(kApplyRoPE) { @@ -640,6 +637,8 @@ struct FmhaFwdAppendKVKernel q_rotary_sin_dram_window, knew_rotary_cos_dram_window, knew_rotary_sin_dram_window, + kargs.seqlen_q <= i_m0, + kargs.seqlen_knew <= i_n0, smem_ptr, kargs.rotary_dim); } @@ -654,6 +653,8 @@ struct FmhaFwdAppendKVKernel q_rotary_sin_dram_window, knew_rotary_cos_dram_window, knew_rotary_sin_dram_window, + kargs.seqlen_q <= i_m0, + kargs.seqlen_knew <= i_n0, smem_ptr); } } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp index f63d13b4cb..07e73359d6 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp @@ -19,35 +19,23 @@ struct FmhaFwdAppendKVTilePartitioner CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, - ck_tile::index_t seqlen_knew, - ck_tile::index_t hdim_v) + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_knew) { - assert(ck_tile::integer_divide_ceil(hdim_v, kTileSizeD) == 1); -#ifdef NDEBUG - ignore = hdim_v; -#endif - // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_knew, kTileSizeSk), nhead, batch_size); + return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, kTileSizeS), + ck_tile::integer_divide_ceil(seqlen_knew, kTileSizeSk)), + nhead, + batch_size); } - CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t /*hdim_v*/) + CK_TILE_DEVICE auto operator()() { - // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); - - const index_t i_block = blockIdx.x; + const index_t i_tile = blockIdx.x; const index_t i_nhead = blockIdx.y; const index_t i_batch = blockIdx.z; - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; - return ck_tile::make_tuple(quotient, modulus); - }; - (void)f; - // const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - - return ck_tile::make_tuple(i_block, i_nhead, i_batch); + return ck_tile::make_tuple(i_tile, i_nhead, i_batch); } };