diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp index c9961bc259..0f374d9f94 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp @@ -15,7 +15,7 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { -#if 0 +#if 1 #pragma message ("mfma k16") if constexpr(std::is_same_v && std::is_same_v && diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp index 27a75626f9..67a7985b48 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -114,7 +114,7 @@ struct BlockGemmPipelineAGmemBGmemCReg // Acc register tile auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; -#if 0 +#if 1 #pragma message ("prefetch") // prefetch // global read 0 diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp index 071235ab0b..b9784c901f 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp @@ -27,6 +27,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr index_t kKPack = 8; #if BANK_CONFLICT_K_FIRST +#pragma message ("BANK_CONFLICT: K_FIRST") constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number{}, number{}, number<1>{}), @@ -41,6 +42,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1>{})); #elif PADDING_K_FIRST +#pragma message ("BANK_CONFLICT: PADDING_K_FIRST") constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number{}, number<1>{}), @@ -55,6 +57,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1>{})); #elif PADDING_MN_FIRST +#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST") constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number{}, number<1>{}), @@ -69,6 +72,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1>{})); #elif XOR +#pragma message ("BANK_CONFLICT: XOR") using ADataType = remove_cvref_t; constexpr auto DataTypeSize = sizeof(ADataType); @@ -122,6 +126,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr index_t kKPack = 8; #if BANK_CONFLICT_K_FIRST +#pragma message ("BANK_CONFLICT: K_FIRST") constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number{}, number{}, number<1>{}), @@ -136,6 +141,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1>{})); #elif PADDING_K_FIRST +#pragma message ("BANK_CONFLICT: PADDING_K_FIRST") constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number{}, number<1>{}), @@ -150,6 +156,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1>{})); #elif PADDING_MN_FIRST +#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST") constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number{}, number<1>{}), @@ -164,6 +171,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1>{})); #elif XOR +#pragma message ("BANK_CONFLICT: XOR") using BDataType = remove_cvref_t; constexpr auto DataTypeSize = sizeof(BDataType); diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.cpp b/example/ck_tile/99_toy_example/02_gemm/gemm.cpp index 0c9bbd21eb..7aea4b376d 100644 --- a/example/ck_tile/99_toy_example/02_gemm/gemm.cpp +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.cpp @@ -82,9 +82,15 @@ int main(int argc, char* argv[]) constexpr ck_tile::index_t kBlockSize = 256; +#if 1 +#pragma message ("(Increase KperBlock, reduce MperBlock) -> increase Grid size") + constexpr ck_tile::index_t kGemmMPerBlock = 128; + constexpr ck_tile::index_t kGemmKPerBlock = 64; +#else constexpr ck_tile::index_t kGemmMPerBlock = 256; - constexpr ck_tile::index_t kGemmNPerBlock = 128; constexpr ck_tile::index_t kGemmKPerBlock = 32; +#endif + constexpr ck_tile::index_t kGemmNPerBlock = 128; ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock); diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp index 0dc6945002..adcf8c3867 100644 --- a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp @@ -80,10 +80,39 @@ struct Gemm static constexpr index_t kKPerBlock = kKPerBlock_; template - CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t NumTilesM, - index_t NumTilesN) + CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, + index_t N0) { - const auto unmerge = make_merge_transform(make_tuple(NumTilesN, NumTilesM)); +#if 1 +#pragma message ("Cache-aware work group sch") + return [=](index_t block_1d_id) { + constexpr index_t M01 = 2; + constexpr index_t GroupNum = 4; + + const auto group_size = integer_divide_ceil(M0 * N0, GroupNum); + const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0); + + const auto group_id_x = block_1d_id % GroupNum; + + const auto remap_block_1d_id = (group_id_x <= big_group_num) + ? (group_id_x * group_size + block_1d_id / GroupNum) + : (group_id_x * group_size + big_group_num - group_id_x); + + const index_t idx_M0 = remap_block_1d_id / N0; + const index_t idx_N0 = remap_block_1d_id % N0; + + const index_t M0_mod_M01 = M0 % M01; + + const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01; + + const index_t idx_M00 = idx_M0 / M01; + const index_t idx_M01 = idx_M0 % M01; + const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_multi_index(idx_N0_M01_local % M01_adapt + idx_M00 * M01, idx_N0_M01_local / M01_adapt); + }; +#else + const auto unmerge = make_merge_transform(make_tuple(N0, M0)); return [unmerge](index_t block_id) { multi_index<2> unmerged; @@ -92,6 +121,7 @@ struct Gemm return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); }; +#endif } template diff --git a/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp index 468c47abe1..9400ec2146 100644 --- a/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp @@ -30,8 +30,8 @@ struct GridGemm // divide problem const auto id_block = get_block_id(); - const auto num_tile_m = M / kMPerBlock; - const auto num_tile_n = N / kNPerBlock; + const auto num_tile_m = integer_divide_ceil(M, kMPerBlock); + const auto num_tile_n = integer_divide_ceil(N, kNPerBlock); const auto block2tile = Policy::template MakeBlock2TileMap(num_tile_m, num_tile_n);