From 0a2132d758042e6fb0292f4e354909b8a4d1c118 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 12 Jun 2024 05:00:06 +0000 Subject: [PATCH] Add constraint to kMaxSplits --- .../block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp | 3 ++- include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp index b57c694678..0bf5030135 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -137,7 +137,8 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size()); + static_assert(0 < Problem::kMaxSplits && Problem::kMaxSplits % get_warp_size() == 0); + constexpr index_t kNPerBlock = Problem::kMaxSplits; constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t NumElements = (kMPerBlock * kNPerBlock); diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 5ba1004eee..71440ebca4 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -45,7 +45,8 @@ struct TileFmhaFwdSplitKVCombineTraits static constexpr bool kStoreLSE = kStoreLSE_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; - static constexpr index_t kMaxSplits = kMaxSplits_; + static constexpr index_t kMaxSplits = kMaxSplits_; + static_assert(0 < kMaxSplits && kMaxSplits % get_warp_size() == 0); static constexpr index_t kBlockPerCu = kBlockPerCu_; };