mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Re-order split-kv pipeline call operator arguments
This commit is contained in:
@@ -865,13 +865,13 @@ struct FmhaFwdSplitKVKernel
|
||||
identity{}, // s_acc_element_func
|
||||
scales{kargs.scale_p}, // p_compute_element_func
|
||||
identity{}, // o_acc_element_func
|
||||
kargs.num_splits,
|
||||
i_split_,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
i_split_,
|
||||
kargs.num_splits);
|
||||
dropout);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -881,13 +881,13 @@ struct FmhaFwdSplitKVKernel
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
lse_acc_dram_window,
|
||||
kargs.num_splits,
|
||||
i_split_,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
i_split_,
|
||||
kargs.num_splits);
|
||||
dropout);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -133,13 +133,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
index_t num_splits,
|
||||
index_t i_split,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout,
|
||||
index_t i_split,
|
||||
index_t num_splits) const
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -629,13 +629,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
|
||||
index_t num_splits,
|
||||
index_t i_split,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout,
|
||||
index_t i_split,
|
||||
index_t num_splits) const
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -651,13 +651,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
num_splits,
|
||||
i_split,
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
i_split,
|
||||
num_splits);
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -145,13 +145,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
index_t num_splits,
|
||||
index_t i_split,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout,
|
||||
index_t i_split,
|
||||
index_t num_splits) const
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -730,13 +730,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
|
||||
index_t num_splits,
|
||||
index_t i_split,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout,
|
||||
index_t i_split,
|
||||
index_t num_splits) const
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -752,13 +752,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
num_splits,
|
||||
i_split,
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
i_split,
|
||||
num_splits);
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user