mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
move make tile window out of hotloop
This commit is contained in:
@@ -38,10 +38,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
|
||||
{
|
||||
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
|
||||
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
|
||||
r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
// r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
// r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
// r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
// r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
@@ -29,10 +29,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
|
||||
{
|
||||
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
|
||||
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
|
||||
r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
// r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
// r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
// r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
// r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
@@ -68,48 +68,51 @@ struct BlockFlatmmASmemBSmemCRegV1
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockWindow, typename BFlatBlockTensor>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockWindow& a_block_window,
|
||||
ABlockWindow& a_warp_windows,
|
||||
BFlatBlockTensor& b_warp_tensor) const
|
||||
{
|
||||
constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];
|
||||
// constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
|
||||
// constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && KPerBlock == BlockGemmShape::kK, "wrong!");
|
||||
// static_assert(MPerBlock == BlockGemmShape::kM && KPerBlock == BlockGemmShape::kK, "wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
// constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp =
|
||||
BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN));
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
// constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
// constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
// const index_t iMWarp = get_warp_id() / NWarp;
|
||||
|
||||
// construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
// auto a_warp_window_tmp = make_tile_window(
|
||||
// a_block_window.get_bottom_tensor_view(),
|
||||
// make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
// a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
|
||||
// make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
// statically_indexed_array<
|
||||
// statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
// MIterPerWarp>
|
||||
// a_warp_windows;
|
||||
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
// move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
// });
|
||||
// });
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
@@ -92,16 +92,17 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp;
|
||||
// constexpr index_t A_LDS_Read_Inst_Remain = A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num;
|
||||
|
||||
static_for<0, A_LDS_Read_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num-A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
@@ -110,7 +111,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
});
|
||||
}
|
||||
|
||||
@@ -134,15 +135,21 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
// constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
|
||||
constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
@@ -166,6 +173,25 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_lds_gemm_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
a_lds_gemm_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Block GEMM
|
||||
auto block_flatmm = BlockFlatmm();
|
||||
|
||||
@@ -245,7 +271,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
// GEMM i
|
||||
block_flatmm(c_block_tile, a_lds_gemm_window, b_warp_tensor);
|
||||
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -278,7 +304,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
// GEMM i
|
||||
block_flatmm(c_block_tile, a_lds_gemm_window, b_warp_tensor_2);
|
||||
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -315,7 +341,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
// GEMM i
|
||||
block_flatmm(c_block_tile, a_lds_gemm_window, b_warp_tensor);
|
||||
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -340,10 +366,11 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
// move to next flat K
|
||||
// move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
HotLoopScheduler();
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
block_flatmm(c_block_tile, a_lds_gemm_window, b_warp_tensor_2);
|
||||
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2);
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
|
||||
@@ -19,11 +19,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
|
||||
#if 0
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
#if 1
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
@@ -39,6 +40,9 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
#endif
|
||||
/*xor*/
|
||||
#if 0
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
@@ -79,7 +83,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
#endif
|
||||
#if 0 /*reduce transform layers,compare with old ck*/
|
||||
#if 1 /*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
Reference in New Issue
Block a user