mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
[GEMM] Add cache-aware WG schedule and adjust block tile
113 -> 121.7 TFops
This commit is contained in:
@@ -15,7 +15,7 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
#if 0
|
||||
#if 1
|
||||
#pragma message ("mfma k16")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
|
||||
@@ -114,7 +114,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
#pragma message ("prefetch")
|
||||
// prefetch
|
||||
// global read 0
|
||||
|
||||
@@ -27,6 +27,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
#if BANK_CONFLICT_K_FIRST
|
||||
#pragma message ("BANK_CONFLICT: K_FIRST")
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
@@ -41,6 +42,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif PADDING_K_FIRST
|
||||
#pragma message ("BANK_CONFLICT: PADDING_K_FIRST")
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
@@ -55,6 +57,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif PADDING_MN_FIRST
|
||||
#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST")
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
@@ -69,6 +72,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif XOR
|
||||
#pragma message ("BANK_CONFLICT: XOR")
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
@@ -122,6 +126,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
#if BANK_CONFLICT_K_FIRST
|
||||
#pragma message ("BANK_CONFLICT: K_FIRST")
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
@@ -136,6 +141,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif PADDING_K_FIRST
|
||||
#pragma message ("BANK_CONFLICT: PADDING_K_FIRST")
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
@@ -150,6 +156,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif PADDING_MN_FIRST
|
||||
#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST")
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
@@ -164,6 +171,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif XOR
|
||||
#pragma message ("BANK_CONFLICT: XOR")
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
|
||||
@@ -82,9 +82,15 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
#if 1
|
||||
#pragma message ("(Increase KperBlock, reduce MperBlock) -> increase Grid size")
|
||||
constexpr ck_tile::index_t kGemmMPerBlock = 128;
|
||||
constexpr ck_tile::index_t kGemmKPerBlock = 64;
|
||||
#else
|
||||
constexpr ck_tile::index_t kGemmMPerBlock = 256;
|
||||
constexpr ck_tile::index_t kGemmNPerBlock = 128;
|
||||
constexpr ck_tile::index_t kGemmKPerBlock = 32;
|
||||
#endif
|
||||
constexpr ck_tile::index_t kGemmNPerBlock = 128;
|
||||
|
||||
ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock);
|
||||
|
||||
|
||||
@@ -80,10 +80,39 @@ struct Gemm
|
||||
static constexpr index_t kKPerBlock = kKPerBlock_;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t NumTilesM,
|
||||
index_t NumTilesN)
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0,
|
||||
index_t N0)
|
||||
{
|
||||
const auto unmerge = make_merge_transform(make_tuple(NumTilesN, NumTilesM));
|
||||
#if 1
|
||||
#pragma message ("Cache-aware work group sch")
|
||||
return [=](index_t block_1d_id) {
|
||||
constexpr index_t M01 = 2;
|
||||
constexpr index_t GroupNum = 4;
|
||||
|
||||
const auto group_size = integer_divide_ceil(M0 * N0, GroupNum);
|
||||
const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
|
||||
|
||||
const auto group_id_x = block_1d_id % GroupNum;
|
||||
|
||||
const auto remap_block_1d_id = (group_id_x <= big_group_num)
|
||||
? (group_id_x * group_size + block_1d_id / GroupNum)
|
||||
: (group_id_x * group_size + big_group_num - group_id_x);
|
||||
|
||||
const index_t idx_M0 = remap_block_1d_id / N0;
|
||||
const index_t idx_N0 = remap_block_1d_id % N0;
|
||||
|
||||
const index_t M0_mod_M01 = M0 % M01;
|
||||
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01;
|
||||
|
||||
const index_t idx_M00 = idx_M0 / M01;
|
||||
const index_t idx_M01 = idx_M0 % M01;
|
||||
const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
return make_multi_index(idx_N0_M01_local % M01_adapt + idx_M00 * M01, idx_N0_M01_local / M01_adapt);
|
||||
};
|
||||
#else
|
||||
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
|
||||
|
||||
return [unmerge](index_t block_id) {
|
||||
multi_index<2> unmerged;
|
||||
@@ -92,6 +121,7 @@ struct Gemm
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
|
||||
};
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -30,8 +30,8 @@ struct GridGemm
|
||||
// divide problem
|
||||
const auto id_block = get_block_id();
|
||||
|
||||
const auto num_tile_m = M / kMPerBlock;
|
||||
const auto num_tile_n = N / kNPerBlock;
|
||||
const auto num_tile_m = integer_divide_ceil(M, kMPerBlock);
|
||||
const auto num_tile_n = integer_divide_ceil(N, kNPerBlock);
|
||||
|
||||
const auto block2tile = Policy::template MakeBlock2TileMap<Problem>(num_tile_m, num_tile_n);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user