mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Change to pipeline so that it is easier to add support of using softmax
This commit is contained in:
@@ -165,8 +165,12 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(SaccBlockTileType{}));
|
||||
// SaccBlockTile size is [kM0, kK1]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
// using PcompBlockTileType = decltype(cast_tile<CompDataType>(SaccBlockTileType{}));
|
||||
using PcompBlockTileType = decltype(make_static_distributed_tensor<CompDataType>(
|
||||
Policy::template MakePRegTileDistribution<Problem>()));
|
||||
|
||||
SaccBlockTileType sacc_tile;
|
||||
PcompBlockTileType pcomp_tile;
|
||||
@@ -198,8 +202,14 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
move_tile_window(q_dram_window, {kGemmSingleRepM, 0});
|
||||
});
|
||||
|
||||
auto k_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
statically_indexed_array<k_tile_type, k1_loops> k_tiles;
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -323,14 +333,6 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
q_tile_type q_tile;
|
||||
|
||||
{
|
||||
constexpr index_t complete_tile_thread_buf_size = q_tile_type::get_thread_buffer_size();
|
||||
constexpr index_t splitted_tile_thread_buf_size =
|
||||
q_reg_tile_type::get_thread_buffer_size();
|
||||
|
||||
static_assert(complete_tile_thread_buf_size ==
|
||||
kGemmNumRepM * splitted_tile_thread_buf_size,
|
||||
"Check failed!");
|
||||
|
||||
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
|
||||
store_tile(q_lds_write_window, q_dram_tiles[i_rep]);
|
||||
|
||||
@@ -368,128 +370,142 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
using v_tile_type = decltype(load_tile(v_dram_window));
|
||||
|
||||
v_tile_type v_tile;
|
||||
|
||||
store_tile(k_lds_write_windows[number<0>{}], tile_elementwise_in(k_element_func, k_tile));
|
||||
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
|
||||
|
||||
do
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
// load v_tile for current unroll
|
||||
v_tile = load_tile(v_dram_window);
|
||||
store_tile(k_lds_write_windows[i_k1],
|
||||
tile_elementwise_in(k_element_func, k_tiles[i_k1]));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
// for i_k1 = k1_loop-1, the loading is for next iteration
|
||||
k_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, siLU
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
auto sacc_tile_tmp = cast_tile<CompDataType>(sacc_tile);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&scale_s, &bias_element_func](auto& x, const auto& y) {
|
||||
x = x * scale_s - type_convert<GemmAccDataType>(bias_element_func(y));
|
||||
},
|
||||
sacc_tile,
|
||||
bias_tile);
|
||||
using pcomp_tile_tmp_type =
|
||||
decltype(get_slice_tile(pcomp_tile, sequence<0, 0>{}, sequence<kM0, kK1>{}));
|
||||
|
||||
move_tile_window(bias_dram_window, {0, kK1});
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, sacc_tile);
|
||||
}
|
||||
pcomp_tile_tmp_type pcomp_tile_tmp;
|
||||
|
||||
if(!mask.IsFullTileInsideMask(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kK1>{}, number<kM0>{}))
|
||||
{
|
||||
constexpr auto s_spans = SaccBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
pcomp_tile_tmp.get_thread_buffer() = sacc_tile_tmp.get_thread_buffer();
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
set_slice_tile(pcomp_tile,
|
||||
pcomp_tile_tmp,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
});
|
||||
|
||||
sacc_tile(i_j_idx) *=
|
||||
static_cast<GemmAccDataType>(mask.IsTokenPairInsideMask(row, col));
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, siLU
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&scale_s, &bias_element_func](auto& x, const auto& y) {
|
||||
x = x * scale_s - type_convert<CompDataType>(bias_element_func(y));
|
||||
},
|
||||
pcomp_tile,
|
||||
bias_tile);
|
||||
|
||||
move_tile_window(bias_dram_window, {0, kK1});
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
|
||||
}
|
||||
|
||||
if(!mask.IsFullTileInsideMask(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
|
||||
{
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
pcomp_tile(i_j_idx) *=
|
||||
static_cast<CompDataType>(mask.IsTokenPairInsideMask(row, col));
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pcomp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
tile_elementwise_inout(f_silu, pcomp_tile);
|
||||
|
||||
tile_elementwise_inout(f_silu, pcomp_tile);
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegTileDistribution<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_tile);
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
// if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer
|
||||
// i+2, No overlap occurs between V and K in the same unroll, and V in current
|
||||
// unroll and K in next unroll or first unroll in next iteration
|
||||
store_tile(
|
||||
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
}
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_lds_ptr = reinterpret_cast<char*>(smem_ptr) +
|
||||
Policy::template GetSmemSizeKV<Problem>();
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
}
|
||||
using v_shuffled_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegTileDistribution<Problem>()));
|
||||
|
||||
auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
statically_indexed_array<v_shuffled_tile_type, k1_loops> v_shuffled_tiles;
|
||||
|
||||
static_for<0, k1_loops, 1>{}(
|
||||
[&](auto i_k1) { shuffle_tile(v_shuffled_tiles[i_k1], v_tiles[i_k1]); });
|
||||
|
||||
// check whether first V-LdsBufer overlap with next K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
// STAGE 3, Gemm_1 ( O = P@V )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffled_tiles[i_k1]));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// load k_tiles used by next iteration
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
gemm_1(o_acc, p, v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
|
||||
|
||||
seqlen_k_curr += kK1;
|
||||
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
{
|
||||
// check whether current V-LdsBufer overlap with next K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((i_k1 + 2) % NumKVLdsBuffers == (i_k1 + 1) % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(k_lds_write_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tile));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
}
|
||||
else
|
||||
{
|
||||
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((i_k1 + 2) % NumKVLdsBuffers == 0)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(k_lds_write_windows[number<0>{}],
|
||||
tile_elementwise_in(k_element_func, k_tile));
|
||||
}
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
|
||||
});
|
||||
|
||||
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
tile_elementwise_inout([&](auto& x) { x = x * type_convert<GemmAccDataType>(scale_p); },
|
||||
|
||||
@@ -68,12 +68,20 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
return WG::WarpGemmAttribute::kKPerThread;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePRegTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
|
||||
|
||||
return BlockGemm::template MakeABlockTileDistribution<
|
||||
Problem::HstuAttentionTileSetting::kM0,
|
||||
Problem::HstuAttentionTileSetting::kN0>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
|
||||
return BlockGemm::MakeCBlockTile().get_tile_distribution();
|
||||
return MakePRegTileDistribution<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
Reference in New Issue
Block a user