From d559ec00a8c9829642ebf6efe53ef62a7825043c Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Wed, 1 Jul 2026 08:21:02 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#8554 (commit be9af54) refactor(ck): mx gemm kernel unification ## Motivation CK tile currently has two separate MX GEMM kernels for gfx950 and gfx1250. This pull request refactors and modernizes the MX GEMM kernel and example to use new scale tensor handling, improved kernel argument structures, and updated pipeline and kernel APIs. The changes simplify the interface and improve type safety. JIRA ID ROCM-26313 ## Technical Details - Add support for gfx950 in MX GEMM kernel for gfx1250 and remove unused kernel - Unify comp async pipeline for GEMM and MX GEMM - Unify eight waves pipeline for GEMM and MX GEMM - Move preshuffle MX GEMM pipeline to gemm ops and remove gemm_mx ops - Unify testing framework for MX GEMM - Add gfx950 tests for grouped MX GEMM ## Test Plan - `test_mx_gemm_async.cpp` for MX GEMM on gfx950 - `test_mx_grouped_gemm_comp_async.cpp` for grouped MX GEMM on gfx950 ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- example/ck_tile/42_mx_gemm/mx_gemm.cpp | 38 +- example/ck_tile/42_mx_gemm/mx_gemm.hpp | 46 +- .../ck_tile/42_mx_gemm/mx_gemm_instance.hpp | 67 +- example/ck_tile/42_mx_gemm/run_mx_gemm.inc | 166 ++-- include/ck_tile/core/tensor/null_tensor.hpp | 18 + include/ck_tile/host/mx_processing.hpp | 247 ++--- .../ck_tile/host/reference/reference_gemm.hpp | 127 +++ include/ck_tile/ops/gemm.hpp | 3 + ...ock_gemm_areg_breg_creg_eight_waves_v1.hpp | 210 ++++- .../block/block_gemm_areg_breg_creg_v1.hpp | 219 +++-- ...k_gemm_areg_breg_creg_v1_custom_policy.hpp | 5 +- ...gemm_asmem_bsmem_creg_v1_custom_policy.hpp | 2 + .../block/block_mx_asmem_breg_creg.hpp | 0 .../ops/gemm/kernel/mx_gemm_kernel.hpp | 353 ++++++- .../gemm/kernel/mx_grouped_gemm_kernel.hpp | 2 +- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 14 +- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 213 ++++- ...ine_ag_bg_cr_comp_async_default_policy.hpp | 189 +++- ...peline_ag_bg_cr_comp_async_eight_waves.hpp | 86 +- ...ag_bg_cr_comp_async_eight_waves_policy.hpp | 148 ++- .../gemm_pipeline_ag_bg_cr_comp_tdm_v1.hpp | 4 + ...emm_pipeline_ag_bg_cr_eight_waves_base.hpp | 48 +- .../wp_mx_pipeline_agmem_bgmem_creg_v1.hpp} | 128 ++- ...x_pipeline_agmem_bgmem_creg_v1_policy.hpp} | 7 +- include/ck_tile/ops/gemm_mx.hpp | 20 - include/ck_tile/ops/gemm_mx/README.md | 0 ..._mx_gemm_areg_breg_creg_eight_waves_v1.hpp | 310 ------- .../block/block_mx_gemm_areg_breg_creg_v1.hpp | 324 ------- .../ops/gemm_mx/kernel/gemm_mx_kernel.hpp | 863 ------------------ .../ops/gemm_mx/kernel/scale_pointer.hpp | 120 --- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 782 ---------------- ...ine_ag_bg_cr_comp_async_default_policy.hpp | 605 ------------ ...peline_ag_bg_cr_comp_async_eight_waves.hpp | 282 ------ ...ag_bg_cr_comp_async_eight_waves_policy.hpp | 201 ---- ...t_pipeline_ag_bg_cr_eight_waves_policy.hpp | 66 +- test/ck_tile/gemm_mx/CMakeLists.txt | 12 +- test/ck_tile/gemm_mx/test_mx_gemm_async.cpp | 201 ---- .../gemm_mx/test_mx_gemm_async_ccr.cpp | 22 + .../gemm_mx/test_mx_gemm_async_crr.cpp | 22 + .../gemm_mx/test_mx_gemm_async_rcr.cpp | 51 ++ .../test_mx_gemm_async_rcr_large_cases.cpp | 29 + .../gemm_mx/test_mx_gemm_async_rrr.cpp | 22 + test/ck_tile/gemm_mx/test_mx_gemm_config.hpp | 169 ---- .../ck_tile/gemm_mx/test_mx_gemm_instance.hpp | 143 --- .../test_mx_gemm_pipeline_kernel_types.hpp | 127 ++- .../test_mx_gemm_pipeline_tr_cases.inc | 27 + .../test_mx_gemm_pipeline_ut_cases.inc | 4 +- .../gemm_mx/test_mx_gemm_pipeline_util.hpp | 671 ++++++++++++-- test/ck_tile/gemm_mx/test_mx_gemm_util.hpp | 220 ----- test/ck_tile/grouped_gemm_mx/CMakeLists.txt | 22 +- .../grouped_gemm_mx/test_mx_grouped_gemm.cpp | 44 - .../test_mx_grouped_gemm_comp_async.cpp | 26 + ...mx_grouped_gemm_comp_async_large_cases.cpp | 34 + .../test_mx_grouped_gemm_largeM_cases.inc | 67 ++ ..._mx_grouped_gemm_pipeline_kernel_types.hpp | 78 ++ .../test_mx_grouped_gemm_ut_cases.inc | 6 +- .../test_mx_grouped_gemm_util.hpp | 821 ++++++++++++++--- .../test_mx_grouped_gemm_wmma_tdm.cpp | 86 ++ tile_engine/ops/gemm/gemm_instance_builder.py | 29 +- .../ops/gemm/mx_gemm/mx_gemm_profiler.hpp | 74 +- 60 files changed, 3703 insertions(+), 5217 deletions(-) rename include/ck_tile/ops/{gemm_mx => gemm}/block/block_mx_asmem_breg_creg.hpp (100%) rename include/ck_tile/ops/{gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp => gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1.hpp} (86%) rename include/ck_tile/ops/{gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp => gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1_policy.hpp} (98%) delete mode 100644 include/ck_tile/ops/gemm_mx.hpp delete mode 100644 include/ck_tile/ops/gemm_mx/README.md delete mode 100644 include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_eight_waves_v1.hpp delete mode 100644 include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp delete mode 100644 include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp delete mode 100644 include/ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp delete mode 100644 include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp delete mode 100644 include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp delete mode 100644 include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp delete mode 100644 include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp delete mode 100644 test/ck_tile/gemm_mx/test_mx_gemm_async.cpp create mode 100644 test/ck_tile/gemm_mx/test_mx_gemm_async_ccr.cpp create mode 100644 test/ck_tile/gemm_mx/test_mx_gemm_async_crr.cpp create mode 100644 test/ck_tile/gemm_mx/test_mx_gemm_async_rcr.cpp create mode 100644 test/ck_tile/gemm_mx/test_mx_gemm_async_rcr_large_cases.cpp create mode 100644 test/ck_tile/gemm_mx/test_mx_gemm_async_rrr.cpp delete mode 100644 test/ck_tile/gemm_mx/test_mx_gemm_config.hpp delete mode 100644 test/ck_tile/gemm_mx/test_mx_gemm_instance.hpp create mode 100644 test/ck_tile/gemm_mx/test_mx_gemm_pipeline_tr_cases.inc delete mode 100644 test/ck_tile/gemm_mx/test_mx_gemm_util.hpp delete mode 100644 test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm.cpp create mode 100644 test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_comp_async.cpp create mode 100644 test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_comp_async_large_cases.cpp create mode 100644 test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_largeM_cases.inc create mode 100644 test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_pipeline_kernel_types.hpp create mode 100644 test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_wmma_tdm.cpp diff --git a/example/ck_tile/42_mx_gemm/mx_gemm.cpp b/example/ck_tile/42_mx_gemm/mx_gemm.cpp index c5cdc84689..f26a4c9ff5 100644 --- a/example/ck_tile/42_mx_gemm/mx_gemm.cpp +++ b/example/ck_tile/42_mx_gemm/mx_gemm.cpp @@ -24,13 +24,13 @@ static constexpr inline auto is_row_major(Layout layout_) template float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf, ck_tile::DeviceMem& b_dev_buf, @@ -42,36 +42,38 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf, ck_tile::index_t stride_B, ck_tile::index_t stride_C, ck_tile::index_t kbatch, - ScaleM scale_m, - ScaleN scale_n, + ck_tile::DeviceMem& scale_m, + ck_tile::DeviceMem& scale_n, int n_warmup, int n_repeat) { - MXGemmHostArgs args(a_dev_buf.GetDeviceBuffer(), - b_dev_buf.GetDeviceBuffer(), - c_dev_buf.GetDeviceBuffer(), - kbatch, - M, - N, - K, - stride_A, - stride_B, - stride_C, - scale_m, - scale_n); + ck_tile::MxGemmHostArgs<1, 1, 0> args({static_cast(a_dev_buf.GetDeviceBuffer())}, + {static_cast(scale_m.GetDeviceBuffer())}, + {static_cast(b_dev_buf.GetDeviceBuffer())}, + {static_cast(scale_n.GetDeviceBuffer())}, + {}, + c_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + {stride_A}, + {stride_B}, + {}, + stride_C); // Simplified invocation - comp_async handles hot loop and tail internally auto invoke_splitk_path = [&](auto split_k_) { return mx_gemm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); diff --git a/example/ck_tile/42_mx_gemm/mx_gemm.hpp b/example/ck_tile/42_mx_gemm/mx_gemm.hpp index 1ee491f7df..f35df974c3 100644 --- a/example/ck_tile/42_mx_gemm/mx_gemm.hpp +++ b/example/ck_tile/42_mx_gemm/mx_gemm.hpp @@ -9,49 +9,8 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm_mx.hpp" -#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp" - -template -struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0> -{ - using Base = ck_tile::UniversalGemmHostArgs<1, 1, 0>; - - MXGemmHostArgs(const void* a_ptr, - const void* b_ptr, - void* c_ptr_, - ck_tile::index_t k_batch_, - ck_tile::index_t M_, - ck_tile::index_t N_, - ck_tile::index_t K_, - ck_tile::index_t stride_A_, - ck_tile::index_t stride_B_, - ck_tile::index_t stride_C_, - ScaleM scale_m_, - ScaleN scale_n_) - : Base({a_ptr}, - {b_ptr}, - {}, - c_ptr_, - k_batch_, - M_, - N_, - K_, - {stride_A_}, - {stride_B_}, - {}, - stride_C_), - scale_m(scale_m_), - scale_n(scale_n_) - { - } - - ScaleM scale_m; - ScaleN scale_n; -}; // GEMM config with 16x16 warp tile - struct MxGemmConfig { static constexpr ck_tile::index_t M_Tile = 128; @@ -78,12 +37,15 @@ struct MxGemmConfig static constexpr int TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool DoubleSmemBuffer = false; // comp_async uses double buffer + static constexpr bool DoubleSmemBuffer = true; // comp_async uses double buffer static constexpr bool Preshuffle = false; static constexpr ck_tile::index_t BContiguousItemsPerAccess = 16; static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; static constexpr bool TiledMMAPermuteN = false; + + using AScaleDataType = ck_tile::e8m0_t; + using BScaleDataType = ck_tile::e8m0_t; }; struct MX_GemmConfigEightWaves : MxGemmConfig diff --git a/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp b/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp index c7ea374862..a7b57fb330 100644 --- a/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp +++ b/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp @@ -5,10 +5,9 @@ #include "ck_tile/host.hpp" #include "mx_gemm.hpp" -#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" -#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp" -#include "ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp" -#include "ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1.hpp" template using is_row_major_t = ck_tile::bool_constant< @@ -17,16 +16,16 @@ using is_row_major_t = ck_tile::bool_constant< template -float mx_gemm_calc(const MXGemmHostArgs& args, const ck_tile::stream_config& s) +float mx_gemm_calc(const ck_tile::MxGemmHostArgs<1, 1, 0>& args, const ck_tile::stream_config& s) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -51,29 +50,38 @@ float mx_gemm_calc(const MXGemmHostArgs& args, const ck_tile::st static_assert(sizeof(ComputeDataType) >= sizeof(BDataType), "mixed_prec_gemm requires ADataType is a wider type than BDataType"); - using MXPipelineProblem = ck_tile::UniversalGemmPipelineProblem; + using AComputeDataType = ADataType; + using BComputeDataType = BDataType; + + using MXPipelineProblem = ck_tile::MxGemmPipelineProblem; - // Use the MX GEMM Preshuffle pipeline or - // the new MX comp_async pipeline with MX scaling support constexpr bool IsEightWave = (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp) == 8; using MXGemmPipeline = std::conditional_t< GemmConfig::Preshuffle, ck_tile::MXGemmPreshufflePipelineAGmemBGmemCRegV1, std::conditional_t, - ck_tile::MXGemmPipelineAgBgCrCompAsync>>; + ck_tile::GemmPipelineAgBgCrCompAsyncEightWaves, + ck_tile::GemmPipelineAgBgCrCompAsync>>; using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr ck_tile::index_t BlockedXDLNPerWarp = GemmConfig::Preshuffle ? 2 : 1; + using GemmEpilogue = std::conditional_t& args, const ck_tile::st GemmConfig::NumWaveGroups, false, // FixedVectorSize_ (Default) 1, // VectorSizeC_ (Default) - ck_tile::MXEpilogueTraits::BlockedXDLNPerWarp, + BlockedXDLNPerWarp, false, // DoubleSmemBuffer_ (Default) ComputeDataType, // AComputeDataType ComputeDataType, // BComputeDataType !GemmConfig::Preshuffle>>>; // TilesPacked_ (because of // packed scales) - using Kernel = ck_tile::MXGemmKernel; + using Kernel = ck_tile::MxGemmKernel; - auto kargs = Kernel::MakeKernelArgs(std::array{args.as_ptr}, - std::array{args.bs_ptr}, - std::array{}, - args.e_ptr, - args.k_batch, - args.M, - args.N, - args.K, - std::array{args.stride_As}, - std::array{args.stride_Bs}, - std::array{}, - args.stride_E, - args.scale_m, - args.scale_n); + auto kargs = Kernel::MakeKernelArgs(args); if(!Kernel::IsSupportedArgument(kargs)) { @@ -145,8 +140,10 @@ float mx_gemm_calc(const MXGemmHostArgs& args, const ck_tile::st "MX GEMM: unsupported shape/configuration (set CK_TILE_LOGGING=1 for details)."); } - const auto kernel = ck_tile::make_kernel( - Kernel{}, Kernel::GridSize(kargs), Kernel::BlockSize(), 0, kargs); + constexpr int kBlockPerCu = 1; + + const auto kernel = ck_tile::make_kernel( + Kernel{}, Kernel::GridSize(args.M, args.N, args.k_batch), Kernel::BlockSize(), 0, kargs); // For split-K (k_batch > 1) the kernel's epilogue uses atomic_add into C, so C must be // zeroed before every kernel launch -- not just once before the first warmup iteration. diff --git a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc index d6b24e6758..52e28443d7 100644 --- a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc +++ b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc @@ -16,6 +16,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K, const float max_accumulated_v template scale_a_host( + {static_cast(M), static_cast(scale_k_size)}, + {static_cast(scale_k_size), static_cast(1)}); + + // scale_b uses N as first dimension (col-major like B) + ck_tile::HostTensor scale_b_host( + {static_cast(N), static_cast(scale_k_size)}, + {static_cast(scale_k_size), static_cast(1)}); - ck_tile::HostTensor scale_a_host( - ck_tile::host_tensor_descriptor(M, scale_k_size, stride_scale_a, is_row_major(ALayout{}))); - ck_tile::HostTensor scale_b_host( - ck_tile::host_tensor_descriptor(scale_k_size, N, stride_scale_b, is_row_major(BLayout{}))); std::mt19937 gen(42); std::uniform_int_distribution fill_seed(0, 500); - auto gen_scales = [&](auto& scales, float range_min, float range_max) { - // e8m0_t is basically an exponent of float32 - ck_tile::HostTensor pow2(scales.get_lengths()); - ck_tile::FillUniformDistributionIntegerValue{range_min, range_max, fill_seed(gen)}( - pow2); - scales.ForEach([&](auto& self, const auto& i) { - self(i) = static_cast(std::exp2(pow2(i))); - }); - }; switch(init_method) { case 0: // Initialize A, B, and scales to random values ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(a_host); ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(b_host); - gen_scales(scale_a_host, -2, 2); - gen_scales(scale_b_host, -2, 2); + ck_tile::FillUniformScaleDistribution{0.125f, 2.0f, fill_seed(gen)}( + scale_a_host); + ck_tile::FillUniformScaleDistribution{0.125f, 2.0f, fill_seed(gen)}( + scale_b_host); break; case 1: // Initialize A, B, and scales to 1.0 ck_tile::FillConstant{ADataType(1.f)}(a_host); ck_tile::FillConstant{BDataType(1.f)}(b_host); - gen_scales(scale_a_host, 0, 0); - gen_scales(scale_b_host, 0, 0); + ck_tile::FillUniformScaleDistribution{1.0f, 1.0f, fill_seed(gen)}( + scale_a_host); + ck_tile::FillUniformScaleDistribution{1.0f, 1.0f, fill_seed(gen)}( + scale_b_host); break; case 2: // Initialize A and B with random values but with constant 1.0 scales ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(a_host); ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(b_host); - gen_scales(scale_a_host, 0, 0); - gen_scales(scale_b_host, 0, 0); + ck_tile::FillUniformScaleDistribution{1.0f, 1.0f, fill_seed(gen)}( + scale_a_host); + ck_tile::FillUniformScaleDistribution{1.0f, 1.0f, fill_seed(gen)}( + scale_b_host); break; } @@ -128,12 +125,31 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout) constexpr ck_tile::index_t XdlMNThread = GemmConfig::M_Warp_Tile; constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread; - auto scale_a_packed = - ck_tile::packScalesMNxK(scale_a_host, - true); - auto scale_b_packed = - ck_tile::packScalesMNxK(scale_b_host, - false); + ck_tile::HostTensor scale_a_shuffled( + {static_cast(M / MXdlPackEff * 2), + static_cast(scale_k_size / KXdlPackEff * 2)}, + {static_cast(scale_k_size / KXdlPackEff * 2), static_cast(1)}); + + ck_tile::HostTensor scale_b_shuffled( + {static_cast(N / NXdlPackEff * 2), + static_cast(scale_k_size / KXdlPackEff * 2)}, + {static_cast(scale_k_size / KXdlPackEff * 2), static_cast(1)}); + + ck_tile::preShuffleScaleBuffer_gfx950( + scale_a_host.mData.data(), scale_a_shuffled.mData.data(), M, scale_k_size, true); + + if constexpr(GemmConfig::Preshuffle && GemmConfig::TiledMMAPermuteN) + { + ck_tile::preShuffleScaleBufferPermuteN_gfx950( + scale_b_host.mData.data(), scale_b_shuffled.mData.data(), N, scale_k_size, true); + } + else + { + ck_tile::preShuffleScaleBuffer_gfx950( + scale_b_host.mData.data(), scale_b_shuffled.mData.data(), N, scale_k_size, true); + } const auto b_host_for_device = [&]() { if constexpr(GemmConfig::Preshuffle) @@ -145,73 +161,29 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout) return b_host; }(); - const auto scale_a_host_for_device = [&]() { - if constexpr(GemmConfig::Preshuffle) - return ck_tile::preShuffleScale(scale_a_host, true); - else - return scale_a_packed; - }(); - - constexpr ck_tile::index_t XdlNThread = GemmConfig::N_Warp_Tile; - constexpr ck_tile::index_t NPerBlock = GemmConfig::N_Tile; - constexpr ck_tile::index_t NWarp = GemmConfig::N_Warp; - - const auto scale_b_host_for_device = [&]() { - if constexpr(GemmConfig::Preshuffle) - { - if constexpr(GemmConfig::TiledMMAPermuteN) - return ck_tile::preShuffleScalePermuteN(scale_b_host, - false); - else - return ck_tile::preShuffleScale(scale_b_host, false); - } - else - return scale_b_packed; - }(); - - const auto scale_a_device_bytes = [&]() { - if constexpr(GemmConfig::Preshuffle) - return scale_a_host_for_device.get_element_space_size_in_bytes(); - else - return scale_a_host_for_device.size() * sizeof(int32_t); - }(); - - const auto scale_b_device_bytes = [&]() { - if constexpr(GemmConfig::Preshuffle) - return scale_b_host_for_device.get_element_space_size_in_bytes(); - else - return scale_b_host_for_device.size() * sizeof(int32_t); - }(); - // Device buffers for A, B, C, and packed scale tensors ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_dev_buf(b_host_for_device.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem scale_a_dev_buf(scale_a_device_bytes); - ck_tile::DeviceMem scale_b_dev_buf(scale_b_device_bytes); + ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes()); + ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes()); a_dev_buf.ToDevice(a_host.data()); b_dev_buf.ToDevice(b_host_for_device.data()); c_dev_buf.SetZero(); - scale_a_dev_buf.ToDevice(scale_a_host_for_device.data()); - scale_b_dev_buf.ToDevice(scale_b_host_for_device.data()); - - // Scale pointers - point to packed int32_t data, kernel reinterprets as int32_t* - using ScaleM = ck_tile::MXScalePointer; - using ScaleN = ck_tile::MXScalePointer; - ScaleM scale_m(reinterpret_cast(scale_a_dev_buf.GetDeviceBuffer())); - ScaleN scale_n(reinterpret_cast(scale_b_dev_buf.GetDeviceBuffer())); + scale_a_dev_buf.ToDevice(scale_a_shuffled.data()); + scale_b_dev_buf.ToDevice(scale_b_shuffled.data()); float ave_time = invoke_mx_gemm(a_dev_buf, b_dev_buf, c_dev_buf, @@ -222,8 +194,8 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout) stride_B, stride_C, kbatch, - scale_m, - scale_n, + scale_a_dev_buf, + scale_b_dev_buf, n_warmup, n_repeat); @@ -240,9 +212,23 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout) ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); - ck_tile:: - reference_mx_gemm( - a_host, b_host, c_m_n_host_ref, scale_a_host, scale_b_host); + // Host reference computation using reference_mx_gemm + // reference_mx_gemm expects scale_a(M, K/ScaleBlockSize) and scale_b(K/ScaleBlockSize, N) + // We need to create scale_b in (K/ScaleBlockSize, N) format for the reference + ck_tile::HostTensor scale_b_ref( + {static_cast(scale_k_size), static_cast(N)}, + {static_cast(1), static_cast(scale_k_size)}); + // Copy scale_b data (our scale_b is (N, scale_k_size) row-major, + // reference expects (scale_k_size, N) col-major, which is the same memory layout) + std::copy(scale_b_host.mData.begin(), scale_b_host.mData.end(), scale_b_ref.mData.begin()); + + ck_tile::reference_mx_gemm( + a_host, b_host, c_m_n_host_ref, scale_a_host, scale_b_ref); const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); @@ -285,6 +271,8 @@ int run_mx_gemm_example(int argc, char* argv[]) { return run_mx_gemm_with_layouts(argc, argv, Row{}, Col{}, Row{}); @@ -293,6 +281,8 @@ int run_mx_gemm_example(int argc, char* argv[]) { return run_mx_gemm_with_layouts(argc, argv, Row{}, Col{}, Row{}); @@ -304,6 +294,8 @@ int run_mx_gemm_example(int argc, char* argv[]) { return run_mx_gemm_with_layouts(argc, argv, Row{}, Col{}, Row{}); @@ -312,6 +304,8 @@ int run_mx_gemm_example(int argc, char* argv[]) { return run_mx_gemm_with_layouts(argc, argv, Row{}, Col{}, Row{}); diff --git a/include/ck_tile/core/tensor/null_tensor.hpp b/include/ck_tile/core/tensor/null_tensor.hpp index f5cabbef5a..fbbef3ab51 100644 --- a/include/ck_tile/core/tensor/null_tensor.hpp +++ b/include/ck_tile/core/tensor/null_tensor.hpp @@ -3,10 +3,28 @@ #pragma once +#include "ck_tile/core/utility/type_traits.hpp" + namespace ck_tile { struct null_tensor { }; +// utility to check if this is a Null Tensor +namespace impl { +template +struct is_null_tensor : public std::false_type +{ +}; + +template <> +struct is_null_tensor : public std::true_type +{ +}; +} // namespace impl + +template +constexpr bool is_null_tensor_v = impl::is_null_tensor>::value; + } // namespace ck_tile diff --git a/include/ck_tile/host/mx_processing.hpp b/include/ck_tile/host/mx_processing.hpp index c15df6f509..e5eaf70fb8 100644 --- a/include/ck_tile/host/mx_processing.hpp +++ b/include/ck_tile/host/mx_processing.hpp @@ -7,148 +7,163 @@ namespace ck_tile { +/// @brief Pre-shuffle scale buffer for gfx1250 wmma mx scale instruction. +/// +/// Reorganizes the scale data from row-major (MN x K) layout to the hardware-specific +/// layout expected by the gfx1250 wmma instruction. +/// +/// @tparam ScaleType Scale data type (e.g., e8m0_t) +/// @tparam ScaleBlockSize The block size for microscaling (e.g., 32) +/// @tparam KStride Whether K is the fast-moving dimension +template +void preShuffleScaleBuffer_gfx1250(const ScaleType* src, + ScaleType* dst, + ck_tile::index_t MN, + ck_tile::index_t K) +{ + static_assert((ScaleBlockSize == 32 || ScaleBlockSize == 16) && sizeof(ScaleType) == 1, + "wrong! only support 8-bit scale with ScaleBlockSize=32 or 16"); + + // ScaleBlockSize == 16: the natural row-major scale layout already matches the gfx1250 + // wmma scale distribution (one e8m0 per 16 K-elements lands warp-aligned), so the + // device-side shuffle is the identity transform for all K. + if constexpr(ScaleBlockSize == 16) + { + for(ck_tile::long_index_t mn = 0; mn < MN; ++mn) + for(ck_tile::long_index_t k = 0; k < K; ++k) + { + if constexpr(KStride) + dst[mn * K + k] = src[mn * K + k]; + else + dst[mn * K + k] = src[k * MN + mn]; + } + return; + } + + constexpr ck_tile::long_index_t MPerXdlops = 16; + constexpr ck_tile::long_index_t KPerXdlops = 128; + + ck_tile::long_index_t MNPack = 2; + ck_tile::long_index_t KPack = 1; + + ck_tile::long_index_t MNStep = MPerXdlops; + ck_tile::long_index_t KStep = KPerXdlops / ScaleBlockSize; + + ck_tile::long_index_t K0 = K / KPack / KStep; + + for(ck_tile::long_index_t mn = 0; mn < MN; ++mn) + { + ck_tile::long_index_t iMNRepeat = mn / (MNStep * MNPack); + ck_tile::long_index_t tempmn = mn % (MNStep * MNPack); + + for(ck_tile::long_index_t k = 0; k < K; ++k) + { + ck_tile::long_index_t iKRepeat = k / (KStep * KPack); + ck_tile::long_index_t tempk = k % (KStep * KPack); + + ck_tile::long_index_t outputIndex = + (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) + + (iKRepeat * KStep * KPack) * (MNStep * MNPack) + tempmn * (KStep * KPack) + tempk; + + if constexpr(KStride) + { + dst[outputIndex] = src[mn * K + k]; + } + else + dst[outputIndex] = src[k * MN + mn]; + } + } +} + // Pack [MN, K/32] e8m0_t scales into [MN/MNPack, K/32/KPack] int32_t // Each int32_t contains MNPack * KPack e8m0_t values with byte layout matching // the GPU tile distribution: values are XdlMNThread apart in M and XdlKThread apart in K. // byte[ik * MNPack + imn] = e8m0 at strided (mn, k) position // kLast=true for A scales (layout [M, K/32]), kLast=false for B scales (layout [K/32, N]) -template -auto packScalesMNxK(const HostTensor& src, const bool kLast) +template +void preShuffleScaleBuffer_gfx950(const ScaleType* src, + ScaleType* packed, + ck_tile::index_t MN, + ck_tile::index_t K_scale, + bool kLast) { - auto src_lengths = src.get_lengths(); - const index_t MN = kLast ? src_lengths[0] : src_lengths[1]; - const index_t K_scale = kLast ? src_lengths[1] : src_lengths[0]; - const index_t MN_packed = MN / MNPack; - const index_t K_packed = K_scale / KPack; + const ck_tile::long_index_t MN_packed = MN / MNPack; + const ck_tile::long_index_t K_packed = K_scale / KPack; + constexpr ck_tile::long_index_t NumScalesPerDword = 4 / sizeof(ScaleType); - // Output as flat vector of int32_t (row-major [MN/MNPack, K/32/KPack]) - HostTensor packed(HostTensorDescriptor( - {static_cast(MN_packed), static_cast(K_packed)}, - {static_cast(K_packed), static_cast(1)})); - - for(index_t packed_mn = 0; packed_mn < MN_packed; packed_mn++) + for(ck_tile::long_index_t packed_mn = 0; packed_mn < MN_packed; packed_mn++) { - for(index_t packed_k = 0; packed_k < K_packed; packed_k++) + for(ck_tile::long_index_t packed_k = 0; packed_k < K_packed; packed_k++) { - uint32_t val = 0; - index_t mn_lane = packed_mn % XdlMNThread; - index_t mn_group = packed_mn / XdlMNThread; - index_t k_lane = packed_k % XdlKThread; - index_t k_group = packed_k / XdlKThread; - for(index_t ik = 0; ik < KPack; ik++) + ck_tile::long_index_t mn_lane = packed_mn % XdlMNThread; + ck_tile::long_index_t mn_group = packed_mn / XdlMNThread; + ck_tile::long_index_t k_lane = packed_k % XdlKThread; + ck_tile::long_index_t k_group = packed_k / XdlKThread; + for(ck_tile::long_index_t ik = 0; ik < KPack; ik++) { - for(index_t imn = 0; imn < MNPack; imn++) + for(ck_tile::long_index_t imn = 0; imn < MNPack; imn++) { - index_t byteIdx = ik * MNPack + imn; - index_t orig_mn = mn_group * XdlMNThread * MNPack + imn * XdlMNThread + mn_lane; - index_t orig_k = k_group * XdlKThread * KPack + ik * XdlKThread + k_lane; + ck_tile::long_index_t byteIdx = ik * MNPack + imn; + ck_tile::long_index_t orig_mn = + mn_group * XdlMNThread * MNPack + imn * XdlMNThread + mn_lane; + ck_tile::long_index_t orig_k = + k_group * XdlKThread * KPack + ik * XdlKThread + k_lane; - e8m0_t v = kLast ? src(orig_mn, orig_k) : src(orig_k, orig_mn); - val |= (static_cast(v.get()) << (byteIdx * 8)); + ck_tile::long_index_t inputIndex = + kLast ? orig_k + orig_mn * K_scale : orig_mn + orig_k * MN; + ScaleType v = src[inputIndex]; + ck_tile::long_index_t outputIndex = + byteIdx + (packed_mn % XdlMNThread) * NumScalesPerDword + + packed_k * XdlMNThread * NumScalesPerDword + + (packed_mn / XdlMNThread) * XdlMNThread * NumScalesPerDword * K_packed; + packed[outputIndex] = v; } } - packed(packed_mn, packed_k) = static_cast(val); } } - return packed; } -template -auto preShuffleScale(ck_tile::HostTensor& src, const bool kLast) +template +auto preShuffleScaleBufferPermuteN_gfx950( + const ScaleType* src, ScaleType* shuffled, ck_tile::index_t MN, ck_tile::index_t K, bool kLast) { - auto src_lengths = src.get_lengths(); - const index_t MN = kLast ? src_lengths[0] : src_lengths[1]; - const index_t K = kLast ? src_lengths[1] : src_lengths[0]; - - constexpr index_t MNXdlPack = 2; - constexpr index_t KXdlPack = 2; - constexpr index_t XdlKThread = get_warp_size() / XdlMNThread; - - const auto MNPadded = integer_least_multiple(MN, XdlMNThread * MNXdlPack); - HostTensor shuffled(HostTensorDescriptor({static_cast(MNPadded * K)}, - {static_cast(1)})); + constexpr ck_tile::long_index_t MNXdlPack = 2; + constexpr ck_tile::long_index_t KXdlPack = 2; + constexpr ck_tile::long_index_t NRepeat = NPerBlock / NWarp / XdlMNThread; + constexpr ck_tile::long_index_t XdlKThread = ck_tile::get_warp_size() / XdlMNThread; if(K % (KXdlPack * XdlKThread) != 0) { throw std::runtime_error("wrong! K must be a multiple of (KXdlPack * XdlKThread)"); } + const ck_tile::long_index_t K0 = K / KXdlPack / XdlKThread; - const index_t K0 = K / KXdlPack / XdlKThread; - - for(index_t n = 0; n < MNPadded; ++n) + for(ck_tile::long_index_t n = 0; n < MN; ++n) { - for(index_t k = 0; k < K; ++k) + for(ck_tile::long_index_t k = 0; k < K; ++k) { - const index_t n0 = n / (XdlMNThread * MNXdlPack); - const index_t tempn = n % (XdlMNThread * MNXdlPack); - const index_t n1 = tempn % XdlMNThread; - const index_t n2 = tempn / XdlMNThread; + const ck_tile::long_index_t n0 = n / NPerBlock; + const ck_tile::long_index_t tempn0 = n % NPerBlock; + const ck_tile::long_index_t n1 = tempn0 / (XdlMNThread * NRepeat); + const ck_tile::long_index_t tempn1 = tempn0 % (XdlMNThread * NRepeat); + const ck_tile::long_index_t n2 = tempn1 / (NRepeat); + const ck_tile::long_index_t tempn2 = tempn1 % (NRepeat); + const ck_tile::long_index_t n3 = tempn2 % MNXdlPack; + const ck_tile::long_index_t n4 = tempn2 / MNXdlPack; - const index_t k0 = k / (XdlKThread * KXdlPack); - const index_t tempk = k % (XdlKThread * KXdlPack); - const index_t k1 = tempk % XdlKThread; - const index_t k2 = tempk / XdlKThread; + const ck_tile::long_index_t k0 = k / (XdlKThread * KXdlPack); + const ck_tile::long_index_t tempk = k % (XdlKThread * KXdlPack); + const ck_tile::long_index_t k1 = tempk % XdlKThread; + const ck_tile::long_index_t k2 = tempk / XdlKThread; - const index_t outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + - k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + - k1 * MNXdlPack * KXdlPack * XdlMNThread + - n1 * MNXdlPack * KXdlPack + k2 * MNXdlPack + n2; - - if(n < MN) - { - shuffled(outputIndex) = kLast ? src(n, k) : src(k, n); - } - else - { - shuffled(outputIndex) = dtype{}; - } - } - } - - return shuffled; -} - -template -auto preShuffleScalePermuteN(const HostTensor& src, const bool kLast) -{ - auto src_lengths = src.get_lengths(); - const index_t MN = kLast ? src_lengths[0] : src_lengths[1]; - const index_t K = kLast ? src_lengths[1] : src_lengths[0]; - - constexpr index_t MNXdlPack = 2; - constexpr index_t KXdlPack = 2; - constexpr index_t NRepeat = NPerBlock / NWarp / XdlMNThread; - constexpr index_t XdlKThread = get_warp_size() / XdlMNThread; // 4 - - const index_t MNPadded = integer_least_multiple(MN, NPerBlock); - HostTensor shuffled(HostTensorDescriptor({static_cast(MNPadded * K)}, - {static_cast(1)})); - - if(K % (KXdlPack * XdlKThread) != 0) - { - throw std::runtime_error("wrong! K must be a multiple of (KXdlPack * XdlKThread)"); - } - const index_t K0 = K / KXdlPack / XdlKThread; - - for(index_t n = 0; n < MNPadded; ++n) - { - for(index_t k = 0; k < K; ++k) - { - const index_t n0 = n / NPerBlock; - const index_t tempn0 = n % NPerBlock; - const index_t n1 = tempn0 / (XdlMNThread * NRepeat); - const index_t tempn1 = tempn0 % (XdlMNThread * NRepeat); - const index_t n2 = tempn1 / (NRepeat); - const index_t tempn2 = tempn1 % (NRepeat); - const index_t n3 = tempn2 % MNXdlPack; - const index_t n4 = tempn2 / MNXdlPack; - - const index_t k0 = k / (XdlKThread * KXdlPack); - const index_t tempk = k % (XdlKThread * KXdlPack); - const index_t k1 = tempk % XdlKThread; - const index_t k2 = tempk / XdlKThread; - - const index_t outputIndex = + const ck_tile::long_index_t outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 * NWarp * (NRepeat / MNXdlPack) + n1 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + @@ -156,13 +171,15 @@ auto preShuffleScalePermuteN(const HostTensor& src, const bool kLast) k1 * MNXdlPack * KXdlPack * XdlMNThread + k2 * MNXdlPack + n4 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 * NWarp + n3; + ck_tile::long_index_t inputIndex = kLast ? k + n * K : n + k * MN; + if(n < MN) { - shuffled(outputIndex) = kLast ? src(n, k) : src(k, n); + shuffled[outputIndex] = src[inputIndex]; } else { - shuffled(outputIndex) = dtype{}; + shuffled[outputIndex] = ScaleType{}; } } } diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 81987876ee..285e0f2d99 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -1184,4 +1184,131 @@ void reference_batched_gemm_gpu(ADataType* a_ptr, return; } +// GPU reference for MX (microscaling) GEMM with e8m0 block scales. +// +// This is the device counterpart of the host `reference_mx_gemm` above. It exists so the +// reference can be computed entirely on the GPU for large problems (e.g. M*N ~ 1e9) where the +// host reference is intractable and where copying the 39 GB of inputs back to host is not +// feasible. It is a faithful mirror of the host semantics: +// - per-element dot product over K, with each A/B element dequantized by its e8m0 block scale. +// - all addressing uses `long`; the existing `naive_gemm_kernel`/`blockwise_gemm_kernel` use +// `int` and silently overflow once M*N exceeds INT_MAX. +// +// Layout assumptions match the fp4 CompAsync grouped-GEMM test row (RowMajor A, ColumnMajor B, +// RowMajor C): +// - A is RowMajor: element (m,k) at linear offset m*K + k. +// - B is ColumnMajor: element (k,n) at linear offset n*K + k (K is the fast dimension). +// - C is RowMajor: element (m,n) at linear offset m*N + n. +// - scale_a is (M, num_scale_k) RowMajor: scale for (m, k) at m*num_scale_k + k/scale_block_size. +// - scale_b is (N, num_scale_k) RowMajor: scale for (k, n) at n*num_scale_k + k/scale_block_size. +template +__global__ void reference_mx_gemm_kernel(const ADataType* __restrict__ a_ptr, + const BDataType* __restrict__ b_ptr, + const AScaleDataType* __restrict__ scale_a_ptr, + const BScaleDataType* __restrict__ scale_b_ptr, + CDataType* __restrict__ c_ptr, + long M, + long N, + long K, + long num_scale_k, + long scale_block_size) +{ + const long total = M * N; + const long idx0 = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const long nthr = static_cast(gridDim.x) * blockDim.x; + + for(long idx = idx0; idx < total; idx += nthr) + { + const long m = idx / N; + const long n = idx % N; + + AccDataType acc = 0; + for(long k = 0; k < K; ++k) + { + // --- A element (RowMajor) --- + AccDataType a_val; + const long a_lin = m * K + k; + if constexpr(std::is_same_v) + { + const fp32x2_t a_f2 = pk_fp4_to_fp32x2(a_ptr[a_lin / 2], 1.0f); + a_val = type_convert((a_lin % 2 == 0) ? a_f2.lo : a_f2.hi); + } + else + { + a_val = type_convert(a_ptr[a_lin]); + } + const float a_sc = + type_convert(scale_a_ptr[m * num_scale_k + k / scale_block_size]); + + // --- B element (ColumnMajor, K fast) --- + AccDataType b_val; + const long b_lin = n * K + k; + if constexpr(std::is_same_v) + { + const fp32x2_t b_f2 = pk_fp4_to_fp32x2(b_ptr[b_lin / 2], 1.0f); + b_val = type_convert((b_lin % 2 == 0) ? b_f2.lo : b_f2.hi); + } + else + { + b_val = type_convert(b_ptr[b_lin]); + } + const float b_sc = + type_convert(scale_b_ptr[n * num_scale_k + k / scale_block_size]); + + acc += (a_val * type_convert(a_sc)) * + (b_val * type_convert(b_sc)); + } + c_ptr[m * N + n] = type_convert(acc); + } +} + +template +void reference_mx_gemm_gpu(const ADataType* a_ptr, + const BDataType* b_ptr, + const AScaleDataType* scale_a_ptr, + const BScaleDataType* scale_b_ptr, + CDataType* c_ptr, + index_t M, + index_t N, + index_t K, + index_t num_scale_k, + index_t scale_block_size, + hipStream_t stream = nullptr) +{ + const long total = static_cast(M) * N; + constexpr int threads = 256; + constexpr long max_blocks = 2097152; // grid-stride cap (~2M blocks) + const long needed = (total + threads - 1) / threads; + const long blocks = needed < max_blocks ? needed : max_blocks; + + reference_mx_gemm_kernel + <<(blocks)), dim3(threads), 0, stream>>>( + a_ptr, + b_ptr, + scale_a_ptr, + scale_b_ptr, + c_ptr, + static_cast(M), + static_cast(N), + static_cast(K), + static_cast(num_scale_k), + static_cast(scale_block_size)); + hip_check_error(hipGetLastError()); +} + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 6779da556e..b93163c202 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -30,6 +30,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" +#include "ck_tile/ops/gemm/block/block_mx_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp" @@ -78,6 +79,8 @@ #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_tdm.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_tdm_policy.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp index 9f91c06e8e..76b5eede5c 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp @@ -35,6 +35,8 @@ struct BlockGemmARegBRegCRegEightWavesV1 static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr auto PackMNIter = Policy::PackMNIter; + static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WarpGemm = remove_cvref_t())>; @@ -95,6 +97,8 @@ struct BlockGemmARegBRegCRegEightWavesV1 static constexpr auto Scheduler = Traits::Scheduler; static constexpr bool TransposeC = Traits::TransposeC; + static constexpr bool PackMNIter = Traits::PackMNIter; + using AWarpDstr = typename WarpGemm::AWarpDstr; using BWarpDstr = typename WarpGemm::BWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr; @@ -136,17 +140,34 @@ struct BlockGemmARegBRegCRegEightWavesV1 sequence, sequence>; - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 1>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + if constexpr(PackMNIter) + { + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 1>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - return a_block_dstr_encode; + return a_block_dstr_encode; + } + else + { + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 1>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } } CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() @@ -161,32 +182,66 @@ struct BlockGemmARegBRegCRegEightWavesV1 sequence, sequence>; - constexpr auto b_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<>, - sequence<>>{}; + if constexpr(PackMNIter) + { + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<>, + sequence<>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - return b_block_dstr_encode; + return b_block_dstr_encode; + } + else + { + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<>, + sequence<>>{}; + + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } } CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence<2, NIterPerWarp, NWarp / 2>>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 1>>{}; - constexpr auto c_block_dstr_encoding = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - return c_block_dstr_encoding; + if constexpr(PackMNIter) + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence<2, NIterPerWarp, NWarp / 2>>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 1>>{}; + constexpr auto c_block_dstr_encoding = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + return c_block_dstr_encoding; + } + else + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence<2, NIterPerWarp, NWarp / 2>>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 1>>{}; + constexpr auto c_block_dstr_encoding = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + return c_block_dstr_encoding; + } } CK_TILE_DEVICE static constexpr auto MakeCBlockTile() @@ -252,6 +307,101 @@ struct BlockGemmARegBRegCRegEightWavesV1 }); } + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ALdsTile& a_warp_tile_, + const BLdsTiles& b_warp_tiles_, + const ScaleATensor& scale_a_tensor, + const ScaleBTensor& scale_b_tensor) const + { + // checks + static_assert(std::is_same_v>, + "CDataType must be same as CBlockTensor::DataType!"); + static_assert( + std::is_same_v, + remove_cvref_t>, + "C distribution is wrong!"); + + // Effective XdlPack: fall back to 1 when iteration count is insufficient + constexpr index_t MXdlPack = + (MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1; + constexpr index_t NXdlPack = + (NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1; + constexpr index_t KXdlPack = + (KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1; + + constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack; + constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack; + constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack; + + // hot loop: + static_for_product, + number, + number>{}([&](auto ikpack, auto inpack, auto impack) { + // get A scale for this M-K tile using get_y_sliced_thread_data + auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data( + sequence{}, sequence<1, 1, 1>{}); + const int32_t a_scale_packed = bit_cast(scale_a_slice[number<0>{}]); + + // get B scale for this N-K tile using get_y_sliced_thread_data + auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data( + sequence{}, sequence<1, 1, 1>{}); + const int32_t b_scale_packed = bit_cast(scale_b_slice[number<0>{}]); + + // Inner loops: issue MFMAs within the pack group using OpSel + static_for_product, number, number>{}( + [&](auto ikxdl, auto inxdl, auto imxdl) { + constexpr auto kIter = ikpack * KXdlPack + ikxdl; + constexpr auto mIter = impack * MXdlPack + imxdl; + constexpr auto nIter = inpack * NXdlPack + inxdl; + + // OpSel for A: selects byte within packed int32_t + constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl; + + // OpSel for B: selects byte within packed int32_t + constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl; + + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = + b_warp_tiles_[number{}][number{}].get_thread_buffer(); + + // read C warp tensor from C block tensor + using c_iter_idx = sequence; + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM with MX scaling + WarpGemm{}.template operator(), OpSelB>( + c_warp_tensor, + a_warp_tensor, + b_warp_tensor, + a_scale_packed, + b_scale_packed); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + } + template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, const ALdsTile& a_warp_tile_, diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index 51fd49f88f..0465dcb2af 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -39,6 +39,8 @@ struct BlockGemmARegBRegCRegV1 static constexpr auto KSubTileNum = Policy::KSubTileNum; + static constexpr auto PackMNIter = Policy::PackMNIter; + static constexpr index_t MWarp = config.template at<1>(); static constexpr index_t NWarp = config.template at<2>(); static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); @@ -69,6 +71,8 @@ struct BlockGemmARegBRegCRegV1 static constexpr index_t KSubTileNum = Traits::KSubTileNum; + static constexpr bool PackMNIter = Traits::PackMNIter; + static constexpr index_t KPerSubTile = KIterPerWarp / KSubTileNum; static constexpr index_t MWarp = Traits::MWarp; @@ -94,17 +98,34 @@ struct BlockGemmARegBRegCRegV1 } else { - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + if constexpr(PackMNIter) + { + constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - return a_block_dstr_encode; + return a_block_dstr_encode; + } + else + { + constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } } } @@ -126,50 +147,103 @@ struct BlockGemmARegBRegCRegV1 } else { - constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + if constexpr(PackMNIter) + { + constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - return b_block_dstr_encode; + return b_block_dstr_encode; + } + else + { + constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } } } CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() { using c_distr_ys_major = std::conditional_t, sequence<1, 2>>; - if constexpr(UseDefaultScheduler) + if constexpr(PackMNIter) { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple<>, - tuple<>, - c_distr_ys_major, - sequence<0, 0>>{}; - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + if constexpr(UseDefaultScheduler) + { + using c_distr_ys_minor = + std::conditional_t, sequence<0, 1>>; + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + c_distr_ys_major, + c_distr_ys_minor>{}; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - return c_block_dstr_encode; + return c_block_dstr_encode; + } + else + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + c_distr_ys_major, + sequence<1, 1>>{}; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + return c_block_dstr_encode; + } } else { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - c_distr_ys_major, - sequence<0, 0>>{}; - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + if constexpr(UseDefaultScheduler) + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple<>, + tuple<>, + c_distr_ys_major, + sequence<0, 0>>{}; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - return c_block_dstr_encode; + return c_block_dstr_encode; + } + else + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + c_distr_ys_major, + sequence<0, 0>>{}; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + return c_block_dstr_encode; + } } } @@ -411,9 +485,12 @@ struct BlockGemmARegBRegCRegV1 typename BBlockTensor, typename ScaleATensor, typename ScaleBTensor, - index_t MXdlPack_ = 2, - index_t NXdlPack_ = 2, - index_t KXdlPack_ = 2> + index_t MXdlPack_ = 2, + index_t NXdlPack_ = 2, + index_t KXdlPack_ = 2, + typename std::enable_if_t && + !is_null_tensor_v, + bool>* = nullptr> CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, const ABlockTensor& a_block_tensor, const BBlockTensor& b_block_tensor, @@ -423,7 +500,7 @@ struct BlockGemmARegBRegCRegV1 static_assert(std::is_same_v> && std::is_same_v> && std::is_same_v>, - "wrong!"); + "Datatypes do not match BlockTensor datatypes!"); // check ABC-block-distribution static_assert( @@ -545,41 +622,27 @@ struct BlockGemmARegBRegCRegV1 }); } + template < + typename CBlockTensor, + typename ABlockTensor, + typename BBlockTensor, + typename ScaleATensor, + typename ScaleBTensor, + typename std::enable_if_t && is_null_tensor_v, + bool>* = nullptr> + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensor& a_block_tensor, + const BBlockTensor& b_block_tensor, + const ScaleATensor&, + const ScaleBTensor&) const + { + operator()(c_block_tensor, a_block_tensor, b_block_tensor); + } + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - using c_distr_ys_major = std::conditional_t, sequence<1, 2>>; - if constexpr(UseDefaultScheduler) - { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple<>, - tuple<>, - c_distr_ys_major, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - return c_block_tensor; - } - else - { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - c_distr_ys_major, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - return c_block_tensor; - } + return make_static_distributed_tensor( + make_static_tile_distribution(MakeCBlockDistributionEncode())); } // C = A * B diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp index 2a47d3b54d..a52c27e5a1 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp @@ -12,8 +12,8 @@ template // this variable is used for split K into multiple subtiles in - // order to reduce register usage per wave> + index_t KSubTileNum_ = 1, // this variable is used for split K into multiple subtiles in + bool PackMNIter_ = false> // order to reduce register usage per wave> struct BlockGemmARegBRegCRegV1CustomPolicy { using AType = remove_cvref_t; @@ -29,6 +29,7 @@ struct BlockGemmARegBRegCRegV1CustomPolicy using WarpGemm = remove_cvref_t; static constexpr index_t KSubTileNum = KSubTileNum_; + static constexpr bool PackMNIter = PackMNIter_; template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp index 91ace17499..e02bf97b88 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -26,6 +26,8 @@ struct BlockGemmASmemBSmemCRegV1CustomPolicy static constexpr index_t kNWarps = BlockWarps::at(number<1>{}); static constexpr index_t kKWarps = BlockWarps::at(number<2>{}); + static constexpr bool PackMNIter = false; + using WarpGemm = remove_cvref_t; template diff --git a/include/ck_tile/ops/gemm_mx/block/block_mx_asmem_breg_creg.hpp b/include/ck_tile/ops/gemm/block/block_mx_asmem_breg_creg.hpp similarity index 100% rename from include/ck_tile/ops/gemm_mx/block/block_mx_asmem_breg_creg.hpp rename to include/ck_tile/ops/gemm/block/block_mx_asmem_breg_creg.hpp diff --git a/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp index bfc5673e11..d1b8d84b48 100644 --- a/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp @@ -3,6 +3,8 @@ #pragma once +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" + namespace ck_tile { template @@ -72,6 +74,7 @@ struct MxGemmKernel using BaseKernel::PersistentKernel; using typename BaseKernel::AsLayout; using typename BaseKernel::BsLayout; + using typename BaseKernel::CLayout; using typename BaseKernel::DsLayout; using typename BaseKernel::ADataType; @@ -91,6 +94,8 @@ struct MxGemmKernel using BaseKernel::APackedSize; using BaseKernel::BPackedSize; + using BaseKernel::I1; + using AElementWise = remove_cvref_t; using BElementWise = remove_cvref_t; @@ -100,12 +105,48 @@ struct MxGemmKernel static constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{}); static constexpr int BlockScaleSize = MxGemmPipeline::ScaleBlockSize; - static_assert(BlockScaleSize == 16 || BlockScaleSize == 32, "unsupported BlockScaleSize"); - // Scale tensor element type is always int32_t (4 packed e8m0 bytes). - // For scale16, each thread needs 8 bytes = 2 int32_t elements. - // For scale32, each thread needs 4 bytes = 1 int32_t element. - static constexpr int ScalePackSize = 4; - using ScalePtrType = const int32_t*; + using ScalePtrType = const int32_t*; + // Padding flags pulled from pipeline so the kernel can pad the (unscaled) C and scale views + // consistently with the A/B views that the pipeline already pads via + // Underlying::MakeA/BBlockWindows. + static constexpr bool kPadM = MxGemmPipeline::kPadM; + static constexpr bool kPadN = MxGemmPipeline::kPadN; + static constexpr bool kPadK = MxGemmPipeline::kPadK; + + // ------------------------------------------------------------------ + // Compile-time padding-support invariants for the MX comp-async pipeline. + // + // - K padding is NOT supported: async_load_tile issues vector buffer reads whose + // OOB check is per-vector-start, so a vector that straddles the K pad boundary + // pulls in data from the adjacent row / next K tile rather than zero. The packed + // scale tile has the same vector-load property. Until the async path learns how + // to do per-element pad masking, we forbid kPadK at compile time. + // + // - kPadM / kPadN are supported only when the GEMM has at least one full block + // along that dimension; the CShuffleEpilogue's LDS shuffle uses thread positions + // that do not all participate when the entire dimension is smaller than a tile + // (resulting in zeros being written into in-range output rows). The "entire + // dimension < tile" case is rejected at runtime in IsSupportedArgument; we + // cannot statically catch it because M and N are runtime values. + // ------------------------------------------------------------------ + static_assert(!kPadK, + "MX GEMM (comp-async pipeline): K padding (kPadK = true) is not supported. " + "The async vector loads do not mask elements that straddle the K pad " + "boundary, so partial K tiles produce silently wrong results. Choose K so " + "that K is a multiple of KPerBlock * k_batch."); + + // Single source of truth for the split-K atomic-add precondition, shared by the runtime + // check in IsSupportedArgument and the atomic_add dispatch in operator(). Split-K + // accumulates each k_id's partial C tile with atomic_add; the CShuffle epilogue can only + // emit atomic_add for fp16/bf16 outputs when the C vector size is even. For an odd vector + // size that combination is not instantiated, so such a config cannot run split-K. For all + // shipped tile shapes GetVectorSizeC() is even, so this is defensive rather than reachable. + static constexpr bool kSplitKAtomicAddSupported = + EpiloguePipeline::GetVectorSizeC() % 2 == 0 || !is_any_of::value; + + static constexpr index_t MXdlPackEff = MxGemmPipeline::MXdlPackEff; + static constexpr index_t NXdlPackEff = MxGemmPipeline::NXdlPackEff; + static constexpr index_t KXdlPackEff = MxGemmPipeline::KXdlPackEff; using KernelArgs = MxGemmKernelArgs; @@ -131,14 +172,57 @@ struct MxGemmKernel CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) { - if(kargs.k_batch != 1) + const bool log = ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)); + + if(kargs.k_batch < 1) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("SplitK (k_batch > 1) is not supported for MX GEMM!"); - } + if(log) + CK_TILE_ERROR("MX GEMM: k_batch must be >= 1."); return false; } + + // Split-K derives this k_id's logical K start from the row-major SplitKBatchOffset + // (as_k_split_offset[0]) to offset the packed-scale / flat-B windows; for column-major A + // that field is stride-scaled, so split-K with non-row-major A is not yet supported. + // (k_batch == 1 is unaffected -- the offset is 0 and unused.) When col-major A lands for + // non-preshuffle, extend the split-K K-offset here instead of this reject. + using ALayout = remove_cvref_t>; + if constexpr(!std::is_same_v) + { + if(kargs.k_batch > 1) + { + if(log) + CK_TILE_ERROR("MX GEMM: split-K (k_batch > 1) currently requires row-major A."); + return false; + } + } + + // Scales are granular in K: each packed int32_t covers BlockScaleSize * KXdlPackEff + // consecutive K elements. Every split-K boundary must land on that granularity so that + // each split can compute a packed-scale K offset. K1 is the WarpTile K, which is a + // multiple of that granularity for all shipped configs, but be defensive. + constexpr index_t scale_granularity_k = BlockScaleSize * KXdlPackEff; + if(kargs.k_batch > 1) + { + // splitk_batch_offset allocates K in units of K1 (warp-tile K). If K1 itself is + // not a multiple of the scale granularity, split-K is not safe. + constexpr index_t K1 = BlockGemmShape::WarpTile::at(number<2>{}); + static_assert(K1 % scale_granularity_k == 0, + "MX GEMM: WarpTile K must be a multiple of BlockScaleSize * KXdlPack " + "to support split-K."); + // Defensive runtime check: K must split evenly along K1 boundaries so that each + // k_id consumes a whole number of warp-tile K chunks (and therefore a whole + // number of packed-scale K elements). + if(kargs.K % (K1 * kargs.k_batch) != 0) + { + if(log) + CK_TILE_ERROR("MX GEMM: with k_batch > 1, K must be a multiple of WarpTile_K * " + "k_batch so that every split lands on a packed-scale boundary."); + return false; + } + } + + // Delegate the remaining shape/vector-size checks to the universal kernel. return BaseKernel::IsSupportedArgument(kargs); } @@ -146,10 +230,14 @@ struct MxGemmKernel CK_TILE_DEVICE static auto MakeScaleABlockWindow(const std::array& as_scale_ptr, const KernelArgs& kargs, - index_t block_idx_m) + index_t block_idx_m, + const index_t k_elem_offset = 0) { - const auto&& scale_packs_m = integer_divide_ceil(kargs.M, MThreadPerXdl); - const auto&& scale_packs_k = kargs.K / BlockScaleSize / ScalePackSize; + const auto&& scale_packs_m = integer_divide_ceil(kargs.M, MThreadPerXdl * MXdlPackEff); + const auto&& scale_packs_k = kargs.K / BlockScaleSize / KXdlPackEff; + + // For split-K (k_batch > 1) advance the scale origin into this k_id's packed-K slice. + const index_t k_scale_offset = k_elem_offset / BlockScaleSize / KXdlPackEff; // Scale16: descriptor order [packs_m, MThreadPerXdl, packs_k] -- K contiguous per M-row, // no pre-shuffle needed (natural row-major layout matches). @@ -184,14 +272,28 @@ struct MxGemmKernel return make_tensor_view(as_scale_ptr[i], scale_a_desc); }, number{}); + + // Pad the scale view so partial trailing tiles along M are handled safely (OOB scale + // loads return zero; with A also zero on the padded region the contribution is zero + // regardless of scale value). kPadK is statically disabled, so K never actually pads. + const auto& scale_a_pad_view = generate_tuple( + [&](auto i) { + return pad_tensor_view( + scale_a_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + }, + number{}); + const auto& scale_a_block_window = generate_tuple( [&](auto i) { return make_tile_window( - scale_a_tensor_view[i], + scale_a_pad_view[i], make_tuple( - number{}, - number{}), - {block_idx_m, 0}); + number{}, + number{}), + {block_idx_m / MXdlPackEff, k_scale_offset}); }, number{}); @@ -202,10 +304,14 @@ struct MxGemmKernel CK_TILE_DEVICE static auto MakeScaleBBlockWindow(const std::array& bs_scale_ptr, const KernelArgs& kargs, - index_t block_idx_n) + index_t block_idx_n, + const index_t k_elem_offset = 0) { - const auto&& scale_packs_n = integer_divide_ceil(kargs.N, NThreadPerXdl); - const auto&& scale_packs_k = kargs.K / BlockScaleSize / ScalePackSize; + const auto&& scale_packs_n = integer_divide_ceil(kargs.N, NThreadPerXdl * NXdlPackEff); + const auto&& scale_packs_k = kargs.K / BlockScaleSize / KXdlPackEff; + + // For split-K (k_batch > 1) advance the scale origin into this k_id's packed-K slice. + const index_t k_scale_offset = k_elem_offset / BlockScaleSize / KXdlPackEff; const auto scale_b_naive_desc = [&]() { if constexpr(BlockScaleSize == 16) @@ -236,33 +342,120 @@ struct MxGemmKernel return make_tensor_view(bs_scale_ptr[i], scale_b_desc); }, number{}); + + // Pad the scale view so partial trailing tiles along N are handled safely (OOB scale + // loads return zero; with B also zero on the padded region the contribution is zero + // regardless of scale value). kPadK is statically disabled, so K never actually pads. + const auto& scale_b_pad_view = generate_tuple( + [&](auto i) { + return pad_tensor_view( + scale_b_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + }, + number{}); + const auto& scale_b_block_window = generate_tuple( [&](auto i) { return make_tile_window( - scale_b_tensor_view[i], + scale_b_pad_view[i], make_tuple( - number{}, - number{}), - {block_idx_n, 0}); + number{}, + number{}), + {block_idx_n / NXdlPackEff, k_scale_offset}); }, number{}); return scale_b_block_window; } + CK_TILE_DEVICE static auto + MakeBFlatBlockWindows(const std::array& bs_ptr, + const KernelArgs& kargs, + const index_t i_n, + const index_t k_elem_offset = 0) + { + static_assert(NumBTensor == 1, "MX GEMM preshuffle currently supports one B tensor"); + + constexpr index_t kKPerBlock = MxGemmPipeline::kKPerBlock; + constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1); + constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile; + const index_t kFlatKBlocks = kargs.K / kKPerBlock; + const index_t kFlatN = kargs.N / kNWarpTile; + + const index_t k_flat_offset = (k_elem_offset / kKPerBlock) * flatKPerBlock; + + auto b_flat_tensor_view = [&]() { + static_assert(flatKPerBlock % MxGemmPipeline::GetVectorSizeB() == 0, + "wrong! vector size for preshuffled B tensor"); + auto naive_desc = make_naive_tensor_descriptor_packed( + make_tuple(kFlatN, kFlatKBlocks, number{})); + auto desc = transform_tensor_descriptor( + naive_desc, + make_tuple(make_pass_through_transform(kFlatN), + make_merge_transform_v3_division_mod( + make_tuple(kFlatKBlocks, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view(bs_ptr[number<0>{}], desc); + }(); + + return generate_tuple( + [&](auto) { + return make_tile_window(b_flat_tensor_view, + make_tuple(number{}, + number{}), + {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), + static_cast(k_flat_offset)}); + }, + number{}); + } + + template CK_TILE_DEVICE static void RunGemm(const std::array& as_ptr, const std::array& bs_ptr, const std::array& ds_ptr, EDataType* e_ptr, void* smem_ptr, - const KernelArgs& kargs, + KernelArgs kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, - const index_t block_idx_n) + const index_t block_idx_n, + const index_t k_elem_offset = 0) { std::array as_scale_ptr; - static_for<0, NumATensor, 1>{}([&](auto i) { - as_scale_ptr[i] = reinterpret_cast(kargs.as_scale_ptr[i]); - }); + std::array as_ptr_; + index_t block_idx_m_; + // Large tensor support (when M is large, N and K are relatively small) + using ALayout = remove_cvref_t>; + constexpr bool offset_ptrs_by_tile_coords = + std::is_same_v && + std::is_same_v && !BaseKernel::ClusterLaunch; + + if constexpr(offset_ptrs_by_tile_coords) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + as_ptr_[i] = as_ptr[i] + static_cast(block_idx_m) * + kargs.stride_As[i] / APackedSize; + }); + e_ptr += static_cast(block_idx_m) * kargs.stride_E; + static_for<0, NumATensor, 1>{}([&](auto i) { + as_scale_ptr[i] = reinterpret_cast(kargs.as_scale_ptr[i]) + + static_cast(block_idx_m / MXdlPackEff) * + (kargs.K / BlockScaleSize / KXdlPackEff); + }); + + kargs.M = std::min(kargs.M - block_idx_m, TilePartitioner::MPerBlock); + block_idx_m_ = 0; + } + else + { + static_for<0, NumATensor, 1>{}([&](auto i) { + as_scale_ptr[i] = reinterpret_cast(kargs.as_scale_ptr[i]); + }); + static_for<0, NumATensor, 1>{}([&](auto i) { as_ptr_[i] = as_ptr[i]; }); + block_idx_m_ = block_idx_m; + } std::array bs_scale_ptr; static_for<0, NumBTensor, 1>{}([&](auto i) { @@ -272,18 +465,50 @@ struct MxGemmKernel // cluster launch pads grid to cluster boundaries; skip out-of-bound blocks if constexpr(BaseKernel::ClusterLaunch) { - if(block_idx_m >= kargs.M || block_idx_n >= kargs.N) + if(block_idx_m_ >= kargs.M || block_idx_n >= kargs.N) return; } - const auto& as_block_window = BaseKernel::MakeABlockWindows( - as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); - const auto& bs_block_window = BaseKernel::MakeBBlockWindows( - bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + // The preshuffle A async-load (MakeMX_AAsyncLoadBytesDramWindow) rebuilds the A + // view with a packed descriptor, i.e. it assumes the leading (M) stride equals + // the view's K extent. That only holds when the extent equals stride_A, which is + // the case for k_batch == 1 (splitted_k == K) but NOT for split-K (splitted_k < K): + // a packed extent of splitted_k would stride M by splitted_k instead of stride_A + // and read the wrong rows (only row 0 lands correctly). Use the full K extent so + // the packed M stride matches stride_A. The as_ptr K-offset already selects this + // k_id's slice and num_loop bounds the blocks read, so reads stay within + // [as_k_split_offset, as_k_split_offset + splitted_k) <= K (in-allocation). + const auto& as_block_window = [&]() { + if constexpr(MxGemmPipeline::Preshuffle) + { + return BaseKernel::MakeABlockWindows(as_ptr_, kargs, kargs.K, block_idx_m_); + } + else + { + return BaseKernel::MakeABlockWindows( + as_ptr_, kargs, splitk_batch_offset.splitted_k, block_idx_m_); + } + }(); + const auto& bs_block_window = [&]() { + if constexpr(MxGemmPipeline::Preshuffle) + { + return MakeBFlatBlockWindows(bs_ptr, kargs, block_idx_n, k_elem_offset); + } + else + { + return BaseKernel::MakeBBlockWindows( + bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + } + }(); const auto& ds_block_window = - BaseKernel::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); - const auto& scale_a_block_window = MakeScaleABlockWindow(as_scale_ptr, kargs, block_idx_m); - const auto& scale_b_block_window = MakeScaleBBlockWindow(bs_scale_ptr, kargs, block_idx_n); + BaseKernel::MakeDBlockWindows(ds_ptr, kargs, block_idx_m_, block_idx_n); + + // Create scale block windows. For split-K (k_batch > 1), k_elem_offset advances the + // scale origin into the correct packed-K slice for this k_id; otherwise it is zero. + const auto& scale_a_block_window = + MakeScaleABlockWindow(as_scale_ptr, kargs, block_idx_m_, k_elem_offset); + const auto& scale_b_block_window = + MakeScaleBBlockWindow(bs_scale_ptr, kargs, block_idx_n, k_elem_offset); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); @@ -297,10 +522,58 @@ struct MxGemmKernel num_loop, smem_ptr); - auto c_block_window = BaseKernel::template MakeCBlockWindows( - e_ptr, kargs, block_idx_m, block_idx_n); + // Dispatch epilogue: when k_batch > 1 each split accumulates a partial result into + // the same C tile, so we need atomic add (universal_gemm_kernel pattern). The + // fp16/bf16 even-vector-size precondition is captured once in kSplitKAtomicAddSupported + // and also rejected up front in IsSupportedArgument. + // if(k_batch == 1) + auto c_block_window = BaseKernel::template MakeCBlockWindows( + e_ptr, kargs, block_idx_m_, block_idx_n); EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr); } + + CK_TILE_DEVICE static void RunGemm(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* smem_ptr, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + if(kargs.k_batch == 1) + { + RunGemm(as_ptr, + bs_ptr, + ds_ptr, + e_ptr, + smem_ptr, + kargs, + splitk_batch_offset, + block_idx_m, + block_idx_n); + } + else + { + // This k_id's logical K-element start. For row-major A, as_k_split_offset[0] is exactly + // that offset, so reuse it rather than recomputing the split formula; the packed-scale + // and flat-B K offsets are derived from it. Split-K with non-row-major A is rejected in + // IsSupportedArgument; for k_batch == 1 this value is 0 and unused for any layout. + const index_t k_elem_offset = + amd_wave_read_first_lane(splitk_batch_offset.as_k_split_offset[number<0>{}]); + RunGemm(as_ptr, + bs_ptr, + ds_ptr, + e_ptr, + smem_ptr, + kargs, + splitk_batch_offset, + block_idx_m, + block_idx_n, + k_elem_offset); + } + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp index 63eda6b925..969de3c2d9 100644 --- a/include/ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp @@ -116,7 +116,7 @@ struct MxGroupedGemmKernel using P_ = GemmPipeline; return concat('_', "mx_gemm_grouped", gemm_prec_str(), concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), - concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), + concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB()), concat('x', P_::kPadM, P_::kPadN, P_::kPadK), (UsePersistentKernel ? "Persistent" : "NonPersistent"), (NumDTensor_ == 2 ? "MultiD" : "NoMultiD"), diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index f315c21cef..39718c3338 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -1280,8 +1280,18 @@ struct UniversalGemmKernel std::array bs_ptr; static_for<0, NumBTensor, 1>{}([&](auto i) { - bs_ptr[i] = static_cast(kargs.bs_ptr[i]) + - splitk_batch_offset.bs_k_split_offset[i] / BPackedSize; + if constexpr(GemmPipeline::Preshuffle) + { + // The preshuffle (flat-B) path applies the per-split K offset to the flat + // window origin in when creating the window; bs_k_split_offset is derived from + // the logical B stride and would mis-offset the flat buffer. + bs_ptr[i] = static_cast(kargs.bs_ptr[i]); + } + else + { + bs_ptr[i] = static_cast(kargs.bs_ptr[i]) + + splitk_batch_offset.bs_k_split_offset[i] / BPackedSize; + } }); // Calculate output offset from tile partitioner and apply to output pointer diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 01b94516e7..5247c3909e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -15,9 +15,10 @@ namespace ck_tile { template struct BaseGemmPipelineAgBgCrCompAsync { - static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefetchStages = 3; static constexpr index_t PrefillStages = 1; static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t UnrollHotLoop = 2; CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop) { @@ -30,13 +31,13 @@ struct BaseGemmPipelineAgBgCrCompAsync { return TailNumber::One; } - if(num_loop % PrefetchStages == 1) + if(num_loop % UnrollHotLoop == 0) { - return TailNumber::Three; + return TailNumber::Two; } else { - return TailNumber::Two; + return TailNumber::Three; } } @@ -130,10 +131,9 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync>::PackedSize; - using BlockGemm = remove_cvref_t())>; - using I0 = number<0>; - using I1 = number<1>; - using I2 = number<2>; + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; static constexpr bool LargeTensors = Problem::LargeTensors; @@ -176,6 +176,37 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}; static constexpr auto is_b_load_tr_v = bool_constant{}; + using BlockWarps = typename BlockGemmShape::BlockWarps; + using WarpTile = typename BlockGemmShape::WarpTile; + static constexpr index_t MWarp = BlockWarps::at(I0{}); + static constexpr index_t NWarp = BlockWarps::at(I1{}); + + // Compute effective XdlPack sizes (fall back to 1 when iter count < pack) + static constexpr index_t MPerXdl = WarpTile::at(I0{}); + static constexpr index_t NPerXdl = WarpTile::at(I1{}); + static constexpr index_t KPerXdl = WarpTile::at(I2{}); + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); + static constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; + + static constexpr index_t MXdlPackEff = + (MIterPerWarp >= Policy::MXdlPack && MIterPerWarp % Policy::MXdlPack == 0) + ? Policy::MXdlPack + : 1; + static constexpr index_t NXdlPackEff = + (NIterPerWarp >= Policy::NXdlPack && NIterPerWarp % Policy::NXdlPack == 0) + ? Policy::NXdlPack + : 1; + static constexpr index_t KXdlPackEff = + (KIterPerWarp >= Policy::KXdlPack && KIterPerWarp % Policy::KXdlPack == 0) + ? Policy::KXdlPack + : 1; + + static constexpr index_t ScaleBlockSize = 32; + + // Packed scale dimensions + static constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize / KXdlPackEff; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() { // clang-format off @@ -246,6 +277,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync::value && is_detected::value, bool>* = nullptr> @@ -253,9 +286,16 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync && + !is_null_tile_window_v; + using BlockGemm = + remove_cvref_t())>; + // TODO support multi-ABD static_assert(1 == std::tuple_size_v); static_assert(1 == std::tuple_size_v); @@ -370,6 +410,23 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}], b_dram_tile_window_step); } + // Load scales for iteration 0 (ping) + load_scales_from_dram(scale_a_tile_ping, scale_b_tile_ping); + + // Load scales for iteration 1 (pong) if needed + if(num_loop > 1) + { + load_scales_from_dram(scale_a_tile_pong, scale_b_tile_pong); + } + if constexpr(HasHotLoop) { // we have had 3 global prefetches so far, indexed (0, 1, 2). @@ -482,8 +548,14 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}], b_dram_tile_window_step); // C(i-3) = A(i-3) @ B(i-3) - block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + block_gemm(c_block_tile, + a_block_tile0, + b_block_tile0, + scale_a_tile_ping, + scale_b_tile_ping); HotLoopScheduler(); + // Load next scales after using current scales above + load_scales_from_dram(scale_a_tile_ping, scale_b_tile_ping); } // pong { @@ -503,8 +575,14 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}], b_dram_tile_window_step); // C(i-2) = A(i-2) @ B(i-2) - block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + block_gemm(c_block_tile, + a_block_tile1, + b_block_tile1, + scale_a_tile_pong, + scale_b_tile_pong); HotLoopScheduler(); + // Load next scales after using current scales above + load_scales_from_dram(scale_a_tile_pong, scale_b_tile_pong); } i_global_read += 2; } while(i_global_read < num_loop); @@ -518,7 +596,13 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}, number<0>{}))); + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, + index_t num_loop, + void* p_smem) const + { + // Scale tensor views and base origins for creating tile windows per iteration + const auto& scale_a_tensor_view = scale_a_window[number<0>{}].get_bottom_tensor_view(); + const auto& scale_b_tensor_view = scale_b_window[number<0>{}].get_bottom_tensor_view(); + auto scale_a_base_origin = scale_a_window[number<0>{}].get_window_origin(); + auto scale_b_base_origin = scale_b_window[number<0>{}].get_window_origin(); + + // Create scale windows with packed int32_t dimensions + auto scale_a_dram_window = make_tile_window( + scale_a_tensor_view, + make_tuple(number{}, number{}), + scale_a_base_origin, + Policy::template MakeMX_ScaleA_DramTileDistribution()); + + auto scale_b_dram_window = make_tile_window( + scale_b_tensor_view, + make_tuple(number{}, number{}), + scale_b_base_origin, + Policy::template MakeMX_ScaleB_DramTileDistribution()); + + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + scale_a_dram_window, + scale_b_dram_window, + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } + template + static constexpr index_t AsyncVectorBytes = + sizeof(DataType) * KPack / numeric_traits>::PackedSize; + template static constexpr bool IsSupportedAsyncVectorWidth = - sizeof(DataType) * KPack == 4 || sizeof(DataType) * KPack == 12 || - sizeof(DataType) * KPack == 16; + AsyncVectorBytes == 4 || AsyncVectorBytes == 12 || + AsyncVectorBytes == 16; // XOR Swizzle: support FP8 / BF8 template @@ -57,10 +62,10 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy // Compute the number of LDS read accesses for A or B // IsLoadTr=true if ds_read_tr is used - template + template CK_TILE_HOST_DEVICE static constexpr auto CalculateWGAttrNumAccess() { - if constexpr(IsLoadTr) + if constexpr(IsLoadTr && !IsScale) { // Transpose-load path: ds_read_tr reads DS_READ_TR_SIZE bytes per instruction. constexpr index_t vector_size = @@ -91,32 +96,34 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy } } - template + template CK_TILE_HOST_DEVICE static constexpr auto GetAWGAttrNumAccess() { using WarpTile = typename Problem::BlockGemmShape::WarpTile; constexpr index_t thread_elements = WarpTile::at(I0) * WarpTile::at(I2) / get_warp_size(); return CalculateWGAttrNumAccess, typename Problem::ADataType, - thread_elements>(); + thread_elements, + IsScale>(); } - template + template CK_TILE_HOST_DEVICE static constexpr auto GetBWGAttrNumAccess() { using WarpTile = typename Problem::BlockGemmShape::WarpTile; constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size(); return CalculateWGAttrNumAccess, typename Problem::BDataType, - thread_elements>(); + thread_elements, + IsScale>(); } // Get number of accesses - template + template CK_TILE_HOST_DEVICE static constexpr auto GetWGAttrNumAccess() { - constexpr auto num_access_a = GetAWGAttrNumAccess(); - constexpr auto num_access_b = GetBWGAttrNumAccess(); + constexpr auto num_access_a = GetAWGAttrNumAccess(); + constexpr auto num_access_b = GetBWGAttrNumAccess(); if constexpr(num_access_a == WGAttrNumAccessEnum::Invalid || num_access_b == WGAttrNumAccessEnum::Invalid) @@ -127,6 +134,70 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy return num_access_b; } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXorSwizzleABDramTileDistribution() + { + using BlockGemmShape = typename Problem::BlockGemmShape; + using BlockWarps = typename BlockGemmShape::BlockWarps; + using WarpTile = typename BlockGemmShape::WarpTile; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t KPerBlock = BlockGemmShape::kK; + constexpr index_t KWarps = BlockWarps::at(I2); + constexpr index_t K1 = WarpTile::at(I2) / K2; + constexpr index_t K0 = KPerBlock / (KWarps * K1 * K2); + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t warp_num = BlockSize / warp_size; + + static_assert(KWarps == 1, "MX XOR swizzle currently supports KWarps == 1"); + static_assert(KWarps * K0 * K1 * K2 == KPerBlock, "Wrong!"); + + constexpr index_t M2 = warp_size / K1; + constexpr index_t M1 = warp_num / Problem::NumWaveGroups; + constexpr index_t M0 = MNPerBlock / (M1 * M2); + + static_assert(M0 * M1 * M2 == MNPerBlock, "Wrong!"); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + if constexpr(UseXorSwizzle) + { + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPack = Base::template GetSmemPackA(); + return MakeXorSwizzleABDramTileDistribution(); + } + else + { + return Base::template MakeADramTileDistribution(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + if constexpr(UseXorSwizzle) + { + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPack = Base::template GetSmemPackB(); + return MakeXorSwizzleABDramTileDistribution(); + } + else + { + return Base::template MakeBDramTileDistribution(); + } + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() + { + using BlockGemmShape = typename Problem::BlockGemmShape; + using BlockWarps = typename BlockGemmShape::BlockWarps; + using WarpTile = typename BlockGemmShape::WarpTile; + + constexpr index_t ScaleGranularityK = 32; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t MWarp = BlockWarps::at(number<0>{}); + constexpr index_t NWarp = BlockWarps::at(number<1>{}); + constexpr index_t MPerXdl = WarpTile::at(number<0>{}); + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K_Lane = get_warp_size() / MPerXdl; + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); + constexpr index_t KPerXdl = WarpTile::at(number<2>{}); + constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; + constexpr index_t KPerLane = KPerXdl / ScaleGranularityK / K_Lane; + + // Effective pack sizes: fall back to 1 when iteration count < pack size + constexpr index_t MXdlPackEff = + (MIterPerWarp >= MXdlPack && MIterPerWarp % MXdlPack == 0) ? MXdlPack : 1; + constexpr index_t KXdlPackEff = + (KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1; + + constexpr index_t MIterPerWarp_packed = MIterPerWarp / MXdlPackEff; + constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<2, 1, 2>, + sequence<0, 1, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution() + { + using BlockGemmShape = typename Problem::BlockGemmShape; + using BlockWarps = typename BlockGemmShape::BlockWarps; + using WarpTile = typename BlockGemmShape::WarpTile; + + constexpr index_t ScaleGranularityK = 32; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t MWarp = BlockWarps::at(number<0>{}); + constexpr index_t NWarp = BlockWarps::at(number<1>{}); + constexpr index_t NPerXdl = WarpTile::at(number<1>{}); + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t K_Lane = get_warp_size() / NPerXdl; + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); + + constexpr index_t KPerXdl = WarpTile::at(number<2>{}); + constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; + constexpr index_t KPerLane = KPerXdl / ScaleGranularityK / K_Lane; + + // Effective pack sizes: fall back to 1 when iteration count < pack size + constexpr index_t NXdlPackEff = + (NIterPerWarp >= NXdlPack && NIterPerWarp % NXdlPack == 0) ? NXdlPack : 1; + constexpr index_t KXdlPackEff = + (KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1; + + constexpr index_t NIterPerWarp_packed = NIterPerWarp / NXdlPackEff; + constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<2, 1, 2>, + sequence<0, 1, 2>>{}); + } + template CK_TILE_DEVICE static constexpr auto GetEstimatedVgprCount() { @@ -479,13 +639,13 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy return number{}; } - template + template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr auto wg_attr_num_access = GetWGAttrNumAccess(); + constexpr auto wg_attr_num_access = GetWGAttrNumAccess(); constexpr auto pipeline_tune_params = GetPipelineSubTileNum(); constexpr index_t sub_tile_num = EnableSubTile ? pipeline_tune_params.value : 1; @@ -506,7 +666,8 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy typename Problem::CDataType, BlockWarps, WarpGemm, - sub_tile_num>; + sub_tile_num, + PackMNIter>; return BlockGemmARegBRegCRegV1{}; } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp index 7c4e42c700..b35fc317e9 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp @@ -45,9 +45,6 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp static constexpr index_t APackedSize = ck_tile::numeric_traits::PackedSize; static constexpr index_t BPackedSize = ck_tile::numeric_traits::PackedSize; - using BlockGemm = remove_cvref_t())>; - using WarpGemm = typename BlockGemm::WarpGemm; - static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; @@ -66,9 +63,9 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp static constexpr index_t kflatKPerWarp = BlockGemmShape::flatKPerWarp; - static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * WarpGemm::kM); - static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * WarpGemm::kN); - static constexpr index_t KIterPerWarp = KPerBlock / (KWarps * WarpGemm::kK); + static constexpr index_t MXdlPackEff = Policy::template GetMXdlPackEff(); + static constexpr index_t NXdlPackEff = Policy::template GetNXdlPackEff(); + static constexpr index_t KXdlPackEff = Policy::template GetKXdlPackEff(); static constexpr bool Async = true; @@ -97,6 +94,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp static constexpr auto Scheduler = Problem::Scheduler; + static constexpr index_t ScaleBlockSize = 32; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() { // clang-format off @@ -123,8 +122,6 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp return Policy::template GetSmemSize(); } - static constexpr index_t MFMA_INST = MIterPerWarp * NIterPerWarp * KIterPerWarp; - template struct PipelineImpl : public PipelineImplBase { @@ -141,6 +138,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp typename BsDramBlockWindowTmp, typename AElementFunction, typename BElementFunction, + typename ScaleADramBlockWindowTmp, + typename ScaleBDramBlockWindowTmp, typename std::enable_if_t::value && !is_detected::value, bool>* = nullptr> @@ -148,9 +147,23 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp const AElementFunction& a_element_func, const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, index_t num_loop, void* __restrict__ p_smem) const { + constexpr bool IsScaledGemm = !is_null_tile_window_v && + !is_null_tile_window_v; + using BlockGemm = + remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t MIterPerWarp = MPerBlock / (MWarps * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarps * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / (KWarps * WarpGemm::kK); + + constexpr index_t MFMA_INST = MIterPerWarp * NIterPerWarp * KIterPerWarp; + // TODO: A/B elementwise functions currently not supported ignore = a_element_func; ignore = b_element_func; @@ -183,12 +196,10 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp // Hot loop scheduler // ------------------ auto hot_loop_scheduler = [&]() { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, MIterPerWarp, 0); // MFMA s_waitcnt_lgkm<4>(); __builtin_amdgcn_sched_group_barrier(0x004, 1, 0); // lgkmcnt / SALU - static_for<0, MFMA_INST - 3, 1>{}([&](auto) { + static_for<0, MFMA_INST - MIterPerWarp, 1>{}([&](auto) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA }); __builtin_amdgcn_sched_barrier(0); @@ -201,10 +212,57 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp num_loop, a_dram_block_window_tmp, b_dram_block_window_tmp, + scale_a_window, + scale_b_window, hot_loop_scheduler); } }; + template < + typename AsDramBlockWindowTmp, + typename BsDramBlockWindowTmp, + typename AElementFunction, + typename BElementFunction, + typename ScaleADramBlockWindowTmp, + typename ScaleBDramBlockWindowTmp, + typename std::enable_if_t::value && + is_detected::value && + is_detected::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, + index_t num_loop, + void* p_smem) const + { + // TODO: A/B windows are tuple of windows, but the implementation doesn't take that into + // account yet and just the first element is passed + static_assert(AsDramBlockWindowTmp::size() == 1); + static_assert(BsDramBlockWindowTmp::size() == 1); + static_assert(ScaleADramBlockWindowTmp::size() == 1); + static_assert(ScaleBDramBlockWindowTmp::size() == 1); + + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp[I0], + a_element_func, + b_dram_block_window_tmp[I0], + b_element_func, + scale_a_window[I0], + scale_b_window[I0], + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } + template {}, number<0>{}))); // TODO: A/B windows are tuple of windows, but the implementation doesn't take that into // account yet and just the first element is passed static_assert(AsDramBlockWindowTmp::size() == 1); @@ -231,6 +291,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp a_element_func, b_dram_block_window_tmp[I0], b_element_func, + NullTileWindowType{}, + NullTileWindowType{}, num_loop, p_smem); }; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp index 56709be910..4a6e1ab705 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp @@ -389,30 +389,85 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked; static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked; - CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + // Scale part + static constexpr int BlockScaleSize = 32; + + // XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension + // Host packs MXdlPack * KXdlPack e8m0_t into one int32_t for A scales + // Host packs NXdlPack * KXdlPack e8m0_t into one int32_t for B scales + static constexpr int MXdlPack = 2; + static constexpr int NXdlPack = 2; + static constexpr int KXdlPack = 2; + + // Compute effective XdlPack sizes (fall back to 1 when iter count < pack) + static constexpr index_t KPerXdl = WarpTile::at(I2); + static constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; + + static constexpr index_t MXdlPackEff = + (MIterPerWarp >= MXdlPack && MIterPerWarp % MXdlPack == 0) ? MXdlPack : 1; + static constexpr index_t NXdlPackEff = + (NIterPerWarp >= NXdlPack && NIterPerWarp % NXdlPack == 0) ? NXdlPack : 1; + static constexpr index_t KXdlPackEff = + (KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1; + + static constexpr index_t KPerBlockScale = KPerBlock / BlockScaleSize / KXdlPackEff; + + CK_TILE_HOST_DEVICE static constexpr auto GetMXdlPackEff() { return MXdlPackEff; } + CK_TILE_HOST_DEVICE static constexpr auto GetNXdlPackEff() { return NXdlPackEff; } + CK_TILE_HOST_DEVICE static constexpr auto GetKXdlPackEff() { return KXdlPackEff; } + + CK_TILE_HOST_DEVICE static constexpr auto GetKStepAQ() { return KPerBlockScale; } + CK_TILE_HOST_DEVICE static constexpr auto GetKStepBQ() { return KPerBlockScale; } + + CK_TILE_HOST_DEVICE static constexpr auto GetInstCountAQ() { - // TODO: Fix for transpose - constexpr auto wg_attr_num_access = WGAccess; + return (MIterPerWarp / MXdlPackEff) * (KIterPerWarp / KXdlPackEff); + } - using WarpGemm = WarpGemmDispatcher; + CK_TILE_HOST_DEVICE static constexpr auto GetInstCountBQ() + { + return (NIterPerWarp / NXdlPackEff) * (KIterPerWarp / KXdlPackEff); + } - using BlockGemmPolicy = - BlockGemmARegBRegCRegV1CustomPolicy; + CK_TILE_HOST_DEVICE static constexpr auto MakeAQBlockDistribution() + { + constexpr index_t K_Lane = get_warp_size() / WarpTileM; - return BlockGemmARegBRegCRegEightWavesV1{}; + constexpr index_t KPerLane = WarpTileK / BlockScaleSize / K_Lane; + + constexpr index_t MIterPerWarp_packed = MIterPerWarp / MXdlPackEff; + constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // repeat over MWarps + tuple, // M dimension (first) + sequence>, // K dimension (second) + tuple, sequence<2, 1>>, // , + tuple, sequence<1, 2>>, + sequence<2, 1, 2>, // + sequence<0, 1, 2>>{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeBQBlockDistribution() + { + constexpr index_t K_Lane = get_warp_size() / WarpTileN; + + constexpr index_t KPerLane = WarpTileK / BlockScaleSize / K_Lane; + + constexpr index_t NIterPerWarp_packed = NIterPerWarp / NXdlPackEff; + constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // repeat over MWarps + tuple, // N dimension + // (first) + sequence>, // K dimension (second) + tuple, sequence<2, 1>>, // , + tuple, sequence<1, 3>>, + sequence<2, 1, 2>, // + sequence<0, 1, 2>>{}); } }; } // namespace detail @@ -447,8 +502,61 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy FORWARD_METHOD_(GetSmemPackA); FORWARD_METHOD_(GetSmemPackB); FORWARD_METHOD_(IsPreshuffle); + // Scale part + FORWARD_METHOD_(MakeAQBlockDistribution); + FORWARD_METHOD_(MakeBQBlockDistribution); + FORWARD_METHOD_(GetKStepAQ); + FORWARD_METHOD_(GetKStepBQ); + FORWARD_METHOD_(GetInstCountAQ); + FORWARD_METHOD_(GetInstCountBQ); + FORWARD_METHOD_(GetMXdlPackEff); + FORWARD_METHOD_(GetNXdlPackEff); + FORWARD_METHOD_(GetKXdlPackEff); #undef FORWARD_METHOD_ + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockGemmShape = typename Problem::BlockGemmShape; + using BlockWarps = typename BlockGemmShape::BlockWarps; + using WarpTile = typename BlockGemmShape::WarpTile; + + using AComputeDataType = remove_cvref_t; + using BComputeDataType = remove_cvref_t; + static_assert(std::is_same_v); + using ComputeDataType = AComputeDataType; + + constexpr auto WGAccess = + std::is_same_v || std::is_same_v + ? WGAttrNumAccessEnum::Double + : WGAttrNumAccessEnum::Single; + + // TODO: Fix for transpose + constexpr auto wg_attr_num_access = WGAccess; + + using WarpGemm = WarpGemmDispatcher{}), + WarpTile::at(number<1>{}), + WarpTile::at(number<2>{}), + Problem::TransposeC, + false, + false, + wg_attr_num_access>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegEightWavesV1{}; + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_tdm_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_tdm_v1.hpp index a88319927b..d50414eb75 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_tdm_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_tdm_v1.hpp @@ -138,6 +138,10 @@ struct GemmPipelineAgBgCrCompTDMV1 : public BaseGemmPipelineAgBgCrCompTDM(); // for these three functions, we always return 1 since TDM handles vectorization internally diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp index 823c4eef32..49a8f84078 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp @@ -17,9 +17,6 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< using BlockGemmShape = remove_cvref_t; - using BlockGemm = remove_cvref_t())>; - using WarpGemm = typename BlockGemm::WarpGemm; - static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; @@ -42,10 +39,6 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; static constexpr index_t WarpTileN = BlockGemmShape::WarpTile::at(I1); - static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * WarpGemm::kM); - static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * WarpGemm::kN); - static constexpr index_t KIterPerWarp = KPerBlock / (KWarps * WarpGemm::kK); - // Rely on the policy. In this way it works for both GEMM and blockscale static constexpr bool Preshuffle = Policy::template IsPreshuffle(); @@ -72,23 +65,30 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< lds_tile_window.load(dst_block_tile, number<-1>{}, true_type{}, static_move_ys{}); } - template + template CK_TILE_DEVICE void LocalPrefetchB(DataType* smem, DstBlockTile& dst_block_tile, - SrcTileWindow& lds_tile_window) const + SrcTileWindow& lds_tile_window, + number = {}, + number = {}) const { + constexpr index_t NIterPerWarp = NPerBlock / (NWarps * NPerXdl); + constexpr index_t KIterPerWarp = KPerBlock / (KWarps * KPerXdl); // swizzle factor limitation using static_move_ys = std::conditional_t, false_type, true_type>; lds_tile_window.set_bottom_tensor_view_data_ptr(smem); static_for_product, number>{}( [&](auto nIter, auto kIter) { - lds_tile_window.load_with_offset( - number_tuple{}, - dst_block_tile[nIter][kIter], - number<-1>{}, - true_type{}, - static_move_ys{}); + lds_tile_window.load_with_offset(number_tuple{}, + dst_block_tile[nIter][kIter], + number<-1>{}, + true_type{}, + static_move_ys{}); }); } @@ -290,6 +290,12 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< const BQDramBlockWindowTmp& bq_dram_block_window_tmp, SchedulerFunc&& scheduler_func) const { + constexpr bool IsScaledGemm = !is_null_tile_window_v && + !is_null_tile_window_v; + using BlockGemm = + remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + // Loop count constexpr index_t N_LOOP = HasHotLoop ? 4 : TailNum == TailNumber::One ? 1 @@ -378,7 +384,11 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< LocalPrefetchA(smem_a, a_block_tile, a_lds_gemm_window); BDataType* smem_b = reinterpret_cast(smem01[i] + lds_offset_b); - LocalPrefetchB(smem_b, b_block_tiles, b_lds_gemm_window); + LocalPrefetchB(smem_b, + b_block_tiles, + b_lds_gemm_window, + number{}, + number{}); }; auto calc_gemm = [&](index_t i) { @@ -418,7 +428,11 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< GlobalPrefetchAsync(smem_b_tic, b_copy_lds_window, b_copy_dram_window); BDataType* smem_b_toc = reinterpret_cast(smem01[toc] + lds_offset_b); - LocalPrefetchB(smem_b_toc, b_block_tiles, b_lds_gemm_window); + LocalPrefetchB(smem_b_toc, + b_block_tiles, + b_lds_gemm_window, + number{}, + number{}); __builtin_amdgcn_sched_barrier(0); block_sync_lds_direct_load(); diff --git a/include/ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1.hpp similarity index 86% rename from include/ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp rename to include/ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1.hpp index e9d80d73b7..a19a28ea95 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1.hpp @@ -8,16 +8,11 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" -#include "ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" namespace ck_tile { -template -struct MXEpilogueTraits -{ - static constexpr index_t BlockedXDLNPerWarp = GemmConfig::Preshuffle ? 2 : 1; -}; - // This pipeline extends the existing universal GEMM machinery with preshuffled-B support. template struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 @@ -53,9 +48,9 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t WaveSize = get_warp_size(); - static constexpr index_t kMPerBlock = BlockGemmShape::kM; - static constexpr index_t kNPerBlock = BlockGemmShape::kN; - static constexpr index_t kKPerBlock = BlockGemmShape::kK; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; @@ -69,12 +64,12 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 template static constexpr index_t GetVectorSizeA() { - return 32; + return PipelinePolicy::template GetVectorSizeA(); } template static constexpr index_t GetVectorSizeB() { - return 32; + return PipelinePolicy::template GetVectorSizeB(); } static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } @@ -93,21 +88,26 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 static constexpr index_t MWarp = BlockGemm::MWarp; static constexpr index_t NWarp = BlockGemm::NWarp; - static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); - static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); - static constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; static constexpr index_t KFlatBytesPerBlockPerIter = flatKPerWarp * sizeof(BDataType) / BPackedSize; static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; - static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; - static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; + static constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + static constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; - static constexpr index_t ScaleGranularityK = 32; - static constexpr index_t MXdlPack = 2; - static constexpr index_t NXdlPack = 2; - static constexpr index_t KXdlPack = 2; + static constexpr index_t ScaleBlockSize = 32; + static constexpr index_t MXdlPack = 2; + static constexpr index_t NXdlPack = 2; + static constexpr index_t KXdlPack = 2; + + // Preshuffle only supports this case as checked by static asserts + static constexpr index_t MXdlPackEff = MXdlPack; + static constexpr index_t NXdlPackEff = NXdlPack; + static constexpr index_t KXdlPackEff = KXdlPack; static constexpr index_t AK1 = 16 * APackedSize / sizeof(ADataType); static constexpr index_t BK1 = 16 * BPackedSize / sizeof(BDataType); @@ -125,12 +125,12 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 static constexpr index_t Aload_num_perK = dswrite_num_perK; static constexpr index_t Aload_rep = dswrite_rep; - static constexpr index_t Bload_num_perK = kNPerBlock * WarpGemm::kK / NWarp / BK1 / WaveSize; + static constexpr index_t Bload_num_perK = NPerBlock * WarpGemm::kK / NWarp / BK1 / WaveSize; static constexpr index_t Bload_num = Bload_num_perK * KIterPerWarp; static constexpr index_t ScaleBload_num = - kNPerBlock * kKPerBlock / NWarp / ScaleGranularityK / NXdlPack / KXdlPack / WaveSize; + NPerBlock * KPerBlock / NWarp / ScaleBlockSize / NXdlPack / KXdlPack / WaveSize; static constexpr index_t ScaleAload_num = - kMPerBlock * kKPerBlock / MWarp / ScaleGranularityK / MXdlPack / KXdlPack / WaveSize; + MPerBlock * KPerBlock / MWarp / ScaleBlockSize / MXdlPack / KXdlPack / WaveSize; static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2; static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter; @@ -181,9 +181,9 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 std::is_same_v>, "wrong!"); - static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!"); - static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + static_assert(KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); static_assert(MWarp == 1); @@ -194,7 +194,7 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 a_copy_dram_window_tmp); using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; constexpr ADramTileWindowStep a_dram_tile_window_step = - make_array(index_t{0}, index_t{kKPerBlock * sizeof(ADataType) / APackedSize}); + make_array(index_t{0}, index_t{KPerBlock * sizeof(ADataType) / APackedSize}); __builtin_amdgcn_sched_barrier(0); @@ -208,13 +208,13 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 auto a_store_lds_window_ping = make_tile_window(a_lds_block_ping, - make_tuple(number{}, - number{}), + make_tuple(number{}, + number{}), {0, 0}); auto a_store_lds_window_pong = make_tile_window(a_lds_block_pong, - make_tuple(number{}, - number{}), + make_tuple(number{}, + number{}), {0, 0}); auto a_warp_window_ping = make_tile_window( @@ -306,7 +306,7 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); }); }); - move_tile_window(scale_a_dram_window, {0, kKPerBlock / (ScaleGranularityK * KXdlPack)}); + move_tile_window(scale_a_dram_window, {0, KPerBlock / (ScaleBlockSize * KXdlPack)}); static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { @@ -315,7 +315,7 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); }); }); - move_tile_window(scale_b_dram_window, {0, kKPerBlock / (ScaleGranularityK * KXdlPack)}); + move_tile_window(scale_b_dram_window, {0, KPerBlock / (ScaleBlockSize * KXdlPack)}); __builtin_amdgcn_sched_barrier(0); if constexpr(HasHotLoop || TailNum == TailNumber::Even) @@ -375,10 +375,8 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 Base::GlobalPrefetchAsync( a_store_lds_window_ping, a_dram_window, a_dram_tile_window_step); - move_tile_window(scale_a_dram_window, - {0, kKPerBlock / (ScaleGranularityK * KXdlPack)}); - move_tile_window(scale_b_dram_window, - {0, kKPerBlock / (ScaleGranularityK * KXdlPack)}); + move_tile_window(scale_a_dram_window, {0, KPerBlock / (ScaleBlockSize * KXdlPack)}); + move_tile_window(scale_b_dram_window, {0, KPerBlock / (ScaleBlockSize * KXdlPack)}); block_gemm.LocalPrefetch(a_load_windows_pong); HotLoopScheduler(); @@ -420,10 +418,8 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 Base::GlobalPrefetchAsync( a_store_lds_window_pong, a_dram_window, a_dram_tile_window_step); - move_tile_window(scale_a_dram_window, - {0, kKPerBlock / (ScaleGranularityK * KXdlPack)}); - move_tile_window(scale_b_dram_window, - {0, kKPerBlock / (ScaleGranularityK * KXdlPack)}); + move_tile_window(scale_a_dram_window, {0, KPerBlock / (ScaleBlockSize * KXdlPack)}); + move_tile_window(scale_b_dram_window, {0, KPerBlock / (ScaleBlockSize * KXdlPack)}); block_gemm.LocalPrefetch(a_load_windows_ping); HotLoopScheduler(); @@ -707,6 +703,43 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 } } + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_copy_dram_window_tmp, + const AElementFunction&, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const BElementFunction&, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, + index_t num_loop, + void* __restrict__ p_smem) const + { + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + constexpr index_t smem_size = PipelinePolicy::template GetSmemSize(); + const auto smem = reinterpret_cast(p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_num = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_copy_dram_window_tmp[number<0>{}], + b_flat_dram_block_window_tmp[number<0>{}], + scale_a_window[number<0>{}], + scale_b_window[number<0>{}], + num_loop, + smem, + smem + smem_size); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_num); + } + template (); + const auto smem = reinterpret_cast(p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_num = Base::GetBlockLoopTailNum(num_loop); const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { return PipelineImpl{}.template operator()( @@ -729,8 +763,8 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 scale_a_window, scale_b_window, num_loop, - p_smem_ping, - p_smem_pong); + smem, + smem + smem_size); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_num); diff --git a/include/ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1_policy.hpp similarity index 98% rename from include/ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp rename to include/ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1_policy.hpp index 04fac8f67a..8ac832f158 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" -#include "ck_tile/ops/gemm_mx/block/block_mx_asmem_breg_creg.hpp" +#include "ck_tile/ops/gemm/block/block_mx_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" namespace ck_tile { @@ -57,6 +57,9 @@ struct MXGemmPipelineAgBgCrPolicy : UniversalGemmPipelineAgBgCrPolicy static constexpr index_t AK1 = DWORDx4 * APackedSize; static constexpr index_t BK1 = DWORDx4 * BPackedSize; + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() { return AK1; } + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() { return BK1; } + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { using WarpGemm = WarpGemmDispatcher -struct BlockMXGemmARegBRegCRegEightWavesV1 -{ - private: - template - struct GemmTraits_ - { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using AComputeDataType = remove_cvref_t; - using BComputeDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr auto Scheduler = Problem::Scheduler; - - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; - - static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - static constexpr index_t MWarp = config.template at<1>(); - static constexpr index_t NWarp = config.template at<2>(); - static constexpr index_t KWarp = Problem::BlockGemmShape::BlockWarps::at(number<2>{}); - - using I0 = number<0>; - using I1 = number<1>; - - static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}), - "Error! WarpGemm's MWarp is not consistent with BlockGemmShape!"); - static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}), - "Error! WarpGemm's NWarp is not consistent with BlockGemmShape!"); - static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}), - "Error! WarpGemm's M is not consistent with BlockGemmShape!"); - static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}), - "Error! WarpGemm's N is not consistent with BlockGemmShape!"); - - static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); - static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); - static constexpr index_t KIterPerWarp = KPerBlock / (KWarp * WarpGemm::kK); - - // Controls how many MAC clusters (MFMA blocks) we have per wave - // If InterWaveSchedulingMacClusters = 1; - // Then we group all WarpGemms into single MAC cluster. - // But if InterWaveSchedulingMacClusters = 2, then we - // split the warp gemms into two groups. - static constexpr index_t InterWaveSchedulingMacClusters = 1; - - static constexpr index_t KPackA = WarpGemm::kAKPack; - static constexpr index_t KPackB = WarpGemm::kBKPack; - static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; - static constexpr bool TransposeC = Problem::TransposeC; - }; - - public: - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using Traits = GemmTraits_; - - using WarpGemm = typename Traits::WarpGemm; - using BlockGemmShape = typename Traits::BlockGemmShape; - - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using AComputeDataType = remove_cvref_t; - using BComputeDataType = remove_cvref_t; - - static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; - static constexpr index_t MIterPerWarp = Traits::MIterPerWarp; - static constexpr index_t NIterPerWarp = Traits::NIterPerWarp; - - static constexpr index_t MWarp = Traits::MWarp; - static constexpr index_t NWarp = Traits::NWarp; - static constexpr index_t KWarp = Traits::KWarp; - - static constexpr auto Scheduler = Traits::Scheduler; - static constexpr bool TransposeC = Traits::TransposeC; - - using AWarpDstr = typename WarpGemm::AWarpDstr; - using BWarpDstr = typename WarpGemm::BWarpDstr; - using CWarpDstr = typename WarpGemm::CWarpDstr; - - using AWarpTensor = typename WarpGemm::AWarpTensor; - using BWarpTensor = typename WarpGemm::BWarpTensor; - using CWarpTensor = typename WarpGemm::CWarpTensor; - - static_assert(std::is_same_v); - - static constexpr auto a_warp_y_lengths = - to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - static constexpr auto b_warp_y_lengths = - to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - static constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - - static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; - static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; - static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - static constexpr index_t APackedSize = - ck_tile::numeric_traits>::PackedSize; - static constexpr index_t BPackedSize = - ck_tile::numeric_traits>::PackedSize; - - using I0 = number<0>; - using I1 = number<1>; - - // Note: distribution encodings have MIterPerWarp and NIterPerWarp contiguous because of scale - // packing. - - CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() - { - constexpr index_t KPerThread = Traits::KPerThread; - constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - - constexpr index_t KPerInnerLoop = - ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); - - constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; - - using KIterSeq = std::conditional_t, - sequence>; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<1, 1>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - return a_block_dstr_encode; - } - - CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() - { - constexpr index_t KPerThread = Traits::KPerThread; - constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - constexpr index_t KPerInnerLoop = - ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); - constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; - - using KIterSeq = std::conditional_t, - sequence>; - - constexpr auto b_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<>, - sequence<>>{}; - - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - return b_block_dstr_encode; - } - - CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() - { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence<2, NIterPerWarp, NWarp / 2>>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<1, 1>>{}; - constexpr auto c_block_dstr_encoding = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - return c_block_dstr_encoding; - } - - CK_TILE_DEVICE static constexpr auto MakeCBlockTile() - { - return make_static_distributed_tensor( - make_static_tile_distribution(MakeCBlockDistributionEncode())); - } - - using ALdsTile = decltype(make_static_distributed_tensor( - make_static_tile_distribution(MakeABlockDistributionEncode()))); - using BLdsTiles = statically_indexed_array< - statically_indexed_array( - make_static_tile_distribution( - MakeBBlockDistributionEncode()))), - KIterPerWarp>, - NIterPerWarp>; - - // C += A * B - template - CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ALdsTile& a_warp_tile_, - const BLdsTiles& b_warp_tiles_, - const ScaleATensor& scale_a_tensor, - const ScaleBTensor& scale_b_tensor) const - { - // checks - static_assert(std::is_same_v>, - "CDataType must be same as CBlockTensor::DataType!"); - static_assert( - std::is_same_v, - remove_cvref_t>, - "C distribution is wrong!"); - - // Effective XdlPack: fall back to 1 when iteration count is insufficient - constexpr index_t MXdlPack = - (MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1; - constexpr index_t NXdlPack = - (NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1; - constexpr index_t KXdlPack = - (KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1; - - constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack; - constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack; - constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack; - - // hot loop: - static_for_product, - number, - number>{}([&](auto ikpack, auto inpack, auto impack) { - // get A scale for this M-K tile using get_y_sliced_thread_data - auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const int32_t a_scale_packed = bit_cast(scale_a_slice[number<0>{}]); - - // get B scale for this N-K tile using get_y_sliced_thread_data - auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const int32_t b_scale_packed = bit_cast(scale_b_slice[number<0>{}]); - - // Inner loops: issue MFMAs within the pack group using OpSel - static_for_product, number, number>{}( - [&](auto ikxdl, auto inxdl, auto imxdl) { - constexpr auto kIter = ikpack * KXdlPack + ikxdl; - constexpr auto mIter = impack * MXdlPack + imxdl; - constexpr auto nIter = inpack * NXdlPack + inxdl; - - // OpSel for A: selects byte within packed int32_t - constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl; - - // OpSel for B: selects byte within packed int32_t - constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl; - - // read A warp tensor from A Block window - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = - b_warp_tiles_[number{}][number{}].get_thread_buffer(); - - // read C warp tensor from C block tensor - using c_iter_idx = sequence; - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM with MX scaling - WarpGemm{}.template operator(), OpSelB>( - c_warp_tensor, - a_warp_tensor, - b_warp_tensor, - a_scale_packed, - b_scale_packed); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp deleted file mode 100644 index 7e190dc8e1..0000000000 --- a/include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp +++ /dev/null @@ -1,324 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp" - -namespace ck_tile { - -// A is block distributed tensor -// B is block distributed tensor -// C is block distributed tensor -template -struct BlockMXGemmARegBRegCRegV1 -{ - private: - template - struct GemmTraits_ - { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - static constexpr index_t kBlockSize = Problem::kBlockSize; - - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; - - static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - static constexpr index_t MWarp = config.template at<1>(); - static constexpr index_t NWarp = config.template at<2>(); - static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); - static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); - static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; - - static constexpr index_t KPackA = WarpGemm::kAKPack; - static constexpr index_t KPackB = WarpGemm::kBKPack; - }; - - public: - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - static constexpr bool TransposeC = TransposeC_; - - using Traits = GemmTraits_; - - using WarpGemm = typename Traits::WarpGemm; - using BlockGemmShape = typename Traits::BlockGemmShape; - - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - - static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; - static constexpr index_t MIterPerWarp = Traits::MIterPerWarp; - static constexpr index_t NIterPerWarp = Traits::NIterPerWarp; - - static constexpr index_t MWarp = Traits::MWarp; - static constexpr index_t NWarp = Traits::NWarp; - static constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1); - - // Note: distribution encodings have MIterPerWarp and NIterPerWarp contiguous because of scale - // packing. - - CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() - { - if constexpr(UseDefaultScheduler) - { - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple<>, - tuple<>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - return a_block_dstr_encode; - } - else - { - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<1, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - return a_block_dstr_encode; - } - } - - CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() - { - if constexpr(UseDefaultScheduler) - { - constexpr auto b_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple<>, - tuple<>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - return b_block_dstr_encode; - } - else - { - constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<1, 0>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - return b_block_dstr_encode; - } - } - - CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() - { - using c_distr_ys_major = std::conditional_t, sequence<1, 2>>; - if constexpr(UseDefaultScheduler) - { - using c_distr_ys_minor = std::conditional_t, sequence<0, 1>>; - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - c_distr_ys_major, - c_distr_ys_minor>{}; - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - - return c_block_dstr_encode; - } - else - { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - c_distr_ys_major, - sequence<1, 1>>{}; - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - - return c_block_dstr_encode; - } - } - - // C += A * B with MX scaling and packed-in-two (XdlPack) optimization - // Scale tensors contain pre-packed int32_t: each int32_t holds MXdlPack * KXdlPack e8m0_t - // values (for A) or NXdlPack * KXdlPack (for B), packed on the host. - // Uses OpSel (0-3) to select which byte within the packed int32_t for each MFMA call. - // XdlPack template parameters default to 2; fall back to 1 when iteration count is too small. - template - CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ABlockTensor& a_block_tensor, - const BBlockTensor& b_block_tensor, - const ScaleATensor& scale_a_tensor, - const ScaleBTensor& scale_b_tensor) const - { - static_assert(std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "Datatypes do not match BlockTensor datatypes!"); - - // check ABC-block-distribution - static_assert( - std::is_same_v, - remove_cvref_t>, - "A distribution is wrong!"); - static_assert( - std::is_same_v, - remove_cvref_t>, - "B distribution is wrong!"); - static_assert( - std::is_same_v, - remove_cvref_t>, - "C distribution is wrong!"); - - using AWarpDstr = typename WarpGemm::AWarpDstr; - using BWarpDstr = typename WarpGemm::BWarpDstr; - using CWarpDstr = typename WarpGemm::CWarpDstr; - - using AWarpTensor = typename WarpGemm::AWarpTensor; - using BWarpTensor = typename WarpGemm::BWarpTensor; - using CWarpTensor = typename WarpGemm::CWarpTensor; - - constexpr auto a_warp_y_lengths = - to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto b_warp_y_lengths = - to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - - constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; - constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - // Effective XdlPack: fall back to 1 when iteration count is insufficient - constexpr index_t MXdlPack = - (MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1; - constexpr index_t NXdlPack = - (NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1; - constexpr index_t KXdlPack = - (KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1; - - constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack; - constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack; - constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack; - - // hot loop with MX scaling and pre-packed int32_t scales: - // Outer loops iterate over pack groups (scale tile indices) - static_ford>{}([&](auto ii) { - constexpr auto ikpack = number{}]>{}; - constexpr auto impack = number{}]>{}; - // Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t) - auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const int32_t a_scale_packed = bit_cast(scale_a_slice[number<0>{}]); - - static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { - // Get pre-packed int32_t B scale - auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const int32_t b_scale_packed = bit_cast(scale_b_slice[number<0>{}]); - - // Inner loops: issue MFMAs within the pack group using OpSel - static_ford>{}([&](auto jj) { - constexpr auto ikxdl = number{}]>{}; - constexpr auto imxdl = number{}]>{}; - constexpr auto kIter = ikpack * KXdlPack + ikxdl; - constexpr auto mIter = impack * MXdlPack + imxdl; - - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - // OpSel for A: selects byte within packed int32_t - constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl; - - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto nIter = inpack * NXdlPack + inxdl; - - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - // OpSel for B: selects byte within packed int32_t - constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl; - - // read C warp tensor from C block tensor - using c_iter_idx = std::conditional_t, - sequence>; - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM with MX scaling using pre-packed scale and OpSel - WarpGemm{}.template operator(), OpSelB>( - c_warp_tensor, - a_warp_tensor, - b_warp_tensor, - a_scale_packed, - b_scale_packed); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); - }); - }); - } - - CK_TILE_DEVICE static constexpr auto MakeCBlockTile() - { - return make_static_distributed_tensor( - make_static_tile_distribution(MakeCBlockDistributionEncode())); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp deleted file mode 100644 index edeb1a5214..0000000000 --- a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp +++ /dev/null @@ -1,863 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/common.hpp" -#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" -#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp" - -namespace ck_tile { - -template -struct MXGemmPipelineAgBgCrCompAsyncEightWaves; - -namespace detail { -template -struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy; - -template -struct MXGemmKernelScaleTraits -{ - static constexpr index_t ScaleGranularityK = Pipeline::ScaleGranularityK; - static constexpr index_t MXdlPack = Pipeline::MXdlPack; - static constexpr index_t NXdlPack = Pipeline::NXdlPack; - static constexpr index_t KXdlPack = Pipeline::KXdlPack; -}; - -template -struct MXGemmKernelScaleTraits> -{ - using PolicyTraits = MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy; - - static constexpr index_t ScaleGranularityK = PolicyTraits::BlockScaleSize; - static constexpr index_t MXdlPack = PolicyTraits::MXdlPack; - static constexpr index_t NXdlPack = PolicyTraits::NXdlPack; - static constexpr index_t KXdlPack = PolicyTraits::KXdlPack; -}; -} // namespace detail - -template , - typename ScaleN = MXScalePointer, - index_t NumATensor = 1, - index_t NumBTensor = 1, - index_t NumDTensor = 0> -struct MXGemmKernelArgs : UniversalGemmKernelArgs -{ - using Base = UniversalGemmKernelArgs; - - CK_TILE_HOST MXGemmKernelArgs(const std::array& as_ptr_, - const std::array& bs_ptr_, - const std::array& ds_ptr_, - void* e_ptr_, - index_t k_batch_, - index_t M_, - index_t N_, - index_t K_, - const std::array& stride_As_, - const std::array& stride_Bs_, - const std::array& stride_Ds_, - index_t stride_E_, - ScaleM scale_m_ptr_, - ScaleN scale_n_ptr_) - : Base{as_ptr_, - bs_ptr_, - ds_ptr_, - e_ptr_, - M_, - N_, - K_, - stride_As_, - stride_Bs_, - stride_Ds_, - stride_E_, - k_batch_}, - scale_m_ptr(scale_m_ptr_), - scale_n_ptr(scale_n_ptr_) - { - } - - ScaleM scale_m_ptr; - ScaleN scale_n_ptr; -}; - -template -struct MXGemmKernel : UniversalGemmKernel -{ - using Underlying = UniversalGemmKernel; - - using TilePartitioner = remove_cvref_t; - using MXGemmPipeline = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; - using DsLayout = remove_cvref_t; - using DsDataType = remove_cvref_t; - static constexpr index_t KernelBlockSize = MXGemmPipeline::BlockSize; - static constexpr bool UsePersistentKernel = MXGemmPipeline::UsePersistentKernel; - - // Below type is actually accumulation data type - the output of block GEMM. - using EDataType = remove_cvref_t; - - static constexpr auto I0 = number<0>(); - static constexpr auto I1 = number<1>(); - static constexpr auto I2 = number<2>(); - static constexpr auto I3 = number<3>(); - static constexpr auto I4 = number<4>(); - static constexpr auto I5 = number<5>(); - - static constexpr index_t NumATensor = Underlying::AsDataType::size(); - static constexpr index_t NumBTensor = Underlying::BsDataType::size(); - static constexpr index_t NumDTensor = Underlying::DsDataType::size(); - - using ADataType = remove_cvref_t>; - using BDataType = remove_cvref_t>; - - static constexpr auto MThreadPerXdl = BlockGemmShape::WarpTile::at(number<0>{}); - static constexpr auto NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{}); - static constexpr auto KThreadPerXdl = 64 / MThreadPerXdl; - - static constexpr auto APackedSize = numeric_traits::PackedSize; - static constexpr auto BPackedSize = numeric_traits::PackedSize; - - // XdlPack: desired packing of e8m0_t scale values into int32_t - using ScaleTraits = detail::MXGemmKernelScaleTraits; - static constexpr index_t ScaleGranularityK = ScaleTraits::ScaleGranularityK; - static constexpr index_t MXdlPack = ScaleTraits::MXdlPack; - static constexpr index_t NXdlPack = ScaleTraits::NXdlPack; - static constexpr index_t KXdlPack = ScaleTraits::KXdlPack; - - // Effective pack sizes: fall back to 1 when dimension is too small - using BlockWarps_ = typename BlockGemmShape::BlockWarps; - static constexpr index_t MPerBlock_ = BlockGemmShape::kM; - static constexpr index_t NPerBlock_ = BlockGemmShape::kN; - static constexpr index_t KPerBlock_ = BlockGemmShape::kK; - static constexpr index_t MWarp_ = BlockWarps_::at(number<0>{}); - static constexpr index_t NWarp_ = BlockWarps_::at(number<1>{}); - static constexpr index_t KPerXdl_ = BlockGemmShape::WarpTile::at(number<2>{}); - static constexpr index_t MIterPerWarp_ = MPerBlock_ / (MWarp_ * MThreadPerXdl); - static constexpr index_t NIterPerWarp_ = NPerBlock_ / (NWarp_ * NThreadPerXdl); - static constexpr index_t KIterPerWarp_ = KPerBlock_ / KPerXdl_; - - static constexpr index_t MXdlPackEff = - (MIterPerWarp_ >= MXdlPack && MIterPerWarp_ % MXdlPack == 0) ? MXdlPack : 1; - static constexpr index_t NXdlPackEff = - (NIterPerWarp_ >= NXdlPack && NIterPerWarp_ % NXdlPack == 0) ? NXdlPack : 1; - static constexpr index_t KXdlPackEff = - (KIterPerWarp_ >= KXdlPack && KIterPerWarp_ % KXdlPack == 0) ? KXdlPack : 1; - - static constexpr int kBlockPerCu = 1; - - // Scale block size (same constant used by MXGemmPipeline): each e8m0 scale covers 32 K elements - static constexpr index_t ScaleBlockSize = 32; - - // Padding flags pulled from pipeline so the kernel can pad the (unscaled) C and scale views - // consistently with the A/B views that the pipeline already pads via - // Underlying::MakeA/BBlockWindows. - static constexpr bool kPadM = MXGemmPipeline::kPadM; - static constexpr bool kPadN = MXGemmPipeline::kPadN; - static constexpr bool kPadK = MXGemmPipeline::kPadK; - - static_assert(DsLayout::size() == DsDataType::size(), - "The size of DsLayout and DsDataType should be the same"); - - // ------------------------------------------------------------------ - // Compile-time padding-support invariants for the MX comp-async pipeline. - // - // - K padding is NOT supported: async_load_tile issues vector buffer reads whose - // OOB check is per-vector-start, so a vector that straddles the K pad boundary - // pulls in data from the adjacent row / next K tile rather than zero. The packed - // scale tile has the same vector-load property. Until the async path learns how - // to do per-element pad masking, we forbid kPadK at compile time. - // - // - kPadM / kPadN are supported only when the GEMM has at least one full block - // along that dimension; the CShuffleEpilogue's LDS shuffle uses thread positions - // that do not all participate when the entire dimension is smaller than a tile - // (resulting in zeros being written into in-range output rows). The "entire - // dimension < tile" case is rejected at runtime in IsSupportedArgument; we - // cannot statically catch it because M and N are runtime values. - // ------------------------------------------------------------------ - static_assert(!kPadK, - "MX GEMM (comp-async pipeline): K padding (kPadK = true) is not supported. " - "The async vector loads do not mask elements that straddle the K pad " - "boundary, so partial K tiles produce silently wrong results. Choose K so " - "that K is a multiple of KPerBlock * k_batch."); - - // Single source of truth for the split-K atomic-add precondition, shared by the runtime - // check in IsSupportedArgument and the atomic_add dispatch in operator(). Split-K - // accumulates each k_id's partial C tile with atomic_add; the CShuffle epilogue can only - // emit atomic_add for fp16/bf16 outputs when the C vector size is even. For an odd vector - // size that combination is not instantiated, so such a config cannot run split-K. For all - // shipped tile shapes GetVectorSizeC() is even, so this is defensive rather than reachable. - static constexpr bool kSplitKAtomicAddSupported = - EpiloguePipeline::GetVectorSizeC() % 2 == 0 || !is_any_of::value; - - [[nodiscard]] CK_TILE_HOST static const std::string GetName() - { - // clang-format off - return concat('_', "mx_gemm", gemm_prec_str, MXGemmPipeline::GetName()); - // clang-format on - } - - template - using KernelArgs = MXGemmKernelArgs; - - template - CK_TILE_HOST static auto MakeKernelArgs(const std::array& as_ptr, - const std::array& bs_ptr, - const std::array& ds_ptr, - void* e_ptr, - index_t k_batch, - index_t M, - index_t N, - index_t K, - const std::array& stride_As, - const std::array& stride_Bs, - const std::array& stride_Ds, - index_t stride_E, - ScaleM scale_m_ptr, - ScaleN scale_n_ptr) - { - return KernelArgs(as_ptr, - bs_ptr, - ds_ptr, - e_ptr, - k_batch, - M, - N, - K, - stride_As, - stride_Bs, - stride_Ds, - stride_E, - scale_m_ptr, - scale_n_ptr); - } - - template - CK_TILE_HOST static constexpr auto GridSize(const KernelArgs& kargs) - { - const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N); - - if constexpr(UsePersistentKernel) - { - hipDeviceProp_t prop; - int deviceId = 0; // default device - - int dync_smem_size = 0; - int maxActiveBlocksPerCU = 0; - - if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess) - throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") + - hipGetErrorName(hipGetLastError())); - - if(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &maxActiveBlocksPerCU, - reinterpret_cast( - kentry<1, MXGemmKernel, remove_cvref_t>), - KernelBlockSize, - dync_smem_size) != hipSuccess) - throw std::runtime_error( - std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") + - hipGetErrorName(hipGetLastError())); - - const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU; - const int actual_grid_size = min(persistent_block_size, total_work_tile_cnt); - - // blockIdx.z selects the K split. For split-K, each k_id gets its own set of - // persistent blocks looping over the MxN tile space. - return dim3(actual_grid_size, 1, kargs.k_batch); - } - else - { - // Non-persistent: grid is (MxN tiles) x 1 x k_batch. blockIdx.z selects the K split. - return dim3(total_work_tile_cnt, 1, kargs.k_batch); - } - } - - template - CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) - { - // Reject unsupported combinations early; the MX pipeline silently produces wrong - // results otherwise (OOB reads, partial-tile shuffle artifacts, mis-aligned splits). - // See the static_assert block at the top of MXGemmKernel for the rationale behind - // each constraint. - const bool log = ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)); - - if(kargs.k_batch < 1) - { - if(log) - CK_TILE_ERROR("MX GEMM: k_batch must be >= 1."); - return false; - } - - // Split-K needs the atomic_add epilogue; reject configs that cannot emit it (fp16/bf16 - // output with an odd C vector size) instead of silently skipping the accumulation. - if constexpr(!kSplitKAtomicAddSupported) - { - if(kargs.k_batch > 1) - { - if(log) - CK_TILE_ERROR("MX GEMM: split-K (k_batch > 1) requires an even C vector size " - "for fp16/bf16 outputs (atomic_add epilogue constraint)."); - return false; - } - } - - // Split-K derives this k_id's logical K start from the row-major SplitKBatchOffset - // (as_k_split_offset[0]) to offset the packed-scale / flat-B windows; for column-major A - // that field is stride-scaled, so split-K with non-row-major A is not yet supported. - // (k_batch == 1 is unaffected -- the offset is 0 and unused.) When col-major A lands for - // non-preshuffle, extend the split-K K-offset here instead of this reject. - if constexpr(!std::is_same_v) - { - if(kargs.k_batch > 1) - { - if(log) - CK_TILE_ERROR("MX GEMM: split-K (k_batch > 1) currently requires row-major A."); - return false; - } - } - - // Preshuffle split-K relies on each split starting on a K-block boundary that is also - // aligned to the host-preshuffled scale layout's packed-K granularity, so that the flat-B - // and scale windows start at the same logical K. Split boundaries are KPerBlock-aligned - // (enforced by the "K % (KPerBlock * k_batch)" check below); it therefore suffices that the - // preshuffled scale K-block granularity (ScaleGranularityK * KXdlPackEff * KThreadPerXdl) - // divides KPerBlock. - if constexpr(MXGemmPipeline::Preshuffle) - { - constexpr index_t preshuffle_scale_k_granularity = - ScaleGranularityK * KXdlPackEff * KThreadPerXdl; - if(kargs.k_batch > 1 && - (TilePartitioner::KPerBlock % preshuffle_scale_k_granularity != 0)) - { - if(log) - CK_TILE_ERROR("MX GEMM: preshuffle split-K requires KPerBlock to be a multiple " - "of ScaleGranularityK * KXdlPackEff * KThreadPerXdl."); - return false; - } - } - - // M / N must be a multiple of the block tile when padding is disabled. - if(!kPadM && (kargs.M % TilePartitioner::MPerBlock != 0)) - { - if(log) - CK_TILE_ERROR("MX GEMM: M must be a multiple of MPerBlock when kPadM is false. " - "Enable kPadM on the GEMM config to run this shape."); - return false; - } - if(!kPadN && (kargs.N % TilePartitioner::NPerBlock != 0)) - { - if(log) - CK_TILE_ERROR("MX GEMM: N must be a multiple of NPerBlock when kPadN is false. " - "Enable kPadN on the GEMM config to run this shape."); - return false; - } - - // CShuffleEpilogue cannot run with a single partial tile along M or N: the shuffle's - // LDS write/read pattern leaves some in-range output rows/cols at zero. Reject these - // pathological shapes whether or not kPadM/kPadN is enabled. - if(kargs.M < TilePartitioner::MPerBlock) - { - if(log) - CK_TILE_ERROR("MX GEMM: M must be >= MPerBlock. Partial-only M tiles are not " - "supported by the MX CShuffleEpilogue."); - return false; - } - if(kargs.N < TilePartitioner::NPerBlock) - { - if(log) - CK_TILE_ERROR("MX GEMM: N must be >= NPerBlock. Partial-only N tiles are not " - "supported by the MX CShuffleEpilogue."); - return false; - } - - // K padding is unconditionally rejected (kPadK is also a compile-time error -- see the - // static_assert at the top of MXGemmKernel). Every split must consume an exact number - // of K tiles, otherwise the async vector loads read garbage past the K boundary. - const index_t k_tile = TilePartitioner::KPerBlock; - if(kargs.K % (k_tile * kargs.k_batch) != 0) - { - if(log) - CK_TILE_ERROR( - "MX GEMM: K must be a multiple of KPerBlock * k_batch. The MX comp-async " - "pipeline does not currently support K padding (vector loads across the K " - "pad boundary read garbage); pick aligned K dimensions or change k_batch."); - return false; - } - - // Scales are granular in K: each packed int32_t covers ScaleBlockSize * KXdlPackEff - // consecutive K elements. Every split-K boundary must land on that granularity so that - // each split can compute a packed-scale K offset. K1 is the WarpTile K, which is a - // multiple of that granularity for all shipped configs, but be defensive. - constexpr index_t scale_granularity_k = ScaleBlockSize * KXdlPackEff; - if(kargs.k_batch > 1) - { - // splitk_batch_offset allocates K in units of K1 (warp-tile K). If K1 itself is - // not a multiple of the scale granularity, split-K is not safe. - constexpr index_t K1 = BlockGemmShape::WarpTile::at(number<2>{}); - static_assert(K1 % scale_granularity_k == 0, - "MX GEMM: WarpTile K must be a multiple of ScaleBlockSize * KXdlPack " - "to support split-K."); - // Defensive runtime check: K must split evenly along K1 boundaries so that each - // k_id consumes a whole number of warp-tile K chunks (and therefore a whole - // number of packed-scale K elements). - if(kargs.K % (K1 * kargs.k_batch) != 0) - { - if(log) - CK_TILE_ERROR("MX GEMM: with k_batch > 1, K must be a multiple of WarpTile_K * " - "k_batch so that every split lands on a packed-scale boundary."); - return false; - } - } - - // Delegate the remaining shape/vector-size checks to the universal kernel. All MX - // pipelines (comp-async, eight-waves, preshuffle) expose the templated - // GetVectorSize{A,B}() that UniversalGemmKernel::IsSupportedArgument requires. - return Underlying::IsSupportedArgument( - static_cast(kargs)); - } - - using SplitKBatchOffset = typename Underlying::SplitKBatchOffset; - - // Create C block window following UniversalGemmKernel pattern - template - CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr, - const KernelArgs& kargs, - const index_t i_m, - const index_t i_n) - { - // Create tensor view for E/C tensor - constexpr index_t vector_size = EpiloguePipeline::GetVectorSizeC(); - const auto& e_tensor_view = [&]() -> auto { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_E, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_E), - number<1>{}, - number{}); - } - }(); - - // Pad both dims so OOB C writes (including partial trailing tiles where M < MPerBlock - // or N < NPerBlock) are masked by the pad transform. - const auto& e_pad_view = pad_tensor_view( - e_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - - // Create block window - auto c_block_window = make_tile_window( - e_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); - - return c_block_window; - } - - // Create scale A block windows with packed int32_t layout. - // Host packs (MXdlPack x KXdlPack) e8m0_t values into a single int32_t, producing a - // packed tensor of shape [M/MXdlPackEff, K/ScaleBlockSize/KXdlPackEff]. - // - // k_elem_offset: starting K element index for this block (0 unless split-K). - // Must be a multiple of ScaleBlockSize * KXdlPackEff. - template - CK_TILE_DEVICE static auto MakeScaleABlockWindows(const KernelArgs& kargs, - const index_t i_m, - const index_t k_elem_offset = 0) - { - auto scale_a = kargs.scale_m_ptr; - static_assert(ScaleM::GranularityK == ScaleGranularityK); - if constexpr(MXGemmPipeline::Preshuffle) - { - const auto scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPackEff * MThreadPerXdl)); - const auto scale_packs_k = kargs.K / ScaleGranularityK / (KXdlPackEff * KThreadPerXdl); - - const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed( - make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl)); - const auto scale_a_desc = transform_tensor_descriptor( - scale_a_naive_desc, - make_tuple(make_merge_transform(make_tuple(scale_packs_m, MThreadPerXdl)), - make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), - make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto scale_a_tensor_view = make_tensor_view( - reinterpret_cast(scale_a.ptr), scale_a_desc); - - // For split-K (k_batch > 1) advance the scale origin into this k_id's packed-K slice. - // The merged-K axis of the preshuffled scale view, merge(scale_packs_k, KThreadPerXdl), - // has the same total extent K/(ScaleGranularityK*KXdlPackEff) as the non-preshuffle - // layout, and split boundaries are KPerBlock-aligned (see IsSupportedArgument), so the - // K-block offset is the same closed form used by the non-preshuffle branch. - const index_t k_scale_offset = k_elem_offset / ScaleGranularityK / KXdlPackEff; - return make_tile_window( - scale_a_tensor_view, - make_tuple( - number{}, - number{}), - {i_m / MXdlPackEff, k_scale_offset}); - } - else - { - const auto scale_k_packed = kargs.K / ScaleGranularityK / KXdlPackEff; - const auto scale_m_packed = kargs.M / MXdlPackEff; - - // A scale tensor view - layout [M/MXdlPackEff, K/32/KXdlPackEff] with int32_t elements - const auto scale_a_tensor_view = make_naive_tensor_view( - reinterpret_cast(scale_a.ptr), - make_tuple(scale_m_packed, scale_k_packed), - make_tuple(scale_k_packed, 1)); - - // Pad the scale view so partial trailing tiles along M are handled safely (OOB scale - // loads return zero; with A also zero on the padded region the contribution is zero - // regardless of scale value). kPadK is statically disabled, so K never actually pads. - const auto scale_a_pad_view = pad_tensor_view( - scale_a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - - // For split-K (k_batch > 1) advance the scale origin into this k_id's packed-K slice. - const index_t k_scale_offset = k_elem_offset / ScaleGranularityK / KXdlPackEff; - - // Tile window shape: [MPerBlock/MXdlPackEff, KPerBlock/32/KXdlPackEff] - return make_tile_window( - scale_a_pad_view, - make_tuple(number{}, - number{}), - {i_m / MXdlPackEff, k_scale_offset}); - } - } - - template - CK_TILE_DEVICE static auto - MakeBFlatBlockWindows(const std::array& bs_ptr, - const KernelArgs& kargs, - const index_t i_n, - const index_t k_elem_offset = 0) - { - static_assert(NumBTensor == 1, "MX GEMM preshuffle currently supports one B tensor"); - - constexpr index_t kKPerBlock = MXGemmPipeline::kKPerBlock; - constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1); - constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile; - const index_t kFlatKBlocks = kargs.K / kKPerBlock; - const index_t kFlatN = kargs.N / kNWarpTile; - - // For split-K (k_batch > 1) advance the flat-B K origin into this k_id's K slice. The - // flat layout stores K as kFlatKBlocks blocks of flatKPerBlock elements each, and split - // boundaries are KPerBlock-aligned (enforced in IsSupportedArgument), so the offset lands - // on a clean K-block boundary. The universal bs_k_split_offset is not used here: it is - // derived from the logical B stride and does not match the preshuffled flat layout. - const index_t k_flat_offset = (k_elem_offset / kKPerBlock) * flatKPerBlock; - - auto b_flat_tensor_view = [&]() { - static_assert(flatKPerBlock % MXGemmPipeline::GetVectorSizeB() == 0, - "wrong! vector size for preshuffled B tensor"); - auto naive_desc = make_naive_tensor_descriptor_packed( - make_tuple(kFlatN, kFlatKBlocks, number{})); - auto desc = transform_tensor_descriptor( - naive_desc, - make_tuple(make_pass_through_transform(kFlatN), - make_merge_transform_v3_division_mod( - make_tuple(kFlatKBlocks, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return make_tensor_view(bs_ptr[number<0>{}], desc); - }(); - - return generate_tuple( - [&](auto) { - return make_tile_window(b_flat_tensor_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), - static_cast(k_flat_offset)}); - }, - number{}); - } - - template - CK_TILE_DEVICE static auto MakeScaleBBlockWindows(const KernelArgs& kargs, - const index_t i_n, - const index_t k_elem_offset = 0) - { - auto scale_b = kargs.scale_n_ptr; - static_assert(ScaleN::GranularityK == ScaleGranularityK); - - if constexpr(MXGemmPipeline::Preshuffle) - { - const auto scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPackEff * NThreadPerXdl)); - const auto scale_packs_k = kargs.K / ScaleGranularityK / (KXdlPackEff * KThreadPerXdl); - - const auto scale_b_naive_desc = make_naive_tensor_descriptor_packed( - make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl)); - const auto scale_b_desc = transform_tensor_descriptor( - scale_b_naive_desc, - make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)), - make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), - make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto scale_b_tensor_view = make_tensor_view( - reinterpret_cast(scale_b.ptr), scale_b_desc); - - // For split-K (k_batch > 1) advance the scale origin into this k_id's packed-K slice. - // The merged-K axis of the preshuffled scale view, merge(scale_packs_k, KThreadPerXdl), - // has the same total extent K/(ScaleGranularityK*KXdlPackEff) as the non-preshuffle - // layout, and split boundaries are KPerBlock-aligned (see IsSupportedArgument), so the - // K-block offset is the same closed form used by the non-preshuffle branch. - const index_t k_scale_offset = k_elem_offset / ScaleGranularityK / KXdlPackEff; - return make_tile_window( - scale_b_tensor_view, - make_tuple( - number{}, - number{}), - {i_n / NXdlPackEff, k_scale_offset}); - } - else - { - const auto scale_k_packed = kargs.K / ScaleGranularityK / KXdlPackEff; - const auto scale_n_packed = kargs.N / NXdlPackEff; - - // B scale tensor view - [N/NXdlPackEff, K/32/KXdlPackEff] of int32_t - const auto scale_b_tensor_view = make_naive_tensor_view( - reinterpret_cast(scale_b.ptr), - make_tuple(scale_n_packed, scale_k_packed), - make_tuple(scale_k_packed, 1)); - - // Pad the scale view so partial trailing tiles along N are handled safely (OOB scale - // loads return zero; with B also zero on the padded region the contribution is zero - // regardless of scale value). kPadK is statically disabled, so K never actually pads. - const auto scale_b_pad_view = pad_tensor_view( - scale_b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - - // For split-K (k_batch > 1) advance the scale origin into this k_id's packed-K slice. - const index_t k_scale_offset = k_elem_offset / ScaleGranularityK / KXdlPackEff; - - // Tile window shape: [NPerBlock/NXdlPackEff, KPerBlock/32/KXdlPackEff] - return make_tile_window( - scale_b_pad_view, - make_tuple(number{}, - number{}), - {i_n / NXdlPackEff, k_scale_offset}); - } - } - - template - CK_TILE_DEVICE static void RunMxGemm(const std::array& as_ptr, - const std::array& bs_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* smem_ptr, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t i_m, - const index_t i_n, - const index_t k_elem_offset = 0) - { - // Create block windows directly, following the new pattern from UniversalGemmKernel - // i_m and i_n are element offsets (iM * MPerBlock, iN * NPerBlock), not tile indices - const auto& a_block_window = [&]() { - if constexpr(MXGemmPipeline::Preshuffle) - { - // The preshuffle A async-load (MakeMX_AAsyncLoadBytesDramWindow) rebuilds the A - // view with a *packed* descriptor, i.e. it assumes the leading (M) stride equals - // the view's K extent. That only holds when the extent equals stride_A, which is - // the case for k_batch == 1 (splitted_k == K) but NOT for split-K (splitted_k < K): - // a packed extent of splitted_k would stride M by splitted_k instead of stride_A - // and read the wrong rows (only row 0 lands correctly). Use the full K extent so - // the packed M stride matches stride_A. The as_ptr K-offset already selects this - // k_id's slice and num_loop bounds the blocks read, so reads stay within - // [as_k_split_offset, as_k_split_offset + splitted_k) <= K (in-allocation). - return Underlying::MakeABlockWindows(as_ptr, kargs, kargs.K, i_m); - } - else - { - return Underlying::MakeABlockWindows( - as_ptr, kargs, splitk_batch_offset.splitted_k, i_m); - } - }(); - const auto& b_block_window = [&]() { - if constexpr(MXGemmPipeline::Preshuffle) - { - return MakeBFlatBlockWindows(bs_ptr, kargs, i_n, k_elem_offset); - } - else - { - return Underlying::MakeBBlockWindows( - bs_ptr, kargs, splitk_batch_offset.splitted_k, i_n); - } - }(); - const auto& d_block_window = Underlying::MakeDBlockWindows(ds_ptr, kargs, i_m, i_n); - - // Create scale block windows. For split-K (k_batch > 1), k_elem_offset advances the - // scale origin into the correct packed-K slice for this k_id; otherwise it is zero. - const auto& scale_a_block_window = MakeScaleABlockWindows(kargs, i_m, k_elem_offset); - const auto& scale_b_block_window = MakeScaleBBlockWindows(kargs, i_n, k_elem_offset); - - const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); - - static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK - || ScaleM::GranularityMN == -1 // or ScaleA is disable - || ScaleN::GranularityMN == -1, // or ScaleB is disable - "ScaleM and ScaleN should have the same GranularityK"); - - const auto& c_block_tile = [&]() { - if constexpr(MXGemmPipeline::Preshuffle) - { - constexpr index_t smem_ping_pong_size = MXGemmPipeline::GetSmemSize() / 2; - return MXGemmPipeline{}(a_block_window[number<0>{}], - b_block_window[number<0>{}], - scale_a_block_window, - scale_b_block_window, - num_loop, - smem_ptr, - static_cast(smem_ptr) + smem_ping_pong_size); - } - else - { - return MXGemmPipeline{}(a_block_window[number<0>{}], - b_block_window[number<0>{}], - scale_a_block_window, - scale_b_block_window, - num_loop, - smem_ptr); - } - }(); - - // Run Epilogue Pipeline - create C block window with the requested memory op (set for - // k_batch == 1, atomic_add for split-K so partial results accumulate into the same tile). - auto c_block_window = MakeCBlockWindows(e_ptr, kargs, i_m, i_n); - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr); - } - - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - return max(MXGemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); - } - - template - CK_TILE_DEVICE void operator()(KernelArgs kargs, - int partition_idx = get_block_id()) const - { -#if !defined(__gfx950__) - static_assert(sizeof(MXGemmPipeline) == 0, "CKTile MX GEMM kernels require gfx950."); - ignore = kargs; - ignore = partition_idx; -#else - const int total_work_tile_cnt = - amd_wave_read_first_lane(TilePartitioner::GridSize(kargs.M, kargs.N)); - - // Allocate shared memory for ping pong buffers - __shared__ char smem_ptr[GetSmemSize()]; - - // Support both persistent and non-persistent modes - do - { - const auto [iM, iN] = - TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx); - const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); - const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); - - // SplitKBatchOffset defaults its k_id to blockIdx.z, selecting this split's K slice. - const SplitKBatchOffset splitk_batch_offset( - static_cast(kargs)); - - // This k_id's logical K-element start. For row-major A, as_k_split_offset[0] is exactly - // that offset, so reuse it rather than recomputing the split formula; the packed-scale - // and flat-B K offsets are derived from it. Split-K with non-row-major A is rejected in - // IsSupportedArgument; for k_batch == 1 this value is 0 and unused for any layout. - const index_t k_elem_offset = - amd_wave_read_first_lane(splitk_batch_offset.as_k_split_offset[I0]); - - EDataType* e_ptr = static_cast(kargs.e_ptr); - - std::array as_ptr; - static_for<0, NumATensor, 1>{}([&](auto i) { - as_ptr[i] = static_cast(kargs.as_ptr[i]) + - splitk_batch_offset.as_k_split_offset[i] / APackedSize; - }); - - std::array bs_ptr; - static_for<0, NumBTensor, 1>{}([&](auto i) { - if constexpr(MXGemmPipeline::Preshuffle) - { - // The preshuffle (flat-B) path applies the per-split K offset to the flat - // window origin in MakeBFlatBlockWindows; bs_k_split_offset is derived from - // the logical B stride and would mis-offset the flat buffer. - bs_ptr[i] = static_cast(kargs.bs_ptr[i]); - } - else - { - bs_ptr[i] = static_cast(kargs.bs_ptr[i]) + - splitk_batch_offset.bs_k_split_offset[i] / BPackedSize; - } - }); - - // Dispatch epilogue: when k_batch > 1 each split accumulates a partial result into - // the same C tile, so we need atomic add (universal_gemm_kernel pattern). The - // fp16/bf16 even-vector-size precondition is captured once in kSplitKAtomicAddSupported - // and also rejected up front in IsSupportedArgument. - if(kargs.k_batch == 1) - { - RunMxGemm(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr, - kargs, - splitk_batch_offset, - i_m, - i_n, - /*k_elem_offset=*/0); - } - else - { - if constexpr(kSplitKAtomicAddSupported) - { - RunMxGemm(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr, - kargs, - splitk_batch_offset, - i_m, - i_n, - k_elem_offset); - } - } - partition_idx += gridDim.x; - } while(UsePersistentKernel && partition_idx < total_work_tile_cnt); -#endif - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp b/include/ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp deleted file mode 100644 index 0214f26a9c..0000000000 --- a/include/ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core.hpp" - -#if __clang_major__ >= 23 -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" -#endif -namespace ck_tile { - -template -struct MXScalePointer -{ - static constexpr int GranularityMN = SharedGranularityMN; - static constexpr int GranularityK = SharedGranularityK; - - static_assert(GranularityK != 0, - "GranularityK cannot be zero in primary template; " - "use the partial specialization for GranularityK == 0"); - - const ScaleType* ptr; - - CK_TILE_HOST_DEVICE MXScalePointer() = default; - CK_TILE_HOST_DEVICE MXScalePointer(const ScaleType* ptr_) : ptr(ptr_) {} - CK_TILE_HOST_DEVICE MXScalePointer(const ScaleType* ptr_, [[maybe_unused]] index_t length_) - : ptr(ptr_) - { - } - - CK_TILE_HOST_DEVICE MXScalePointer operator+(index_t offset) const - { - MXScalePointer ret; - if constexpr(GranularityMN == 0) - { - ret.ptr = ptr + offset / GranularityK; - } - else - { - ret.ptr = ptr + offset / GranularityMN / GranularityK; - } - return ret; - } - - CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const = delete; -}; - -template -struct MXScalePointer -{ - static constexpr int GranularityMN = SharedGranularityMN; - static constexpr int GranularityK = 0; - - static_assert(GranularityMN != 0); - - const ScaleType* ptr; - index_t length; - - CK_TILE_HOST_DEVICE MXScalePointer() = default; - CK_TILE_HOST_DEVICE MXScalePointer(const ScaleType* ptr_) : ptr(ptr_), length(1) {} - CK_TILE_HOST_DEVICE MXScalePointer(const ScaleType* ptr_, index_t length_) - : ptr(ptr_), length(length_) - { - } - - CK_TILE_HOST_DEVICE MXScalePointer operator+(index_t offset) const - { - MXScalePointer ret; - if constexpr(GranularityMN == 1) - { - ret.ptr = ptr + offset; - ret.length = length - offset; - } - else - { - ret.ptr = ptr + offset / GranularityMN; - ret.length = length - offset / GranularityMN; - } - return ret; - } - - CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const - { - // with additional oob check - if constexpr(GranularityMN == 1) - return i < length ? ptr[i] : 0; - else - return i / GranularityMN < length ? ptr[i / GranularityMN] : 0; - } -}; - -// shared granularityMN = -1 means no scale -template -struct MXScalePointer -{ - static constexpr int GranularityMN = -1; - static constexpr int GranularityK = 0; - - const ScaleType* ptr = nullptr; - - CK_TILE_HOST_DEVICE constexpr MXScalePointer() = default; - CK_TILE_HOST_DEVICE constexpr MXScalePointer(const ScaleType*) {} - CK_TILE_HOST_DEVICE constexpr MXScalePointer(const ScaleType*, index_t) {} - - CK_TILE_HOST_DEVICE constexpr MXScalePointer operator+(index_t) const - { - return MXScalePointer{}; - } - CK_TILE_HOST_DEVICE constexpr ScaleType operator[](index_t) const - { - return 1; // alway return 1, it doesn't change the result - } -}; - -} // namespace ck_tile -#if __clang_major__ >= 23 -#pragma clang diagnostic pop -#endif diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp deleted file mode 100644 index 32a1afc3c8..0000000000 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ /dev/null @@ -1,782 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once -#include "ck_tile/core.hpp" -#include "ck_tile/core/arch/arch.hpp" -#include "ck_tile/core/tensor/load_tile.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" -#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp" - -namespace ck_tile { - -// A Tile Window: global memory -// B Tile Window: global memory -// C Distributed tensor: register -// MX scaling support with OpSel -template -struct BaseMXGemmPipelineAgBgCrCompAsync -{ - static constexpr index_t PrefetchStages = 2; - static constexpr index_t PrefillStages = 1; - static constexpr index_t GlobalBufferNum = 1; - - static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; - - CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop) - { - // The prologue puts PrefetchStages + PrefillStages tiles in flight (2 LDS buffers + 1 - // register prefill) before the main loop, so the loop only runs when there is work - // beyond them; otherwise the tail drains the in-flight tiles. - return num_loop > PrefetchStages + PrefillStages; - } - - CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) - { - if(num_loop == 1) - { - return TailNumber::One; - } - if(num_loop % PrefetchStages == 1) - { - return TailNumber::Three; - } - else - { - return TailNumber::Two; - } - } - - template - CK_TILE_HOST_DEVICE static auto - TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) - { - // Handle all the valid cases. - if(has_hot_loop) - { - if(tail_number == TailNumber::Three) - { - return run_func(bool_constant{}, - integral_constant{}); - } - else if(tail_number == TailNumber::Two) - { - return run_func(bool_constant{}, - integral_constant{}); - } - } - else - { - if(tail_number == TailNumber::Three) - { - return run_func(bool_constant{}, - integral_constant{}); - } - else if(tail_number == TailNumber::Two) - { - return run_func(bool_constant{}, - integral_constant{}); - } - else - { - return (run_func(bool_constant{}, - integral_constant{})); - } - } - // If execution reaches here, it's an invalid tail_number because it wasn't handled above. -#if defined(__HIP_DEVICE_COMPILE__) - __builtin_unreachable(); -#else - throw std::logic_error( - "Invalid TailNumber: Only TailNumber::Three and TailNumber::Two are supported"); -#endif - } -}; - -/** - * @brief MX GEMM compute optimized pipeline version async; which is based on V4. - * - * This pipeline introduces asynchronous load from global memory to LDS, - * skipping the intermediate loading into pipeline registers. - * Supports MX scaling with e8m0 packed values and OpSel. - */ -template -struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync -{ - using Base = BaseMXGemmPipelineAgBgCrCompAsync; - using PipelineImplBase = GemmPipelineAgBgCrImplBase; - - using AsDataType = remove_cvref_t; - using BsDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - using AsLayout = remove_cvref_t; - using BsLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - - using AElementWise = remove_cvref_t; - using BElementWise = remove_cvref_t; - - using ALayout = remove_cvref_t>; - using BLayout = remove_cvref_t>; - - using ADataType = remove_cvref_t>; - using BDataType = remove_cvref_t>; - - static_assert(!std::is_same_v, "Not implemented"); - - static constexpr index_t ScaleGranularityK = Policy::ScaleGranularityK; - static constexpr index_t MXdlPack = Policy::MXdlPack; - static constexpr index_t NXdlPack = Policy::NXdlPack; - static constexpr index_t KXdlPack = Policy::KXdlPack; - - static constexpr index_t APackedSize = - ck_tile::numeric_traits>::PackedSize; - static constexpr index_t BPackedSize = - ck_tile::numeric_traits>::PackedSize; - - using BlockGemm = remove_cvref_t())>; - using I0 = number<0>; - using I1 = number<1>; - using I2 = number<2>; - - static constexpr index_t BlockSize = Problem::kBlockSize; - - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; - - template - static constexpr index_t GetVectorSizeA() - { - return Policy::template GetVectorSizeA(); - } - template - static constexpr index_t GetVectorSizeB() - { - return Policy::template GetVectorSizeB(); - } - static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } - - static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } - static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } - - static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - static constexpr index_t Preshuffle = Problem::Preshuffle; - - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; - static constexpr bool kPadK = Problem::kPadK; - - static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - - static constexpr auto Scheduler = Problem::Scheduler; - - static constexpr auto is_a_load_tr_v = bool_constant{}; - static constexpr auto is_b_load_tr_v = bool_constant{}; - -#if defined(__gfx950__) - static_assert(!(std::is_same_v && - !PipelineImplBase::is_a_load_tr), - "A=ColumnMajor requires transpose load (ds_read_tr), but it is disabled for " - "this K warp tile size. Use a smaller K warp tile (e.g. 32x32x64 MFMA)."); - static_assert(!(std::is_same_v && - !PipelineImplBase::is_b_load_tr), - "B=RowMajor requires transpose load (ds_read_tr), but it is disabled for " - "this K warp tile size. Use a smaller K warp tile (e.g. 32x32x64 MFMA)."); -#endif - - [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() - { - // clang-format off - return "COMPUTE_ASYNC"; - // clang-format on - } - - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - constexpr index_t smem_size = Policy::template GetSmemSize(); - return 2 * smem_size; - } - - CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() - { - return Policy::template IsTransposeC(); - } - - template - struct PipelineImpl : public PipelineImplBase - { - }; - - template <> - struct PipelineImpl : public PipelineImplBase - { - using Base = PipelineImplBase; - - CK_TILE_DEVICE static constexpr auto HotLoopScheduler() - { - constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{}); - constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{}); - constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{}); - - constexpr index_t WaveSize = get_warp_size(); - - constexpr index_t A_Buffer_Load_Inst_Num = - MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); - constexpr index_t B_Buffer_Load_Inst_Num = - NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); - - constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / - (BlockSize / WaveSize) / - (MPerXDL * NPerXDL * KPerXDL); - - constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num; - constexpr auto num_issue = num_buffer_load_inst; - - static_for<0, num_buffer_load_inst, 1>{}([&](auto i) { - // TODO: this will likely need to be redesigned after (1) changes to reading from - // LDS and (2) re-profiling - ignore = i; - __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA : 1 - __builtin_amdgcn_sched_group_barrier( - LLVMSchedGroupMask::DS_READ, 1, 0); // DS read : 1 - __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA: 1 - __builtin_amdgcn_sched_group_barrier( - LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read :1 - __builtin_amdgcn_sched_group_barrier( - LLVMSchedGroupMask::MFMA, C_MFMA_Inst_Num / num_issue - 2, 0); // MFMA : 6 - }); - __builtin_amdgcn_sched_barrier(0); - } - - template ::value && - is_detected::value, - bool>* = nullptr> - CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - index_t num_loop, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const - { - // TODO support multi-ABD - static_assert(1 == std::tuple_size_v); - static_assert(1 == std::tuple_size_v); - using ADramBlockWindowTmp = - remove_cvref_t{}, AsDramBlockWindowTmp>>; - using BDramBlockWindowTmp = - remove_cvref_t{}, BsDramBlockWindowTmp>>; - // TODO currently fused elementwise are not supported - ignore = a_element_func; - ignore = b_element_func; - static_assert(std::is_same_v, - element_wise::PassThrough>); - static_assert(std::is_same_v, - element_wise::PassThrough>); - static_assert( - std::is_same_v> && - std::is_same_v>, - "Data Type conflict on A and B matrix input data type."); - - constexpr bool is_a_col_major = - std::is_same_v; - constexpr bool is_b_row_major = std::is_same_v; - - static_assert(is_a_col_major - ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) - : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), - "A block window has incorrect lengths for defined ALayout!"); - static_assert(is_b_row_major - ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) - : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), - "B block window has incorrect lengths for defined BLayout!"); - - ////////////// global window & register ///////////////// - // A DRAM tile window(s) for load - auto a_tile_windows = generate_tuple( - [&](auto idx) { - return make_tile_window( - a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeADramTileDistribution()); - }, - number{}); - // B DRAM window(s) for load - auto b_tile_windows = generate_tuple( - [&](auto idx) { - return make_tile_window( - b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeBDramTileDistribution()); - }, - number{}); - - // for XOR swizzle: policy makes async global-to-LDS stores match LDS reads - // otherwise: no change to view - auto a_async_tile_windows = generate_tuple( - [&](auto idx) { - return make_tile_window(Policy::template MakeAsyncLoadADramWindow( - a_tile_windows[number{}]), - Policy::template MakeADramTileDistribution()); - }, - number{}); - - auto b_async_tile_windows = generate_tuple( - [&](auto idx) { - return make_tile_window(Policy::template MakeAsyncLoadBDramWindow( - b_tile_windows[number{}]), - Policy::template MakeBDramTileDistribution()); - }, - number{}); - - ////////////// MX Scale windows (pre-packed int32_t) ///////////////// - // Get WarpGemm configuration - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - constexpr index_t MWarp = BlockWarps::at(I0{}); - constexpr index_t NWarp = BlockWarps::at(I1{}); - - // Compute effective XdlPack sizes (fall back to 1 when iter count < pack) - constexpr index_t MPerXdl = WarpTile::at(I0{}); - constexpr index_t NPerXdl = WarpTile::at(I1{}); - constexpr index_t KPerXdl = WarpTile::at(I2{}); - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); - constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); - constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; - - constexpr index_t MXdlPackEff = - (MIterPerWarp >= Policy::MXdlPack && MIterPerWarp % Policy::MXdlPack == 0) - ? Policy::MXdlPack - : 1; - constexpr index_t NXdlPackEff = - (NIterPerWarp >= Policy::NXdlPack && NIterPerWarp % Policy::NXdlPack == 0) - ? Policy::NXdlPack - : 1; - constexpr index_t KXdlPackEff = - (KIterPerWarp >= Policy::KXdlPack && KIterPerWarp % Policy::KXdlPack == 0) - ? Policy::KXdlPack - : 1; - - // Packed scale dimensions - constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleGranularityK / KXdlPackEff; - - // Scale tensor views and base origins for creating tile windows per iteration - const auto& scale_a_tensor_view = scale_a_window.get_bottom_tensor_view(); - const auto& scale_b_tensor_view = scale_b_window.get_bottom_tensor_view(); - auto scale_a_base_origin = scale_a_window.get_window_origin(); - auto scale_b_base_origin = scale_b_window.get_window_origin(); - - // Create scale windows with packed int32_t dimensions - auto scale_a_dram_window = make_tile_window( - scale_a_tensor_view, - make_tuple(number{}, number{}), - scale_a_base_origin, - Policy::template MakeMX_ScaleA_DramTileDistribution()); - - auto scale_b_dram_window = make_tile_window( - scale_b_tensor_view, - make_tuple(number{}, number{}), - scale_b_base_origin, - Policy::template MakeMX_ScaleB_DramTileDistribution()); - - // this pipeline has a pair of LDS buffers per logical tile - auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); - auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); - - constexpr auto a_lds_shape = []() { - if constexpr(is_a_load_tr_v) - return make_tuple(number{}, number{}); - else - return make_tuple(number{}, number{}); - }(); - - constexpr auto b_lds_shape = []() { - if constexpr(is_b_load_tr_v) - return make_tuple(number{}, number{}); - else - return make_tuple(number{}, number{}); - }(); - - // LDS tile windows for storing, one per LDS buffer - auto a_copy_lds_window0 = make_tile_window(a_lds_block0, a_lds_shape, {0, 0}); - - auto a_copy_lds_window1 = make_tile_window(a_lds_block1, a_lds_shape, {0, 0}); - - auto b_copy_lds_window0 = make_tile_window(b_lds_block0, b_lds_shape, {0, 0}); - - auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0}); - - // initialize DRAM window steps, used to advance the DRAM windows - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; - using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; - constexpr ADramTileWindowStep a_dram_tile_window_step = - is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); - constexpr BDramTileWindowStep b_dram_tile_window_step = - is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); - - // read A(0), B(0) from DRAM to LDS window(0) - // and advance the DRAM windows - Base::GlobalPrefetchAsync( - a_copy_lds_window0, a_async_tile_windows[number<0>{}], a_dram_tile_window_step); - Base::GlobalPrefetchAsync( - b_copy_lds_window0, b_async_tile_windows[number<0>{}], b_dram_tile_window_step); - - // Initialize block gemm and C block tile - auto block_gemm = BlockGemm(); - auto c_block_tile = block_gemm.MakeCBlockTile(); - clear_tile(c_block_tile); - - // read A(1), B(1) from DRAM to LDS window(1) - // and advance the DRAM windows - Base::GlobalPrefetchAsync( - a_copy_lds_window1, a_async_tile_windows[number<0>{}], a_dram_tile_window_step); - Base::GlobalPrefetchAsync( - b_copy_lds_window1, b_async_tile_windows[number<0>{}], b_dram_tile_window_step); - - // tile distribution for the register tiles - using ALdsTileDistr = - decltype(make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode())); - using BLdsTileDistr = - decltype(make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode())); - - using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr{})); - using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr{})); - - // register tiles; double buffering -> a register tile corresponds to a LDS tile window - ALdsTile a_block_tile0, a_block_tile1; - BLdsTile b_block_tile0, b_block_tile1; - - // Some sanity checks on the LDS tile sizes - static_assert(sizeof(ALdsTile) == MPerBlock * - (KPerBlock * sizeof(ADataType) / APackedSize) * - NWarp / BlockSize, - "ALdsTile size is wrong!"); - static_assert(sizeof(BLdsTile) == NPerBlock * - (KPerBlock * sizeof(BDataType) / BPackedSize) * - MWarp / BlockSize, - "BLdsTile size is wrong!"); - static_assert(Policy::template GetSmemSizeA() >= - MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize), - "SmemSizeA size is wrong!"); - static_assert(Policy::template GetSmemSizeB() >= - (KPerBlock * sizeof(BDataType) / BPackedSize) * NPerBlock, - "SmemSizeB size is wrong!"); - - ////////////// MX Scale register tiles (ping-pong buffers) ///////////////// - // Scales are pre-packed int32_t: each int32_t holds 2M/N x 2K e8m0_t values - // Block GEMM uses OpSel (0-3) to select the right byte per MFMA call - - using ScaleATileType = decltype(load_tile(scale_a_dram_window)); - using ScaleBTileType = decltype(load_tile(scale_b_dram_window)); - ScaleATileType scale_a_tile_ping, scale_a_tile_pong; - ScaleBTileType scale_b_tile_ping, scale_b_tile_pong; - - // initialize Scale DRAM window steps, used to advance the Scale DRAM windows - using ScaleADramTileWindowStep = typename ScaleADramBlockWindowTmp::BottomTensorIndex; - using ScaleBDramTileWindowStep = typename ScaleBDramBlockWindowTmp::BottomTensorIndex; - constexpr ScaleADramTileWindowStep scale_a_dram_tile_window_step = - make_array(0, ScaleKDimPerBlock); - constexpr ScaleBDramTileWindowStep scale_b_dram_tile_window_step = - make_array(0, ScaleKDimPerBlock); - - // Helper function to load scales - auto load_scales_from_dram = [&](auto& scale_a, auto& scale_b) { - scale_a = load_tile(scale_a_dram_window); - scale_b = load_tile(scale_b_dram_window); - move_tile_window(scale_a_dram_window, scale_a_dram_tile_window_step); - move_tile_window(scale_b_dram_window, scale_b_dram_tile_window_step); - }; - - constexpr auto a_lds_input_tile_distr = []() { - if constexpr(is_a_load_tr_v) - return make_static_tile_distribution( - typename InputTileDistributionTraits< - typename ALdsTileDistr::DstrEncode, - typename Problem::ADataType>::TransposedDstrEncode{}); - else - return ALdsTileDistr{}; - }(); - constexpr auto b_lds_input_tile_distr = []() { - if constexpr(is_b_load_tr_v) - return make_static_tile_distribution( - typename InputTileDistributionTraits< - typename BLdsTileDistr::DstrEncode, - typename Problem::BDataType>::TransposedDstrEncode{}); - else - return BLdsTileDistr{}; - }(); - - // LDS tile windows for reading; - // they share the data pointer with the LDS windows for storing - // but also associate with a distribution to produce a register tile when reading - auto a_lds_ld_window0 = - make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr); - auto a_lds_ld_window1 = - make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr); - auto b_lds_ld_window0 = - make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr); - auto b_lds_ld_window1 = - make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr); - - static_assert(!(is_tile_window_linear_v) && - !(is_tile_window_linear_v) && - !(is_tile_window_linear_v) && - !(is_tile_window_linear_v), - "LDS windows must not be linear"); - - // write to LDS window(0) must complete before the local prefetch - block_sync_lds_direct_load(); - // read A(0), B(0) from LDS window(0) to pipeline registers(0) - Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v); - Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v); - // LDS window(0) contents are overwritten below by global prefetch, need to sync - block_sync_lds(); - // read A(2), B(2) from DRAM to LDS window(0) - // and advance the DRAM windows - Base::GlobalPrefetchAsync( - a_copy_lds_window0, a_async_tile_windows[number<0>{}], a_dram_tile_window_step); - Base::GlobalPrefetchAsync( - b_copy_lds_window0, b_async_tile_windows[number<0>{}], b_dram_tile_window_step); - - // Load scales for iteration 0 (ping) - load_scales_from_dram(scale_a_tile_ping, scale_b_tile_ping); - // Load scales for iteration 1 (pong) if needed - if(num_loop > 1) - { - load_scales_from_dram(scale_a_tile_pong, scale_b_tile_pong); - } - - if(HasHotLoop) - { - // we have had 3 global prefetches so far, indexed (0, 1, 2). - index_t i_global_read = amd_wave_read_first_lane(3); - // alternate ping: (read to register tile(1), use register tile(0) as gemm input) - // pong: (read to register tile(0), use register tile(1) as gemm input) - do - { - // ping - { - // read A(i-1), B(i-1) from LDS window(1) to pipeline registers(1) - Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); - Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); - // LDS window(1) contents are overwritten by global prefetch, need to sync - block_sync_lds(); - // read A(i), B(i) from DRAM to LDS window(1) - // and advance the DRAM windows - Base::GlobalPrefetchAsync(a_copy_lds_window1, - a_async_tile_windows[number<0>{}], - a_dram_tile_window_step); - Base::GlobalPrefetchAsync(b_copy_lds_window1, - b_async_tile_windows[number<0>{}], - b_dram_tile_window_step); - // C(i-3) = A(i-3) @ B(i-3) with MX scaling - block_gemm(c_block_tile, - a_block_tile0, - b_block_tile0, - scale_a_tile_ping, - scale_b_tile_ping); - HotLoopScheduler(); - // Load next scales after using current scales above - load_scales_from_dram(scale_a_tile_ping, scale_b_tile_ping); - } - // pong - { - // write to LDS window(0) must complete before the local prefetch - block_sync_lds_direct_load(); - // read A(i), B(i) from LDS window(0) to pipeline registers(0) - Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v); - Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v); - // LDS window(0) contents are overwritten by global prefetch, need to sync - block_sync_lds(); - // read A(i+1), B(i+1) from DRAM to LDS window(0) - // and advance the DRAM windows - Base::GlobalPrefetchAsync(a_copy_lds_window0, - a_async_tile_windows[number<0>{}], - a_dram_tile_window_step); - Base::GlobalPrefetchAsync(b_copy_lds_window0, - b_async_tile_windows[number<0>{}], - b_dram_tile_window_step); - // C(i-2) = A(i-2) @ B(i-2) with MX scaling - block_gemm(c_block_tile, - a_block_tile1, - b_block_tile1, - scale_a_tile_pong, - scale_b_tile_pong); - HotLoopScheduler(); - // Load next scales after using current scales above - load_scales_from_dram(scale_a_tile_pong, scale_b_tile_pong); - } - i_global_read += 2; - } while(i_global_read < num_loop); - } - - // 3 block gemms remaining - if constexpr(TailNum == TailNumber::Three) - { - { - // read A(num_loop-1), B(num_loop-1) from LDS window(1) to pipeline registers(1) - Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); - Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); - // C(num_loop-2) = A(num_loop-2) @ B(num_loop-2) with MX scaling - block_gemm(c_block_tile, - a_block_tile0, - b_block_tile0, - scale_a_tile_ping, - scale_b_tile_ping); - - // load last scales to ping for the last iteration to ping buffers - load_scales_from_dram(scale_a_tile_ping, scale_b_tile_ping); - } - { - // write to LDS window(0) must complete before the local prefetch - block_sync_lds_direct_load(); - // read A(num_loop), B(num_loop) from LDS window(0) to pipeline registers(0) - Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v); - Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v); - // C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling - block_gemm(c_block_tile, - a_block_tile1, - b_block_tile1, - scale_a_tile_pong, - scale_b_tile_pong); - } - { - // C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling - block_gemm(c_block_tile, - a_block_tile0, - b_block_tile0, - scale_a_tile_ping, - scale_b_tile_ping); - } - } - else if(TailNum == TailNumber::Two) - // 2 block gemms remaining - { - { - // read A(num_loop), B(num_loop) from LDS window(1) to pipeline registers(1) - Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); - Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); - block_gemm(c_block_tile, - a_block_tile0, - b_block_tile0, - scale_a_tile_ping, - scale_b_tile_ping); - } - { - // C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling - block_gemm(c_block_tile, - a_block_tile1, - b_block_tile1, - scale_a_tile_pong, - scale_b_tile_pong); - } - } - else if(TailNum == TailNumber::One) - { - block_sync_lds(); - // C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling - block_gemm(c_block_tile, - a_block_tile0, - b_block_tile0, - scale_a_tile_ping, - scale_b_tile_ping); - __builtin_amdgcn_sched_barrier(0); - } - - return c_block_tile; - } - }; - - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - index_t num_loop, - void* __restrict__ p_smem) const - { - constexpr index_t smem_size = Policy::template GetSmemSize(); - const auto smem = reinterpret_cast(p_smem); - - const bool has_hot_loop = Base::BlockHasHotloop(num_loop); - const auto tail_number = Base::GetBlockLoopTailNum(num_loop); - - const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - scale_a_window, - scale_b_window, - num_loop, - smem, - smem + smem_size); - }; - - return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); - } - - public: - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - const index_t num_loop, - void* __restrict__ p_smem) const - { - constexpr index_t smem_size = Policy::template GetSmemSize(); - const auto smem = reinterpret_cast(p_smem); - - const bool has_hot_loop = Base::BlockHasHotloop(num_loop); - const auto tail_number = Base::GetBlockLoopTailNum(num_loop); - - const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { - return PipelineImpl{}.template operator()( - make_tuple(a_dram_block_window_tmp), - element_wise::PassThrough{}, - make_tuple(b_dram_block_window_tmp), - element_wise::PassThrough{}, - scale_a_window, - scale_b_window, - num_loop, - smem, - smem + smem_size); - }; - - return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); - } -}; -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp deleted file mode 100644 index bce312d1f9..0000000000 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ /dev/null @@ -1,605 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/core/arch/arch.hpp" -#include "ck_tile/core/numeric/float8.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" -#include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" -#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp" -#include - -namespace ck_tile { -// Default policy for MXGemmPipelineAgBgCrCompAsync -// Customized methods: MakeALdsBlockDescriptor, MakeBLdsBlockDescriptor -// GetBlockGemm implementation is copied from GemmPipelineAgBgCrCompV4DefaultPolicy -// Adds MX scale tile distributions -struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy - : public UniversalGemmBasePolicy -{ - static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked; - static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked; - - // Async copy supports 32-bit, 96-bit, or 128-bit transfers (4, 12, 16 bytes) - // Take PackedSize into consideration (for example for FP4 support) - template - static constexpr index_t AsyncVectorBytes = - sizeof(DataType) * KPack / numeric_traits>::PackedSize; - - template - static constexpr bool IsSupportedAsyncVectorWidth = - AsyncVectorBytes == 4 || AsyncVectorBytes == 12 || - AsyncVectorBytes == 16; - - template - static constexpr bool IsF8XorSwizzleDataType = - std::is_same_v, fp8_t> || - std::is_same_v, bf8_t>; - - template - static constexpr bool IsFP4XorSwizzleDataType = - std::is_same_v, pk_fp4_t>; - - // XOR Swizzle: support F8/F8 and FP4/FP4. Mixed F8/FP4 stays on the plain path. - template - static constexpr bool IsSupportedXorSwizzleDataType = - (IsF8XorSwizzleDataType && - IsF8XorSwizzleDataType) || - (IsFP4XorSwizzleDataType && - IsFP4XorSwizzleDataType); - - // FP4 needs the XOR KPack in logical elements - // so the async transaction remains 16 bytes - template - static constexpr index_t GetXorSwizzleKPack() - { - return SmemPack * numeric_traits>::PackedSize; - } - - template - static constexpr index_t GetXorSwizzleKPackA() - { - return GetXorSwizzleKPack()>(); - } - - template - static constexpr index_t GetXorSwizzleKPackB() - { - return GetXorSwizzleKPack()>(); - } - - // Check that async vector store to LDS is supported - template - static constexpr bool IsSupportedXorSwizzleAsyncWidth = - IsSupportedAsyncVectorWidth()> && - IsSupportedAsyncVectorWidth()>; - - // gfx950 scales:16x16x128 warp tile, 16-element smem pack, KWarps==1 - template - static constexpr bool IsSupportedXorSwizzleShape = []() { - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - return Problem::NumWaveGroups == 1 && BlockWarps::at(number<2>{}) == 1 && - WarpTile::at(number<0>{}) == 16 && WarpTile::at(number<1>{}) == 16 && - WarpTile::at(number<2>{}) == 128 && GetSmemPackA() == 16 && - GetSmemPackB() == 16; - }(); - - // Assume normal LDS layout, not transpose-load - template - static constexpr bool UseXorSwizzle = - !is_a_load_tr && !is_b_load_tr && - IsSupportedXorSwizzleDataType && IsSupportedXorSwizzleAsyncWidth && - IsSupportedXorSwizzleShape; - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeXorSwizzleABDramTileDistribution() - { - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t KPerBlock = BlockGemmShape::kK; - constexpr index_t KWarps = BlockWarps::at(I2); - constexpr index_t K1 = WarpTile::at(I2) / K2; - constexpr index_t K0 = KPerBlock / (KWarps * K1 * K2); - - constexpr index_t warp_size = get_warp_size(); - constexpr index_t warp_num = BlockSize / warp_size; - - static_assert(KWarps == 1, "MX XOR swizzle currently supports KWarps == 1"); - static_assert(KWarps * K0 * K1 * K2 == KPerBlock, "Wrong!"); - - constexpr index_t M2 = warp_size / K1; - constexpr index_t M1 = warp_num / Problem::NumWaveGroups; - constexpr index_t M0 = MNPerBlock / (M1 * M2); - - static_assert(M0 * M1 * M2 == MNPerBlock, "Wrong!"); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 1>>, - sequence<1, 2, 2>, - sequence<0, 0, 2>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() - { - if constexpr(UseXorSwizzle) - { - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPack = GetXorSwizzleKPackA(); - return MakeXorSwizzleABDramTileDistribution(); - } - else - { - return UniversalGemmBasePolicy:: - template MakeADramTileDistribution(); - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() - { - if constexpr(UseXorSwizzle) - { - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPack = GetXorSwizzleKPackB(); - return MakeXorSwizzleABDramTileDistribution(); - } - else - { - return UniversalGemmBasePolicy:: - template MakeBDramTileDistribution(); - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeXorSwizzledABLdsBlockDescriptor() - { - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t KPerBlock = BlockGemmShape::kK; - constexpr index_t KWarps = BlockWarps::at(I2); - constexpr index_t K1 = WarpTile::at(I2) / K2; - constexpr index_t K0 = KPerBlock / (KWarps * K1 * K2); - - constexpr index_t warp_size = get_warp_size(); - constexpr index_t warp_num = BlockSize / warp_size; - constexpr index_t wg_attr_num_access_v = static_cast(WGAttrNumAccess); - - static_assert(warp_num * warp_size == BlockSize, "Wrong!"); - static_assert(KWarps * K0 * K1 * K2 == KPerBlock, "Wrong!"); - static_assert(KWarps == 1, "MX XOR swizzle currently supports KWarps == 1"); - static_assert(wg_attr_num_access_v == 1 || wg_attr_num_access_v == 2, - "MX XOR swizzle currently supports FP8, BF8, FP4"); - - constexpr index_t K2Pad = K2 < 16 ? 16 : K2; - constexpr index_t M3 = 4; - constexpr index_t M2 = warp_size / K1 / M3; - constexpr index_t M1 = WarpTileMN / (M2 * M3); - constexpr index_t M0 = MNPerBlock / (M1 * M2 * M3); - - static_assert(M0 * M1 * M2 * M3 == MNPerBlock, "Wrong!"); - - constexpr index_t PadSize = 4 * K2; - - constexpr auto desc_0 = make_naive_tensor_descriptor( - number_tuple{}, - number_tuple{}, - number{}, - number<1>{}); - - constexpr auto desc_1 = transform_tensor_descriptor( - desc_0, - make_tuple(make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_xor_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4, 5>{}, - sequence<6>{}), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4, 5>{}, - sequence<6>{})); - - constexpr auto desc_2 = transform_tensor_descriptor( - desc_1, - make_tuple(make_merge_transform_v3_division_mod(number_tuple{}), - make_merge_transform_v3_division_mod(number_tuple{})), - make_tuple(sequence<0, 2, 3, 4>{}, sequence<1, 5, 6>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return desc_2; - } - - // MX scaling configuration: each e8m0 scale covers 32 elements in K - static constexpr int ScaleGranularityK = 32; - - template > - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - if constexpr(is_a_load_tr) - { - // TODO: better LDS descriptor for performance - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( // - make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{}), - number{}, - number<1>{}); - return a_lds_block_desc_0; - } - else - { - if constexpr(UseXorSwizzle) - { - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr index_t KPack = GetXorSwizzleKPackA(); - constexpr auto desc = - MakeXorSwizzledABLdsBlockDescriptor()>(); - static_assert(desc.get_element_space_size() >= MPerBlock * KPerBlock, - "XOR swizzle LDS allocation must cover the A tile"); - return desc; - } - else - { - constexpr index_t KPack = GetSmemPackA(); - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - return transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_pass_through_transform(number{}), - make_merge_transform( - make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - if constexpr(is_b_load_tr) - { - // TODO: better LDS descriptor for performance - constexpr auto b_lds_block_desc_0 = - make_naive_tensor_descriptor(make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{}), - number{}, - number<1>{}); - return b_lds_block_desc_0; - } - else - { - if constexpr(UseXorSwizzle) - { - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr index_t KPack = GetXorSwizzleKPackB(); - constexpr auto desc = - MakeXorSwizzledABLdsBlockDescriptor()>(); - static_assert(desc.get_element_space_size() >= NPerBlock * KPerBlock, - "XOR swizzle LDS allocation must cover the B tile"); - return desc; - } - else - { - constexpr index_t KPack = GetSmemPackB(); - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - return transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_pass_through_transform(number{}), - make_merge_transform( - make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - } - } - - // MX GEMM: Double access for FP8/BF8, Single for FP4 - template - CK_TILE_HOST_DEVICE static constexpr auto CalculateWGAttrNumAccess() - { - using DataType = remove_cvref_t; - - if constexpr(std::is_same_v || std::is_same_v) - { - return WGAttrNumAccessEnum::Double; - } - else if constexpr(std::is_same_v) - { - return WGAttrNumAccessEnum::Single; - } - else - { - static_assert(sizeof(DataType) == 0, - "CalculateWGAttrNumAccess(): unsupported data type"); - return WGAttrNumAccessEnum::Invalid; - } - } - - // Get number of accesses - template - CK_TILE_HOST_DEVICE static constexpr auto GetWGAttrNumAccess() - { - constexpr auto num_access_a = CalculateWGAttrNumAccess(); - constexpr auto num_access_b = CalculateWGAttrNumAccess(); - - if constexpr(static_cast(num_access_a) >= static_cast(num_access_b)) - { - return num_access_a; - } - else - { - return num_access_b; - } - } - - template - CK_TILE_DEVICE static constexpr auto MakeAsyncLoadABDramWindow(const Window& window) - { - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - constexpr auto ndims = std::decay_t::get_num_of_dimension(); - static_assert(ndims == 2, "only support 2D tensor"); - - constexpr index_t KPerBlock = BlockGemmShape::kK; - constexpr index_t KWarps = BlockWarps::at(I2); - constexpr index_t K1 = WarpTile::at(I2) / K2; - - static_assert(K1 * K2 == WarpTile::at(I2), "Wrong!"); - static_assert(KPerBlock % (KWarps * K1 * K2) == 0, "Wrong!"); - - constexpr index_t wg_attr_num_access_v = static_cast(WGAttrNumAccess); - - constexpr index_t M4 = 4; // same as MakeXorSwizzledABLdsBlockDescriptor::M3 - static_assert(get_warp_size() % (wg_attr_num_access_v * K1 * M4) == 0, - "warp_size must be divisible by (wg_attr_num_access_v * K1 * M4)"); - - auto&& tensor_view = window.get_bottom_tensor_view(); - const auto [rows, cols] = tensor_view.get_tensor_descriptor().get_lengths(); - - const index_t k_tiles = cols / (KWarps * K1 * K2); - const auto col_lens = make_tuple(k_tiles, number{}, number{}, number{}); - - const index_t M0 = integer_divide_ceil(rows, M4); - const auto row_lens = make_tuple(M0, number{}); - - const auto desc_0 = transform_tensor_descriptor( - tensor_view.get_tensor_descriptor(), - make_tuple(make_unmerge_transform(row_lens), make_unmerge_transform(col_lens)), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0, 1>{}, sequence<2, 3, 4, 5>{})); - - const auto desc_1 = transform_tensor_descriptor( - desc_0, - make_tuple(make_pass_through_transform(M0), - make_xor_transform(make_tuple(number{}, number{})), - make_pass_through_transform(k_tiles), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple( - sequence<0>{}, sequence<1, 4>{}, sequence<2>{}, sequence<3>{}, sequence<5>{}), - make_tuple( - sequence<0>{}, sequence<1, 4>{}, sequence<2>{}, sequence<3>{}, sequence<5>{})); - - const auto desc = - transform_tensor_descriptor(desc_1, - make_tuple(make_merge_transform_v3_division_mod(row_lens), - make_merge_transform_v3_division_mod(col_lens)), - make_tuple(sequence<0, 1>{}, sequence<2, 3, 4, 5>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tile_window( - make_tensor_view(&tensor_view.get_buffer_view()(0), desc), - window.get_window_lengths(), - window.get_window_origin()); - } - - template - CK_TILE_DEVICE static constexpr auto MakeAsyncLoadADramWindow(const Window& window) - { - if constexpr(UseXorSwizzle) - { - constexpr index_t KPack = GetXorSwizzleKPackA(); - return MakeAsyncLoadABDramWindow()>(window); - } - else - { - return make_tile_window(window.get_bottom_tensor_view(), - window.get_window_lengths(), - window.get_window_origin()); - } - } - - template - CK_TILE_DEVICE static constexpr auto MakeAsyncLoadBDramWindow(const Window& window) - { - if constexpr(UseXorSwizzle) - { - constexpr index_t KPack = GetXorSwizzleKPackB(); - return MakeAsyncLoadABDramWindow()>(window); - } - else - { - return make_tile_window(window.get_bottom_tensor_view(), - window.get_window_lengths(), - window.get_window_origin()); - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() - { - using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - - using ADataType = typename Problem::ADataType; - using BDataType = typename Problem::BDataType; - using CDataType = typename Problem::CDataType; - - // FP4 and FP8 require different layouts for the scaled mfma instructions - constexpr auto wg_attr_num_access = GetWGAttrNumAccess(); - - using WarpGemm = WarpGemmDispatcher; - - using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; - - return BlockMXGemmARegBRegCRegV1{}; - } - - // XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension - // Host packs MXdlPack * KXdlPack e8m0_t into one int32_t for A scales - // Host packs NXdlPack * KXdlPack e8m0_t into one int32_t for B scales - static constexpr int MXdlPack = 2; - static constexpr int NXdlPack = 2; - static constexpr int KXdlPack = 2; - - // MX Scale tile distributions for loading pre-packed int32_t from global memory - // Packed layout: [M/MXdlPack, K/32/KXdlPack] of int32_t - template - CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() - { - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t MWarp = BlockWarps::at(number<0>{}); - constexpr index_t NWarp = BlockWarps::at(number<1>{}); - constexpr index_t MPerXdl = WarpTile::at(number<0>{}); - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K_Lane = get_warp_size() / MPerXdl; - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); - constexpr index_t KPerXdl = WarpTile::at(number<2>{}); - constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; - constexpr index_t KPerLane = KPerXdl / ScaleGranularityK / K_Lane; - - // Effective pack sizes: fall back to 1 when iteration count < pack size - constexpr index_t MXdlPackEff = - (MIterPerWarp >= MXdlPack && MIterPerWarp % MXdlPack == 0) ? MXdlPack : 1; - constexpr index_t KXdlPackEff = - (KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1; - - constexpr index_t MIterPerWarp_packed = MIterPerWarp / MXdlPackEff; - constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, - sequence<2, 1, 2>, - sequence<0, 1, 2>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution() - { - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t MWarp = BlockWarps::at(number<0>{}); - constexpr index_t NWarp = BlockWarps::at(number<1>{}); - constexpr index_t NPerXdl = WarpTile::at(number<1>{}); - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K_Lane = get_warp_size() / NPerXdl; - constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); - - constexpr index_t KPerXdl = WarpTile::at(number<2>{}); - constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; - constexpr index_t KPerLane = KPerXdl / ScaleGranularityK / K_Lane; - - // Effective pack sizes: fall back to 1 when iteration count < pack size - constexpr index_t NXdlPackEff = - (NIterPerWarp >= NXdlPack && NIterPerWarp % NXdlPack == 0) ? NXdlPack : 1; - constexpr index_t KXdlPackEff = - (KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1; - - constexpr index_t NIterPerWarp_packed = NIterPerWarp / NXdlPackEff; - constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, - sequence<2, 1, 2>, - sequence<0, 1, 2>>{}); - } -}; -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp deleted file mode 100644 index 3b25d6091a..0000000000 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp +++ /dev/null @@ -1,282 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once -#include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" -#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp" - -namespace ck_tile { - -/** - * @brief Compute optimized pipeline version async for 8 waves - * - * This pipeline introduces asynchronous load from global memory to LDS, - * skipping the intermediate loading into pipeline registers. - */ -template -struct MXGemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrCompV3 -{ - using Base = BaseGemmPipelineAgBgCrCompV3; - using PipelineImplBase = GemmPipelineAgBgCrEightWavesImplBase; - - using AsDataType = remove_cvref_t; - using BsDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - using AsLayout = remove_cvref_t; - using BsLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - - using AElementWise = remove_cvref_t; - using BElementWise = remove_cvref_t; - - using ALayout = remove_cvref_t>; - using BLayout = remove_cvref_t>; - - using ADataType = remove_cvref_t>; - using BDataType = remove_cvref_t>; - - static_assert(!std::is_same_v, "Not implemented"); - - static constexpr index_t APackedSize = ck_tile::numeric_traits::PackedSize; - static constexpr index_t BPackedSize = ck_tile::numeric_traits::PackedSize; - - using BlockGemm = remove_cvref_t())>; - using WarpGemm = typename BlockGemm::WarpGemm; - - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - - static constexpr index_t BlockSize = Problem::kBlockSize; - - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; - - static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(I0); - static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(I1); - static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(I2); - - static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock; - - static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * WarpGemm::kM); - static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * WarpGemm::kN); - static constexpr index_t KIterPerWarp = KPerBlock / (KWarps * WarpGemm::kK); - - static constexpr bool Async = true; - - template - static constexpr index_t GetVectorSizeA() - { - return Policy::template GetVectorSizeA(); - } - template - static constexpr index_t GetVectorSizeB() - { - return Policy::template GetVectorSizeB(); - } - - static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } - static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } - - static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - static constexpr index_t Preshuffle = Problem::Preshuffle; - - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; - static constexpr bool kPadK = Problem::kPadK; - - static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - - static constexpr auto Scheduler = Problem::Scheduler; - - [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() - { - // clang-format off - return "COMPUTE_ASYNC_EIGHT_WAVES"; - // clang-format on - } - - [[nodiscard]] CK_TILE_HOST static const std::string GetName() - { - // clang-format off - constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0); - constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1); - return concat('_', "pipeline_AgBgCrCompAsyncEightWaves", - concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, - concat('x', GetVectorSizeA(), GetVectorSizeB()), - concat('x', WaveNumM, WaveNumN), - concat('x', kPadM, kPadN, kPadK), - Problem::GetName()); - // clang-format on - } - - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - return Policy::template GetSmemSize(); - } - - static constexpr index_t MFMA_INST = MIterPerWarp * NIterPerWarp * KIterPerWarp; - - // Scales are packed so odd numbers of iterations greater than 1 are not supported - static_assert((MIterPerWarp == 1) || (MIterPerWarp % 2 == 0)); - static_assert((NIterPerWarp == 1) || (NIterPerWarp % 2 == 0)); - static_assert((KIterPerWarp == 1) || (KIterPerWarp % 2 == 0)); - - template - struct PipelineImpl : public PipelineImplBase - { - }; - - template <> - struct PipelineImpl : public PipelineImplBase - { - using Base = PipelineImplBase; - - template ::value && - !is_detected::value, - bool>* = nullptr> - CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - index_t num_loop, - void* __restrict__ p_smem) const - { - // TODO: A/B elementwise functions currently not supported - ignore = a_element_func; - ignore = b_element_func; - - // ------ - // Checks - // ------ - static_assert( - std::is_same_v> && - std::is_same_v>, - "A/B Dram block window should have the same data type as appropriate " - "([A|B]DataType) defined in Problem definition!"); - - static_assert(std::is_same_v, "Wrong!"); - static_assert(std::is_same_v, "Wrong!"); - - static_assert((MPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I0] && - KPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I1]), - "A block window has incorrect lengths for defined ALayout!"); - static_assert(Preshuffle // - ? (NWarps == BsDramBlockWindowTmp{}.get_window_lengths()[I0] && - kflatKPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1]) - : (NPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I0] && - KPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1]), - "B block window has incorrect lengths for defined BLayout!"); - - // ------------------ - // Hot loop scheduler - // ------------------ - auto hot_loop_scheduler = [&]() { - __builtin_amdgcn_sched_group_barrier(0x008, MIterPerWarp, 0); // MFMA - s_waitcnt_lgkm<4>(); - __builtin_amdgcn_sched_group_barrier(0x004, 1, 0); // lgkmcnt / SALU - static_for<0, MFMA_INST - MIterPerWarp, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - __builtin_amdgcn_sched_barrier(0); - }; - - // ------- - // Compute - // ------- - return Base::template Run_(p_smem, - num_loop, - a_dram_block_window_tmp, - b_dram_block_window_tmp, - scale_a_window, - scale_b_window, - hot_loop_scheduler); - } - }; - - template ::value && - !is_detected::value, - bool>* = nullptr> - CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - index_t num_loop, - void* p_smem) const - { - const bool has_hot_loop = Base::BlockHasHotloop(num_loop); - const auto tail_number = Base::GetBlockLoopTailNum(num_loop); - const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - scale_a_window, - scale_b_window, - num_loop, - p_smem); - }; - - return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); - } - - template ::value && - !is_detected::value, - bool>* = nullptr> - CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - index_t num_loop, - void* p_smem) const - { - const bool has_hot_loop = Base::BlockHasHotloop(num_loop); - const auto tail_number = Base::GetBlockLoopTailNum(num_loop); - const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - identity{}, - b_dram_block_window_tmp, - identity{}, - scale_a_window, - scale_b_window, - num_loop, - p_smem); - }; - - return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp deleted file mode 100644 index 519b7afcd3..0000000000 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp" -#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_eight_waves_v1.hpp" - -namespace ck_tile { -namespace detail { - -template -struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy -{ - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - - // MX scaling configuration: each e8m0 scale covers 32 elements in K - static constexpr int BlockScaleSize = 32; - - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using AComputeDataType = remove_cvref_t; - using BComputeDataType = remove_cvref_t; - using ComputeDataType = AComputeDataType; - static_assert(std::is_same_v, "Wrong!"); - static_assert(std::is_same_v, "Wrong!"); - static_assert(is_any_of::value); - static_assert(is_any_of::value); - static_assert(std::is_same_v); - static_assert(std::is_same_v); - - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - static constexpr index_t BlockSize = Problem::kBlockSize; - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(I0); - static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(I1); - static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(I2); - static constexpr index_t WarpTileM = WarpTile::at(I0); - static constexpr index_t WarpTileN = WarpTile::at(I1); - static constexpr index_t WarpTileK = WarpTile::at(I2); - static constexpr index_t MWarpTiles = MPerBlock / WarpTileM; - static constexpr index_t NWarpTiles = NPerBlock / WarpTileN; - static constexpr index_t KWarpTiles = KPerBlock / WarpTileK; - - // XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension - // Host packs MXdlPack * KXdlPack e8m0_t into one int32_t for A scales - // Host packs NXdlPack * KXdlPack e8m0_t into one int32_t for B scales - static constexpr int MXdlPack = 2; - static constexpr int NXdlPack = 2; - static constexpr int KXdlPack = 2; - - // Compute effective XdlPack sizes (fall back to 1 when iter count < pack) - static constexpr index_t MPerXdl = WarpTile::at(I0); - static constexpr index_t NPerXdl = WarpTile::at(I1); - static constexpr index_t KPerXdl = WarpTile::at(I2); - static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * MPerXdl); - static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * NPerXdl); - static constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; - - static constexpr index_t MXdlPackEff = - (MIterPerWarp >= MXdlPack && MIterPerWarp % MXdlPack == 0) ? MXdlPack : 1; - static constexpr index_t NXdlPackEff = - (NIterPerWarp >= NXdlPack && NIterPerWarp % NXdlPack == 0) ? NXdlPack : 1; - static constexpr index_t KXdlPackEff = - (KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1; - - static constexpr index_t KPerBlockScale = KPerBlock / BlockScaleSize / KXdlPackEff; - - static constexpr index_t KPerWarp = KPerBlock / KWarps; - static constexpr index_t NPerWarp = NPerBlock / NWarps; - static_assert(NWarps == 2, "NWarps == 2 for ping-pong!"); - - static constexpr index_t warp_size = get_warp_size(); - static constexpr index_t warp_num = BlockSize / warp_size; - static_assert(warp_size == 64, "Wrong!"); - static_assert(warp_num * warp_size == BlockSize, "Wrong!"); - - static_assert(sizeof(ADataType) == sizeof(BDataType), "Wrong!"); - static constexpr index_t ElementSize = sizeof(ADataType); - static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize; // 16 - static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8 - static constexpr index_t K0 = KPerWarp / (K1 * K2); - static_assert(K0 * K1 * K2 == KPerWarp, "Wrong!"); - - CK_TILE_HOST_DEVICE static constexpr auto GetKStepAQ() { return KPerBlockScale; } - CK_TILE_HOST_DEVICE static constexpr auto GetKStepBQ() { return KPerBlockScale; } - - CK_TILE_HOST_DEVICE static constexpr auto GetInstCountAQ() - { - return (MIterPerWarp / MXdlPackEff) * (KIterPerWarp / KXdlPackEff); - } - - CK_TILE_HOST_DEVICE static constexpr auto GetInstCountBQ() - { - return (NIterPerWarp / NXdlPackEff) * (KIterPerWarp / KXdlPackEff); - } - - CK_TILE_HOST_DEVICE static constexpr auto MakeAQBlockDistribution() - { - constexpr index_t K_Lane = get_warp_size() / WarpTileM; - - constexpr index_t KPerLane = WarpTileK / BlockScaleSize / K_Lane; - - constexpr index_t MIterPerWarp_packed = MIterPerWarp / MXdlPackEff; - constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; - - return make_static_tile_distribution( - tile_distribution_encoding< - sequence, // repeat over MWarps - tuple, // M dimension (first) - sequence>, // K dimension (second) - tuple, sequence<2, 1>>, // , - tuple, sequence<1, 2>>, - sequence<2, 1, 2>, // - sequence<0, 1, 2>>{}); - } - - CK_TILE_HOST_DEVICE static constexpr auto MakeBQBlockDistribution() - { - constexpr index_t K_Lane = get_warp_size() / WarpTileN; - - constexpr index_t KPerLane = WarpTileK / BlockScaleSize / K_Lane; - - constexpr index_t NIterPerWarp_packed = NIterPerWarp / NXdlPackEff; - constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; - - return make_static_tile_distribution( - tile_distribution_encoding< - sequence, // repeat over MWarps - tuple, // N dimension - // (first) - sequence>, // K dimension (second) - tuple, sequence<2, 1>>, // , - tuple, sequence<1, 3>>, - sequence<2, 1, 2>, // - sequence<0, 1, 2>>{}); - } - - CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() - { - constexpr auto wg_attr_num_access = - (std::is_same_v || std::is_same_v) - ? WGAttrNumAccessEnum::Double - : WGAttrNumAccessEnum::Single; - - using WarpGemm = WarpGemmDispatcher; - - using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; - - return BlockMXGemmARegBRegCRegEightWavesV1{}; - } -}; -} // namespace detail - -struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy - : public GemmPipelineAgBgCrCompAsyncEightWavesPolicy -{ - -#define FORWARD_METHOD_(method) \ - template \ - CK_TILE_HOST_DEVICE static constexpr auto method(Args&&... args) \ - { \ - return detail::MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy::method( \ - std::forward(args)...); \ - } - - FORWARD_METHOD_(MakeAQBlockDistribution); - FORWARD_METHOD_(MakeBQBlockDistribution); - FORWARD_METHOD_(GetBlockGemm); - FORWARD_METHOD_(GetKStepAQ); - FORWARD_METHOD_(GetKStepBQ); - FORWARD_METHOD_(GetInstCountAQ); - FORWARD_METHOD_(GetInstCountBQ); - -#undef FORWARD_METHOD_ -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves_policy.hpp index d52cb9ddc1..3723353233 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves_policy.hpp @@ -115,31 +115,6 @@ struct GemmABQuantPipelineAgBgCrAsyncPolicy sequence<1, 2>, sequence<0, 1>>{}); } - - CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() - { - static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0, - "KPerWarpGemm must be a multiple of QuantGroupSize::kK!"); - static_assert(Problem::TransposeC, "Wrong!"); - - using WarpGemm = WarpGemmDispatcher; - - using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; - return ABQuantBlockUniversalGemmAsBsCrAsync{}; - } }; } // namespace detail @@ -165,6 +140,47 @@ struct GemmABQuantPipelineAgBgCrAsyncPolicy : public GemmPipelineAgBgCrCompAsync FORWARD_METHOD_(GetInstCountBQ); #undef FORWARD_METHOD_ + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0, + "KPerWarpGemm must be a multiple of QuantGroupSize::kK!"); + static_assert(Problem::TransposeC, "Wrong!"); + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using AComputeDataType = remove_cvref_t; + using BComputeDataType = remove_cvref_t; + + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + + constexpr index_t WarpTileM = WarpTile::at(I0); + constexpr index_t WarpTileN = WarpTile::at(I1); + constexpr index_t WarpTileK = WarpTile::at(I2); + + constexpr auto WGAccessDouble = WGAttrNumAccessEnum::Double; + + using WarpGemm = WarpGemmDispatcher; + + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + return ABQuantBlockUniversalGemmAsBsCrAsync{}; + } }; } // namespace ck_tile diff --git a/test/ck_tile/gemm_mx/CMakeLists.txt b/test/ck_tile/gemm_mx/CMakeLists.txt index 4b6e6b795c..18316ec801 100644 --- a/test/ck_tile/gemm_mx/CMakeLists.txt +++ b/test/ck_tile/gemm_mx/CMakeLists.txt @@ -7,8 +7,16 @@ if(CK_USE_OCP_FP8) endif() if(GPU_TARGETS MATCHES "gfx95") - add_gtest_executable(test_ck_tile_mx_gemm_async test_mx_gemm_async.cpp) - target_compile_options(test_ck_tile_mx_gemm_async PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_mx_gemm_async_rcr test_mx_gemm_async_rcr.cpp) + target_compile_options(test_ck_tile_mx_gemm_async_rcr PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_mx_gemm_async_rrr test_mx_gemm_async_rrr.cpp) + target_compile_options(test_ck_tile_mx_gemm_async_rrr PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_mx_gemm_async_crr test_mx_gemm_async_crr.cpp) + target_compile_options(test_ck_tile_mx_gemm_async_crr PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_mx_gemm_async_ccr test_mx_gemm_async_ccr.cpp) + target_compile_options(test_ck_tile_mx_gemm_async_ccr PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_mx_gemm_async_rcr_large_cases test_mx_gemm_async_rcr_large_cases.cpp) + target_compile_options(test_ck_tile_mx_gemm_async_rcr_large_cases PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile MX GEMM tests for current target") endif() diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_async.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_async.cpp deleted file mode 100644 index 1804325013..0000000000 --- a/test/ck_tile/gemm_mx/test_mx_gemm_async.cpp +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_mx_gemm_config.hpp" -#include "test_mx_gemm_util.hpp" - -using Row = ck_tile::tensor_layout::gemm::RowMajor; -using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -using F4 = ck_tile::pk_fp4_t; -using F8 = ck_tile::fp8_t; -using B8 = ck_tile::bf8_t; - -// clang-format off -using MxTypes = ::testing::Types, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple>; -// clang-format on - -template -class TestMxGemm : public TestMxGemmUtil -{ -}; - -TYPED_TEST_SUITE(TestMxGemm, MxTypes); - -TYPED_TEST(TestMxGemm, Default) -{ - this->Run(128, 512, 256); - this->Run(256, 512, 512); - this->Run(1024, 1024, 1024); -} - -// 32x32x64 MFMA warp tile: enables all four A/B layout combinations via ds_read_tr -// transposed LDS loads. (16x16x128 stays Row/Col only above: KWarpTile=128 exceeds the -// ds_read_tr subtile limit, which disables transpose loads.) -// clang-format off -using MxTypesTranspose = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - // bf8/bf8 and mixed fp8/bf8 exercise the float8 paths newly consolidated into the generic - // 32x32x64 f8/f6/f4 dispatcher (previously distinct per-type code paths). - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple>; -// clang-format on - -template -class TestMxGemmTranspose : public TestMxGemmUtil -{ -}; - -TYPED_TEST_SUITE(TestMxGemmTranspose, MxTypesTranspose); - -TYPED_TEST(TestMxGemmTranspose, BasicSizes) -{ - this->Run(128, 128, 256); - this->Run(128, 128, 512); -} - -TYPED_TEST(TestMxGemmTranspose, MultiBlockMN) -{ - this->Run(256, 128, 256); - this->Run(128, 256, 256); - this->Run(256, 256, 256); -} - -// Preshuffle split-K coverage. MxTypes already exercises the preshuffle configs on the -// non-split-K shapes (TestMxGemm.Default); this fixture pins the split-K shapes to the -// fp4/fp8 preshuffle configs. -using MxTypesPreshuffle = - ::testing::Types, - std::tuple>; - -template -class TestMxGemmPreshuffle : public TestMxGemmUtil -{ -}; - -TYPED_TEST_SUITE(TestMxGemmPreshuffle, MxTypesPreshuffle); - -// Split-K for the preshuffle pipeline: each k_id offsets the flat-B window and the -// host-preshuffled A/B scale windows into its own K slice (and accumulates via atomic-add). -// K is a multiple of K_Tile * k_batch (= 256 * k_batch); N is a multiple of 512 so the shapes -// are valid for both the fp4 (N_Tile = 512) and fp8 (N_Tile = 256) preshuffle configs. -TYPED_TEST(TestMxGemmPreshuffle, SplitK) -{ - this->Run(128, 512, 512, /*k_batch=*/2); - this->Run(128, 512, 1024, /*k_batch=*/2); - this->Run(128, 512, 1024, /*k_batch=*/4); - this->Run(256, 512, 2048, /*k_batch=*/4); -} - -// Regression coverage for the MX GEMM correctness fixes (PR #6663): num_loop == 3 hot-loop -// dispatch, split-K, and M/N padding. Shapes are pinned to fp8 x MX_GemmConfig16 (M_Tile = 64, -// N_Tile = 128, K_Tile = 256, default comp-async pipeline) so the regressions hit the intended -// code path -- e.g. K = 768 gives num_loop = K / K_Tile = 3. -using MxFp8Cfg16Types = ::testing::Types>; - -using MxFp8PadMNTypes = - ::testing::Types>; - -template -class TestMxGemmFp8Regression : public TestMxGemmUtil -{ -}; - -TYPED_TEST_SUITE(TestMxGemmFp8Regression, MxFp8Cfg16Types); - -// num_loop == 3 must not enter the hot loop: with K_Tile = 256, K = 768 gives num_loop = 3, -// which previously produced 5 gemm accumulations instead of 3 (deterministically wrong). -TYPED_TEST(TestMxGemmFp8Regression, HotLoopTailNumLoopThree) -{ - this->Run(64, 128, 768); - this->Run(128, 256, 768); - this->Run(256, 256, 768); -} - -// Split-K: exercises both the full_k_read and partial_k_read paths of SplitKBatchOffset together -// with the per-split scale-window K offset and the atomic-add epilogue. K is a multiple of -// K_Tile * k_batch and of WarpTile_K * k_batch (= 128 * k_batch) so every split lands on a -// packed-scale boundary. -TYPED_TEST(TestMxGemmFp8Regression, SplitK) -{ - this->Run(128, 256, 512, /*k_batch=*/2); - this->Run(128, 256, 1024, /*k_batch=*/2); - this->Run(128, 256, 1024, /*k_batch=*/4); - this->Run(256, 256, 2048, /*k_batch=*/4); -} - -// fp4 split-K (non-preshuffle). Same MX_GemmConfig16 tile shape as the fp8 regression above, so -// the K alignment requirements are identical; this verifies the packed (BPackedSize = 2) A/B -// pointer K-offset works under split-K + atomic-add for fp4. -using MxF4Cfg16Types = ::testing::Types>; - -template -class TestMxGemmFp4Regression : public TestMxGemmUtil -{ -}; - -TYPED_TEST_SUITE(TestMxGemmFp4Regression, MxF4Cfg16Types); - -TYPED_TEST(TestMxGemmFp4Regression, SplitK) -{ - this->Run(128, 256, 512, /*k_batch=*/2); - this->Run(128, 256, 1024, /*k_batch=*/2); - this->Run(128, 256, 1024, /*k_batch=*/4); - this->Run(256, 256, 2048, /*k_batch=*/4); -} - -template -class TestMxGemmFp8PadMN : public TestMxGemmUtil -{ -}; - -TYPED_TEST_SUITE(TestMxGemmFp8PadMN, MxFp8PadMNTypes); - -// M/N padding (kPadM = kPadN = true). M_Tile = 64, N_Tile = 128. Each of M and N must be >= its -// block tile (the CShuffleEpilogue cannot safely run a single partial tile along either -// dimension); K stays aligned because the MX async pipeline does not support K padding. -TYPED_TEST(TestMxGemmFp8PadMN, MNPaddingAligned) -{ - // Sanity: padding enabled but already-aligned M, N must not regress the normal path. - this->Run(64, 128, 256); -} - -TYPED_TEST(TestMxGemmFp8PadMN, MPadding) -{ - // M has a full tile + partial trailing tile (N aligned to N_Tile = 128). - this->Run(96, 128, 256); - this->Run(160, 128, 256); -} - -TYPED_TEST(TestMxGemmFp8PadMN, NPadding) -{ - // N has a full tile + partial trailing tile (M aligned to M_Tile = 64). - this->Run(64, 160, 256); - this->Run(64, 224, 256); -} - -TYPED_TEST(TestMxGemmFp8PadMN, MNPadding) -{ - // Both M and N unaligned (full + partial trailing tiles). - this->Run(96, 160, 256); - this->Run(160, 224, 512); -} diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_async_ccr.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_async_ccr.cpp new file mode 100644 index 0000000000..0dab4b592b --- /dev/null +++ b/test/ck_tile/gemm_mx/test_mx_gemm_async_ccr.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_mx_gemm_pipeline_kernel_types.hpp" +#include "test_mx_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileMxGemmPipelineCompAsyncCCR + : public TestCkTileMxGemmPipeline> +{ + public: + static constexpr bool check_data_type() { return true; } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsyncCCR + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsyncCCR, KernelTypesMxGemmCompAsyncCCR); + +#include "test_mx_gemm_pipeline_tr_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_async_crr.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_async_crr.cpp new file mode 100644 index 0000000000..f95b286cd3 --- /dev/null +++ b/test/ck_tile/gemm_mx/test_mx_gemm_async_crr.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_mx_gemm_pipeline_kernel_types.hpp" +#include "test_mx_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileMxGemmPipelineCompAsyncCRR + : public TestCkTileMxGemmPipeline> +{ + public: + static constexpr bool check_data_type() { return true; } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsyncCRR + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsyncCRR, KernelTypesMxGemmCompAsyncCRR); + +#include "test_mx_gemm_pipeline_tr_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_async_rcr.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_async_rcr.cpp new file mode 100644 index 0000000000..0528295138 --- /dev/null +++ b/test/ck_tile/gemm_mx/test_mx_gemm_async_rcr.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_mx_gemm_pipeline_kernel_types.hpp" +#include "test_mx_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileMxGemmPipelineCompAsyncRCR + : public TestCkTileMxGemmPipeline> +{ + public: + static constexpr bool check_data_type() { return true; } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsyncRCR + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsyncRCR, KernelTypesMxGemmCompAsyncRCR); + +#include "test_mx_gemm_pipeline_ut_cases.inc" + +TYPED_TEST(TEST_SUITE_NAME, MNPadding) +{ + if constexpr(TestFixture::PipelineType == MxGemmPipelineType::WeightPreshuffle || + TestFixture::PipelineType == MxGemmPipelineType::CompEightWaves) + { + return; + } + + std::vector Ms{96, 160, 224}; + std::vector Ns{96, 160, 224}; + std::vector Ks; + // K must be multiple of ScaleBlockSize (16 or 32) and K_Tile + for(auto K_count : {2, 3, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int N : Ns) + { + for(int K : Ks) + { + this->template Run(M, N, K); + } + } + } +} + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_async_rcr_large_cases.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_async_rcr_large_cases.cpp new file mode 100644 index 0000000000..faa7b4be8d --- /dev/null +++ b/test/ck_tile/gemm_mx/test_mx_gemm_async_rcr_large_cases.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_mx_gemm_pipeline_kernel_types.hpp" +#include "test_mx_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileMxGemmPipelineCompAsyncRCR + : public TestCkTileMxGemmPipeline> +{ + public: + static constexpr bool check_data_type() { return true; } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsyncRCR + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsyncRCR, KernelTypesMxGemmCompAsyncRCRLargeCases); + +TYPED_TEST(TEST_SUITE_NAME, Large) +{ + int M = 6422528; + int N = 6144; + int K = 1024; + + this->RunAllGpu(M, N, K); +} + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_async_rrr.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_async_rrr.cpp new file mode 100644 index 0000000000..490326a78f --- /dev/null +++ b/test/ck_tile/gemm_mx/test_mx_gemm_async_rrr.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_mx_gemm_pipeline_kernel_types.hpp" +#include "test_mx_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileMxGemmPipelineCompAsyncRRR + : public TestCkTileMxGemmPipeline> +{ + public: + static constexpr bool check_data_type() { return true; } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsyncRRR + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsyncRRR, KernelTypesMxGemmCompAsyncRRR); + +#include "test_mx_gemm_pipeline_tr_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp deleted file mode 100644 index 1e49fdbe71..0000000000 --- a/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp" - -template -struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0> -{ - using Base = ck_tile::UniversalGemmHostArgs<1, 1, 0>; - - MXGemmHostArgs(const void* a_ptr, - const void* b_ptr, - void* c_ptr_, - ck_tile::index_t k_batch_, - ck_tile::index_t M_, - ck_tile::index_t N_, - ck_tile::index_t K_, - ck_tile::index_t stride_A_, - ck_tile::index_t stride_B_, - ck_tile::index_t stride_C_, - ScaleM scale_m_, - ScaleN scale_n_) - : Base({a_ptr}, - {b_ptr}, - {}, - c_ptr_, - k_batch_, - M_, - N_, - K_, - {stride_A_}, - {stride_B_}, - {}, - stride_C_), - scale_m(scale_m_), - scale_n(scale_n_) - { - } - - ScaleM scale_m; - ScaleN scale_n; -}; - -struct MxGemmConfig -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 512; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 128; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = false; - - static constexpr int kBlockPerCu = 1; - static constexpr int TileParitionerGroupNum = 8; - static constexpr int TileParitionerM01 = 4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool DoubleSmemBuffer = false; - static constexpr bool Preshuffle = false; - static constexpr ck_tile::index_t BContiguousItemsPerAccess = 16; - - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = false; -}; - -struct MX_GemmConfig16 : MxGemmConfig -{ - static constexpr ck_tile::index_t M_Tile = 64; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 256; -}; - -struct MX_GemmConfigEightWaves : MxGemmConfig -{ - static constexpr ck_tile::index_t M_Warp = 4; - static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong! - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128 * N_Warp; - static constexpr ck_tile::index_t K_Tile = 128 * K_Warp; - - static constexpr int kBlockPerCu = 2; -}; - -struct MXfp4_GemmConfig16_Preshuffle : MxGemmConfig -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 512; - static constexpr ck_tile::index_t K_Tile = 256; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr bool Preshuffle = true; - static constexpr ck_tile::index_t BContiguousItemsPerAccess = 32; -}; - -struct MXfp4_GemmConfig16_PermuteN : MXfp4_GemmConfig16_Preshuffle -{ - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; -}; - -struct MXfp8_GemmConfig16_Preshuffle : MxGemmConfig -{ - // For FP8 Preshuffle: - // The theoretical functional minimum is N_Tile = N_Warp * N_Warp_Tile * NXdlPack = 4*16*2 = - // 128 . For better performance, we would choose N_Repeat = 2 which would yield N_Tile = 128 * 2 - // = 256 . Note: If we use fewer waves, the minimum theoretical N_Tile can be even smaller, - // reduced to N_Tile = 32 for 1 single wave. - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 256; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr bool Preshuffle = true; -}; - -struct MxGemmConfig32 : MxGemmConfig -{ - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 64; -}; - -struct MXfp4_GemmConfig32 : MxGemmConfig32 -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 256; -}; - -struct MXfp8_GemmConfig32 : MxGemmConfig32 -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 256; -}; - -// Variant with M/N padding enabled. Used to cover shapes where M/N are not multiples of -// the respective block tiles (MX_GemmConfig16 has M_Tile = 64, N_Tile = 128). K is still -// required to be a multiple of K_Tile -- the MX comp-async pipeline does not support K padding -// (see MXGemmKernel::IsSupportedArgument). -struct MXfp8_GemmConfig16_PadMN : MX_GemmConfig16 -{ - static constexpr bool kPadM = true; - static constexpr bool kPadN = true; -}; - -struct MXfp8_GemmConfig16_PermuteN : MXfp8_GemmConfig16_Preshuffle -{ - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; -}; diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_instance.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_instance.hpp deleted file mode 100644 index 84cea49af3..0000000000 --- a/test/ck_tile/gemm_mx/test_mx_gemm_instance.hpp +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm_mx.hpp" -#include "test_mx_gemm_config.hpp" - -template -float mx_gemm_calc(const MXGemmHostArgs& args, const ck_tile::stream_config& s) -{ - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile:: - sequence>; - - using MXGemmTraits = ck_tile::TileGemmUniversalTraits; - - using ComputeDataType = ADataType; - static_assert(sizeof(ComputeDataType) >= sizeof(BDataType), - "mixed_prec_gemm requires ADataType is a wider type than BDataType"); - - using MXPipelineProblem = ck_tile::UniversalGemmPipelineProblem; - - constexpr bool IsEightWave = - (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp) == 8; - using MXGemmPipeline = std::conditional_t< - GemmConfig::Preshuffle, - ck_tile::MXGemmPreshufflePipelineAGmemBGmemCRegV1, - std::conditional_t, - ck_tile::MXGemmPipelineAgBgCrCompAsync>>; - - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; - - using GemmEpilogue = - std::conditional_t, // DsDataType - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - MXPipelineProblem::TransposeC, - false, // FixedVectorSize_ (Default) - 1>>, // VectorSizeC_ (Default) - ck_tile::CShuffleEpilogue, // DsDataType - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - MXPipelineProblem::TransposeC, - GemmConfig::NumWaveGroups, - false, - 1, - ck_tile::MXEpilogueTraits::BlockedXDLNPerWarp, - false, // DoubleSmemBuffer_ (Default) - ADataType, // AComputeDataType - BDataType, // BComputeDataType - !GemmConfig::Preshuffle>>>; // TilesPacked_ (because of - // packed scales) - - using Kernel = ck_tile::MXGemmKernel; - - auto kargs = Kernel::MakeKernelArgs(std::array{args.as_ptr}, - std::array{args.bs_ptr}, - std::array{}, - args.e_ptr, - args.k_batch, - args.M, - args.N, - args.K, - std::array{args.stride_As}, - std::array{args.stride_Bs}, - std::array{}, - args.stride_E, - args.scale_m, - args.scale_n); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error( - "MX GEMM: unsupported shape/configuration (set CK_TILE_LOGGING=1 for details)."); - } - - const auto kernel = ck_tile::make_kernel( - Kernel{}, Kernel::GridSize(kargs), Kernel::BlockSize(), 0, kargs); - - return ck_tile::launch_kernel(s, kernel); -} diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp index 71d05e7656..ff9e42e2e4 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp @@ -10,19 +10,23 @@ #include "test_mx_gemm_pipeline_util.hpp" #include "test_mx_gemm_pipeline_prec_types.hpp" -using Row = ck_tile::tensor_layout::gemm::RowMajor; -using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -using Intrawave = ck_tile::integral_constant; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using CompTDMV1 = ck_tile::integral_constant; using CompTDMV2 = ck_tile::integral_constant; +using CompAsync = ck_tile::integral_constant; +using CompEightWaves = + ck_tile::integral_constant; +using WeightPreshuffle = + ck_tile::integral_constant; using I16 = ck_tile::number<16>; using I32 = ck_tile::number<32>; using I64 = ck_tile::number<64>; using I128 = ck_tile::number<128>; using I256 = ck_tile::number<256>; +using I512 = ck_tile::number<512>; using ClusterEnable = std::true_type; using ClusterDisable = std::false_type; @@ -33,48 +37,89 @@ using ClusterDisable = std::false_type; // ALayout, BLayout, CLayout, ADataType, BDataType, AScaleDataType, BScaleDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, N_TileSize, Scheduler, PipelineType, ScaleBlockSize using KernelTypesMxGemmCompTDMWmma = ::testing::Types< // --- Scale32 (WarpTile=32, CompTDMV1) --- - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, - std::tuple< Row, Col, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, - std::tuple< Row, Col, Row, F4, F4, E4M3, E4M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, - std::tuple< Row, Col, Row, F8, F4, E8M0, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, - std::tuple< Row, Col, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, - std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, - std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, + std::tuple< Row, Col, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, + std::tuple< Row, Col, Row, F4, F4, E4M3, E4M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, + std::tuple< Row, Col, Row, F8, F4, E8M0, E5M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, + std::tuple< Row, Col, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, // --- Scale32 (WarpTile=32, CompTDMV2) --- - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>, - std::tuple< Row, Col, Row, F4, F4, E4M3, E4M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>, - std::tuple< Row, Col, Row, F8, F4, E8M0, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>, - std::tuple< Row, Col, Row, F4, F8, E4M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>, - std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>, - std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>, - std::tuple< Row, Row, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, - std::tuple< Col, Row, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>, + std::tuple< Row, Col, Row, F4, F4, E4M3, E4M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>, + std::tuple< Row, Col, Row, F8, F4, E8M0, E5M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>, + std::tuple< Row, Col, Row, F4, F8, E4M3, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>, + std::tuple< Row, Row, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, + std::tuple< Col, Row, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, // --- Scale16 (WarpTile=16, CompTDMV1) --- - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV1, I16>, // RRR (non-RCR) layout + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV1, I16>, + std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV1, I16>, + std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV1, I16>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV1, I16>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV1, I16>, // RRR (non-RCR) layout // --- Scale16 (WarpTile=32, CompTDMV1) --- - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I16>, // RRR (non-RCR) layout + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I16>, + std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I16>, + std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I16>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I16>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I16>, // RRR (non-RCR) layout // --- Scale16 (WarpTile=16, CompTDMV2) --- - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV2, I16>, // RRR (non-RCR) layout + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV2, I16>, + std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV2, I16>, + std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV2, I16>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV2, I16>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV2, I16>, // RRR (non-RCR) layout // --- Scale16 (WarpTile=32, CompTDMV2) --- - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I16>, // RRR (non-RCR) layout + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I16>, + std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I16>, + std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I16>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I16>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I16>, // RRR (non-RCR) layout // --- Scale32 cluster launch (from develop; ScaleBlockSize=I32 at idx 16, ClusterEnable at idx 17) --- - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32, ClusterEnable>, - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32, ClusterEnable> + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32, std::false_type, ClusterEnable>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32, std::false_type, ClusterEnable> >; + +using KernelTypesMxGemmCompAsyncRCR = ::testing::Types< + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I256, I16, I16, CompAsync, I32>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I64, I64, I256, I16, I16, CompAsync, I32>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I256, I128, I16, I16, CompEightWaves, I32>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I256, I128, I16, I16, CompEightWaves, I32>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I256, I256, I16, I16, WeightPreshuffle, I32>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I512, I256, I16, I16, WeightPreshuffle, I32>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I256, I256, I16, I16, WeightPreshuffle, I32, std::true_type>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I256, I256, I16, I16, WeightPreshuffle, I32, std::true_type>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32> +>; + +using KernelTypesMxGemmCompAsyncRCRLargeCases = ::testing::Types< + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I64, I64, I256, I16, I16, CompAsync, I32> +>; + +using KernelTypesMxGemmCompAsyncRRR = ::testing::Types< + std::tuple< Row, Row, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Row, Row, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Row, Row, Row, BF8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Row, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32> +>; + +using KernelTypesMxGemmCompAsyncCRR = ::testing::Types< + std::tuple< Col, Row, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Col, Row, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Col, Row, Row, BF8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Col, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32> +>; + +using KernelTypesMxGemmCompAsyncCCR = ::testing::Types< + std::tuple< Col, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Col, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Col, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Col, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32> +>; + // clang-format on diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_tr_cases.inc b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_tr_cases.inc new file mode 100644 index 0000000000..1d5e359935 --- /dev/null +++ b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_tr_cases.inc @@ -0,0 +1,27 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +TYPED_TEST(TEST_SUITE_NAME, Regular) +{ + std::vector Ms{128, 256}; + std::vector Ns{128, 256}; + std::vector Ks; + // K must be multiple of ScaleBlockSize (16 or 32) and K_Tile + for(auto K_count : {1, 2, 3, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int N : Ns) + { + for(int K : Ks) + { + this->Run(M, N, K); + } + } + } +} diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_ut_cases.inc index f579766cf6..a4e62ad3f1 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_ut_cases.inc @@ -14,7 +14,7 @@ TYPED_TEST(TEST_SUITE_NAME, SingleTile) TYPED_TEST(TEST_SUITE_NAME, SmallM) { std::vector Ms{1, 2, 4, 8, 16}; - constexpr int N = 64; + constexpr int N = TestFixture::N_Tile; std::vector Ks; // K must be multiple of ScaleBlockSize (16 or 32) and K_Tile for(auto K_count : {2, 3, 4}) @@ -34,7 +34,7 @@ TYPED_TEST(TEST_SUITE_NAME, SmallM) TYPED_TEST(TEST_SUITE_NAME, MidLargeM) { std::vector Ms{32, 64, 128, 256}; - std::vector Ns{96, 128}; // 96 tests non-tile-aligned N + std::vector Ns{TestFixture::N_Tile}; std::vector Ks; for(auto K_count : {2, 3, 4}) { diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp index d6a6217b55..361921ce45 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp @@ -11,6 +11,7 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp" #include "ck_tile/core/numeric/math.hpp" +#include "ck/library/utility/gpu_verification.hpp" template static constexpr inline auto is_row_major(Layout layout_) @@ -44,9 +45,9 @@ constexpr ck_tile::index_t get_k_warp_tile() #endif #else if constexpr(M_Warp_Tile == 32) - return 16; + return 64; else - return 32; + return 128; #endif } @@ -68,10 +69,64 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } +// Deterministic per-element hash RNG for GPU data init. Returns a float in [-3, 3). +// The generic `fill_tensor_uniform_rand_fp_values` filler is NOT valid for ck_tile::pk_fp4_t +// (it converts a single float and duplicates it into both nibbles, and special-cases only the +// classic ck::f4x2_pk_t). We need two independent fp4 values per byte, so we fill directly. +// The narrow [-3,3) range keeps the fp16 GEMM output from overflowing at K up to 4096 (with the +// [0.25,1.0] scales used in RunAllGpu, worst case K*9 = 36864 < 65504). +__device__ inline float mx_fp4_fill_rand(unsigned int seed, unsigned long long idx) +{ + // splitmix64-style avalanche; deterministic given (seed, idx). + unsigned long long z = (idx + 1ULL) * 0x9E3779B97F4A7C15ULL + + static_cast(seed) * 0xD1B54A32D192ED03ULL; + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ULL; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBULL; + z ^= z >> 31; + const float u = + static_cast((z >> 40) & 0xFFFFFFULL) / static_cast(0x1000000); // [0,1) + return u * 6.0f - 3.0f; // [-3,3) +} + +// Fill a packed-fp4 buffer with two independent, deterministic random fp4 values per byte. +// `num_packed` is the number of pk_fp4_t elements (= total fp4 values / 2). +__global__ void +fill_pk_fp4_uniform_kernel(ck_tile::pk_fp4_t* __restrict__ ptr, long num_packed, unsigned int seed) +{ + const long idx0 = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const long nthr = static_cast(gridDim.x) * blockDim.x; + for(long i = idx0; i < num_packed; i += nthr) + { + const float lo_f = rintf(mx_fp4_fill_rand(seed, static_cast(i) * 2ULL)); + const float hi_f = + rintf(mx_fp4_fill_rand(seed, static_cast(i) * 2ULL + 1ULL)); + const auto lo = ck_tile::float_to_mxfp4(lo_f, 1.0f); + const auto hi = ck_tile::float_to_mxfp4(hi_f, 1.0f); + ptr[i] = ck_tile::pk_fp4_t::_pack(lo, hi); + } +} + +inline void fill_pk_fp4_uniform(ck_tile::pk_fp4_t* ptr, + long num_packed, + unsigned int seed, + hipStream_t stream = nullptr) +{ + constexpr int threads = 256; + constexpr long max_blocks = 65536; // grid-stride cap + const long needed = (num_packed + threads - 1) / threads; + const long blocks = needed < max_blocks ? needed : max_blocks; + fill_pk_fp4_uniform_kernel<<(blocks)), dim3(threads), 0, stream>>>( + ptr, num_packed, seed); + ck_tile::hip_check_error(hipGetLastError()); +} + enum struct MxGemmPipelineType { CompTDMV1, - CompTDMV2 + CompTDMV2, + CompAsync, + CompEightWaves, + WeightPreshuffle }; template @@ -95,88 +150,93 @@ struct MxGemmPipelineTypeSelector static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV2"; } }; -template +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync; + using pipeline = ck_tile::GemmPipelineAgBgCrCompAsync; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsync"; } +}; + +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + using pipeline = ck_tile::GemmPipelineAgBgCrCompAsyncEightWaves; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompEightWaves"; } +}; + +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using pipeline = ck_tile::MXGemmPreshufflePipelineAGmemBGmemCRegV1; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffle"; } +}; + +template struct MxGemmEpilogueTypeSelector +{ +}; + +template +struct MxGemmEpilogueTypeSelector { using epilogue = ck_tile::TdmEpilogue; }; +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = ck_tile::TdmEpilogue; +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = ck_tile::CShuffleEpilogue; +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = ck_tile::CShuffleEpilogue; +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = std::conditional_t, + ck_tile::CShuffleEpilogue>; +}; + template struct MxGemmPipelineDefaultParams { static constexpr bool PadM = false; static constexpr bool PadN = false; static constexpr bool PadK = false; - static constexpr bool Preshuffle = false; + static constexpr bool Preshuffle = PT == MxGemmPipelineType::WeightPreshuffle; }; -/// @brief Pre-shuffle scale buffer for gfx1250 wmma mx scale instruction. -/// -/// Reorganizes the scale data from row-major (MN x K) layout to the hardware-specific -/// layout expected by the gfx1250 wmma instruction. -/// -/// @tparam ScaleType Scale data type (e.g., e8m0_t) -/// @tparam ScaleBlockSize The block size for microscaling (e.g., 32) -/// @tparam KStride Whether K is the fast-moving dimension -template -void preShuffleScaleBuffer_gfx1250(const ScaleType* src, - ScaleType* dst, - ck_tile::index_t MN, - ck_tile::index_t K) +template +struct Config { - static_assert((ScaleBlockSize == 32 || ScaleBlockSize == 16) && sizeof(ScaleType) == 1, - "wrong! only support 8-bit scale with ScaleBlockSize=32 or 16"); - - // ScaleBlockSize == 16: the natural row-major scale layout already matches the gfx1250 - // wmma scale distribution (one e8m0 per 16 K-elements lands warp-aligned), so the - // device-side shuffle is the identity transform for all K. - if constexpr(ScaleBlockSize == 16) - { - for(ck_tile::index_t mn = 0; mn < MN; ++mn) - for(ck_tile::index_t k = 0; k < K; ++k) - { - if constexpr(KStride) - dst[mn * K + k] = src[mn * K + k]; - else - dst[mn * K + k] = src[k * MN + mn]; - } - return; - } - - constexpr ck_tile::index_t MPerXdlops = 16; - constexpr ck_tile::index_t KPerXdlops = 128; - - int MNPack = 2; - int KPack = 1; - - int MNStep = MPerXdlops; - int KStep = KPerXdlops / ScaleBlockSize; - - int K0 = K / KPack / KStep; - - for(int mn = 0; mn < MN; ++mn) - { - int iMNRepeat = mn / (MNStep * MNPack); - int tempmn = mn % (MNStep * MNPack); - - for(int k = 0; k < K; ++k) - { - int iKRepeat = k / (KStep * KPack); - int tempk = k % (KStep * KPack); - - int outputIndex = (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) + - (iKRepeat * KStep * KPack) * (MNStep * MNPack) + - tempmn * (KStep * KPack) + tempk; - - if constexpr(KStride) - { - dst[outputIndex] = src[mn * K + k]; - } - else - dst[outputIndex] = src[k * MN + mn]; - } - } -} + static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_; + static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_; + static constexpr ck_tile::index_t N_Tile = N_Tile_; + static constexpr ck_tile::index_t N_Warp = N_Warp_; + static constexpr ck_tile::index_t BContiguousItemsPerAccess = + std::is_same_v ? 32 : 16; +}; template class TestCkTileMxGemmPipeline : public ::testing::Test @@ -191,8 +251,10 @@ class TestCkTileMxGemmPipeline : public ::testing::Test using BScaleDataType = std::tuple_element_t<6, Tuple>; using AccDataType = std::tuple_element_t<7, Tuple>; using CDataType = std::tuple_element_t<8, Tuple>; - static constexpr auto Scheduler = std::tuple_element_t<14, Tuple>::value; - static constexpr auto PipelineType = std::tuple_element_t<15, Tuple>::value; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr auto PipelineType = std::tuple_element_t<14, Tuple>::value; + static constexpr bool PermuteN = + ck_tile::tuple_element_or_default_t::value; static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<9, Tuple>{}; static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<10, Tuple>{}; @@ -213,17 +275,21 @@ class TestCkTileMxGemmPipeline : public ::testing::Test static constexpr bool ClusterLaunch = ck_tile::tuple_element_or_default_t::value; - static constexpr ck_tile::index_t ScaleBlockSize = std::tuple_element_t<16, Tuple>{}; + static constexpr ck_tile::index_t ScaleBlockSize = std::tuple_element_t<15, Tuple>{}; + + static constexpr ck_tile::index_t M_Warp = + PipelineType == MxGemmPipelineType::WeightPreshuffle + ? 1 + : (PipelineType == MxGemmPipelineType::CompEightWaves ? 4 : 2); + static constexpr ck_tile::index_t N_Warp = + PipelineType == MxGemmPipelineType::WeightPreshuffle ? 4 : 2; + static constexpr ck_tile::index_t K_Warp = 1; protected: template void invoke_mx_gemm(const ck_tile::MxGemmHostArgs<1, 1, 0>& args, const ck_tile::stream_config& s) { - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - // if cluster launch is enabled, set cluster dim to 2x2x1 constexpr ck_tile::index_t kClusterSizeM = std::conditional_t, ck_tile::number<1>>{}; @@ -240,11 +306,13 @@ class TestCkTileMxGemmPipeline : public ::testing::Test constexpr bool DoubleSmemBuffer = true; // TDM pipeline requires double smem buffer #if defined(CK_USE_GFX1250) + constexpr ck_tile::index_t BlockedXDLNPerWarp = 1; constexpr bool TransposeC = std::is_same_v && M_Warp_Tile == N_Warp_Tile; -#else - constexpr bool TransposeC = false; +#elif defined(CK_USE_GFX950) + constexpr ck_tile::index_t BlockedXDLNPerWarp = Preshuffle ? 2 : 1; + constexpr bool TransposeC = false; #endif static constexpr bool StructuredSparsity = false; static constexpr bool NumWaveGroup = 1; @@ -302,8 +370,26 @@ class TestCkTileMxGemmPipeline : public ::testing::Test using GemmPipeline = typename MxGemmPipelineTypeSelector::pipeline; - using GemmEpilogue = typename MxGemmEpilogueTypeSelector< - PipelineType, + using GemmEpilogueProblem = std::conditional_t< + PipelineType == MxGemmPipelineType::WeightPreshuffle && PermuteN, + ck_tile::PermuteNEpilogueProblem, /*VectorSizeC_*/ ck_tile::CShuffleEpilogueProblem>::epilogue; + 1, /*kNumWaveGroups_*/ + false, /*FixedVectorSize_*/ + 1, /*VectorSizeC_*/ + BlockedXDLNPerWarp, /*BlockedXDLN_PerWarp_*/ + DoubleSmemBuffer, /*DoubleSmemBuffer*/ + AComputeDataType, /*AComputeDataType_*/ + BComputeDataType, /*BComputeDataType_*/ + !preshuffle>>; + + using GemmEpilogue = typename MxGemmEpilogueTypeSelector::epilogue; using Kernel = ck_tile::MxGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -360,12 +451,28 @@ class TestCkTileMxGemmPipeline : public ::testing::Test } public: + std::vector k_batches_; + void SetUp() override { if constexpr(!Derived::check_data_type()) { GTEST_SKIP() << "Unsupported data type combination for mx_gemm pipeline test."; } + // for TDM it's used tdm_epilogue which don't support split-k + if constexpr(PipelineType == MxGemmPipelineType::CompTDMV1 || + PipelineType == MxGemmPipelineType::CompTDMV2 || + std::is_same_v || + std::is_same_v) + { + // Only do k_batch = 1 + k_batches_ = {1}; + } + else + { + // Otherwise, use k_batch = 1 and 2 + k_batches_ = {1, 2}; + } } template ::PadM, @@ -381,7 +488,15 @@ class TestCkTileMxGemmPipeline : public ::testing::Test { if constexpr(Derived::check_data_type()) { - RunSingle(M, N, K, StrideA, StrideB, StrideC, 1); + for(auto kb : k_batches_) + { + // skip test when split k' number is not evenly distributed + if((K / K_Tile) % kb != 0) + { + continue; + } + RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); + } } } @@ -422,16 +537,18 @@ class TestCkTileMxGemmPipeline : public ::testing::Test // so M must be padded to at least MNPack * MPerXdlops = 32. constexpr index_t ScaleShuffleAlign = 32; const index_t scale_padded_M = integer_least_multiple( - static_cast(M), - static_cast(ck_tile::max(M_Warp_Tile, ScaleShuffleAlign))); + static_cast(M), static_cast(ck_tile::max(M_Tile, ScaleShuffleAlign))); HostTensor scale_a( {static_cast(scale_padded_M), static_cast(num_scale_k)}, {static_cast(num_scale_k), static_cast(1)}); - // scale_b uses N as first dimension (col-major like B) + const index_t scale_padded_N = integer_least_multiple( + static_cast(N), static_cast(ck_tile::max(N_Tile, ScaleShuffleAlign))); + // Pre-shuffle interleaves 2 K-lanes (MNPack=2) with MPerXdlops=16 stride, + // so N must be padded to at least MNPack * NPerXdlops = 32. HostTensor scale_b( - {static_cast(N), static_cast(num_scale_k)}, + {static_cast(scale_padded_N), static_cast(num_scale_k)}, {static_cast(num_scale_k), static_cast(1)}); // Fill data @@ -485,38 +602,112 @@ class TestCkTileMxGemmPipeline : public ::testing::Test } // Pre-shuffle scale buffers for the hardware +#if defined(CK_USE_GFX1250) + static constexpr index_t NXdlPackEff = 1; + HostTensor scale_a_shuffled( {static_cast(scale_padded_M), static_cast(num_scale_k)}, {static_cast(num_scale_k), static_cast(1)}); HostTensor scale_b_shuffled( - {static_cast(N), static_cast(num_scale_k)}, + {static_cast(scale_padded_N), static_cast(num_scale_k)}, {static_cast(num_scale_k), static_cast(1)}); // Pre-shuffle for gfx1250 (WaveSize=32, WMMA) // Scales start in natural tensor layout and are pre-shuffled into the device layout // for both scale block sizes (the shuffle is the identity for ScaleBlockSize==16, // whose natural layout already matches the warp scale distribution). - preShuffleScaleBuffer_gfx1250( + ck_tile::preShuffleScaleBuffer_gfx1250( scale_a.mData.data(), scale_a_shuffled.mData.data(), scale_padded_M, num_scale_k); - preShuffleScaleBuffer_gfx1250( - scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k); + ck_tile::preShuffleScaleBuffer_gfx1250( + scale_b.mData.data(), scale_b_shuffled.mData.data(), scale_padded_N, num_scale_k); +#elif defined(CK_USE_GFX950) + constexpr ck_tile::index_t MPerXdl = M_Warp_Tile; + constexpr ck_tile::index_t NPerXdl = N_Warp_Tile; + constexpr ck_tile::index_t KPerXdl = K_Warp_Tile; + constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * MPerXdl); + constexpr ck_tile::index_t NIterPerWarp = N_Tile / (N_Warp * NPerXdl); + constexpr ck_tile::index_t KIterPerWarp = K_Tile / KPerXdl; + + constexpr ck_tile::index_t MXdlPackEff = + (MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t NXdlPackEff = + (NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t KXdlPackEff = + (KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1; + + constexpr ck_tile::index_t XdlMNThread = M_Warp_Tile; + constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread; + + HostTensor scale_a_shuffled( + {static_cast(scale_padded_M / MXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), static_cast(1)}); + + HostTensor scale_b_shuffled( + {static_cast(scale_padded_N / NXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), static_cast(1)}); + + ck_tile::preShuffleScaleBuffer_gfx950( + scale_a.mData.data(), scale_a_shuffled.mData.data(), scale_padded_M, num_scale_k, true); + + if constexpr(PipelineType == MxGemmPipelineType::WeightPreshuffle && PermuteN) + { + ck_tile::preShuffleScaleBufferPermuteN_gfx950( + scale_b.mData.data(), + scale_b_shuffled.mData.data(), + scale_padded_N, + num_scale_k, + true); + } + else + { + ck_tile:: + preShuffleScaleBuffer_gfx950( + scale_b.mData.data(), + scale_b_shuffled.mData.data(), + scale_padded_N, + num_scale_k, + true); + } +#endif // Allocate device memory DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); - DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes()); DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes()); // Upload data to device a_m_k_dev_buf.ToDevice(a_m_k.data()); - b_k_n_dev_buf.ToDevice(b_k_n.data()); c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); scale_a_dev_buf.ToDevice(scale_a_shuffled.data()); scale_b_dev_buf.ToDevice(scale_b_shuffled.data()); + using GemmConfig = Config; + + const auto b_host_for_dev = [&]() { + if constexpr(Preshuffle) + { + if constexpr(PermuteN) + { + return ck_tile::shuffle_b_permuteN(b_k_n); + } + else + { + return ck_tile::shuffle_b(b_k_n); + } + } + else + { + return b_k_n; + } + }(); + DeviceMem b_k_n_dev_buf(b_host_for_dev.get_element_space_size_in_bytes()); + b_k_n_dev_buf.ToDevice(b_host_for_dev.data()); + // Create MxGemmHostArgs ck_tile::MxGemmHostArgs<1, 1, 0> args( {static_cast(a_m_k_dev_buf.GetDeviceBuffer())}, @@ -546,7 +737,14 @@ class TestCkTileMxGemmPipeline : public ::testing::Test {static_cast(1), static_cast(num_scale_k)}); // Copy scale_b data (our scale_b is (N, num_scale_k) row-major, // reference expects (num_scale_k, N) col-major, which is the same memory layout) - std::copy(scale_b.mData.begin(), scale_b.mData.end(), scale_b_ref.mData.begin()); + // Truncate scale_a to actual N (not padded) + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < num_scale_k; ++k) + { + scale_b_ref(k, n) = scale_b(n, k); + } + } // Truncate scale_a to actual M (not padded) HostTensor scale_a_ref( @@ -582,4 +780,275 @@ class TestCkTileMxGemmPipeline : public ::testing::Test rtol_atol.at(number<1>{})); EXPECT_TRUE(pass); } + + // All-GPU validation path for the fp4 (pk_fp4_t) MX GEMM. + // + // Unlike Run(), this never materializes the A/B/C tensors on the host: + // - A/B are generated directly on device with a deterministic fp4 fill. + // - the reference is computed on device by reference_mx_gemm_gpu. + // - the comparison is done on device by ck::profiler::gpu_verify. + // Only the tiny e8m0 scales touch the host (for pre-shuffle + an unshuffled copy that the + // device reference consumes). + void RunAllGpu(const int M, const int N, const int K, const int kbatch = 1) + { + if constexpr(!Derived::check_data_type()) + return; + + static_assert(std::is_same_v && + std::is_same_v, + "RunAllGpu currently supports pk_fp4_t A/B only."); + // The GPU reference (reference_mx_gemm_gpu) hardcodes these layouts; guard so it cannot be + // silently misused with a layout it does not handle. + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, + "RunAllGpu / reference_mx_gemm_gpu assume RowMajor-A, ColumnMajor-B, " + "RowMajor-C."); + + static_assert(PipelineType != MxGemmPipelineType::WeightPreshuffle); + +#if !defined(CK_USE_GFX950) + (void)M; + (void)N; + (void)K; + (void)kbatch; + GTEST_SKIP() << "RunAllGpu requires CK_USE_GFX950."; +#else + using namespace ck_tile::literals; + constexpr long kIntMax = 2147483647L; // INT_MAX + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + return col; + else + return row; + } + else + return stride; + }; + + constexpr ck_tile::index_t psize = ck_tile::numeric_traits::PackedSize; // 2 + static_assert(psize == 2, + "RunAllGpu byte-sizing and reference_mx_gemm_kernel's a_ptr[a_lin/2] " + "addressing assume pk_fp4_t PackedSize == 2."); + + bool pass = true; + long total_MN = 0; + + // Strides are K/N here (small); keep them as index_t to match the kernel args, and + // make the size_t->index_t narrowing explicit. + const ck_tile::index_t stride_A = + static_cast(f_get_default_stride(M, K, 0, ALayout{})); // K + const ck_tile::index_t stride_B = + static_cast(f_get_default_stride(K, N, 0, BLayout{})); // K + const ck_tile::index_t stride_C = + static_cast(f_get_default_stride(M, N, 0, CLayout{})); // N + + ASSERT_EQ(K % ScaleBlockSize, 0) << "K must be a multiple of ScaleBlockSize for MX GEMM"; + const ck_tile::index_t num_scale_k = K / ScaleBlockSize; + ASSERT_EQ(num_scale_k % (K_Warp_Tile / ScaleBlockSize), 0) + << "K must be a multiple of K_Warp_Tile (" << K_Warp_Tile + << ") for MX GEMM. Pad the scale data."; + const ck_tile::index_t scale_padded_M = ck_tile::integer_least_multiple( + static_cast(M), static_cast(M_Tile)); + + // int32-safety: the property under test for the M-decomposition. The predicate is + // "largest 0-based element offset fits in a signed 32-bit int", i.e. offset <= INT_MAX. + const long MN = static_cast(M) * N; + const long A_elems = static_cast(M) * K; + const long B_elems = static_cast(K) * N; + const long C_off = static_cast(M - 1) * stride_C + (N - 1); + const long A_off = static_cast(M - 1) * stride_A + (K - 1); + const long B_off = static_cast(N - 1) * stride_B + (K - 1); + const long c_bytes = MN * static_cast(sizeof(CDataType)); + std::cout << "[int32-safety] M=" << M << " N=" << N << " K=" << K << " M*N=" << MN + << " A_elems=" << A_elems << " B_elems=" << B_elems << " C_off=" << C_off + << " A_off=" << A_off << " B_off=" << B_off << " C_bytes=" << c_bytes + << " (INT_MAX=" << kIntMax << ")" << std::endl; + // Note (not an assert): the C *byte* span can exceed INT_MAX even when the element + // count is int32-safe. We deliberately let the run proceed -- if any internal byte + // offset overflows, gpu_verify will flag it, which is exactly what we want to discover. + if(c_bytes > kIntMax) + std::cout << "[int32-safety][note] C byte span (" << c_bytes + << ") exceeds INT_MAX; if verification fails, byte-offset overflow is the " + "prime suspect." + << std::endl; + ASSERT_LE(B_off, kIntMax) << "B offset exceeds INT_MAX"; + total_MN += MN; + + const long a_bytes = (A_elems + psize - 1) / psize; + const long b_bytes = (B_elems + psize - 1) / psize; + + // Bound peak device memory (A + B + 2*C + scales). Skip cleanly rather + // than aborting via hip_check_error if the device cannot hold test shapes. + { + std::size_t free_b = 0, total_b = 0; + ck_tile::hip_check_error(hipMemGetInfo(&free_b, &total_b)); + const std::size_t need = static_cast(a_bytes) + + static_cast(b_bytes) + + 2u * static_cast(c_bytes) + (64u << 20); + if(free_b < need) + GTEST_SKIP() << "insufficient device memory (need " << need << " B, free " << free_b + << " B)"; + } + + auto a_dev = std::make_unique(static_cast(a_bytes)); + auto b_dev = std::make_unique(static_cast(b_bytes)); + auto c_dev = std::make_unique(static_cast(c_bytes)); + auto c_ref_dev = std::make_unique(static_cast(c_bytes)); + c_dev->SetZero(); + c_ref_dev->SetZero(); + + // GPU fill A/B (deterministic, fp4-correct). Same device buffers feed both the kernel + // and the reference, so the fill need not bit-match any host RNG. + fill_pk_fp4_uniform( + reinterpret_cast(a_dev->GetDeviceBuffer()), a_bytes, 11939u); + fill_pk_fp4_uniform( + reinterpret_cast(b_dev->GetDeviceBuffer()), b_bytes, 11940u); + ck_tile::hip_check_error(hipDeviceSynchronize()); // surface fill faults at the fill site + + // e8m0 scales (tiny, host-built). The range is + // deliberately narrow ([0.25,1.0] scales, [-3,3) fp4 fill) so that K up to 4096 cannot + // overflow the fp16 output (worst case K*9 = 36864 < 65504); gpu_verify counts matched + // infinities as errors, so an overflow would otherwise be a false failure. + ck_tile::HostTensor scale_a( + {static_cast(scale_padded_M), static_cast(num_scale_k)}, + {static_cast(num_scale_k), static_cast(1)}); + ck_tile::HostTensor scale_b( + {static_cast(N), static_cast(num_scale_k)}, + {static_cast(num_scale_k), static_cast(1)}); + { + std::mt19937 gen(11941u); + std::uniform_real_distribution dist(0.25f, 1.0f); + for(auto& s : scale_a.mData) + s = AScaleDataType{dist(gen)}; + for(auto& s : scale_b.mData) + s = BScaleDataType{dist(gen)}; + } + + // gfx950 scale pre-shuffle. NOTE: this must stay in sync with the identical block in + // Run() -- the kernel-input layout and the reference-input layout must agree. + constexpr ck_tile::index_t MPerXdl = M_Warp_Tile; + constexpr ck_tile::index_t NPerXdl = N_Warp_Tile; + constexpr ck_tile::index_t KPerXdl = K_Warp_Tile; + constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * MPerXdl); + constexpr ck_tile::index_t NIterPerWarp = N_Tile / (N_Warp * NPerXdl); + constexpr ck_tile::index_t KIterPerWarp = K_Tile / KPerXdl; + + constexpr ck_tile::index_t MXdlPackEff = + (MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t NXdlPackEff = + (NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t KXdlPackEff = + (KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1; + + constexpr ck_tile::index_t XdlMNThread = M_Warp_Tile; + constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread; + + ck_tile::HostTensor scale_a_shuffled( + {static_cast(scale_padded_M / MXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), static_cast(1)}); + ck_tile::HostTensor scale_b_shuffled( + {static_cast(N / NXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), static_cast(1)}); + + ck_tile::preShuffleScaleBuffer_gfx950( + scale_a.mData.data(), scale_a_shuffled.mData.data(), scale_padded_M, num_scale_k, true); + ck_tile::preShuffleScaleBuffer_gfx950( + scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k, true); + + // Device scale buffers: shuffled feed the kernel, unshuffled feed the reference. + auto scale_a_shuf_dev = std::make_unique( + scale_a_shuffled.get_element_space_size_in_bytes()); + auto scale_b_shuf_dev = std::make_unique( + scale_b_shuffled.get_element_space_size_in_bytes()); + scale_a_shuf_dev->ToDevice(scale_a_shuffled.data()); + scale_b_shuf_dev->ToDevice(scale_b_shuffled.data()); + + auto scale_a_ref_dev = + std::make_unique(scale_a.get_element_space_size_in_bytes()); + auto scale_b_ref_dev = + std::make_unique(scale_b.get_element_space_size_in_bytes()); + scale_a_ref_dev->ToDevice(scale_a.data()); + scale_b_ref_dev->ToDevice(scale_b.data()); + + // Launch kernel + ck_tile::MxGemmHostArgs<1, 1, 0> args( + {static_cast(a_dev->GetDeviceBuffer())}, + {static_cast(scale_a_shuf_dev->GetDeviceBuffer())}, + {static_cast(b_dev->GetDeviceBuffer())}, + {static_cast(scale_b_shuf_dev->GetDeviceBuffer())}, + {}, + c_dev->GetDeviceBuffer(), + kbatch, + M, + N, + K, + {stride_A}, + {stride_B}, + {}, + stride_C); + + invoke_mx_gemm(args, ck_tile::stream_config{nullptr, false}); + + ck_tile::hip_check_error(hipDeviceSynchronize()); + + // GPU reference on the same device A/B buffers. + ck_tile::reference_mx_gemm_gpu( + reinterpret_cast(a_dev->GetDeviceBuffer()), + reinterpret_cast(b_dev->GetDeviceBuffer()), + reinterpret_cast(scale_a_ref_dev->GetDeviceBuffer()), + reinterpret_cast(scale_b_ref_dev->GetDeviceBuffer()), + reinterpret_cast(c_ref_dev->GetDeviceBuffer()), + M, + N, + K, + num_scale_k, + ScaleBlockSize); + ck_tile::hip_check_error(hipDeviceSynchronize()); + + // GPU verify with explicit MX tolerance (auto tolerance defaults too tight for MX). + const float max_acc = ck::profiler::gpu_reduce_max(c_ref_dev->GetDeviceBuffer(), + static_cast(MN)); + // The reference must be non-degenerate, else error_count==0 is a vacuous pass. + ASSERT_GT(max_acc, 0.0f) << "GPU reference output is all-zero"; + const auto rtol_atol = + calculate_rtol_atol(K, kbatch, max_acc); + const auto res = ck::profiler::gpu_verify(c_dev->GetDeviceBuffer(), + c_ref_dev->GetDeviceBuffer(), + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{}), + static_cast(MN)); + + // Positive liveness check on the *device* output. res.all_zero ANDs device- and + // reference-zeroness, and the reference is never zero here, so it cannot detect a no-op + // kernel on its own -- reduce the device buffer directly. + const float c_dev_absmax = ck::profiler::gpu_reduce_max( + c_dev->GetDeviceBuffer(), static_cast(MN)); + + std::cout << "[verify] errors=" << res.error_count << " max_error=" << res.max_error + << " c_dev_absmax=" << c_dev_absmax << " max_acc=" << max_acc + << " rtol=" << rtol_atol.at(ck_tile::number<0>{}) + << " atol=" << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + + EXPECT_EQ(res.error_count, 0ull) << "produced mismatched results"; + EXPECT_GT(c_dev_absmax, 0.0f) << "produced an all-zero device output"; + pass &= (res.error_count == 0 && c_dev_absmax > 0.0f); + + std::cout << "[int32-safety] aggregate total_M*N=" << total_MN << " (INT_MAX=" << kIntMax + << ") -> decomposition is the variable under test" << std::endl; + EXPECT_TRUE(pass); +#endif // CK_USE_GFX950 + } }; diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp deleted file mode 100644 index f566ecb38b..0000000000 --- a/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" -#include "ck_tile/host/check_err.hpp" -#include "ck_tile/host/reference/reference_gemm.hpp" -#include "ck_tile/host/tensor_shuffle_utils.hpp" -#include "ck_tile/host/mx_processing.hpp" -#include "test_mx_gemm_config.hpp" -#include "test_mx_gemm_instance.hpp" - -template -static constexpr auto is_row_major(Layout) -{ - return ck_tile::bool_constant< - std::is_same_v, ck_tile::tensor_layout::gemm::RowMajor>>{}; -} - -template -auto calculate_rtol_atol_mx(ck_tile::index_t K, float max_accumulated_value) -{ - using ComputeType = - std::conditional_t; - const auto rtol = ck_tile::get_relative_threshold(K); - const auto atol = ck_tile::get_absolute_threshold( - max_accumulated_value, K); - return ck_tile::make_tuple(rtol, atol); -} - -template -class TestMxGemmUtil : public ::testing::Test -{ - protected: - using ADataType = std::tuple_element_t<0, Tuple>; - using BDataType = std::tuple_element_t<1, Tuple>; - using GemmConfig = std::tuple_element_t<2, Tuple>; - using ALayout = std::tuple_element_t<3, Tuple>; - using BLayout = std::tuple_element_t<4, Tuple>; - using CLayout = std::tuple_element_t<5, Tuple>; - - using AccDataType = float; - using CDataType = ck_tile::fp16_t; - using ScaleType = ck_tile::e8m0_t; - using ScaleM = ck_tile::MXScalePointer; - using ScaleN = ck_tile::MXScalePointer; - - void - Run(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t k_batch = 1) - { - const ck_tile::index_t scale_k_size = K / 32; - const ck_tile::index_t stride_A = - ck_tile::get_default_stride(M, K, 0, is_row_major(ALayout{})); - const ck_tile::index_t stride_B = - ck_tile::get_default_stride(K, N, 0, is_row_major(BLayout{})); - const ck_tile::index_t stride_C = - ck_tile::get_default_stride(M, N, 0, is_row_major(CLayout{})); - // Scales use fixed layouts independent of A/B layout: - // scale A is row-major [M, K/32], and scale B is column-major [K/32, N]. - const ck_tile::index_t stride_scale_a = - ck_tile::get_default_stride(M, scale_k_size, 0, ck_tile::bool_constant{}); - const ck_tile::index_t stride_scale_b = - ck_tile::get_default_stride(scale_k_size, N, 0, ck_tile::bool_constant{}); - - ck_tile::HostTensor a_host( - ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_host( - ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(BLayout{}))); - ck_tile::HostTensor c_host( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - ck_tile::HostTensor scale_a_host(ck_tile::host_tensor_descriptor( - M, scale_k_size, stride_scale_a, ck_tile::bool_constant{})); - ck_tile::HostTensor scale_b_host(ck_tile::host_tensor_descriptor( - scale_k_size, N, stride_scale_b, ck_tile::bool_constant{})); - - std::mt19937 gen(42); - std::uniform_int_distribution fill_seed(0, 500); - - auto gen_scales = [&](auto& scales, float range_min, float range_max) { - // e8m0_t is basically an exponent of float32 - ck_tile::HostTensor pow2(scales.get_lengths()); - ck_tile::FillUniformDistributionIntegerValue{ - range_min, range_max, fill_seed(gen)}(pow2); - scales.ForEach([&](auto& self, const auto& i) { - self(i) = static_cast(std::exp2(pow2(i))); - }); - }; - - ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(a_host); - ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(b_host); - gen_scales(scale_a_host, -2, 2); - gen_scales(scale_b_host, -2, 2); - - // Compute effective XdlPack sizes based on GemmConfig tile dimensions - constexpr ck_tile::index_t MPerXdl = GemmConfig::M_Warp_Tile; - constexpr ck_tile::index_t NPerXdl = GemmConfig::N_Warp_Tile; - constexpr ck_tile::index_t KPerXdl = GemmConfig::K_Warp_Tile; - constexpr ck_tile::index_t MIterPerWarp = - GemmConfig::M_Tile / (GemmConfig::M_Warp * MPerXdl); - constexpr ck_tile::index_t NIterPerWarp = - GemmConfig::N_Tile / (GemmConfig::N_Warp * NPerXdl); - constexpr ck_tile::index_t KIterPerWarp = GemmConfig::K_Tile / KPerXdl; - - constexpr ck_tile::index_t MXdlPackEff = - (MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1; - constexpr ck_tile::index_t NXdlPackEff = - (NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1; - constexpr ck_tile::index_t KXdlPackEff = - (KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1; - - constexpr ck_tile::index_t XdlMNThread = GemmConfig::M_Warp_Tile; - constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread; - - // Pack scales into int32_t for GPU consumption - auto scale_a_packed = - ck_tile::packScalesMNxK(scale_a_host, - true); - auto scale_b_packed = - ck_tile::packScalesMNxK(scale_b_host, - false); - - const auto b_host_for_device = [&]() { - if constexpr(GemmConfig::Preshuffle) - if constexpr(GemmConfig::TiledMMAPermuteN) - return ck_tile::shuffle_b_permuteN(b_host); - else - return ck_tile::shuffle_b(b_host); - else - return b_host; - }(); - - const auto scale_a_host_for_device = [&]() { - if constexpr(GemmConfig::Preshuffle) - return ck_tile::preShuffleScale(scale_a_host, true); - else - return scale_a_packed; - }(); - - constexpr ck_tile::index_t XdlNThread = GemmConfig::N_Warp_Tile; - constexpr ck_tile::index_t NPerBlock = GemmConfig::N_Tile; - constexpr ck_tile::index_t NWarp = GemmConfig::N_Warp; - - const auto scale_b_host_for_device = [&]() { - if constexpr(GemmConfig::Preshuffle) - if constexpr(GemmConfig::TiledMMAPermuteN) - return ck_tile::preShuffleScalePermuteN( - scale_b_host, false); - else - return ck_tile::preShuffleScale(scale_b_host, false); - else - return scale_b_packed; - }(); - - ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_dev_buf(b_host_for_device.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem scale_a_dev_buf( - scale_a_host_for_device.get_element_space_size_in_bytes()); - ck_tile::DeviceMem scale_b_dev_buf( - scale_b_host_for_device.get_element_space_size_in_bytes()); - - a_dev_buf.ToDevice(a_host.data()); - b_dev_buf.ToDevice(b_host_for_device.data()); - c_dev_buf.SetZero(); - scale_a_dev_buf.ToDevice(scale_a_host_for_device.data()); - scale_b_dev_buf.ToDevice(scale_b_host_for_device.data()); - - ScaleM scale_m(reinterpret_cast(scale_a_dev_buf.GetDeviceBuffer())); - ScaleN scale_n(reinterpret_cast(scale_b_dev_buf.GetDeviceBuffer())); - - MXGemmHostArgs args(a_dev_buf.GetDeviceBuffer(), - b_dev_buf.GetDeviceBuffer(), - c_dev_buf.GetDeviceBuffer(), - k_batch, - M, - N, - K, - stride_A, - stride_B, - stride_C, - scale_m, - scale_n); - - mx_gemm_calc(args, ck_tile::stream_config{}); - - c_dev_buf.FromDevice(c_host.data()); - - ck_tile::HostTensor c_ref( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - c_ref.SetZero(); - ck_tile:: - reference_mx_gemm( - a_host, b_host, c_ref, scale_a_host, scale_b_host); - - const float max_accumulated_value = ck_tile::type_convert(c_ref.max()); - const auto rtol_atol = calculate_rtol_atol_mx( - K, max_accumulated_value); - const double rtol = rtol_atol.at(ck_tile::number<0>{}); - const double atol = rtol_atol.at(ck_tile::number<1>{}); - - bool pass = ck_tile::check_err(c_host, c_ref, "MX GEMM: Incorrect results!", rtol, atol); - - EXPECT_TRUE(pass); - } -}; diff --git a/test/ck_tile/grouped_gemm_mx/CMakeLists.txt b/test/ck_tile/grouped_gemm_mx/CMakeLists.txt index 42bcb2a0ae..3cd34fe00d 100644 --- a/test/ck_tile/grouped_gemm_mx/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_mx/CMakeLists.txt @@ -11,13 +11,33 @@ endif() # Currently TDM is only supported on gfx1250 if(GPU_TARGETS MATCHES "gfx1250") - add_gtest_executable(test_ck_tile_grouped_gemm_mx_tdm test_mx_grouped_gemm.cpp) + add_gtest_executable(test_ck_tile_grouped_gemm_mx_tdm test_mx_grouped_gemm_wmma_tdm.cpp) # target_compile_options(test_ck_tile_grouped_gemm_mx_tdm PRIVATE --save-temps) add_gtest_executable(test_ck_tile_grouped_gemm_mx_flatmm_tdm test_grouped_gemm_mx_flatmm_tdm.cpp) target_compile_options(test_ck_tile_grouped_gemm_mx_flatmm_tdm PRIVATE ${GROUPED_MX_FLATMM_COMPILE_OPTIONS}) endif() +if(GPU_TARGETS MATCHES "gfx950") + add_gtest_executable(test_ck_tile_grouped_gemm_mx_comp_async test_mx_grouped_gemm_comp_async.cpp) + # target_compile_options(test_ck_tile_grouped_gemm_mx_tdm PRIVATE --save-temps) + + # Large-tensor / decomposition cases. Built so it does not bitrot, but NOT + # registered with ctest (plain add_executable, no add_test) -> excluded from the default CI + # test pass; run explicitly (mirrors the CK *_large_cases convention). Allocates multi-GB + # device buffers to exercise the int32 element-count / decomposition boundary. + set(_mx_large_cases_src test_mx_grouped_gemm_comp_async_large_cases.cpp) + set_source_files_properties(${_mx_large_cases_src} PROPERTIES LANGUAGE HIP) + add_executable(test_ck_tile_grouped_gemm_mx_comp_async_large_cases ${_mx_large_cases_src}) + set_property(TARGET test_ck_tile_grouped_gemm_mx_comp_async_large_cases + PROPERTY HIP_ARCHITECTURES ${SUPPORTED_GPU_TARGETS}) + target_compile_options(test_ck_tile_grouped_gemm_mx_comp_async_large_cases + PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(test_ck_tile_grouped_gemm_mx_comp_async_large_cases + PRIVATE gtest_main getopt::getopt) + add_dependencies(tests test_ck_tile_grouped_gemm_mx_comp_async_large_cases) +endif() + if(GPU_TARGETS MATCHES "gfx950|gfx1250") add_gtest_executable(test_ck_tile_grouped_gemm_mx_flatmm_non_tdm test_grouped_gemm_mx_flatmm_non_tdm.cpp) diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm.cpp b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm.cpp deleted file mode 100644 index 3b5c9d6018..0000000000 --- a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm.cpp +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "gtest/gtest.h" - -#include "ck_tile/host.hpp" -#include "test_mx_grouped_gemm_util.hpp" - -using F8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using F16 = ck_tile::half_t; -using F32 = float; -using BF16 = ck_tile::bf16_t; -using Row = ck_tile::tensor_layout::gemm::RowMajor; -using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -using True = ck_tile::bool_constant; -using False = ck_tile::bool_constant; -using E8M0 = ck_tile::e8m0_t; -using Intrawave = ck_tile::integral_constant; -using CompTDMV1 = ck_tile::integral_constant; -using CompTDMV2 = ck_tile::integral_constant; -template -using ScaleBS = ck_tile::integral_constant; - -// clang-format off -using KernelTypes = ::testing::Types< - // ALayout, BLayout, CLayout, ADataType, BDataType, AScaleDataType, BScaleDataType, AccDataType, CDataType, Persistent, Scheduler, PipelineType, ScaleBlockSize -std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV1, ScaleBS<32>>, -std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV1, ScaleBS<32>>, -std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV1, ScaleBS<32>>, -std::tuple< Col, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV1, ScaleBS<32>>, -std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV2, ScaleBS<32>>, -std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV2, ScaleBS<32>>, -std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV2, ScaleBS<32>>, -std::tuple< Col, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV2, ScaleBS<32>> ->; -// clang-format on - -TYPED_TEST_SUITE(TestCkTileMxGroupedGemm, KernelTypes); - -#include "test_mx_grouped_gemm_ut_cases.inc" diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_comp_async.cpp b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_comp_async.cpp new file mode 100644 index 0000000000..5c5de3a1b3 --- /dev/null +++ b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_comp_async.cpp @@ -0,0 +1,26 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_mx_grouped_gemm_util.hpp" +#include "test_mx_grouped_gemm_pipeline_kernel_types.hpp" + +template +class TestCkTileMxGemmPipelineCompAsync + : public TestCkTileMxGroupedGemm> +{ + public: + static constexpr bool check_data_type() { return true; } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsync + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsync, KernelTypesMxGemmCompAsync); + +#include "test_mx_grouped_gemm_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_comp_async_large_cases.cpp b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_comp_async_large_cases.cpp new file mode 100644 index 0000000000..5241812ed9 --- /dev/null +++ b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_comp_async_large_cases.cpp @@ -0,0 +1,34 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Large-tensor / decomposition cases for the fp4 (a4w4) grouped MX GEMM (ROCM-22075). +// +// This is a SEPARATE executable from test_ck_tile_grouped_gemm_mx_comp_async and is intentionally +// NOT registered with ctest (its CMake target uses add_executable, not add_gtest_executable), so +// it is excluded from the default CI test pass. It is run explicitly (mirroring the CK +// *_large_cases / RUN_*_LARGE_CASES_TESTS convention) because it allocates multi-GB device buffers +// (per-group C ~2.5 GB) to exercise the int32 element-count / decomposition boundary. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_mx_grouped_gemm_util.hpp" +#include "test_mx_grouped_gemm_pipeline_kernel_types.hpp" + +template +class TestCkTileMxGemmPipelineCompAsync + : public TestCkTileMxGroupedGemm> +{ + public: + static constexpr bool check_data_type() { return true; } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsync + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsync, KernelTypesMxGemmCompAsync); + +#include "test_mx_grouped_gemm_largeM_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_largeM_cases.inc b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_largeM_cases.inc new file mode 100644 index 0000000000..0fc2dd1872 --- /dev/null +++ b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_largeM_cases.inc @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +// Large-tensor decomposition / INT_MAX element-count validation for fp4 (a4w4) grouped MX GEMM +// These cases exercise the all-GPU validation path (RunAllGpu): GPU data +// init, GPU reference, GPU compare. They apply only to the fp4 non-persistent CompAsync kernel +// row; for every other type in the suite they skip (RunAllGpu is fp4-only and the discarded +// if-constexpr branch is never instantiated for non-fp4 types). +// +// These tests are deliberately NOT registered with ctest (see the dedicated *_large_cases target +// in CMakeLists.txt) so they do not run in the default CI test pass; they are run explicitly +// (mirroring the CK *_large_cases / RUN_*_LARGE_CASES_TESTS convention). + +// Stage A: cross-validate the new GPU reference against the trusted host reference on identical +// small shapes. Both Run() (host reference + check_err) and RunAllGpu() (GPU reference + +// gpu_verify) must pass before the GPU reference can be trusted at large scale. +TYPED_TEST(TEST_SUITE_NAME, GpuRefCrossValidate_Small) +{ + using ADataType = typename TestFixture::ADataType; + if constexpr(std::is_same_v && + TestFixture::PipelineType == MxGemmPipelineType::CompAsync && + !TestFixture::Persistent) + { + const int group_count = 2; + const int kbatch = 1; + const std::vector Ms{256, 512}; + const std::vector Ns{512, 1024}; + const std::vector Ks{512, 512}; + + // Trusted host reference path first, then the new all-GPU path on the same shapes. + this->Run(Ms, Ns, Ks, kbatch, group_count); + this->RunAllGpu(Ms, Ns, Ks, kbatch, group_count); + } + else + { + GTEST_SKIP() << "GPU-reference cases apply only to fp4 non-persistent CompAsync."; + } +} + +// Stage B: the decomposition / int32-overflow case. Minimal shape that crosses the boundary: +// 2 groups, each per-group M*N = 100352*12288 = 1,233,125,376 < INT_MAX (and C byte span +// 2,466,250,752 > INT_MAX), aggregate total M*N = 2,466,250,752 > INT_MAX. This proves the +// host M-decomposition is what keeps every per-group buffer int32-safe (a single fused tensor +// would overflow the int32 element count), while the kernel still addresses C correctly even +// though the per-group C *byte* span exceeds INT_MAX. K is kept small (compute is irrelevant to +// the addressing property under test). +TYPED_TEST(TEST_SUITE_NAME, LargeM_decomposition_int32) +{ + using ADataType = typename TestFixture::ADataType; + if constexpr(std::is_same_v && + TestFixture::PipelineType == MxGemmPipelineType::CompAsync && + !TestFixture::Persistent) + { + const int group_count = 2; + const int kbatch = 1; + const std::vector Ms(group_count, 100352); + const std::vector Ns(group_count, 12288); + const std::vector Ks(group_count, 512); + + this->RunAllGpu(Ms, Ns, Ks, kbatch, group_count); + } + else + { + GTEST_SKIP() << "GPU-reference cases apply only to fp4 non-persistent CompAsync."; + } +} diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_pipeline_kernel_types.hpp b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_pipeline_kernel_types.hpp new file mode 100644 index 0000000000..95b1cf4387 --- /dev/null +++ b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_pipeline_kernel_types.hpp @@ -0,0 +1,78 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_mx_grouped_gemm_util.hpp" + +using F4 = ck_tile::pk_fp4_t; +using F8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using F16 = ck_tile::half_t; +using F32 = float; +using BF16 = ck_tile::bf16_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; +using E8M0 = ck_tile::e8m0_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +using CompTDMV1 = ck_tile::integral_constant; +using CompTDMV2 = ck_tile::integral_constant; +using CompAsync = ck_tile::integral_constant; +using CompEightWaves = + ck_tile::integral_constant; +using WeightPreshuffle = + ck_tile::integral_constant; + +using I16 = ck_tile::number<16>; +using I32 = ck_tile::number<32>; +using I64 = ck_tile::number<64>; +using I128 = ck_tile::number<128>; +using I256 = ck_tile::number<256>; +using I512 = ck_tile::number<512>; + +template +using ScaleBS = ck_tile::integral_constant; + +// clang-format off +// MX GEMM kernel types using TDM pipeline with scale support +// Tuple format: +// ALayout, BLayout, CLayout, ADataType, BDataType, AScaleDataType, BScaleDataType, AccDataType, CDataType, Persistent, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, N_TileSize, PipelineType +using KernelTypesMxGemmCompTDMWmma = ::testing::Types< + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV1, ScaleBS<32>>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV1, ScaleBS<32>>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV1, ScaleBS<32>>, + std::tuple< Col, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV1, ScaleBS<32>>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV2, ScaleBS<32>>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV2, ScaleBS<32>>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV2, ScaleBS<32>>, + std::tuple< Col, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV2, ScaleBS<32>> +>; + +using KernelTypesMxGemmCompAsync = ::testing::Types< + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I256, I16, I16, CompAsync, ScaleBS<32>>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, False, I64, I64, I256, I16, I16, CompAsync, ScaleBS<32>>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, True, I64, I64, I256, I16, I16, CompAsync, ScaleBS<32>>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, True, I64, I64, I256, I16, I16, CompAsync, ScaleBS<32>>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I128, I256, I128, I16, I16, CompEightWaves, ScaleBS<32>>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, False, I128, I256, I128, I16, I16, CompEightWaves, ScaleBS<32>>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, True, I128, I256, I128, I16, I16, CompEightWaves, ScaleBS<32>>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, True, I128, I256, I128, I16, I16, CompEightWaves, ScaleBS<32>>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I128, I256, I256, I16, I16, WeightPreshuffle, ScaleBS<32>>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, False, I128, I512, I256, I16, I16, WeightPreshuffle, ScaleBS<32>>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, True, I128, I256, I256, I16, I16, WeightPreshuffle, ScaleBS<32>>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, True, I128, I512, I256, I16, I16, WeightPreshuffle, ScaleBS<32>>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I128, I256, I256, I16, I16, WeightPreshuffle, ScaleBS<32>, True>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, False, I128, I512, I256, I16, I16, WeightPreshuffle, ScaleBS<32>, True>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, True, I128, I256, I256, I16, I16, WeightPreshuffle, ScaleBS<32>, True>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, True, I128, I512, I256, I16, I16, WeightPreshuffle, ScaleBS<32>, True> +>; +// clang-format on diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_ut_cases.inc b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_ut_cases.inc index c41350db33..ec4a7183e3 100644 --- a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_ut_cases.inc +++ b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_ut_cases.inc @@ -3,7 +3,7 @@ #pragma once -TYPED_TEST(TestCkTileMxGroupedGemm, Basic) +TYPED_TEST(TEST_SUITE_NAME, Basic) { const int group_count = 4; const int kbatch = 1; @@ -14,8 +14,8 @@ TYPED_TEST(TestCkTileMxGroupedGemm, Basic) for(int i = 0; i < group_count; i++) { Ms.push_back(256 + 256 * i); - Ns.push_back(256 + 512 * i); - Ks.push_back(512 + 128 * i); + Ns.push_back(512 + 512 * i); + Ks.push_back(512 + TestFixture::K_Tile * i); } this->Run(Ms, Ns, Ks, kbatch, group_count); diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_util.hpp index 3d22b6409a..797196911c 100644 --- a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_util.hpp @@ -14,11 +14,44 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck/library/utility/gpu_verification.hpp" + +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if CK_TILE_USE_WMMA +#if defined(CK_USE_GFX1250) + constexpr bool is_8bit = std::is_same_v || + std::is_same_v || + std::is_same_v; + constexpr bool is_mxtype = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32 && is_mxtype) + { + return 128; + } + else + { + return is_8bit ? 64 : 32; + } +#else + return 16; +#endif +#else + if constexpr(M_Warp_Tile == 32) + return 64; + else + return 128; +#endif +} enum struct MxGemmPipelineType { CompTDMV1, - CompTDMV2 + CompTDMV2, + CompAsync, + CompEightWaves, + WeightPreshuffle }; template @@ -42,61 +75,137 @@ struct MxGemmPipelineTypeSelector static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV2"; } }; -/** - * @brief Pre-shuffle scale buffer for gfx1250 wmma mx scale instruction. - * - * Reorganizes the scale data from row-major (MN x K) layout to the hardware-specific - * layout expected by the gfx1250 wmma instruction. - * - * @tparam ScaleType Scale data type (e.g., e8m0_t) - * @tparam ScaleBlockSize The block size for microscaling (e.g., 32) - * @tparam KStride Whether K is the fast-moving dimension - */ -template -void preShuffleScaleBuffer_gfx1250(const ScaleType* src, - ScaleType* dst, - ck_tile::index_t MN, - ck_tile::index_t K) +template +struct MxGemmPipelineTypeSelector { - static_assert(ScaleBlockSize == 32 && sizeof(ScaleType) == 1, - "wrong! only support 8-bit scale with ScaleBlockSize=32"); + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync; + using pipeline = ck_tile::GemmPipelineAgBgCrCompAsync; - constexpr ck_tile::index_t MPerXdlops = 16; - constexpr ck_tile::index_t KPerXdlops = 128; + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsync"; } +}; - int MNPack = 2; - int KPack = 1; +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + using pipeline = ck_tile::GemmPipelineAgBgCrCompAsyncEightWaves; - int MNStep = MPerXdlops; - int KStep = KPerXdlops / ScaleBlockSize; + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompEightWaves"; } +}; - int K0 = K / KPack / KStep; +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using pipeline = ck_tile::MXGemmPreshufflePipelineAGmemBGmemCRegV1; - for(int mn = 0; mn < MN; ++mn) + static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffle"; } +}; + +template +struct MxGemmEpilogueTypeSelector +{ +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = ck_tile::TdmEpilogue; +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = ck_tile::TdmEpilogue; +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = ck_tile::CShuffleEpilogue; +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = ck_tile::CShuffleEpilogue; +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = std::conditional_t, + ck_tile::CShuffleEpilogue>; +}; + +template +struct Config +{ + static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_; + static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_; + static constexpr ck_tile::index_t N_Tile = N_Tile_; + static constexpr ck_tile::index_t N_Warp = N_Warp_; + static constexpr ck_tile::index_t BContiguousItemsPerAccess = + std::is_same_v ? 32 : 16; +}; + +// Deterministic per-element hash RNG for GPU data init. Returns a float in [-3, 3). +// The generic `fill_tensor_uniform_rand_fp_values` filler is NOT valid for ck_tile::pk_fp4_t +// (it converts a single float and duplicates it into both nibbles, and special-cases only the +// classic ck::f4x2_pk_t). We need two independent fp4 values per byte, so we fill directly. +// The narrow [-3,3) range keeps the fp16 GEMM output from overflowing at K up to 4096 (with the +// [0.25,1.0] scales used in RunAllGpu, worst case K*9 = 36864 < 65504). +__device__ inline float mx_fp4_fill_rand(unsigned int seed, unsigned long long idx) +{ + // splitmix64-style avalanche; deterministic given (seed, idx). + unsigned long long z = (idx + 1ULL) * 0x9E3779B97F4A7C15ULL + + static_cast(seed) * 0xD1B54A32D192ED03ULL; + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ULL; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBULL; + z ^= z >> 31; + const float u = + static_cast((z >> 40) & 0xFFFFFFULL) / static_cast(0x1000000); // [0,1) + return u * 6.0f - 3.0f; // [-3,3) +} + +// Fill a packed-fp4 buffer with two independent, deterministic random fp4 values per byte. +// `num_packed` is the number of pk_fp4_t elements (= total fp4 values / 2). +__global__ void +fill_pk_fp4_uniform_kernel(ck_tile::pk_fp4_t* __restrict__ ptr, long num_packed, unsigned int seed) +{ + const long idx0 = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const long nthr = static_cast(gridDim.x) * blockDim.x; + for(long i = idx0; i < num_packed; i += nthr) { - int iMNRepeat = mn / (MNStep * MNPack); - int tempmn = mn % (MNStep * MNPack); - - for(int k = 0; k < K; ++k) - { - int iKRepeat = k / (KStep * KPack); - int tempk = k % (KStep * KPack); - - int outputIndex = (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) + - (iKRepeat * KStep * KPack) * (MNStep * MNPack) + - tempmn * (KStep * KPack) + tempk; - - if constexpr(KStride) - { - dst[outputIndex] = src[mn * K + k]; - } - else - dst[outputIndex] = src[k * MN + mn]; - } + const float lo_f = rintf(mx_fp4_fill_rand(seed, static_cast(i) * 2ULL)); + const float hi_f = + rintf(mx_fp4_fill_rand(seed, static_cast(i) * 2ULL + 1ULL)); + const auto lo = ck_tile::float_to_mxfp4(lo_f, 1.0f); + const auto hi = ck_tile::float_to_mxfp4(hi_f, 1.0f); + ptr[i] = ck_tile::pk_fp4_t::_pack(lo, hi); } } -template +inline void fill_pk_fp4_uniform(ck_tile::pk_fp4_t* ptr, + long num_packed, + unsigned int seed, + hipStream_t stream = nullptr) +{ + constexpr int threads = 256; + constexpr long max_blocks = 65536; // grid-stride cap + const long needed = (num_packed + threads - 1) / threads; + const long blocks = needed < max_blocks ? needed : max_blocks; + fill_pk_fp4_uniform_kernel<<(blocks)), dim3(threads), 0, stream>>>( + ptr, num_packed, seed); + ck_tile::hip_check_error(hipGetLastError()); +} + +template class TestCkTileMxGroupedGemm : public ::testing::Test { protected: @@ -111,9 +220,36 @@ class TestCkTileMxGroupedGemm : public ::testing::Test using CDataType = std::tuple_element_t<8, Tuple>; using PersistentType = std::tuple_element_t<9, Tuple>; static constexpr bool Persistent = PersistentType::value; - static constexpr auto Scheduler = std::tuple_element_t<10, Tuple>::value; - static constexpr auto PipelineType = std::tuple_element_t<11, Tuple>::value; - static constexpr ck_tile::index_t ScaleBlockSize = std::tuple_element_t<12, Tuple>::value; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr auto PipelineType = std::tuple_element_t<15, Tuple>::value; + static constexpr ck_tile::index_t ScaleBlockSize = std::tuple_element_t<16, Tuple>::value; + static constexpr bool PermuteN = + ck_tile::tuple_element_or_default_t::value; + + static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<10, Tuple>{}; + static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<11, Tuple>{}; + static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<12, Tuple>{}; + + static constexpr ck_tile::index_t M_Warp_Tile = std::tuple_element_t<13, Tuple>{}; + static constexpr ck_tile::index_t N_Warp_Tile = std::tuple_element_t<14, Tuple>{}; + static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::max( + get_k_warp_tile(), get_k_warp_tile()); + + static constexpr ck_tile::index_t M_Warp = + PipelineType == MxGemmPipelineType::WeightPreshuffle + ? 1 + : (PipelineType == MxGemmPipelineType::CompEightWaves ? 4 : 2); + static constexpr ck_tile::index_t N_Warp = + PipelineType == MxGemmPipelineType::WeightPreshuffle ? 4 : 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr int kBlockPerCu = 1; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool Preshuffle = PipelineType == MxGemmPipelineType::WeightPreshuffle; // No D tensors for this test using DsLayout = ck_tile::tuple<>; @@ -123,42 +259,27 @@ class TestCkTileMxGroupedGemm : public ::testing::Test using AComputeDataType = ADataType; using BComputeDataType = BDataType; - struct GroupedGemKernelParam_Wmma - { - static const bool kPadM = false; - static const bool kPadN = false; - static const bool kPadK = false; - - static const int kBlockPerCu = 1; - static const ck_tile::index_t M_Tile = 64; - static const ck_tile::index_t N_Tile = 64; - static const ck_tile::index_t K_Tile = 128; - - static const ck_tile::index_t M_Warp = 2; - static const ck_tile::index_t N_Warp = 2; - static const ck_tile::index_t K_Warp = 1; - - static const ck_tile::index_t M_Warp_Tile = 32; - static const ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 128; - }; - using mx_grouped_gemm_kargs = ck_tile::MxGroupedGemmHostArgs<>; std::size_t get_workspace_size(const std::vector& gemm_descs) { return gemm_descs.size() * sizeof(ck_tile::MxGemmTransKernelArg<>); } - template + template bool invoke_mx_grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr) { - constexpr bool preshuffle = false; constexpr bool DoubleSmemBuffer = true; // TDM pipeline requires double smem buffer +#if defined(CK_USE_GFX1250) + constexpr ck_tile::index_t BlockedXDLNPerWarp = 1; constexpr bool TransposeC = std::is_same_v && - GroupedGemKernelParam::M_Warp_Tile == GroupedGemKernelParam::N_Warp_Tile; + M_Warp_Tile == N_Warp_Tile; +#elif defined(CK_USE_GFX950) + constexpr ck_tile::index_t BlockedXDLNPerWarp = Preshuffle ? 2 : 1; + constexpr bool TransposeC = false; +#endif static constexpr bool StructuredSparsity = false; static constexpr bool NumWaveGroup = 1; @@ -166,21 +287,15 @@ class TestCkTileMxGroupedGemm : public ::testing::Test constexpr ck_tile::index_t TileParitionerM01 = 4; using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + Preshuffle>; using UniversalGemmProblem = ck_tile::MxGemmPipelineProblem::pipeline; - using GemmEpilogue = ck_tile::TdmEpilogue< + using GemmEpilogueProblem = std::conditional_t< + Preshuffle && PermuteN, + ck_tile::PermuteNEpilogueProblem, /*VectorSizeC_*/ ck_tile::CShuffleEpilogueProblem>; + 1, /*kNumWaveGroups_*/ + false, /*FixedVectorSize_*/ + 1, /*VectorSizeC_*/ + BlockedXDLNPerWarp, /*BlockedXDLN_PerWarp_*/ + DoubleSmemBuffer, /*DoubleSmemBuffer*/ + AComputeDataType, /*AComputeDataType_*/ + BComputeDataType, /*BComputeDataType_*/ + !Preshuffle>>; + + using GemmEpilogue = typename MxGemmEpilogueTypeSelector::epilogue; using Kernel = ck_tile::MxGroupedGemmKernel; @@ -266,7 +405,7 @@ class TestCkTileMxGroupedGemm : public ::testing::Test ck_tile::ignore = ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{}, grids, blocks, @@ -303,36 +442,9 @@ class TestCkTileMxGroupedGemm : public ::testing::Test return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } - static constexpr bool check_data_type() - { - - // Validate scale type / data type combination - constexpr bool a_is_f4 = std::is_same_v; - constexpr bool b_is_f4 = std::is_same_v; - constexpr bool a_scale_e8m0 = std::is_same_v; - constexpr bool b_scale_e8m0 = std::is_same_v; - if constexpr(!a_is_f4 && !a_scale_e8m0) - return false; - if constexpr(!b_is_f4 && !b_scale_e8m0) - return false; - - // Check hardware WMMA support for the fixed warp tile (32x32x128) -#if defined(CK_USE_GFX1250) - return ck_tile::has_wmma_traits_v; -#else - return false; -#endif - } - void SetUp() override { - if constexpr(!check_data_type()) + if constexpr(!Derived::check_data_type()) { GTEST_SKIP() << "Unsupported data type / layout combination for mx_grouped_gemm."; } @@ -345,7 +457,7 @@ class TestCkTileMxGroupedGemm : public ::testing::Test const int kbatch = 1, const int group_count = 16) { - if constexpr(!check_data_type()) + if constexpr(!Derived::check_data_type()) return; using namespace ck_tile::literals; @@ -445,8 +557,42 @@ class TestCkTileMxGroupedGemm : public ::testing::Test << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc << " KBatch: " << kbatch << std::endl; - ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); - ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + // For pk_fp4_t each byte packs two 4-bit elements; the generic filler + // converts a single float and duplicates it into both nibbles. + // Generate two independent random values per byte instead. + if constexpr(std::is_same_v) + { + std::mt19937 gen(11939); + std::uniform_real_distribution dis(-5.f, 5.f); + for(auto& elem : a_m_k_tensors[i].mData) + { + auto lo = ck_tile::float_to_mxfp4(std::round(dis(gen)), 1.f); + auto hi = ck_tile::float_to_mxfp4(std::round(dis(gen)), 1.f); + elem = ck_tile::pk_fp4_t::_pack(lo, hi); + } + } + else + { + ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11939}( + a_m_k_tensors[i]); + } + + if constexpr(std::is_same_v) + { + std::mt19937 gen(11940); + std::uniform_real_distribution dis(-5.f, 5.f); + for(auto& elem : b_k_n_tensors[i].mData) + { + auto lo = ck_tile::float_to_mxfp4(std::round(dis(gen)), 1.f); + auto hi = ck_tile::float_to_mxfp4(std::round(dis(gen)), 1.f); + elem = ck_tile::pk_fp4_t::_pack(lo, hi); + } + } + else + { + ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11940}( + b_k_n_tensors[i]); + } // K must be a multiple of ScaleBlockSize if(K % ScaleBlockSize != 0) @@ -454,15 +600,13 @@ class TestCkTileMxGroupedGemm : public ::testing::Test GTEST_SKIP() << "K must be multiple of ScaleBlockSize for MX GEMM"; } const ck_tile::index_t num_scale_k = K / ScaleBlockSize; - if(num_scale_k % (GroupedGemKernelParam_Wmma::K_Warp_Tile / ScaleBlockSize) != 0) + if(num_scale_k % (K_Warp_Tile / ScaleBlockSize) != 0) { - GTEST_SKIP() << "K must be a multiple of K_Warp_Tile (" - << GroupedGemKernelParam_Wmma::K_Warp_Tile + GTEST_SKIP() << "K must be a multiple of K_Warp_Tile (" << K_Warp_Tile << ") for MX GEMM. Pad the scale data."; } const ck_tile::index_t scale_padded_M = ck_tile::integer_least_multiple( - static_cast(M), - static_cast(GroupedGemKernelParam_Wmma::M_Warp_Tile)); + static_cast(M), static_cast(M_Tile)); ck_tile::HostTensor scale_a( {static_cast(scale_padded_M), static_cast(num_scale_k)}, @@ -515,6 +659,9 @@ class TestCkTileMxGroupedGemm : public ::testing::Test } // Pre-shuffle scale buffers for the hardware +#if defined(CK_USE_GFX1250) + constexpr ck_tile::index_t NXdlPackEff = 1; + ck_tile::HostTensor scale_a_shuffled( {static_cast(scale_padded_M), static_cast(num_scale_k)}, {static_cast(num_scale_k), static_cast(1)}); @@ -523,11 +670,6 @@ class TestCkTileMxGroupedGemm : public ::testing::Test {static_cast(N), static_cast(num_scale_k)}, {static_cast(num_scale_k), static_cast(1)}); - std::cout << " scale_a: [scale_padded_M = " << scale_padded_M - << ", num_scale_k = " << num_scale_k << "]." << std::endl; - std::cout << " scale_b: [N = " << N << ", num_scale_k = " << num_scale_k << "]." - << std::endl; - // Pre-shuffle for gfx1250 (WaveSize=32, WMMA) preShuffleScaleBuffer_gfx1250( scale_a.mData.data(), scale_a_shuffled.mData.data(), scale_padded_M, num_scale_k); @@ -536,19 +678,95 @@ class TestCkTileMxGroupedGemm : public ::testing::Test // where N is the fast-changing dimension for col-major B preShuffleScaleBuffer_gfx1250( scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k); +#elif defined(CK_USE_GFX950) + constexpr ck_tile::index_t MPerXdl = M_Warp_Tile; + constexpr ck_tile::index_t NPerXdl = N_Warp_Tile; + constexpr ck_tile::index_t KPerXdl = K_Warp_Tile; + constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * MPerXdl); + constexpr ck_tile::index_t NIterPerWarp = N_Tile / (N_Warp * NPerXdl); + constexpr ck_tile::index_t KIterPerWarp = K_Tile / KPerXdl; + + constexpr ck_tile::index_t MXdlPackEff = + (MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t NXdlPackEff = + (NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t KXdlPackEff = + (KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1; + + constexpr ck_tile::index_t XdlMNThread = M_Warp_Tile; + constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread; + + ck_tile::HostTensor scale_a_shuffled( + {static_cast(scale_padded_M / MXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), + static_cast(1)}); + + ck_tile::HostTensor scale_b_shuffled( + {static_cast(N / NXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), + static_cast(1)}); + + ck_tile:: + preShuffleScaleBuffer_gfx950( + scale_a.mData.data(), + scale_a_shuffled.mData.data(), + scale_padded_M, + num_scale_k, + true); + + if constexpr(Preshuffle && PermuteN) + { + ck_tile::preShuffleScaleBufferPermuteN_gfx950( + scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k, true); + } + else + { + ck_tile:: + preShuffleScaleBuffer_gfx950( + scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k, true); + } +#endif + + std::cout << " scale_a: [scale_padded_M = " << scale_padded_M + << ", num_scale_k = " << num_scale_k << "]." << std::endl; + std::cout << " scale_b: [N = " << N << ", num_scale_k = " << num_scale_k << "]." + << std::endl; scale_a_tensors.push_back(scale_a_shuffled); scale_b_tensors.push_back(scale_b_shuffled); + using GemmConfig = Config; + + const auto b_host_for_dev = [&]() { + if constexpr(Preshuffle) + { + if constexpr(PermuteN) + { + return ck_tile::shuffle_b_permuteN( + b_k_n_tensors[i]); + } + else + { + return ck_tile::shuffle_b(b_k_n_tensors[i]); + } + } + else + { + return b_k_n_tensors[i]; + } + }(); + a_m_k_dev_buf.push_back(std::make_unique( a_m_k_tensors[i].get_element_space_size_in_bytes())); b_k_n_dev_buf.push_back(std::make_unique( - b_k_n_tensors[i].get_element_space_size_in_bytes())); + b_host_for_dev.get_element_space_size_in_bytes())); c_m_n_dev_buf.push_back(std::make_unique( c_m_n_tensors[i].get_element_space_size_in_bytes())); a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); - b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_host_for_dev.data()); c_m_n_dev_buf[i]->SetZero(); c_m_n_tensors[i].SetZero(); @@ -584,7 +802,7 @@ class TestCkTileMxGroupedGemm : public ::testing::Test ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(gemm_descs)); - if(!invoke_mx_grouped_gemm( + if(!invoke_mx_grouped_gemm( gemm_descs, ck_tile::stream_config{nullptr, false, 1}, gemm_workspace.GetDeviceBuffer())) @@ -628,4 +846,321 @@ class TestCkTileMxGroupedGemm : public ::testing::Test } EXPECT_TRUE(pass); } + + // All-GPU validation path for the fp4 (pk_fp4_t) MX grouped GEMM. + // + // Unlike Run(), this never materializes the (potentially 39 GB) A/B/C tensors on the host: + // - A/B are generated directly on device with a deterministic fp4 fill. + // - the reference is computed on device by reference_mx_gemm_gpu. + // - the comparison is done on device by ck::profiler::gpu_verify. + // Only the tiny e8m0 scales touch the host (for pre-shuffle + an unshuffled copy that the + // device reference consumes). Groups are processed one at a time to bound peak device memory + // and to make any fault attributable to a specific group. + // + // Per group it logs and asserts that every int32-addressed quantity (M*N, A/B/C worst-case + // element offsets) stays < INT_MAX -- this is the property under test for the host-side + // M-decomposition that keeps each per-group buffer kernel-safe. + void RunAllGpu(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const int kbatch = 1, + const int group_count = 16) + { + if constexpr(!Derived::check_data_type()) + return; + + static_assert(std::is_same_v && + std::is_same_v, + "RunAllGpu currently supports pk_fp4_t A/B only."); + // The GPU reference (reference_mx_gemm_gpu) hardcodes these layouts; guard so it cannot be + // silently misused with a layout it does not handle. + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, + "RunAllGpu / reference_mx_gemm_gpu assume RowMajor-A, ColumnMajor-B, " + "RowMajor-C."); + +#if !defined(CK_USE_GFX950) + (void)Ms; + (void)Ns; + (void)Ks; + (void)kbatch; + (void)group_count; + GTEST_SKIP() << "RunAllGpu requires CK_USE_GFX950."; +#else + using namespace ck_tile::literals; + constexpr long kIntMax = 2147483647L; // INT_MAX + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + return col; + else + return row; + } + else + return stride; + }; + + constexpr ck_tile::index_t psize = ck_tile::numeric_traits::PackedSize; // 2 + static_assert(psize == 2, + "RunAllGpu byte-sizing and reference_mx_gemm_kernel's a_ptr[a_lin/2] " + "addressing assume pk_fp4_t PackedSize == 2."); + + bool pass = true; + long total_MN = 0; + + for(int i = 0; i < group_count; ++i) + { + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + // Strides are K/N here (small); keep them as index_t to match the kernel args, and + // make the size_t->index_t narrowing explicit. + const ck_tile::index_t stride_A = + static_cast(f_get_default_stride(M, K, 0, ALayout{})); // K + const ck_tile::index_t stride_B = + static_cast(f_get_default_stride(K, N, 0, BLayout{})); // K + const ck_tile::index_t stride_C = + static_cast(f_get_default_stride(M, N, 0, CLayout{})); // N + + // Per-group shape guards. Fatal (ASSERT), not GTEST_SKIP: a skip mid-loop would + // silently report success for only a prefix of the groups already validated. + ASSERT_EQ(K % ScaleBlockSize, 0) + << "group " << i << ": K must be a multiple of ScaleBlockSize for MX GEMM"; + const ck_tile::index_t num_scale_k = K / ScaleBlockSize; + ASSERT_EQ(num_scale_k % (K_Warp_Tile / ScaleBlockSize), 0) + << "group " << i << ": K must be a multiple of K_Warp_Tile (" << K_Warp_Tile + << ") for MX GEMM. Pad the scale data."; + const ck_tile::index_t scale_padded_M = ck_tile::integer_least_multiple( + static_cast(M), static_cast(M_Tile)); + + // int32-safety: the property under test for the M-decomposition. The predicate is + // "largest 0-based element offset fits in a signed 32-bit int", i.e. offset <= INT_MAX. + const long MN = static_cast(M) * N; + const long A_elems = static_cast(M) * K; + const long B_elems = static_cast(K) * N; + const long C_off = static_cast(M - 1) * stride_C + (N - 1); + const long A_off = static_cast(M - 1) * stride_A + (K - 1); + const long B_off = static_cast(N - 1) * stride_B + (K - 1); + const long c_bytes = MN * static_cast(sizeof(CDataType)); + std::cout << "[int32-safety] group " << i << " M=" << M << " N=" << N << " K=" << K + << " M*N=" << MN << " A_elems=" << A_elems << " B_elems=" << B_elems + << " C_off=" << C_off << " A_off=" << A_off << " B_off=" << B_off + << " C_bytes=" << c_bytes << " (INT_MAX=" << kIntMax << ")" << std::endl; + // Note (not an assert): the C *byte* span can exceed INT_MAX even when the element + // count is int32-safe. We deliberately let the run proceed -- if any internal byte + // offset overflows, gpu_verify will flag it, which is exactly what we want to discover. + if(c_bytes > kIntMax) + std::cout + << "[int32-safety][note] group " << i << " C byte span (" << c_bytes + << ") exceeds INT_MAX; if verification fails, byte-offset overflow is the " + "prime suspect." + << std::endl; + ASSERT_LE(MN - 1, kIntMax) << "group " << i << " max C element index exceeds INT_MAX"; + ASSERT_LE(C_off, kIntMax) << "group " << i << " C offset exceeds INT_MAX"; + ASSERT_LE(A_off, kIntMax) << "group " << i << " A offset exceeds INT_MAX"; + ASSERT_LE(B_off, kIntMax) << "group " << i << " B offset exceeds INT_MAX"; + total_MN += MN; + + // Device buffers (no big host tensors). Round byte counts up: a stray odd fp4 element + // still occupies a full packed byte. + const long a_bytes = (A_elems + psize - 1) / psize; + const long b_bytes = (B_elems + psize - 1) / psize; + + // Bound peak device memory (A + B + 2*C + scales/workspace slack). Skip cleanly rather + // than aborting via hip_check_error if the device cannot hold one group. + { + std::size_t free_b = 0, total_b = 0; + ck_tile::hip_check_error(hipMemGetInfo(&free_b, &total_b)); + const std::size_t need = static_cast(a_bytes) + + static_cast(b_bytes) + + 2u * static_cast(c_bytes) + (64u << 20); + if(free_b < need) + GTEST_SKIP() << "group " << i << ": insufficient device memory (need " << need + << " B, free " << free_b << " B)"; + } + + auto a_dev = std::make_unique(static_cast(a_bytes)); + auto b_dev = std::make_unique(static_cast(b_bytes)); + auto c_dev = std::make_unique(static_cast(c_bytes)); + auto c_ref_dev = + std::make_unique(static_cast(c_bytes)); + c_dev->SetZero(); + c_ref_dev->SetZero(); + + // GPU fill A/B (deterministic, fp4-correct). Same device buffers feed both the kernel + // and the reference, so the fill need not bit-match any host RNG. Fold the group index + // into the seed so each group gets a distinct data pattern. + fill_pk_fp4_uniform(reinterpret_cast(a_dev->GetDeviceBuffer()), + a_bytes, + 11939u + static_cast(i)); + fill_pk_fp4_uniform(reinterpret_cast(b_dev->GetDeviceBuffer()), + b_bytes, + 11940u + static_cast(i)); + ck_tile::hip_check_error( + hipDeviceSynchronize()); // surface fill faults at the fill site + + // e8m0 scales (tiny, host-built, fixed per-group seed for determinism). The range is + // deliberately narrow ([0.25,1.0] scales, [-3,3) fp4 fill) so that K up to 4096 cannot + // overflow the fp16 output (worst case K*9 = 36864 < 65504); gpu_verify counts matched + // infinities as errors, so an overflow would otherwise be a false failure. + ck_tile::HostTensor scale_a( + {static_cast(scale_padded_M), static_cast(num_scale_k)}, + {static_cast(num_scale_k), static_cast(1)}); + ck_tile::HostTensor scale_b( + {static_cast(N), static_cast(num_scale_k)}, + {static_cast(num_scale_k), static_cast(1)}); + { + std::mt19937 gen(11941u + static_cast(i)); + std::uniform_real_distribution dist(0.25f, 1.0f); + for(auto& s : scale_a.mData) + s = AScaleDataType{dist(gen)}; + for(auto& s : scale_b.mData) + s = BScaleDataType{dist(gen)}; + } + + // gfx950 scale pre-shuffle. NOTE: this must stay in sync with the identical block in + // Run() -- the kernel-input layout and the reference-input layout must agree. + constexpr ck_tile::index_t MPerXdl = M_Warp_Tile; + constexpr ck_tile::index_t NPerXdl = N_Warp_Tile; + constexpr ck_tile::index_t KPerXdl = K_Warp_Tile; + constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * MPerXdl); + constexpr ck_tile::index_t NIterPerWarp = N_Tile / (N_Warp * NPerXdl); + constexpr ck_tile::index_t KIterPerWarp = K_Tile / KPerXdl; + + constexpr ck_tile::index_t MXdlPackEff = + (MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t NXdlPackEff = + (NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t KXdlPackEff = + (KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1; + + constexpr ck_tile::index_t XdlMNThread = M_Warp_Tile; + constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread; + + ck_tile::HostTensor scale_a_shuffled( + {static_cast(scale_padded_M / MXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), + static_cast(1)}); + ck_tile::HostTensor scale_b_shuffled( + {static_cast(N / NXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), + static_cast(1)}); + + ck_tile:: + preShuffleScaleBuffer_gfx950( + scale_a.mData.data(), + scale_a_shuffled.mData.data(), + scale_padded_M, + num_scale_k, + true); + ck_tile:: + preShuffleScaleBuffer_gfx950( + scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k, true); + + // Device scale buffers: shuffled feed the kernel, unshuffled feed the reference. + auto scale_a_shuf_dev = std::make_unique( + scale_a_shuffled.get_element_space_size_in_bytes()); + auto scale_b_shuf_dev = std::make_unique( + scale_b_shuffled.get_element_space_size_in_bytes()); + scale_a_shuf_dev->ToDevice(scale_a_shuffled.data()); + scale_b_shuf_dev->ToDevice(scale_b_shuffled.data()); + + auto scale_a_ref_dev = + std::make_unique(scale_a.get_element_space_size_in_bytes()); + auto scale_b_ref_dev = + std::make_unique(scale_b.get_element_space_size_in_bytes()); + scale_a_ref_dev->ToDevice(scale_a.data()); + scale_b_ref_dev->ToDevice(scale_b.data()); + + // Launch the grouped kernel for this single group. + std::vector gemm_descs; + gemm_descs.push_back(mx_grouped_gemm_kargs(a_dev->GetDeviceBuffer(), + scale_a_shuf_dev->GetDeviceBuffer(), + b_dev->GetDeviceBuffer(), + scale_b_shuf_dev->GetDeviceBuffer(), + {/*ds_ptr*/}, + c_dev->GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + {/*stride_Ds*/}, + stride_C)); + + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + if(!invoke_mx_grouped_gemm( + gemm_descs, + ck_tile::stream_config{nullptr, false, 1}, + gemm_workspace.GetDeviceBuffer())) + { + ADD_FAILURE() << "invoke_mx_grouped_gemm failed for group " << i; + pass = false; + continue; // DeviceMem frees cleanly at loop end; keep validating other groups + } + ck_tile::hip_check_error(hipDeviceSynchronize()); + + // GPU reference on the same device A/B buffers. + ck_tile::reference_mx_gemm_gpu( + reinterpret_cast(a_dev->GetDeviceBuffer()), + reinterpret_cast(b_dev->GetDeviceBuffer()), + reinterpret_cast(scale_a_ref_dev->GetDeviceBuffer()), + reinterpret_cast(scale_b_ref_dev->GetDeviceBuffer()), + reinterpret_cast(c_ref_dev->GetDeviceBuffer()), + M, + N, + K, + num_scale_k, + ScaleBlockSize); + ck_tile::hip_check_error(hipDeviceSynchronize()); + + // GPU verify with explicit MX tolerance (auto tolerance defaults too tight for MX). + const float max_acc = ck::profiler::gpu_reduce_max( + c_ref_dev->GetDeviceBuffer(), static_cast(MN)); + // The reference must be non-degenerate, else error_count==0 is a vacuous pass. + ASSERT_GT(max_acc, 0.0f) << "group " << i << ": GPU reference output is all-zero"; + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_acc); + const auto res = ck::profiler::gpu_verify(c_dev->GetDeviceBuffer(), + c_ref_dev->GetDeviceBuffer(), + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{}), + static_cast(MN)); + + // Positive liveness check on the *device* output. res.all_zero ANDs device- and + // reference-zeroness, and the reference is never zero here, so it cannot detect a no-op + // kernel on its own -- reduce the device buffer directly. + const float c_dev_absmax = ck::profiler::gpu_reduce_max( + c_dev->GetDeviceBuffer(), static_cast(MN)); + + std::cout << "[verify] group " << i << " errors=" << res.error_count + << " max_error=" << res.max_error << " c_dev_absmax=" << c_dev_absmax + << " max_acc=" << max_acc << " rtol=" << rtol_atol.at(ck_tile::number<0>{}) + << " atol=" << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + + EXPECT_EQ(res.error_count, 0ull) << "group " << i << " produced mismatched results"; + EXPECT_GT(c_dev_absmax, 0.0f) << "group " << i << " produced an all-zero device output"; + pass &= (res.error_count == 0 && c_dev_absmax > 0.0f); + // a_dev/b_dev/c_dev/... freed here (unique_ptr) before the next group. + } + + std::cout << "[int32-safety] aggregate total_M*N=" << total_MN << " (INT_MAX=" << kIntMax + << ") -> decomposition is the variable under test" << std::endl; + EXPECT_TRUE(pass); +#endif // CK_USE_GFX950 + } }; diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_wmma_tdm.cpp b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_wmma_tdm.cpp new file mode 100644 index 0000000000..012598fbc1 --- /dev/null +++ b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_wmma_tdm.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_mx_grouped_gemm_util.hpp" +#include "test_mx_grouped_gemm_pipeline_kernel_types.hpp" + +template +class TestCkTileMxGemmPipelineCompTDMWmma + : public TestCkTileMxGroupedGemm> +{ + public: + static constexpr bool check_data_type() + { + using Base = TestCkTileMxGroupedGemm>; + + if constexpr(!is_valid_mx_scale_combination()) + { + return false; + } + +#if defined(CK_USE_GFX1250) + using DeviceIp = ck_tile::gfx125_t; +#else +#error "Unsupported architecture for WMMA MX GEMM" +#endif + + return ck_tile::has_wmma_traits_v::value, + ck_tile::constant::value, + ck_tile::constant::value>; + } + + private: + template + static constexpr bool is_valid_mx_scale_combination() + { + constexpr bool a_is_f4 = std::is_same_v; + constexpr bool b_is_f4 = std::is_same_v; + constexpr bool a_scale_e8m0 = std::is_same_v; + constexpr bool b_scale_e8m0 = std::is_same_v; + + // Non-F4 must use E8M0 scale + if constexpr(!a_is_f4 && !a_scale_e8m0) + return false; + if constexpr(!b_is_f4 && !b_scale_e8m0) + return false; + + // Both E8M0 -> always valid + if constexpr(a_scale_e8m0 && b_scale_e8m0) + return true; + + // Both non-E8M0 -> must match (both are F4 by rule 1) + if constexpr(!a_scale_e8m0 && !b_scale_e8m0) + return std::is_same_v; + + // One side non-E8M0: the E8M0 side must not be F4 + if constexpr(!a_scale_e8m0) + return !b_is_f4; + if constexpr(!b_scale_e8m0) + return !a_is_f4; + + return true; + } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompTDMWmma + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompTDMWmma, KernelTypesMxGemmCompTDMWmma); + +#include "test_mx_grouped_gemm_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 9e71d6ea76..31148661f2 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -520,7 +520,7 @@ class GemmKernelBuilder: } elif self.kernel_name_prefix == "mx_gemm": pipeline_impl_map = { - "comp_async": "ck_tile::MXGemmPipelineAgBgCrCompAsync", + "comp_async": "ck_tile::GemmPipelineAgBgCrCompAsync", } base_pipeline_map = {} @@ -581,9 +581,6 @@ class GemmKernelBuilder: instance_code += """#include #include #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" -""" - elif self.kernel_name_prefix == "mx_gemm": - instance_code += """#include "ck_tile/ops/gemm_mx.hpp" """ return instance_code @@ -617,9 +614,7 @@ using CDataType = {get_dtype_string(c_type)};""" if self.kernel_name_prefix == "mx_gemm": instance_code += """ using ScaleType = ck_tile::e8m0_t; -using ScaleM = ck_tile::MXScalePointer; -using ScaleN = ck_tile::MXScalePointer; -using MxGemmHostArgs = ck_tile::MXGemmKernelArgs;""" +using MxGemmHostArgs = ck_tile::MxGemmHostArgs<1, 1, 0>;""" if self.kernel_name_prefix == "gemm_multi_d": instance_code += f""" @@ -684,7 +679,7 @@ struct SelectedKernel {{ static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"}; static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"}; static constexpr bool TransposeC = false; - static constexpr bool DoubleSmemBuffer = {"true" if pipeline in ["compv4", "preshufflev2"] else "false"};""" + static constexpr bool DoubleSmemBuffer = {"true" if pipeline in ["compv4", "preshufflev2", "comp_async"] else "false"};""" if self.kernel_name_prefix in [ "gemm_universal", @@ -1069,17 +1064,17 @@ struct SelectedKernel {{ instance_code += f""" // Kernel type - using Kernel = ck_tile::MXGemmKernel; + using Kernel = ck_tile::MxGemmKernel; // Kernel arguments - auto kargs = args; + auto kargs = Kernel::MakeKernelArgs(args); - if(!Kernel::Underlying::IsSupportedArgument(kargs)) {{ + if(!Kernel::IsSupportedArgument(kargs)) {{ throw std::runtime_error("Wrong! Arguments not supported! Skipping mx gemm!"); }} // Get grid and block sizes - const dim3 grids = Kernel::GridSize(kargs); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); const dim3 blocks = Kernel::BlockSize(); if(stream.log_level_ > 0) {{ @@ -1245,7 +1240,15 @@ struct SelectedKernel {{ WarpTileM, WarpTileN, WarpTileK, - TransposeC>; + TransposeC, + 1, // NumWaveGroups + false, // FixedVectorSize_ + 1, // VectorSizeC_ + 1, // BlockedXDLNPerWarp + false, // DoubleSmemBuffer_ + ADataType, // AComputeDataType + BDataType, // BComputeDataType + true>; // TilesPacked_ using GemmEpilogue = ck_tile::CShuffleEpilogue;""" return instance_code diff --git a/tile_engine/ops/gemm/mx_gemm/mx_gemm_profiler.hpp b/tile_engine/ops/gemm/mx_gemm/mx_gemm_profiler.hpp index 55d21b5733..1e6b2fee05 100644 --- a/tile_engine/ops/gemm/mx_gemm/mx_gemm_profiler.hpp +++ b/tile_engine/ops/gemm/mx_gemm/mx_gemm_profiler.hpp @@ -9,7 +9,6 @@ #include #include "ck_tile/host/device_prop.hpp" -#include "ck_tile/ops/gemm_mx.hpp" #include "gemm/gemm_profiler.hpp" #include "mx_gemm_benchmark.hpp" @@ -49,15 +48,13 @@ class MXGemmProfiler : public GemmProfiler scale_a_host(ck_tile::host_tensor_descriptor( - gemm_problem.m_, scale_k_size, stride_scale_a, is_row_major(layout_a))); - ck_tile::HostTensor scale_b_host(ck_tile::host_tensor_descriptor( - scale_k_size, gemm_problem.n_, stride_scale_b, is_row_major(layout_b))); + ck_tile::HostTensor scale_a_host( + {static_cast(gemm_problem.m_), static_cast(scale_k_size)}, + {static_cast(scale_k_size), static_cast(1)}); + ck_tile::HostTensor scale_b_host( + {static_cast(gemm_problem.n_), static_cast(scale_k_size)}, + {static_cast(scale_k_size), static_cast(1)}); if(setting_.init_method == 0) { @@ -109,31 +106,47 @@ class MXGemmProfiler : public GemmProfiler(scale_a_host, - true); - auto scale_b_packed = - pack_mx_scales_mn_x_k(scale_b_host, - false); + ck_tile::HostTensor scale_a_shuffled( + {static_cast(gemm_problem.m_ / m_xdl_pack * 2), + static_cast(scale_k_size / k_xdl_pack * 2)}, + {static_cast(scale_k_size / k_xdl_pack * 2), static_cast(1)}); + + ck_tile::HostTensor scale_b_shuffled( + {static_cast(gemm_problem.n_ / n_xdl_pack * 2), + static_cast(scale_k_size / k_xdl_pack * 2)}, + {static_cast(scale_k_size / k_xdl_pack * 2), static_cast(1)}); + + ck_tile::preShuffleScaleBuffer_gfx950( + scale_a_host.mData.data(), + scale_a_shuffled.mData.data(), + gemm_problem.m_, + scale_k_size, + true); + + ck_tile::preShuffleScaleBuffer_gfx950( + scale_b_host.mData.data(), + scale_b_shuffled.mData.data(), + gemm_problem.n_, + scale_k_size, + true); ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - ck_tile::DeviceMem scale_a_dev_buf(scale_a_packed.size() * sizeof(int32_t)); - ck_tile::DeviceMem scale_b_dev_buf(scale_b_packed.size() * sizeof(int32_t)); + ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes()); + ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes()); a_m_k_dev_buf.ToDevice(a_m_k.data()); b_k_n_dev_buf.ToDevice(b_k_n.data()); c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - scale_a_dev_buf.ToDevice(scale_a_packed.data()); - scale_b_dev_buf.ToDevice(scale_b_packed.data()); - - ScaleM scale_m(reinterpret_cast(scale_a_dev_buf.GetDeviceBuffer())); - ScaleN scale_n(reinterpret_cast(scale_b_dev_buf.GetDeviceBuffer())); + scale_a_dev_buf.ToDevice(scale_a_shuffled.data()); + scale_b_dev_buf.ToDevice(scale_b_shuffled.data()); MxGemmHostArgs gemm_args({a_m_k_dev_buf.GetDeviceBuffer()}, + {scale_a_dev_buf.GetDeviceBuffer()}, {b_k_n_dev_buf.GetDeviceBuffer()}, + {scale_b_dev_buf.GetDeviceBuffer()}, {}, c_m_n_dev_buf.GetDeviceBuffer(), gemm_problem.split_k_, @@ -143,17 +156,26 @@ class MXGemmProfiler : public GemmProfiler c_m_n_host_result(ck_tile::host_tensor_descriptor( gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); if(setting_.verify) { + // Host reference computation using reference_mx_gemm + // reference_mx_gemm expects scale_a(M, K/ScaleBlockSize) and scale_b(K/ScaleBlockSize, + // N) We need to create scale_b in (K/ScaleBlockSize, N) format for the reference + ck_tile::HostTensor scale_b_ref( + {static_cast(scale_k_size), static_cast(gemm_problem.n_)}, + {static_cast(1), static_cast(scale_k_size)}); + // Copy scale_b data (our scale_b is (N, scale_k_size) row-major, + // reference expects (scale_k_size, N) col-major, which is the same memory layout) + std::copy( + scale_b_host.mData.begin(), scale_b_host.mData.end(), scale_b_ref.mData.begin()); + mx_gemm_host_reference( - setting_.verify, a_m_k, b_k_n, c_m_n_host_result, scale_a_host, scale_b_host); + setting_.verify, a_m_k, b_k_n, c_m_n_host_result, scale_a_host, scale_b_ref); } for(auto& callable : callables)