diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 911ee2f295..bc9e904348 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -209,6 +209,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout; + static constexpr auto QScaleEnum = Problem::QScaleEnum; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -236,6 +237,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr auto LOG2E = log2e_v; #endif + // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] + static constexpr float OCP_FP8_SHIFT = 8.0f; + static constexpr float FNUZ_FP8_SHIFT = 7.0f; + static constexpr index_t kBlockPerCu = []() { if constexpr(Problem::kBlockPerCu != -1) return Problem::kBlockPerCu;