mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
Load Q through Lds
This commit is contained in:
@@ -74,6 +74,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM<Problem>();
|
||||
static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM;
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
@@ -220,10 +223,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
OaccBlockTileType o_acc;
|
||||
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kSubQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kSubQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
@@ -235,6 +239,14 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
using q_dram_tile_type = decltype(load_tile(q_dram_window));
|
||||
statically_indexed_array<q_dram_tile_type, kGemmNumRepM> q_dram_tiles;
|
||||
|
||||
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
|
||||
q_dram_tiles[i_rep] = load_tile(q_dram_window);
|
||||
move_tile_window(q_dram_window, {kGemmSingleRepM, 0});
|
||||
});
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
// only prefetch two k tiles to save vgprs consumption
|
||||
@@ -250,9 +262,22 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
// provide partition_index for LDS tile window with so that warp_id is in vgpr
|
||||
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(smem_ptr);
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_write_window = make_tile_window(
|
||||
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>(),
|
||||
partition_index);
|
||||
|
||||
// K tile in LDS
|
||||
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
|
||||
@@ -354,19 +379,58 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
|
||||
using q_reg_tile_type = decltype(make_static_distributed_tensor<QDataType>(
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>()));
|
||||
statically_indexed_array<q_reg_tile_type, kGemmNumRepM> q_reg_tiles;
|
||||
|
||||
using q_tile_type = decltype(make_static_distributed_tensor<QDataType>(
|
||||
Policy::template MakeQRegTileDistribution<Problem>()));
|
||||
|
||||
q_tile_type q_tile;
|
||||
|
||||
{
|
||||
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
|
||||
store_tile(q_lds_write_window, q_dram_tiles[i_rep], partition_index);
|
||||
|
||||
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
|
||||
// by each wavefront is read by itself
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
|
||||
q_reg_tiles[i_rep] = load_tile(q_lds_read_window);
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
|
||||
// the following codes will not generate actual instructions by the compiler
|
||||
set_slice_tile(q_tile,
|
||||
q_reg_tiles[i_rep],
|
||||
sequence<i_rep * kGemmSingleRepM, 0>{},
|
||||
sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{});
|
||||
|
||||
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice read
|
||||
// by each wavefront is over-written by itself
|
||||
});
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
clear_tile(l);
|
||||
};
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
using v_tile_type = decltype(load_tile(v_dram_window));
|
||||
|
||||
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
|
||||
|
||||
// provide partition_index for LDS tile window with so that warp_id is in vgpr
|
||||
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
|
||||
|
||||
do
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
|
||||
@@ -40,6 +40,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
return 4;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSingleRepMTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr index_t kBlockGemmM = GetQKBlockGemmSingleRepM<Problem>();
|
||||
|
||||
return BlockGemm::
|
||||
template MakeABlockTileDistribution<kBlockGemmM, Problem::BlockFmhaShape::kQKHeaddim>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
|
||||
{
|
||||
@@ -90,16 +100,28 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
|
||||
{
|
||||
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
return 8;
|
||||
else
|
||||
return 4;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
|
||||
{
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
|
||||
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
|
||||
|
||||
return min(MaxVectorSize, ElemPerThread);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -209,6 +231,118 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
GetVSingleSmemElementSpaceSize<Problem>());
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentQ<Problem>();
|
||||
|
||||
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
|
||||
constexpr index_t DataTypeSize = sizeof(QDataType);
|
||||
|
||||
// 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes
|
||||
constexpr auto MLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock / MLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * MLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock * MLdsLayer>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
q_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<kMPerBlock / MLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto q_lds_block_desc_k0_mldslayer_m_k1 = transform_tensor_descriptor(
|
||||
q_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<MLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
|
||||
q_lds_block_desc_k0_mldslayer_m_k1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return q_lds_block_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKVector % kKPack == 0);
|
||||
|
||||
constexpr auto q_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kMPerBlock>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kMPerBlock * kKVector + kKPack>{},
|
||||
number<kMPerBlock * kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
|
||||
q_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kKPack>{}))),
|
||||
make_tuple(sequence<2>{}, sequence<0, 1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return q_lds_block_desc;
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramSingleRepMTileDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t kKVector = GetAlignmentQ<Problem>();
|
||||
|
||||
constexpr index_t KPerThread = kKVector;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
// for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E]
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MThreadPerWarp, MPerThread>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
@@ -317,19 +451,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
|
||||
constexpr index_t KPerThread = kMaxVecLoad;
|
||||
constexpr index_t KPerThread = kKVector;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
@@ -446,6 +574,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
sequence<1, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetQKBlockGemmSingleRepM()
|
||||
{
|
||||
return Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}) *
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
@@ -587,6 +722,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
|
||||
{
|
||||
return MakeQLdsBlockDescriptor<Problem>().get_element_space_size() *
|
||||
sizeof(typename Problem::QDataType);
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
|
||||
{
|
||||
@@ -605,7 +747,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>();
|
||||
return max(GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(),
|
||||
GetSmemSizeQ<Problem>());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user