diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index fc98792615..92b98a2f11 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -10,200 +10,6 @@ namespace ck_tile { -// A Tile Window: global memory -// B Tile Window: global memory -// C Distributed tensor: register -template -struct BlockGemmPipelineAGmemBGmemCReg -{ - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy; - - static constexpr index_t kBlockSize = Problem::kBlockSize; - - static constexpr index_t kMPerBlock = BlockGemmShape::kM; - static constexpr index_t kNPerBlock = BlockGemmShape::kN; - static constexpr index_t kKPerBlock = BlockGemmShape::kK; - - // Move this part into Policy? - __host__ __device__ static constexpr index_t GetStaticLdsSize() - { - return sizeof(BDataType) * - Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); - } - - template - __host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - index_t num_loop, - void* p_smem) const - { - static_assert( - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], - "wrong!"); - - // A tile in Reg,blockTensor - // This tensor distribution used to construct both distributed tensor for local buffer store - // and read. without buffer address info - constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor(); - - // B tile in LDS, blockWindow - BDataType* p_b_lds = - static_cast(static_cast(static_cast(p_smem))); - - constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); - - // This tensor view used to construct both tile window for lds store and read, with buffer - // address info - auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); - - // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); - - // A Reg tensor for store, also used for block GEMM - auto a_copy_reg_tensor = make_static_distributed_tensor(a_reg_block_dstr); - - // B DRAM tile window for load - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); - - // B LDS tile window for store - auto b_copy_lds_window = - make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - b_copy_dram_window.get_tile_distribution()); - - // B LDS tile for block GEMM - auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // Block GEMM - constexpr auto block_gemm = Policy::template GetBlockGemm(); - - // Acc register tile - auto c_block_tile = decltype(block_gemm(a_copy_reg_tensor, b_lds_gemm_window)){}; - - // prefetch - // global read 0 - auto a_block_tile = load_tile(a_copy_dram_window); - auto b_block_tile = load_tile(b_copy_dram_window); - - { - // move to 1 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - - // Initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // block buffer write 0 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - // store_tile -> shuffle store tile - store_tile(a_copy_reg_tensor, a_block_tile_tmp); - // global read 1 - a_block_tile = load_tile(a_copy_dram_window); - - // LDS write 0 - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); - // global read 1 - b_block_tile = load_tile(b_copy_dram_window); - } - - index_t iCounter = num_loop - 2; - - do - { - block_sync_lds(); - - // GEMM i - block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window); - - block_sync_lds(); - - // move to i + 2 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - - // LDS write i + 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_reg_tensor, a_block_tile_tmp); - // global read i + 2 - a_block_tile = load_tile(a_copy_dram_window); - - // LDS write i + 1 - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); - // global read i + 2 - b_block_tile = load_tile(b_copy_dram_window); - - iCounter--; - - } while(iCounter > 0); - - // tail - { - block_sync_lds(); - - // GEMM num_loop - 2 - block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window); - - block_sync_lds(); - - // LDS write num_loop - 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_reg_tensor, a_block_tile_tmp); - - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); - - block_sync_lds(); - - // GEMM num_loop - 1 - block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window); - } - - return c_block_tile; - } - - template - __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, - index_t num_loop, - void* p_smem) const - { - return operator()( - a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, - b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, - num_loop, - p_smem); - } -}; - // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register @@ -248,12 +54,12 @@ struct BlockGemmPipelineAGmemBGmemCReg< { static_assert( std::is_same_v> && - std::is_same_v>, + std::is_same_v>, "wrong!"); static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); ignore = a_element_func; @@ -312,12 +118,13 @@ struct BlockGemmPipelineAGmemBGmemCReg< auto a_block_tile = load_tile(a_copy_dram_window); auto b_block_tile = load_tile(b_copy_dram_window); + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + if constexpr(k_loops > 1) { move_tile_window(a_copy_dram_window, {0, kKPerBlock}); move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - set_slice_tile(a_copy_reg_tensor, a_block_tile, sequence<0, 0>{}, @@ -327,6 +134,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< store_tile(b_copy_lds_window, b_block_tile); b_block_tile = load_tile(b_copy_dram_window); } + if constexpr(k_loops > 2) { static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { @@ -334,7 +142,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, - sequence<0, (i_k0)*kKPerBlock>{}, + sequence<0, i_k0 * kKPerBlock>{}, sequence{}), b_copy_lds_window); @@ -356,15 +164,18 @@ struct BlockGemmPipelineAGmemBGmemCReg< // tail { - block_sync_lds(); + if constexpr(k_loops > 1) + { + block_sync_lds(); - block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, (k_loops - 2) * kKPerBlock>{}, - sequence{}), - b_copy_lds_window); + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, (k_loops - 2) * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); - block_sync_lds(); + block_sync_lds(); + } set_slice_tile(a_copy_reg_tensor, a_block_tile, @@ -378,11 +189,10 @@ struct BlockGemmPipelineAGmemBGmemCReg< block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 1) * kKPerBlock>{}, - sequence{}), + sequence{}), b_copy_lds_window); } - // store_tile(a_reg_block_tensor_tmp, a_copy_reg_tensor); set_slice_tile(a_reg_block_tensor_tmp, a_copy_reg_tensor, sequence<0, 0>{}, @@ -402,7 +212,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< "wrong!"); static_assert(kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); ignore = b_element_func; @@ -414,7 +224,6 @@ struct BlockGemmPipelineAGmemBGmemCReg< // A Reg tensor for store, also used for block GEMM auto a_copy_reg_tensor = make_static_distributed_tensor(a_reg_block_dstr); - // store_tile(a_copy_reg_tensor, a_reg_block_tensor_tmp); set_slice_tile(a_copy_reg_tensor, a_reg_block_tensor_tmp, @@ -458,14 +267,16 @@ struct BlockGemmPipelineAGmemBGmemCReg< b_lds_gemm_window)){}; auto b_block_tile = load_tile(b_copy_dram_window); + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + if constexpr(k_loops > 1) { move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - store_tile(b_copy_lds_window, b_block_tile); b_block_tile = load_tile(b_copy_dram_window); } + if constexpr(k_loops > 2) { static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { @@ -473,7 +284,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, - sequence<0, (i_k0)*kKPerBlock>{}, + sequence<0, i_k0 * kKPerBlock>{}, sequence{}), b_copy_lds_window); @@ -488,16 +299,18 @@ struct BlockGemmPipelineAGmemBGmemCReg< // tail { - block_sync_lds(); + if constexpr(k_loops > 1) + { + block_sync_lds(); - block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, (k_loops - 2) * kKPerBlock>{}, - sequence{}), - b_copy_lds_window); - - block_sync_lds(); + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, (k_loops - 2) * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + block_sync_lds(); + } store_tile(b_copy_lds_window, b_block_tile); block_sync_lds(); @@ -505,7 +318,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 1) * kKPerBlock>{}, - sequence{}), + sequence{}), b_copy_lds_window); } diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index fc98792615..92b98a2f11 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -10,200 +10,6 @@ namespace ck_tile { -// A Tile Window: global memory -// B Tile Window: global memory -// C Distributed tensor: register -template -struct BlockGemmPipelineAGmemBGmemCReg -{ - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy; - - static constexpr index_t kBlockSize = Problem::kBlockSize; - - static constexpr index_t kMPerBlock = BlockGemmShape::kM; - static constexpr index_t kNPerBlock = BlockGemmShape::kN; - static constexpr index_t kKPerBlock = BlockGemmShape::kK; - - // Move this part into Policy? - __host__ __device__ static constexpr index_t GetStaticLdsSize() - { - return sizeof(BDataType) * - Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); - } - - template - __host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - index_t num_loop, - void* p_smem) const - { - static_assert( - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], - "wrong!"); - - // A tile in Reg,blockTensor - // This tensor distribution used to construct both distributed tensor for local buffer store - // and read. without buffer address info - constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor(); - - // B tile in LDS, blockWindow - BDataType* p_b_lds = - static_cast(static_cast(static_cast(p_smem))); - - constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); - - // This tensor view used to construct both tile window for lds store and read, with buffer - // address info - auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); - - // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); - - // A Reg tensor for store, also used for block GEMM - auto a_copy_reg_tensor = make_static_distributed_tensor(a_reg_block_dstr); - - // B DRAM tile window for load - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); - - // B LDS tile window for store - auto b_copy_lds_window = - make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - b_copy_dram_window.get_tile_distribution()); - - // B LDS tile for block GEMM - auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // Block GEMM - constexpr auto block_gemm = Policy::template GetBlockGemm(); - - // Acc register tile - auto c_block_tile = decltype(block_gemm(a_copy_reg_tensor, b_lds_gemm_window)){}; - - // prefetch - // global read 0 - auto a_block_tile = load_tile(a_copy_dram_window); - auto b_block_tile = load_tile(b_copy_dram_window); - - { - // move to 1 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - - // Initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // block buffer write 0 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - // store_tile -> shuffle store tile - store_tile(a_copy_reg_tensor, a_block_tile_tmp); - // global read 1 - a_block_tile = load_tile(a_copy_dram_window); - - // LDS write 0 - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); - // global read 1 - b_block_tile = load_tile(b_copy_dram_window); - } - - index_t iCounter = num_loop - 2; - - do - { - block_sync_lds(); - - // GEMM i - block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window); - - block_sync_lds(); - - // move to i + 2 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - - // LDS write i + 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_reg_tensor, a_block_tile_tmp); - // global read i + 2 - a_block_tile = load_tile(a_copy_dram_window); - - // LDS write i + 1 - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); - // global read i + 2 - b_block_tile = load_tile(b_copy_dram_window); - - iCounter--; - - } while(iCounter > 0); - - // tail - { - block_sync_lds(); - - // GEMM num_loop - 2 - block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window); - - block_sync_lds(); - - // LDS write num_loop - 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_reg_tensor, a_block_tile_tmp); - - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); - - block_sync_lds(); - - // GEMM num_loop - 1 - block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window); - } - - return c_block_tile; - } - - template - __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, - index_t num_loop, - void* p_smem) const - { - return operator()( - a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, - b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, - num_loop, - p_smem); - } -}; - // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register @@ -248,12 +54,12 @@ struct BlockGemmPipelineAGmemBGmemCReg< { static_assert( std::is_same_v> && - std::is_same_v>, + std::is_same_v>, "wrong!"); static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); ignore = a_element_func; @@ -312,12 +118,13 @@ struct BlockGemmPipelineAGmemBGmemCReg< auto a_block_tile = load_tile(a_copy_dram_window); auto b_block_tile = load_tile(b_copy_dram_window); + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + if constexpr(k_loops > 1) { move_tile_window(a_copy_dram_window, {0, kKPerBlock}); move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - set_slice_tile(a_copy_reg_tensor, a_block_tile, sequence<0, 0>{}, @@ -327,6 +134,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< store_tile(b_copy_lds_window, b_block_tile); b_block_tile = load_tile(b_copy_dram_window); } + if constexpr(k_loops > 2) { static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { @@ -334,7 +142,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, - sequence<0, (i_k0)*kKPerBlock>{}, + sequence<0, i_k0 * kKPerBlock>{}, sequence{}), b_copy_lds_window); @@ -356,15 +164,18 @@ struct BlockGemmPipelineAGmemBGmemCReg< // tail { - block_sync_lds(); + if constexpr(k_loops > 1) + { + block_sync_lds(); - block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, (k_loops - 2) * kKPerBlock>{}, - sequence{}), - b_copy_lds_window); + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, (k_loops - 2) * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); - block_sync_lds(); + block_sync_lds(); + } set_slice_tile(a_copy_reg_tensor, a_block_tile, @@ -378,11 +189,10 @@ struct BlockGemmPipelineAGmemBGmemCReg< block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 1) * kKPerBlock>{}, - sequence{}), + sequence{}), b_copy_lds_window); } - // store_tile(a_reg_block_tensor_tmp, a_copy_reg_tensor); set_slice_tile(a_reg_block_tensor_tmp, a_copy_reg_tensor, sequence<0, 0>{}, @@ -402,7 +212,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< "wrong!"); static_assert(kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); ignore = b_element_func; @@ -414,7 +224,6 @@ struct BlockGemmPipelineAGmemBGmemCReg< // A Reg tensor for store, also used for block GEMM auto a_copy_reg_tensor = make_static_distributed_tensor(a_reg_block_dstr); - // store_tile(a_copy_reg_tensor, a_reg_block_tensor_tmp); set_slice_tile(a_copy_reg_tensor, a_reg_block_tensor_tmp, @@ -458,14 +267,16 @@ struct BlockGemmPipelineAGmemBGmemCReg< b_lds_gemm_window)){}; auto b_block_tile = load_tile(b_copy_dram_window); + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + if constexpr(k_loops > 1) { move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - store_tile(b_copy_lds_window, b_block_tile); b_block_tile = load_tile(b_copy_dram_window); } + if constexpr(k_loops > 2) { static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { @@ -473,7 +284,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, - sequence<0, (i_k0)*kKPerBlock>{}, + sequence<0, i_k0 * kKPerBlock>{}, sequence{}), b_copy_lds_window); @@ -488,16 +299,18 @@ struct BlockGemmPipelineAGmemBGmemCReg< // tail { - block_sync_lds(); + if constexpr(k_loops > 1) + { + block_sync_lds(); - block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, (k_loops - 2) * kKPerBlock>{}, - sequence{}), - b_copy_lds_window); - - block_sync_lds(); + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, (k_loops - 2) * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + block_sync_lds(); + } store_tile(b_copy_lds_window, b_block_tile); block_sync_lds(); @@ -505,7 +318,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 1) * kKPerBlock>{}, - sequence{}), + sequence{}), b_copy_lds_window); }