add qscaleenum and shift value

This commit is contained in:
ltqin
2026-01-29 08:17:31 +00:00
parent 9e338c5b47
commit a9d85dfe16

View File

@@ -209,6 +209,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -236,6 +237,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr auto LOG2E = log2e_v<SaccDataType>;
#endif
// For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift]
static constexpr float OCP_FP8_SHIFT = 8.0f;
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;