mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Unify parameter/variable naming style
This commit is contained in:
@@ -76,27 +76,27 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KnewDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename VnewDramBlockWindowTmp,
|
||||
template <typename QDramBlockWindow,
|
||||
typename KDramBlockWindow,
|
||||
typename KnewDramBlockWindow,
|
||||
typename VDramBlockWindow,
|
||||
typename VnewDramBlockWindow,
|
||||
typename QElementFunction,
|
||||
typename KnewElementFunction,
|
||||
typename VnewElementFunction,
|
||||
typename RotaryCosBlockWindowTemp,
|
||||
typename RotarySinBlockWindowTemp>
|
||||
typename RotaryCosDramBlockWindow,
|
||||
typename RotarySinDramBlockWindow>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
operator()(const QDramBlockWindow& q_dram_block_window, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const KnewDramBlockWindowTmp& knew_dram_block_window_tmp, // N0*K0 tile
|
||||
KDramBlockWindow& k_dram_block_window, // N0*K0 tile
|
||||
const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile
|
||||
const KnewElementFunction& knew_element_func,
|
||||
VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VnewDramBlockWindowTmp& vnew_dram_block_window_tmp, // N1*K1 tile
|
||||
VDramBlockWindow& v_dram_block_window, // N1*K1 tile
|
||||
const VnewDramBlockWindow& vnew_dram_block_window, // N1*K1 tile
|
||||
const VnewElementFunction& vnew_element_func,
|
||||
const RotaryCosBlockWindowTemp rotary_cos_block_window_tmp,
|
||||
const RotarySinBlockWindowTemp rotary_sin_block_window_tmp,
|
||||
const RotaryCosDramBlockWindow rotary_cos_dram_block_window,
|
||||
const RotarySinDramBlockWindow rotary_sin_dram_block_window,
|
||||
void* smem_ptr,
|
||||
index_t rotary_dim = 0) const
|
||||
{
|
||||
@@ -154,34 +154,29 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
#endif
|
||||
};
|
||||
|
||||
auto knew_dram_block_window =
|
||||
make_tile_window(knew_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
knew_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, 0});
|
||||
|
||||
auto knew_dram_window =
|
||||
auto knew_window =
|
||||
make_tile_window(knew_dram_block_window.get_bottom_tensor_view(),
|
||||
knew_dram_block_window.get_window_lengths(),
|
||||
knew_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKnewDramTileDistribution<Problem>());
|
||||
|
||||
auto knew_tile = [&]() {
|
||||
auto knew = load_tile(knew_dram_window);
|
||||
auto knew = load_tile(knew_window);
|
||||
return tile_elementwise_in(knew_element_func, knew);
|
||||
}();
|
||||
|
||||
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
|
||||
{
|
||||
auto rotary_cos_window =
|
||||
make_tile_window(rotary_cos_block_window_tmp.get_bottom_tensor_view(),
|
||||
rotary_cos_block_window_tmp.get_window_lengths(),
|
||||
rotary_cos_block_window_tmp.get_window_origin(),
|
||||
make_tile_window(rotary_cos_dram_block_window.get_bottom_tensor_view(),
|
||||
rotary_cos_dram_block_window.get_window_lengths(),
|
||||
rotary_cos_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
|
||||
auto rotary_sin_window =
|
||||
make_tile_window(rotary_sin_block_window_tmp.get_bottom_tensor_view(),
|
||||
rotary_sin_block_window_tmp.get_window_lengths(),
|
||||
rotary_sin_block_window_tmp.get_window_origin(),
|
||||
make_tile_window(rotary_sin_dram_block_window.get_bottom_tensor_view(),
|
||||
rotary_sin_dram_block_window.get_window_lengths(),
|
||||
rotary_sin_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
|
||||
// We assume that each thread owns contiguous elements on head dimention. And we will
|
||||
@@ -225,11 +220,11 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
{
|
||||
bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
|
||||
|
||||
auto knew_other_dram_window = knew_dram_window;
|
||||
move_tile_window(knew_other_dram_window,
|
||||
auto knew_other_window = knew_window;
|
||||
move_tile_window(knew_other_window,
|
||||
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
|
||||
|
||||
auto knew_other_tile = load_tile(knew_other_dram_window);
|
||||
auto knew_other_tile = load_tile(knew_other_window);
|
||||
|
||||
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
@@ -253,53 +248,49 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
}
|
||||
}
|
||||
print_tile(knew_tile, 7);
|
||||
store_tile(k_dram_block_window_tmp, knew_tile);
|
||||
store_tile(k_dram_block_window, knew_tile);
|
||||
|
||||
auto vnew_dram_block_window =
|
||||
make_tile_window(vnew_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
vnew_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, 0});
|
||||
|
||||
auto vnew_dram_window =
|
||||
auto vnew_window =
|
||||
make_tile_window(vnew_dram_block_window.get_bottom_tensor_view(),
|
||||
vnew_dram_block_window.get_window_lengths(),
|
||||
vnew_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeVnewDramTileDistribution<Problem>());
|
||||
|
||||
auto vnew_tile = [&]() {
|
||||
auto vnew = load_tile(vnew_dram_window);
|
||||
auto vnew = load_tile(vnew_window);
|
||||
return tile_elementwise_in(vnew_element_func, vnew);
|
||||
}();
|
||||
store_tile(v_dram_block_window_tmp, vnew_tile);
|
||||
store_tile(v_dram_block_window, vnew_tile);
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KnewDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename VnewDramBlockWindowTmp,
|
||||
typename RotaryCosBlockWindowTemp,
|
||||
typename RotarySinBlockWindowTemp>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const KnewDramBlockWindowTmp& knew_dram_block_window_tmp,
|
||||
VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const VnewDramBlockWindowTmp& vnew_dram_block_window_tmp,
|
||||
const RotaryCosBlockWindowTemp& rotary_cos_block_window_tmp,
|
||||
const RotarySinBlockWindowTemp& rotary_sin_block_window_tmp,
|
||||
void* smem_ptr,
|
||||
index_t rotary_dim = 0) const
|
||||
template <typename QDramBlockWindow,
|
||||
typename KDramBlockWindow,
|
||||
typename KnewDramBlockWindow,
|
||||
typename VDramBlockWindow,
|
||||
typename VnewDramBlockWindow,
|
||||
typename RotaryCosDramBlockWindow,
|
||||
typename RotarySinDramBlockWindow>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindow& q_dram_block_window,
|
||||
KDramBlockWindow& k_dram_block_window,
|
||||
const KnewDramBlockWindow& knew_dram_block_window,
|
||||
VDramBlockWindow& v_dram_block_window,
|
||||
const VnewDramBlockWindow& vnew_dram_block_window,
|
||||
const RotaryCosDramBlockWindow& rotary_cos_dram_block_window,
|
||||
const RotarySinDramBlockWindow& rotary_sin_dram_block_window,
|
||||
void* smem_ptr,
|
||||
index_t rotary_dim = 0) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
return operator()(q_dram_block_window,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
knew_dram_block_window_tmp,
|
||||
k_dram_block_window,
|
||||
knew_dram_block_window,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
vnew_dram_block_window_tmp,
|
||||
v_dram_block_window,
|
||||
vnew_dram_block_window,
|
||||
identity{},
|
||||
rotary_cos_block_window_tmp,
|
||||
rotary_sin_block_window_tmp,
|
||||
rotary_cos_dram_block_window,
|
||||
rotary_sin_dram_block_window,
|
||||
smem_ptr,
|
||||
rotary_dim);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user