add batch block scale parameters to kernel

This commit is contained in:
ltqin
2025-11-06 08:01:41 +00:00
parent 4626bace60
commit 9b341c5d6f
3 changed files with 187 additions and 37 deletions

View File

@@ -196,6 +196,10 @@ struct fmha_fwd_args
const void* seqstart_padded_q_ptr = nullptr; // [batch+1]
const void* seqstart_padded_k_ptr = nullptr; // [batch+1]
const float* q_scale_ptr = nullptr;
const float* k_scale_ptr = nullptr;
const float* v_scale_ptr = nullptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
@@ -224,6 +228,9 @@ struct fmha_fwd_args
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_q_scale;
ck_tile::index_t nhead_stride_k_scale;
ck_tile::index_t nhead_stride_v_scale;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
@@ -231,12 +238,18 @@ struct fmha_fwd_args
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_q_scale;
ck_tile::index_t batch_stride_k_scale;
ck_tile::index_t batch_stride_v_scale;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
ck_tile::index_t min_seqlen_q;
ck_tile::index_t block_scale_m;
ck_tile::index_t block_scale_n;
float p_drop;
bool s_randval;
@@ -596,6 +609,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.rand_val_ptr,
args.lse_ptr,
args.o_ptr,
args.q_scale_ptr,
args.k_scale_ptr,
args.v_scale_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
@@ -619,6 +635,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.nhead_stride_q_scale,
args.nhead_stride_k_scale,
args.nhead_stride_v_scale,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
@@ -626,12 +645,17 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.batch_stride_randval,
args.batch_stride_lse,
args.batch_stride_o,
args.batch_stride_q_scale,
args.batch_stride_k_scale,
args.batch_stride_v_scale,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.block_scale_m,
args.block_scale_n,
args.cu_seqlen_q_ptr,
args.cu_seqlen_kv_ptr);
}

View File

@@ -159,27 +159,29 @@ class BlockQuantizer
template <typename SrcTensor, typename DstTensor, typename ScaleTensor>
void quantize(const SrcTensor& in, DstTensor& out, ScaleTensor& block_scale, size_t block_size_)
{
using InDataType = typename std::remove_reference_t<decltype(in)>::DataType;
using OutDataType = typename std::remove_reference_t<decltype(out)>::DataType;
float dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<InDataType>::max());
size_t batch = in.get_length(0);
size_t head = in.get_length(i_perm ? 1 : 2);
size_t seq_len = in.get_length(i_perm ? 2 : 1);
size_t hdim = in.get_length(3);
using InDataType = typename std::remove_reference_t<decltype(in)>::DataType;
using OutDataType = typename std::remove_reference_t<decltype(out)>::DataType;
float dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<InDataType>::max());
size_t batch = in.get_length(0);
size_t head = in.get_length(i_perm ? 1 : 2);
size_t seq_len = in.get_length(i_perm ? 2 : 1);
size_t hdim = in.get_length(3);
size_t num_blocks_ = (seq_len + block_size_ - 1) / block_size_;
std::cout << "batch: " << batch << " head: " << head << " seq_len: " << seq_len
<< " hdim: " << hdim << " dtype_max: " << dtype_max
<< " num_blocks_: " << num_blocks_ << std::endl;
// std::cout << "batch: " << batch << " head: " << head << " seq_len: " << seq_len
// << " hdim: " << hdim << " dtype_max: " << dtype_max
// << " num_blocks_: " << num_blocks_ << std::endl;
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(0.5, 2.0f);
for(size_t b = 0; b < batch; ++b){
for(size_t b = 0; b < batch; ++b)
{
for(size_t h = 0; h < head; ++h)
{
for(size_t block = 0; block < num_blocks_; ++block)
{
// get block max value
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<InDataType>::min());
float max_value =
ck_tile::type_convert<float>(ck_tile::numeric<InDataType>::min());
for(size_t s = block * block_size_;
s < (block + 1) * block_size_ && s < seq_len;
++s)
@@ -196,11 +198,12 @@ class BlockQuantizer
}
// calculate block scale
max_value += dis(gen);
float scale = dtype_max / max_value;
block_scale(b,h,block) = scale;
std::cout << "block: " << block << " scale: " << scale << " max_value: " << max_value << " block_scale: " << block_scale << std::endl;
//quant
float scale = dtype_max / max_value;
block_scale(b, h, block) = scale;
// std::cout << "block: " << block << " scale: " << scale << " max_value: " <<
// max_value << " block_scale: " << block_scale << std::endl;
// quant
for(size_t s = block * block_size_;
s < (block + 1) * block_size_ && s < seq_len;
++s)
@@ -211,26 +214,28 @@ class BlockQuantizer
if(!i_perm)
idx = {b, s, h, d};
float val = ck_tile::type_convert<float>(in(idx));
out(idx) = ck_tile::type_convert<OutDataType>(val * scale);
out(idx) = ck_tile::type_convert<OutDataType>(val * scale);
}
}
}
}
}
}
}
template <typename SrcTensor, typename DstTensor, typename ScaleTensor>
void dequantize(const SrcTensor& in, DstTensor& out, ScaleTensor& block_scale, size_t block_size_)
void
dequantize(const SrcTensor& in, DstTensor& out, ScaleTensor& block_scale, size_t block_size_)
{
using OutDataType = typename std::remove_reference_t<decltype(out)>::DataType;
size_t batch = in.get_length(0);
size_t head = in.get_length(i_perm ? 1 : 2);
size_t seq_len = in.get_length(i_perm ? 2 : 1);
size_t hdim = in.get_length(3);
using OutDataType = typename std::remove_reference_t<decltype(out)>::DataType;
size_t batch = in.get_length(0);
size_t head = in.get_length(i_perm ? 1 : 2);
size_t seq_len = in.get_length(i_perm ? 2 : 1);
size_t hdim = in.get_length(3);
size_t num_blocks_ = (seq_len + block_size_ - 1) / block_size_;
//dequant
for(size_t b = 0; b < batch; ++b){
// dequant
for(size_t b = 0; b < batch; ++b)
{
for(size_t h = 0; h < head; ++h)
{
for(size_t block = 0; block < num_blocks_; ++block)
@@ -245,14 +250,13 @@ class BlockQuantizer
std::vector<size_t> idx = {b, h, s, d};
if(!i_perm)
idx = {b, s, h, d};
float val = ck_tile::type_convert<float>(in(idx));
out(idx) = ck_tile::type_convert<OutDataType>(val / scale);
float val = ck_tile::type_convert<float>(in(idx));
out(idx) = ck_tile::type_convert<OutDataType>(val / scale);
}
}
}
}
}
}
};
@@ -651,8 +655,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
0 < page_block_size
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
ck_tile::HostTensor<float> k_scale(std::array<ck_tile::index_t, 3>{
shape_batch, nhead_k, num_block_scale_n});
ck_tile::HostTensor<float> k_scale(
std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_n});
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
ck_tile::HostTensor<KDataType> knew_host(
0 < seqlen_knew
@@ -665,8 +669,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
: get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size))
: (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
ck_tile::HostTensor<float> v_scale(std::array<ck_tile::index_t, 3>{
shape_batch, nhead_k, num_block_scale_n});
ck_tile::HostTensor<float> v_scale(
std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_n});
ck_tile::HostTensor<VDataType> vnew_host(
0 < seqlen_knew
? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v)
@@ -906,7 +910,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::DeviceMem k_scale_buf(k_scale.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_scale_buf(v_scale.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
v_buf.ToDevice(v_host.data());
@@ -940,7 +943,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
if(quant == 2)
{
//dequant data for host
// dequant data for host
BlockQuantizer quantizer(i_perm);
// q_host.savetxt("./q_quant.txt");
quantizer.dequantize(q_host, q_host, q_scale, block_scale_m_);
@@ -1125,6 +1128,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse);
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_q_scale = num_block_scale_m;
const ck_tile::index_t nhead_stride_k_scale = num_block_scale_n;
const ck_tile::index_t nhead_stride_v_scale = num_block_scale_n;
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k =
@@ -1142,6 +1148,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
const ck_tile::index_t batch_stride_q_scale = num_block_scale_m * nhead;
const ck_tile::index_t batch_stride_k_scale = num_block_scale_n * nhead_k;
const ck_tile::index_t batch_stride_v_scale = num_block_scale_n * nhead_k;
// setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v);
@@ -1235,6 +1244,27 @@ fwd_result fmha_fwd_run(mode_enum mode,
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
{
if(quant == 2)
{
args.q_scale_ptr =
reinterpret_cast<const float*>(q_scale_buf.GetDeviceBuffer());
args.k_scale_ptr =
reinterpret_cast<const float*>(k_scale_buf.GetDeviceBuffer());
args.v_scale_ptr =
reinterpret_cast<const float*>(v_scale_buf.GetDeviceBuffer());
args.nhead_stride_q_scale = nhead_stride_q_scale;
args.nhead_stride_k_scale = nhead_stride_k_scale;
args.nhead_stride_v_scale = nhead_stride_v_scale;
args.batch_stride_q_scale = batch_stride_q_scale;
args.batch_stride_k_scale = batch_stride_k_scale;
args.batch_stride_v_scale = batch_stride_v_scale;
args.block_scale_m = block_scale_m_;
args.block_scale_n = block_scale_n_;
}
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
args.stride_randval = stride_randval;

View File

@@ -156,6 +156,27 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o;
};
struct FmhaFwdCommonBlockScaleKargs
{
const float* q_scale_ptr = nullptr;
const float* k_scale_ptr = nullptr;
const float* v_scale_ptr = nullptr;
ck_tile::index_t nhead_stride_q_scale;
ck_tile::index_t nhead_stride_k_scale;
ck_tile::index_t nhead_stride_v_scale;
ck_tile::index_t block_scale_m;
ck_tile::index_t block_scale_n;
};
struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs
{
ck_tile::index_t batch_stride_q_scale;
ck_tile::index_t batch_stride_k_scale;
ck_tile::index_t batch_stride_v_scale;
};
struct FmhaFwdLogitsSoftCapKargs
{
FmhaFwdLogitsSoftCapKargs() = default;
@@ -287,7 +308,8 @@ struct FmhaFwdKernel
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdBatchBlockScaleKargs, FmhaFwdEmptyKargs<6>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
@@ -342,6 +364,9 @@ struct FmhaFwdKernel
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const float* q_scale_ptr,
const float* k_scale_ptr,
const float* v_scale_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
@@ -365,6 +390,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_scale,
ck_tile::index_t nhead_stride_k_scale,
ck_tile::index_t nhead_stride_v_scale,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
@@ -372,6 +400,9 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_q_scale,
ck_tile::index_t batch_stride_k_scale,
ck_tile::index_t batch_stride_v_scale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
@@ -379,6 +410,8 @@ struct FmhaFwdKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
ck_tile::index_t block_scale_m,
ck_tile::index_t block_scale_n,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
{
@@ -411,6 +444,7 @@ struct FmhaFwdKernel
{}, // placeholder for fp8_static_quant args
{}, // placeholder for dropout
{}, // placeholder for logits_soft_cap
{}, // palceholder for quant scale
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -471,6 +505,24 @@ struct FmhaFwdKernel
kargs.init_logits_soft_cap(logits_soft_cap);
}
if constexpr(kDoFp8StaticQuant)
{
kargs.q_scale_ptr = q_scale_ptr;
kargs.k_scale_ptr = k_scale_ptr;
kargs.v_scale_ptr = v_scale_ptr;
kargs.nhead_stride_q_scale = nhead_stride_q_scale;
kargs.nhead_stride_k_scale = nhead_stride_k_scale;
kargs.nhead_stride_v_scale = nhead_stride_v_scale;
kargs.batch_stride_q_scale = batch_stride_q_scale;
kargs.batch_stride_k_scale = batch_stride_k_scale;
kargs.batch_stride_v_scale = batch_stride_v_scale;
kargs.block_scale_m = block_scale_m;
kargs.block_scale_n = block_scale_n;
}
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
return kargs;
@@ -486,6 +538,9 @@ struct FmhaFwdKernel
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const float* q_scale_ptr,
const float* k_scale_ptr,
const float* v_scale_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
@@ -509,6 +564,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_scale,
ck_tile::index_t nhead_stride_k_scale,
ck_tile::index_t nhead_stride_v_scale,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
@@ -516,12 +574,17 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_q_scale,
ck_tile::index_t batch_stride_k_scale,
ck_tile::index_t batch_stride_v_scale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
ck_tile::index_t block_scale_m,
ck_tile::index_t block_scale_n,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
{
@@ -533,6 +596,9 @@ struct FmhaFwdKernel
rand_val_ptr,
lse_ptr,
o_ptr,
q_scale_ptr,
k_scale_ptr,
v_scale_ptr,
seqlen_q,
seqlen_k,
hdim_q,
@@ -556,6 +622,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_scale,
nhead_stride_k_scale,
nhead_stride_v_scale,
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -563,12 +632,17 @@ struct FmhaFwdKernel
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
batch_stride_q_scale,
batch_stride_k_scale,
batch_stride_v_scale,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_m,
block_scale_n,
cu_seqlen_q_ptr,
cu_seqlen_kv_ptr);
}
@@ -583,6 +657,9 @@ struct FmhaFwdKernel
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const float* q_scale_ptr,
const float* k_scale_ptr,
const float* v_scale_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
@@ -606,6 +683,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_scale,
ck_tile::index_t nhead_stride_k_scale,
ck_tile::index_t nhead_stride_v_scale,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
@@ -613,12 +693,17 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_q_scale,
ck_tile::index_t batch_stride_k_scale,
ck_tile::index_t batch_stride_v_scale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset,
ck_tile::index_t block_scale_m,
ck_tile::index_t block_scale_n,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
{
@@ -630,6 +715,9 @@ struct FmhaFwdKernel
rand_val_ptr,
lse_ptr,
o_ptr,
q_scale_ptr,
k_scale_ptr,
v_scale_ptr,
seqlen_q,
seqlen_k,
hdim_q,
@@ -653,6 +741,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_scale,
nhead_stride_k_scale,
nhead_stride_v_scale,
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -660,12 +751,17 @@ struct FmhaFwdKernel
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
batch_stride_q_scale,
batch_stride_k_scale,
batch_stride_v_scale,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_m,
block_scale_n,
cu_seqlen_q_ptr,
cu_seqlen_kv_ptr);
}