diff --git a/example/ck_tile/39_gemm_microscaling/CMakeLists.txt b/example/ck_tile/39_gemm_microscaling/CMakeLists.txt new file mode 100644 index 0000000000..c99d6409e0 --- /dev/null +++ b/example/ck_tile/39_gemm_microscaling/CMakeLists.txt @@ -0,0 +1,13 @@ +set(EXAMPLE_GEMM_MX_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_MX_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +list(APPEND EXAMPLE_GEMM_MX_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) + +if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") + add_executable(tile_example_gemm_mx_basic EXCLUDE_FROM_ALL gemm_mx_basic.cpp) + target_compile_options(tile_example_gemm_mx_basic PRIVATE ${EXAMPLE_GEMM_MX_COMPILE_OPTIONS}) +else() + message(DEBUG "Skipping ck_tile quant gemm tests for current target") +endif() diff --git a/example/ck_tile/39_gemm_microscaling/README.md b/example/ck_tile/39_gemm_microscaling/README.md new file mode 100644 index 0000000000..ea71072869 --- /dev/null +++ b/example/ck_tile/39_gemm_microscaling/README.md @@ -0,0 +1,35 @@ +# GEMM Matrix Multiplication + +This folder contains example for Block Scale GEMM using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# The aquant pipeline method on the gemm calculation +make tile_example_gemm_aquant_basic -j +``` +This will result in an executable `build/bin/tile_example_gemm_mx_basic` + +## example +``` +args: + -b batch size (default:1) + -m m dimension (default:1024) + -n n dimension (default:2048) + -k k dimension (default:64) + -a_layout Tensor A data layout (default: R) + -b_layout Tensor B data layout (default: R) + -c_layout Tensor C data layout (default: R) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) + -e Absolute error tolerance (default:1e-5) + -prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16) + -warmup number of iterations before benchmark the kernel (default:10) + -repeat number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) +``` diff --git a/example/ck_tile/39_gemm_microscaling/gemm_mx_fp4_basic.cpp b/example/ck_tile/39_gemm_microscaling/gemm_mx_fp4_basic.cpp index 305bbd6f1d..626dedea0c 100644 --- a/example/ck_tile/39_gemm_microscaling/gemm_mx_fp4_basic.cpp +++ b/example/ck_tile/39_gemm_microscaling/gemm_mx_fp4_basic.cpp @@ -13,16 +13,19 @@ #include "gemm_utils.hpp" template -float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s) + uint32_t BlockScaleSize> +float gemm_mx_calc(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s) { constexpr bool kPadM = false; constexpr bool kPadN = false; @@ -32,17 +35,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s static_assert(std::is_same_v); - constexpr ck_tile::index_t M_Tile = 16; + constexpr ck_tile::index_t M_Tile = 64; constexpr ck_tile::index_t N_Tile = 64; constexpr ck_tile::index_t K_Tile = 256; - constexpr ck_tile::index_t M_Warp = 1; - constexpr ck_tile::index_t N_Warp = 4; + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; constexpr ck_tile::index_t M_Warp_Tile = 16; constexpr ck_tile::index_t N_Warp_Tile = 16; - constexpr ck_tile::index_t K_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 128; using CodegenGemmShape = ck_tile::TileGemmShape, @@ -51,8 +54,14 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s using TilePartitioner = ck_tile::GemmTile1DPartitioner; - using CodegenGemmTraits = - ck_tile::TileGemmAQuantTraits; + using CodegenGemmTraits = ck_tile::TileGemmMXTraits; using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; - using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3; + using BaseGemmPipeline = ck_tile::BaseGemmMXPipelineAgBgCrCompV3; const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); @@ -74,39 +83,39 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s constexpr auto tail_number_v = tail_number_.value; using CodegenPipelineProblem = - ck_tile::GemmAQuantPipelineProblem; - using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; + ck_tile::GemmMXPipelineProblem; + using CodegenGemmPipeline = ck_tile::GemmMXPipelineAgBgCrCompV3; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - transposed_warp_gemm, - ck_tile::memory_operation_enum::set>>; - using Kernel = - ck_tile::AQuantGemmKernel; + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + CodegenPipelineProblem::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + y transposed_warp_gemm, + ck_tile::memory_operation_enum::set>>; + using Kernel = ck_tile::GemmMXKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -170,7 +179,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a return 0; } -int run_gemm_example(int argc, char* argv[]) +int run_gemm_mx_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); if(!result) @@ -182,10 +191,11 @@ int run_gemm_example(int argc, char* argv[]) if(data_type == "fp4") { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmMXTypeConfig{}); return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); } else @@ -194,4 +204,4 @@ int run_gemm_example(int argc, char* argv[]) } } -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return !run_gemm_mx_example(argc, argv); } diff --git a/example/ck_tile/39_gemm_microscaling/gemm_utils.hpp b/example/ck_tile/39_gemm_microscaling/gemm_utils.hpp index 175f5f464f..2bd8cf99b4 100644 --- a/example/ck_tile/39_gemm_microscaling/gemm_utils.hpp +++ b/example/ck_tile/39_gemm_microscaling/gemm_utils.hpp @@ -375,86 +375,51 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase static constexpr bool DoubleSmemBuffer = false; }; -template -struct GemmTypeConfig; +// template +// struct GemmTypeConfig; -template <> -struct GemmTypeConfig -{ - using ADataType = ck_tile::half_t; - using BDataType = ck_tile::half_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; - // ToDo: Add more bias config to support different categories of GEMM. -}; - -template <> -struct GemmTypeConfig -{ - using ADataType = ck_tile::bf16_t; - using BDataType = ck_tile::bf16_t; - using AccDataType = float; - using CDataType = ck_tile::bf16_t; -}; - -template <> -struct GemmTypeConfig -{ - using ADataType = ck_tile::fp8_t; - using BDataType = ck_tile::fp8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmTypeConfig -{ - using ADataType = ck_tile::bf8_t; - using BDataType = ck_tile::bf8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmTypeConfig -{ - using ADataType = ck_tile::half_t; - using BDataType = ck_tile::pk_int4_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmTypeConfig -{ - using ADataType = ck_tile::int8_t; - using BDataType = ck_tile::int8_t; - using AccDataType = int32_t; - using CDataType = int32_t; -}; +// template <> +// struct GemmTypeConfig +// { +// using ADataType = ck_tile::half_t; +// using BDataType = ck_tile::half_t; +// using AccDataType = float; +// using CDataType = ck_tile::half_t; +// // ToDo: Add more bias config to support different categories of GEMM MX. +// }; template + typename BDataType_ = ADataType_, + typename ScaleDataType_ = ADataType_, + typename ScalePackDataType_ = ScaleDataType_, + typename CDataType_ = ADataType_> struct GemmMXTypeConfig { - using ADataType = ADataType_; - using QDataType = QDataType_; - using BDataType = BDataType_; - using AccDataType = float; - using CDataType = CDataType_; + using ADataType = ADataType_; + using BDataType = BDataType_; + using ScaleDataType = ScaleDataType_; + using ScalePackDataType = ScalePackDataType_; + using AccDataType = float; + using CDataType = CDataType_; }; // microscaling gemm template <> -struct GemmMXTypeConfig +struct GemmMXTypeConfig { - using ADataType = ck_tile::pk_fp4_t; - using BDataType = ck_tile::pk_fp4_t; - using QDataType = ck_tile::e8m0_bexp_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; + using ADataType = ck_tile::pk_fp4_t; + using BDataType = ck_tile::pk_fp4_t; + using ScaleDataType = ck_tile::e8m0_bexp_t; + using ScalePackDataType = int32_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; } template diff --git a/example/ck_tile/39_gemm_microscaling/run_gemm_mx_example.inc b/example/ck_tile/39_gemm_microscaling/run_gemm_mx_example.inc index 8b8a32c002..d9822e7560 100644 --- a/example/ck_tile/39_gemm_microscaling/run_gemm_mx_example.inc +++ b/example/ck_tile/39_gemm_microscaling/run_gemm_mx_example.inc @@ -32,7 +32,6 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, - ck_tile::index_t AQK, ck_tile::index_t stride_A, ck_tile::index_t stride_AQ, ck_tile::index_t stride_B, @@ -41,46 +40,54 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, int n_warmup, int n_repeat) { - ck_tile::AQuantGemmHostArgs args; - args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer(); - args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - args.k_batch = kbatch; - args.M = M; - args.N = N; - args.K = K; - args.QK = AQK; - args.stride_A = stride_A; - args.stride_B = stride_B; - args.stride_C = stride_C; - args.stride_AQ = stride_AQ; + ck_tile::GemmMXKernelArgs args; + args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + args.a_scale_ptr_ = a_m_k_scale_dev_buf.GetDeviceBuffer(); + args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + args.b_scale_ptr_ = b_k_n_scale_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.k_batch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.stride_A = stride_A; + args.stride_scale_A = stride_scale_A; // stride for A scale + args.stride_B = stride_B; + args.stride_scale_B = stride_scale_B; // stride for B scale + args.stride_C = stride_C; - float ave_time = gemm_calc_aquant( + float ave_time = gemm_mx_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(AQDataType) * M * AQK + - sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / BlockScaleSize; + std::size_t num_byte = + sizeof(ADataType) * M * K / ck_tile::numeric_traits::PackedSize + + sizeof(BDataType) * K * N / ck_tile::numeric_traits::PackedSize + + sizeof(ck_tile::e8m0_bexp_t) * M * K / BlockScaleSize + + sizeof(ck_tile::e8m0_bexp_t) * K * N / BlockScaleSize + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K - << " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B + << " StrideA =" << stride_A << " StrideScaleA =" << stride_scale_A + << " StrideB =" << stride_B << " StrideScaleB =" << stride_scale_B << " StrideC =" << stride_C << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name << " A_Type = " << DataTypeTraits::name - << " AQ_Type = " << DataTypeTraits::name + << " A_Scale_Type = " << DataTypeTraits::name << " B_Type = " << DataTypeTraits::name + << " B_Scale_Type = " << DataTypeTraits::name << " Acc_Type = " << DataTypeTraits::name << " C_Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; @@ -107,12 +114,13 @@ int run_gemm_example_with_layouts(int argc, if(!result) return -1; - using ADataType = typename TypeConfig::ADataType; - using AScaleDataType = typename TypeConfig::QDataType; - using BDataType = typename TypeConfig::BDataType; - using BScaleDataType = typename TypeConfig::QDataType; - using AccDataType = typename TypeConfig::AccDataType; - using CDataType = typename TypeConfig::CDataType; + using ADataType = typename TypeConfig::ADataType; + using AScaleDataType = typename TypeConfig::ScaleDataType; + using BDataType = typename TypeConfig::BDataType; + using BScaleDataType = typename TypeConfig::ScaleDataType; + using XPackedDataType = typename TypeConfig::ScalePackDataType; + using AccDataType = typename TypeConfig::AccDataType; + using CDataType = typename TypeConfig::CDataType; ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); @@ -177,23 +185,24 @@ int run_gemm_example_with_layouts(int argc, if(init_method == 0) { - ck_tile::FillUniformDistribution{-1.0f, 1.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-1.0f, 1.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-1.0f, 1.0f, fill_seed(gen)}(a_m_k_scale); - ck_tile::FillUniformDistribution{-1.0f, 1.0f, fill_seed(gen)}(b_k_n_scale); + ck_tile::FillUniformDistribution{-1.0f, 1.0f}(a_m_k); + ck_tile::FillUniformDistribution{-1.0f, 1.0f}(b_k_n); + ck_tile::FillUniformDistribution{-1.0f, 1.0f}(a_m_k_scale); + ck_tile::FillUniformDistribution{-1.0f, 1.0f}(b_k_n_scale); } else if(init_method == 1) { - ck_tile::FillConstant{ck_tile::type_convert(ck_tile::float2_t(0.5f))}( - a_m_k); - ck_tile::FillConstant{static_cast(0.5f)}(aq_m_aqk); - ck_tile::FillConstant{static_cast(0x38)}(b_k_n); + ck_tile::FillConstant{1.0f, 1.0f}(a_m_k); + ck_tile::FillConstant{1.0f, 1.0f}(b_k_n); + ck_tile::FillConstant{1.0f, 1.0f}(a_m_k_scale); + ck_tile::FillConstant{1.0f, 1.0f}(b_k_n_scale); } else { a_m_k.SetZero(); - aq_m_aqk.SetZero(); b_k_n.SetZero(); + a_m_k_scale.SetZero(); + b_k_n_scale.SetZero(); } // Shuffle A, B scale tensors @@ -216,9 +225,9 @@ int run_gemm_example_with_layouts(int argc, c_m_n_dev_result.SetZero(); invoke_gemm +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +struct GemmMXProblem +{ + CK_TILE_HOST GemmMXProblem() = default; + CK_TILE_HOST GemmMXProblem(index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_, + index_t stride_scale_A_, + intdex_t stride_scale_B_) + : M(M_), + N(N_), + K(K_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_C(stride_C_), + stride_scale_A(stride_scale_A_), + stride_scale_B(stride_scale_B_) + { + } + + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + index_t stride_C; + index_t stride_scale_A; + index_t stride_scale_B; +}; + +struct GemmMXHostArgs : public GemmMXProblem +{ + CK_TILE_HOST GemmMXHostArgs() = default; + CK_TILE_HOST GemmMXHostArgs(const void* a_ptr_, + const void* a_scale_ptr_, + const void* b_ptr_, + const void* b_scale_ptr_, + void* c_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_scale_A_, + index_t stride_B_, + index_t stride_scale_B_, + index_t stride_C_) + : GemmMXProblem( + M_, N_, K_, stride_A_, stride_B_, stride_C_, stride_scale_A_, stride_scale_B_), + a_ptr(a_ptr_), + a_scale_ptr_(a_scale_ptr_), + b_ptr(b_ptr_), + b_scale_ptr_(b_scale_ptr_), + c_ptr(c_ptr_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* a_scale_ptr_; + const void* b_ptr; + const void* b_scale_ptr_; + void* c_ptr; + index_t k_batch; +}; + +struct GemmMXKernelArgs +{ + const void* a_ptr; + const void* a_scale_ptr; + const void* b_ptr; + const void* b_scale_ptr; + void* c_ptr; + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_scale_A; + index_t stride_B; + index_t stride_scale_B; + index_t stride_C; + index_t k_batch; +}; + +template +struct GemmMXKernel +{ + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using AScaleLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using BScaleLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + + using ADataType = remove_cvref_t; + using AScaleDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BScaleDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>(); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm", gemm_prec_str, GemmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + { + return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + CK_TILE_HOST static constexpr AQuantGemmKernelArgs + MakeKernelArgs(const AQuantGemmHostArgs& hostArgs) + { + return AQuantGemmKernelArgs{hostArgs.a_ptr, + hostArgs.b_ptr, + hostArgs.aq_ptr, + hostArgs.c_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.QK, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_C, + hostArgs.stride_AQ, + hostArgs.k_batch}; + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const AQuantGemmKernelArgs& kargs, + const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); + const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); + + if constexpr(std::is_same_v) + { + a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + else if constexpr(std::is_same_v) + { + a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A); + } + + if constexpr(std::is_same_v) + { + b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B); + } + else if constexpr(std::is_same_v) + { + b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + + if(k_id < static_cast(kargs.k_batch - 1)) + { + splitted_k = __builtin_amdgcn_readfirstlane(KRead); + } + else + { + splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t splitted_k; + }; + + CK_TILE_HOST static bool IsSupportedArgument(const AQuantGemmKernelArgs& kargs) + { + if(kargs.k_batch != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + } + return false; + } + + static_assert(std::is_same_v); + if(kargs.QK % GemmPipeline::GetVectorSizeAQ() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } + return false; + } + + if constexpr(std::is_same_v) + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + return false; + } + if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + return false; + } + if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); + } + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + return false; + } + if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); + } + return false; + } + } + else + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + return false; + } + if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); + } + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + return false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + return false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + return true; + } + + template + CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, + const BDataType* b_ptr, + const AQDataType* aq_ptr, + CDataType* c_ptr, + const AQuantGemmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) + { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + const auto& a_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + a_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + }(); + + const auto& aq_tensor_view = [&]() { + static_assert(std::is_same_v); + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.M, kargs.QK), + make_tuple(kargs.stride_AQ, 1), + number{}, + number<1>{}); + }(); + + const auto& b_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + } + else + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + } + }(); + + // TODO: enable vector write for C in ColMajor + const auto& c_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_C, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_C), + number<1>{}, + number<1>{}); + } + }(); + + return make_tuple(a_tensor_view, aq_tensor_view, b_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + const auto& aq_pad_view = [&]() { + const auto& aq_tensor_view = views.at(I1); + static_assert(std::is_same_v); + return pad_tensor_view( + aq_tensor_view, + make_tuple(number{}, + number{}), + // TODO: Add support for padding. + sequence{}); + }(); + + const auto& b_pad_view = [&]() { + const auto& b_tensor_view = views.at(I2); + if constexpr(std::is_same_v) + { + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // TODO vector write in for C in ColMajor + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I3); + if constexpr(std::is_same_v) + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + return make_tuple(a_pad_view, aq_pad_view, b_pad_view, c_pad_view); + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& a_pad_view = views.at(I0); + const auto& aq_pad_view = views.at(I1); + const auto& b_pad_view = views.at(I2); + const auto& c_pad_view = views.at(I3); + + const auto& a_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, i_m}); + } + }(); + + const auto& aq_block_window = [&]() { + static_assert(std::is_same_v); + return make_tile_window( + aq_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + }(); + + const auto& b_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + } + else + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {0, i_n}); + } + }(); + + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(a_block_window, aq_block_window, b_block_window, c_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param aq_ptr input AQ pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + * @tparam DstInMemOp Destination memory operation (default: set). + */ + template + CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, + const BDataType* b_ptr, + const AQDataType* aq_ptr, + CDataType* c_ptr, + void* smem_ptr_0, + const AQuantGemmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( + a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& aq_block_window = gemm_tile_windows.at(I1); + const auto& b_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + + CK_TILE_DEVICE void operator()(AQuantGemmKernelArgs kargs) const + { + const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const SplitKBatchOffset splitk_batch_offset(kargs); + // options + const ADataType* a_ptr = static_cast(kargs.a_ptr); + const BDataType* b_ptr = static_cast(kargs.b_ptr); + const AQDataType* aq_ptr = static_cast(kargs.aq_ptr); + CDataType* c_ptr = static_cast(kargs.c_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + assert(kargs.k_batch == 1); + RunGemm(a_ptr, b_ptr, aq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp new file mode 100644 index 0000000000..a3dd430faa --- /dev/null +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" + +namespace ck_tile { + +template +struct GemmMXPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase +{ + using Base = GemmPipelineAgBgCrImplBase; + using ADataType = typename Base::ADataType; + using ALayout = typename Base::ALayout; + using BDataType = typename Base::BDataType; + using BLayout = typename Base::BLayout; + using BlockGemmShape = typename Base::BlockGemmShape; + + using AQLayout = remove_cvref_t; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize; + static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize; + + static_assert(KPerBlock % QuantGroupSize == 0, + "KPerBlock must be a multiple of QuantGroupSize"); + + // Create DRAM tile window for AQ + template + CK_TILE_DEVICE constexpr auto + GetAQDramLoadWindow(const AQDramBlockWindowTmp& aq_dram_block_window_tmp) const + { + static_assert(std::is_same_v); + + using YPerTile = number; + using XPerTile = number; + + auto aq_copy_dram_window = + make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(YPerTile(), XPerTile()), + aq_dram_block_window_tmp.get_window_origin(), + Policy::template MakeAQDramTileDistribution()); + return aq_copy_dram_window; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp new file mode 100644 index 0000000000..3c165e2e91 --- /dev/null +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "gemm_mx_utils.hpp" + +namespace ck_tile { + +struct GemmMXPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy +{ + using Base = UniversalGemmPipelineAgBgCrPolicy; + using Base::I0; + using Base::I1; + using Base::I2; + + using Base::ATileAccessPattern; + using Base::BTileAccessPattern; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAScale() + { + using AScaleLayout = remove_cvref_t; + using AScaleDataType = remove_cvref_t; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockScale = KPerBlock / Problem::kBlockScaleSize; + + static_assert(std::is_same_v); + return GetAScaleGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeAQDramTileDistribution() + { + using AQLayout = remove_cvref_t; + using BlockGemmShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize; + constexpr index_t VecLoadSize = GetVectorSizeAQ(); + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + + static_assert(std::is_same_v); + using TileEncodingPattern = TileDistributionEncodingPatternAQ; + + return TileEncodingPattern::Make2DStaticTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0, + "KPerWarpGemm must be a multiple of kQuantGroupSize!"); + + using WarpGemm = WarpGemmMfmaDispatcher; + static_assert(std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v); + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + return AQuantBlockUniversalGemmAsBsCr{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp new file mode 100644 index 0000000000..e2fdd7d443 --- /dev/null +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp @@ -0,0 +1,480 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BaseGemmMXPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 +{ + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) + { + if(has_hot_loop) + { + if(tail_number == ck_tile::TailNumber::Full) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Odd) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Even) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Unsupported tail number for this operation !!!"); + } + } + else + { + if(tail_number == ck_tile::TailNumber::Full) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Odd) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Even) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Unsupported tail number for this operation !!!"); + } + } + } +}; + +template +struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3 +{ + using Base = BaseGemmPipelineAgBgCrCompV3; + using PipelineImplBase = GemmMXPipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using AScaleDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BScaleDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t AScalePackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BScalePackedSize = + ck_tile::numeric_traits>::PackedSize; + + using ALayout = remove_cvref_t; + using AScaleLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using BScaleLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockGemm = remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr index_t BlockScaleSize = Problem::kBlockScaleSize; + static constexpr index_t KPerBlockScale = BlockGemmShape::kK / BlockScaleSize; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + static constexpr index_t GetVectorSizeScale() + { + return Policy::template GetVectorSizeAQ(); + } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + using Base::PrefetchStages; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + return concat('_', "mx_pipeline_AgBgCrCompV3", + concat('x', MPerBlock, NPerBlock, KPerBlock), + BlockSize, + concat('x', WaveNumM, WaveNumN), + concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK), + concat('x', kPadM, kPadN, kPadK), "BlockSize", QuantGroupSize); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST static std::string Print() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + constexpr index_t AQ_Buffer_Load_Inst_Num = + MPerBlock * KPerBlockAQ / (BlockSize * GetVectorSizeAQ()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + auto str = std::stringstream{}; + + str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", " + << "AQ vector size: " << GetVectorSizeAQ() << "\n" + << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n" + << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num + << ", " + << "AQ buffer load inst: " << AQ_Buffer_Load_Inst_Num << "\n" + << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num + << "\n" + << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n" + << "C MFMA inst: " << C_MFMA_Inst_Num << "\n" + << "QuantGroupSize: " << QuantGroupSize << "\n" + << "KPack: " << BlockGemm::Traits::KPack << "\n" + << "PrefetchStages: " << PrefetchStages << "\n"; + return str.str(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/AQ Dram block window should have the same data type as appropriate " + "([A|B|AQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_aq_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)"); + static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}], + "Aq block window has incorrect lengths for defined AqLayout!"); + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex; + + auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp); + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + using AQBlockTile = + decltype(make_static_distributed_tensor(AQBlockTileDistr{})); + + auto block_gemm = BlockGemm(); + + ABlockTile a_block_tile; + BBlockTile b_block_tile; + AQBlockTile aq_block_tile[2]; + int currIdx = 0; + + auto c_block_tile = block_gemm.MakeCBlockTile(); + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr AQDramTileWindowStep aq_dram_tile_window_step = + is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ); + + // DRAM prefetch (global read 0) + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch( + aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step); + + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffled2DStaticTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffled2DStaticTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + __builtin_amdgcn_sched_barrier(0); + + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], + aq_copy_dram_window, + aq_dram_tile_window_step); + + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + + currIdx = (currIdx + 1) % 2; + + block_sync_lds(); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) + { + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + } + else + { + Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], + aq_copy_dram_window, + aq_dram_tile_window_step); + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + currIdx = (currIdx + 1) % 2; + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + } + return c_block_tile; + } + }; + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + aq_dram_block_window_tmp, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_problem.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_problem.hpp new file mode 100644 index 0000000000..35a775b4ec --- /dev/null +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_problem.hpp @@ -0,0 +1,126 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" + +#include + +namespace ck_tile { + +template +struct GemmMXPipelineProblemBase : public GemmPipelineProblemBase +{ + using Base = GemmPipelineProblemBase; + + using Traits = typename Base::Traits; + + using typename Base::ADataType; + using typename Base::BDataType; + using typename Base::CDataType; + using typename Base::ComputeDataType; + using AScaleDataType = remove_cvref_t; + using BScaleDataType = remove_cvref_t; + + using BlockGemmShape = typename Base::BlockGemmShape; + + using typename Base::ALayout; + using typename Base::BLayout; + using typename Base::CLayout; + + static constexpr bool TransposeC = false; + + using Base::kBlockSize; + + using Base::kPadK; + using Base::kPadM; + using Base::kPadN; + + using Base::DoubleSmemBuffer; + using Base::VectorLoadSize; + + using AScaleLayout = remove_cvref_t; + using BScaleLayout = remove_cvref_t; + + static constexpr uint32_t kBlockScaleSize = BlockScaleSize_; + static constexpr auto Scheduler = Scheduler_; + static constexpr auto HasHotLoop = HasHotLoop_; + static constexpr auto TailNum = TailNum_; + + static_assert(BlockGemmShape::kK % kBlockScaleSize == 0); + static_assert(Scheduler == GemmPipelineScheduler::Intrawave); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm_mx_problem", + concat('x', VectorLoadSize, kBlockSize), + concat('x', kPadM, kPadN, kPadK), + Scheduler, + "BlockScaleSize", + kBlockScaleSize); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentAScale() + { + static_assert(std::is_same_v); + return VectorLoadSize / sizeof(AScaleDataType); + } + + static constexpr index_t VectorSizeAScale = []() { + static_assert(std::is_same_v); + return kPadK ? 1 : GetAlignmentAScale(); + }(); +}; + +template +using GemmMXPipelineProblem = GemmMXPipelineProblemBase; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_utils.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_utils.hpp new file mode 100644 index 0000000000..9d7a8abaa2 --- /dev/null +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_utils.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST_DEVICE static constexpr auto GetAScaleGlobalVectorLoadSize() +{ + using I1 = number<1>; + constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{}); + + constexpr index_t BlockSize = Problem::kBlockSize; + + // Data is replicated across warps along NWarps, so we divide BlockSize by NWarps + constexpr index_t elements_per_thread = (YPerTile * XPerTile) / (BlockSize / NWarps); + constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; + + // Define vector load candidates in descending order of priority + constexpr std::array candidates{ + PackedSize * 32 / sizeof(DataType), + PackedSize * 16 / sizeof(DataType), + PackedSize * 8 / sizeof(DataType), + PackedSize * 4 / sizeof(DataType), + PackedSize * 2 / sizeof(DataType), + }; + + for(const auto vec_size : candidates) + { + if(vec_size <= 0 || XPerTile % vec_size != 0 || elements_per_thread % vec_size != 0) + continue; + bool is_valid = (vec_size > 0) && (XPerTile % vec_size == 0) && + (elements_per_thread % vec_size == 0) && vec_size != candidates[4]; + if(is_valid) + { + return vec_size; + } + } + return PackedSize; // Absolute fallback +} + +// AQ holds groupquant scale data for A. Data is loaded from DRAM and partitioned across +// threads. Post mfma scales are shuffled across threads in the warp and applied to +// accum registers. +template +struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern +{ + // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! + static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); + static constexpr index_t warp_size = get_warp_size(); + static constexpr index_t num_warps = BlockSize / get_warp_size(); + + static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{}); + static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{}); + static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{}); + + static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM); + + static_assert(num_warps == MWarps * NWarps * KWarps); + + // KWarps > 1 isn't supported + static_assert(KWarps == 1); + + // # of elements per thread + static constexpr index_t X = XPerTile; + + static constexpr index_t Y0 = 1; + static constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1; + static constexpr index_t Y2 = MWarps; + static constexpr index_t Y3 = WarpGemm::kM; + static_assert(Y3 >= WarpGemm::kM, "Scales for all rows must be available within the warp."); + static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile, + "Y0, Y1, Y2, Y3 must cover the blocktile along Y."); + + CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 1>>, + tuple, sequence<0, 3>>, + sequence<1, 2>, + sequence<1, 0>>{}); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/pipeline/tile_gemm_mx_traits.hpp b/include/ck_tile/ops/gemm_mx/pipeline/tile_gemm_mx_traits.hpp new file mode 100644 index 0000000000..e95d1bd72b --- /dev/null +++ b/include/ck_tile/ops/gemm_mx/pipeline/tile_gemm_mx_traits.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct TileGemmMXTraits +{ + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kPadK = kPadK_; + + static constexpr int _VectorSize = 16; + + using ALayout = ALayout_; + using AScaleLayout = AScaleLayout_; + using BLayout = BLayout_; + using BScaleLayout = BScaleLayout_; + using CLayout = CLayout_; + + static constexpr bool UseStructuredSparsity = false; + static constexpr index_t NumWaveGroups = 1; +}; + +} // namespace ck_tile