Fp8 block scale quantization for fmha fwd (#3330)

* add block scale parameters to kernel

* add block scale to kernel

* add smoke test

* format

* Revert "format"

This reverts commit 356c3c9706.

* only format my code

* format py

* fix auto not allowd in function prototype

* change instance tttt to ttff

* fix structured binding issue

* change s_acc elementwise op

* async pipeline add block scale

* add quantation P using shift exp2

* precompute (m - shift) once per row

* change blk scale seqstrt ptr name

* fix some name

* fix for  deduction guide

* fix some comments

* add P scale to qr_ksvs_pipeline

* add comment to idx_identity

* change the method of calculating descale block index

* unify naming style: use block_scale_ as name prefix

* unify naming style

* update the CHANGELOG.md

* Add FP8 block scale quantization support for FMHA forward kernel

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
ltqin
2026-01-22 12:58:26 +08:00
committed by GitHub
parent 4c2c18ef48
commit dd0b4294af
14 changed files with 667 additions and 84 deletions

View File

@@ -37,6 +37,13 @@ struct scales
return lhs_ * rhs;
}
template <typename OtherScale>
CK_TILE_HOST_DEVICE constexpr auto operator*(OtherScale other) const
{
auto new_scale = lhs_ * other;
return scales<std::decay_t<decltype(new_scale)>>(new_scale);
}
private:
Scale lhs_;
};

View File

@@ -119,6 +119,18 @@ struct identity
}
};
// Similar to identity, but takes an additional index parameter as the first argument.
// The index is ignored and only the second argument (value) is forwarded.
// Useful for indexed element-wise operations where the functor signature requires an index.
struct idx_identity
{
template <typename I, typename T>
CK_TILE_HOST_DEVICE constexpr T&& operator()(I&& /*idx*/, T&& arg) const noexcept
{
return std::forward<T>(arg);
}
};
namespace detail {
// RemainLengths: sequence<...>

View File

@@ -47,4 +47,44 @@ CK_TILE_HOST void reference_batched_gemm(const HostTensor<ADataType>& a_b_m_k,
make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
std::thread::hardware_concurrency());
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename AElementOp = ck_tile::idx_identity,
typename BElementOp = ck_tile::idx_identity,
typename ACCElementOp = ck_tile::idx_identity>
CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor<ADataType>& a_b_m_k,
const HostTensor<BDataType>& b_b_n_k,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const int N = b_b_n_k.mDesc.get_lengths()[1];
const int K = b_b_n_k.mDesc.get_lengths()[2];
auto f = [&](auto batch, auto m) {
for(int n = 0; n < N; ++n)
{
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
AccDataType v_a = ck_tile::type_convert<AccDataType>(
a_element_op(std::make_tuple(batch, m, k), a_b_m_k(batch, m, k)));
AccDataType v_b = ck_tile::type_convert<AccDataType>(
b_element_op(std::make_tuple(batch, n, k), b_b_n_k(batch, n, k)));
v_acc += v_a * v_b;
}
c_b_m_n(batch, m, n) = ck_tile::type_convert<CDataType>(
acc_element_op(std::make_tuple(batch, m, n), v_acc));
}
};
make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -12,6 +12,7 @@ enum class BlockAttentionQuantScaleEnum
{
NO_SCALE = 0,
PERTENSOR = 1,
BLOCKSCALE,
};
template <BlockAttentionQuantScaleEnum>
@@ -27,5 +28,10 @@ struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::PERTENSOR
{
static constexpr const char* name = "pertensor";
};
template <>
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::BLOCKSCALE>
{
static constexpr const char* name = "blockscale";
};
} // namespace ck_tile

View File

@@ -168,6 +168,29 @@ struct FmhaFwdKernel
const void* v_descale_ptr = nullptr;
};
struct FmhaFwdCommonBlockScaleKargs : public FmhaFwdCommonQScaleKargs
{
ck_tile::index_t nhead_stride_q_descale;
ck_tile::index_t nhead_stride_k_descale;
ck_tile::index_t nhead_stride_v_descale;
ck_tile::index_t block_scale_size_q;
ck_tile::index_t block_scale_size_kv;
};
struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs
{
ck_tile::index_t batch_stride_q_descale;
ck_tile::index_t batch_stride_k_descale;
ck_tile::index_t batch_stride_v_descale;
};
struct FmhaFwdGroupBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs
{
const int32_t* block_scale_seqstart_q_ptr;
const int32_t* block_scale_seqstart_k_ptr;
};
struct FmhaFwdCommonLSEKargs
{
void* lse_ptr = nullptr;
@@ -243,9 +266,12 @@ struct FmhaFwdKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
std::conditional_t<
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
FmhaFwdBatchBlockScaleKargs,
FmhaFwdEmptyKargs<3>>>,
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
{
@@ -269,9 +295,12 @@ struct FmhaFwdKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
std::conditional_t<
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
FmhaFwdGroupBlockScaleKargs,
FmhaFwdEmptyKargs<3>>>,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
@@ -328,6 +357,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
@@ -335,6 +367,9 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_q_descale,
ck_tile::index_t batch_stride_k_descale,
ck_tile::index_t batch_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
@@ -343,6 +378,8 @@ struct FmhaFwdKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
@@ -413,6 +450,23 @@ struct FmhaFwdKernel
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
kargs.batch_stride_q_descale = batch_stride_q_descale;
kargs.batch_stride_k_descale = batch_stride_k_descale;
kargs.batch_stride_v_descale = batch_stride_v_descale;
kargs.block_scale_size_q = block_scale_size_q;
kargs.block_scale_size_kv = block_scale_size_kv;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -478,6 +532,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
@@ -485,6 +542,9 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_q_descale,
ck_tile::index_t batch_stride_k_descale,
ck_tile::index_t batch_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
@@ -492,6 +552,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
@@ -528,6 +590,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_descale,
nhead_stride_k_descale,
nhead_stride_v_descale,
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -535,6 +600,9 @@ struct FmhaFwdKernel
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
batch_stride_q_descale,
batch_stride_k_descale,
batch_stride_v_descale,
window_size_left,
window_size_right,
sink_size,
@@ -542,6 +610,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_size_q,
block_scale_size_kv,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr,
sink_ptr);
@@ -581,6 +651,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
@@ -588,6 +661,9 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_q_descale,
ck_tile::index_t batch_stride_k_descale,
ck_tile::index_t batch_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
@@ -595,6 +671,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
@@ -631,6 +709,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_descale,
nhead_stride_k_descale,
nhead_stride_v_descale,
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -638,6 +719,9 @@ struct FmhaFwdKernel
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
batch_stride_q_descale,
batch_stride_k_descale,
batch_stride_v_descale,
window_size_left,
window_size_right,
sink_size,
@@ -645,6 +729,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_size_q,
block_scale_size_kv,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr,
sink_ptr);
@@ -666,6 +752,8 @@ struct FmhaFwdKernel
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
const void* block_scale_seqstart_q_ptr,
const void* block_scale_seqstart_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -685,6 +773,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
@@ -694,6 +785,8 @@ struct FmhaFwdKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
@@ -763,6 +856,24 @@ struct FmhaFwdKernel
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
kargs.block_scale_size_q = block_scale_size_q;
kargs.block_scale_size_kv = block_scale_size_kv;
kargs.block_scale_seqstart_q_ptr =
reinterpret_cast<const int32_t*>(block_scale_seqstart_q_ptr);
kargs.block_scale_seqstart_k_ptr =
reinterpret_cast<const int32_t*>(block_scale_seqstart_k_ptr);
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -814,6 +925,8 @@ struct FmhaFwdKernel
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
const void* block_scale_seqstart_q_ptr,
const void* block_scale_seqstart_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -833,6 +946,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
@@ -841,6 +957,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
@@ -860,6 +978,8 @@ struct FmhaFwdKernel
seqstart_k_ptr,
seqlen_q_ptr,
seqlen_k_ptr,
block_scale_seqstart_q_ptr,
block_scale_seqstart_k_ptr,
hdim_q,
hdim_v,
num_head_q,
@@ -879,6 +999,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_descale,
nhead_stride_k_descale,
nhead_stride_v_descale,
window_size_left,
window_size_right,
sink_size,
@@ -887,6 +1010,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_size_q,
block_scale_size_kv,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr,
sink_ptr);
@@ -909,6 +1034,8 @@ struct FmhaFwdKernel
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
const void* block_scale_seqstart_q_ptr,
const void* block_scale_seqstart_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -928,6 +1055,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
@@ -936,6 +1066,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
@@ -955,6 +1087,8 @@ struct FmhaFwdKernel
seqstart_k_ptr,
seqlen_q_ptr,
seqlen_k_ptr,
block_scale_seqstart_q_ptr,
block_scale_seqstart_k_ptr,
hdim_q,
hdim_v,
num_head_q,
@@ -974,6 +1108,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_descale,
nhead_stride_k_descale,
nhead_stride_v_descale,
window_size_left,
window_size_right,
sink_size,
@@ -982,6 +1119,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_size_q,
block_scale_size_kv,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr,
sink_ptr);
@@ -1111,13 +1250,16 @@ struct FmhaFwdKernel
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
long_index_t batch_offset_q_descale = 0;
long_index_t batch_offset_k_descale = 0;
long_index_t batch_offset_v_descale = 0;
const float sink_value =
kargs.sink_ptr != nullptr
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
@@ -1153,6 +1295,14 @@ struct FmhaFwdKernel
{
batch_offset_randval = query_start * kargs.stride_randval;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
const long_index_t bquery_start = kargs.block_scale_seqstart_q_ptr[i_batch];
const long_index_t bkey_start = kargs.block_scale_seqstart_k_ptr[i_batch];
batch_offset_q_descale = bquery_start;
batch_offset_k_descale = bkey_start;
batch_offset_v_descale = bkey_start;
}
batch_offset_o = query_start * kargs.stride_o;
// real logical lengths (exclude PAD)
@@ -1220,6 +1370,15 @@ struct FmhaFwdKernel
batch_offset_randval =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
batch_offset_q_descale =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q_descale;
batch_offset_k_descale =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_k_descale;
batch_offset_v_descale =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_v_descale;
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
// If cumulative seqlen pointers are provided, override per-batch effective lengths
@@ -1540,7 +1699,8 @@ struct FmhaFwdKernel
}();
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
auto o_acc_tile = [&]() {
auto o_acc_tile = [&, i_nhead_ = i_nhead]() {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
// TODO - move global load of descale to pipeline
@@ -1581,8 +1741,62 @@ struct FmhaFwdKernel
block_indices,
smem_ptr,
dropout,
nullptr,
nullptr,
1,
sink_value);
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
const float* q_descale_ptr =
reinterpret_cast<const float*>(kargs.q_descale_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_q_descale +
batch_offset_q_descale;
const float* k_descale_ptr =
reinterpret_cast<const float*>(kargs.k_descale_ptr) +
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_k_descale +
batch_offset_k_descale;
const float* v_descale_ptr =
reinterpret_cast<const float*>(kargs.v_descale_ptr) +
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_v_descale +
batch_offset_v_descale;
size_t idx = i_m0 / kargs.block_scale_size_q;
float q_descale = q_descale_ptr[idx];
// BLOCKSCALE: P is scaled in exp2(x+shift) where shift=7 or 8
// Both P and rowsum are scaled by 2^shift, canceling in normalization
// No additional scaling needed in p_compute_element_func or o_acc_element_func
return FmhaPipeline{}(
q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
scales<float>(q_descale), // s_acc_element_func
identity{}, // p_compute_element_func - No scaling (done in exp2)
identity{}, // o_acc_element_func - No dequant needed (canceled by rowsum)
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
k_descale_ptr,
v_descale_ptr,
kargs.block_scale_size_kv,
sink_value);
}
else
{
return FmhaPipeline{}(q_dram_window,

View File

@@ -57,8 +57,13 @@ struct BlockFmhaPipelineQRKSVS
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static constexpr bool kHasSink = Problem::kHasSink;
// For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift]
static constexpr float OCP_FP8_SHIFT = 8.0f;
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
@@ -167,6 +172,9 @@ struct BlockFmhaPipelineQRKSVS
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout,
const float* k_descale_ptr,
const float* v_descale_ptr,
const index_t block_scale_size_kv,
const float sink_v) const
{
static_assert(
@@ -358,6 +366,13 @@ struct BlockFmhaPipelineQRKSVS
static_assert(1 <= k1_loops);
do
{
float k_descale = 1.0f;
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
// K and V share the same seqlen_k position within a block
const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv;
k_descale = k_descale_ptr[kv_idx];
}
// STAGE 1, QK gemm
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
@@ -427,11 +442,20 @@ struct BlockFmhaPipelineQRKSVS
k_lds_window);
schedule_gemm0();
}
// dequant
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
return s_acc_element_func * k_descale;
}
else
return s_acc_element_func;
}();
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
@@ -449,7 +473,7 @@ struct BlockFmhaPipelineQRKSVS
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
@@ -466,7 +490,7 @@ struct BlockFmhaPipelineQRKSVS
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
if constexpr(kHasLogitsSoftCap)
{
auto apply_logits_transform =
@@ -571,7 +595,21 @@ struct BlockFmhaPipelineQRKSVS
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
// For BLOCKSCALE: precompute (m - shift) once per row
// Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift))
// else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift))
auto validated_m = get_validated_m(m[i_idx]);
auto row_max = scale_s * validated_m;
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
#if CK_TILE_USE_OCP_FP8
validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap
row_max -= OCP_FP8_SHIFT; // for else branch
#else
validated_m -= FNUZ_FP8_SHIFT;
row_max -= FNUZ_FP8_SHIFT;
#endif
}
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
@@ -579,13 +617,13 @@ struct BlockFmhaPipelineQRKSVS
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
}
else
{
if constexpr(kHasLogitsSoftCap)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
}
else
{
@@ -676,18 +714,39 @@ struct BlockFmhaPipelineQRKSVS
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
}
move_tile_window(v_dram_window, {0, kK1});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
float v_descale = 1.0f;
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
// K and V share the same seqlen_k position within a block
const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv;
v_descale = v_descale_ptr[kv_idx];
}
// STAGE 3, KV gemm
auto o_acc0 = decltype(o_acc){};
clear_tile(o_acc0);
auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
return o_acc0;
}
else
{
return o_acc;
}
}();
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto v = load_tile(v_dram_window); // load next v
block_sync_lds();
gemm_1(o_acc,
gemm_1(o_acc_,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
@@ -722,11 +781,16 @@ struct BlockFmhaPipelineQRKSVS
// tail
{
block_sync_lds();
gemm_1(o_acc,
gemm_1(o_acc_,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
v_lds_window);
block_sync_lds();
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
tile_elementwise_inout(
[&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0);
}
} while(++i_total_loops < num_total_loop);
// store lse
@@ -846,6 +910,9 @@ struct BlockFmhaPipelineQRKSVS
block_indices,
smem_ptr,
dropout,
nullptr,
nullptr,
1,
sink_v);
}
};

View File

@@ -46,6 +46,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
@@ -64,6 +65,10 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasSink = Problem::kHasSink;
// For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift]
static constexpr float OCP_FP8_SHIFT = 8.0f;
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
@@ -190,6 +195,9 @@ struct BlockFmhaPipelineQRKSVSAsync
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout,
const float* k_descale_ptr,
const float* v_descale_ptr,
const index_t block_scale_size_kv,
const float sink_v) const
{
static_assert(
@@ -403,6 +411,13 @@ struct BlockFmhaPipelineQRKSVSAsync
// main loop
do
{
float k_descale = 1.0f;
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
// K and V share the same seqlen_k position within a block
const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv;
k_descale = k_descale_ptr[kv_idx];
}
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
if constexpr(k0_loops > 1)
@@ -449,11 +464,20 @@ struct BlockFmhaPipelineQRKSVSAsync
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
}
__builtin_amdgcn_sched_barrier(1);
// dequant
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
return s_acc_element_func * k_descale;
}
else
return s_acc_element_func;
}();
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
@@ -471,7 +495,7 @@ struct BlockFmhaPipelineQRKSVSAsync
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
@@ -488,7 +512,7 @@ struct BlockFmhaPipelineQRKSVSAsync
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
if constexpr(kHasLogitsSoftCap)
{
auto apply_logits_transform =
@@ -630,7 +654,21 @@ struct BlockFmhaPipelineQRKSVSAsync
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
// For BLOCKSCALE: precompute (m - shift) once per row
// Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift))
// else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift))
auto validated_m = get_validated_m(m[i_idx]);
auto row_max = scale_s * validated_m;
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
#if CK_TILE_USE_OCP_FP8
validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap
row_max -= OCP_FP8_SHIFT; // for else branch
#else
validated_m -= FNUZ_FP8_SHIFT;
row_max -= FNUZ_FP8_SHIFT;
#endif
}
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
@@ -638,13 +676,13 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
}
else
{
if constexpr(kHasLogitsSoftCap)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
}
else
{
@@ -735,7 +773,27 @@ struct BlockFmhaPipelineQRKSVSAsync
#endif
}();
float v_descale = 1.0f;
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
// K and V share the same seqlen_k position within a block
const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv;
v_descale = v_descale_ptr[kv_idx];
}
// STAGE 3, KV gemm
auto o_acc0 = decltype(o_acc){};
clear_tile(o_acc0);
auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
return o_acc0;
}
else
{
return o_acc;
}
}();
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
@@ -745,7 +803,7 @@ struct BlockFmhaPipelineQRKSVSAsync
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
}
block_sync_lds();
gemm_1(o_acc,
gemm_1(o_acc_,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(
@@ -808,13 +866,19 @@ struct BlockFmhaPipelineQRKSVSAsync
{
block_sync_lds();
gemm_1(
o_acc,
o_acc_,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
tile_elementwise_inout(
[&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0);
}
} while(i_total_loops < num_total_loop);
// store lse
@@ -922,6 +986,9 @@ struct BlockFmhaPipelineQRKSVSAsync
block_indices,
smem_ptr,
dropout,
nullptr,
nullptr,
1,
sink_v);
}
};