mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Fp8 block scale quantization for fmha fwd (#3330)
* add block scale parameters to kernel
* add block scale to kernel
* add smoke test
* format
* Revert "format"
This reverts commit 356c3c9706.
* only format my code
* format py
* fix auto not allowd in function prototype
* change instance tttt to ttff
* fix structured binding issue
* change s_acc elementwise op
* async pipeline add block scale
* add quantation P using shift exp2
* precompute (m - shift) once per row
* change blk scale seqstrt ptr name
* fix some name
* fix for deduction guide
* fix some comments
* add P scale to qr_ksvs_pipeline
* add comment to idx_identity
* change the method of calculating descale block index
* unify naming style: use block_scale_ as name prefix
* unify naming style
* update the CHANGELOG.md
* Add FP8 block scale quantization support for FMHA forward kernel
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -168,6 +168,29 @@ struct FmhaFwdKernel
|
||||
const void* v_descale_ptr = nullptr;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonBlockScaleKargs : public FmhaFwdCommonQScaleKargs
|
||||
{
|
||||
ck_tile::index_t nhead_stride_q_descale;
|
||||
ck_tile::index_t nhead_stride_k_descale;
|
||||
ck_tile::index_t nhead_stride_v_descale;
|
||||
|
||||
ck_tile::index_t block_scale_size_q;
|
||||
ck_tile::index_t block_scale_size_kv;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_q_descale;
|
||||
ck_tile::index_t batch_stride_k_descale;
|
||||
ck_tile::index_t batch_stride_v_descale;
|
||||
};
|
||||
|
||||
struct FmhaFwdGroupBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs
|
||||
{
|
||||
const int32_t* block_scale_seqstart_q_ptr;
|
||||
const int32_t* block_scale_seqstart_k_ptr;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonLSEKargs
|
||||
{
|
||||
void* lse_ptr = nullptr;
|
||||
@@ -243,9 +266,12 @@ struct FmhaFwdKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
|
||||
FmhaFwdBatchBlockScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
@@ -269,9 +295,12 @@ struct FmhaFwdKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
|
||||
FmhaFwdGroupBlockScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
|
||||
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
|
||||
@@ -328,6 +357,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
@@ -335,6 +367,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_q_descale,
|
||||
ck_tile::index_t batch_stride_k_descale,
|
||||
ck_tile::index_t batch_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -343,6 +378,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -413,6 +450,23 @@ struct FmhaFwdKernel
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
|
||||
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
kargs.batch_stride_q_descale = batch_stride_q_descale;
|
||||
kargs.batch_stride_k_descale = batch_stride_k_descale;
|
||||
kargs.batch_stride_v_descale = batch_stride_v_descale;
|
||||
|
||||
kargs.block_scale_size_q = block_scale_size_q;
|
||||
kargs.block_scale_size_kv = block_scale_size_kv;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -478,6 +532,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
@@ -485,6 +542,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_q_descale,
|
||||
ck_tile::index_t batch_stride_k_descale,
|
||||
ck_tile::index_t batch_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -492,6 +552,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -528,6 +590,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -535,6 +600,9 @@ struct FmhaFwdKernel
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
batch_stride_q_descale,
|
||||
batch_stride_k_descale,
|
||||
batch_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
@@ -542,6 +610,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_size_q,
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
@@ -581,6 +651,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
@@ -588,6 +661,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_q_descale,
|
||||
ck_tile::index_t batch_stride_k_descale,
|
||||
ck_tile::index_t batch_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -595,6 +671,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -631,6 +709,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -638,6 +719,9 @@ struct FmhaFwdKernel
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
batch_stride_q_descale,
|
||||
batch_stride_k_descale,
|
||||
batch_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
@@ -645,6 +729,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_size_q,
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
@@ -666,6 +752,8 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* block_scale_seqstart_q_ptr,
|
||||
const void* block_scale_seqstart_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -685,6 +773,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -694,6 +785,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -763,6 +856,24 @@ struct FmhaFwdKernel
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
|
||||
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
kargs.block_scale_size_q = block_scale_size_q;
|
||||
kargs.block_scale_size_kv = block_scale_size_kv;
|
||||
|
||||
kargs.block_scale_seqstart_q_ptr =
|
||||
reinterpret_cast<const int32_t*>(block_scale_seqstart_q_ptr);
|
||||
kargs.block_scale_seqstart_k_ptr =
|
||||
reinterpret_cast<const int32_t*>(block_scale_seqstart_k_ptr);
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -814,6 +925,8 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* block_scale_seqstart_q_ptr,
|
||||
const void* block_scale_seqstart_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -833,6 +946,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -841,6 +957,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -860,6 +978,8 @@ struct FmhaFwdKernel
|
||||
seqstart_k_ptr,
|
||||
seqlen_q_ptr,
|
||||
seqlen_k_ptr,
|
||||
block_scale_seqstart_q_ptr,
|
||||
block_scale_seqstart_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -879,6 +999,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
@@ -887,6 +1010,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_size_q,
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
@@ -909,6 +1034,8 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* block_scale_seqstart_q_ptr,
|
||||
const void* block_scale_seqstart_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -928,6 +1055,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -936,6 +1066,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -955,6 +1087,8 @@ struct FmhaFwdKernel
|
||||
seqstart_k_ptr,
|
||||
seqlen_q_ptr,
|
||||
seqlen_k_ptr,
|
||||
block_scale_seqstart_q_ptr,
|
||||
block_scale_seqstart_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -974,6 +1108,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
@@ -982,6 +1119,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_size_q,
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
@@ -1111,13 +1250,16 @@ struct FmhaFwdKernel
|
||||
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_randval = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_randval = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
long_index_t batch_offset_q_descale = 0;
|
||||
long_index_t batch_offset_k_descale = 0;
|
||||
long_index_t batch_offset_v_descale = 0;
|
||||
const float sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
|
||||
@@ -1153,6 +1295,14 @@ struct FmhaFwdKernel
|
||||
{
|
||||
batch_offset_randval = query_start * kargs.stride_randval;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
const long_index_t bquery_start = kargs.block_scale_seqstart_q_ptr[i_batch];
|
||||
const long_index_t bkey_start = kargs.block_scale_seqstart_k_ptr[i_batch];
|
||||
batch_offset_q_descale = bquery_start;
|
||||
batch_offset_k_descale = bkey_start;
|
||||
batch_offset_v_descale = bkey_start;
|
||||
}
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
|
||||
// real logical lengths (exclude PAD)
|
||||
@@ -1220,6 +1370,15 @@ struct FmhaFwdKernel
|
||||
batch_offset_randval =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
batch_offset_q_descale =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q_descale;
|
||||
batch_offset_k_descale =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_k_descale;
|
||||
batch_offset_v_descale =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_v_descale;
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
|
||||
// If cumulative seqlen pointers are provided, override per-batch effective lengths
|
||||
@@ -1540,7 +1699,8 @@ struct FmhaFwdKernel
|
||||
}();
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
auto o_acc_tile = [&]() {
|
||||
|
||||
auto o_acc_tile = [&, i_nhead_ = i_nhead]() {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
// TODO - move global load of descale to pipeline
|
||||
@@ -1581,8 +1741,62 @@ struct FmhaFwdKernel
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
sink_value);
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
const float* q_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.q_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_q_descale +
|
||||
batch_offset_q_descale;
|
||||
const float* k_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.k_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_k_descale +
|
||||
batch_offset_k_descale;
|
||||
const float* v_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.v_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_v_descale +
|
||||
batch_offset_v_descale;
|
||||
|
||||
size_t idx = i_m0 / kargs.block_scale_size_q;
|
||||
float q_descale = q_descale_ptr[idx];
|
||||
// BLOCKSCALE: P is scaled in exp2(x+shift) where shift=7 or 8
|
||||
// Both P and rowsum are scaled by 2^shift, canceling in normalization
|
||||
// No additional scaling needed in p_compute_element_func or o_acc_element_func
|
||||
|
||||
return FmhaPipeline{}(
|
||||
q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
scales<float>(q_descale), // s_acc_element_func
|
||||
identity{}, // p_compute_element_func - No scaling (done in exp2)
|
||||
identity{}, // o_acc_element_func - No dequant needed (canceled by rowsum)
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
kargs.block_scale_size_kv,
|
||||
sink_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
|
||||
Reference in New Issue
Block a user