mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Load Q directly from global memory to registers for BlockGemm
This commit is contained in:
@@ -96,9 +96,6 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
|
||||
// ToDo: should we define kUseTrLoad and kLoadWholeQTileOnceThrough Lds here ?
|
||||
static constexpr bool kLoadWholeQTileOnceThroughLds = kUseTrLoad ? true : false;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
|
||||
@@ -75,9 +75,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM<Problem>();
|
||||
static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM;
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
@@ -223,11 +220,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
OaccBlockTileType o_acc;
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
@@ -239,14 +235,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
using q_dram_tile_type = decltype(load_tile(q_dram_window));
|
||||
statically_indexed_array<q_dram_tile_type, kGemmNumRepM> q_dram_tiles;
|
||||
|
||||
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
|
||||
q_dram_tiles[i_rep] = load_tile(q_dram_window);
|
||||
move_tile_window(q_dram_window, {kGemmSingleRepM, 0});
|
||||
});
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
// only prefetch two k tiles to save vgprs consumption
|
||||
@@ -260,24 +248,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
k_tiles[I0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// provide partition_index for LDS tile window with so that warp_id is in vgpr
|
||||
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
|
||||
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(smem_ptr);
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_write_window = make_tile_window(
|
||||
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>(),
|
||||
partition_index);
|
||||
|
||||
// K tile in LDS
|
||||
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
@@ -368,47 +345,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
|
||||
using q_reg_tile_type = decltype(make_static_distributed_tensor<QDataType>(
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>()));
|
||||
statically_indexed_array<q_reg_tile_type, kGemmNumRepM> q_reg_tiles;
|
||||
|
||||
using q_tile_type = decltype(make_static_distributed_tensor<QDataType>(
|
||||
Policy::template MakeQRegTileDistribution<Problem>()));
|
||||
|
||||
q_tile_type q_tile;
|
||||
|
||||
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
|
||||
store_tile(q_lds_write_window, q_dram_tiles[i_rep], partition_index);
|
||||
|
||||
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
|
||||
// by each wavefront is read by itself
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
|
||||
q_reg_tiles[i_rep] = load_tile(q_lds_read_window);
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
|
||||
// the following codes will not generate actual instructions by the compiler
|
||||
set_slice_tile(q_tile,
|
||||
q_reg_tiles[i_rep],
|
||||
sequence<i_rep * kGemmSingleRepM, 0>{},
|
||||
sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{});
|
||||
|
||||
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice read
|
||||
// by each wavefront is over-written by itself
|
||||
});
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
using v_tile_type = decltype(load_tile(v_dram_window));
|
||||
|
||||
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
|
||||
|
||||
@@ -43,16 +43,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
return 4;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSingleRepMTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr index_t kBlockGemmM = GetQKBlockGemmSingleRepM<Problem>();
|
||||
|
||||
return BlockGemm::
|
||||
template MakeABlockTileDistribution<kBlockGemmM, Problem::BlockFmhaShape::kQKHeaddim>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
|
||||
{
|
||||
@@ -103,33 +93,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
|
||||
{
|
||||
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
return 8;
|
||||
else
|
||||
return 4;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
|
||||
{
|
||||
if constexpr(Problem::kLoadWholeQTileOnceThroughLds)
|
||||
{
|
||||
return Problem::GetQDramTileAccessMaxVectorSize();
|
||||
}
|
||||
else
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
return detail::
|
||||
GetDramTileAccessMaxVectorSize<QDataType, kBlockSize, kMPerBlock, kKPerBlock>();
|
||||
};
|
||||
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -257,217 +230,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
GetVSingleSmemElementSpaceSize<Problem>());
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds
|
||||
? Problem::BlockFmhaShape::kM0
|
||||
: GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentQ<Problem>();
|
||||
|
||||
// for hdim96 and hdim160, use simplest layout
|
||||
if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<1>{}),
|
||||
number<kKVector>{},
|
||||
number<1>{});
|
||||
}
|
||||
else if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
|
||||
constexpr index_t DataTypeSize = sizeof(QDataType);
|
||||
|
||||
// 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes
|
||||
constexpr auto MLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock / MLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * MLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock * MLdsLayer>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
q_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<kMPerBlock / MLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto q_lds_block_desc_k0_mldslayer_m_k1 = transform_tensor_descriptor(
|
||||
q_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<MLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
|
||||
q_lds_block_desc_k0_mldslayer_m_k1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return q_lds_block_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKVector % kKPack == 0);
|
||||
|
||||
constexpr auto q_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kMPerBlock>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kMPerBlock * kKVector + kKPack>{},
|
||||
number<kMPerBlock * kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
|
||||
q_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kKPack>{}))),
|
||||
make_tuple(sequence<2>{}, sequence<0, 1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return q_lds_block_desc;
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramSingleRepMTileDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t kKVector = GetAlignmentQ<Problem>();
|
||||
constexpr index_t OtherK = kKPerBlock / kKVector;
|
||||
|
||||
if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim)
|
||||
// for kKPerBlock=32,64,128,256
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
|
||||
|
||||
constexpr index_t KPerThread = kKVector;
|
||||
constexpr index_t KThreads = OtherK;
|
||||
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
else // for kKPerBlock=96,160
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
|
||||
|
||||
// ToDo: need more considieration for hdim72
|
||||
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
|
||||
constexpr index_t KThreads = OtherK / KRepPerThread;
|
||||
|
||||
static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!");
|
||||
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KRepPerThread, KThreads, kKVector>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<1, 0, 2>>{});
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t kKVector = GetAlignmentQ<Problem>();
|
||||
constexpr index_t OtherK = kKPerBlock / kKVector;
|
||||
|
||||
if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim)
|
||||
// for kKPerBlock=32,64,128,256
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
|
||||
|
||||
constexpr index_t KPerThread = kKVector;
|
||||
constexpr index_t KThreads = OtherK;
|
||||
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
else // for kKPerBlock=96,160
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
|
||||
|
||||
// ToDo: need more considieration for hdim72
|
||||
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
|
||||
constexpr index_t KThreads = OtherK / KRepPerThread;
|
||||
|
||||
static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!");
|
||||
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KRepPerThread, KThreads, kKVector>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<1, 0, 2>>{});
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
@@ -823,13 +585,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
sequence<1, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetQKBlockGemmSingleRepM()
|
||||
{
|
||||
return Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}) *
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
@@ -976,13 +731,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
|
||||
{
|
||||
return MakeQLdsBlockDescriptor<Problem>().get_element_space_size() *
|
||||
sizeof(typename Problem::QDataType);
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
|
||||
{
|
||||
@@ -1001,8 +749,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return max(GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(),
|
||||
GetSmemSizeQ<Problem>());
|
||||
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -65,8 +65,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
|
||||
static_assert(Problem::kLoadWholeQTileOnceThroughLds == true, "Check failed!");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
@@ -226,11 +224,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
OaccBlockTileType o_acc;
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
@@ -242,7 +239,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
auto q_dram_tile = load_tile(q_dram_window);
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
@@ -262,19 +259,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
// provide partition_index for LDS tile window with so that warp_id is in vgpr
|
||||
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
|
||||
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(smem_ptr);
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_write_window = make_tile_window(
|
||||
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegTileDistribution<Problem>(),
|
||||
partition_index);
|
||||
|
||||
// K tile in LDS
|
||||
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
@@ -361,18 +345,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
|
||||
store_tile(q_lds_write_window, q_dram_tile, partition_index);
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
auto q_tile = load_tile(q_lds_read_window);
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
@@ -380,13 +353,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
using v_tile_type = decltype(load_tile(v_dram_window));
|
||||
|
||||
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
|
||||
|
||||
Reference in New Issue
Block a user