add group parameters for block quant

This commit is contained in:
ltqin
2025-11-06 12:32:36 +00:00
parent 9b341c5d6f
commit 0aea348865
2 changed files with 106 additions and 1 deletions

View File

@@ -569,6 +569,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.q_scale_ptr,
args.k_scale_ptr,
args.v_scale_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
@@ -590,6 +593,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.nhead_stride_q_scale,
args.nhead_stride_k_scale,
args.nhead_stride_v_scale,
args.window_size_left,
args.window_size_right,
args.mask_type,
@@ -597,6 +603,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.block_scale_m,
args.block_scale_n,
args.seqstart_padded_q_ptr,
args.seqstart_padded_k_ptr);
}

View File

@@ -334,7 +334,8 @@ struct FmhaFwdKernel
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdCommonBlockScaleKargs, FmhaFwdEmptyKargs<7>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
@@ -778,6 +779,9 @@ struct FmhaFwdKernel
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
const float* q_scale_ptr,
const float* k_scale_ptr,
const float* v_scale_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -799,6 +803,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_scale,
ck_tile::index_t nhead_stride_k_scale,
ck_tile::index_t nhead_stride_v_scale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
@@ -807,6 +814,8 @@ struct FmhaFwdKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
ck_tile::index_t block_scale_m,
ck_tile::index_t block_scale_n,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
{
@@ -840,6 +849,7 @@ struct FmhaFwdKernel
{}, // placeholder for dropout
{}, // placeholder for logits_soft_cap
{}, // placeholder for min_seqlen_q
{}, // placeholder for block quant scale
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
@@ -899,6 +909,19 @@ struct FmhaFwdKernel
{
kargs.min_seqlen_q = min_seqlen_q;
}
if constexpr(kDoFp8StaticQuant)
{
kargs.q_scale_ptr = q_scale_ptr;
kargs.k_scale_ptr = k_scale_ptr;
kargs.v_scale_ptr = v_scale_ptr;
kargs.nhead_stride_q_scale = nhead_stride_q_scale;
kargs.nhead_stride_k_scale = nhead_stride_k_scale;
kargs.nhead_stride_v_scale = nhead_stride_v_scale;
kargs.block_scale_m = block_scale_m;
kargs.block_scale_n = block_scale_n;
}
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
@@ -918,6 +941,9 @@ struct FmhaFwdKernel
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
const float* q_scale_ptr,
const float* k_scale_ptr,
const float* v_scale_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -939,6 +965,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_scale,
ck_tile::index_t nhead_stride_k_scale,
ck_tile::index_t nhead_stride_v_scale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
@@ -946,6 +975,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
ck_tile::index_t block_scale_m,
ck_tile::index_t block_scale_n,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
{
@@ -960,6 +991,9 @@ struct FmhaFwdKernel
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
q_scale_ptr,
k_scale_ptr,
v_scale_ptr,
hdim_q,
hdim_v,
num_head_q,
@@ -981,6 +1015,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_scale,
nhead_stride_k_scale,
nhead_stride_v_scale,
window_size_left,
window_size_right,
mask_type,
@@ -988,6 +1025,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_m,
block_scale_n,
seqstart_padded_q_ptr,
seqstart_padded_k_ptr);
}
@@ -1005,6 +1044,9 @@ struct FmhaFwdKernel
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
const float* q_scale_ptr,
const float* k_scale_ptr,
const float* v_scale_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -1026,6 +1068,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_scale,
ck_tile::index_t nhead_stride_k_scale,
ck_tile::index_t nhead_stride_v_scale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
@@ -1033,6 +1078,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset,
ck_tile::index_t block_scale_m,
ck_tile::index_t block_scale_n,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
{
@@ -1047,6 +1094,9 @@ struct FmhaFwdKernel
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
q_scale_ptr,
k_scale_ptr,
v_scale_ptr,
hdim_q,
hdim_v,
num_head_q,
@@ -1068,6 +1118,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_scale,
nhead_stride_k_scale,
nhead_stride_v_scale,
window_size_left,
window_size_right,
mask_type,
@@ -1075,6 +1128,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_m,
block_scale_n,
seqstart_padded_q_ptr,
seqstart_padded_k_ptr);
}
@@ -1202,6 +1257,9 @@ struct FmhaFwdKernel
long_index_t batch_offset_randval = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
long_index_t batch_offset_q_scale = 0;
long_index_t batch_offset_k_scale = 0;
long_index_t batch_offset_v_scale = 0;
if constexpr(kIsGroupMode)
{
@@ -1240,6 +1298,15 @@ struct FmhaFwdKernel
{
batch_offset_randval = query_start_padded * kargs.stride_randval;
}
if constexpr(kDoFp8StaticQuant)
{
if(kargs.q_scale_ptr)
{
batch_offset_q_scale = query_start_padded;
batch_offset_k_scale = key_start_padded;
batch_offset_v_scale = key_start_padded;
}
}
batch_offset_o = query_start_padded * kargs.stride_o;
// real logical lengths (exclude PAD)
@@ -1289,6 +1356,18 @@ struct FmhaFwdKernel
batch_offset_randval =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
}
if constexpr(kDoFp8StaticQuant)
{
if(kargs.q_scale_ptr)
{
batch_offset_q_scale =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q_scale;
batch_offset_k_scale =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_k_scale;
batch_offset_v_scale =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_v_scale;
}
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
// If cumulative seqlen pointers are provided, override per-batch effective lengths
@@ -1320,6 +1399,24 @@ struct FmhaFwdKernel
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
batch_offset_o;
const float* q_scale_ptr = nullptr;
const float* k_scale_ptr = nullptr;
const float* v_scale_ptr = nullptr;
if constexpr(kDoFp8StaticQuant)
{
q_scale_ptr = reinterpret_cast<const float*>(kargs.q_scale_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q_scale +
batch_offset_q_scale;
k_scale_ptr = reinterpret_cast<const float*>(kargs.k_scale_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_k_scale +
batch_offset_k_scale;
v_scale_ptr = reinterpret_cast<const float*>(kargs.v_scale_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_v_scale +
batch_offset_v_scale;
}
ck_tile::ignore = q_scale_ptr;
ck_tile::ignore = k_scale_ptr;
ck_tile::ignore = v_scale_ptr;
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(