Re-order pipeline paremeters

This commit is contained in:
PoYen, Chen
2024-08-13 07:38:41 +00:00
parent 19c19d8bd3
commit 3dd6ef61ef
4 changed files with 48 additions and 48 deletions

View File

@@ -738,17 +738,17 @@ struct FmhaFwdAppendKVKernel
FmhaPipeline{}(q_dram_window,
k_dram_window_tmp,
i_page_block_k,
k_page_block_navigator,
knew_dram_window,
v_dram_window_tmp,
i_page_block_v,
v_page_block_navigator,
vnew_dram_window,
q_rotary_cos_dram_window,
q_rotary_sin_dram_window,
knew_rotary_cos_dram_window,
knew_rotary_sin_dram_window,
kargs.rotary_dim,
k_page_block_navigator,
v_page_block_navigator,
kargs.seqlen_q <= i_m0,
kargs.seqlen_knew <= i_n0);
}
@@ -757,17 +757,17 @@ struct FmhaFwdAppendKVKernel
FmhaPipeline{}(q_dram_window,
k_dram_window_tmp,
i_page_block_k,
k_page_block_navigator,
knew_dram_window,
v_dram_window_tmp,
i_page_block_v,
v_page_block_navigator,
vnew_dram_window,
q_rotary_cos_dram_window,
q_rotary_sin_dram_window,
knew_rotary_cos_dram_window,
knew_rotary_sin_dram_window,
0, // rotary_dim not used
k_page_block_navigator,
v_page_block_navigator,
kargs.seqlen_q <= i_m0,
kargs.seqlen_knew <= i_n0);
}

View File

@@ -829,8 +829,10 @@ struct FmhaFwdSplitKVKernel
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window,
k_page_block_navigator,
identity{}, // k_element_func
v_dram_window,
v_page_block_navigator,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
@@ -844,15 +846,15 @@ struct FmhaFwdSplitKVKernel
mask,
position_encoding,
kargs.scale_s,
smem_ptr,
k_page_block_navigator,
v_page_block_navigator);
smem_ptr);
}
else
{
return FmhaPipeline{}(q_dram_window,
k_dram_window,
k_page_block_navigator,
v_dram_window,
v_page_block_navigator,
bias_dram_window,
lse_acc_dram_window,
kargs.num_splits,
@@ -860,9 +862,7 @@ struct FmhaFwdSplitKVKernel
mask,
position_encoding,
kargs.scale_s,
smem_ptr,
k_page_block_navigator,
v_page_block_navigator);
smem_ptr);
}
}();

View File

@@ -74,8 +74,10 @@ struct BlockFmhaFwdAppendKVPipeline
template <typename QDramBlockWindow,
typename KDramBlockWindow,
typename KPageBlockNavigator,
typename KnewDramBlockWindow,
typename VDramBlockWindow,
typename VPageBlockNavigator,
typename VnewDramBlockWindow,
typename QElementFunction,
typename KnewElementFunction,
@@ -83,18 +85,18 @@ struct BlockFmhaFwdAppendKVPipeline
typename QRotaryCosDramBlockWindow,
typename QRotarySinDramBlockWindow,
typename KnewRotaryCosDramBlockWindow,
typename KnewRotarySinDramBlockWindow,
typename KPageBlockNavigator,
typename VPageBlockNavigator>
typename KnewRotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE auto
operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile
const QElementFunction& q_element_func,
KDramBlockWindow& k_dram_block_window, // N0*K0 tile
index_t i_page_block_k,
const KPageBlockNavigator& k_page_block_navigator,
const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile
const KnewElementFunction& knew_element_func,
VDramBlockWindow& v_dram_block_window, // N1*N0 tile
index_t i_page_block_v,
const VPageBlockNavigator& v_page_block_navigator,
const VnewDramBlockWindow& vnew_dram_block_window, // N1*N0 tile
const VnewElementFunction& vnew_element_func,
const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window,
@@ -102,12 +104,10 @@ struct BlockFmhaFwdAppendKVPipeline
const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window,
const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window,
index_t rotary_dim,
const KPageBlockNavigator& k_page_block_navigator,
const VPageBlockNavigator& v_page_block_navigator,
bool skip_transform_q,
bool skip_append_kv) const
bool skip_rotate_q,
bool skip_rotate_append_kv) const
{
if(!skip_append_kv)
if(!skip_rotate_append_kv)
{
// append Knew to K
auto knew_window = make_tile_window(
@@ -190,7 +190,7 @@ struct BlockFmhaFwdAppendKVPipeline
}
}
if(!skip_transform_q)
if(!skip_rotate_q)
{
// optionally apply rotary embedding to Q
if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE)
@@ -231,41 +231,43 @@ struct BlockFmhaFwdAppendKVPipeline
template <typename QDramBlockWindow,
typename KDramBlockWindow,
typename KPageBlockNavigator,
typename KnewDramBlockWindow,
typename VDramBlockWindow,
typename VPageBlockNavigator,
typename VnewDramBlockWindow,
typename QRotaryCosDramBlockWindow,
typename QRotarySinDramBlockWindow,
typename KnewRotaryCosDramBlockWindow,
typename KnewRotarySinDramBlockWindow,
typename KPageBlockNavigator,
typename VPageBlockNavigator>
typename KnewRotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE auto
operator()(QDramBlockWindow& q_dram_block_window,
KDramBlockWindow& k_dram_block_window,
index_t i_page_block_k,
const KPageBlockNavigator& k_page_block_navigator,
const KnewDramBlockWindow& knew_dram_block_window,
VDramBlockWindow& v_dram_block_window,
index_t i_page_block_v,
const VPageBlockNavigator& v_page_block_navigator,
const VnewDramBlockWindow& vnew_dram_block_window,
const QRotaryCosDramBlockWindow& q_rotary_cos_dram_block_window,
const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window,
const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window,
const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window,
index_t rotary_dim,
const KPageBlockNavigator& k_page_block_navigator,
const VPageBlockNavigator& v_page_block_navigator,
bool skip_transform_q,
bool skip_append_kv) const
bool skip_rotate_q,
bool skip_rotate_append_kv) const
{
return operator()(q_dram_block_window,
identity{},
k_dram_block_window,
i_page_block_k,
k_page_block_navigator,
knew_dram_block_window,
identity{},
v_dram_block_window,
i_page_block_v,
v_page_block_navigator,
vnew_dram_block_window,
identity{},
q_rotary_cos_dram_block_window,
@@ -273,10 +275,8 @@ struct BlockFmhaFwdAppendKVPipeline
knew_rotary_cos_dram_block_window,
knew_rotary_sin_dram_block_window,
rotary_dim,
k_page_block_navigator,
v_page_block_navigator,
skip_transform_q,
skip_append_kv);
skip_rotate_q,
skip_rotate_append_kv);
}
};

View File

@@ -105,7 +105,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename KPageBlockNavigator,
typename VDramBlockWindowTmp,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename QElementFunction,
@@ -116,15 +118,15 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding,
typename KPageBlockNavigator,
typename VPageBlockNavigator>
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KPageBlockNavigator& k_page_block_navigator,
const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
@@ -138,9 +140,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
const KPageBlockNavigator& k_page_block_navigator,
const VPageBlockNavigator& v_page_block_navigator) const
void* smem_ptr) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -634,16 +634,18 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename KPageBlockNavigator,
typename VDramBlockWindowTmp,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename PositionEncoding,
typename KPageBlockNavigator,
typename VPageBlockNavigator>
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KPageBlockNavigator& k_page_block_navigator,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
index_t num_splits,
@@ -651,15 +653,15 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
const KPageBlockNavigator& k_page_block_navigator,
const VPageBlockNavigator& v_page_block_navigator) const
void* smem_ptr) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
k_page_block_navigator,
identity{},
v_dram_block_window_tmp,
v_page_block_navigator,
identity{},
bias_dram_block_window_tmp,
identity{},
@@ -673,9 +675,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
mask,
position_encoding,
scale_s,
smem_ptr,
k_page_block_navigator,
v_page_block_navigator);
smem_ptr);
}
};