Replace s_acc and pcomp tile array by single tile object for simplification

This commit is contained in:
Qianfeng Zhang
2025-05-19 07:46:57 +00:00
parent 4e65469fe8
commit f582c21418

View File

@@ -229,8 +229,8 @@ struct HstuAttentionFwdPipelineQRKSVS
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
using PcompBlockTileType = decltype(cast_tile<CompDataType>(SaccBlockTileType{}));
statically_indexed_array<SaccBlockTileType, k1_loops> sacc_tiles;
statically_indexed_array<PcompBlockTileType, k1_loops> pcomp_tiles;
SaccBlockTileType sacc_tile;
PcompBlockTileType pcomp_tile;
// reduction function for softmax
const auto f_silu = [&](CompDataType& x) {
@@ -314,9 +314,9 @@ struct HstuAttentionFwdPipelineQRKSVS
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(sacc_tiles[i_k1], q_tile, k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}]);
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}]);
sacc_tiles[i_k1] = tile_elementwise_in(s_acc_element_func, sacc_tiles[i_k1]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
// STAGE 2, scale_s, add bias, mask, siLU
if constexpr(kHasBias)
@@ -327,15 +327,14 @@ struct HstuAttentionFwdPipelineQRKSVS
[&scale_s, &bias_element_func](auto& x, const auto& y) {
x = x * scale_s - type_convert<GemmAccDataType>(bias_element_func(y));
},
sacc_tiles[i_k1],
sacc_tile,
bias_tile);
move_tile_window(bias_dram_window, {0, kK1});
}
else
{
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; },
sacc_tiles[i_k1]);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, sacc_tile);
}
if constexpr(HstuMask::IsMasking)
@@ -346,15 +345,14 @@ struct HstuAttentionFwdPipelineQRKSVS
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
sacc_tiles[i_k1].get_tile_distribution(),
make_tuple(idx0, idx1));
sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
sacc_tiles[i_k1](i_j_idx) *= static_cast<GemmAccDataType>(
sacc_tile(i_j_idx) *= static_cast<GemmAccDataType>(
mask.IsTokenPairInsideMask(row, col));
});
});
@@ -370,15 +368,14 @@ struct HstuAttentionFwdPipelineQRKSVS
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
sacc_tiles[i_k1].get_tile_distribution(),
make_tuple(idx0, idx1));
sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
sacc_tiles[i_k1](i_j_idx) *= static_cast<GemmAccDataType>(
sacc_tile(i_j_idx) *= static_cast<GemmAccDataType>(
mask.IsTokenPairInsideMask(row, col));
});
});
@@ -393,22 +390,21 @@ struct HstuAttentionFwdPipelineQRKSVS
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
sacc_tiles[i_k1].get_tile_distribution(),
make_tuple(idx0, idx1));
sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
sacc_tiles[i_k1](i_j_idx) *= static_cast<GemmAccDataType>(
sacc_tile(i_j_idx) *= static_cast<GemmAccDataType>(
mask.IsTokenPairInsideMask(row, col));
});
});
}
}
pcomp_tiles[i_k1] = cast_tile<CompDataType>(sacc_tiles[i_k1]);
pcomp_tile = cast_tile<CompDataType>(sacc_tile);
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
@@ -438,7 +434,7 @@ struct HstuAttentionFwdPipelineQRKSVS
__builtin_amdgcn_sched_barrier(0);
tile_elementwise_inout(f_silu, pcomp_tiles[i_k1]);
tile_elementwise_inout(f_silu, pcomp_tile);
if constexpr(kHasDropout)
{
@@ -446,11 +442,11 @@ struct HstuAttentionFwdPipelineQRKSVS
Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
randval_lds_ptr, seqlen_k_curr, pcomp_tiles[i_k1], null_randval_window);
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
}
auto p = cast_tile<PDataType>(
tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1]));
auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
block_sync_lds();