add blockscale parameters to kernel

This commit is contained in:
ltqin
2025-11-19 06:19:18 +00:00
parent 907dc988dc
commit a183b4dc29
7 changed files with 262 additions and 46 deletions

View File

@@ -66,11 +66,13 @@ def get_mask_check_map(mask: str):
QSCALE_MAP = {
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
}
QSCALE_CHECK_MAP = {
"no": "quant_scale_enum::no_scale",
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
}
BIAS_MAP = {

View File

@@ -343,7 +343,7 @@ class FmhaFwdPipeline:
F_bias: str # true/false
F_lse: str #
F_dropout: str #
F_qscale: str # no/pertensor
F_qscale: str # no/pertensor/blockscale
F_mask: str # value from MASK_MAP
F_skip: str # true/false
F_trload: str # true/false
@@ -739,7 +739,7 @@ class KernelComponentFactoryGfx9:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(
["f"],
["no", "pertensor"],
["no", "pertensor", "blockscale"],
get_mask_map(mask_impl).keys(),
["no"],
):
@@ -830,7 +830,7 @@ class KernelComponentFactoryGfx12:
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(
["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"]
["f"], ["no", "pertensor", "blockscale"], get_mask_map(mask_impl).keys(), ["no"]
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip

View File

@@ -256,6 +256,9 @@ struct fmha_fwd_args
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_q_descale;
ck_tile::index_t nhead_stride_k_descale;
ck_tile::index_t nhead_stride_v_descale;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
@@ -263,6 +266,9 @@ struct fmha_fwd_args
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_q_descale;
ck_tile::index_t batch_stride_k_descale;
ck_tile::index_t batch_stride_v_descale;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
@@ -274,6 +280,9 @@ struct fmha_fwd_args
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
ck_tile::index_t block_scale_m;
ck_tile::index_t block_scale_n;
};
struct fmha_fwd_pagedkv_args
@@ -604,6 +613,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_o,
args.nhead_stride_q_descale,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
@@ -618,6 +630,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.block_scale_m,
args.block_scale_n,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
}
@@ -654,6 +668,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.nhead_stride_q_descale,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
@@ -661,12 +678,17 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.batch_stride_randval,
args.batch_stride_lse,
args.batch_stride_o,
args.batch_stride_q_descale,
args.batch_stride_k_descale,
args.batch_stride_v_descale,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.block_scale_m,
args.block_scale_n,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
}

View File

@@ -187,6 +187,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::stream_config& stream_config,
std::optional<std::string> json = std::nullopt)
{
constexpr ck_tile::index_t block_scale_m_ = 128;
constexpr ck_tile::index_t block_scale_n_ = 128;
const std::string data_type = []() {
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp32>)
return "fp32";
@@ -448,7 +451,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::size_t flop = 0, num_byte = 0;
auto max_seqlen_q =
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
size_t num_block_scale_q = 0;
size_t num_block_scale_k = 0;
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
{
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
@@ -464,6 +469,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
{
max_seqlen_k = real_seqlen_k;
}
num_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_m_);
num_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_n_);
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
@@ -525,6 +532,13 @@ fwd_result fmha_fwd_run(mode_enum mode,
? seqstart_k_with_padding_host.back()
: seqstart_k_host.back()));
const ck_tile::index_t num_block_scale_m =
(mode == mode_enum::batch) ? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_m_)
: num_block_scale_q;
const ck_tile::index_t num_block_scale_n =
(mode == mode_enum::batch) ? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_n_)
: num_block_scale_k;
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KDataType> k_host(
@@ -575,9 +589,18 @@ fwd_result fmha_fwd_run(mode_enum mode,
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// TODO - change the tensor length for different quant scale
ck_tile::HostTensor<float> q_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
ck_tile::HostTensor<float> k_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
ck_tile::HostTensor<float> v_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
ck_tile::HostTensor<float> q_descale_host(
qscale.type == quant_scale_enum::blockscale
? std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_m}
: std::array<ck_tile::index_t, 3>{1, 1, 1});
ck_tile::HostTensor<float> k_descale_host(
qscale.type == quant_scale_enum::blockscale
? std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_n}
: std::array<ck_tile::index_t, 3>{1, 1, 1});
ck_tile::HostTensor<float> v_descale_host(
qscale.type == quant_scale_enum::blockscale
? std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_n}
: std::array<ck_tile::index_t, 3>{1, 1, 1});
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
@@ -692,6 +715,12 @@ fwd_result fmha_fwd_run(mode_enum mode,
k_descale_host(0) = qkv_max / k_dtype_max;
v_descale_host(0) = qkv_max / v_dtype_max;
}
else if(qscale.type == quant_scale_enum::blockscale)
{
ck_tile::FillUniformDistribution<float>{0.015f, 0.02f, next_seed()}(q_descale_host);
ck_tile::FillUniformDistribution<float>{0.015f, 0.02f, next_seed()}(k_descale_host);
ck_tile::FillUniformDistribution<float>{0.015f, 0.02f, next_seed()}(v_descale_host);
}
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
@@ -941,11 +970,14 @@ fwd_result fmha_fwd_run(mode_enum mode,
}();
const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_q_descale = num_block_scale_m;
const ck_tile::index_t nhead_stride_k_descale = num_block_scale_n;
const ck_tile::index_t nhead_stride_v_descale = num_block_scale_n;
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k =
@@ -963,6 +995,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
const ck_tile::index_t batch_stride_q_descale = num_block_scale_m * nhead;
const ck_tile::index_t batch_stride_k_descale = num_block_scale_n * nhead_k;
const ck_tile::index_t batch_stride_v_descale = num_block_scale_n * nhead_k;
// setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v);
@@ -1046,9 +1081,32 @@ fwd_result fmha_fwd_run(mode_enum mode,
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
{
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
if(qscale.type == quant_scale_enum::blockscale)
{
args.q_descale_ptr =
reinterpret_cast<const float*>(q_descale_buf.GetDeviceBuffer());
args.k_descale_ptr =
reinterpret_cast<const float*>(k_descale_buf.GetDeviceBuffer());
args.v_descale_ptr =
reinterpret_cast<const float*>(v_descale_buf.GetDeviceBuffer());
args.nhead_stride_q_descale = nhead_stride_q_descale;
args.nhead_stride_k_descale = nhead_stride_k_descale;
args.nhead_stride_v_descale = nhead_stride_v_descale;
args.batch_stride_q_descale = batch_stride_q_descale;
args.batch_stride_k_descale = batch_stride_k_descale;
args.batch_stride_v_descale = batch_stride_v_descale;
args.block_scale_m = block_scale_m_;
args.block_scale_n = block_scale_n_;
}
else
{
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
}
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
@@ -1788,31 +1846,33 @@ fwd_result fmha_fwd_run(mode_enum mode,
if(json)
{
dump_fmha_fwd_json_results(*json,
data_type,
mode == mode_enum::batch ? "batch" : "group",
io_layout(i_perm, o_perm),
batch,
nhead,
nhead_k,
seqlen_qs[0],
seqlen_ks[0],
seqlen_kpads[0],
hdim_q,
hdim_v,
scale_s,
p_drop,
lse,
qscale.type == quant_scale_enum::no_scale ? "no_scale"
: "pertensor",
bias.type == bias_enum::elementwise_bias
? "elementwise_bias"
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
is_v_rowmajor ? "r" : "c",
pass,
ave_time,
tflops,
gb_per_sec);
dump_fmha_fwd_json_results(
*json,
data_type,
mode == mode_enum::batch ? "batch" : "group",
io_layout(i_perm, o_perm),
batch,
nhead,
nhead_k,
seqlen_qs[0],
seqlen_ks[0],
seqlen_kpads[0],
hdim_q,
hdim_v,
scale_s,
p_drop,
lse,
qscale.type == quant_scale_enum::no_scale
? "no_scale"
: (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"),
bias.type == bias_enum::elementwise_bias
? "elementwise_bias"
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
is_v_rowmajor ? "r" : "c",
pass,
ave_time,
tflops,
gb_per_sec);
}
return pass ? fwd_result::success : fwd_result::failure;

View File

@@ -13,6 +13,7 @@ enum class quant_scale_enum
{
no_scale = 0,
pertensor = 1,
blockscale,
};
struct quant_scale_info
@@ -25,6 +26,8 @@ struct quant_scale_info
os << "n";
else if(type == quant_scale_enum::pertensor)
os << "pt";
else if(type == quant_scale_enum::blockscale)
os << "bs";
}
static quant_scale_info decode(std::string str)
@@ -38,6 +41,10 @@ struct quant_scale_info
{
info.type = quant_scale_enum::pertensor;
}
else if(str == "bs" || str == "2")
{
info.type = quant_scale_enum::blockscale;
}
else
{
throw std::invalid_argument("invalid quant scale value: " + str);

View File

@@ -12,6 +12,7 @@ enum class BlockAttentionQuantScaleEnum
{
NO_SCALE = 0,
PERTENSOR = 1,
BLOCKSCALE,
};
template <BlockAttentionQuantScaleEnum>
@@ -27,5 +28,11 @@ struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::PERTENSOR
{
static constexpr const char* name = "pertensor";
};
template <>
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::BLOCKSCALE>
{
static constexpr const char* name = "blockscale";
};
} // namespace ck_tile

View File

@@ -214,6 +214,23 @@ struct FmhaFwdKernel
const void* v_descale_ptr = nullptr;
};
struct FmhaFwdCommonBlockScaleKargs : public FmhaFwdCommonQScaleKargs
{
ck_tile::index_t nhead_stride_q_descale;
ck_tile::index_t nhead_stride_k_descale;
ck_tile::index_t nhead_stride_v_descale;
ck_tile::index_t block_scale_m;
ck_tile::index_t block_scale_n;
};
struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs
{
ck_tile::index_t batch_stride_q_descale;
ck_tile::index_t batch_stride_k_descale;
ck_tile::index_t batch_stride_v_descale;
};
struct FmhaFwdCommonLSEKargs
{
void* lse_ptr = nullptr;
@@ -289,9 +306,12 @@ struct FmhaFwdKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
std::conditional_t<
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
FmhaFwdBatchBlockScaleKargs,
FmhaFwdEmptyKargs<3>>>,
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
{
@@ -315,9 +335,12 @@ struct FmhaFwdKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
std::conditional_t<
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
FmhaFwdCommonBlockScaleKargs,
FmhaFwdEmptyKargs<3>>>,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
@@ -374,6 +397,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
@@ -381,6 +407,9 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_q_descale,
ck_tile::index_t batch_stride_k_descale,
ck_tile::index_t batch_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
@@ -388,6 +417,8 @@ struct FmhaFwdKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
ck_tile::index_t block_scale_m,
ck_tile::index_t block_scale_n,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
@@ -455,6 +486,23 @@ struct FmhaFwdKernel
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
kargs.batch_stride_q_descale = batch_stride_q_descale;
kargs.batch_stride_k_descale = batch_stride_k_descale;
kargs.batch_stride_v_descale = batch_stride_v_descale;
kargs.block_scale_m = block_scale_m;
kargs.block_scale_n = block_scale_n;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -520,6 +568,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
@@ -527,12 +578,17 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_q_descale,
ck_tile::index_t batch_stride_k_descale,
ck_tile::index_t batch_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
ck_tile::index_t block_scale_m,
ck_tile::index_t block_scale_n,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
@@ -568,6 +624,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_descale,
nhead_stride_k_descale,
nhead_stride_v_descale,
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -575,12 +634,17 @@ struct FmhaFwdKernel
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
batch_stride_q_descale,
batch_stride_k_descale,
batch_stride_v_descale,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_m,
block_scale_n,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr);
}
@@ -619,6 +683,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
@@ -626,12 +693,17 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_q_descale,
ck_tile::index_t batch_stride_k_descale,
ck_tile::index_t batch_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset,
ck_tile::index_t block_scale_m,
ck_tile::index_t block_scale_n,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
@@ -667,6 +739,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_descale,
nhead_stride_k_descale,
nhead_stride_v_descale,
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -674,12 +749,17 @@ struct FmhaFwdKernel
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
batch_stride_q_descale,
batch_stride_k_descale,
batch_stride_v_descale,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_m,
block_scale_n,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr);
}
@@ -719,6 +799,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
@@ -727,6 +810,8 @@ struct FmhaFwdKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
ck_tile::index_t block_scale_m,
ck_tile::index_t block_scale_n,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
@@ -793,6 +878,19 @@ struct FmhaFwdKernel
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
kargs.block_scale_m = block_scale_m;
kargs.block_scale_n = block_scale_n;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -863,6 +961,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
@@ -870,6 +971,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
ck_tile::index_t block_scale_m,
ck_tile::index_t block_scale_n,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
@@ -907,6 +1010,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_descale,
nhead_stride_k_descale,
nhead_stride_v_descale,
window_size_left,
window_size_right,
mask_type,
@@ -914,6 +1020,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_m,
block_scale_n,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr);
}
@@ -954,6 +1062,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
@@ -961,6 +1072,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset,
ck_tile::index_t block_scale_m,
ck_tile::index_t block_scale_n,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
@@ -998,6 +1111,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_descale,
nhead_stride_k_descale,
nhead_stride_v_descale,
window_size_left,
window_size_right,
mask_type,
@@ -1005,6 +1121,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_m,
block_scale_n,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr);
}