diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index d6b63dc47b..40f06ec97a 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -16,6 +16,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") gemm_aquant_quantgrouped_preshufflequant.cpp gemm_bquant_quantgrouped_bf8i4.cpp gemm_bquant_quantgrouped_fp8i4.cpp + gemm_bquant_quantgrouped_bf16mxfp4.cpp gemm_bquant_quantgrouped_bf8.cpp gemm_bquant_quantgrouped_fp8.cpp gemm_bquant_quantgrouped_preshuffleb.cpp diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 3a30c2bad3..eb36ae5800 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -23,7 +23,7 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming - **Preshuffled GEMM**: Shuffle the GEMM of B (weight) matrix in the warp layout and bypass the shared memory to do the GEMM calculation. Best performance solution for GEMM. - **TransposeC**: Transpose the C Matrix Output layout to have the best coalesced scale reading - **Preshuffled Quant**: Preshuffle the input matrix to load multiple Quant warp blocks along the selected dimension. -- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix). +- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix), uint8 (split into two fp4 in the pipeline (for B Matrix)). - **Validation**: CPU/GPU validation and error tolerance options. ## build @@ -53,7 +53,7 @@ args: -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) -v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1) - -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8) + -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, or bf16fp4 (default for both AQuant and Bquant: fp8) -warmup Number of iterations before benchmarking the kernel (default:50) -repeat Number of iterations to benchmark the kernel (default:1000) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp new file mode 100644 index 0000000000..a022ce18e1 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_bf16fp4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + + lut[hash_multiple_strings( + {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 45d2151d5e..669bce2995 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " - "or bf8i4") + "bf8i4 or bf16fp4") .insert("warmup", "50", "Number of iterations before benchmarking the kernel") .insert("repeat", "1000", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") @@ -97,6 +97,8 @@ void bquant_quantgrouped_fp8i4_instance_factory( std::unordered_map>& lut); void bquant_quantgrouped_bf8i4_instance_factory( std::unordered_map>& lut); +void bquant_quantgrouped_bf16fp4_instance_factory( + std::unordered_map>& lut); void bquant_quantgrouped_preshuffleb_instance_factory( std::unordered_map>& lut); void bquant_quantgrouped_preshufflequant_instance_factory( @@ -128,6 +130,7 @@ int main(int argc, char* argv[]) bquant_quantgrouped_bf8_instance_factory(lut); bquant_quantgrouped_fp8i4_instance_factory(lut); bquant_quantgrouped_bf8i4_instance_factory(lut); + bquant_quantgrouped_bf16fp4_instance_factory(lut); bquant_quantgrouped_preshuffleb_instance_factory(lut); bquant_quantgrouped_preshufflequant_instance_factory(lut); bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 2b2333b04c..aabbfff3bd 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -69,8 +69,10 @@ auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) { - using ComputeType = - std::conditional_t; + using ComputeType = std::conditional_t< + std::is_same_v, + ADataType, + std::conditional_t>; // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 8a0dd9bc08..fa5e1f12e3 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -136,9 +136,13 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str std::conditional_t, ck_tile::AQuantGemmPipelineAgBgCrMem>, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; + std::conditional_t< + GemmConfig::PreshuffleB == true, + ck_tile::WPQuantBPipelineAgBgCrV2, + std::conditional_t< + std::is_same_v, + ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>>; constexpr bool TiledPermuteN = (QuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN; @@ -147,28 +151,31 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str printf( "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, QuantGroupSize::kN); } - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - TiledPermuteN>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue, + typename TypeConfig::ADataType, + typename TypeConfig::BDataType>, + ck_tile::tuple<>, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + ck_tile::memory_operation_enum::set, + 1, + false, + 1, + TiledPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; @@ -205,7 +212,11 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( args.M, args.K, args.stride_A, is_row_major(ALayout{}))); ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + std::is_same_v ? args.K / 2 + : args.K, + args.N, + args.stride_B, + is_row_major(BLayout{}))); auto size_a_buffer = a_m.get_element_space_size_in_bytes(); auto size_b_buffer = b_n.get_element_space_size_in_bytes(); @@ -427,7 +438,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, int rotating_count = arg_parser.get_int("rotating_count"); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_B = ck_tile::get_default_stride( + (std::is_same_v) ? (K / 2) : K, + N, + stride_B, + is_row_major(b_layout)); stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); // Conditional stride calculation based on QuantMode @@ -454,8 +469,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); - ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( + (std::is_same_v) ? (K / 2) : K, + N, + stride_B, + is_row_major(b_layout))); ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); @@ -499,13 +517,22 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( + *bq_tensor_ptr); } else { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); } else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) @@ -721,13 +748,23 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { - ck_tile::reference_gemm_quant(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); + if constexpr(std::is_same_v) + ck_tile::reference_mxfp4gemm_quant( + a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); + else + ck_tile::reference_gemm_quant(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); } else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) { @@ -787,16 +824,18 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) using Col = ck_tile::tensor_layout::gemm::ColumnMajor; if((QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) && + QuantMode == ck_tile::QuantType::RowColQuant || + std::is_same_v) && GemmConfig::PreshuffleB) { throw std::runtime_error( - "Preshuffling weight matrix is not supported for AQuant or RowColQuant"); + "Preshuffling weight matrix is not supported for AQuant, RowColQuant or bf16_fp4_gemm"); } if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v) + std::is_same_v || + std::is_same_v) { std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 8830adfdd9..9c2ce62856 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1550,9 +1550,10 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) || - (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))), + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index ac388992d1..a1be8027b2 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -52,9 +52,19 @@ template ::value, - "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); + static_assert(is_any_of::value, + "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); double compute_error = 0; if constexpr(is_any_of::value) @@ -113,9 +123,19 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) { - static_assert( - is_any_of::value, - "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); + static_assert(is_any_of::value, + "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); auto expo = std::log2(std::abs(max_possible_num)); double compute_error = 0; diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 883b08fcaa..0aa296b8d9 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -246,6 +246,63 @@ CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor& a_m_k make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); } +template +CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor& a_m_k, + const HostTensor& q, + const HostTensor& b_k_n, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) +{ + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); + + auto f_mn = [&](auto m, auto n) { + AccDataType v_acc = 0; + AccDataType pasual = 0; + for(std::size_t k = 0; k < (K / 2); k++) + { + using ComputeType = float; + auto b_scale = type_convert(q((2 * k) / QuantGroupSize::kK, n)) - 127; + ComputeType v_a_0, v_a_1; + ComputeType v_b_0, v_b_1; + + v_a_0 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k)))); + v_a_1 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k + 1)))); + + if constexpr(std::is_same_v) + { + auto b_pack = type_convert(b_element_op(b_k_n(k, n))); + auto b_scale_fp4 = type_convert(std::pow(2.0f, b_scale)); + + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + + v_b_0 = type_convert(b_f4_lo) * b_scale_fp4; + v_b_1 = type_convert(b_f4_hi) * b_scale_fp4; + } + + pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1; + v_acc += pasual; + } + c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); + }; + + make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); + std::cout << std::endl; +} + template struct DataTypeTraits { static constexpr const char * name = template <> struct DataTypeTraits { static constexpr const char * name = "int8"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_int4"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4_raw"; }; template struct memOpToStr; template <> struct memOpToStr { static constexpr const char * name = "set"; }; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 9a7876f6a5..ad1862306a 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -92,11 +92,17 @@ struct CShuffleEpilogue using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; - using ATypeToUse = - std::conditional_t, BDataType, ADataType>; + using ATypeToUse = std::conditional_t || + std::is_same_v, + BDataType, + ADataType>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + using BTypeToUse = std::conditional_t || + std::is_same_v || + std::is_same_v, + ADataType, + BDataType>; + using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 8541ffa3a9..f6e26ad206 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -96,8 +96,10 @@ struct BlockUniversalGemmAsBsCr using ATypeToUse = std::conditional_t, BDataType, ADataType>; - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + using BTypeToUse = std::conditional_t || + std::is_same_v, + ADataType, + BDataType>; using WarpGemm = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index f39d41a653..343e37ed66 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -17,10 +17,12 @@ struct GemmPipelineAgBgCrImplBase using BsLayout = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ADataType = remove_cvref_t{}, AsDataType>>; - using ALayout = remove_cvref_t{}, AsLayout>>; - using BDataType = remove_cvref_t{}, BsDataType>>; - using BLayout = remove_cvref_t{}, BsLayout>>; + using ADataType = remove_cvref_t{}, AsDataType>>; + using ALayout = remove_cvref_t{}, AsLayout>>; + using BInDataType = remove_cvref_t{}, BsDataType>>; + using BDataType = + std::conditional_t, ADataType, BInDataType>; + using BLayout = remove_cvref_t{}, BsLayout>>; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; @@ -270,12 +272,17 @@ struct GemmPipelineAgBgCrImplBase }(); auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); + using BLdsDataType = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + auto b_lds_load_tile_distr = []() { if constexpr(is_b_load_tr) return make_static_tile_distribution( - typename InputTileDistributionTraits< - typename BLdsLoadTileDistr::DstrEncode, - typename Problem::BDataType>::TransposedDstrEncode{}); + typename InputTileDistributionTraits::TransposedDstrEncode{}); + else return BLdsLoadTileDistr{}; }(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 76341af70b..a45d41189b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -303,8 +303,11 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using BLayout = remove_cvref_t; - using BDataType = remove_cvref_t; + using BLayout = remove_cvref_t; + using BDataType = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -585,9 +588,12 @@ struct UniversalGemmBasePolicy using BsDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + using BLayout = remove_cvref_t{}, BsLayout>>; + using BInDataType = remove_cvref_t{}, BsDataType>>; - using BLayout = remove_cvref_t{}, BsLayout>>; - using BDataType = remove_cvref_t{}, BsDataType>>; + using BDataType = std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; if constexpr(Problem::FixedVectorSize) { @@ -729,13 +735,17 @@ struct UniversalGemmBasePolicy { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + using BDataType = remove_cvref_t; + constexpr index_t KPerBlock = std::is_same_v + ? Problem::BlockGemmShape::kK / 2 + : Problem::BlockGemmShape::kK; constexpr index_t VecLoadSize = - Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + std::is_same_v + ? 4 + : (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB()); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - - using BLayout = remove_cvref_t< - std::tuple_element_t{}, remove_cvref_t>>; + using BLayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; // Tile: KPerBlock X NPerBlock if constexpr(std::is_same_v) { @@ -841,10 +851,12 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr index_t GetSmemSizeB() { - constexpr index_t smem_size_b = - integer_least_multiple(sizeof(typename Problem::BDataType) * - Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, - 16); + using BDataType = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + constexpr index_t smem_size_b = integer_least_multiple( + sizeof(BDataType) * Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, 16); return smem_size_b; } @@ -882,8 +894,10 @@ struct UniversalGemmPipelineAgBgCrPolicy using BDataType = remove_cvref_t; using ATypeToUse = std::conditional_t, BDataType, ADataType>; - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + using BTypeToUse = std::conditional_t || + std::is_same_v, + ADataType, + BDataType>; using WarpGemm = WarpGemmDispatcher( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); + if constexpr(std::is_same_v) + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k / 2), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + else + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); } } } @@ -885,10 +893,16 @@ struct QuantGemmKernel const auto& b_tensor_view = views.at(I2); if constexpr(std::is_same_v) { - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); + if constexpr(std::is_same_v) + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + else + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } else { @@ -1020,10 +1034,17 @@ struct QuantGemmKernel { if constexpr(std::is_same_v) { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); + if constexpr(std::is_same_v) + return make_tile_window( + b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + else + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); } else { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp new file mode 100644 index 0000000000..58019d703e --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" + +namespace ck_tile { + +template +struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase +{ + using Base = GemmPipelineAgBgCrImplBase; + using ADataType = typename Base::ADataType; + using ALayout = typename Base::ALayout; + using BDataType = typename Base::BDataType; + using BLayout = typename Base::BLayout; + using BlockGemmShape = typename Base::BlockGemmShape; + using QuantGroupSize = remove_cvref_t; + + using BQLayout = remove_cvref_t; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN; + static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK; + + static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize"); + static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize"); + + static_assert(NPerBlock % QuantGroupSize::kN == 0, + "NPerBlock must be a multiple of QuantGroupSize::kN"); + static_assert(KPerBlock % QuantGroupSize::kK == 0, + "KPerBlock must be a multiple of QuantGroupSize::kK"); + + // Create DRAM tile window for BQ + template + CK_TILE_DEVICE constexpr auto + GetBQDramLoadWindow(const BQDramBlockWindowTmp& bq_dram_block_window_tmp) const + { + static_assert(std::is_same_v); + + using YPerTile = number; + using XPerTile = number; + + auto bq_copy_dram_window = + make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(YPerTile(), XPerTile()), + bq_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBQDramTileDistribution()); + return bq_copy_dram_window; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp new file mode 100644 index 0000000000..6ce2ff10fa --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "gemm_group_quant_utils.hpp" + +namespace ck_tile { + +struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy +{ + using Base = UniversalGemmPipelineAgBgCrPolicy; + using Base::I0; + using Base::I1; + using Base::I2; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ() + { + using BQLayout = remove_cvref_t; + using BQDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK; + + static_assert(std::is_same_v); + return GetABQGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBRegTileDistribution() + { + using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + // Tile: KPerBlock X NPerBlock + if constexpr(std::is_same_v) + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + // Tile: NPerBlock X KPerBlock + else + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() + { + // using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t KScale = KPerBlock / Problem::QuantGroupSize::kK; // k_scale num //2 + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t num_warps = BlockSize / get_warp_size(); + constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size); + constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize; + constexpr index_t K0 = KPerBlock / b_vec; + constexpr index_t K1 = K0 / KScale; + constexpr index_t K3 = K0 / K1; + constexpr index_t K2 = 1; + + constexpr index_t N0 = num_warps / NumWaveGroups; + constexpr index_t N1 = warp_size / K0; + constexpr index_t N2 = NPerBlock / (N0 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 0>>, + tuple, sequence<1, 0, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0, + "KPerWarpGemm must be a multiple of QuantGroupSize!"); + + using WarpGemm = WarpGemmDispatcher; + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v); + + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy< + typename Problem::ADataType, + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>, + typename Problem::CDataType, + BlockWarps, + WarpGemm>; + + return BlockUniversalGemmAsBsCr{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp new file mode 100644 index 0000000000..c113521d6b --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp @@ -0,0 +1,665 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register + +template +struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 +{ + using Base = BaseGemmPipelineAgBgCrCompV3; + using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BDqDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using QuantGroupSize = remove_cvref_t; + + static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); + + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t BQPackedSize = + ck_tile::numeric_traits>::PackedSize; + + using ALayout = remove_cvref_t; + using BQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockGemm = remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN; + static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + static constexpr index_t GetVectorSizeBQ() + { + return Policy::template GetVectorSizeBQ(); + } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + using Base::PrefetchStages; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + return concat('_', "mxfp4gemm_pipeline_AgBgCrCompV3", + concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, + concat('x', WaveNumM, WaveNumN), + concat('x', kPadM, kPadN, kPadK), + concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName()); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST static std::string Print() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + constexpr index_t BQ_Buffer_Load_Inst_Num = + NPerBlock * KPerBlockBQ / (BlockSize * GetVectorSizeBQ()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + auto str = std::stringstream{}; + + str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", " + << "BQ vector size: " << GetVectorSizeBQ() << "\n" + << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n" + << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num + << ", " << "BQ buffer load inst: " << BQ_Buffer_Load_Inst_Num << "\n" + << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num + << "\n" + << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n" + << "C MFMA inst: " << C_MFMA_Inst_Num << "\n" + << "QuantGroupSize: " << QuantGroupSize::GetName() << "\n" + << "KPack: " << BlockGemm::Traits::KPack << "\n" + << "PrefetchStages: " << PrefetchStages << "\n"; + return str.str(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + // Below should be equal to AK1|BK1 + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + B_LDS_Read_Width * sizeof(BDqDataType) / BPackedSize == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + B_LDS_Read_Width * sizeof(BDqDataType) / BPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / + // sizeof(BDataType) + // ? sizeof(ComputeDataType) / + // sizeof(ADataType) : sizeof(ComputeDataType) + // / sizeof(BDataType); + constexpr auto num_mfma_stage1 = + num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/BQ Dram block window should have the same data type as appropriate " + "([A|B|BQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_bq_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); + static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}], + "Bq block window has incorrect lengths for defined BqLayout!"); + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert( + is_b_row_major + ? (KPerBlock / 2 == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock / 2 == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + // ------------------------------------------------------------------------------------ + // Definitions of all needed tiles + // int b_block_stride = 0; + // A/B tiles in LDS + auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + + // B DRAM tile window for load, (kN, kK/2) + // B LDS tile window for store, (kN, kK) + // B LDS tile for block GEMM + auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // B scale DRAM tile window for load + // auto b_scale_copy_dram_window = + // make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), + // bq_dram_block_window_tmp.get_window_lengths(), + // bq_dram_block_window_tmp.get_window_origin(), + // Policy::template GetBQDramLoadWindow()); + auto bq_copy_dram_window = Base::GetBQDramLoadWindow(bq_dram_block_window_tmp); + + auto bq_block_tile = decltype(load_tile(bq_copy_dram_window)){}; + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + // using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile; + BBlockTile b_fp4_block_tile; + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock / 2, 0) : make_array(0, KPerBlock / 2); + + constexpr index_t b_scale_dram_tile_window_step = KPerBlock / QuantGroupSize::kK; + // ----------------------------------------------------------------------------------------- + // Gemm pipeline start + + // prefetch + // global read 0 + // auto a_scale_block_tile = decltype(load_tile(a_scale_copy_dram_window)){}; + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // BDataType + auto b_block_tile = make_static_distributed_tensor( + Policy::template MakeBRegTileDistribution()); + + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + + constexpr auto idx1_js = tile_distributed_index<0>{}; + constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans(); + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; + auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + + auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + b_block_tile(i_j_idx_lo) = + type_convert(type_convert(b_f4_lo) * b_scale); + b_block_tile(i_j_idx_hi) = + type_convert(type_convert(b_f4_hi) * b_scale); + }); + }); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + block_sync_lds(); + + // LDS write 0 + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + + auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; + auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + + auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + b_block_tile(i_j_idx_lo) = + type_convert(type_convert(b_f4_lo) * b_scale); + b_block_tile(i_j_idx_hi) = + type_convert(type_convert(b_f4_hi) * b_scale); + }); + }); + + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch( + b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + + auto b_scale_uint = + type_convert(bq_block_tile(i_j_idx_scale)) - 127; + auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = + tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + + auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + b_block_tile(i_j_idx_lo) = + type_convert(type_convert(b_f4_lo) * b_scale); + b_block_tile(i_j_idx_hi) = + type_convert(type_convert(b_f4_hi) * b_scale); + }); + }); + + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + // b_block_stride +=1; + } while(i < (num_loop - 1)); + } + // tile_elementwise_inout([](auto& c) { c = 0; }, acc_block_tile); + // tail + if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) + { + // Leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + else + { + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + } + __builtin_amdgcn_sched_barrier(0); + return c_block_tile; + } + }; + + /** + * @brief This function runs the pipeline using compile-time known hot loop and tail number. + * @param num_loop The number of loop iterations. This is determined at runtime due to e.g. + * SplitK. + * @note This is used by the kernel variants that are able to determine + * hot loop and tail number on the host side, e.g. non-persistent gemm kernel. + */ + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + void* p_smem, + index_t n = 0) const + { + ck_tile::ignore = n; + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDqDataType& b) { return b; }, + bq_dram_block_window_tmp, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt old mode 100644 new mode 100755 diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 39a7c66f38..fe5d2bd7e1 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -131,8 +131,10 @@ class TestCkTileGemmQuantBase : public ::testing::Test const ck_tile::index_t kbatch, const float max_accumulated_value) { - using ComputeType = - std::conditional_t; + using ComputeType = std::conditional_t< + std::is_same_v, + ADataType_, + std::conditional_t>; // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp index ef0d41909b..ec123364cb 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp @@ -16,9 +16,12 @@ using FP8 = ck_tile::fp8_t; using BF8 = ck_tile::bf8_t; using Half = ck_tile::half_t; using PkInt4 = ck_tile::pk_int4_t; +using BF16 = ck_tile::bf16_t; +using UInt8 = ck_tile::pk_fp4_raw_t; using BQuantGrouped = std::integral_constant; using GroupSize = ck_tile::QuantGroupShape>; using GroupSize64 = ck_tile::QuantGroupShape>; +using GroupSize32 = ck_tile::QuantGroupShape>; // 2d block sizes for BQuant using GroupSize2D8N = ck_tile::QuantGroupShape>; @@ -42,6 +45,9 @@ using BQuantTypes = ::testing::Types< std::tuple, std::tuple, std::tuple, + std::tuple, + + std::tuple, // 2d cases with grouping also on the n axis std::tuple, diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index bf9c7a138d..4f2edb3609 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -60,6 +60,13 @@ struct GemmConfigPrefill : public GemmConfigBase static constexpr ck_tile::index_t K_Tile = 128; }; +struct GemmConfigMxFp4 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; +}; + struct GemmConfigPreshuffleQuant : public GemmConfigBase { static constexpr bool PreshuffleQuant = true; @@ -403,7 +410,8 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase ? (K / 2) : K; const ck_tile::index_t stride_C = N; // BQuant uses block/grouped quantization for B matrix @@ -414,15 +422,27 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); - ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); + ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( + std::is_same_v ? K / 2 : K, + N, + stride_B, + this->is_row_major(BLayout{}))); ck_tile::HostTensor bq_bqk_bqn( ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{}))); // Initialize data with random values ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); - ck_tile::FillUniformDistribution{0.f, 1.f}(b_k_n); - ck_tile::FillUniformDistribution{-1.0f, 1.0f}(bq_bqk_bqn); + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f}(b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f}(bq_bqk_bqn); + } + else + { + ck_tile::FillUniformDistribution{0.f, 1.f}(b_k_n); + ck_tile::FillUniformDistribution{-1.0f, 1.0f}(bq_bqk_bqn); + } + // Allocate device memory ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType)); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType)); @@ -501,13 +521,22 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); + if constexpr(std::is_same_v) + ck_tile::reference_mxfp4gemm_quant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); + else + ck_tile::reference_gemm_quant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); // Get device result ck_tile::HostTensor c_m_n_dev_result( @@ -580,33 +609,37 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase; - using GemmPipeline = - std::conditional_t, - ck_tile::WPQuantBPipelineAgBgCrV2>; + using GemmPipeline = std::conditional_t< + PreshuffleB == false, + std::conditional_t, + ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>, + ck_tile::WPQuantBPipelineAgBgCrV2>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - Base::M_Warp, - Base::N_Warp, - Base::M_Warp_Tile, - Base::N_Warp_Tile, - Base::K_Warp_Tile, - false, // transpose_c - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - TiledMMAPermuteN>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue, + ADataType, + BDataType>, + ck_tile::tuple<>, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Base::M_Warp, + Base::N_Warp, + Base::M_Warp_Tile, + Base::N_Warp_Tile, + Base::K_Warp_Tile, + false, // transpose_c + ck_tile::memory_operation_enum::set, + 1, + false, + 1, + TiledMMAPermuteN>>; using Kernel = ck_tile::QuantGemmKernel