diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 37a1536413..581f67eb8b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -67,8 +67,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - /// TODO: add padding to avoid bank conflict - return (kM0 * kMaxSplits * sizeof(LSEDataType)); + return Policy::template GetSmemSize(); } template (static_cast(static_cast(smem_ptr))); +#if 0 + auto lse_acc_lds = make_tensor_view( + lse_acc_lds_ptr, Policy::template MakeLSEaccLdsBlockDescriptor()); + auto lse_acc_lds_window = + make_tile_window(lse_acc_lds, make_tuple(number{}, number{}), {0, 0}); +#endif + + auto lse_acc_lds_ms_m0_for_write = Policy::template MakeLSEaccLdsBlockDescriptor(); auto lse_acc_dist = Policy::template MakeLSEaccDramTileDistribution(); auto lse_acc_dram_window = @@ -116,13 +124,15 @@ struct BlockFmhaFwdSplitKVCombinePipeline const auto row = x_indices.at(number<0>{}); const auto col = x_indices.at(number<1>{}); + auto offset = + lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(row, col)); if(row < num_splits && col < seqlen_q) { - lse_acc_lds_ptr[row + col * kMaxSplits] = lse_acc(distributed_indices); + lse_acc_lds_ptr[offset] = lse_acc(distributed_indices); } else { - lse_acc_lds_ptr[row + col * kMaxSplits] = -numeric::infinity(); + lse_acc_lds_ptr[offset] = -numeric::infinity(); } }); }); @@ -148,9 +158,11 @@ struct BlockFmhaFwdSplitKVCombinePipeline const auto row = x_indices.at(number<0>{}); const auto col = x_indices.at(number<1>{}); + auto offset = + lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row)); if(col < num_splits) { - lse_accum(distributed_indices) = lse_acc_lds_ptr[col + row * kMaxSplits]; + lse_accum(distributed_indices) = lse_acc_lds_ptr[offset]; } else { @@ -190,11 +202,13 @@ struct BlockFmhaFwdSplitKVCombinePipeline const auto row = x_indices.at(number<0>{}); const auto col = x_indices.at(number<1>{}); + auto offset = + lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row)); if(col < num_splits) { // from shared memory - p_compute(i_j_idx) = ck_tile::exp(lse_acc_lds_ptr[col + row * kMaxSplits] - - get_validated_m(lse_max(i_idx))); + p_compute(i_j_idx) = + ck_tile::exp(lse_acc_lds_ptr[offset] - get_validated_m(lse_max(i_idx))); } }); }); @@ -240,8 +254,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline for(index_t col = 0; col < num_splits; ++col) { - lse_acc_lds_ptr[col + row * kMaxSplits] = ck_tile::exp( - lse_acc_lds_ptr[col + row * kMaxSplits] - lse_logsum(distributed_indices)); + auto offset = + lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row)); + lse_acc_lds_ptr[offset] = + ck_tile::exp(lse_acc_lds_ptr[offset] - lse_logsum(distributed_indices)); } }); } @@ -291,7 +307,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline const auto row = x_indices.at(number<0>{}); - LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits]; + auto offset = + lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(i_split, row)); + + LSEDataType lse_scale = lse_acc_lds_ptr[offset]; o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices); }); }); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp index 7661e66a07..89e4fa29ef 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -31,6 +31,13 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy return 16 / sizeof(ODataType); } + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return sizeof(typename Problem::LSEDataType) * + MakeLSEaccLdsBlockDescriptor().get_element_space_size(); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution() { @@ -73,6 +80,33 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy } } + // 3d + padding + // [kMaxSplits x kM0] + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor() + { + using LSEDataType = remove_cvref_t; + + constexpr index_t kMPerBlock = Problem::kMaxSplits; + constexpr index_t kNPerBlock = Problem::kM0; + constexpr index_t NPack = 16 / sizeof(LSEDataType); + + constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * NPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor( + lse_acc_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lse_acc_lds_block_desc; + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccTDramTileDistribution() {