diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 628e9194ae..cb452043d1 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -61,6 +61,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using BQDataType = remove_cvref_t; + using BLayout = remove_cvref_t; using BQLayout = remove_cvref_t; using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; @@ -156,9 +157,11 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using CDataType = remove_cvref_t; // BDataType gets converted from PkInt4 during loading - using OverrideBDataType = - std::conditional_t, ADataType, BDataType>; - + using OverrideBDataType = std::conditional_t< + std::is_same_v && + std::is_same_v, + ADataType, + BDataType>; using Base = BlockGemmBQuantBase; using WarpGemm = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 2c191cc2b4..f6ebbd9228 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -33,9 +33,17 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using QuantGroupSize = remove_cvref_t; + using ALayout = remove_cvref_t; + using BQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading using OverrideBDataType = - std::conditional_t, ADataType, BDataType>; + std::conditional_t && + std::is_same_v, + ADataType, + BDataType>; static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); using I0 = number<0>; @@ -50,11 +58,6 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3>::PackedSize; - using ALayout = remove_cvref_t; - using BQLayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - using BlockGemm = remove_cvref_t())>; static constexpr index_t BlockSize = Problem::kBlockSize; @@ -184,6 +187,23 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_dram_window); } + template + CK_TILE_DEVICE void + BGlobalPrefetch(BBlockTile_& b_block_tile, + BDramWindow& b_copy_dram_window, + const BDramTileWindowStep& b_dram_tile_window_step) const + { + if constexpr(!std::is_same_v) + { + LoadAndConvertBTile(b_block_tile, b_copy_dram_window); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + } + else + { + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + } + } + template (ABlockTileDistr{})); using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); + decltype(make_static_distributed_tensor(BBlockTileDistr{})); using BQBlockTile = decltype(make_static_distributed_tensor(BQBlockTileDistr{})); @@ -289,8 +309,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -322,8 +341,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3