[GEMM] Add cache-aware WG schedule and adjust block tile

113 -> 121.7 TFops
This commit is contained in:
Clement Lin
2025-03-21 09:15:17 +08:00
parent 93193e42ea
commit 1f604e9b0a
6 changed files with 52 additions and 8 deletions

View File

@@ -15,7 +15,7 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
#if 0
#if 1
#pragma message ("mfma k16")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&

View File

@@ -114,7 +114,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
// Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
#if 0
#if 1
#pragma message ("prefetch")
// prefetch
// global read 0

View File

@@ -27,6 +27,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
constexpr index_t kKPack = 8;
#if BANK_CONFLICT_K_FIRST
#pragma message ("BANK_CONFLICT: K_FIRST")
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
@@ -41,6 +42,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif PADDING_K_FIRST
#pragma message ("BANK_CONFLICT: PADDING_K_FIRST")
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
@@ -55,6 +57,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif PADDING_MN_FIRST
#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST")
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
@@ -69,6 +72,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif XOR
#pragma message ("BANK_CONFLICT: XOR")
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr auto DataTypeSize = sizeof(ADataType);
@@ -122,6 +126,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
constexpr index_t kKPack = 8;
#if BANK_CONFLICT_K_FIRST
#pragma message ("BANK_CONFLICT: K_FIRST")
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
@@ -136,6 +141,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif PADDING_K_FIRST
#pragma message ("BANK_CONFLICT: PADDING_K_FIRST")
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
@@ -150,6 +156,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif PADDING_MN_FIRST
#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST")
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
@@ -164,6 +171,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif XOR
#pragma message ("BANK_CONFLICT: XOR")
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr auto DataTypeSize = sizeof(BDataType);

View File

@@ -82,9 +82,15 @@ int main(int argc, char* argv[])
constexpr ck_tile::index_t kBlockSize = 256;
#if 1
#pragma message ("(Increase KperBlock, reduce MperBlock) -> increase Grid size")
constexpr ck_tile::index_t kGemmMPerBlock = 128;
constexpr ck_tile::index_t kGemmKPerBlock = 64;
#else
constexpr ck_tile::index_t kGemmMPerBlock = 256;
constexpr ck_tile::index_t kGemmNPerBlock = 128;
constexpr ck_tile::index_t kGemmKPerBlock = 32;
#endif
constexpr ck_tile::index_t kGemmNPerBlock = 128;
ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock);

View File

@@ -80,10 +80,39 @@ struct Gemm
static constexpr index_t kKPerBlock = kKPerBlock_;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t NumTilesM,
index_t NumTilesN)
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0,
index_t N0)
{
const auto unmerge = make_merge_transform(make_tuple(NumTilesN, NumTilesM));
#if 1
#pragma message ("Cache-aware work group sch")
return [=](index_t block_1d_id) {
constexpr index_t M01 = 2;
constexpr index_t GroupNum = 4;
const auto group_size = integer_divide_ceil(M0 * N0, GroupNum);
const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
const auto group_id_x = block_1d_id % GroupNum;
const auto remap_block_1d_id = (group_id_x <= big_group_num)
? (group_id_x * group_size + block_1d_id / GroupNum)
: (group_id_x * group_size + big_group_num - group_id_x);
const index_t idx_M0 = remap_block_1d_id / N0;
const index_t idx_N0 = remap_block_1d_id % N0;
const index_t M0_mod_M01 = M0 % M01;
const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01;
const index_t idx_M00 = idx_M0 / M01;
const index_t idx_M01 = idx_M0 % M01;
const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_multi_index(idx_N0_M01_local % M01_adapt + idx_M00 * M01, idx_N0_M01_local / M01_adapt);
};
#else
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
return [unmerge](index_t block_id) {
multi_index<2> unmerged;
@@ -92,6 +121,7 @@ struct Gemm
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
};
#endif
}
template <typename Problem>

View File

@@ -30,8 +30,8 @@ struct GridGemm
// divide problem
const auto id_block = get_block_id();
const auto num_tile_m = M / kMPerBlock;
const auto num_tile_n = N / kNPerBlock;
const auto num_tile_m = integer_divide_ceil(M, kMPerBlock);
const auto num_tile_n = integer_divide_ceil(N, kNPerBlock);
const auto block2tile = Policy::template MakeBlock2TileMap<Problem>(num_tile_m, num_tile_n);