diff --git a/CHANGELOG.md b/CHANGELOG.md index 8fd3c36255..b07e322fe1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle. +* Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM. * Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM * Added a compute async pipeline in the CK TILE universal GEMM on gfx950 * Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM. diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 44d0736ad3..396a54c7c2 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -21,7 +21,9 @@ template ; @@ -67,12 +69,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str using BaseGemmPipeline = std::conditional_t< GemmConfig::PreshuffleB == true, ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true, - ck_tile::BaseGemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BaseGemmPipelineAgBgCrCompV3>>>; + ck_tile::BaseGemmPipelineAgBgCrCompV3>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -131,9 +128,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::GemmPipelineAgBgCrCompV3, std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped, - std::conditional_t, - ck_tile::AQuantGemmPipelineAgBgCrMem>, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, std::conditional_t, ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; @@ -289,7 +284,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float ave_time = gemm_calc_quant( + arg_parser, Row{}, Row{}, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_example_with_layouts( + arg_parser, Col{}, Row{}, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts( + arg_parser, Col{}, Col{}, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } + else { throw std::runtime_error("Unsupported memory layout for the input matrices!"); diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp index 08163e27ad..10c2a1e4df 100644 --- a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp +++ b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp @@ -48,7 +48,7 @@ CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) } else { - dst = load_tile(src); + load_tile(dst, src); } } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index ad7a032e52..f39d41a653 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -26,18 +26,32 @@ struct GemmPipelineAgBgCrImplBase static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; #if defined(__gfx950__) - // The combination of pk_int4_t and transposed loading causes numerical errors. + // The combination of pk_int4_t and transposed loading causes compilation errors. // Therefore do not use transposed loading in this case. + // Also, transpose load (ds_read_tr) requires specific tile distribution patterns + // that only work for certain K warp tile sizes based on data type size: + // - For 1-byte types (fp8/bf8): K warp tile <= 64 + // - For 2-byte types (fp16/bf16): K warp tile <= 32 static constexpr bool is_a_load_tr = []() { + using WarpTile = typename BlockGemmShape::WarpTile; + constexpr index_t kKWarpTile = WarpTile::at(number<2>{}); + constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32; if constexpr(std::is_same_v) return false; + else if constexpr(kKWarpTile > kMaxKWarpTile) + return false; else return std::is_same_v; }(); static constexpr bool is_b_load_tr = []() { + using WarpTile = typename BlockGemmShape::WarpTile; + constexpr index_t kKWarpTile = WarpTile::at(number<2>{}); + constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32; if constexpr(std::is_same_v) return false; + else if constexpr(kKWarpTile > kMaxKWarpTile) + return false; else return std::is_same_v; }(); @@ -93,19 +107,21 @@ struct GemmPipelineAgBgCrImplBase load_tile(dst_block_tile, lds_tile_window); } + template CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const { // A tile in LDS - ADataType* __restrict__ p_a_lds = static_cast(p_smem); - constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + OverrideADataType* __restrict__ p_a_lds = static_cast(p_smem); + constexpr auto a_lds_block_desc = + Policy::template MakeALdsBlockDescriptor(); auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); // TODO: LDS alignment should come from Policy! constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple( - sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16); + sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size(), 16); // B tile in LDS - BDataType* __restrict__ p_b_lds = static_cast( + OverrideBDataType* __restrict__ p_b_lds = static_cast( static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp index 4872ea34a9..ffe889af41 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -18,7 +18,8 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked; static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked; - template + template > CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index c7c161e710..d843916f5e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -37,11 +37,22 @@ struct UniversalGemmBasePolicy #if defined(__gfx950__) // The combination of pk_int4_t and transposed loading causes numerical errors. // Therefore do not use transposed loading in this case. + // Also, transpose load (ds_read_tr) requires specific tile distribution patterns + // that only work for certain K warp tile sizes based on data type size: + // - For 1-byte types (fp8/bf8): K warp tile <= 64 + // - For 2-byte types (fp16/bf16): K warp tile <= 32 template static constexpr bool is_a_load_tr = []() { - using BDataType = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + constexpr index_t kKWarpTile = WarpTile::at(number<2>{}); + // Max K warp tile for transpose load based on data type size + constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32; if constexpr(std::is_same_v) return false; + else if constexpr(kKWarpTile > kMaxKWarpTile) + return false; else return std::is_same_v, tensor_layout::gemm::ColumnMajor>; @@ -49,9 +60,15 @@ struct UniversalGemmBasePolicy template static constexpr bool is_b_load_tr = []() { - using BDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + constexpr index_t kKWarpTile = WarpTile::at(number<2>{}); + // Max K warp tile for transpose load based on data type size + constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32; if constexpr(std::is_same_v) return false; + else if constexpr(kKWarpTile > kMaxKWarpTile) + return false; else return std::is_same_v, tensor_layout::gemm::RowMajor>; @@ -87,13 +104,12 @@ struct UniversalGemmBasePolicy return DefaultBTileAccessPattern; } - template + template > CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor() { - using ALayout = remove_cvref_t; - using ADataType = remove_cvref_t; - - using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + using ADataType = OverrideADataType; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPack = GetSmemPackA(); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index ad4d0baab2..5100de58ac 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -435,12 +435,22 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase ALdsTile a_warp_tile_; BLdsTile b_warp_tile_; - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { - load_int4_tile(a_warp_tile_, a_block_window); - load_int4_tile(b_warp_tile_, b_block_window); + // while ADatatype might not be the same as BDataType at the time of problem + // initialization, we can safely use BDataType here because when A would be int4 we will + // ensure A is converted to BDataType prior to loading + load_int4_tile( + a_warp_tile_, a_block_window); + load_int4_tile( + b_warp_tile_, b_block_window); } // C += A * B @@ -522,11 +532,16 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase MakeCBlockTile(); } - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { - block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window); + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 012b53bbd4..f6cf4ce9be 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -414,7 +414,6 @@ struct QuantGemmKernel if constexpr(kQuantType == QuantType::AQuantGrouped) { - static_assert(std::is_same_v); if(kargs.QK_A % GemmPipeline::GetVectorSizeAQ() != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) @@ -655,13 +654,24 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant) { - static_assert(std::is_same_v); - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.M, kargs.QK_A), - make_tuple(kargs.stride_AQ, 1), - number{}, - number<1>{}); + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.M, kargs.QK_A), + make_tuple(kargs.stride_AQ, 1), + number{}, + number<1>{}); + } + else // Column major AQ + { + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.QK_A, kargs.M), // Swapped dimensions + make_tuple(kargs.stride_AQ, 1), // Same stride pattern + number{}, + number<1>{}); + } } else if constexpr(kQuantType == QuantType::RowColQuant) { @@ -946,14 +956,21 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant) { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - constexpr auto block_m = TilePartitioner::MPerBlock; - constexpr auto block_k = TilePartitioner::KPerBlock; - return make_tile_window( - aq_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); + using QuantGroupSize = remove_cvref_t; + constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; + constexpr auto block_m = TilePartitioner::MPerBlock; + if constexpr(std::is_same_v) + { + return make_tile_window(aq_pad_view, + make_tuple(number{}, number{}), + {i_m, 0}); + } + else // Column major AQ + { + return make_tile_window(aq_pad_view, + make_tuple(number{}, number{}), + {0, i_m}); + } } else if constexpr(kQuantType == QuantType::RowColQuant) { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp index 6b0a47f3c0..e3ad883440 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp @@ -20,8 +20,6 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase; - using AQLayout = remove_cvref_t; - static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; @@ -36,8 +34,6 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase); - auto aq_copy_dram_window = make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(), aq_dram_block_window_tmp.get_window_lengths(), diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp index 88565f96ef..9681156e1a 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp @@ -18,13 +18,11 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ() { - using AQLayout = remove_cvref_t; using AQDataType = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK; - static_assert(std::is_same_v); return GetABQGlobalVectorLoadSize(); } @@ -49,7 +47,6 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC WarpTile::at(I2), Problem::TransposeC>; - static_assert(std::is_same_v); if constexpr(PreshuffleQuant) { using TileEncodingPattern = tile_distribution_encoding_pattern_aq< @@ -68,6 +65,8 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC { if constexpr(Problem::TransposeC) { + static_assert(std::is_same_v, + "TransposeC currently only supports RowMajor layout"); using TileEncodingPatternTransposeC = tile_distribution_encoding_pattern_aq_transposed_c; + // !Problem::TransposeC + if constexpr(std::is_same_v) + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_aq; - return TileEncodingPattern::make_2d_static_tile_distribution(); + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + else + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_aq; + return TileEncodingPattern::make_2d_static_tile_distribution_transposed(); + } } } } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index fcbac3ff66..30b9d70eb8 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -77,6 +77,9 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -161,6 +164,16 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE static void LoadAndConvertATile(ABlockTile_& a_block_tile, + const ADramWindow& a_dram_window) + { + using DestDataType = typename ABlockTile_::DataType; + using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; + constexpr index_t UnaryOpSize = 8; + load_int4_tile(a_block_tile, a_dram_window); + } + template > && std::is_same_v; constexpr bool is_b_row_major = std::is_same_v; - static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)"); - static_assert(is_a_col_major ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) @@ -211,7 +223,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(p_smem); constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); @@ -228,8 +241,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(ABlockTileDistr{})); + decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); using AQBlockTile = @@ -251,23 +267,25 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( - Policy::template make_shuffled_2d_static_tile_distribution()); + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); } @@ -276,10 +294,10 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( - Policy::template make_shuffled_2d_static_tile_distribution()); + Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); } @@ -288,12 +306,14 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -315,7 +335,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledBRegTileDistribution()); @@ -327,7 +347,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -374,7 +396,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledBRegTileDistribution()); @@ -386,7 +408,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + // Note: a_element_func takes BDataType (not ADataType) because A tiles are + // converted from ADataType (e.g., pk_int4_t) to BDataType (e.g., fp8) in + // LoadAndConvertATile before the element function is applied. + [](const BDataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index f2142b4fdf..b51dee752d 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -110,6 +110,27 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding sequence<0, 0>>{}); } } + CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution_transposed() + { + + constexpr index_t Y0 = YPerTile; + constexpr index_t X0 = 1; + constexpr index_t X1 = MIterPerWarp ? MIterPerWarp : 1; + constexpr index_t X2 = MWarps; + constexpr index_t X3 = WarpGemm::kM; + + static_assert(X3 >= WarpGemm::kM, "Scales for all rows must be available within the warp."); + static_assert(X0 * X1 * X2 * X3 == XPerTile, + "X0, X1, X2, X3 must cover the blocktile along X."); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 2>>, + tuple, sequence<0, 3>>, + sequence<2, 1>, + sequence<1, 0>>{}); + } }; template +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using BQuantGrouped = std::integral_constant; +using RowColQuant = std::integral_constant; +using TensorQuant = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; +using GroupSize64 = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests +// Tuple format: +// clang-format off +using AQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // RRR layout (RowMajor A, RowMajor B, RowMajor C with RowMajor AQ) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // CRR layout (ColumnMajor A, RowMajor B, RowMajor C with RowMajor AQ) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // CCR layout (ColumnMajor A, ColumnMajor B, RowMajor C with ColumnMajor AQ) - NEW layout support + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // RCR layout - with the Prefill BlockTile Config. + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // PreshuffleQuant = false && TransposeC = true (with RowMajor AQ) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // PreshuffleQuant = true && TransposeC = false (with RowMajor AQ - PreshuffleQuant only supports RowMajor) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // PreshuffleQuant = true && TransposeC = true (with RowMajor AQ - PreshuffleQuant only supports RowMajor) + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant +TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 08232f81be..38bd59b882 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -29,13 +29,14 @@ class TestCkTileGemmQuantBase : public ::testing::Test using ALayout = std::tuple_element_t<0, Tuple>; using BLayout = std::tuple_element_t<1, Tuple>; using CLayout = std::tuple_element_t<2, Tuple>; - using ADataType = std::tuple_element_t<3, Tuple>; - using BDataType = std::tuple_element_t<4, Tuple>; - using QDataType = std::tuple_element_t<5, Tuple>; - using CDataType = std::tuple_element_t<6, Tuple>; - static constexpr auto QuantType = std::tuple_element_t<7, Tuple>::value; - using GemmConfig = std::tuple_element_t<8, Tuple>; - using QuantGroupSize = std::tuple_element_t<9, Tuple>; + using AQLayout = std::tuple_element_t<3, Tuple>; + using ADataType = std::tuple_element_t<4, Tuple>; + using BDataType = std::tuple_element_t<5, Tuple>; + using QDataType = std::tuple_element_t<6, Tuple>; + using CDataType = std::tuple_element_t<7, Tuple>; + static constexpr auto QuantType = std::tuple_element_t<8, Tuple>::value; + using GemmConfig = std::tuple_element_t<9, Tuple>; + using QuantGroupSize = std::tuple_element_t<10, Tuple>; using AccDataType = float; // accumulate always in float // Get the quant-type specific data types from traits @@ -85,6 +86,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test using TilePartitioner = ck_tile::GemmTile1DPartitioner; + // BQLayout is always ColumnMajor for BQuant + using BQLayout = ck_tile::tensor_layout::gemm::ColumnMajor; + using CodegenGemmTraits = ck_tile::TileGemmQuantTraits +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; +using GroupSize64 = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests (without PreshuffleB) +// Tuple format: +// clang-format off +using BQuantTypes = ::testing::Types< + // 1d cases with grouping only on k axis (AQLayout is always RowMajor for BQuant) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // 2d cases with grouping also on the n axis + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant (without PreshuffleB) +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp new file mode 100644 index 0000000000..59b267842f --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests with PreshuffleB +// Tuple format: +// clang-format off +using BPreshuffleBQuantTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant with PreshuffleB +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleBQuantTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 685f52cdac..3b62d8073e 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -135,6 +135,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBaseis_row_major(ALayout{})); + const ck_tile::index_t stride_B = + ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{})); + const ck_tile::index_t stride_C = + ck_tile::get_default_stride(M, N, 0, this->is_row_major(CLayout{})); // AQuant uses grouped quantization for A matrix const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize::kK); + // AQLayout is parameterized in the test tuple (can be RowMajor or ColumnMajor for AQuant) const ck_tile::index_t stride_AQ = - ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(ALayout{})); + ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(AQLayout{})); // Generate test data ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); + // AQLayout is independently specified for each test case ck_tile::HostTensor aq_m_aqk( - ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(ALayout{}))); + ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(AQLayout{}))); ck_tile::HostTensor b_k_n( ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); @@ -407,8 +413,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBaseis_row_major(ALayout{}))); ck_tile::HostTensor b_k_n( ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); + // BQ is always ColumnMajor ck_tile::HostTensor bq_bqk_bqn( - ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BLayout{}))); + ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, ck_tile::bool_constant{})); // Initialize data with random values ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp new file mode 100644 index 0000000000..5a58ed886a --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using RowColQuant = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for RowColQuant tests +// Tuple format: +// clang-format off +using RowColQuantTypes = ::testing::Types< + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for RowColQuant +TYPED_TEST_SUITE(TestCkTileGemmRowColQuant, RowColQuantTypes); + +// RowColQuant tests +TYPED_TEST(TestCkTileGemmRowColQuant, RowColQuantTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp new file mode 100644 index 0000000000..0fa4048dab --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using TensorQuant = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for TensorQuant tests +// Tuple format: +// clang-format off +using TensorQuantTypes = ::testing::Types< + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for TensorQuant +TYPED_TEST_SUITE(TestCkTileGemmTensorQuant, TensorQuantTypes); + +// TensorQuant tests +TYPED_TEST(TestCkTileGemmTensorQuant, TensorQuantTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp deleted file mode 100644 index 07aed62804..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include -#include - -#include "test_gemm_quant_fixtures.hpp" - -// Type aliases for readability -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using Half = ck_tile::half_t; -using PkInt4 = ck_tile::pk_int4_t; -using AQuantGrouped = std::integral_constant; -using BQuantGrouped = std::integral_constant; -using RowColQuant = std::integral_constant; -using TensorQuant = std::integral_constant; -using GroupSize = ck_tile::QuantGroupShape>; -using GroupSize64 = ck_tile::QuantGroupShape>; - -// 2d block sizes for BQuant -using GroupSize2D8N = ck_tile::QuantGroupShape>; -using GroupSize2D16N = ck_tile::QuantGroupShape>; -using GroupSize2D32N = ck_tile::QuantGroupShape>; -using GroupSize2D64N = ck_tile::QuantGroupShape>; -using GroupSize2D128N = ck_tile::QuantGroupShape>; - -// Type combinations for each quantization type -// clang-format off -using AQuantTypes = ::testing::Types< - // PreshuffleQuant = false && TransposeC = false - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // PreshuffleQuant = false && TransposeC = false && Prefill - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // PreshuffleQuant = false && TransposeC = true - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // PreshuffleQuant = true && TransposeC = false - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // PreshuffleQuant = true && TransposeC = true - std::tuple, - std::tuple, - std::tuple, - std::tuple ->; -// clang-format on - -// clang-format off -using BQuantTypes = ::testing::Types< - // 1d cases with grouping only on k axis - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // 2d cases with grouping also on the n axis - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple ->; -// clang-format on - -// clang-format off -using BPreshuffleBQuantTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple ->; -// clang-format on - -// clang-format off -using RowColQuantTypes = ::testing::Types< - std::tuple, - std::tuple ->; -// clang-format on - -// clang-format off -using TensorQuantTypes = ::testing::Types< - std::tuple, - std::tuple ->; -// clang-format on - -// Test suites for each quantization type -TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes); -TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes); -TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleBQuantTypes); -TYPED_TEST_SUITE(TestCkTileGemmRowColQuant, RowColQuantTypes); -TYPED_TEST_SUITE(TestCkTileGemmTensorQuant, TensorQuantTypes); - -#include "test_gemm_quant_ut_cases.inc" diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc b/test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc deleted file mode 100644 index a88483fe3e..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -// AQuant tests -TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} - -// BQuant tests -TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} - -// BQuant tests -TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantGroupedTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} -// RowColQuant tests -TYPED_TEST(TestCkTileGemmRowColQuant, RowColQuantTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} - -// TensorQuant tests -TYPED_TEST(TestCkTileGemmTensorQuant, TensorQuantTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -}