From afd3d2bd10da068dcfe891d6c7b0a4e428e0be47 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Wed, 28 Jan 2026 17:15:05 +0000 Subject: [PATCH] Clean up pipeline --- include/ck_tile/core/numeric/type_convert.hpp | 6 - .../block_universal_gemm_as_bs_bquant_cr.hpp | 2 +- .../pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp | 409 +++++++++--------- 3 files changed, 208 insertions(+), 209 deletions(-) diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index a581eab62d..634b845725 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -69,12 +69,6 @@ CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, fp32x2_t, fp32x2) CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2) #undef CK_TILE_TYPE_CONVERT -template <> -CK_TILE_HOST_DEVICE constexpr bf16_t type_convert(bf8_t x) -{ - return float_to_bf16(bf8_to_float(x)); -} - } // namespace ck_tile #include "ck_tile/core/numeric/pk_fp4.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 65a9281166..51bc35efe9 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -239,7 +239,7 @@ struct BQuantBlockUniversalGemmAsBsCr bool BLoadTranspose = false> CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, const BSmemBlockWindow& b_block_window, - BQRegBlockTile& bq_block_tensor, + const BQRegBlockTile& bq_block_tensor, bool_constant = {}, bool_constant = {}) { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp index 03381f569a..de92d45763 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp @@ -27,9 +27,12 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; - using BDqDataType = std::conditional_t, - remove_cvref_t, - BDataType>; + using BDqDataType = remove_cvref_t; + + static constexpr bool IsCastBeforeLDS = Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + + using BLDSType = std::conditional_t; + using BQDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; @@ -43,17 +46,16 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 static constexpr index_t APackedSize = ck_tile::numeric_traits>::PackedSize; - static constexpr index_t BDqPackedSize = - ck_tile::numeric_traits>::PackedSize; static constexpr index_t BPackedSize = - std::is_same_v - ? 2 - : ck_tile::numeric_traits>::PackedSize; + ck_tile::numeric_traits>::PackedSize; static constexpr index_t BQPackedSize = ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BLDSPackedSize = + ck_tile::numeric_traits>::PackedSize; + using ALayout = remove_cvref_t; using BQLayout = remove_cvref_t; using BLayout = remove_cvref_t; @@ -90,8 +92,6 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 static constexpr auto TailNum = Problem::TailNum; static constexpr auto Scheduler = Problem::Scheduler; - static constexpr bool IsCastBeforeLDS = Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; - using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -175,6 +175,11 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { using Base = PipelineImplBase; + static constexpr bool is_a_col_major = + std::is_same_v; + static constexpr bool is_b_row_major = + std::is_same_v; + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() { constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; @@ -217,7 +222,7 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num : A_LDS_Read_Inst_Num / 2; constexpr auto num_ds_read_inst_b = - B_LDS_Read_Width * sizeof(BDqDataType) / BDqPackedSize == 16 + B_LDS_Read_Width * sizeof(BLDSType) / BLDSPackedSize == 16 ? B_LDS_Read_Inst_Num : B_LDS_Read_Inst_Num / 2; @@ -233,7 +238,7 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr auto ds_read_a_issue_cycle = A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; constexpr auto ds_read_b_issue_cycle = - B_LDS_Read_Width * sizeof(BDqDataType) / BDqPackedSize == 16 ? 8 : 4; + B_LDS_Read_Width * sizeof(BLDSType) / BLDSPackedSize == 16 ? 8 : 4; constexpr auto ds_read_a_mfma_rate = (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); constexpr auto ds_read_b_mfma_rate = @@ -316,6 +321,139 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 }); } + template + CK_TILE_DEVICE static void + ScaleTile(TileType& block_tile, CastTileType& block_tile_cast, ScaleTileType& scale_tile) + { + if constexpr(IsCastBeforeLDS) + { + constexpr auto b_block = TileType::get_distributed_spans(); + constexpr auto idx1_js = tile_distributed_index<0>{}; + + // Internally this is using V_CVT_SCALEF32_PK_BF16_FP4 or V_CVT_SCALEF32_PK_FP16_FP4 + // on gfx950 + auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) { + if constexpr(std::is_same_v) + { + return pk_fp4_to_fp16x2(pk_mxfp4, fscale); + } + else if constexpr(std::is_same_v) + { + return pk_fp4_to_bf16x2(pk_mxfp4, fscale); + } + else + { + static_assert(false, "unsupported compute type"); + } + }; + + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + auto scale = scale_tile(i_j_idx_scale); + auto b_scale_uint = uint32_t(scale.data) << 23; + if constexpr(std::is_same_v) + { + if constexpr(idx1.impl_.at(0) % BPackedSize == 0) + { + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = + tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + auto b_pack = block_tile(i_j_idx); + auto cvt = + pk_mxfp4_to_compute_v2(b_pack, bit_cast(b_scale_uint)); + block_tile_cast(i_j_idx_lo) = cvt.x; + block_tile_cast(i_j_idx_hi) = cvt.y; + } + } + else + { + auto b_pack = block_tile(i_j_idx); + block_tile_cast(i_j_idx) = type_convert( + type_convert(b_pack) * bit_cast(b_scale_uint)); + } + }); + }); + } + } + + template + CK_TILE_DEVICE void ALocalPrefill(WindowType& lds_window, + const TileType& block_tile, + const ElementwiseFunc& element_func) const + { + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, block_tile); + Base::LocalPrefill(lds_window, a_shuffle_tmp, element_func); + } + else + { + Base::LocalPrefill(lds_window, block_tile, element_func); + } + } + + template + CK_TILE_DEVICE void BLocalPrefill(WindowType& lds_window, + const TileType& block_tile, + const TileTypeCast& block_tile_cast, + const ElementwiseFunc& element_func) const + { + // Fill LDS and apply the scale if IsCastBeforeLDS + auto get_b_block_tile = [](auto& b_block_tile_orig, auto& b_block_tile_cast) { + if constexpr(IsCastBeforeLDS) + { + return b_block_tile_cast; + } + else + { + return b_block_tile_orig; + } + }; + + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, get_b_block_tile(block_tile, block_tile_cast)); + Base::LocalPrefill(lds_window, b_shuffle_tmp, element_func); + } + else + { + Base::LocalPrefill( + lds_window, get_b_block_tile(block_tile, block_tile_cast), element_func); + } + } + + template + CK_TILE_DEVICE void LocalPrefetch(BlockGemmType& block_gemm, + const AWindowType& a_lds_window, + const BWindowType& b_lds_window, + const QTileType& q_block_tile) const + { + // Load from LDS + // It can apply the scale and cast if we scale after reading from LDS + if constexpr(IsCastBeforeLDS) + { + block_gemm.LocalPrefetch(a_lds_window, b_lds_window); + } + else + { + block_gemm.LocalPrefetch(a_lds_window, b_lds_window, q_block_tile); + } + } + template index_t num_loop, void* p_smem) const { + // ----------------------------------------------------------------------------------------- + // Pipeline checks static_assert( std::is_same_v> && std::is_same_v "A/B/BQ Dram block window should have the same data type as appropriate " "([A|B|BQ]DataType) defined in Problem definition!"); - constexpr bool is_a_col_major = - std::is_same_v; constexpr bool is_bq_col_major = std::is_same_v; - constexpr bool is_b_row_major = std::is_same_v; static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && @@ -393,6 +530,11 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 auto bq_block_tile = decltype(load_tile(bq_copy_dram_window)){}; + // This defines the scaled and casted block tile for B matrix. + // Effectively, it is used only if we scale and cast before writing to LDS. + auto bdq_block_tile = make_static_distributed_tensor( + Policy::template MakeBRegTileDistribution()); + // Block GEMM auto block_gemm = BlockGemm(); auto c_block_tile = block_gemm.MakeCBlockTile(); @@ -405,7 +547,7 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 decltype(make_static_distributed_tensor(BBlockTileDistr{})); ABlockTile a_block_tile; - BBlockTile b_fp4_block_tile; + BBlockTile b_block_tile; using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; @@ -419,137 +561,44 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 // ----------------------------------------------------------------------------------------- // Gemm pipeline start - // prefetch - // global read 0 - // auto a_scale_block_tile = decltype(load_tile(a_scale_copy_dram_window)){}; - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); - // BDataType - auto b_block_tile = make_static_distributed_tensor( - Policy::template MakeBRegTileDistribution()); + // prefetch stages + // Vmem -> Vgpr 0 + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + // Vmem -> Vgpr 0 (Q matrix) + // Scale and cast tile before writing to LDS (if IsCastBeforeLDS) bq_block_tile = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile); - constexpr auto idx1_js = tile_distributed_index<0>{}; - constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans(); - - // Internally this is using V_CVT_SCALEF32_PK_BF16_FP4 or V_CVT_SCALEF32_PK_FP16_FP4 on - // gfx950 - auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) { - if constexpr(std::is_same_v) - { - return pk_fp4_to_fp16x2(pk_mxfp4, fscale); - } - else if constexpr(std::is_same_v) - { - return pk_fp4_to_bf16x2(pk_mxfp4, fscale); - } - else - { - static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type"); - } - }; - - auto get_b_block_tile = [](auto& b_block_tile_orig, auto& b_block_tile_cast) { - if constexpr(IsCastBeforeLDS) - { - return b_block_tile_cast; - } - else - { - return b_block_tile_orig; - } - }; - - auto apply_scale_func = [&]() { - sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { - sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); - auto scale = bq_block_tile(i_j_idx_scale); - auto b_scale_uint = uint32_t(scale.data) << 23; - if constexpr(std::is_same_v) - { - if constexpr(idx1.impl_.at(0) % BPackedSize == 0) - { - constexpr auto idx1_lo = tile_distributed_index{}; - constexpr auto idx1_hi = - tile_distributed_index{}; - constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); - constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - auto b_pack = b_fp4_block_tile(i_j_idx); - - auto cvt = - pk_mxfp4_to_compute_v2(b_pack, bit_cast(b_scale_uint)); - b_block_tile(i_j_idx_lo) = cvt.x; - b_block_tile(i_j_idx_hi) = cvt.y; - } - } - else - { - auto b_pack = b_fp4_block_tile(i_j_idx); - b_block_tile(i_j_idx) = type_convert( - type_convert(b_pack) * bit_cast(b_scale_uint)); - } - }); - }); - }; - - if constexpr(IsCastBeforeLDS) - apply_scale_func(); - - // initialize C + // initialize C tile to zero tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); block_sync_lds(); - // LDS write 0 - if constexpr(is_a_col_major) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - } - - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); - transpose_tile2d(b_shuffle_tmp, b_block_tile_); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - } - else - { - auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_block_tile_, b_element_func); - } + // Vgpr -> LDS 0 + ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func); + // Vmem -> Vgpr 1 Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // If we scale and cast before writing to LDS, + // we need to read another tile of Q matrix from Vmem, then scale and cast tile if constexpr(IsCastBeforeLDS) { bq_block_tile = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); - - apply_scale_func(); - - block_sync_lds(); - - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); } - else - { - block_sync_lds(); + ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile); + + block_sync_lds(); + + // LDS -> Vgpr 0 + LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); - } __builtin_amdgcn_sched_barrier(0); // main body @@ -560,58 +609,34 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { block_sync_lds(); - if constexpr(is_a_col_major) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - } - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); - transpose_tile2d(b_shuffle_tmp, b_block_tile_); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - } - else - { - auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_block_tile_, b_element_func); - } + // Vgpr -> LDS + ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func); + // Vmem -> Vgpr Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // Vmem -> Vgpr (Q matrix) + // Scale and cast tile before writing to LDS (if IsCastBeforeLDS) bq_block_tile = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile); - if constexpr(IsCastBeforeLDS) - apply_scale_func(); - + // Consume tile block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - if constexpr(IsCastBeforeLDS) - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); - else - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); + // LDS -> Vgpr + LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); i += 1; - // b_block_stride +=1; } while(i < (num_loop - 1)); } - // tile_elementwise_inout([](auto& c) { c = 0; }, acc_block_tile); + // tail if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) { @@ -621,50 +646,31 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 } else { + // If we scale and cast after reading from LDS, + // we didn't read the second tile of Q matrix from Vmem during prefetch stages, + // so we need to read the last tile here. + // This is not a problem because we have all block_gemm instructions to hide the + // latency. if constexpr(!IsCastBeforeLDS) { bq_block_tile = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); } + // Consume second to last tile block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - if constexpr(is_a_col_major) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - } - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); - transpose_tile2d(b_shuffle_tmp, b_block_tile_); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - } - else - { - auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_block_tile_, b_element_func); - } + // Vgpr -> LDS last tile + ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func); block_sync_lds(); - if constexpr(IsCastBeforeLDS) - { - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); - } - else - { - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); - } + // LDS -> Vgpr last tile + LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); + + // Consume last tile block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); } @@ -690,13 +696,12 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 void* p_smem, index_t n = 0) const { - using BElementwise = std::conditional_t; - ck_tile::ignore = n; + ck_tile::ignore = n; return PipelineImpl{}.template operator()( a_dram_block_window_tmp, [](const ADataType& a) { return a; }, b_dram_block_window_tmp, - [](const BElementwise& b) { return b; }, + [](const BLDSType& b) { return b; }, bq_dram_block_window_tmp, num_loop, p_smem);