mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Replace s_acc and pcomp tile array by single tile object for simplification
This commit is contained in:
@@ -229,8 +229,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(SaccBlockTileType{}));
|
||||
|
||||
statically_indexed_array<SaccBlockTileType, k1_loops> sacc_tiles;
|
||||
statically_indexed_array<PcompBlockTileType, k1_loops> pcomp_tiles;
|
||||
SaccBlockTileType sacc_tile;
|
||||
PcompBlockTileType pcomp_tile;
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_silu = [&](CompDataType& x) {
|
||||
@@ -314,9 +314,9 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
block_sync_lds();
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tiles[i_k1], q_tile, k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tiles[i_k1] = tile_elementwise_in(s_acc_element_func, sacc_tiles[i_k1]);
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, siLU
|
||||
if constexpr(kHasBias)
|
||||
@@ -327,15 +327,14 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
[&scale_s, &bias_element_func](auto& x, const auto& y) {
|
||||
x = x * scale_s - type_convert<GemmAccDataType>(bias_element_func(y));
|
||||
},
|
||||
sacc_tiles[i_k1],
|
||||
sacc_tile,
|
||||
bias_tile);
|
||||
|
||||
move_tile_window(bias_dram_window, {0, kK1});
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; },
|
||||
sacc_tiles[i_k1]);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, sacc_tile);
|
||||
}
|
||||
|
||||
if constexpr(HstuMask::IsMasking)
|
||||
@@ -346,15 +345,14 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
sacc_tiles[i_k1].get_tile_distribution(),
|
||||
make_tuple(idx0, idx1));
|
||||
sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
sacc_tiles[i_k1](i_j_idx) *= static_cast<GemmAccDataType>(
|
||||
sacc_tile(i_j_idx) *= static_cast<GemmAccDataType>(
|
||||
mask.IsTokenPairInsideMask(row, col));
|
||||
});
|
||||
});
|
||||
@@ -370,15 +368,14 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
sacc_tiles[i_k1].get_tile_distribution(),
|
||||
make_tuple(idx0, idx1));
|
||||
sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
sacc_tiles[i_k1](i_j_idx) *= static_cast<GemmAccDataType>(
|
||||
sacc_tile(i_j_idx) *= static_cast<GemmAccDataType>(
|
||||
mask.IsTokenPairInsideMask(row, col));
|
||||
});
|
||||
});
|
||||
@@ -393,22 +390,21 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
sacc_tiles[i_k1].get_tile_distribution(),
|
||||
make_tuple(idx0, idx1));
|
||||
sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
sacc_tiles[i_k1](i_j_idx) *= static_cast<GemmAccDataType>(
|
||||
sacc_tile(i_j_idx) *= static_cast<GemmAccDataType>(
|
||||
mask.IsTokenPairInsideMask(row, col));
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pcomp_tiles[i_k1] = cast_tile<CompDataType>(sacc_tiles[i_k1]);
|
||||
pcomp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -438,7 +434,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
tile_elementwise_inout(f_silu, pcomp_tiles[i_k1]);
|
||||
tile_elementwise_inout(f_silu, pcomp_tile);
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
@@ -446,11 +442,11 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tiles[i_k1], null_randval_window);
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
}
|
||||
|
||||
auto p = cast_tile<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1]));
|
||||
auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user