mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
add blockscale parameters to kernel
This commit is contained in:
@@ -66,11 +66,13 @@ def get_mask_check_map(mask: str):
|
||||
QSCALE_MAP = {
|
||||
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
|
||||
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
}
|
||||
|
||||
QSCALE_CHECK_MAP = {
|
||||
"no": "quant_scale_enum::no_scale",
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
"blockscale": "quant_scale_enum::blockscale",
|
||||
}
|
||||
|
||||
BIAS_MAP = {
|
||||
|
||||
@@ -343,7 +343,7 @@ class FmhaFwdPipeline:
|
||||
F_bias: str # true/false
|
||||
F_lse: str #
|
||||
F_dropout: str #
|
||||
F_qscale: str # no/pertensor
|
||||
F_qscale: str # no/pertensor/blockscale
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_skip: str # true/false
|
||||
F_trload: str # true/false
|
||||
@@ -739,7 +739,7 @@ class KernelComponentFactoryGfx9:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
["f"],
|
||||
["no", "pertensor"],
|
||||
["no", "pertensor", "blockscale"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
):
|
||||
@@ -830,7 +830,7 @@ class KernelComponentFactoryGfx12:
|
||||
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"]
|
||||
["f"], ["no", "pertensor", "blockscale"], get_mask_map(mask_impl).keys(), ["no"]
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
|
||||
|
||||
@@ -256,6 +256,9 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_q_descale;
|
||||
ck_tile::index_t nhead_stride_k_descale;
|
||||
ck_tile::index_t nhead_stride_v_descale;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
@@ -263,6 +266,9 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_q_descale;
|
||||
ck_tile::index_t batch_stride_k_descale;
|
||||
ck_tile::index_t batch_stride_v_descale;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
@@ -274,6 +280,9 @@ struct fmha_fwd_args
|
||||
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
|
||||
ck_tile::index_t block_scale_m;
|
||||
ck_tile::index_t block_scale_n;
|
||||
};
|
||||
|
||||
struct fmha_fwd_pagedkv_args
|
||||
@@ -604,6 +613,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
@@ -618,6 +630,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.block_scale_m,
|
||||
args.block_scale_n,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
}
|
||||
@@ -654,6 +668,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
@@ -661,12 +678,17 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.batch_stride_q_descale,
|
||||
args.batch_stride_k_descale,
|
||||
args.batch_stride_v_descale,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.block_scale_m,
|
||||
args.block_scale_n,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
}
|
||||
|
||||
@@ -187,6 +187,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::stream_config& stream_config,
|
||||
std::optional<std::string> json = std::nullopt)
|
||||
{
|
||||
constexpr ck_tile::index_t block_scale_m_ = 128;
|
||||
constexpr ck_tile::index_t block_scale_n_ = 128;
|
||||
|
||||
const std::string data_type = []() {
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp32>)
|
||||
return "fp32";
|
||||
@@ -448,7 +451,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
auto max_seqlen_q =
|
||||
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
|
||||
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
|
||||
size_t num_block_scale_q = 0;
|
||||
size_t num_block_scale_k = 0;
|
||||
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
|
||||
{
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
@@ -464,6 +469,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
{
|
||||
max_seqlen_k = real_seqlen_k;
|
||||
}
|
||||
num_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_m_);
|
||||
num_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_n_);
|
||||
|
||||
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
|
||||
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
|
||||
@@ -525,6 +532,13 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
? seqstart_k_with_padding_host.back()
|
||||
: seqstart_k_host.back()));
|
||||
|
||||
const ck_tile::index_t num_block_scale_m =
|
||||
(mode == mode_enum::batch) ? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_m_)
|
||||
: num_block_scale_q;
|
||||
const ck_tile::index_t num_block_scale_n =
|
||||
(mode == mode_enum::batch) ? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_n_)
|
||||
: num_block_scale_k;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
@@ -575,9 +589,18 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
// TODO - change the tensor length for different quant scale
|
||||
ck_tile::HostTensor<float> q_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
|
||||
ck_tile::HostTensor<float> k_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
|
||||
ck_tile::HostTensor<float> v_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
|
||||
ck_tile::HostTensor<float> q_descale_host(
|
||||
qscale.type == quant_scale_enum::blockscale
|
||||
? std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_m}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
ck_tile::HostTensor<float> k_descale_host(
|
||||
qscale.type == quant_scale_enum::blockscale
|
||||
? std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_n}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
ck_tile::HostTensor<float> v_descale_host(
|
||||
qscale.type == quant_scale_enum::blockscale
|
||||
? std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_n}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
|
||||
// batch mode of lse data layout is [batch, nhead, seqlen_q]
|
||||
// group mode of lse data layout is [nhead, total_seqlen_q]
|
||||
@@ -692,6 +715,12 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
k_descale_host(0) = qkv_max / k_dtype_max;
|
||||
v_descale_host(0) = qkv_max / v_dtype_max;
|
||||
}
|
||||
else if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<float>{0.015f, 0.02f, next_seed()}(q_descale_host);
|
||||
ck_tile::FillUniformDistribution<float>{0.015f, 0.02f, next_seed()}(k_descale_host);
|
||||
ck_tile::FillUniformDistribution<float>{0.015f, 0.02f, next_seed()}(v_descale_host);
|
||||
}
|
||||
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
|
||||
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
|
||||
@@ -941,11 +970,14 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_q_descale = num_block_scale_m;
|
||||
const ck_tile::index_t nhead_stride_k_descale = num_block_scale_n;
|
||||
const ck_tile::index_t nhead_stride_v_descale = num_block_scale_n;
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k =
|
||||
@@ -963,6 +995,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
|
||||
const ck_tile::index_t batch_stride_q_descale = num_block_scale_m * nhead;
|
||||
const ck_tile::index_t batch_stride_k_descale = num_block_scale_n * nhead_k;
|
||||
const ck_tile::index_t batch_stride_v_descale = num_block_scale_n * nhead_k;
|
||||
// setup split_stride_* arguments (only used in split-kv kernel)
|
||||
const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q);
|
||||
const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v);
|
||||
@@ -1046,9 +1081,32 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
|
||||
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
|
||||
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
args.q_descale_ptr =
|
||||
reinterpret_cast<const float*>(q_descale_buf.GetDeviceBuffer());
|
||||
args.k_descale_ptr =
|
||||
reinterpret_cast<const float*>(k_descale_buf.GetDeviceBuffer());
|
||||
args.v_descale_ptr =
|
||||
reinterpret_cast<const float*>(v_descale_buf.GetDeviceBuffer());
|
||||
|
||||
args.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
args.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
args.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
args.batch_stride_q_descale = batch_stride_q_descale;
|
||||
args.batch_stride_k_descale = batch_stride_k_descale;
|
||||
args.batch_stride_v_descale = batch_stride_v_descale;
|
||||
|
||||
args.block_scale_m = block_scale_m_;
|
||||
args.block_scale_n = block_scale_n_;
|
||||
}
|
||||
else
|
||||
{
|
||||
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
|
||||
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
|
||||
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
|
||||
}
|
||||
|
||||
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
|
||||
|
||||
@@ -1788,31 +1846,33 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
if(json)
|
||||
{
|
||||
dump_fmha_fwd_json_results(*json,
|
||||
data_type,
|
||||
mode == mode_enum::batch ? "batch" : "group",
|
||||
io_layout(i_perm, o_perm),
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs[0],
|
||||
seqlen_ks[0],
|
||||
seqlen_kpads[0],
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
scale_s,
|
||||
p_drop,
|
||||
lse,
|
||||
qscale.type == quant_scale_enum::no_scale ? "no_scale"
|
||||
: "pertensor",
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? "elementwise_bias"
|
||||
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
|
||||
is_v_rowmajor ? "r" : "c",
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
dump_fmha_fwd_json_results(
|
||||
*json,
|
||||
data_type,
|
||||
mode == mode_enum::batch ? "batch" : "group",
|
||||
io_layout(i_perm, o_perm),
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs[0],
|
||||
seqlen_ks[0],
|
||||
seqlen_kpads[0],
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
scale_s,
|
||||
p_drop,
|
||||
lse,
|
||||
qscale.type == quant_scale_enum::no_scale
|
||||
? "no_scale"
|
||||
: (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"),
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? "elementwise_bias"
|
||||
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
|
||||
is_v_rowmajor ? "r" : "c",
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
|
||||
return pass ? fwd_result::success : fwd_result::failure;
|
||||
|
||||
@@ -13,6 +13,7 @@ enum class quant_scale_enum
|
||||
{
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale,
|
||||
};
|
||||
|
||||
struct quant_scale_info
|
||||
@@ -25,6 +26,8 @@ struct quant_scale_info
|
||||
os << "n";
|
||||
else if(type == quant_scale_enum::pertensor)
|
||||
os << "pt";
|
||||
else if(type == quant_scale_enum::blockscale)
|
||||
os << "bs";
|
||||
}
|
||||
|
||||
static quant_scale_info decode(std::string str)
|
||||
@@ -38,6 +41,10 @@ struct quant_scale_info
|
||||
{
|
||||
info.type = quant_scale_enum::pertensor;
|
||||
}
|
||||
else if(str == "bs" || str == "2")
|
||||
{
|
||||
info.type = quant_scale_enum::blockscale;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid quant scale value: " + str);
|
||||
|
||||
@@ -12,6 +12,7 @@ enum class BlockAttentionQuantScaleEnum
|
||||
{
|
||||
NO_SCALE = 0,
|
||||
PERTENSOR = 1,
|
||||
BLOCKSCALE,
|
||||
};
|
||||
|
||||
template <BlockAttentionQuantScaleEnum>
|
||||
@@ -27,5 +28,11 @@ struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::PERTENSOR
|
||||
{
|
||||
static constexpr const char* name = "pertensor";
|
||||
};
|
||||
template <>
|
||||
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::BLOCKSCALE>
|
||||
{
|
||||
static constexpr const char* name = "blockscale";
|
||||
};
|
||||
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -214,6 +214,23 @@ struct FmhaFwdKernel
|
||||
const void* v_descale_ptr = nullptr;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonBlockScaleKargs : public FmhaFwdCommonQScaleKargs
|
||||
{
|
||||
ck_tile::index_t nhead_stride_q_descale;
|
||||
ck_tile::index_t nhead_stride_k_descale;
|
||||
ck_tile::index_t nhead_stride_v_descale;
|
||||
|
||||
ck_tile::index_t block_scale_m;
|
||||
ck_tile::index_t block_scale_n;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_q_descale;
|
||||
ck_tile::index_t batch_stride_k_descale;
|
||||
ck_tile::index_t batch_stride_v_descale;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonLSEKargs
|
||||
{
|
||||
void* lse_ptr = nullptr;
|
||||
@@ -289,9 +306,12 @@ struct FmhaFwdKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
|
||||
FmhaFwdBatchBlockScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
@@ -315,9 +335,12 @@ struct FmhaFwdKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
|
||||
FmhaFwdCommonBlockScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
|
||||
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
|
||||
@@ -374,6 +397,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
@@ -381,6 +407,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_q_descale,
|
||||
ck_tile::index_t batch_stride_k_descale,
|
||||
ck_tile::index_t batch_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
@@ -388,6 +417,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
ck_tile::index_t block_scale_m,
|
||||
ck_tile::index_t block_scale_n,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
@@ -455,6 +486,23 @@ struct FmhaFwdKernel
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
|
||||
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
kargs.batch_stride_q_descale = batch_stride_q_descale;
|
||||
kargs.batch_stride_k_descale = batch_stride_k_descale;
|
||||
kargs.batch_stride_v_descale = batch_stride_v_descale;
|
||||
|
||||
kargs.block_scale_m = block_scale_m;
|
||||
kargs.block_scale_n = block_scale_n;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -520,6 +568,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
@@ -527,12 +578,17 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_q_descale,
|
||||
ck_tile::index_t batch_stride_k_descale,
|
||||
ck_tile::index_t batch_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_m,
|
||||
ck_tile::index_t block_scale_n,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
@@ -568,6 +624,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -575,12 +634,17 @@ struct FmhaFwdKernel
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
batch_stride_q_descale,
|
||||
batch_stride_k_descale,
|
||||
batch_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_m,
|
||||
block_scale_n,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
}
|
||||
@@ -619,6 +683,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
@@ -626,12 +693,17 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_q_descale,
|
||||
ck_tile::index_t batch_stride_k_descale,
|
||||
ck_tile::index_t batch_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_m,
|
||||
ck_tile::index_t block_scale_n,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
@@ -667,6 +739,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -674,12 +749,17 @@ struct FmhaFwdKernel
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
batch_stride_q_descale,
|
||||
batch_stride_k_descale,
|
||||
batch_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_m,
|
||||
block_scale_n,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
}
|
||||
@@ -719,6 +799,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
@@ -727,6 +810,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
ck_tile::index_t block_scale_m,
|
||||
ck_tile::index_t block_scale_n,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
@@ -793,6 +878,19 @@ struct FmhaFwdKernel
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
|
||||
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
kargs.block_scale_m = block_scale_m;
|
||||
kargs.block_scale_n = block_scale_n;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -863,6 +961,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
@@ -870,6 +971,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_m,
|
||||
ck_tile::index_t block_scale_n,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
@@ -907,6 +1010,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
@@ -914,6 +1020,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_m,
|
||||
block_scale_n,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
}
|
||||
@@ -954,6 +1062,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
@@ -961,6 +1072,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_m,
|
||||
ck_tile::index_t block_scale_n,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
@@ -998,6 +1111,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
@@ -1005,6 +1121,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_m,
|
||||
block_scale_n,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user