Re-order split-kv pipeline call operator arguments

This commit is contained in:
PoYen, Chen
2024-06-11 19:23:19 +00:00
parent 6ee71c2bf6
commit df4fc8f26c
3 changed files with 24 additions and 24 deletions

View File

@@ -865,13 +865,13 @@ struct FmhaFwdSplitKVKernel
identity{}, // s_acc_element_func
scales{kargs.scale_p}, // p_compute_element_func
identity{}, // o_acc_element_func
kargs.num_splits,
i_split_,
mask,
position_encoding,
kargs.scale_s,
smem_ptr,
dropout,
i_split_,
kargs.num_splits);
dropout);
}
else
{
@@ -881,13 +881,13 @@ struct FmhaFwdSplitKVKernel
bias_dram_window,
randval_dram_window,
lse_acc_dram_window,
kargs.num_splits,
i_split_,
mask,
position_encoding,
kargs.scale_s,
smem_ptr,
dropout,
i_split_,
kargs.num_splits);
dropout);
}
}();

View File

@@ -133,13 +133,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout,
index_t i_split,
index_t num_splits) const
BlockDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -629,13 +629,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout,
index_t i_split,
index_t num_splits) const
BlockDropout& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -651,13 +651,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
identity{},
identity{},
identity{},
num_splits,
i_split,
mask,
position_encoding,
scale_s,
smem_ptr,
dropout,
i_split,
num_splits);
dropout);
}
};

View File

@@ -145,13 +145,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout,
index_t i_split,
index_t num_splits) const
BlockDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -730,13 +730,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout,
index_t i_split,
index_t num_splits) const
BlockDropout& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -752,13 +752,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
identity{},
identity{},
identity{},
num_splits,
i_split,
mask,
position_encoding,
scale_s,
smem_ptr,
dropout,
i_split,
num_splits);
dropout);
}
};