Add constraint to kMaxSplits

This commit is contained in:
PoYen, Chen
2024-06-12 05:00:06 +00:00
parent e00ff9d246
commit 0a2132d758
2 changed files with 4 additions and 2 deletions

View File

@@ -137,7 +137,8 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size());
static_assert(0 < Problem::kMaxSplits && Problem::kMaxSplits % get_warp_size() == 0);
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t NumElements = (kMPerBlock * kNPerBlock);

View File

@@ -45,7 +45,8 @@ struct TileFmhaFwdSplitKVCombineTraits
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kMaxSplits = kMaxSplits_;
static constexpr index_t kMaxSplits = kMaxSplits_;
static_assert(0 < kMaxSplits && kMaxSplits % get_warp_size() == 0);
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};