diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 2f6497fdba..650cd947f7 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -28,7 +28,11 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using QuantGroupSize = remove_cvref_t; + using QuantGroupSize = remove_cvref_t; + // When ADataType is pk_int4_t, use BDataType instead for transpose operations + // since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision) + using OverrideADataType = + std::conditional_t, BDataType, ADataType>; static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!"); static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!"); @@ -228,9 +232,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem "B block window has incorrect lengths for defined BLayout!"); // A/B tiles in LDS - using the same approach as regular gemm pipeline - auto ab_lds_blocks = Base::template GetABLdsTensorViews(p_smem); - auto& a_lds_block = ab_lds_blocks.at(I0{}); - auto& b_lds_block = ab_lds_blocks.at(I1{}); + auto ab_lds_blocks = + Base::template GetABLdsTensorViews(p_smem); + auto& a_lds_block = ab_lds_blocks.at(I0{}); + auto& b_lds_block = ab_lds_blocks.at(I1{}); // Tile distribution for load from lds constexpr auto a_lds_load_tile_distr = @@ -260,7 +265,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); using ABlockTile = - decltype(make_static_distributed_tensor(ABlockTileDistr{})); + decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); using AQBlockTile = @@ -295,7 +300,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // LDS prefill - VGPRs to LDS if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -346,7 +351,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // Prepare next iteration data if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d( a_shuffle_tmp, @@ -406,7 +411,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number{})); @@ -494,7 +499,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return PipelineImpl{} .template operator()( a_dram_block_window_tmp, - [](const BDataType& a) { return a; }, + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 22dd78e070..71e4a74400 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -25,7 +25,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using QuantGroupSize = remove_cvref_t; + using QuantGroupSize = remove_cvref_t; + // When ADataType is pk_int4_t, use BDataType instead for transpose operations + // since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision) + using OverrideADataType = + std::conditional_t, BDataType, ADataType>; static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!"); static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!"); @@ -164,14 +168,17 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 - CK_TILE_DEVICE static void LoadAndConvertATile(ABlockTile_& a_block_tile, - const ADramWindow& a_dram_window) + template + CK_TILE_DEVICE static void + LoadAndConvertATile(ABlockTile_& a_block_tile, + ADramWindow& a_dram_window, + const DramTileWindowStep& dram_tile_window_step) { using DestDataType = typename ABlockTile_::DataType; using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; constexpr index_t UnaryOpSize = 8; load_int4_tile(a_block_tile, a_dram_window); + move_tile_window(a_dram_window, dram_tile_window_step); } template (p_smem); + Base::template GetABLdsTensorViews(p_smem); constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); @@ -241,11 +248,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(ABlockTileDistr{})); + decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); using AQBlockTile = @@ -274,8 +278,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -306,8 +309,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -349,8 +351,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -430,10 +431,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - // Note: a_element_func takes BDataType (not ADataType) because A tiles are - // converted from ADataType (e.g., pk_int4_t) to BDataType (e.g., fp8) in - // LoadAndConvertATile before the element function is applied. - [](const BDataType& a) { return a; }, + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp, @@ -476,7 +474,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp,