From acd77f9c2c74ad786b9de3916d604311ee402b51 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Mon, 12 Jan 2026 18:17:03 +0000 Subject: [PATCH] Merge commit '5aaa0313503305ad697f6614836be87f8e0b281a' into develop --- .../block/block_universal_gemm_as_bs_cr.hpp | 91 +++---------------- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 5 +- .../gemm/pipeline/gemm_pipeline_problem.hpp | 2 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 2 +- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 56 ++---------- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 4 +- 6 files changed, 22 insertions(+), 138 deletions(-) diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index f6e26ad206..79030fcd51 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -194,83 +194,6 @@ struct BlockUniversalGemmAsBsCr { }; - template - struct BlockGemmImpl - { - static constexpr auto ALdsTileDistr = - decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; - static constexpr auto BLdsTileDistr = - decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; - - using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); - using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); - - ALdsTile a_warp_tile_; - BLdsTile b_warp_tile_; - - // C += A * B - template - CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window, - bool_constant = {}, - bool_constant = {}) - { - static_assert(std::is_same_v, - "The CDataType as defined in traits should be the same as correspoinding " - "C block tensor data type!"); - static_assert(std::is_same_v && - std::is_same_v, - "The ADataType and BDataType as defined in " - "traits should be the same as correspoinding block window data type!"); - - load_int4_tile(a_warp_tile_, - a_block_window); - load_int4_tile(b_warp_tile_, - b_block_window); - // hot loop: - static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; - - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - - b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - // read C warp tensor from C block tensor- - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); - }); - } - }; - template struct BlockGemmImpl { @@ -450,7 +373,9 @@ struct BlockUniversalGemmAsBsCr // hot loop: static_for<0, KRepeat, 1>{}([&](auto kIter) { LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier( + 0); // Complete scheduling all pending instruction groups before this point + // NOTE: Synchronize threads in a workgroup at the start of each MAC // cluster, but except the first, as we can shorten non-MAC cluster a bit // and there's no observable negative impact. The desired effect is waves in @@ -460,8 +385,14 @@ struct BlockUniversalGemmAsBsCr // sync point. if constexpr(kIter.value != 0 || KRepeat == 1) { - __builtin_amdgcn_s_barrier(); - __builtin_amdgcn_sched_barrier(0); + // This pattern ensures: + // At runtime: All waves synchronize (hardware barrier) + // At compile-time: Instructions after the barrier don't get moved before it + // (scheduling barrier) + __builtin_amdgcn_s_barrier(); // Blocks execution until all waves (threads) in + // the workgroup reach this point + __builtin_amdgcn_sched_barrier( + 0); // Prevents instruction reordering across this boundary } static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) { diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 628f5f7dc8..9583ac8a3f 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -1035,7 +1035,6 @@ struct UniversalGemmKernel * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * */ - template CK_TILE_DEVICE static void RunGemm(const std::array& as_ptr, const std::array& bs_ptr, const std::array& ds_ptr, @@ -1161,9 +1160,7 @@ struct UniversalGemmKernel // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - constexpr auto scheduler_type = - GemmPipeline::DoubleSmemBuffer || (GemmPipeline::NumWaveGroups == 1); - RunGemm( + RunGemm( as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 46c1f69b12..3597590c0f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -80,7 +80,7 @@ struct GemmPipelineProblemBase static constexpr bool kPadK = Traits::kPadK; static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; - static constexpr auto Scheduler = GemmPipelineScheduler::Default; + static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave; static constexpr index_t VectorLoadSize = Traits::_VectorSize; // In the base situation, the Preshuffle setting should be false. diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 6199142d98..e123cee9e1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -112,7 +112,7 @@ struct UniversalGemmBasePolicy using ADataType = OverrideADataType; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackA(); + constexpr index_t KPack = Derived::template GetSmemPackA(); if constexpr(is_a_load_tr) { diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index e90c6a27d7..1ff95b157c 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -14,56 +14,6 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy { using BasePolicy = UniversalGemmBasePolicy; - // 3d + padding - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - using namespace ck_tile; - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t kKPack = GetSmemPackA(); - using ADataType = remove_cvref_t; - - constexpr auto DataTypeSize = sizeof(ADataType); - constexpr auto MLdsLayer = - (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple( - make_merge_transform( - make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return a_lds_block_desc; - } - template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { @@ -291,6 +241,12 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy } } + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + return GetBlockWeightPreshuffle(); + } + template CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffle() { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 91dfc8494a..2f6497fdba 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -164,7 +164,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem }; template <> - struct PipelineImpl : public PipelineImplBase + struct PipelineImpl : public PipelineImplBase { using Base = PipelineImplBase; @@ -491,7 +491,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem void* p_smem, index_t m = 0) const { - return PipelineImpl{} + return PipelineImpl{} .template operator()( a_dram_block_window_tmp, [](const BDataType& a) { return a; },