diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index 7eddc3cede..79a114770e 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -332,15 +332,15 @@ struct FmhaFwdSplitKVCombineKernel make_tuple(number<1>{}, number{}, number{}), sequence{}); - const index_t new_seqlen_q = - integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) * FmhaPipeline::kM0; - const index_t new_hdim_v = - integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1) * FmhaPipeline::kN1; + const index_t padded_max_seqlen_q = + o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}]; + const index_t padded_hdim_v = + o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}]; return transform_tensor_view( o_acc_dram_view, - make_tuple(make_merge_transform(make_tuple(kargs.num_splits, new_seqlen_q)), - make_pass_through_transform(new_hdim_v)), + make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)), + make_pass_through_transform(padded_hdim_v)), make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{})); }(); @@ -399,6 +399,7 @@ struct FmhaFwdSplitKVCombineKernel composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func smem_ptr, kargs.num_splits, + kargs.seqlen_q, kargs.max_seqlen_q); } else @@ -408,6 +409,7 @@ struct FmhaFwdSplitKVCombineKernel lse_dram_window, smem_ptr, kargs.num_splits, + kargs.seqlen_q, kargs.max_seqlen_q); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 0e11cbb98d..af22256883 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -75,11 +75,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline return (kM0 * kMaxSplits * sizeof(LSEDataType)); } -#define MARKER(msg) \ - __builtin_amdgcn_sched_barrier(0); \ - asm volatile("; [POYENC] " msg ::); \ - __builtin_amdgcn_sched_barrier(0) - template {}); const auto col = x_indices.at(number<1>{}); - lse_acc_lds_ptr[row + col * kMaxSplits] = lse_acc(distributed_indices); + if(row < num_splits && col < real_seqlen_q) + { + lse_acc_lds_ptr[row + col * kMaxSplits] = lse_acc(distributed_indices); + } + else + { + lse_acc_lds_ptr[row + col * kMaxSplits] = -numeric::infinity(); + } }); }); } @@ -138,7 +141,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline auto lse_accum_dist = Policy::template MakeLSEaccTDramTileDistribution(); auto lse_accum = make_static_distributed_tensor(lse_accum_dist); - // copy LDS to lse_accum + // copy LDS to lse_accum (transpose) { using DataType = LSEDataType; using StaticTileDistribution = decltype(lse_accum_dist); @@ -162,14 +165,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline { lse_accum(distributed_indices) = -numeric::infinity(); } - - DEBUG_STMTS - { - printf("[POYENC][DEVICE] lse_accum[%2d,%2d]: %11.7f\n", - row, - col, - lse_accum(distributed_indices)); - } }); }); } @@ -204,16 +199,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline static const auto get_validated_m = [](LSEDataType raw_m) { /// NOTICE: bias might be materialized mask including -inf values, need /// consideration - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return raw_m == -numeric::infinity() ? type_convert(0.f) - : raw_m; - } - else - { - return raw_m; - } + return raw_m == -numeric::infinity() ? type_convert(0.f) + : raw_m; }; auto p_compute = make_static_distributed_tensor( @@ -240,15 +227,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline p_compute(i_j_idx) = ck_tile::exp(lse_acc_lds_ptr[col + row * kMaxSplits] - get_validated_m(lse_max(i_idx))); } -#endif -#if 0 - DEBUG_STMTS - { - printf("[POYENC][DEVICE] p_compute[%2d,%2d]: %11.7f\n", - row, - col, - p_compute(i_j_idx)); - } #endif }); }); @@ -312,21 +290,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline }); } -#if defined(PRINT_LSE_ACCUM) - DEBUG_STMTS - { - for(index_t row = 0; row < kM0; ++row) - { - printf("[POYENC][DEVICE] lse_accum[%d] = ", row); - for(index_t col = 0; col < num_splits; ++col) - { - printf("%11.7f", lse_acc_lds_ptr[col + row * kMaxSplits]); - } - printf("\n"); - } - } -#endif - // write lse scales into LDS { constexpr auto out_spans = @@ -365,7 +328,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline if constexpr(kStoreLSE) { - static_assert(kBlockSize == 256); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); } @@ -378,7 +340,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline auto o_acc = make_static_distributed_tensor(o_acc_dist); // Pcompute{j} clear_tile(o_acc); - // [POYENC] added + const index_t new_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0; + for(index_t i_split = 0; i_split < num_splits; ++i_split) { auto o_tile = load_tile(o_acc_dram_window); @@ -394,14 +357,25 @@ struct BlockFmhaFwdSplitKVCombinePipeline get_x_indices_from_distributed_indices(o_acc_dist, distributed_indices); const auto row = x_indices.at(number<0>{}); + const auto col = x_indices.at(number<1>{}); LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits]; o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices); + DEBUG_STMTS + { + printf("[POYENC][DEVICE] [%3d,%3d], o_acc(%11.7f) = lse_scale(%11.7f) " + "* o_tile(%11.7f)\n", + row, + col, + o_acc(distributed_indices), + lse_scale, + o_tile(distributed_indices)); + } }); }); } - move_tile_window(o_acc_dram_window, {max_seqlen_q, 0}); + move_tile_window(o_acc_dram_window, {new_max_seqlen_q, 0}); } o_acc = tile_elementwise_in(o_acc_element_func, o_acc); @@ -417,6 +391,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline LSEDramBlockWindow& lse_dram_block_window, void* smem_ptr, index_t num_splits, + index_t real_seqlen_q, index_t max_seqlen_q) const { return operator()(lse_acc_dram_block_window, @@ -426,6 +401,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline identity{}, smem_ptr, num_splits, + real_seqlen_q, max_seqlen_q); } };