Add runtime check nullptr for prevent quantization parameters.

This commit is contained in:
Jeff Huang
2026-02-02 22:57:29 +08:00
parent 0aa1142bb5
commit dc74e66e7b

View File

@@ -10,6 +10,7 @@
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <cassert>
#include <string>
#include <type_traits>
#include <utility>
@@ -1208,6 +1209,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
const float scale_s = [&] {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
assert(kargs.q_descale_ptr != nullptr);
assert(kargs.k_descale_ptr != nullptr);
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
@@ -1216,6 +1219,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
// Q is per-tensor, K is per-page (handled in pipeline)
assert(kargs.q_descale_ptr != nullptr);
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
return kargs.scale_s * q_descale;
}
@@ -1251,6 +1255,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
// TODO - move global load of descale to pipeline
assert(kargs.v_descale_ptr != nullptr);
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
float scale_p = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
@@ -1297,6 +1302,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
// KV_BLOCKSCALE: K/V descale is per-page, handled in pipeline
assert(kargs.kv_block_descale_ptr != nullptr);
const float* kv_block_descale_ptr =
reinterpret_cast<const float*>(kargs.kv_block_descale_ptr);