diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index f6e26ad206..79030fcd51 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -194,83 +194,6 @@ struct BlockUniversalGemmAsBsCr { }; - template - struct BlockGemmImpl - { - static constexpr auto ALdsTileDistr = - decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; - static constexpr auto BLdsTileDistr = - decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; - - using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); - using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); - - ALdsTile a_warp_tile_; - BLdsTile b_warp_tile_; - - // C += A * B - template - CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window, - bool_constant = {}, - bool_constant = {}) - { - static_assert(std::is_same_v, - "The CDataType as defined in traits should be the same as correspoinding " - "C block tensor data type!"); - static_assert(std::is_same_v && - std::is_same_v, - "The ADataType and BDataType as defined in " - "traits should be the same as correspoinding block window data type!"); - - load_int4_tile(a_warp_tile_, - a_block_window); - load_int4_tile(b_warp_tile_, - b_block_window); - // hot loop: - static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; - - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - - b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - // read C warp tensor from C block tensor- - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); - }); - } - }; - template struct BlockGemmImpl { @@ -450,7 +373,9 @@ struct BlockUniversalGemmAsBsCr // hot loop: static_for<0, KRepeat, 1>{}([&](auto kIter) { LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier( + 0); // Complete scheduling all pending instruction groups before this point + // NOTE: Synchronize threads in a workgroup at the start of each MAC // cluster, but except the first, as we can shorten non-MAC cluster a bit // and there's no observable negative impact. The desired effect is waves in @@ -460,8 +385,14 @@ struct BlockUniversalGemmAsBsCr // sync point. if constexpr(kIter.value != 0 || KRepeat == 1) { - __builtin_amdgcn_s_barrier(); - __builtin_amdgcn_sched_barrier(0); + // This pattern ensures: + // At runtime: All waves synchronize (hardware barrier) + // At compile-time: Instructions after the barrier don't get moved before it + // (scheduling barrier) + __builtin_amdgcn_s_barrier(); // Blocks execution until all waves (threads) in + // the workgroup reach this point + __builtin_amdgcn_sched_barrier( + 0); // Prevents instruction reordering across this boundary } static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) { diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 628f5f7dc8..9583ac8a3f 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -1035,7 +1035,6 @@ struct UniversalGemmKernel * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * */ - template CK_TILE_DEVICE static void RunGemm(const std::array& as_ptr, const std::array& bs_ptr, const std::array& ds_ptr, @@ -1161,9 +1160,7 @@ struct UniversalGemmKernel // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - constexpr auto scheduler_type = - GemmPipeline::DoubleSmemBuffer || (GemmPipeline::NumWaveGroups == 1); - RunGemm( + RunGemm( as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 46c1f69b12..3597590c0f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -80,7 +80,7 @@ struct GemmPipelineProblemBase static constexpr bool kPadK = Traits::kPadK; static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; - static constexpr auto Scheduler = GemmPipelineScheduler::Default; + static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave; static constexpr index_t VectorLoadSize = Traits::_VectorSize; // In the base situation, the Preshuffle setting should be false. diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 91dfc8494a..2f6497fdba 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -164,7 +164,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem }; template <> - struct PipelineImpl : public PipelineImplBase + struct PipelineImpl : public PipelineImplBase { using Base = PipelineImplBase; @@ -491,7 +491,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem void* p_smem, index_t m = 0) const { - return PipelineImpl{} + return PipelineImpl{} .template operator()( a_dram_block_window_tmp, [](const BDataType& a) { return a; },