diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp index 0fb5d2de7d..5c58fa3d60 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp @@ -111,8 +111,8 @@ struct BlockGemmASmemBSmemCReg // C += A * B template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ABlockWindowTmp& a_block_window_tmp, - const BBlockWindowTmp& b_block_window_tmp) const + [[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const { static_assert(std::is_same_v && std::is_same_v && @@ -127,14 +127,11 @@ struct BlockGemmASmemBSmemCReg KPerBlock == BlockGemmShape::kK, "wrong!"); - // constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - - // using WarpGemm = remove_cvref_t())>; - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; +#if !defined(ENABLE_PREFETCH) constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; @@ -142,7 +139,7 @@ struct BlockGemmASmemBSmemCReg const index_t iMWarp = get_warp_id() / NWarp; const index_t iNWarp = get_warp_id() % NWarp; - // construct A-warp-window + // Construct A-warp-window auto a_warp_window_tmp = make_tile_window( a_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), @@ -158,13 +155,12 @@ struct BlockGemmASmemBSmemCReg static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - move_tile_window(a_warp_windows(mIter)(kIter), {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); }); - // construct B-warp-window + // Construct B-warp-window auto b_warp_window_tmp = make_tile_window( b_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), @@ -180,16 +176,16 @@ struct BlockGemmASmemBSmemCReg static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); }); +#endif // hot loop: static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor + // Read A warp tensor from A block tensor AWarpTensor a_warp_tensor; #if defined(ENABLE_PREFETCH) #pragma message("local data share prefetch") @@ -200,7 +196,7 @@ struct BlockGemmASmemBSmemCReg a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); #endif static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor + // Read B warp tensor from B block tensor BWarpTensor b_warp_tensor; #if defined(ENABLE_PREFETCH) b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data( @@ -209,17 +205,17 @@ struct BlockGemmASmemBSmemCReg #else b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); #endif - // read C warp tensor from C block tensor + // Read C warp tensor from C block tensor CWarpTensor c_warp_tensor; c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM + // Warp GEMM WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor + // Write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), @@ -231,8 +227,8 @@ struct BlockGemmASmemBSmemCReg // C = A * B template - CK_TILE_DEVICE auto operator()(const ABlockWindowTmp& a_block_window_tmp, - const BBlockWindowTmp& b_block_window_tmp) const + CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const { static_assert(std::is_same_v && std::is_same_v, @@ -246,14 +242,11 @@ struct BlockGemmASmemBSmemCReg KPerBlock == BlockGemmShape::kK, "wrong!"); - // constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - - // using WarpGemm = remove_cvref_t{}))>; - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; +#if !defined(ENABLE_PREFETCH) constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; @@ -261,7 +254,7 @@ struct BlockGemmASmemBSmemCReg const index_t iMWarp = get_warp_id() / NWarp; const index_t iNWarp = get_warp_id() % NWarp; - // construct A-warp-window + // Construct A-warp-window auto a_warp_window_tmp = make_tile_window( a_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), @@ -277,13 +270,12 @@ struct BlockGemmASmemBSmemCReg static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - move_tile_window(a_warp_windows(mIter)(kIter), {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); }); - // construct B-warp-window + // Construct B-warp-window auto b_warp_window_tmp = make_tile_window( b_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), @@ -299,11 +291,11 @@ struct BlockGemmASmemBSmemCReg static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); }); +#endif static_assert(std::is_same_v, "wrong!"); @@ -323,10 +315,10 @@ struct BlockGemmASmemBSmemCReg auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - // hot loop: + // Hot loop: static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor + // Read A warp tensor from A block tensor AWarpTensor a_warp_tensor; #if defined(ENABLE_PREFETCH) a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data( @@ -336,7 +328,7 @@ struct BlockGemmASmemBSmemCReg a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); #endif static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor + // Read B warp tensor from B block tensor BWarpTensor b_warp_tensor; #if defined(ENABLE_PREFETCH) b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data( @@ -345,10 +337,10 @@ struct BlockGemmASmemBSmemCReg #else b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); #endif - // read C warp tensor from C block tensor + // Read C warp tensor from C block tensor CWarpTensor c_warp_tensor; - // warp GEMM + // Warp GEMM if constexpr(KIterPerWarp == 0) { // c = a * b @@ -364,7 +356,7 @@ struct BlockGemmASmemBSmemCReg WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); } - // write C warp tensor into C block tensor + // Write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp index 57a0614c7f..effcc2b101 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -307,7 +307,7 @@ struct BlockGemmPipelineAGmemBGmemCReg // Gemm pipeline start #if defined(ENABLE_PREFETCH) - +#pragma message("global prefetch") // Initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -342,7 +342,7 @@ struct BlockGemmPipelineAGmemBGmemCReg // Main body if(num_loop > 2) { - index_t i = 0; + index_t iCounter = 0; do { block_sync_lds();