mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Add runtime check nullptr for prevent quantization parameters.
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
@@ -1208,6 +1209,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
const float scale_s = [&] {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
assert(kargs.q_descale_ptr != nullptr);
|
||||
assert(kargs.k_descale_ptr != nullptr);
|
||||
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
|
||||
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
|
||||
|
||||
@@ -1216,6 +1219,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
// Q is per-tensor, K is per-page (handled in pipeline)
|
||||
assert(kargs.q_descale_ptr != nullptr);
|
||||
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
|
||||
return kargs.scale_s * q_descale;
|
||||
}
|
||||
@@ -1251,6 +1255,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
// TODO - move global load of descale to pipeline
|
||||
assert(kargs.v_descale_ptr != nullptr);
|
||||
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
|
||||
|
||||
float scale_p = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
|
||||
@@ -1297,6 +1302,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
// KV_BLOCKSCALE: K/V descale is per-page, handled in pipeline
|
||||
assert(kargs.kv_block_descale_ptr != nullptr);
|
||||
const float* kv_block_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.kv_block_descale_ptr);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user