mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Fp8 block scale quantization for fmha fwd (#3330)
* add block scale parameters to kernel
* add block scale to kernel
* add smoke test
* format
* Revert "format"
This reverts commit 356c3c9706.
* only format my code
* format py
* fix auto not allowd in function prototype
* change instance tttt to ttff
* fix structured binding issue
* change s_acc elementwise op
* async pipeline add block scale
* add quantation P using shift exp2
* precompute (m - shift) once per row
* change blk scale seqstrt ptr name
* fix some name
* fix for deduction guide
* fix some comments
* add P scale to qr_ksvs_pipeline
* add comment to idx_identity
* change the method of calculating descale block index
* unify naming style: use block_scale_ as name prefix
* unify naming style
* update the CHANGELOG.md
* Add FP8 block scale quantization support for FMHA forward kernel
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -16,6 +16,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
|
||||
* Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines.
|
||||
* Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming.
|
||||
* Added FP8 block scale quantization for FMHA forward kernel.
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
@@ -77,11 +77,13 @@ def get_mask_cpp_check_expr(mask: str) -> str:
|
||||
QSCALE_MAP = {
|
||||
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
|
||||
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
}
|
||||
|
||||
QSCALE_CHECK_MAP = {
|
||||
"no": "quant_scale_enum::no_scale",
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
"blockscale": "quant_scale_enum::blockscale",
|
||||
}
|
||||
|
||||
BIAS_MAP = {
|
||||
|
||||
@@ -1018,7 +1018,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias, sink in itertools.product(
|
||||
["t", "f"],
|
||||
["no", "pertensor"],
|
||||
["no", "pertensor", "blockscale"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
["f", "t"],
|
||||
@@ -1146,7 +1146,10 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
|
||||
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"]
|
||||
["f"],
|
||||
["no", "pertensor", "blockscale"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
|
||||
|
||||
@@ -230,6 +230,8 @@ struct fmha_fwd_args
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* block_scale_seqstart_q_ptr;
|
||||
const void* block_scale_seqstart_k_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
@@ -257,6 +259,9 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_q_descale;
|
||||
ck_tile::index_t nhead_stride_k_descale;
|
||||
ck_tile::index_t nhead_stride_v_descale;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
@@ -264,6 +269,9 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_q_descale;
|
||||
ck_tile::index_t batch_stride_k_descale;
|
||||
ck_tile::index_t batch_stride_v_descale;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
@@ -276,6 +284,9 @@ struct fmha_fwd_args
|
||||
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
|
||||
ck_tile::index_t block_scale_size_q;
|
||||
ck_tile::index_t block_scale_size_kv;
|
||||
};
|
||||
|
||||
struct fmha_fwd_pagedkv_args
|
||||
@@ -615,6 +626,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.block_scale_seqstart_q_ptr,
|
||||
args.block_scale_seqstart_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
@@ -634,6 +647,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
@@ -642,6 +658,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.block_scale_size_q,
|
||||
args.block_scale_size_kv,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
@@ -679,6 +697,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
@@ -686,6 +707,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.batch_stride_q_descale,
|
||||
args.batch_stride_k_descale,
|
||||
args.batch_stride_v_descale,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
@@ -693,6 +717,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.block_scale_size_q,
|
||||
args.block_scale_size_kv,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
|
||||
@@ -210,6 +210,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::stream_config& stream_config,
|
||||
std::optional<std::string> json = std::nullopt)
|
||||
{
|
||||
// Note: block_scale_size_q_ and block_scale_size_kv_ should be greater than or equal to the
|
||||
// compute block size
|
||||
constexpr ck_tile::index_t block_scale_size_q_ = 128;
|
||||
constexpr ck_tile::index_t block_scale_size_kv_ = 128;
|
||||
|
||||
const std::string data_type = []() {
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp32>)
|
||||
return "fp32";
|
||||
@@ -471,7 +476,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
auto max_seqlen_q =
|
||||
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
|
||||
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
|
||||
size_t i_block_scale_q = 0;
|
||||
size_t i_block_scale_k = 0;
|
||||
std::vector<int32_t> block_scale_seqstart_q_host = {0};
|
||||
std::vector<int32_t> block_scale_seqstart_k_host = {0};
|
||||
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
|
||||
{
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
@@ -487,6 +496,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
{
|
||||
max_seqlen_k = real_seqlen_k;
|
||||
}
|
||||
i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_);
|
||||
i_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_);
|
||||
block_scale_seqstart_q_host.push_back(i_block_scale_q);
|
||||
block_scale_seqstart_k_host.push_back(i_block_scale_k);
|
||||
|
||||
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
|
||||
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
|
||||
@@ -548,6 +561,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
? seqstart_k_with_padding_host.back()
|
||||
: seqstart_k_host.back()));
|
||||
|
||||
const ck_tile::index_t num_block_scale_q =
|
||||
(mode == mode_enum::batch)
|
||||
? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_size_q_)
|
||||
: i_block_scale_q;
|
||||
const ck_tile::index_t num_block_scale_kv =
|
||||
(mode == mode_enum::batch)
|
||||
? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_size_kv_)
|
||||
: i_block_scale_k;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<SMPLComputeDataType> sink_host({nhead});
|
||||
@@ -599,9 +621,18 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
// TODO - change the tensor length for different quant scale
|
||||
ck_tile::HostTensor<float> q_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
|
||||
ck_tile::HostTensor<float> k_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
|
||||
ck_tile::HostTensor<float> v_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
|
||||
ck_tile::HostTensor<float> q_descale_host(
|
||||
qscale.type == quant_scale_enum::blockscale
|
||||
? std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_q}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
ck_tile::HostTensor<float> k_descale_host(
|
||||
qscale.type == quant_scale_enum::blockscale
|
||||
? std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
ck_tile::HostTensor<float> v_descale_host(
|
||||
qscale.type == quant_scale_enum::blockscale
|
||||
? std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
|
||||
// batch mode of lse data layout is [batch, nhead, seqlen_q]
|
||||
// group mode of lse data layout is [nhead, total_seqlen_q]
|
||||
@@ -717,6 +748,12 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
k_descale_host(0) = qkv_max / k_dtype_max;
|
||||
v_descale_host(0) = qkv_max / v_dtype_max;
|
||||
}
|
||||
else if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(q_descale_host);
|
||||
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(k_descale_host);
|
||||
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(v_descale_host);
|
||||
}
|
||||
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
|
||||
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
|
||||
@@ -737,6 +774,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem block_scale_seqstart_q_buf(block_scale_seqstart_q_host.size() *
|
||||
sizeof(int32_t));
|
||||
ck_tile::DeviceMem block_scale_seqstart_k_buf(block_scale_seqstart_k_host.size() *
|
||||
sizeof(int32_t));
|
||||
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
|
||||
@@ -782,6 +823,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
q_descale_buf.ToDevice(q_descale_host.data());
|
||||
k_descale_buf.ToDevice(k_descale_host.data());
|
||||
v_descale_buf.ToDevice(v_descale_host.data());
|
||||
block_scale_seqstart_q_buf.ToDevice(block_scale_seqstart_q_host.data());
|
||||
block_scale_seqstart_k_buf.ToDevice(block_scale_seqstart_k_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
// Keep logical starts in seqstart_k; pass padded K via separate pointer
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
@@ -975,11 +1018,14 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_q_descale = num_block_scale_q;
|
||||
const ck_tile::index_t nhead_stride_k_descale = num_block_scale_kv;
|
||||
const ck_tile::index_t nhead_stride_v_descale = num_block_scale_kv;
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k =
|
||||
@@ -997,6 +1043,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
|
||||
const ck_tile::index_t batch_stride_q_descale = num_block_scale_q * nhead;
|
||||
const ck_tile::index_t batch_stride_k_descale = num_block_scale_kv * nhead_k;
|
||||
const ck_tile::index_t batch_stride_v_descale = num_block_scale_kv * nhead_k;
|
||||
// setup split_stride_* arguments (only used in split-kv kernel)
|
||||
const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q);
|
||||
const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v);
|
||||
@@ -1084,9 +1133,39 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
|
||||
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
|
||||
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
args.q_descale_ptr =
|
||||
reinterpret_cast<const float*>(q_descale_buf.GetDeviceBuffer());
|
||||
args.k_descale_ptr =
|
||||
reinterpret_cast<const float*>(k_descale_buf.GetDeviceBuffer());
|
||||
args.v_descale_ptr =
|
||||
reinterpret_cast<const float*>(v_descale_buf.GetDeviceBuffer());
|
||||
|
||||
args.block_scale_seqstart_q_ptr =
|
||||
(mode == mode_enum::group ? block_scale_seqstart_q_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
args.block_scale_seqstart_k_ptr =
|
||||
(mode == mode_enum::group ? block_scale_seqstart_k_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
|
||||
args.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
args.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
args.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
args.batch_stride_q_descale = batch_stride_q_descale;
|
||||
args.batch_stride_k_descale = batch_stride_k_descale;
|
||||
args.batch_stride_v_descale = batch_stride_v_descale;
|
||||
|
||||
args.block_scale_size_q = block_scale_size_q_;
|
||||
args.block_scale_size_kv = block_scale_size_kv_;
|
||||
}
|
||||
else
|
||||
{
|
||||
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
|
||||
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
|
||||
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
|
||||
}
|
||||
|
||||
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
|
||||
|
||||
@@ -1589,14 +1668,42 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
#endif
|
||||
|
||||
// reference
|
||||
ck_tile::
|
||||
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
const ck_tile::index_t q_offset =
|
||||
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_q_host[wb];
|
||||
const ck_tile::index_t k_offset =
|
||||
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb];
|
||||
ck_tile::reference_batched_quant_gemm<QDataType,
|
||||
KDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType>(
|
||||
q_host_ref,
|
||||
k_host_ref,
|
||||
s_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale_s_host));
|
||||
ck_tile::idx_identity{},
|
||||
ck_tile::idx_identity{},
|
||||
[&](auto idx, auto value) {
|
||||
return value * scale_s *
|
||||
q_descale_host(b_idx,
|
||||
std::get<0>(idx),
|
||||
q_offset + std::get<1>(idx) / block_scale_size_q_) *
|
||||
k_descale_host(b_idx,
|
||||
std::get<0>(idx) / nr,
|
||||
k_offset + std::get<2>(idx) / block_scale_size_kv_);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::
|
||||
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
|
||||
q_host_ref,
|
||||
k_host_ref,
|
||||
s_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale_s_host));
|
||||
}
|
||||
|
||||
if(0.f < logits_soft_cap)
|
||||
{
|
||||
@@ -1794,13 +1901,35 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref,
|
||||
v_host_ref,
|
||||
o_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
oacc_element_func);
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
const ck_tile::index_t v_offset =
|
||||
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb];
|
||||
ck_tile::
|
||||
reference_batched_quant_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref,
|
||||
v_host_ref,
|
||||
o_host_ref,
|
||||
ck_tile::idx_identity{},
|
||||
[&](auto idx, auto value) {
|
||||
return ck_tile::type_convert<float>(value) *
|
||||
v_descale_host(b_idx,
|
||||
std::get<0>(idx) / nr,
|
||||
v_offset +
|
||||
std::get<2>(idx) / block_scale_size_kv_);
|
||||
},
|
||||
ck_tile::idx_identity{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref,
|
||||
v_host_ref,
|
||||
o_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
oacc_element_func);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
|
||||
// clang-format off
|
||||
@@ -1808,7 +1937,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); });
|
||||
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool cur_pass = ck_tile::check_err(o_host_result,
|
||||
o_host_ref,
|
||||
@@ -1866,31 +1994,33 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
if(json)
|
||||
{
|
||||
dump_fmha_fwd_json_results(*json,
|
||||
data_type,
|
||||
mode == mode_enum::batch ? "batch" : "group",
|
||||
io_layout(i_perm, o_perm),
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs[0],
|
||||
seqlen_ks[0],
|
||||
seqlen_kpads[0],
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
scale_s,
|
||||
p_drop,
|
||||
lse,
|
||||
qscale.type == quant_scale_enum::no_scale ? "no_scale"
|
||||
: "pertensor",
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? "elementwise_bias"
|
||||
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
|
||||
is_v_rowmajor ? "r" : "c",
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
dump_fmha_fwd_json_results(
|
||||
*json,
|
||||
data_type,
|
||||
mode == mode_enum::batch ? "batch" : "group",
|
||||
io_layout(i_perm, o_perm),
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs[0],
|
||||
seqlen_ks[0],
|
||||
seqlen_kpads[0],
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
scale_s,
|
||||
p_drop,
|
||||
lse,
|
||||
qscale.type == quant_scale_enum::no_scale
|
||||
? "no_scale"
|
||||
: (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"),
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? "elementwise_bias"
|
||||
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
|
||||
is_v_rowmajor ? "r" : "c",
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
|
||||
return pass ? fwd_result::success : fwd_result::failure;
|
||||
|
||||
@@ -13,6 +13,7 @@ enum class quant_scale_enum
|
||||
{
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale,
|
||||
};
|
||||
|
||||
struct quant_scale_info
|
||||
@@ -25,6 +26,8 @@ struct quant_scale_info
|
||||
os << "n";
|
||||
else if(type == quant_scale_enum::pertensor)
|
||||
os << "pt";
|
||||
else if(type == quant_scale_enum::blockscale)
|
||||
os << "bs";
|
||||
}
|
||||
|
||||
static quant_scale_info decode(std::string str)
|
||||
@@ -38,6 +41,10 @@ struct quant_scale_info
|
||||
{
|
||||
info.type = quant_scale_enum::pertensor;
|
||||
}
|
||||
else if(str == "bs" || str == "2")
|
||||
{
|
||||
info.type = quant_scale_enum::blockscale;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid quant scale value: " + str);
|
||||
|
||||
@@ -95,10 +95,11 @@ run_fp8bf16_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
for scale in 1 2; do
|
||||
|
||||
$EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=1 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=$scale -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done
|
||||
done ; done ; done ; done
|
||||
}
|
||||
|
||||
run_fp8fp32_tests() {
|
||||
|
||||
@@ -37,6 +37,13 @@ struct scales
|
||||
return lhs_ * rhs;
|
||||
}
|
||||
|
||||
template <typename OtherScale>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(OtherScale other) const
|
||||
{
|
||||
auto new_scale = lhs_ * other;
|
||||
return scales<std::decay_t<decltype(new_scale)>>(new_scale);
|
||||
}
|
||||
|
||||
private:
|
||||
Scale lhs_;
|
||||
};
|
||||
|
||||
@@ -119,6 +119,18 @@ struct identity
|
||||
}
|
||||
};
|
||||
|
||||
// Similar to identity, but takes an additional index parameter as the first argument.
|
||||
// The index is ignored and only the second argument (value) is forwarded.
|
||||
// Useful for indexed element-wise operations where the functor signature requires an index.
|
||||
struct idx_identity
|
||||
{
|
||||
template <typename I, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T&& operator()(I&& /*idx*/, T&& arg) const noexcept
|
||||
{
|
||||
return std::forward<T>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
// RemainLengths: sequence<...>
|
||||
|
||||
@@ -47,4 +47,44 @@ CK_TILE_HOST void reference_batched_gemm(const HostTensor<ADataType>& a_b_m_k,
|
||||
make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AElementOp = ck_tile::idx_identity,
|
||||
typename BElementOp = ck_tile::idx_identity,
|
||||
typename ACCElementOp = ck_tile::idx_identity>
|
||||
CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor<ADataType>& a_b_m_k,
|
||||
const HostTensor<BDataType>& b_b_n_k,
|
||||
HostTensor<CDataType>& c_b_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const int N = b_b_n_k.mDesc.get_lengths()[1];
|
||||
const int K = b_b_n_k.mDesc.get_lengths()[2];
|
||||
|
||||
auto f = [&](auto batch, auto m) {
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_a = ck_tile::type_convert<AccDataType>(
|
||||
a_element_op(std::make_tuple(batch, m, k), a_b_m_k(batch, m, k)));
|
||||
AccDataType v_b = ck_tile::type_convert<AccDataType>(
|
||||
b_element_op(std::make_tuple(batch, n, k), b_b_n_k(batch, n, k)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
c_b_m_n(batch, m, n) = ck_tile::type_convert<CDataType>(
|
||||
acc_element_op(std::make_tuple(batch, m, n), v_acc));
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -12,6 +12,7 @@ enum class BlockAttentionQuantScaleEnum
|
||||
{
|
||||
NO_SCALE = 0,
|
||||
PERTENSOR = 1,
|
||||
BLOCKSCALE,
|
||||
};
|
||||
|
||||
template <BlockAttentionQuantScaleEnum>
|
||||
@@ -27,5 +28,10 @@ struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::PERTENSOR
|
||||
{
|
||||
static constexpr const char* name = "pertensor";
|
||||
};
|
||||
template <>
|
||||
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::BLOCKSCALE>
|
||||
{
|
||||
static constexpr const char* name = "blockscale";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -168,6 +168,29 @@ struct FmhaFwdKernel
|
||||
const void* v_descale_ptr = nullptr;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonBlockScaleKargs : public FmhaFwdCommonQScaleKargs
|
||||
{
|
||||
ck_tile::index_t nhead_stride_q_descale;
|
||||
ck_tile::index_t nhead_stride_k_descale;
|
||||
ck_tile::index_t nhead_stride_v_descale;
|
||||
|
||||
ck_tile::index_t block_scale_size_q;
|
||||
ck_tile::index_t block_scale_size_kv;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_q_descale;
|
||||
ck_tile::index_t batch_stride_k_descale;
|
||||
ck_tile::index_t batch_stride_v_descale;
|
||||
};
|
||||
|
||||
struct FmhaFwdGroupBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs
|
||||
{
|
||||
const int32_t* block_scale_seqstart_q_ptr;
|
||||
const int32_t* block_scale_seqstart_k_ptr;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonLSEKargs
|
||||
{
|
||||
void* lse_ptr = nullptr;
|
||||
@@ -243,9 +266,12 @@ struct FmhaFwdKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
|
||||
FmhaFwdBatchBlockScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
@@ -269,9 +295,12 @@ struct FmhaFwdKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
|
||||
FmhaFwdGroupBlockScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
|
||||
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
|
||||
@@ -328,6 +357,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
@@ -335,6 +367,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_q_descale,
|
||||
ck_tile::index_t batch_stride_k_descale,
|
||||
ck_tile::index_t batch_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -343,6 +378,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -413,6 +450,23 @@ struct FmhaFwdKernel
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
|
||||
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
kargs.batch_stride_q_descale = batch_stride_q_descale;
|
||||
kargs.batch_stride_k_descale = batch_stride_k_descale;
|
||||
kargs.batch_stride_v_descale = batch_stride_v_descale;
|
||||
|
||||
kargs.block_scale_size_q = block_scale_size_q;
|
||||
kargs.block_scale_size_kv = block_scale_size_kv;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -478,6 +532,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
@@ -485,6 +542,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_q_descale,
|
||||
ck_tile::index_t batch_stride_k_descale,
|
||||
ck_tile::index_t batch_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -492,6 +552,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -528,6 +590,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -535,6 +600,9 @@ struct FmhaFwdKernel
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
batch_stride_q_descale,
|
||||
batch_stride_k_descale,
|
||||
batch_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
@@ -542,6 +610,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_size_q,
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
@@ -581,6 +651,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
@@ -588,6 +661,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_q_descale,
|
||||
ck_tile::index_t batch_stride_k_descale,
|
||||
ck_tile::index_t batch_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -595,6 +671,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -631,6 +709,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -638,6 +719,9 @@ struct FmhaFwdKernel
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
batch_stride_q_descale,
|
||||
batch_stride_k_descale,
|
||||
batch_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
@@ -645,6 +729,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_size_q,
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
@@ -666,6 +752,8 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* block_scale_seqstart_q_ptr,
|
||||
const void* block_scale_seqstart_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -685,6 +773,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -694,6 +785,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -763,6 +856,24 @@ struct FmhaFwdKernel
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
|
||||
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
kargs.block_scale_size_q = block_scale_size_q;
|
||||
kargs.block_scale_size_kv = block_scale_size_kv;
|
||||
|
||||
kargs.block_scale_seqstart_q_ptr =
|
||||
reinterpret_cast<const int32_t*>(block_scale_seqstart_q_ptr);
|
||||
kargs.block_scale_seqstart_k_ptr =
|
||||
reinterpret_cast<const int32_t*>(block_scale_seqstart_k_ptr);
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -814,6 +925,8 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* block_scale_seqstart_q_ptr,
|
||||
const void* block_scale_seqstart_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -833,6 +946,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -841,6 +957,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -860,6 +978,8 @@ struct FmhaFwdKernel
|
||||
seqstart_k_ptr,
|
||||
seqlen_q_ptr,
|
||||
seqlen_k_ptr,
|
||||
block_scale_seqstart_q_ptr,
|
||||
block_scale_seqstart_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -879,6 +999,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
@@ -887,6 +1010,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_size_q,
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
@@ -909,6 +1034,8 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* block_scale_seqstart_q_ptr,
|
||||
const void* block_scale_seqstart_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -928,6 +1055,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -936,6 +1066,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -955,6 +1087,8 @@ struct FmhaFwdKernel
|
||||
seqstart_k_ptr,
|
||||
seqlen_q_ptr,
|
||||
seqlen_k_ptr,
|
||||
block_scale_seqstart_q_ptr,
|
||||
block_scale_seqstart_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -974,6 +1108,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
@@ -982,6 +1119,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_size_q,
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
@@ -1111,13 +1250,16 @@ struct FmhaFwdKernel
|
||||
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_randval = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_randval = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
long_index_t batch_offset_q_descale = 0;
|
||||
long_index_t batch_offset_k_descale = 0;
|
||||
long_index_t batch_offset_v_descale = 0;
|
||||
const float sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
|
||||
@@ -1153,6 +1295,14 @@ struct FmhaFwdKernel
|
||||
{
|
||||
batch_offset_randval = query_start * kargs.stride_randval;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
const long_index_t bquery_start = kargs.block_scale_seqstart_q_ptr[i_batch];
|
||||
const long_index_t bkey_start = kargs.block_scale_seqstart_k_ptr[i_batch];
|
||||
batch_offset_q_descale = bquery_start;
|
||||
batch_offset_k_descale = bkey_start;
|
||||
batch_offset_v_descale = bkey_start;
|
||||
}
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
|
||||
// real logical lengths (exclude PAD)
|
||||
@@ -1220,6 +1370,15 @@ struct FmhaFwdKernel
|
||||
batch_offset_randval =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
batch_offset_q_descale =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q_descale;
|
||||
batch_offset_k_descale =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_k_descale;
|
||||
batch_offset_v_descale =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_v_descale;
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
|
||||
// If cumulative seqlen pointers are provided, override per-batch effective lengths
|
||||
@@ -1540,7 +1699,8 @@ struct FmhaFwdKernel
|
||||
}();
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
auto o_acc_tile = [&]() {
|
||||
|
||||
auto o_acc_tile = [&, i_nhead_ = i_nhead]() {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
// TODO - move global load of descale to pipeline
|
||||
@@ -1581,8 +1741,62 @@ struct FmhaFwdKernel
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
sink_value);
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
const float* q_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.q_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_q_descale +
|
||||
batch_offset_q_descale;
|
||||
const float* k_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.k_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_k_descale +
|
||||
batch_offset_k_descale;
|
||||
const float* v_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.v_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_v_descale +
|
||||
batch_offset_v_descale;
|
||||
|
||||
size_t idx = i_m0 / kargs.block_scale_size_q;
|
||||
float q_descale = q_descale_ptr[idx];
|
||||
// BLOCKSCALE: P is scaled in exp2(x+shift) where shift=7 or 8
|
||||
// Both P and rowsum are scaled by 2^shift, canceling in normalization
|
||||
// No additional scaling needed in p_compute_element_func or o_acc_element_func
|
||||
|
||||
return FmhaPipeline{}(
|
||||
q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
scales<float>(q_descale), // s_acc_element_func
|
||||
identity{}, // p_compute_element_func - No scaling (done in exp2)
|
||||
identity{}, // o_acc_element_func - No dequant needed (canceled by rowsum)
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
kargs.block_scale_size_kv,
|
||||
sink_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
|
||||
@@ -57,8 +57,13 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
// For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift]
|
||||
static constexpr float OCP_FP8_SHIFT = 8.0f;
|
||||
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
|
||||
|
||||
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
|
||||
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
|
||||
|
||||
@@ -167,6 +172,9 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout,
|
||||
const float* k_descale_ptr,
|
||||
const float* v_descale_ptr,
|
||||
const index_t block_scale_size_kv,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
@@ -358,6 +366,13 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static_assert(1 <= k1_loops);
|
||||
do
|
||||
{
|
||||
float k_descale = 1.0f;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
// K and V share the same seqlen_k position within a block
|
||||
const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv;
|
||||
k_descale = k_descale_ptr[kv_idx];
|
||||
}
|
||||
// STAGE 1, QK gemm
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window.get_bottom_tensor_view(),
|
||||
@@ -427,11 +442,20 @@ struct BlockFmhaPipelineQRKSVS
|
||||
k_lds_window);
|
||||
schedule_gemm0();
|
||||
}
|
||||
// dequant
|
||||
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
return s_acc_element_func * k_descale;
|
||||
}
|
||||
else
|
||||
return s_acc_element_func;
|
||||
}();
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
@@ -449,7 +473,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
@@ -466,7 +490,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
auto apply_logits_transform =
|
||||
@@ -571,7 +595,21 @@ struct BlockFmhaPipelineQRKSVS
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
// For BLOCKSCALE: precompute (m - shift) once per row
|
||||
// Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift))
|
||||
// else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift))
|
||||
auto validated_m = get_validated_m(m[i_idx]);
|
||||
auto row_max = scale_s * validated_m;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap
|
||||
row_max -= OCP_FP8_SHIFT; // for else branch
|
||||
#else
|
||||
validated_m -= FNUZ_FP8_SHIFT;
|
||||
row_max -= FNUZ_FP8_SHIFT;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
@@ -579,13 +617,13 @@ struct BlockFmhaPipelineQRKSVS
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -676,18 +714,39 @@ struct BlockFmhaPipelineQRKSVS
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
|
||||
}
|
||||
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
|
||||
float v_descale = 1.0f;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
// K and V share the same seqlen_k position within a block
|
||||
const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv;
|
||||
v_descale = v_descale_ptr[kv_idx];
|
||||
}
|
||||
// STAGE 3, KV gemm
|
||||
auto o_acc0 = decltype(o_acc){};
|
||||
clear_tile(o_acc0);
|
||||
|
||||
auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
return o_acc0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return o_acc;
|
||||
}
|
||||
}();
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
gemm_1(o_acc_,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_window);
|
||||
@@ -722,11 +781,16 @@ struct BlockFmhaPipelineQRKSVS
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
gemm_1(o_acc_,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0);
|
||||
}
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// store lse
|
||||
@@ -846,6 +910,9 @@ struct BlockFmhaPipelineQRKSVS
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -46,6 +46,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
@@ -64,6 +65,10 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
// For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift]
|
||||
static constexpr float OCP_FP8_SHIFT = 8.0f;
|
||||
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
!kHasLogitsSoftCap)) ||
|
||||
@@ -190,6 +195,9 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout,
|
||||
const float* k_descale_ptr,
|
||||
const float* v_descale_ptr,
|
||||
const index_t block_scale_size_kv,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
@@ -403,6 +411,13 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
// main loop
|
||||
do
|
||||
{
|
||||
float k_descale = 1.0f;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
// K and V share the same seqlen_k position within a block
|
||||
const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv;
|
||||
k_descale = k_descale_ptr[kv_idx];
|
||||
}
|
||||
// STAGE 1, QK gemm
|
||||
clear_tile(s_acc); // initialize C
|
||||
if constexpr(k0_loops > 1)
|
||||
@@ -449,11 +464,20 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(1);
|
||||
// dequant
|
||||
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
return s_acc_element_func * k_descale;
|
||||
}
|
||||
else
|
||||
return s_acc_element_func;
|
||||
}();
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
@@ -471,7 +495,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
@@ -488,7 +512,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
}
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
auto apply_logits_transform =
|
||||
@@ -630,7 +654,21 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
// For BLOCKSCALE: precompute (m - shift) once per row
|
||||
// Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift))
|
||||
// else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift))
|
||||
auto validated_m = get_validated_m(m[i_idx]);
|
||||
auto row_max = scale_s * validated_m;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap
|
||||
row_max -= OCP_FP8_SHIFT; // for else branch
|
||||
#else
|
||||
validated_m -= FNUZ_FP8_SHIFT;
|
||||
row_max -= FNUZ_FP8_SHIFT;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
@@ -638,13 +676,13 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -735,7 +773,27 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
#endif
|
||||
}();
|
||||
|
||||
float v_descale = 1.0f;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
// K and V share the same seqlen_k position within a block
|
||||
const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv;
|
||||
v_descale = v_descale_ptr[kv_idx];
|
||||
}
|
||||
// STAGE 3, KV gemm
|
||||
auto o_acc0 = decltype(o_acc){};
|
||||
clear_tile(o_acc0);
|
||||
|
||||
auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
return o_acc0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return o_acc;
|
||||
}
|
||||
}();
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
@@ -745,7 +803,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
}
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
gemm_1(o_acc_,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
get_slice_tile(
|
||||
@@ -808,13 +866,19 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(
|
||||
o_acc,
|
||||
o_acc_,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
|
||||
}
|
||||
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0);
|
||||
}
|
||||
} while(i_total_loops < num_total_loop);
|
||||
|
||||
// store lse
|
||||
@@ -922,6 +986,9 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user