Move silu calculation to gemm1 iteration and try to interleave gemm_1 and silu

This commit is contained in:
Qianfeng Zhang
2025-04-23 13:10:02 +00:00
parent 2d2e1941a8
commit ce4665262b

View File

@@ -356,7 +356,7 @@ struct HstuAttentionFwdPipelineQRKSVS
set_tile_if(
sacc_tiles[i_k1], type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
const auto col = seqlen_k_curr + i_k1 * kK1 + tile_idx.at(number<1>{});
return !mask.IsTokenPairInsideMask(row, col);
});
}
@@ -368,26 +368,14 @@ struct HstuAttentionFwdPipelineQRKSVS
sacc_tiles[i_k1], type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
const auto col =
seqlen_k_curr + i_k1 * kK1 + tile_idx.at(number<1>{});
return !mask.IsTokenPairInsideMask(row, col);
});
}
}
pcomp_tiles[i_k1] = cast_tile<CompDataType>(sacc_tiles[i_k1]);
tile_elementwise_inout(f_silu, pcomp_tiles[i_k1]);
if constexpr(kHasDropout)
{
auto randval_lds_ptr = reinterpret_cast<char*>(smem_ptr) +
Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
randval_lds_ptr, seqlen_k_curr, pcomp_tiles[i_k1], null_randval_window);
}
seqlen_k_curr += kK1;
});
// load one k_tile for next iteration
@@ -419,16 +407,29 @@ struct HstuAttentionFwdPipelineQRKSVS
tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch
};
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
const auto p = [&]() {
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pk_fp16_fp32<PDataType>(
tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1]));
else
return cast_tile<PDataType>(
tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1]));
}();
tile_elementwise_inout(f_silu, pcomp_tiles[I0]);
if constexpr(kHasDropout)
{
auto randval_lds_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
randval_lds_ptr, seqlen_k_curr, pcomp_tiles[I0], null_randval_window);
}
seqlen_k_curr += kK1;
auto p = [&]() {
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pk_fp16_fp32<PDataType>(
tile_elementwise_in(p_compute_element_func, pcomp_tiles[I0]));
else
return cast_tile<PDataType>(
tile_elementwise_in(p_compute_element_func, pcomp_tiles[I0]));
}();
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 < k1_loops - NumPrefetchV)
{
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
@@ -436,31 +437,56 @@ struct HstuAttentionFwdPipelineQRKSVS
};
block_sync_lds();
gemm_1(o_acc, p, v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
tile_elementwise_inout(f_silu, pcomp_tiles[number<i_k1 + 1>{}]);
if constexpr(i_k1 < k1_loops - 1)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]);
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]);
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(
v_element_func,
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); // store the
// prefetch
}
if constexpr(kHasDropout)
{
auto randval_lds_ptr = reinterpret_cast<char*>(smem_ptr) +
Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
randval_lds_ptr,
seqlen_k_curr,
pcomp_tiles[number<i_k1 + 1>{}],
null_randval_window);
}
seqlen_k_curr += kK1;
p = [&]() {
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pk_fp16_fp32<PDataType>(tile_elementwise_in(
p_compute_element_func, pcomp_tiles[number<i_k1 + 1>{}]));
else
{
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(
v_element_func,
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); // store the
// prefetch
}
};
return cast_tile<PDataType>(tile_elementwise_in(
p_compute_element_func, pcomp_tiles[number<i_k1 + 1>{}]));
}();
});
block_sync_lds();
gemm_1(o_acc, p, v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]);
// the over-lap only occurs when k1_loops is 3/5/7, NumVLdsBuffers is 2
if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer<Problem>())
__builtin_amdgcn_s_barrier();