mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Add kN0Sub to separate the n0_loop and k1_loop tile size for more flexible tuning
This commit is contained in:
@@ -145,7 +145,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
using KDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
|
||||
@@ -199,7 +199,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
@@ -401,7 +401,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers<Problem>();
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
@@ -508,7 +508,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
|
||||
@@ -719,7 +719,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::HstuAttentionTileSetting::kM0,
|
||||
Problem::HstuAttentionTileSetting::kK1,
|
||||
Problem::HstuAttentionTileSetting::kN0Sub,
|
||||
Problem::HstuAttentionTileSetting::kQKHeaddim>,
|
||||
typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps,
|
||||
typename Problem::HstuAttentionTileSetting::Gemm0WarpTile>>;
|
||||
|
||||
@@ -16,12 +16,12 @@ using HstuAttentionFwdWarpTile3 = ck_tile::sequence<32, 32, 16>;
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile;
|
||||
|
||||
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
|
||||
// Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
|
||||
//
|
||||
template <>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<32>
|
||||
{
|
||||
using type = ck_tile::sequence<64, 64, 32, 32, 32>;
|
||||
using type = ck_tile::sequence<64, 64, 32, 32, 32, 32>;
|
||||
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
|
||||
};
|
||||
@@ -29,7 +29,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<32>
|
||||
template <>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<64>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
|
||||
using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -37,7 +37,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<64>
|
||||
template <>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<128>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
|
||||
using type = ck_tile::sequence<128, 32, 16, 128, 16, 128>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -45,7 +45,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<128>
|
||||
template <>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<256>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
|
||||
using type = ck_tile::sequence<128, 32, 16, 256, 16, 256>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -53,12 +53,12 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<256>
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile;
|
||||
|
||||
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
|
||||
// Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
|
||||
//
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<32>
|
||||
{
|
||||
using type = ck_tile::sequence<64, 64, 32, 32, 32>;
|
||||
using type = ck_tile::sequence<64, 64, 32, 32, 32, 32>;
|
||||
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
|
||||
};
|
||||
@@ -66,7 +66,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<32>
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<64>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
|
||||
using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -74,7 +74,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<64>
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<128>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 128, 16, 128>;
|
||||
using type = ck_tile::sequence<128, 64, 16, 128, 16, 128>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -82,7 +82,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<128>
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<256>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
|
||||
using type = ck_tile::sequence<128, 32, 16, 256, 16, 256>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -186,12 +186,12 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<256>
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile;
|
||||
|
||||
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
|
||||
// Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
|
||||
//
|
||||
template <>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<32>
|
||||
{
|
||||
using type = ck_tile::sequence<64, 64, 32, 32, 32>;
|
||||
using type = ck_tile::sequence<64, 64, 32, 32, 32, 32>;
|
||||
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
|
||||
};
|
||||
@@ -199,7 +199,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<32>
|
||||
template <>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<64>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
|
||||
using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -207,7 +207,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<64>
|
||||
template <>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<128>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 128, 32, 128>;
|
||||
using type = ck_tile::sequence<128, 32, 32, 128, 32, 128>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -215,7 +215,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<128>
|
||||
template <>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<256>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 256, 32, 256>;
|
||||
using type = ck_tile::sequence<128, 32, 32, 256, 32, 256>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -223,12 +223,12 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<256>
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile;
|
||||
|
||||
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
|
||||
// Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
|
||||
//
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<32>
|
||||
{
|
||||
using type = ck_tile::sequence<64, 64, 32, 32, 32>;
|
||||
using type = ck_tile::sequence<64, 64, 32, 32, 32, 32>;
|
||||
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
|
||||
};
|
||||
@@ -236,7 +236,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<32>
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<64>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
|
||||
using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -244,7 +244,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<64>
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<128>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 128, 32, 128>;
|
||||
using type = ck_tile::sequence<128, 64, 32, 128, 32, 128>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -252,7 +252,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<128>
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<256>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 256, 32, 256>;
|
||||
using type = ck_tile::sequence<128, 64, 32, 256, 32, 256>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
@@ -31,6 +31,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
static constexpr index_t kM0 = HstuAttentionTileSetting::kM0;
|
||||
static constexpr index_t kN0 = HstuAttentionTileSetting::kN0;
|
||||
static constexpr index_t kN0Sub = HstuAttentionTileSetting::kN0Sub;
|
||||
static constexpr index_t kN1 = HstuAttentionTileSetting::kN1;
|
||||
static constexpr index_t kK1 = HstuAttentionTileSetting::kK1;
|
||||
static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim;
|
||||
@@ -158,17 +159,20 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t n0_loops = kN0 / kN0Sub;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(n0_loops == k1_loops, "n0_loops == k1_loops required by this pipeline");
|
||||
|
||||
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
// SaccBlockTile size is [kM0, kK1]
|
||||
// SaccBlockTile size is [kM0, kN0Sub]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kK1>());
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
|
||||
using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0>());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(CombineSaccBlockTileType{}));
|
||||
|
||||
@@ -190,7 +194,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kK1>{}, number<kSubQKHeaddim>{}),
|
||||
make_tuple(number<kN0Sub>{}, number<kSubQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
@@ -204,11 +208,11 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
statically_indexed_array<k_tile_type, k1_loops> k_tiles;
|
||||
statically_indexed_array<k_tile_type, n0_loops> k_tiles;
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
k_tiles[i_n0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -238,11 +242,11 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kSubQKHeaddim>{}));
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
using k_lds_read_window_type =
|
||||
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
@@ -250,11 +254,12 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
@@ -376,13 +381,13 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
do
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[i_k1], k_tiles[i_k1], partition_index);
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[i_n0], k_tiles[i_n0], partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
@@ -390,7 +395,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
block_sync_lds();
|
||||
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
@@ -398,8 +403,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
@@ -487,7 +492,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
// load k_tiles used by next iteration
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
static constexpr index_t kM0 = HstuAttentionTileSetting::kM0;
|
||||
static constexpr index_t kN0 = HstuAttentionTileSetting::kN0;
|
||||
static constexpr index_t kN0Sub = HstuAttentionTileSetting::kN0Sub;
|
||||
static constexpr index_t kN1 = HstuAttentionTileSetting::kN1;
|
||||
static constexpr index_t kK1 = HstuAttentionTileSetting::kK1;
|
||||
static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim;
|
||||
@@ -156,17 +157,20 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t n0_loops = kN0 / kN0Sub;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(n0_loops == k1_loops, "n0_loops == k1_loops required by this pipeline");
|
||||
|
||||
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
// SaccBlockTile size is [kM0, kK1]
|
||||
// SaccBlockTile size is [kM0, kN0Sub]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kK1>());
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
|
||||
using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0>());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(CombineSaccBlockTileType{}));
|
||||
|
||||
@@ -188,7 +192,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kK1>{}, number<kSubQKHeaddim>{}),
|
||||
make_tuple(number<kN0Sub>{}, number<kSubQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
@@ -196,11 +200,11 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
statically_indexed_array<k_tile_type, k1_loops> k_tiles;
|
||||
statically_indexed_array<k_tile_type, n0_loops> k_tiles;
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
k_tiles[i_n0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -230,11 +234,11 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kSubQKHeaddim>{}));
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
using k_lds_read_window_type =
|
||||
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
@@ -242,11 +246,12 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
@@ -344,15 +349,15 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
do
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
k_tiles[i_k1],
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
k_tiles[i_n0],
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
@@ -360,7 +365,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
block_sync_lds();
|
||||
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
@@ -368,8 +373,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
@@ -449,7 +454,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
// load k_tiles used by next iteration
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ struct HstuAttentionFwdTileSettingClass
|
||||
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
|
||||
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
|
||||
|
||||
static_assert(BlockTile::size() == 5, "Check failed!");
|
||||
static_assert(BlockTile::size() == 6, "Check failed!");
|
||||
static_assert(Gemm0BlockWarps::size() == 3, "Check failed!");
|
||||
static_assert(Gemm0WarpTile::size() == 3, "Check failed!");
|
||||
static_assert(Gemm1BlockWarps::size() == 3, "Check failed!");
|
||||
@@ -50,12 +50,13 @@ struct HstuAttentionFwdTileSettingClass
|
||||
|
||||
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
|
||||
|
||||
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
|
||||
static constexpr index_t kN1 = BlockTile::at(number<2>{}); // tile size along v head_dim
|
||||
static constexpr index_t kK1 = BlockTile::at(number<3>{}); // tile size along kv gemm unroll
|
||||
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
|
||||
static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // tile size along k seqlen
|
||||
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
|
||||
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
|
||||
static constexpr index_t kQKHeaddim =
|
||||
BlockTile::at(number<4>{}); // total length of K0, used for pipeline that need load Q at
|
||||
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
|
||||
// once (or repeately load Q as a whole tile)
|
||||
|
||||
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
|
||||
|
||||
@@ -31,6 +31,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
static constexpr index_t kM0 = HstuAttentionTileSetting::kM0;
|
||||
static constexpr index_t kN0 = HstuAttentionTileSetting::kN0;
|
||||
static constexpr index_t kN0Sub = HstuAttentionTileSetting::kN0Sub;
|
||||
static constexpr index_t kN1 = HstuAttentionTileSetting::kN1;
|
||||
static constexpr index_t kK1 = HstuAttentionTileSetting::kK1;
|
||||
static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim;
|
||||
@@ -160,8 +161,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t n0_loops = kN0 / kN0Sub;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(n0_loops >= k1_loops, "n0_loops >= k1_loops required by this pipeline");
|
||||
static_assert(k1_loops >= 2,
|
||||
"k1_loops >= 2 required due to pre-storing two v_tiles to Lds");
|
||||
|
||||
@@ -171,9 +174,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
// SaccBlockTile size is [kM0, kK1]
|
||||
// SaccBlockTile size is [kM0, kN0Sub]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kK1>());
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
|
||||
using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0>());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(CombineSaccBlockTileType{}));
|
||||
|
||||
@@ -205,7 +208,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kK1>{}, number<kSubQKHeaddim>{}),
|
||||
make_tuple(number<kN0Sub>{}, number<kSubQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
@@ -228,7 +231,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) {
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -258,11 +261,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kSubQKHeaddim>{}));
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
using k_lds_read_window_type =
|
||||
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
@@ -270,11 +273,12 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
@@ -396,23 +400,27 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
do
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
k_tiles[number<i_k1 % NumPrefetchK>{}],
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
k_tiles[number<i_n0 % NumPrefetchK>{}],
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchK)
|
||||
if constexpr(i_n0 < n0_loops - NumPrefetchK)
|
||||
{
|
||||
k_tiles[number<i_k1 % NumPrefetchK>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
k_tiles[number<i_n0 % NumPrefetchK>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[number<i_k1 - (k1_loops - NumPrefetchK)>{}] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
if constexpr(i_n0 - (n0_loops - NumPrefetchK) < k1_loops)
|
||||
{
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[number<i_n0 - (n0_loops - NumPrefetchK)>{}] =
|
||||
load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
}
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
@@ -420,7 +428,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
block_sync_lds();
|
||||
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
@@ -428,8 +436,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, siLU
|
||||
@@ -511,7 +519,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<NumPrefetchK, k1_loops, 1>{}([&](auto i_k1) {
|
||||
static_for<min(NumPrefetchK, k1_loops), k1_loops, 1>{}([&](auto i_k1) {
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
@@ -600,7 +608,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
{
|
||||
// load k_tiles used by next iteration
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -31,6 +31,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
static constexpr index_t kM0 = HstuAttentionTileSetting::kM0;
|
||||
static constexpr index_t kN0 = HstuAttentionTileSetting::kN0;
|
||||
static constexpr index_t kN0Sub = HstuAttentionTileSetting::kN0Sub;
|
||||
static constexpr index_t kN1 = HstuAttentionTileSetting::kN1;
|
||||
static constexpr index_t kK1 = HstuAttentionTileSetting::kK1;
|
||||
static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim;
|
||||
@@ -158,8 +159,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t n0_loops = kN0 / kN0Sub;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(n0_loops >= k1_loops, "n0_loops >= k1_loops required by this pipeline");
|
||||
static_assert(k1_loops >= 2,
|
||||
"k1_loops >= 2 required due to pre-storing two v_tiles to Lds");
|
||||
|
||||
@@ -169,9 +172,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
// SaccBlockTile size is [kM0, kK1]
|
||||
// SaccBlockTile size is [kM0, kN0Sub]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kK1>());
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
|
||||
using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0>());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(CombineSaccBlockTileType{}));
|
||||
|
||||
@@ -203,7 +206,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kK1>{}, number<kSubQKHeaddim>{}),
|
||||
make_tuple(number<kN0Sub>{}, number<kSubQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
@@ -223,7 +226,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) {
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -252,11 +255,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kSubQKHeaddim>{}));
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
using k_lds_read_window_type =
|
||||
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
@@ -264,11 +267,12 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
@@ -366,23 +370,27 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
do
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
k_tiles[number<i_k1 % NumPrefetchK>{}],
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
k_tiles[number<i_n0 % NumPrefetchK>{}],
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchK)
|
||||
if constexpr(i_n0 < n0_loops - NumPrefetchK)
|
||||
{
|
||||
k_tiles[number<i_k1 % NumPrefetchK>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
k_tiles[number<i_n0 % NumPrefetchK>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[number<i_k1 - (k1_loops - NumPrefetchK)>{}] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
if constexpr(i_n0 - (n0_loops - NumPrefetchK) < k1_loops)
|
||||
{
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[number<i_n0 - (n0_loops - NumPrefetchK)>{}] =
|
||||
load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
}
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
@@ -390,7 +398,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
block_sync_lds();
|
||||
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
@@ -398,8 +406,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
@@ -477,7 +485,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<NumPrefetchK, k1_loops, 1>{}([&](auto i_k1) {
|
||||
static_for<min(NumPrefetchK, k1_loops), k1_loops, 1>{}([&](auto i_k1) {
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
@@ -572,7 +580,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
{
|
||||
// load k_tiles used by next iteration
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
Reference in New Issue
Block a user