mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Fp8 block scale quantization for fmha fwd (#3330)
* add block scale parameters to kernel
* add block scale to kernel
* add smoke test
* format
* Revert "format"
This reverts commit 356c3c9706.
* only format my code
* format py
* fix auto not allowd in function prototype
* change instance tttt to ttff
* fix structured binding issue
* change s_acc elementwise op
* async pipeline add block scale
* add quantation P using shift exp2
* precompute (m - shift) once per row
* change blk scale seqstrt ptr name
* fix some name
* fix for deduction guide
* fix some comments
* add P scale to qr_ksvs_pipeline
* add comment to idx_identity
* change the method of calculating descale block index
* unify naming style: use block_scale_ as name prefix
* unify naming style
* update the CHANGELOG.md
* Add FP8 block scale quantization support for FMHA forward kernel
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -77,11 +77,13 @@ def get_mask_cpp_check_expr(mask: str) -> str:
|
||||
QSCALE_MAP = {
|
||||
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
|
||||
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
}
|
||||
|
||||
QSCALE_CHECK_MAP = {
|
||||
"no": "quant_scale_enum::no_scale",
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
"blockscale": "quant_scale_enum::blockscale",
|
||||
}
|
||||
|
||||
BIAS_MAP = {
|
||||
|
||||
@@ -1018,7 +1018,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias, sink in itertools.product(
|
||||
["t", "f"],
|
||||
["no", "pertensor"],
|
||||
["no", "pertensor", "blockscale"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
["f", "t"],
|
||||
@@ -1146,7 +1146,10 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
|
||||
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"]
|
||||
["f"],
|
||||
["no", "pertensor", "blockscale"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
|
||||
|
||||
@@ -230,6 +230,8 @@ struct fmha_fwd_args
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* block_scale_seqstart_q_ptr;
|
||||
const void* block_scale_seqstart_k_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
@@ -257,6 +259,9 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_q_descale;
|
||||
ck_tile::index_t nhead_stride_k_descale;
|
||||
ck_tile::index_t nhead_stride_v_descale;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
@@ -264,6 +269,9 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_q_descale;
|
||||
ck_tile::index_t batch_stride_k_descale;
|
||||
ck_tile::index_t batch_stride_v_descale;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
@@ -276,6 +284,9 @@ struct fmha_fwd_args
|
||||
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
|
||||
ck_tile::index_t block_scale_size_q;
|
||||
ck_tile::index_t block_scale_size_kv;
|
||||
};
|
||||
|
||||
struct fmha_fwd_pagedkv_args
|
||||
@@ -615,6 +626,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.block_scale_seqstart_q_ptr,
|
||||
args.block_scale_seqstart_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
@@ -634,6 +647,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
@@ -642,6 +658,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.block_scale_size_q,
|
||||
args.block_scale_size_kv,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
@@ -679,6 +697,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
@@ -686,6 +707,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.batch_stride_q_descale,
|
||||
args.batch_stride_k_descale,
|
||||
args.batch_stride_v_descale,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
@@ -693,6 +717,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.block_scale_size_q,
|
||||
args.block_scale_size_kv,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
|
||||
@@ -210,6 +210,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::stream_config& stream_config,
|
||||
std::optional<std::string> json = std::nullopt)
|
||||
{
|
||||
// Note: block_scale_size_q_ and block_scale_size_kv_ should be greater than or equal to the
|
||||
// compute block size
|
||||
constexpr ck_tile::index_t block_scale_size_q_ = 128;
|
||||
constexpr ck_tile::index_t block_scale_size_kv_ = 128;
|
||||
|
||||
const std::string data_type = []() {
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp32>)
|
||||
return "fp32";
|
||||
@@ -471,7 +476,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
auto max_seqlen_q =
|
||||
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
|
||||
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
|
||||
size_t i_block_scale_q = 0;
|
||||
size_t i_block_scale_k = 0;
|
||||
std::vector<int32_t> block_scale_seqstart_q_host = {0};
|
||||
std::vector<int32_t> block_scale_seqstart_k_host = {0};
|
||||
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
|
||||
{
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
@@ -487,6 +496,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
{
|
||||
max_seqlen_k = real_seqlen_k;
|
||||
}
|
||||
i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_);
|
||||
i_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_);
|
||||
block_scale_seqstart_q_host.push_back(i_block_scale_q);
|
||||
block_scale_seqstart_k_host.push_back(i_block_scale_k);
|
||||
|
||||
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
|
||||
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
|
||||
@@ -548,6 +561,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
? seqstart_k_with_padding_host.back()
|
||||
: seqstart_k_host.back()));
|
||||
|
||||
const ck_tile::index_t num_block_scale_q =
|
||||
(mode == mode_enum::batch)
|
||||
? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_size_q_)
|
||||
: i_block_scale_q;
|
||||
const ck_tile::index_t num_block_scale_kv =
|
||||
(mode == mode_enum::batch)
|
||||
? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_size_kv_)
|
||||
: i_block_scale_k;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<SMPLComputeDataType> sink_host({nhead});
|
||||
@@ -599,9 +621,18 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
// TODO - change the tensor length for different quant scale
|
||||
ck_tile::HostTensor<float> q_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
|
||||
ck_tile::HostTensor<float> k_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
|
||||
ck_tile::HostTensor<float> v_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
|
||||
ck_tile::HostTensor<float> q_descale_host(
|
||||
qscale.type == quant_scale_enum::blockscale
|
||||
? std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_q}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
ck_tile::HostTensor<float> k_descale_host(
|
||||
qscale.type == quant_scale_enum::blockscale
|
||||
? std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
ck_tile::HostTensor<float> v_descale_host(
|
||||
qscale.type == quant_scale_enum::blockscale
|
||||
? std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
|
||||
// batch mode of lse data layout is [batch, nhead, seqlen_q]
|
||||
// group mode of lse data layout is [nhead, total_seqlen_q]
|
||||
@@ -717,6 +748,12 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
k_descale_host(0) = qkv_max / k_dtype_max;
|
||||
v_descale_host(0) = qkv_max / v_dtype_max;
|
||||
}
|
||||
else if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(q_descale_host);
|
||||
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(k_descale_host);
|
||||
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(v_descale_host);
|
||||
}
|
||||
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
|
||||
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
|
||||
@@ -737,6 +774,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem block_scale_seqstart_q_buf(block_scale_seqstart_q_host.size() *
|
||||
sizeof(int32_t));
|
||||
ck_tile::DeviceMem block_scale_seqstart_k_buf(block_scale_seqstart_k_host.size() *
|
||||
sizeof(int32_t));
|
||||
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
|
||||
@@ -782,6 +823,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
q_descale_buf.ToDevice(q_descale_host.data());
|
||||
k_descale_buf.ToDevice(k_descale_host.data());
|
||||
v_descale_buf.ToDevice(v_descale_host.data());
|
||||
block_scale_seqstart_q_buf.ToDevice(block_scale_seqstart_q_host.data());
|
||||
block_scale_seqstart_k_buf.ToDevice(block_scale_seqstart_k_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
// Keep logical starts in seqstart_k; pass padded K via separate pointer
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
@@ -975,11 +1018,14 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_q_descale = num_block_scale_q;
|
||||
const ck_tile::index_t nhead_stride_k_descale = num_block_scale_kv;
|
||||
const ck_tile::index_t nhead_stride_v_descale = num_block_scale_kv;
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k =
|
||||
@@ -997,6 +1043,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
|
||||
const ck_tile::index_t batch_stride_q_descale = num_block_scale_q * nhead;
|
||||
const ck_tile::index_t batch_stride_k_descale = num_block_scale_kv * nhead_k;
|
||||
const ck_tile::index_t batch_stride_v_descale = num_block_scale_kv * nhead_k;
|
||||
// setup split_stride_* arguments (only used in split-kv kernel)
|
||||
const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q);
|
||||
const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v);
|
||||
@@ -1084,9 +1133,39 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
|
||||
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
|
||||
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
args.q_descale_ptr =
|
||||
reinterpret_cast<const float*>(q_descale_buf.GetDeviceBuffer());
|
||||
args.k_descale_ptr =
|
||||
reinterpret_cast<const float*>(k_descale_buf.GetDeviceBuffer());
|
||||
args.v_descale_ptr =
|
||||
reinterpret_cast<const float*>(v_descale_buf.GetDeviceBuffer());
|
||||
|
||||
args.block_scale_seqstart_q_ptr =
|
||||
(mode == mode_enum::group ? block_scale_seqstart_q_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
args.block_scale_seqstart_k_ptr =
|
||||
(mode == mode_enum::group ? block_scale_seqstart_k_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
|
||||
args.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
args.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
args.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
args.batch_stride_q_descale = batch_stride_q_descale;
|
||||
args.batch_stride_k_descale = batch_stride_k_descale;
|
||||
args.batch_stride_v_descale = batch_stride_v_descale;
|
||||
|
||||
args.block_scale_size_q = block_scale_size_q_;
|
||||
args.block_scale_size_kv = block_scale_size_kv_;
|
||||
}
|
||||
else
|
||||
{
|
||||
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
|
||||
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
|
||||
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
|
||||
}
|
||||
|
||||
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
|
||||
|
||||
@@ -1589,14 +1668,42 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
#endif
|
||||
|
||||
// reference
|
||||
ck_tile::
|
||||
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
const ck_tile::index_t q_offset =
|
||||
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_q_host[wb];
|
||||
const ck_tile::index_t k_offset =
|
||||
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb];
|
||||
ck_tile::reference_batched_quant_gemm<QDataType,
|
||||
KDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType>(
|
||||
q_host_ref,
|
||||
k_host_ref,
|
||||
s_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale_s_host));
|
||||
ck_tile::idx_identity{},
|
||||
ck_tile::idx_identity{},
|
||||
[&](auto idx, auto value) {
|
||||
return value * scale_s *
|
||||
q_descale_host(b_idx,
|
||||
std::get<0>(idx),
|
||||
q_offset + std::get<1>(idx) / block_scale_size_q_) *
|
||||
k_descale_host(b_idx,
|
||||
std::get<0>(idx) / nr,
|
||||
k_offset + std::get<2>(idx) / block_scale_size_kv_);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::
|
||||
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
|
||||
q_host_ref,
|
||||
k_host_ref,
|
||||
s_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale_s_host));
|
||||
}
|
||||
|
||||
if(0.f < logits_soft_cap)
|
||||
{
|
||||
@@ -1794,13 +1901,35 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref,
|
||||
v_host_ref,
|
||||
o_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
oacc_element_func);
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
const ck_tile::index_t v_offset =
|
||||
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb];
|
||||
ck_tile::
|
||||
reference_batched_quant_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref,
|
||||
v_host_ref,
|
||||
o_host_ref,
|
||||
ck_tile::idx_identity{},
|
||||
[&](auto idx, auto value) {
|
||||
return ck_tile::type_convert<float>(value) *
|
||||
v_descale_host(b_idx,
|
||||
std::get<0>(idx) / nr,
|
||||
v_offset +
|
||||
std::get<2>(idx) / block_scale_size_kv_);
|
||||
},
|
||||
ck_tile::idx_identity{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref,
|
||||
v_host_ref,
|
||||
o_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
oacc_element_func);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
|
||||
// clang-format off
|
||||
@@ -1808,7 +1937,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); });
|
||||
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool cur_pass = ck_tile::check_err(o_host_result,
|
||||
o_host_ref,
|
||||
@@ -1866,31 +1994,33 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
if(json)
|
||||
{
|
||||
dump_fmha_fwd_json_results(*json,
|
||||
data_type,
|
||||
mode == mode_enum::batch ? "batch" : "group",
|
||||
io_layout(i_perm, o_perm),
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs[0],
|
||||
seqlen_ks[0],
|
||||
seqlen_kpads[0],
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
scale_s,
|
||||
p_drop,
|
||||
lse,
|
||||
qscale.type == quant_scale_enum::no_scale ? "no_scale"
|
||||
: "pertensor",
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? "elementwise_bias"
|
||||
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
|
||||
is_v_rowmajor ? "r" : "c",
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
dump_fmha_fwd_json_results(
|
||||
*json,
|
||||
data_type,
|
||||
mode == mode_enum::batch ? "batch" : "group",
|
||||
io_layout(i_perm, o_perm),
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs[0],
|
||||
seqlen_ks[0],
|
||||
seqlen_kpads[0],
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
scale_s,
|
||||
p_drop,
|
||||
lse,
|
||||
qscale.type == quant_scale_enum::no_scale
|
||||
? "no_scale"
|
||||
: (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"),
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? "elementwise_bias"
|
||||
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
|
||||
is_v_rowmajor ? "r" : "c",
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
|
||||
return pass ? fwd_result::success : fwd_result::failure;
|
||||
|
||||
@@ -13,6 +13,7 @@ enum class quant_scale_enum
|
||||
{
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale,
|
||||
};
|
||||
|
||||
struct quant_scale_info
|
||||
@@ -25,6 +26,8 @@ struct quant_scale_info
|
||||
os << "n";
|
||||
else if(type == quant_scale_enum::pertensor)
|
||||
os << "pt";
|
||||
else if(type == quant_scale_enum::blockscale)
|
||||
os << "bs";
|
||||
}
|
||||
|
||||
static quant_scale_info decode(std::string str)
|
||||
@@ -38,6 +41,10 @@ struct quant_scale_info
|
||||
{
|
||||
info.type = quant_scale_enum::pertensor;
|
||||
}
|
||||
else if(str == "bs" || str == "2")
|
||||
{
|
||||
info.type = quant_scale_enum::blockscale;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid quant scale value: " + str);
|
||||
|
||||
@@ -95,10 +95,11 @@ run_fp8bf16_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
for scale in 1 2; do
|
||||
|
||||
$EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=1 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=$scale -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done
|
||||
done ; done ; done ; done
|
||||
}
|
||||
|
||||
run_fp8fp32_tests() {
|
||||
|
||||
Reference in New Issue
Block a user