diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp index 7f6ec4307f..4419d31e28 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp @@ -43,7 +43,7 @@ struct BlockGemmASmemBSmemCReg static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; -#if defined(ENABLE_INSTRUCTION_SCH) +#if defined(ENABLE_PREFETCH) // A block tile distribution for load from lds CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { @@ -92,16 +92,16 @@ struct BlockGemmASmemBSmemCReg using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); - ALdsTile a_warp_tile_; - ALdsTile b_warp_tile_; + ALdsTile aWarpTile; + BLdsTile bWarpTile; // Prefetch from LDS to warp register template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, const BSmemBlockWindow& b_block_window) { - load_tile(a_warp_tile_, a_block_window); - load_tile(b_warp_tile_, b_block_window); + aWarpTile = load_tile(a_block_window); + bWarpTile = load_tile(b_block_window); } #endif @@ -178,23 +178,23 @@ struct BlockGemmASmemBSmemCReg static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block tensor AWarpTensor a_warp_tensor; -#if defined(ENABLE_INSTRUCTION_SCH) +#if defined(ENABLE_PREFETCH) #pragma message ("local data share prefetch") - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); #else - load_tile(a_warp_tensor, a_warp_windows(mIter)(kIter)); + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); #endif static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B block tensor BWarpTensor b_warp_tensor; -#if defined(ENABLE_INSTRUCTION_SCH) - b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( +#if defined(ENABLE_PREFETCH) + b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data( merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); #else - load_tile(b_warp_tensor, b_warp_windows(nIter)(kIter)); + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); #endif // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; @@ -305,22 +305,22 @@ struct BlockGemmASmemBSmemCReg static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block tensor AWarpTensor a_warp_tensor; -#if defined(ENABLE_INSTRUCTION_SCH) - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( +#if defined(ENABLE_PREFETCH) + a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); #else - load_tile(a_warp_tensor, a_warp_windows(mIter)(kIter)); + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); #endif static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B block tensor BWarpTensor b_warp_tensor; -#if defined(ENABLE_INSTRUCTION_SCH) - b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( +#if defined(ENABLE_PREFETCH) + b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data( merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); #else - load_tile(b_warp_tensor, b_warp_windows(nIter)(kIter)); + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); #endif // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp index a6bc4e7563..ec71f9c76a 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp @@ -17,19 +17,31 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { +#if defined(ADJUST_BLOCK_TILE_SHAPE) + constexpr index_t kMWarp = 2; + constexpr index_t kNWarp = 2; +#else + constexpr index_t kMWarp = 4; + constexpr index_t kNWarp = 1; +#endif + #if defined(NAIVE_IMPLEMENTATION) #pragma message ("mfma m32 n32 k8") if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, + kMWarp, + kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, + kMWarp, + kNWarp); } #elif defined(USING_MFMA_32x32x_8x2) #pragma message ("mfma m32 n32 k16") @@ -37,13 +49,17 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, + kMWarp, + kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, + kMWarp, + kNWarp); } #elif defined(USING_MFMA_16x16x16) #pragma message ("mfma m16 n16 k16") @@ -51,13 +67,17 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, + kMWarp, + kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, + kMWarp, + kNWarp); } #elif defined(USING_MFMA_16x16x_16x2) #pragma message ("mfma m16 n16 k32") @@ -65,13 +85,17 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, + kMWarp, + kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, + kMWarp, + kNWarp); } #endif else diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp index d96f27b66b..26d5618330 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -42,15 +42,8 @@ struct BlockGemmPipelineAGmemBGmemCReg } #if defined(ENABLE_INSTRUCTION_SCH) - static constexpr index_t APackedSize = + static constexpr index_t kPackedSize = ck_tile::numeric_traits>::PackedSize; - static constexpr index_t BPackedSize = - ck_tile::numeric_traits>::PackedSize; - - static constexpr index_t BlockSize = Problem::kBlockSize; - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } @@ -72,31 +65,31 @@ struct BlockGemmPipelineAGmemBGmemCReg constexpr index_t AB_LDS_RW_Width = GetSmemPack(); constexpr index_t A_Buffer_Load_Inst_Num = - MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + kMPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeA()); constexpr index_t B_Buffer_Load_Inst_Num = - NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + kNPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeB()); constexpr index_t A_LDS_Write_Inst_Num = - MPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width); + kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); constexpr index_t B_LDS_Write_Inst_Num = - NPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width); + kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); constexpr index_t A_LDS_Read_Inst_Num = - WaveNumN * MPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width); + WaveNumN * kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); constexpr index_t B_LDS_Read_Inst_Num = - WaveNumM * NPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width); + WaveNumM * kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); - constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / - (BlockSize / WaveSize) / + constexpr index_t C_MFMA_Inst_Num = kMPerBlock * kNPerBlock * kKPerBlock / + (kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); // A/B split schedule // compiler is likely to use ds_read2 when instruction width smaller than 16bytes constexpr auto num_ds_read_inst_a = - AB_LDS_RW_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num : + AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? A_LDS_Read_Inst_Num : A_LDS_Read_Inst_Num / 2; constexpr auto num_ds_read_inst_b = - AB_LDS_RW_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num : + AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? B_LDS_Read_Inst_Num : B_LDS_Read_Inst_Num / 2; constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; @@ -109,9 +102,9 @@ struct BlockGemmPipelineAGmemBGmemCReg constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; constexpr auto ds_read_a_issue_cycle = - AB_LDS_RW_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; + AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? 8 : 4; constexpr auto ds_read_b_issue_cycle = - AB_LDS_RW_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4; + AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4; constexpr auto ds_read_a_mfma_rate = (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); constexpr auto ds_read_b_mfma_rate = @@ -266,15 +259,15 @@ struct BlockGemmPipelineAGmemBGmemCReg {0, 0}, b_copy_dram_window.get_tile_distribution()); -#if defined(ENABLE_INSTRUCTION_SCH) +#if defined(ENABLE_PREFETCH) // A LDS tile for block GEMM auto a_lds_gemm_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}, + a_lds_block, make_tuple(number{}, number{}), {0, 0}, make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode())); // B LDS tile for block GEMM auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}, + b_lds_block, make_tuple(number{}, number{}), {0, 0}, make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode())); #else // A LDS tile for block GEMM @@ -303,23 +296,23 @@ struct BlockGemmPipelineAGmemBGmemCReg ABlockTile a_block_tile; BBlockTile b_block_tile; + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock); // ------------------------------------------------------------------------------------- // Gemm pipeline start -#if defined(ENABLE_INSTRUCTION_SCH) - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; - using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; - constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock); - constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock); +#if defined(ENABLE_PREFETCH) // Initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // Prefetch // Global read 0 - load_tile(a_block_tile, a_copy_dram_window); - load_tile(b_block_tile, b_copy_dram_window); + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); if (num_loop > 1) { @@ -331,8 +324,8 @@ struct BlockGemmPipelineAGmemBGmemCReg store_tile(b_copy_lds_window, b_block_tile); // Global read 0 - load_tile(a_block_tile, a_copy_dram_window); - load_tile(b_block_tile, b_copy_dram_window); + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); move_tile_window(a_copy_dram_window, a_dram_tile_window_step); move_tile_window(b_copy_dram_window, b_dram_tile_window_step); @@ -357,8 +350,8 @@ struct BlockGemmPipelineAGmemBGmemCReg store_tile(b_copy_lds_window, b_block_tile); // Global read 0 - load_tile(a_block_tile, a_copy_dram_window); - load_tile(b_block_tile, b_copy_dram_window); + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); move_tile_window(a_copy_dram_window, a_dram_tile_window_step); move_tile_window(b_copy_dram_window, b_dram_tile_window_step); @@ -369,12 +362,14 @@ struct BlockGemmPipelineAGmemBGmemCReg // Prefetch from LDS to warp register in block gemm block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); +#if defined(ENABLE_INSTRUCTION_SCH) HotLoopScheduler(); +#endif __builtin_amdgcn_sched_barrier(0); - i += 1; - } while(i < (num_loop - 2)); + iCounter += 1; + } while(iCounter < (num_loop - 2)); } // Tail @@ -388,84 +383,12 @@ struct BlockGemmPipelineAGmemBGmemCReg block_sync_lds(); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - -#elif defined(ENABLE_PREFETCH) - // Prefetch - // Global read 0 - load_tile(a_block_tile, a_copy_dram_window); - load_tile(b_block_tile, b_copy_dram_window); - - { - // Move to 1 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - - // Initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // LDS write 0 - store_tile(a_copy_lds_window, a_block_tile); - // Global read 1 - load_tile(a_block_tile, a_copy_dram_window); - - // LDS write 0 - store_tile(b_copy_lds_window, b_block_tile); - // Global read 1 - load_tile(b_block_tile, b_copy_dram_window); - } - - index_t iCounter = num_loop - 2; - - do - { - block_sync_lds(); - - // GEMM i - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - // Move to i + 2 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - - // LDS write i + 1 - store_tile(a_copy_lds_window, a_block_tile); - // Global read i + 2 - load_tile(a_block_tile, a_copy_dram_window); - - // LDS write i + 1 - store_tile(b_copy_lds_window, b_block_tile); - // Global read i + 2 - load_tile(b_block_tile, b_copy_dram_window); - - iCounter--; - - } while(iCounter > 0); - - // Tail - { - block_sync_lds(); - - // GEMM num_loop - 2 - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - // LDS write num_loop - 1 - store_tile(a_copy_lds_window, a_block_tile); - - store_tile(b_copy_lds_window, b_block_tile); - - block_sync_lds(); - - // GEMM num_loop - 1 - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - } #else // non-prefetch - load_tile(a_block_tile, a_copy_dram_window); - load_tile(b_block_tile, b_copy_dram_window); + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); store_tile(a_copy_lds_window, a_block_tile); store_tile(b_copy_lds_window, b_block_tile); @@ -477,11 +400,10 @@ struct BlockGemmPipelineAGmemBGmemCReg while (iCounter > 0) { - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - - load_tile(a_block_tile, a_copy_dram_window); - load_tile(b_block_tile, b_copy_dram_window); + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); store_tile(a_copy_lds_window, a_block_tile); store_tile(b_copy_lds_window, b_block_tile);