diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp index 45c773e07d..cd38c22e8f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp @@ -52,7 +52,9 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - return 1; + using KDataType = remove_cvref_t; + + return sizeof(KDataType) * Problem::kTileSizeSk * (Problem::kTileSizeD / 2); } template @@ -139,6 +141,26 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy template CK_TILE_DEVICE static constexpr auto MakeRotaryCosSinInterleaveDramTileDistribution() { + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::kTileSizeSk; + constexpr index_t kKPerBlock = Problem::kTileSizeD / 2; + + constexpr index_t KPerThread = 8 / sizeof(KDataType); + constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread; + constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); } template