move make tile window out of hotloop

This commit is contained in:
AMD-dteng
2025-04-27 13:50:44 +08:00
parent 0d503ae0a8
commit 272bd532f6
5 changed files with 81 additions and 47 deletions

View File

@@ -38,10 +38,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
{
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
r.x = __builtin_amdgcn_readfirstlane(r.x);
r.y = __builtin_amdgcn_readfirstlane(r.y);
r.z = __builtin_amdgcn_readfirstlane(r.z);
r.w = __builtin_amdgcn_readfirstlane(r.w);
// r.x = __builtin_amdgcn_readfirstlane(r.x);
// r.y = __builtin_amdgcn_readfirstlane(r.y);
// r.z = __builtin_amdgcn_readfirstlane(r.z);
// r.w = __builtin_amdgcn_readfirstlane(r.w);
return r;
}

View File

@@ -29,10 +29,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
{
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
r.x = __builtin_amdgcn_readfirstlane(r.x);
r.y = __builtin_amdgcn_readfirstlane(r.y);
r.z = __builtin_amdgcn_readfirstlane(r.z);
r.w = __builtin_amdgcn_readfirstlane(r.w);
// r.x = __builtin_amdgcn_readfirstlane(r.x);
// r.y = __builtin_amdgcn_readfirstlane(r.y);
// r.z = __builtin_amdgcn_readfirstlane(r.z);
// r.w = __builtin_amdgcn_readfirstlane(r.w);
return r;
}

View File

@@ -68,48 +68,51 @@ struct BlockFlatmmASmemBSmemCRegV1
// C += A * B
template <typename CBlockTensor, typename ABlockWindow, typename BFlatBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockWindow& a_block_window,
ABlockWindow& a_warp_windows,
BFlatBlockTensor& b_warp_tensor) const
{
constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];
// constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && KPerBlock == BlockGemmShape::kK, "wrong!");
// static_assert(MPerBlock == BlockGemmShape::kM && KPerBlock == BlockGemmShape::kK, "wrong!");
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
// constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp =
BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN));
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
// constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
// constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
// const index_t iMWarp = get_warp_id() / NWarp;
// construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(),
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
// auto a_warp_window_tmp = make_tile_window(
// a_block_window.get_bottom_tensor_view(),
// make_tuple(number<WG::kM>{}, number<WG::kK>{}),
// a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
// make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
// statically_indexed_array<
// statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
// MIterPerWarp>
// a_warp_windows;
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// move_tile_window(a_warp_windows(mIter)(kIter),
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
// });
// });
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;

View File

@@ -92,16 +92,17 @@ struct FlatmmPipelineAGmemBGmemCRegV1
constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp;
// constexpr index_t A_LDS_Read_Inst_Remain = A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num;
static_for<0, A_LDS_Read_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_LDS_Read_Inst_Num-A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
});
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
@@ -110,7 +111,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
});
}
@@ -134,15 +135,21 @@ struct FlatmmPipelineAGmemBGmemCRegV1
using WG = remove_cvref_t<decltype(config.template at<0>())>;
// constexpr index_t MWarp = config.template at<1>();
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
@@ -166,6 +173,25 @@ struct FlatmmPipelineAGmemBGmemCRegV1
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_warp_window_tmp = make_tile_window(
a_lds_gemm_window.get_bottom_tensor_view(),
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
a_lds_gemm_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// Block GEMM
auto block_flatmm = BlockFlatmm();
@@ -245,7 +271,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
a_block_tile = load_tile(a_copy_dram_window);
// GEMM i
block_flatmm(c_block_tile, a_lds_gemm_window, b_warp_tensor);
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor);
block_sync_lds();
@@ -278,7 +304,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
a_block_tile = load_tile(a_copy_dram_window);
// GEMM i
block_flatmm(c_block_tile, a_lds_gemm_window, b_warp_tensor_2);
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2);
block_sync_lds();
@@ -315,7 +341,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
a_block_tile = load_tile(a_copy_dram_window);
// GEMM i
block_flatmm(c_block_tile, a_lds_gemm_window, b_warp_tensor);
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor);
block_sync_lds();
@@ -340,10 +366,11 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// move to next flat K
// move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
HotLoopScheduler();
block_sync_lds();
// GEMM num_loop - 1
block_flatmm(c_block_tile, a_lds_gemm_window, b_warp_tensor_2);
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2);
}
return c_block_tile;

View File

@@ -19,11 +19,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
#if 0
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = GetSmemPackA<Problem>();
#if 1
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
@@ -39,6 +40,9 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
#endif
/*xor*/
#if 0
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = GetSmemPackA<Problem>();
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr auto DataTypeSize = sizeof(ADataType);
@@ -79,7 +83,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#endif
#if 0 /*reduce transform layers,compare with old ck*/
#if 1 /*reduce transform layers,compare with old ck*/
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();