Unify parameter/variable naming style

This commit is contained in:
PoYen, Chen
2024-07-23 02:59:17 +00:00
parent c0bc097758
commit c26c60db4c

View File

@@ -76,27 +76,27 @@ struct BlockFmhaFwdAppendKVPipeline
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename KnewDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename VnewDramBlockWindowTmp,
template <typename QDramBlockWindow,
typename KDramBlockWindow,
typename KnewDramBlockWindow,
typename VDramBlockWindow,
typename VnewDramBlockWindow,
typename QElementFunction,
typename KnewElementFunction,
typename VnewElementFunction,
typename RotaryCosBlockWindowTemp,
typename RotarySinBlockWindowTemp>
typename RotaryCosDramBlockWindow,
typename RotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
operator()(const QDramBlockWindow& q_dram_block_window, // M0*K0 tile
const QElementFunction& q_element_func,
KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KnewDramBlockWindowTmp& knew_dram_block_window_tmp, // N0*K0 tile
KDramBlockWindow& k_dram_block_window, // N0*K0 tile
const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile
const KnewElementFunction& knew_element_func,
VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VnewDramBlockWindowTmp& vnew_dram_block_window_tmp, // N1*K1 tile
VDramBlockWindow& v_dram_block_window, // N1*K1 tile
const VnewDramBlockWindow& vnew_dram_block_window, // N1*K1 tile
const VnewElementFunction& vnew_element_func,
const RotaryCosBlockWindowTemp rotary_cos_block_window_tmp,
const RotarySinBlockWindowTemp rotary_sin_block_window_tmp,
const RotaryCosDramBlockWindow rotary_cos_dram_block_window,
const RotarySinDramBlockWindow rotary_sin_dram_block_window,
void* smem_ptr,
index_t rotary_dim = 0) const
{
@@ -154,34 +154,29 @@ struct BlockFmhaFwdAppendKVPipeline
#endif
};
auto knew_dram_block_window =
make_tile_window(knew_dram_block_window_tmp.get_bottom_tensor_view(),
knew_dram_block_window_tmp.get_window_lengths(),
{0, 0});
auto knew_dram_window =
auto knew_window =
make_tile_window(knew_dram_block_window.get_bottom_tensor_view(),
knew_dram_block_window.get_window_lengths(),
knew_dram_block_window.get_window_origin(),
Policy::template MakeKnewDramTileDistribution<Problem>());
auto knew_tile = [&]() {
auto knew = load_tile(knew_dram_window);
auto knew = load_tile(knew_window);
return tile_elementwise_in(knew_element_func, knew);
}();
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
{
auto rotary_cos_window =
make_tile_window(rotary_cos_block_window_tmp.get_bottom_tensor_view(),
rotary_cos_block_window_tmp.get_window_lengths(),
rotary_cos_block_window_tmp.get_window_origin(),
make_tile_window(rotary_cos_dram_block_window.get_bottom_tensor_view(),
rotary_cos_dram_block_window.get_window_lengths(),
rotary_cos_dram_block_window.get_window_origin(),
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
auto rotary_sin_window =
make_tile_window(rotary_sin_block_window_tmp.get_bottom_tensor_view(),
rotary_sin_block_window_tmp.get_window_lengths(),
rotary_sin_block_window_tmp.get_window_origin(),
make_tile_window(rotary_sin_dram_block_window.get_bottom_tensor_view(),
rotary_sin_dram_block_window.get_window_lengths(),
rotary_sin_dram_block_window.get_window_origin(),
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
// We assume that each thread owns contiguous elements on head dimention. And we will
@@ -225,11 +220,11 @@ struct BlockFmhaFwdAppendKVPipeline
{
bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
auto knew_other_dram_window = knew_dram_window;
move_tile_window(knew_other_dram_window,
auto knew_other_window = knew_window;
move_tile_window(knew_other_window,
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
auto knew_other_tile = load_tile(knew_other_dram_window);
auto knew_other_tile = load_tile(knew_other_window);
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
auto rotary_cos_tile = load_tile(rotary_cos_window);
@@ -253,53 +248,49 @@ struct BlockFmhaFwdAppendKVPipeline
}
}
print_tile(knew_tile, 7);
store_tile(k_dram_block_window_tmp, knew_tile);
store_tile(k_dram_block_window, knew_tile);
auto vnew_dram_block_window =
make_tile_window(vnew_dram_block_window_tmp.get_bottom_tensor_view(),
vnew_dram_block_window_tmp.get_window_lengths(),
{0, 0});
auto vnew_dram_window =
auto vnew_window =
make_tile_window(vnew_dram_block_window.get_bottom_tensor_view(),
vnew_dram_block_window.get_window_lengths(),
vnew_dram_block_window.get_window_origin(),
Policy::template MakeVnewDramTileDistribution<Problem>());
auto vnew_tile = [&]() {
auto vnew = load_tile(vnew_dram_window);
auto vnew = load_tile(vnew_window);
return tile_elementwise_in(vnew_element_func, vnew);
}();
store_tile(v_dram_block_window_tmp, vnew_tile);
store_tile(v_dram_block_window, vnew_tile);
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename KnewDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename VnewDramBlockWindowTmp,
typename RotaryCosBlockWindowTemp,
typename RotarySinBlockWindowTemp>
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
KDramBlockWindowTmp& k_dram_block_window_tmp,
const KnewDramBlockWindowTmp& knew_dram_block_window_tmp,
VDramBlockWindowTmp& v_dram_block_window_tmp,
const VnewDramBlockWindowTmp& vnew_dram_block_window_tmp,
const RotaryCosBlockWindowTemp& rotary_cos_block_window_tmp,
const RotarySinBlockWindowTemp& rotary_sin_block_window_tmp,
void* smem_ptr,
index_t rotary_dim = 0) const
template <typename QDramBlockWindow,
typename KDramBlockWindow,
typename KnewDramBlockWindow,
typename VDramBlockWindow,
typename VnewDramBlockWindow,
typename RotaryCosDramBlockWindow,
typename RotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindow& q_dram_block_window,
KDramBlockWindow& k_dram_block_window,
const KnewDramBlockWindow& knew_dram_block_window,
VDramBlockWindow& v_dram_block_window,
const VnewDramBlockWindow& vnew_dram_block_window,
const RotaryCosDramBlockWindow& rotary_cos_dram_block_window,
const RotarySinDramBlockWindow& rotary_sin_dram_block_window,
void* smem_ptr,
index_t rotary_dim = 0) const
{
return operator()(q_dram_block_window_tmp,
return operator()(q_dram_block_window,
identity{},
k_dram_block_window_tmp,
knew_dram_block_window_tmp,
k_dram_block_window,
knew_dram_block_window,
identity{},
v_dram_block_window_tmp,
vnew_dram_block_window_tmp,
v_dram_block_window,
vnew_dram_block_window,
identity{},
rotary_cos_block_window_tmp,
rotary_sin_block_window_tmp,
rotary_cos_dram_block_window,
rotary_sin_dram_block_window,
smem_ptr,
rotary_dim);
}