mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 21:58:13 +00:00
add group parameters for block quant
This commit is contained in:
@@ -569,6 +569,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.q_scale_ptr,
|
||||
args.k_scale_ptr,
|
||||
args.v_scale_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
@@ -590,6 +593,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_q_scale,
|
||||
args.nhead_stride_k_scale,
|
||||
args.nhead_stride_v_scale,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
@@ -597,6 +603,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.block_scale_m,
|
||||
args.block_scale_n,
|
||||
args.seqstart_padded_q_ptr,
|
||||
args.seqstart_padded_k_ptr);
|
||||
}
|
||||
|
||||
@@ -334,7 +334,8 @@ struct FmhaFwdKernel
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
|
||||
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
|
||||
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdCommonBlockScaleKargs, FmhaFwdEmptyKargs<7>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -778,6 +779,9 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const float* q_scale_ptr,
|
||||
const float* k_scale_ptr,
|
||||
const float* v_scale_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -799,6 +803,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_scale,
|
||||
ck_tile::index_t nhead_stride_k_scale,
|
||||
ck_tile::index_t nhead_stride_v_scale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
@@ -807,6 +814,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
ck_tile::index_t block_scale_m,
|
||||
ck_tile::index_t block_scale_n,
|
||||
const void* seqstart_padded_q_ptr = nullptr,
|
||||
const void* seqstart_padded_k_ptr = nullptr)
|
||||
{
|
||||
@@ -840,6 +849,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for logits_soft_cap
|
||||
{}, // placeholder for min_seqlen_q
|
||||
{}, // placeholder for block quant scale
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
@@ -899,6 +909,19 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.min_seqlen_q = min_seqlen_q;
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
kargs.q_scale_ptr = q_scale_ptr;
|
||||
kargs.k_scale_ptr = k_scale_ptr;
|
||||
kargs.v_scale_ptr = v_scale_ptr;
|
||||
|
||||
kargs.nhead_stride_q_scale = nhead_stride_q_scale;
|
||||
kargs.nhead_stride_k_scale = nhead_stride_k_scale;
|
||||
kargs.nhead_stride_v_scale = nhead_stride_v_scale;
|
||||
|
||||
kargs.block_scale_m = block_scale_m;
|
||||
kargs.block_scale_n = block_scale_n;
|
||||
}
|
||||
|
||||
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
|
||||
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
|
||||
@@ -918,6 +941,9 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const float* q_scale_ptr,
|
||||
const float* k_scale_ptr,
|
||||
const float* v_scale_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -939,6 +965,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_scale,
|
||||
ck_tile::index_t nhead_stride_k_scale,
|
||||
ck_tile::index_t nhead_stride_v_scale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
@@ -946,6 +975,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_m,
|
||||
ck_tile::index_t block_scale_n,
|
||||
const void* seqstart_padded_q_ptr = nullptr,
|
||||
const void* seqstart_padded_k_ptr = nullptr)
|
||||
{
|
||||
@@ -960,6 +991,9 @@ struct FmhaFwdKernel
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
q_scale_ptr,
|
||||
k_scale_ptr,
|
||||
v_scale_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -981,6 +1015,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_scale,
|
||||
nhead_stride_k_scale,
|
||||
nhead_stride_v_scale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
@@ -988,6 +1025,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_m,
|
||||
block_scale_n,
|
||||
seqstart_padded_q_ptr,
|
||||
seqstart_padded_k_ptr);
|
||||
}
|
||||
@@ -1005,6 +1044,9 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const float* q_scale_ptr,
|
||||
const float* k_scale_ptr,
|
||||
const float* v_scale_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -1026,6 +1068,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_scale,
|
||||
ck_tile::index_t nhead_stride_k_scale,
|
||||
ck_tile::index_t nhead_stride_v_scale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
@@ -1033,6 +1078,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_m,
|
||||
ck_tile::index_t block_scale_n,
|
||||
const void* seqstart_padded_q_ptr = nullptr,
|
||||
const void* seqstart_padded_k_ptr = nullptr)
|
||||
{
|
||||
@@ -1047,6 +1094,9 @@ struct FmhaFwdKernel
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
q_scale_ptr,
|
||||
k_scale_ptr,
|
||||
v_scale_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -1068,6 +1118,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_scale,
|
||||
nhead_stride_k_scale,
|
||||
nhead_stride_v_scale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
@@ -1075,6 +1128,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_m,
|
||||
block_scale_n,
|
||||
seqstart_padded_q_ptr,
|
||||
seqstart_padded_k_ptr);
|
||||
}
|
||||
@@ -1202,6 +1257,9 @@ struct FmhaFwdKernel
|
||||
long_index_t batch_offset_randval = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
long_index_t batch_offset_q_scale = 0;
|
||||
long_index_t batch_offset_k_scale = 0;
|
||||
long_index_t batch_offset_v_scale = 0;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
@@ -1240,6 +1298,15 @@ struct FmhaFwdKernel
|
||||
{
|
||||
batch_offset_randval = query_start_padded * kargs.stride_randval;
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
if(kargs.q_scale_ptr)
|
||||
{
|
||||
batch_offset_q_scale = query_start_padded;
|
||||
batch_offset_k_scale = key_start_padded;
|
||||
batch_offset_v_scale = key_start_padded;
|
||||
}
|
||||
}
|
||||
batch_offset_o = query_start_padded * kargs.stride_o;
|
||||
|
||||
// real logical lengths (exclude PAD)
|
||||
@@ -1289,6 +1356,18 @@ struct FmhaFwdKernel
|
||||
batch_offset_randval =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
if(kargs.q_scale_ptr)
|
||||
{
|
||||
batch_offset_q_scale =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q_scale;
|
||||
batch_offset_k_scale =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_k_scale;
|
||||
batch_offset_v_scale =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_v_scale;
|
||||
}
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
|
||||
// If cumulative seqlen pointers are provided, override per-batch effective lengths
|
||||
@@ -1320,6 +1399,24 @@ struct FmhaFwdKernel
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
batch_offset_o;
|
||||
|
||||
const float* q_scale_ptr = nullptr;
|
||||
const float* k_scale_ptr = nullptr;
|
||||
const float* v_scale_ptr = nullptr;
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
q_scale_ptr = reinterpret_cast<const float*>(kargs.q_scale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q_scale +
|
||||
batch_offset_q_scale;
|
||||
k_scale_ptr = reinterpret_cast<const float*>(kargs.k_scale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_k_scale +
|
||||
batch_offset_k_scale;
|
||||
v_scale_ptr = reinterpret_cast<const float*>(kargs.v_scale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_v_scale +
|
||||
batch_offset_v_scale;
|
||||
}
|
||||
ck_tile::ignore = q_scale_ptr;
|
||||
ck_tile::ignore = k_scale_ptr;
|
||||
ck_tile::ignore = v_scale_ptr;
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = [&]() {
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
|
||||
Reference in New Issue
Block a user