mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Re-order pipeline paremeters
This commit is contained in:
@@ -738,17 +738,17 @@ struct FmhaFwdAppendKVKernel
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window_tmp,
|
||||
i_page_block_k,
|
||||
k_page_block_navigator,
|
||||
knew_dram_window,
|
||||
v_dram_window_tmp,
|
||||
i_page_block_v,
|
||||
v_page_block_navigator,
|
||||
vnew_dram_window,
|
||||
q_rotary_cos_dram_window,
|
||||
q_rotary_sin_dram_window,
|
||||
knew_rotary_cos_dram_window,
|
||||
knew_rotary_sin_dram_window,
|
||||
kargs.rotary_dim,
|
||||
k_page_block_navigator,
|
||||
v_page_block_navigator,
|
||||
kargs.seqlen_q <= i_m0,
|
||||
kargs.seqlen_knew <= i_n0);
|
||||
}
|
||||
@@ -757,17 +757,17 @@ struct FmhaFwdAppendKVKernel
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window_tmp,
|
||||
i_page_block_k,
|
||||
k_page_block_navigator,
|
||||
knew_dram_window,
|
||||
v_dram_window_tmp,
|
||||
i_page_block_v,
|
||||
v_page_block_navigator,
|
||||
vnew_dram_window,
|
||||
q_rotary_cos_dram_window,
|
||||
q_rotary_sin_dram_window,
|
||||
knew_rotary_cos_dram_window,
|
||||
knew_rotary_sin_dram_window,
|
||||
0, // rotary_dim not used
|
||||
k_page_block_navigator,
|
||||
v_page_block_navigator,
|
||||
kargs.seqlen_q <= i_m0,
|
||||
kargs.seqlen_knew <= i_n0);
|
||||
}
|
||||
|
||||
@@ -829,8 +829,10 @@ struct FmhaFwdSplitKVKernel
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
k_page_block_navigator,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
v_page_block_navigator,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
@@ -844,15 +846,15 @@ struct FmhaFwdSplitKVKernel
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
smem_ptr,
|
||||
k_page_block_navigator,
|
||||
v_page_block_navigator);
|
||||
smem_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
k_page_block_navigator,
|
||||
v_dram_window,
|
||||
v_page_block_navigator,
|
||||
bias_dram_window,
|
||||
lse_acc_dram_window,
|
||||
kargs.num_splits,
|
||||
@@ -860,9 +862,7 @@ struct FmhaFwdSplitKVKernel
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
smem_ptr,
|
||||
k_page_block_navigator,
|
||||
v_page_block_navigator);
|
||||
smem_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -74,8 +74,10 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
|
||||
template <typename QDramBlockWindow,
|
||||
typename KDramBlockWindow,
|
||||
typename KPageBlockNavigator,
|
||||
typename KnewDramBlockWindow,
|
||||
typename VDramBlockWindow,
|
||||
typename VPageBlockNavigator,
|
||||
typename VnewDramBlockWindow,
|
||||
typename QElementFunction,
|
||||
typename KnewElementFunction,
|
||||
@@ -83,18 +85,18 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
typename QRotaryCosDramBlockWindow,
|
||||
typename QRotarySinDramBlockWindow,
|
||||
typename KnewRotaryCosDramBlockWindow,
|
||||
typename KnewRotarySinDramBlockWindow,
|
||||
typename KPageBlockNavigator,
|
||||
typename VPageBlockNavigator>
|
||||
typename KnewRotarySinDramBlockWindow>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
KDramBlockWindow& k_dram_block_window, // N0*K0 tile
|
||||
index_t i_page_block_k,
|
||||
const KPageBlockNavigator& k_page_block_navigator,
|
||||
const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile
|
||||
const KnewElementFunction& knew_element_func,
|
||||
VDramBlockWindow& v_dram_block_window, // N1*N0 tile
|
||||
index_t i_page_block_v,
|
||||
const VPageBlockNavigator& v_page_block_navigator,
|
||||
const VnewDramBlockWindow& vnew_dram_block_window, // N1*N0 tile
|
||||
const VnewElementFunction& vnew_element_func,
|
||||
const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window,
|
||||
@@ -102,12 +104,10 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window,
|
||||
const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window,
|
||||
index_t rotary_dim,
|
||||
const KPageBlockNavigator& k_page_block_navigator,
|
||||
const VPageBlockNavigator& v_page_block_navigator,
|
||||
bool skip_transform_q,
|
||||
bool skip_append_kv) const
|
||||
bool skip_rotate_q,
|
||||
bool skip_rotate_append_kv) const
|
||||
{
|
||||
if(!skip_append_kv)
|
||||
if(!skip_rotate_append_kv)
|
||||
{
|
||||
// append Knew to K
|
||||
auto knew_window = make_tile_window(
|
||||
@@ -190,7 +190,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
}
|
||||
}
|
||||
|
||||
if(!skip_transform_q)
|
||||
if(!skip_rotate_q)
|
||||
{
|
||||
// optionally apply rotary embedding to Q
|
||||
if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE)
|
||||
@@ -231,41 +231,43 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
|
||||
template <typename QDramBlockWindow,
|
||||
typename KDramBlockWindow,
|
||||
typename KPageBlockNavigator,
|
||||
typename KnewDramBlockWindow,
|
||||
typename VDramBlockWindow,
|
||||
typename VPageBlockNavigator,
|
||||
typename VnewDramBlockWindow,
|
||||
typename QRotaryCosDramBlockWindow,
|
||||
typename QRotarySinDramBlockWindow,
|
||||
typename KnewRotaryCosDramBlockWindow,
|
||||
typename KnewRotarySinDramBlockWindow,
|
||||
typename KPageBlockNavigator,
|
||||
typename VPageBlockNavigator>
|
||||
typename KnewRotarySinDramBlockWindow>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(QDramBlockWindow& q_dram_block_window,
|
||||
KDramBlockWindow& k_dram_block_window,
|
||||
index_t i_page_block_k,
|
||||
const KPageBlockNavigator& k_page_block_navigator,
|
||||
const KnewDramBlockWindow& knew_dram_block_window,
|
||||
VDramBlockWindow& v_dram_block_window,
|
||||
index_t i_page_block_v,
|
||||
const VPageBlockNavigator& v_page_block_navigator,
|
||||
const VnewDramBlockWindow& vnew_dram_block_window,
|
||||
const QRotaryCosDramBlockWindow& q_rotary_cos_dram_block_window,
|
||||
const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window,
|
||||
const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window,
|
||||
const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window,
|
||||
index_t rotary_dim,
|
||||
const KPageBlockNavigator& k_page_block_navigator,
|
||||
const VPageBlockNavigator& v_page_block_navigator,
|
||||
bool skip_transform_q,
|
||||
bool skip_append_kv) const
|
||||
bool skip_rotate_q,
|
||||
bool skip_rotate_append_kv) const
|
||||
{
|
||||
return operator()(q_dram_block_window,
|
||||
identity{},
|
||||
k_dram_block_window,
|
||||
i_page_block_k,
|
||||
k_page_block_navigator,
|
||||
knew_dram_block_window,
|
||||
identity{},
|
||||
v_dram_block_window,
|
||||
i_page_block_v,
|
||||
v_page_block_navigator,
|
||||
vnew_dram_block_window,
|
||||
identity{},
|
||||
q_rotary_cos_dram_block_window,
|
||||
@@ -273,10 +275,8 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
knew_rotary_cos_dram_block_window,
|
||||
knew_rotary_sin_dram_block_window,
|
||||
rotary_dim,
|
||||
k_page_block_navigator,
|
||||
v_page_block_navigator,
|
||||
skip_transform_q,
|
||||
skip_append_kv);
|
||||
skip_rotate_q,
|
||||
skip_rotate_append_kv);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -105,7 +105,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KPageBlockNavigator,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename VPageBlockNavigator,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEaccDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
@@ -116,15 +118,15 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding,
|
||||
typename KPageBlockNavigator,
|
||||
typename VPageBlockNavigator>
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const KPageBlockNavigator& k_page_block_navigator,
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VPageBlockNavigator& v_page_block_navigator,
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
@@ -138,9 +140,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr,
|
||||
const KPageBlockNavigator& k_page_block_navigator,
|
||||
const VPageBlockNavigator& v_page_block_navigator) const
|
||||
void* smem_ptr) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -634,16 +634,18 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KPageBlockNavigator,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename VPageBlockNavigator,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEaccDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename KPageBlockNavigator,
|
||||
typename VPageBlockNavigator>
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const KPageBlockNavigator& k_page_block_navigator,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VPageBlockNavigator& v_page_block_navigator,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
|
||||
index_t num_splits,
|
||||
@@ -651,15 +653,15 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr,
|
||||
const KPageBlockNavigator& k_page_block_navigator,
|
||||
const VPageBlockNavigator& v_page_block_navigator) const
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
k_page_block_navigator,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
v_page_block_navigator,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -673,9 +675,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
smem_ptr,
|
||||
k_page_block_navigator,
|
||||
v_page_block_navigator);
|
||||
smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user