mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Put two gemms call inside one n0loop unroll
This commit is contained in:
@@ -155,17 +155,10 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
static_assert(2 <= k1_loops);
|
||||
|
||||
constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>();
|
||||
constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
|
||||
constexpr auto NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
|
||||
|
||||
static_assert(NumKLdsBuffers >= 2);
|
||||
static_assert(NumPrefetchV >= 2);
|
||||
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
@@ -199,9 +192,9 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
using k_lds_window_type =
|
||||
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_window_type, NumKLdsBuffers> k_lds_windows;
|
||||
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
|
||||
|
||||
static_for<0, NumKLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
@@ -219,16 +212,12 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using v_tile_type = decltype(load_tile(v_dram_window));
|
||||
|
||||
statically_indexed_array<v_tile_type, NumPrefetchV> v_tiles;
|
||||
|
||||
using v_lds_window_type =
|
||||
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{}));
|
||||
|
||||
statically_indexed_array<v_lds_window_type, NumVLdsBuffers> v_lds_windows;
|
||||
statically_indexed_array<v_lds_window_type, NumKVLdsBuffers> v_lds_windows;
|
||||
|
||||
static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
v_lds_windows[i_buf] = get_slice_tile(
|
||||
v_lds_window, sequence<i_buf * kN1, 0>{}, sequence<(i_buf + 1) * kN1, kK1>{});
|
||||
});
|
||||
@@ -307,6 +296,9 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
auto v_tile = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
index_t i_loop = 0;
|
||||
@@ -314,25 +306,18 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
do
|
||||
{
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_windows[number<i_k1 % NumKLdsBuffers>{}],
|
||||
store_tile(k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tile));
|
||||
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
{
|
||||
k_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
// for i_k1 = k1_loop-1, the loading is for next iteration
|
||||
k_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
|
||||
clear_tile(sacc_tiles[i_k1]);
|
||||
|
||||
block_sync_lds();
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tiles[i_k1], q_tile, k_lds_windows[number<i_k1 % NumKLdsBuffers>{}]);
|
||||
gemm_0(sacc_tiles[i_k1], q_tile, k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tiles[i_k1] = tile_elementwise_in(s_acc_element_func, sacc_tiles[i_k1]);
|
||||
|
||||
@@ -426,95 +411,33 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
pcomp_tiles[i_k1] = cast_tile<CompDataType>(sacc_tiles[i_k1]);
|
||||
|
||||
seqlen_k_curr += kK1;
|
||||
});
|
||||
|
||||
static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
|
||||
v_tiles[i_buf] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
|
||||
|
||||
// ensure gemm_0 has finished access of k-Lds for all warps
|
||||
// the over-lap only occurs when k0_loops is 3/5/7, NumKLdsBuffers is 2
|
||||
if constexpr(Policy::template IsFirstVLdsBufferOverlapLastKLdsBuffer<Problem>())
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
store_tile(
|
||||
v_lds_windows[I0],
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
// ensure gemm_0 has finished access of k-Lds for all warps
|
||||
if constexpr(Policy::template IsFirstVLdsBufferOverlapLastKLdsBuffer<Problem>())
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
store_tile(v_lds_windows[I0],
|
||||
tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch
|
||||
};
|
||||
|
||||
tile_elementwise_inout(f_silu, pcomp_tiles[I0]);
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr - kN0, pcomp_tiles[I0], null_randval_window);
|
||||
}
|
||||
|
||||
auto p = [&]() {
|
||||
if constexpr(std::is_same_v<PDataType, fp16_t>)
|
||||
return impl::cast_tile_pk_fp16_fp32<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, pcomp_tiles[I0]));
|
||||
else
|
||||
return cast_tile<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, pcomp_tiles[I0]));
|
||||
}();
|
||||
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchV)
|
||||
{
|
||||
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
}
|
||||
else if constexpr(i_k1 == k1_loops - NumPrefetchV)
|
||||
{
|
||||
// load one k_tile for next iteration
|
||||
k_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
gemm_1(o_acc, p, v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
|
||||
tile_elementwise_inout(f_silu, pcomp_tiles[number<i_k1 + 1>{}]);
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]);
|
||||
shuffle_tile(v_shuffle_tmp, v_tile);
|
||||
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_shuffle_tmp)); // store the prefetch
|
||||
// if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer
|
||||
// i+1, No overlap occurs between V and K in the same unroll, and V in current
|
||||
// unroll and K in next unroll or first unrool in next iteration
|
||||
store_tile(
|
||||
v_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(
|
||||
v_element_func,
|
||||
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); // store the
|
||||
// prefetch
|
||||
}
|
||||
// if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer
|
||||
// i+1, No overlap occurs between V and K in the same unroll, and V in current
|
||||
// unroll and K in next unroll or first unrool in next iteration
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_tile)); // store the prefetch
|
||||
};
|
||||
|
||||
// for i_k1 = k1_loops-1, the loading is for next iteration
|
||||
v_tile = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
tile_elementwise_inout(f_silu, pcomp_tiles[i_k1]);
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
@@ -522,26 +445,26 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr,
|
||||
seqlen_k_curr - kN0 + (i_k1 + 1) * kK1,
|
||||
pcomp_tiles[number<i_k1 + 1>{}],
|
||||
null_randval_window);
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tiles[i_k1], null_randval_window);
|
||||
}
|
||||
|
||||
p = [&]() {
|
||||
auto p = [&]() {
|
||||
if constexpr(std::is_same_v<PDataType, fp16_t>)
|
||||
return impl::cast_tile_pk_fp16_fp32<PDataType>(tile_elementwise_in(
|
||||
p_compute_element_func, pcomp_tiles[number<i_k1 + 1>{}]));
|
||||
return impl::cast_tile_pk_fp16_fp32<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1]));
|
||||
else
|
||||
return cast_tile<PDataType>(tile_elementwise_in(
|
||||
p_compute_element_func, pcomp_tiles[number<i_k1 + 1>{}]));
|
||||
return cast_tile<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1]));
|
||||
}();
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
gemm_1(o_acc, p, v_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}]);
|
||||
|
||||
seqlen_k_curr += kK1;
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc, p, v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]);
|
||||
|
||||
// the over-lap only occurs when k1_loops is 3/5/7, NumVLdsBuffers is 2
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 3
|
||||
if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer<Problem>())
|
||||
__builtin_amdgcn_s_barrier();
|
||||
} while(++i_loop < num_loops);
|
||||
|
||||
@@ -12,35 +12,14 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopy = */ false,
|
||||
/* NumPrefetchK = */ -1,
|
||||
/* NumPrefetchV = */ 2>
|
||||
/* NumPrefetchV = */ 1>
|
||||
{
|
||||
static constexpr index_t NumPrefetchV = 2;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetNumKLdsBuffers()
|
||||
CK_TILE_DEVICE static constexpr auto GetNumKVLdsBuffers()
|
||||
{
|
||||
return 2;
|
||||
return 3;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetNumPrefetchV()
|
||||
{
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
return min(NumPrefetchV, k1_loops);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumVLdsBuffers()
|
||||
{
|
||||
return 2;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
|
||||
{
|
||||
@@ -120,7 +99,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers<Problem>();
|
||||
constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers<Problem>();
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
@@ -234,7 +213,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
{
|
||||
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers<Problem>();
|
||||
constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QKVDataType);
|
||||
@@ -422,33 +401,21 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstVLdsBufferOverlapLastKLdsBuffer()
|
||||
{
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
constexpr index_t k0_loops = BlockFmhaShape::kQKHeaddim / BlockFmhaShape::kK0;
|
||||
constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers<Problem>();
|
||||
|
||||
return (k0_loops - 1) % num_k_lds_buffers == 0;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer()
|
||||
{
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1;
|
||||
constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers<Problem>();
|
||||
constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1;
|
||||
constexpr index_t num_kv_lds_buffers = GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
return (k1_loops - 1) % num_v_lds_buffers == 0;
|
||||
return (k1_loops - 1 + 1) % num_kv_lds_buffers == 0;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
|
||||
{
|
||||
constexpr index_t num_kv_lds_buffers =
|
||||
max(GetNumKLdsBuffers<Problem>(), GetNumVLdsBuffers<Problem>());
|
||||
constexpr index_t num_kv_lds_buffers = GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
return num_kv_lds_buffers * GetSingleSmemElementSpaceSize<Problem>() *
|
||||
sizeof(typename Problem::QKVDataType);
|
||||
|
||||
Reference in New Issue
Block a user