From dc74e66e7be47fefe4383a025068fdeb6dcbd22d Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Mon, 2 Feb 2026 22:57:29 +0800 Subject: [PATCH] Add runtime check nullptr for prevent quantization parameters. --- .../ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 03303a0683..873e65eac5 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -10,6 +10,7 @@ #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +#include #include #include #include @@ -1208,6 +1209,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel const float scale_s = [&] { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { + assert(kargs.q_descale_ptr != nullptr); + assert(kargs.k_descale_ptr != nullptr); float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); float k_descale = *(reinterpret_cast(kargs.k_descale_ptr)); @@ -1216,6 +1219,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) { // Q is per-tensor, K is per-page (handled in pipeline) + assert(kargs.q_descale_ptr != nullptr); float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); return kargs.scale_s * q_descale; } @@ -1251,6 +1255,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { // TODO - move global load of descale to pipeline + assert(kargs.v_descale_ptr != nullptr); float v_descale = *(reinterpret_cast(kargs.v_descale_ptr)); float scale_p = ck_tile::type_convert(ck_tile::numeric::max()); @@ -1297,6 +1302,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) { // KV_BLOCKSCALE: K/V descale is per-page, handled in pipeline + assert(kargs.kv_block_descale_ptr != nullptr); const float* kv_block_descale_ptr = reinterpret_cast(kargs.kv_block_descale_ptr);