mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
add batch block scale parameters to kernel
This commit is contained in:
@@ -196,6 +196,10 @@ struct fmha_fwd_args
|
||||
const void* seqstart_padded_q_ptr = nullptr; // [batch+1]
|
||||
const void* seqstart_padded_k_ptr = nullptr; // [batch+1]
|
||||
|
||||
const float* q_scale_ptr = nullptr;
|
||||
const float* k_scale_ptr = nullptr;
|
||||
const float* v_scale_ptr = nullptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
@@ -224,6 +228,9 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_q_scale;
|
||||
ck_tile::index_t nhead_stride_k_scale;
|
||||
ck_tile::index_t nhead_stride_v_scale;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
@@ -231,12 +238,18 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_q_scale;
|
||||
ck_tile::index_t batch_stride_k_scale;
|
||||
ck_tile::index_t batch_stride_v_scale;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
ck_tile::index_t min_seqlen_q;
|
||||
|
||||
ck_tile::index_t block_scale_m;
|
||||
ck_tile::index_t block_scale_n;
|
||||
|
||||
float p_drop;
|
||||
bool s_randval;
|
||||
|
||||
@@ -596,6 +609,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.q_scale_ptr,
|
||||
args.k_scale_ptr,
|
||||
args.v_scale_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
@@ -619,6 +635,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_q_scale,
|
||||
args.nhead_stride_k_scale,
|
||||
args.nhead_stride_v_scale,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
@@ -626,12 +645,17 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.batch_stride_q_scale,
|
||||
args.batch_stride_k_scale,
|
||||
args.batch_stride_v_scale,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.block_scale_m,
|
||||
args.block_scale_n,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_kv_ptr);
|
||||
}
|
||||
|
||||
@@ -159,27 +159,29 @@ class BlockQuantizer
|
||||
template <typename SrcTensor, typename DstTensor, typename ScaleTensor>
|
||||
void quantize(const SrcTensor& in, DstTensor& out, ScaleTensor& block_scale, size_t block_size_)
|
||||
{
|
||||
using InDataType = typename std::remove_reference_t<decltype(in)>::DataType;
|
||||
using OutDataType = typename std::remove_reference_t<decltype(out)>::DataType;
|
||||
float dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<InDataType>::max());
|
||||
size_t batch = in.get_length(0);
|
||||
size_t head = in.get_length(i_perm ? 1 : 2);
|
||||
size_t seq_len = in.get_length(i_perm ? 2 : 1);
|
||||
size_t hdim = in.get_length(3);
|
||||
using InDataType = typename std::remove_reference_t<decltype(in)>::DataType;
|
||||
using OutDataType = typename std::remove_reference_t<decltype(out)>::DataType;
|
||||
float dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<InDataType>::max());
|
||||
size_t batch = in.get_length(0);
|
||||
size_t head = in.get_length(i_perm ? 1 : 2);
|
||||
size_t seq_len = in.get_length(i_perm ? 2 : 1);
|
||||
size_t hdim = in.get_length(3);
|
||||
size_t num_blocks_ = (seq_len + block_size_ - 1) / block_size_;
|
||||
std::cout << "batch: " << batch << " head: " << head << " seq_len: " << seq_len
|
||||
<< " hdim: " << hdim << " dtype_max: " << dtype_max
|
||||
<< " num_blocks_: " << num_blocks_ << std::endl;
|
||||
// std::cout << "batch: " << batch << " head: " << head << " seq_len: " << seq_len
|
||||
// << " hdim: " << hdim << " dtype_max: " << dtype_max
|
||||
// << " num_blocks_: " << num_blocks_ << std::endl;
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<float> dis(0.5, 2.0f);
|
||||
for(size_t b = 0; b < batch; ++b){
|
||||
for(size_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(size_t h = 0; h < head; ++h)
|
||||
{
|
||||
for(size_t block = 0; block < num_blocks_; ++block)
|
||||
{
|
||||
// get block max value
|
||||
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<InDataType>::min());
|
||||
float max_value =
|
||||
ck_tile::type_convert<float>(ck_tile::numeric<InDataType>::min());
|
||||
for(size_t s = block * block_size_;
|
||||
s < (block + 1) * block_size_ && s < seq_len;
|
||||
++s)
|
||||
@@ -196,11 +198,12 @@ class BlockQuantizer
|
||||
}
|
||||
// calculate block scale
|
||||
max_value += dis(gen);
|
||||
float scale = dtype_max / max_value;
|
||||
block_scale(b,h,block) = scale;
|
||||
std::cout << "block: " << block << " scale: " << scale << " max_value: " << max_value << " block_scale: " << block_scale << std::endl;
|
||||
|
||||
//quant
|
||||
float scale = dtype_max / max_value;
|
||||
block_scale(b, h, block) = scale;
|
||||
// std::cout << "block: " << block << " scale: " << scale << " max_value: " <<
|
||||
// max_value << " block_scale: " << block_scale << std::endl;
|
||||
|
||||
// quant
|
||||
for(size_t s = block * block_size_;
|
||||
s < (block + 1) * block_size_ && s < seq_len;
|
||||
++s)
|
||||
@@ -211,26 +214,28 @@ class BlockQuantizer
|
||||
if(!i_perm)
|
||||
idx = {b, s, h, d};
|
||||
float val = ck_tile::type_convert<float>(in(idx));
|
||||
out(idx) = ck_tile::type_convert<OutDataType>(val * scale);
|
||||
out(idx) = ck_tile::type_convert<OutDataType>(val * scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcTensor, typename DstTensor, typename ScaleTensor>
|
||||
void dequantize(const SrcTensor& in, DstTensor& out, ScaleTensor& block_scale, size_t block_size_)
|
||||
void
|
||||
dequantize(const SrcTensor& in, DstTensor& out, ScaleTensor& block_scale, size_t block_size_)
|
||||
{
|
||||
using OutDataType = typename std::remove_reference_t<decltype(out)>::DataType;
|
||||
size_t batch = in.get_length(0);
|
||||
size_t head = in.get_length(i_perm ? 1 : 2);
|
||||
size_t seq_len = in.get_length(i_perm ? 2 : 1);
|
||||
size_t hdim = in.get_length(3);
|
||||
using OutDataType = typename std::remove_reference_t<decltype(out)>::DataType;
|
||||
size_t batch = in.get_length(0);
|
||||
size_t head = in.get_length(i_perm ? 1 : 2);
|
||||
size_t seq_len = in.get_length(i_perm ? 2 : 1);
|
||||
size_t hdim = in.get_length(3);
|
||||
size_t num_blocks_ = (seq_len + block_size_ - 1) / block_size_;
|
||||
|
||||
//dequant
|
||||
for(size_t b = 0; b < batch; ++b){
|
||||
// dequant
|
||||
for(size_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(size_t h = 0; h < head; ++h)
|
||||
{
|
||||
for(size_t block = 0; block < num_blocks_; ++block)
|
||||
@@ -245,14 +250,13 @@ class BlockQuantizer
|
||||
std::vector<size_t> idx = {b, h, s, d};
|
||||
if(!i_perm)
|
||||
idx = {b, s, h, d};
|
||||
float val = ck_tile::type_convert<float>(in(idx));
|
||||
out(idx) = ck_tile::type_convert<OutDataType>(val / scale);
|
||||
float val = ck_tile::type_convert<float>(in(idx));
|
||||
out(idx) = ck_tile::type_convert<OutDataType>(val / scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@@ -651,8 +655,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
0 < page_block_size
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
ck_tile::HostTensor<float> k_scale(std::array<ck_tile::index_t, 3>{
|
||||
shape_batch, nhead_k, num_block_scale_n});
|
||||
ck_tile::HostTensor<float> k_scale(
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_n});
|
||||
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
|
||||
ck_tile::HostTensor<KDataType> knew_host(
|
||||
0 < seqlen_knew
|
||||
@@ -665,8 +669,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
: get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size))
|
||||
: (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
|
||||
ck_tile::HostTensor<float> v_scale(std::array<ck_tile::index_t, 3>{
|
||||
shape_batch, nhead_k, num_block_scale_n});
|
||||
ck_tile::HostTensor<float> v_scale(
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_n});
|
||||
ck_tile::HostTensor<VDataType> vnew_host(
|
||||
0 < seqlen_knew
|
||||
? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v)
|
||||
@@ -906,7 +910,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem k_scale_buf(k_scale.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_scale_buf(v_scale.get_element_space_size_in_bytes());
|
||||
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
@@ -940,7 +943,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
if(quant == 2)
|
||||
{
|
||||
//dequant data for host
|
||||
// dequant data for host
|
||||
BlockQuantizer quantizer(i_perm);
|
||||
// q_host.savetxt("./q_quant.txt");
|
||||
quantizer.dequantize(q_host, q_host, q_scale, block_scale_m_);
|
||||
@@ -1125,6 +1128,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse);
|
||||
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_q_scale = num_block_scale_m;
|
||||
const ck_tile::index_t nhead_stride_k_scale = num_block_scale_n;
|
||||
const ck_tile::index_t nhead_stride_v_scale = num_block_scale_n;
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k =
|
||||
@@ -1142,6 +1148,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
|
||||
const ck_tile::index_t batch_stride_q_scale = num_block_scale_m * nhead;
|
||||
const ck_tile::index_t batch_stride_k_scale = num_block_scale_n * nhead_k;
|
||||
const ck_tile::index_t batch_stride_v_scale = num_block_scale_n * nhead_k;
|
||||
// setup split_stride_* arguments (only used in split-kv kernel)
|
||||
const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q);
|
||||
const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v);
|
||||
@@ -1235,6 +1244,27 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
if(quant == 2)
|
||||
{
|
||||
args.q_scale_ptr =
|
||||
reinterpret_cast<const float*>(q_scale_buf.GetDeviceBuffer());
|
||||
args.k_scale_ptr =
|
||||
reinterpret_cast<const float*>(k_scale_buf.GetDeviceBuffer());
|
||||
args.v_scale_ptr =
|
||||
reinterpret_cast<const float*>(v_scale_buf.GetDeviceBuffer());
|
||||
|
||||
args.nhead_stride_q_scale = nhead_stride_q_scale;
|
||||
args.nhead_stride_k_scale = nhead_stride_k_scale;
|
||||
args.nhead_stride_v_scale = nhead_stride_v_scale;
|
||||
|
||||
args.batch_stride_q_scale = batch_stride_q_scale;
|
||||
args.batch_stride_k_scale = batch_stride_k_scale;
|
||||
args.batch_stride_v_scale = batch_stride_v_scale;
|
||||
|
||||
args.block_scale_m = block_scale_m_;
|
||||
args.block_scale_n = block_scale_n_;
|
||||
}
|
||||
|
||||
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
|
||||
|
||||
args.stride_randval = stride_randval;
|
||||
|
||||
@@ -156,6 +156,27 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonBlockScaleKargs
|
||||
{
|
||||
const float* q_scale_ptr = nullptr;
|
||||
const float* k_scale_ptr = nullptr;
|
||||
const float* v_scale_ptr = nullptr;
|
||||
|
||||
ck_tile::index_t nhead_stride_q_scale;
|
||||
ck_tile::index_t nhead_stride_k_scale;
|
||||
ck_tile::index_t nhead_stride_v_scale;
|
||||
|
||||
ck_tile::index_t block_scale_m;
|
||||
ck_tile::index_t block_scale_n;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_q_scale;
|
||||
ck_tile::index_t batch_stride_k_scale;
|
||||
ck_tile::index_t batch_stride_v_scale;
|
||||
};
|
||||
|
||||
struct FmhaFwdLogitsSoftCapKargs
|
||||
{
|
||||
FmhaFwdLogitsSoftCapKargs() = default;
|
||||
@@ -287,7 +308,8 @@ struct FmhaFwdKernel
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdBatchBlockScaleKargs, FmhaFwdEmptyKargs<6>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
@@ -342,6 +364,9 @@ struct FmhaFwdKernel
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const float* q_scale_ptr,
|
||||
const float* k_scale_ptr,
|
||||
const float* v_scale_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
@@ -365,6 +390,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_scale,
|
||||
ck_tile::index_t nhead_stride_k_scale,
|
||||
ck_tile::index_t nhead_stride_v_scale,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
@@ -372,6 +400,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_q_scale,
|
||||
ck_tile::index_t batch_stride_k_scale,
|
||||
ck_tile::index_t batch_stride_v_scale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
@@ -379,6 +410,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
ck_tile::index_t block_scale_m,
|
||||
ck_tile::index_t block_scale_n,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
|
||||
{
|
||||
@@ -411,6 +444,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for logits_soft_cap
|
||||
{}, // palceholder for quant scale
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -471,6 +505,24 @@ struct FmhaFwdKernel
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
kargs.q_scale_ptr = q_scale_ptr;
|
||||
kargs.k_scale_ptr = k_scale_ptr;
|
||||
kargs.v_scale_ptr = v_scale_ptr;
|
||||
|
||||
kargs.nhead_stride_q_scale = nhead_stride_q_scale;
|
||||
kargs.nhead_stride_k_scale = nhead_stride_k_scale;
|
||||
kargs.nhead_stride_v_scale = nhead_stride_v_scale;
|
||||
|
||||
kargs.batch_stride_q_scale = batch_stride_q_scale;
|
||||
kargs.batch_stride_k_scale = batch_stride_k_scale;
|
||||
kargs.batch_stride_v_scale = batch_stride_v_scale;
|
||||
|
||||
kargs.block_scale_m = block_scale_m;
|
||||
kargs.block_scale_n = block_scale_n;
|
||||
}
|
||||
|
||||
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
|
||||
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
|
||||
return kargs;
|
||||
@@ -486,6 +538,9 @@ struct FmhaFwdKernel
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const float* q_scale_ptr,
|
||||
const float* k_scale_ptr,
|
||||
const float* v_scale_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
@@ -509,6 +564,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_scale,
|
||||
ck_tile::index_t nhead_stride_k_scale,
|
||||
ck_tile::index_t nhead_stride_v_scale,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
@@ -516,12 +574,17 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_q_scale,
|
||||
ck_tile::index_t batch_stride_k_scale,
|
||||
ck_tile::index_t batch_stride_v_scale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_m,
|
||||
ck_tile::index_t block_scale_n,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
|
||||
{
|
||||
@@ -533,6 +596,9 @@ struct FmhaFwdKernel
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
q_scale_ptr,
|
||||
k_scale_ptr,
|
||||
v_scale_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
@@ -556,6 +622,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_scale,
|
||||
nhead_stride_k_scale,
|
||||
nhead_stride_v_scale,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -563,12 +632,17 @@ struct FmhaFwdKernel
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
batch_stride_q_scale,
|
||||
batch_stride_k_scale,
|
||||
batch_stride_v_scale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_m,
|
||||
block_scale_n,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_kv_ptr);
|
||||
}
|
||||
@@ -583,6 +657,9 @@ struct FmhaFwdKernel
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const float* q_scale_ptr,
|
||||
const float* k_scale_ptr,
|
||||
const float* v_scale_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
@@ -606,6 +683,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_scale,
|
||||
ck_tile::index_t nhead_stride_k_scale,
|
||||
ck_tile::index_t nhead_stride_v_scale,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
@@ -613,12 +693,17 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_q_scale,
|
||||
ck_tile::index_t batch_stride_k_scale,
|
||||
ck_tile::index_t batch_stride_v_scale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_m,
|
||||
ck_tile::index_t block_scale_n,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
|
||||
{
|
||||
@@ -630,6 +715,9 @@ struct FmhaFwdKernel
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
q_scale_ptr,
|
||||
k_scale_ptr,
|
||||
v_scale_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
@@ -653,6 +741,9 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_scale,
|
||||
nhead_stride_k_scale,
|
||||
nhead_stride_v_scale,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -660,12 +751,17 @@ struct FmhaFwdKernel
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
batch_stride_q_scale,
|
||||
batch_stride_k_scale,
|
||||
batch_stride_v_scale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_m,
|
||||
block_scale_n,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_kv_ptr);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user