From b32fd8d3f45020390093af57094212c8cb5ba683 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 16 Jul 2024 06:27:28 +0000 Subject: [PATCH] Rename variables used in distributio encoding --- ...a_fwd_appendkv_pipeline_default_policy.hpp | 50 ++++++++++++------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp index 82fe412798..45c773e07d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp @@ -64,15 +64,16 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy constexpr index_t kNPerBlock = Problem::kTileSizeSk; constexpr index_t kKPerBlock = Problem::kTileSizeD; - constexpr index_t K1 = 16 / sizeof(KDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); + constexpr index_t KPerThread = 16 / sizeof(KDataType); + constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread; + constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, + sequence>, tuple, sequence<1, 2>>, tuple, sequence<2, 0>>, sequence<1, 2>, @@ -100,15 +101,16 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy if constexpr(std::is_same_v) { - constexpr index_t N1 = 16 / sizeof(VDataType); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t K2 = get_warp_size() / N0; - constexpr index_t K1 = kBlockSize / get_warp_size(); - constexpr index_t K0 = kKPerBlock / (K2 * K1); + constexpr index_t NPerThread = 16 / sizeof(VDataType); + constexpr index_t NThreadPerBlock = kNPerBlock / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreadPerBlock; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t KPerThread = kKPerBlock / (NumWarps * KThreadPerWarp); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, + sequence>, tuple, sequence<1, 2>>, tuple, sequence<0, 2>>, sequence<1, 2>, @@ -116,21 +118,33 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy } else { - constexpr index_t K1 = 16 / sizeof(VDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); + constexpr index_t KPerThread = 16 / sizeof(VDataType); + constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread; + constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, + sequence>, tuple, sequence<1, 2>>, tuple, sequence<2, 0>>, sequence<1, 2>, sequence<0, 1>>{}); } } + + // 4 vals per load + template + CK_TILE_DEVICE static constexpr auto MakeRotaryCosSinInterleaveDramTileDistribution() + { + } + + template + CK_TILE_DEVICE static constexpr auto MakeRotaryCosSinContiguousDramTileDistribution() + { + } }; } // namespace ck_tile