From c40c1daff079447e4208207537c82e9adcabdbbf Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 26 Jun 2024 18:02:28 +0000 Subject: [PATCH] Extract common logics --- .../block_fmha_fwd_appendkv_pipeline_default_policy.hpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp index b8706b8a23..067cc89b4f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp @@ -90,11 +90,11 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy template CK_TILE_DEVICE static constexpr auto MakeVnewDramTileDistribution() { - using VLayout = remove_cvref_t; + using VLayout = remove_cvref_t; + using VDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; - static_assert(!std::is_same_v); if constexpr(std::is_same_v) { constexpr index_t kNPerBlock = Problem::kTileSizeDv; @@ -105,7 +105,6 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy constexpr index_t N2 = get_warp_size() / K0; constexpr index_t N1 = kBlockSize / get_warp_size(); constexpr index_t N0 = kNPerBlock / (N2 * N1); - static_assert(N0 != 0); return make_static_tile_distribution( tile_distribution_encoding, @@ -117,8 +116,6 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy } else { - using VDataType = remove_cvref_t; - constexpr index_t kNPerBlock = Problem::kTileSizeDv; constexpr index_t kKPerBlock = Problem::kTileSizeSk;