diff --git a/example/ck_tile/42_mx_gemm/mx_gemm.cpp b/example/ck_tile/42_mx_gemm/mx_gemm.cpp index c5cdc84689..f26a4c9ff5 100644 --- a/example/ck_tile/42_mx_gemm/mx_gemm.cpp +++ b/example/ck_tile/42_mx_gemm/mx_gemm.cpp @@ -24,13 +24,13 @@ static constexpr inline auto is_row_major(Layout layout_) template float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf, ck_tile::DeviceMem& b_dev_buf, @@ -42,36 +42,38 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf, ck_tile::index_t stride_B, ck_tile::index_t stride_C, ck_tile::index_t kbatch, - ScaleM scale_m, - ScaleN scale_n, + ck_tile::DeviceMem& scale_m, + ck_tile::DeviceMem& scale_n, int n_warmup, int n_repeat) { - MXGemmHostArgs args(a_dev_buf.GetDeviceBuffer(), - b_dev_buf.GetDeviceBuffer(), - c_dev_buf.GetDeviceBuffer(), - kbatch, - M, - N, - K, - stride_A, - stride_B, - stride_C, - scale_m, - scale_n); + ck_tile::MxGemmHostArgs<1, 1, 0> args({static_cast(a_dev_buf.GetDeviceBuffer())}, + {static_cast(scale_m.GetDeviceBuffer())}, + {static_cast(b_dev_buf.GetDeviceBuffer())}, + {static_cast(scale_n.GetDeviceBuffer())}, + {}, + c_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + {stride_A}, + {stride_B}, + {}, + stride_C); // Simplified invocation - comp_async handles hot loop and tail internally auto invoke_splitk_path = [&](auto split_k_) { return mx_gemm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); diff --git a/example/ck_tile/42_mx_gemm/mx_gemm.hpp b/example/ck_tile/42_mx_gemm/mx_gemm.hpp index 1ee491f7df..f35df974c3 100644 --- a/example/ck_tile/42_mx_gemm/mx_gemm.hpp +++ b/example/ck_tile/42_mx_gemm/mx_gemm.hpp @@ -9,49 +9,8 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm_mx.hpp" -#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp" - -template -struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0> -{ - using Base = ck_tile::UniversalGemmHostArgs<1, 1, 0>; - - MXGemmHostArgs(const void* a_ptr, - const void* b_ptr, - void* c_ptr_, - ck_tile::index_t k_batch_, - ck_tile::index_t M_, - ck_tile::index_t N_, - ck_tile::index_t K_, - ck_tile::index_t stride_A_, - ck_tile::index_t stride_B_, - ck_tile::index_t stride_C_, - ScaleM scale_m_, - ScaleN scale_n_) - : Base({a_ptr}, - {b_ptr}, - {}, - c_ptr_, - k_batch_, - M_, - N_, - K_, - {stride_A_}, - {stride_B_}, - {}, - stride_C_), - scale_m(scale_m_), - scale_n(scale_n_) - { - } - - ScaleM scale_m; - ScaleN scale_n; -}; // GEMM config with 16x16 warp tile - struct MxGemmConfig { static constexpr ck_tile::index_t M_Tile = 128; @@ -78,12 +37,15 @@ struct MxGemmConfig static constexpr int TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool DoubleSmemBuffer = false; // comp_async uses double buffer + static constexpr bool DoubleSmemBuffer = true; // comp_async uses double buffer static constexpr bool Preshuffle = false; static constexpr ck_tile::index_t BContiguousItemsPerAccess = 16; static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; static constexpr bool TiledMMAPermuteN = false; + + using AScaleDataType = ck_tile::e8m0_t; + using BScaleDataType = ck_tile::e8m0_t; }; struct MX_GemmConfigEightWaves : MxGemmConfig diff --git a/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp b/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp index c7ea374862..a7b57fb330 100644 --- a/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp +++ b/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp @@ -5,10 +5,9 @@ #include "ck_tile/host.hpp" #include "mx_gemm.hpp" -#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" -#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp" -#include "ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp" -#include "ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1.hpp" template using is_row_major_t = ck_tile::bool_constant< @@ -17,16 +16,16 @@ using is_row_major_t = ck_tile::bool_constant< template -float mx_gemm_calc(const MXGemmHostArgs& args, const ck_tile::stream_config& s) +float mx_gemm_calc(const ck_tile::MxGemmHostArgs<1, 1, 0>& args, const ck_tile::stream_config& s) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -51,29 +50,38 @@ float mx_gemm_calc(const MXGemmHostArgs& args, const ck_tile::st static_assert(sizeof(ComputeDataType) >= sizeof(BDataType), "mixed_prec_gemm requires ADataType is a wider type than BDataType"); - using MXPipelineProblem = ck_tile::UniversalGemmPipelineProblem; + using AComputeDataType = ADataType; + using BComputeDataType = BDataType; + + using MXPipelineProblem = ck_tile::MxGemmPipelineProblem; - // Use the MX GEMM Preshuffle pipeline or - // the new MX comp_async pipeline with MX scaling support constexpr bool IsEightWave = (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp) == 8; using MXGemmPipeline = std::conditional_t< GemmConfig::Preshuffle, ck_tile::MXGemmPreshufflePipelineAGmemBGmemCRegV1, std::conditional_t, - ck_tile::MXGemmPipelineAgBgCrCompAsync>>; + ck_tile::GemmPipelineAgBgCrCompAsyncEightWaves, + ck_tile::GemmPipelineAgBgCrCompAsync>>; using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr ck_tile::index_t BlockedXDLNPerWarp = GemmConfig::Preshuffle ? 2 : 1; + using GemmEpilogue = std::conditional_t& args, const ck_tile::st GemmConfig::NumWaveGroups, false, // FixedVectorSize_ (Default) 1, // VectorSizeC_ (Default) - ck_tile::MXEpilogueTraits::BlockedXDLNPerWarp, + BlockedXDLNPerWarp, false, // DoubleSmemBuffer_ (Default) ComputeDataType, // AComputeDataType ComputeDataType, // BComputeDataType !GemmConfig::Preshuffle>>>; // TilesPacked_ (because of // packed scales) - using Kernel = ck_tile::MXGemmKernel; + using Kernel = ck_tile::MxGemmKernel; - auto kargs = Kernel::MakeKernelArgs(std::array{args.as_ptr}, - std::array{args.bs_ptr}, - std::array{}, - args.e_ptr, - args.k_batch, - args.M, - args.N, - args.K, - std::array{args.stride_As}, - std::array{args.stride_Bs}, - std::array{}, - args.stride_E, - args.scale_m, - args.scale_n); + auto kargs = Kernel::MakeKernelArgs(args); if(!Kernel::IsSupportedArgument(kargs)) { @@ -145,8 +140,10 @@ float mx_gemm_calc(const MXGemmHostArgs& args, const ck_tile::st "MX GEMM: unsupported shape/configuration (set CK_TILE_LOGGING=1 for details)."); } - const auto kernel = ck_tile::make_kernel( - Kernel{}, Kernel::GridSize(kargs), Kernel::BlockSize(), 0, kargs); + constexpr int kBlockPerCu = 1; + + const auto kernel = ck_tile::make_kernel( + Kernel{}, Kernel::GridSize(args.M, args.N, args.k_batch), Kernel::BlockSize(), 0, kargs); // For split-K (k_batch > 1) the kernel's epilogue uses atomic_add into C, so C must be // zeroed before every kernel launch -- not just once before the first warmup iteration. diff --git a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc index d6b24e6758..52e28443d7 100644 --- a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc +++ b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc @@ -16,6 +16,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K, const float max_accumulated_v template scale_a_host( + {static_cast(M), static_cast(scale_k_size)}, + {static_cast(scale_k_size), static_cast(1)}); + + // scale_b uses N as first dimension (col-major like B) + ck_tile::HostTensor scale_b_host( + {static_cast(N), static_cast(scale_k_size)}, + {static_cast(scale_k_size), static_cast(1)}); - ck_tile::HostTensor scale_a_host( - ck_tile::host_tensor_descriptor(M, scale_k_size, stride_scale_a, is_row_major(ALayout{}))); - ck_tile::HostTensor scale_b_host( - ck_tile::host_tensor_descriptor(scale_k_size, N, stride_scale_b, is_row_major(BLayout{}))); std::mt19937 gen(42); std::uniform_int_distribution fill_seed(0, 500); - auto gen_scales = [&](auto& scales, float range_min, float range_max) { - // e8m0_t is basically an exponent of float32 - ck_tile::HostTensor pow2(scales.get_lengths()); - ck_tile::FillUniformDistributionIntegerValue{range_min, range_max, fill_seed(gen)}( - pow2); - scales.ForEach([&](auto& self, const auto& i) { - self(i) = static_cast(std::exp2(pow2(i))); - }); - }; switch(init_method) { case 0: // Initialize A, B, and scales to random values ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(a_host); ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(b_host); - gen_scales(scale_a_host, -2, 2); - gen_scales(scale_b_host, -2, 2); + ck_tile::FillUniformScaleDistribution{0.125f, 2.0f, fill_seed(gen)}( + scale_a_host); + ck_tile::FillUniformScaleDistribution{0.125f, 2.0f, fill_seed(gen)}( + scale_b_host); break; case 1: // Initialize A, B, and scales to 1.0 ck_tile::FillConstant{ADataType(1.f)}(a_host); ck_tile::FillConstant{BDataType(1.f)}(b_host); - gen_scales(scale_a_host, 0, 0); - gen_scales(scale_b_host, 0, 0); + ck_tile::FillUniformScaleDistribution{1.0f, 1.0f, fill_seed(gen)}( + scale_a_host); + ck_tile::FillUniformScaleDistribution{1.0f, 1.0f, fill_seed(gen)}( + scale_b_host); break; case 2: // Initialize A and B with random values but with constant 1.0 scales ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(a_host); ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(b_host); - gen_scales(scale_a_host, 0, 0); - gen_scales(scale_b_host, 0, 0); + ck_tile::FillUniformScaleDistribution{1.0f, 1.0f, fill_seed(gen)}( + scale_a_host); + ck_tile::FillUniformScaleDistribution{1.0f, 1.0f, fill_seed(gen)}( + scale_b_host); break; } @@ -128,12 +125,31 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout) constexpr ck_tile::index_t XdlMNThread = GemmConfig::M_Warp_Tile; constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread; - auto scale_a_packed = - ck_tile::packScalesMNxK(scale_a_host, - true); - auto scale_b_packed = - ck_tile::packScalesMNxK(scale_b_host, - false); + ck_tile::HostTensor scale_a_shuffled( + {static_cast(M / MXdlPackEff * 2), + static_cast(scale_k_size / KXdlPackEff * 2)}, + {static_cast(scale_k_size / KXdlPackEff * 2), static_cast(1)}); + + ck_tile::HostTensor scale_b_shuffled( + {static_cast(N / NXdlPackEff * 2), + static_cast(scale_k_size / KXdlPackEff * 2)}, + {static_cast(scale_k_size / KXdlPackEff * 2), static_cast(1)}); + + ck_tile::preShuffleScaleBuffer_gfx950( + scale_a_host.mData.data(), scale_a_shuffled.mData.data(), M, scale_k_size, true); + + if constexpr(GemmConfig::Preshuffle && GemmConfig::TiledMMAPermuteN) + { + ck_tile::preShuffleScaleBufferPermuteN_gfx950( + scale_b_host.mData.data(), scale_b_shuffled.mData.data(), N, scale_k_size, true); + } + else + { + ck_tile::preShuffleScaleBuffer_gfx950( + scale_b_host.mData.data(), scale_b_shuffled.mData.data(), N, scale_k_size, true); + } const auto b_host_for_device = [&]() { if constexpr(GemmConfig::Preshuffle) @@ -145,73 +161,29 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout) return b_host; }(); - const auto scale_a_host_for_device = [&]() { - if constexpr(GemmConfig::Preshuffle) - return ck_tile::preShuffleScale(scale_a_host, true); - else - return scale_a_packed; - }(); - - constexpr ck_tile::index_t XdlNThread = GemmConfig::N_Warp_Tile; - constexpr ck_tile::index_t NPerBlock = GemmConfig::N_Tile; - constexpr ck_tile::index_t NWarp = GemmConfig::N_Warp; - - const auto scale_b_host_for_device = [&]() { - if constexpr(GemmConfig::Preshuffle) - { - if constexpr(GemmConfig::TiledMMAPermuteN) - return ck_tile::preShuffleScalePermuteN(scale_b_host, - false); - else - return ck_tile::preShuffleScale(scale_b_host, false); - } - else - return scale_b_packed; - }(); - - const auto scale_a_device_bytes = [&]() { - if constexpr(GemmConfig::Preshuffle) - return scale_a_host_for_device.get_element_space_size_in_bytes(); - else - return scale_a_host_for_device.size() * sizeof(int32_t); - }(); - - const auto scale_b_device_bytes = [&]() { - if constexpr(GemmConfig::Preshuffle) - return scale_b_host_for_device.get_element_space_size_in_bytes(); - else - return scale_b_host_for_device.size() * sizeof(int32_t); - }(); - // Device buffers for A, B, C, and packed scale tensors ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_dev_buf(b_host_for_device.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem scale_a_dev_buf(scale_a_device_bytes); - ck_tile::DeviceMem scale_b_dev_buf(scale_b_device_bytes); + ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes()); + ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes()); a_dev_buf.ToDevice(a_host.data()); b_dev_buf.ToDevice(b_host_for_device.data()); c_dev_buf.SetZero(); - scale_a_dev_buf.ToDevice(scale_a_host_for_device.data()); - scale_b_dev_buf.ToDevice(scale_b_host_for_device.data()); - - // Scale pointers - point to packed int32_t data, kernel reinterprets as int32_t* - using ScaleM = ck_tile::MXScalePointer; - using ScaleN = ck_tile::MXScalePointer; - ScaleM scale_m(reinterpret_cast(scale_a_dev_buf.GetDeviceBuffer())); - ScaleN scale_n(reinterpret_cast(scale_b_dev_buf.GetDeviceBuffer())); + scale_a_dev_buf.ToDevice(scale_a_shuffled.data()); + scale_b_dev_buf.ToDevice(scale_b_shuffled.data()); float ave_time = invoke_mx_gemm(a_dev_buf, b_dev_buf, c_dev_buf, @@ -222,8 +194,8 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout) stride_B, stride_C, kbatch, - scale_m, - scale_n, + scale_a_dev_buf, + scale_b_dev_buf, n_warmup, n_repeat); @@ -240,9 +212,23 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout) ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); - ck_tile:: - reference_mx_gemm( - a_host, b_host, c_m_n_host_ref, scale_a_host, scale_b_host); + // Host reference computation using reference_mx_gemm + // reference_mx_gemm expects scale_a(M, K/ScaleBlockSize) and scale_b(K/ScaleBlockSize, N) + // We need to create scale_b in (K/ScaleBlockSize, N) format for the reference + ck_tile::HostTensor scale_b_ref( + {static_cast(scale_k_size), static_cast(N)}, + {static_cast(1), static_cast(scale_k_size)}); + // Copy scale_b data (our scale_b is (N, scale_k_size) row-major, + // reference expects (scale_k_size, N) col-major, which is the same memory layout) + std::copy(scale_b_host.mData.begin(), scale_b_host.mData.end(), scale_b_ref.mData.begin()); + + ck_tile::reference_mx_gemm( + a_host, b_host, c_m_n_host_ref, scale_a_host, scale_b_ref); const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); @@ -285,6 +271,8 @@ int run_mx_gemm_example(int argc, char* argv[]) { return run_mx_gemm_with_layouts(argc, argv, Row{}, Col{}, Row{}); @@ -293,6 +281,8 @@ int run_mx_gemm_example(int argc, char* argv[]) { return run_mx_gemm_with_layouts(argc, argv, Row{}, Col{}, Row{}); @@ -304,6 +294,8 @@ int run_mx_gemm_example(int argc, char* argv[]) { return run_mx_gemm_with_layouts(argc, argv, Row{}, Col{}, Row{}); @@ -312,6 +304,8 @@ int run_mx_gemm_example(int argc, char* argv[]) { return run_mx_gemm_with_layouts(argc, argv, Row{}, Col{}, Row{}); diff --git a/include/ck_tile/core/tensor/null_tensor.hpp b/include/ck_tile/core/tensor/null_tensor.hpp index f5cabbef5a..fbbef3ab51 100644 --- a/include/ck_tile/core/tensor/null_tensor.hpp +++ b/include/ck_tile/core/tensor/null_tensor.hpp @@ -3,10 +3,28 @@ #pragma once +#include "ck_tile/core/utility/type_traits.hpp" + namespace ck_tile { struct null_tensor { }; +// utility to check if this is a Null Tensor +namespace impl { +template +struct is_null_tensor : public std::false_type +{ +}; + +template <> +struct is_null_tensor : public std::true_type +{ +}; +} // namespace impl + +template +constexpr bool is_null_tensor_v = impl::is_null_tensor>::value; + } // namespace ck_tile diff --git a/include/ck_tile/host/mx_processing.hpp b/include/ck_tile/host/mx_processing.hpp index c15df6f509..e5eaf70fb8 100644 --- a/include/ck_tile/host/mx_processing.hpp +++ b/include/ck_tile/host/mx_processing.hpp @@ -7,148 +7,163 @@ namespace ck_tile { +/// @brief Pre-shuffle scale buffer for gfx1250 wmma mx scale instruction. +/// +/// Reorganizes the scale data from row-major (MN x K) layout to the hardware-specific +/// layout expected by the gfx1250 wmma instruction. +/// +/// @tparam ScaleType Scale data type (e.g., e8m0_t) +/// @tparam ScaleBlockSize The block size for microscaling (e.g., 32) +/// @tparam KStride Whether K is the fast-moving dimension +template +void preShuffleScaleBuffer_gfx1250(const ScaleType* src, + ScaleType* dst, + ck_tile::index_t MN, + ck_tile::index_t K) +{ + static_assert((ScaleBlockSize == 32 || ScaleBlockSize == 16) && sizeof(ScaleType) == 1, + "wrong! only support 8-bit scale with ScaleBlockSize=32 or 16"); + + // ScaleBlockSize == 16: the natural row-major scale layout already matches the gfx1250 + // wmma scale distribution (one e8m0 per 16 K-elements lands warp-aligned), so the + // device-side shuffle is the identity transform for all K. + if constexpr(ScaleBlockSize == 16) + { + for(ck_tile::long_index_t mn = 0; mn < MN; ++mn) + for(ck_tile::long_index_t k = 0; k < K; ++k) + { + if constexpr(KStride) + dst[mn * K + k] = src[mn * K + k]; + else + dst[mn * K + k] = src[k * MN + mn]; + } + return; + } + + constexpr ck_tile::long_index_t MPerXdlops = 16; + constexpr ck_tile::long_index_t KPerXdlops = 128; + + ck_tile::long_index_t MNPack = 2; + ck_tile::long_index_t KPack = 1; + + ck_tile::long_index_t MNStep = MPerXdlops; + ck_tile::long_index_t KStep = KPerXdlops / ScaleBlockSize; + + ck_tile::long_index_t K0 = K / KPack / KStep; + + for(ck_tile::long_index_t mn = 0; mn < MN; ++mn) + { + ck_tile::long_index_t iMNRepeat = mn / (MNStep * MNPack); + ck_tile::long_index_t tempmn = mn % (MNStep * MNPack); + + for(ck_tile::long_index_t k = 0; k < K; ++k) + { + ck_tile::long_index_t iKRepeat = k / (KStep * KPack); + ck_tile::long_index_t tempk = k % (KStep * KPack); + + ck_tile::long_index_t outputIndex = + (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) + + (iKRepeat * KStep * KPack) * (MNStep * MNPack) + tempmn * (KStep * KPack) + tempk; + + if constexpr(KStride) + { + dst[outputIndex] = src[mn * K + k]; + } + else + dst[outputIndex] = src[k * MN + mn]; + } + } +} + // Pack [MN, K/32] e8m0_t scales into [MN/MNPack, K/32/KPack] int32_t // Each int32_t contains MNPack * KPack e8m0_t values with byte layout matching // the GPU tile distribution: values are XdlMNThread apart in M and XdlKThread apart in K. // byte[ik * MNPack + imn] = e8m0 at strided (mn, k) position // kLast=true for A scales (layout [M, K/32]), kLast=false for B scales (layout [K/32, N]) -template -auto packScalesMNxK(const HostTensor& src, const bool kLast) +template +void preShuffleScaleBuffer_gfx950(const ScaleType* src, + ScaleType* packed, + ck_tile::index_t MN, + ck_tile::index_t K_scale, + bool kLast) { - auto src_lengths = src.get_lengths(); - const index_t MN = kLast ? src_lengths[0] : src_lengths[1]; - const index_t K_scale = kLast ? src_lengths[1] : src_lengths[0]; - const index_t MN_packed = MN / MNPack; - const index_t K_packed = K_scale / KPack; + const ck_tile::long_index_t MN_packed = MN / MNPack; + const ck_tile::long_index_t K_packed = K_scale / KPack; + constexpr ck_tile::long_index_t NumScalesPerDword = 4 / sizeof(ScaleType); - // Output as flat vector of int32_t (row-major [MN/MNPack, K/32/KPack]) - HostTensor packed(HostTensorDescriptor( - {static_cast(MN_packed), static_cast(K_packed)}, - {static_cast(K_packed), static_cast(1)})); - - for(index_t packed_mn = 0; packed_mn < MN_packed; packed_mn++) + for(ck_tile::long_index_t packed_mn = 0; packed_mn < MN_packed; packed_mn++) { - for(index_t packed_k = 0; packed_k < K_packed; packed_k++) + for(ck_tile::long_index_t packed_k = 0; packed_k < K_packed; packed_k++) { - uint32_t val = 0; - index_t mn_lane = packed_mn % XdlMNThread; - index_t mn_group = packed_mn / XdlMNThread; - index_t k_lane = packed_k % XdlKThread; - index_t k_group = packed_k / XdlKThread; - for(index_t ik = 0; ik < KPack; ik++) + ck_tile::long_index_t mn_lane = packed_mn % XdlMNThread; + ck_tile::long_index_t mn_group = packed_mn / XdlMNThread; + ck_tile::long_index_t k_lane = packed_k % XdlKThread; + ck_tile::long_index_t k_group = packed_k / XdlKThread; + for(ck_tile::long_index_t ik = 0; ik < KPack; ik++) { - for(index_t imn = 0; imn < MNPack; imn++) + for(ck_tile::long_index_t imn = 0; imn < MNPack; imn++) { - index_t byteIdx = ik * MNPack + imn; - index_t orig_mn = mn_group * XdlMNThread * MNPack + imn * XdlMNThread + mn_lane; - index_t orig_k = k_group * XdlKThread * KPack + ik * XdlKThread + k_lane; + ck_tile::long_index_t byteIdx = ik * MNPack + imn; + ck_tile::long_index_t orig_mn = + mn_group * XdlMNThread * MNPack + imn * XdlMNThread + mn_lane; + ck_tile::long_index_t orig_k = + k_group * XdlKThread * KPack + ik * XdlKThread + k_lane; - e8m0_t v = kLast ? src(orig_mn, orig_k) : src(orig_k, orig_mn); - val |= (static_cast(v.get()) << (byteIdx * 8)); + ck_tile::long_index_t inputIndex = + kLast ? orig_k + orig_mn * K_scale : orig_mn + orig_k * MN; + ScaleType v = src[inputIndex]; + ck_tile::long_index_t outputIndex = + byteIdx + (packed_mn % XdlMNThread) * NumScalesPerDword + + packed_k * XdlMNThread * NumScalesPerDword + + (packed_mn / XdlMNThread) * XdlMNThread * NumScalesPerDword * K_packed; + packed[outputIndex] = v; } } - packed(packed_mn, packed_k) = static_cast(val); } } - return packed; } -template -auto preShuffleScale(ck_tile::HostTensor& src, const bool kLast) +template +auto preShuffleScaleBufferPermuteN_gfx950( + const ScaleType* src, ScaleType* shuffled, ck_tile::index_t MN, ck_tile::index_t K, bool kLast) { - auto src_lengths = src.get_lengths(); - const index_t MN = kLast ? src_lengths[0] : src_lengths[1]; - const index_t K = kLast ? src_lengths[1] : src_lengths[0]; - - constexpr index_t MNXdlPack = 2; - constexpr index_t KXdlPack = 2; - constexpr index_t XdlKThread = get_warp_size() / XdlMNThread; - - const auto MNPadded = integer_least_multiple(MN, XdlMNThread * MNXdlPack); - HostTensor shuffled(HostTensorDescriptor({static_cast(MNPadded * K)}, - {static_cast(1)})); + constexpr ck_tile::long_index_t MNXdlPack = 2; + constexpr ck_tile::long_index_t KXdlPack = 2; + constexpr ck_tile::long_index_t NRepeat = NPerBlock / NWarp / XdlMNThread; + constexpr ck_tile::long_index_t XdlKThread = ck_tile::get_warp_size() / XdlMNThread; if(K % (KXdlPack * XdlKThread) != 0) { throw std::runtime_error("wrong! K must be a multiple of (KXdlPack * XdlKThread)"); } + const ck_tile::long_index_t K0 = K / KXdlPack / XdlKThread; - const index_t K0 = K / KXdlPack / XdlKThread; - - for(index_t n = 0; n < MNPadded; ++n) + for(ck_tile::long_index_t n = 0; n < MN; ++n) { - for(index_t k = 0; k < K; ++k) + for(ck_tile::long_index_t k = 0; k < K; ++k) { - const index_t n0 = n / (XdlMNThread * MNXdlPack); - const index_t tempn = n % (XdlMNThread * MNXdlPack); - const index_t n1 = tempn % XdlMNThread; - const index_t n2 = tempn / XdlMNThread; + const ck_tile::long_index_t n0 = n / NPerBlock; + const ck_tile::long_index_t tempn0 = n % NPerBlock; + const ck_tile::long_index_t n1 = tempn0 / (XdlMNThread * NRepeat); + const ck_tile::long_index_t tempn1 = tempn0 % (XdlMNThread * NRepeat); + const ck_tile::long_index_t n2 = tempn1 / (NRepeat); + const ck_tile::long_index_t tempn2 = tempn1 % (NRepeat); + const ck_tile::long_index_t n3 = tempn2 % MNXdlPack; + const ck_tile::long_index_t n4 = tempn2 / MNXdlPack; - const index_t k0 = k / (XdlKThread * KXdlPack); - const index_t tempk = k % (XdlKThread * KXdlPack); - const index_t k1 = tempk % XdlKThread; - const index_t k2 = tempk / XdlKThread; + const ck_tile::long_index_t k0 = k / (XdlKThread * KXdlPack); + const ck_tile::long_index_t tempk = k % (XdlKThread * KXdlPack); + const ck_tile::long_index_t k1 = tempk % XdlKThread; + const ck_tile::long_index_t k2 = tempk / XdlKThread; - const index_t outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + - k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + - k1 * MNXdlPack * KXdlPack * XdlMNThread + - n1 * MNXdlPack * KXdlPack + k2 * MNXdlPack + n2; - - if(n < MN) - { - shuffled(outputIndex) = kLast ? src(n, k) : src(k, n); - } - else - { - shuffled(outputIndex) = dtype{}; - } - } - } - - return shuffled; -} - -template -auto preShuffleScalePermuteN(const HostTensor& src, const bool kLast) -{ - auto src_lengths = src.get_lengths(); - const index_t MN = kLast ? src_lengths[0] : src_lengths[1]; - const index_t K = kLast ? src_lengths[1] : src_lengths[0]; - - constexpr index_t MNXdlPack = 2; - constexpr index_t KXdlPack = 2; - constexpr index_t NRepeat = NPerBlock / NWarp / XdlMNThread; - constexpr index_t XdlKThread = get_warp_size() / XdlMNThread; // 4 - - const index_t MNPadded = integer_least_multiple(MN, NPerBlock); - HostTensor shuffled(HostTensorDescriptor({static_cast(MNPadded * K)}, - {static_cast(1)})); - - if(K % (KXdlPack * XdlKThread) != 0) - { - throw std::runtime_error("wrong! K must be a multiple of (KXdlPack * XdlKThread)"); - } - const index_t K0 = K / KXdlPack / XdlKThread; - - for(index_t n = 0; n < MNPadded; ++n) - { - for(index_t k = 0; k < K; ++k) - { - const index_t n0 = n / NPerBlock; - const index_t tempn0 = n % NPerBlock; - const index_t n1 = tempn0 / (XdlMNThread * NRepeat); - const index_t tempn1 = tempn0 % (XdlMNThread * NRepeat); - const index_t n2 = tempn1 / (NRepeat); - const index_t tempn2 = tempn1 % (NRepeat); - const index_t n3 = tempn2 % MNXdlPack; - const index_t n4 = tempn2 / MNXdlPack; - - const index_t k0 = k / (XdlKThread * KXdlPack); - const index_t tempk = k % (XdlKThread * KXdlPack); - const index_t k1 = tempk % XdlKThread; - const index_t k2 = tempk / XdlKThread; - - const index_t outputIndex = + const ck_tile::long_index_t outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 * NWarp * (NRepeat / MNXdlPack) + n1 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + @@ -156,13 +171,15 @@ auto preShuffleScalePermuteN(const HostTensor& src, const bool kLast) k1 * MNXdlPack * KXdlPack * XdlMNThread + k2 * MNXdlPack + n4 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 * NWarp + n3; + ck_tile::long_index_t inputIndex = kLast ? k + n * K : n + k * MN; + if(n < MN) { - shuffled(outputIndex) = kLast ? src(n, k) : src(k, n); + shuffled[outputIndex] = src[inputIndex]; } else { - shuffled(outputIndex) = dtype{}; + shuffled[outputIndex] = ScaleType{}; } } } diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 81987876ee..285e0f2d99 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -1184,4 +1184,131 @@ void reference_batched_gemm_gpu(ADataType* a_ptr, return; } +// GPU reference for MX (microscaling) GEMM with e8m0 block scales. +// +// This is the device counterpart of the host `reference_mx_gemm` above. It exists so the +// reference can be computed entirely on the GPU for large problems (e.g. M*N ~ 1e9) where the +// host reference is intractable and where copying the 39 GB of inputs back to host is not +// feasible. It is a faithful mirror of the host semantics: +// - per-element dot product over K, with each A/B element dequantized by its e8m0 block scale. +// - all addressing uses `long`; the existing `naive_gemm_kernel`/`blockwise_gemm_kernel` use +// `int` and silently overflow once M*N exceeds INT_MAX. +// +// Layout assumptions match the fp4 CompAsync grouped-GEMM test row (RowMajor A, ColumnMajor B, +// RowMajor C): +// - A is RowMajor: element (m,k) at linear offset m*K + k. +// - B is ColumnMajor: element (k,n) at linear offset n*K + k (K is the fast dimension). +// - C is RowMajor: element (m,n) at linear offset m*N + n. +// - scale_a is (M, num_scale_k) RowMajor: scale for (m, k) at m*num_scale_k + k/scale_block_size. +// - scale_b is (N, num_scale_k) RowMajor: scale for (k, n) at n*num_scale_k + k/scale_block_size. +template +__global__ void reference_mx_gemm_kernel(const ADataType* __restrict__ a_ptr, + const BDataType* __restrict__ b_ptr, + const AScaleDataType* __restrict__ scale_a_ptr, + const BScaleDataType* __restrict__ scale_b_ptr, + CDataType* __restrict__ c_ptr, + long M, + long N, + long K, + long num_scale_k, + long scale_block_size) +{ + const long total = M * N; + const long idx0 = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const long nthr = static_cast(gridDim.x) * blockDim.x; + + for(long idx = idx0; idx < total; idx += nthr) + { + const long m = idx / N; + const long n = idx % N; + + AccDataType acc = 0; + for(long k = 0; k < K; ++k) + { + // --- A element (RowMajor) --- + AccDataType a_val; + const long a_lin = m * K + k; + if constexpr(std::is_same_v) + { + const fp32x2_t a_f2 = pk_fp4_to_fp32x2(a_ptr[a_lin / 2], 1.0f); + a_val = type_convert((a_lin % 2 == 0) ? a_f2.lo : a_f2.hi); + } + else + { + a_val = type_convert(a_ptr[a_lin]); + } + const float a_sc = + type_convert(scale_a_ptr[m * num_scale_k + k / scale_block_size]); + + // --- B element (ColumnMajor, K fast) --- + AccDataType b_val; + const long b_lin = n * K + k; + if constexpr(std::is_same_v) + { + const fp32x2_t b_f2 = pk_fp4_to_fp32x2(b_ptr[b_lin / 2], 1.0f); + b_val = type_convert((b_lin % 2 == 0) ? b_f2.lo : b_f2.hi); + } + else + { + b_val = type_convert(b_ptr[b_lin]); + } + const float b_sc = + type_convert(scale_b_ptr[n * num_scale_k + k / scale_block_size]); + + acc += (a_val * type_convert(a_sc)) * + (b_val * type_convert(b_sc)); + } + c_ptr[m * N + n] = type_convert(acc); + } +} + +template +void reference_mx_gemm_gpu(const ADataType* a_ptr, + const BDataType* b_ptr, + const AScaleDataType* scale_a_ptr, + const BScaleDataType* scale_b_ptr, + CDataType* c_ptr, + index_t M, + index_t N, + index_t K, + index_t num_scale_k, + index_t scale_block_size, + hipStream_t stream = nullptr) +{ + const long total = static_cast(M) * N; + constexpr int threads = 256; + constexpr long max_blocks = 2097152; // grid-stride cap (~2M blocks) + const long needed = (total + threads - 1) / threads; + const long blocks = needed < max_blocks ? needed : max_blocks; + + reference_mx_gemm_kernel + <<(blocks)), dim3(threads), 0, stream>>>( + a_ptr, + b_ptr, + scale_a_ptr, + scale_b_ptr, + c_ptr, + static_cast(M), + static_cast(N), + static_cast(K), + static_cast(num_scale_k), + static_cast(scale_block_size)); + hip_check_error(hipGetLastError()); +} + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 6779da556e..b93163c202 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -30,6 +30,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" +#include "ck_tile/ops/gemm/block/block_mx_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp" @@ -78,6 +79,8 @@ #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_tdm.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_tdm_policy.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp index 9f91c06e8e..76b5eede5c 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp @@ -35,6 +35,8 @@ struct BlockGemmARegBRegCRegEightWavesV1 static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr auto PackMNIter = Policy::PackMNIter; + static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WarpGemm = remove_cvref_t())>; @@ -95,6 +97,8 @@ struct BlockGemmARegBRegCRegEightWavesV1 static constexpr auto Scheduler = Traits::Scheduler; static constexpr bool TransposeC = Traits::TransposeC; + static constexpr bool PackMNIter = Traits::PackMNIter; + using AWarpDstr = typename WarpGemm::AWarpDstr; using BWarpDstr = typename WarpGemm::BWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr; @@ -136,17 +140,34 @@ struct BlockGemmARegBRegCRegEightWavesV1 sequence, sequence>; - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 1>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + if constexpr(PackMNIter) + { + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 1>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - return a_block_dstr_encode; + return a_block_dstr_encode; + } + else + { + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 1>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } } CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() @@ -161,32 +182,66 @@ struct BlockGemmARegBRegCRegEightWavesV1 sequence, sequence>; - constexpr auto b_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<>, - sequence<>>{}; + if constexpr(PackMNIter) + { + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<>, + sequence<>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - return b_block_dstr_encode; + return b_block_dstr_encode; + } + else + { + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<>, + sequence<>>{}; + + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } } CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence<2, NIterPerWarp, NWarp / 2>>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 1>>{}; - constexpr auto c_block_dstr_encoding = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - return c_block_dstr_encoding; + if constexpr(PackMNIter) + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence<2, NIterPerWarp, NWarp / 2>>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 1>>{}; + constexpr auto c_block_dstr_encoding = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + return c_block_dstr_encoding; + } + else + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence<2, NIterPerWarp, NWarp / 2>>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 1>>{}; + constexpr auto c_block_dstr_encoding = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + return c_block_dstr_encoding; + } } CK_TILE_DEVICE static constexpr auto MakeCBlockTile() @@ -252,6 +307,101 @@ struct BlockGemmARegBRegCRegEightWavesV1 }); } + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ALdsTile& a_warp_tile_, + const BLdsTiles& b_warp_tiles_, + const ScaleATensor& scale_a_tensor, + const ScaleBTensor& scale_b_tensor) const + { + // checks + static_assert(std::is_same_v>, + "CDataType must be same as CBlockTensor::DataType!"); + static_assert( + std::is_same_v, + remove_cvref_t>, + "C distribution is wrong!"); + + // Effective XdlPack: fall back to 1 when iteration count is insufficient + constexpr index_t MXdlPack = + (MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1; + constexpr index_t NXdlPack = + (NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1; + constexpr index_t KXdlPack = + (KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1; + + constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack; + constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack; + constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack; + + // hot loop: + static_for_product, + number, + number>{}([&](auto ikpack, auto inpack, auto impack) { + // get A scale for this M-K tile using get_y_sliced_thread_data + auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data( + sequence{}, sequence<1, 1, 1>{}); + const int32_t a_scale_packed = bit_cast(scale_a_slice[number<0>{}]); + + // get B scale for this N-K tile using get_y_sliced_thread_data + auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data( + sequence{}, sequence<1, 1, 1>{}); + const int32_t b_scale_packed = bit_cast(scale_b_slice[number<0>{}]); + + // Inner loops: issue MFMAs within the pack group using OpSel + static_for_product, number, number>{}( + [&](auto ikxdl, auto inxdl, auto imxdl) { + constexpr auto kIter = ikpack * KXdlPack + ikxdl; + constexpr auto mIter = impack * MXdlPack + imxdl; + constexpr auto nIter = inpack * NXdlPack + inxdl; + + // OpSel for A: selects byte within packed int32_t + constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl; + + // OpSel for B: selects byte within packed int32_t + constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl; + + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = + b_warp_tiles_[number{}][number{}].get_thread_buffer(); + + // read C warp tensor from C block tensor + using c_iter_idx = sequence; + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM with MX scaling + WarpGemm{}.template operator(), OpSelB>( + c_warp_tensor, + a_warp_tensor, + b_warp_tensor, + a_scale_packed, + b_scale_packed); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + } + template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, const ALdsTile& a_warp_tile_, diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index 51fd49f88f..0465dcb2af 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -39,6 +39,8 @@ struct BlockGemmARegBRegCRegV1 static constexpr auto KSubTileNum = Policy::KSubTileNum; + static constexpr auto PackMNIter = Policy::PackMNIter; + static constexpr index_t MWarp = config.template at<1>(); static constexpr index_t NWarp = config.template at<2>(); static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); @@ -69,6 +71,8 @@ struct BlockGemmARegBRegCRegV1 static constexpr index_t KSubTileNum = Traits::KSubTileNum; + static constexpr bool PackMNIter = Traits::PackMNIter; + static constexpr index_t KPerSubTile = KIterPerWarp / KSubTileNum; static constexpr index_t MWarp = Traits::MWarp; @@ -94,17 +98,34 @@ struct BlockGemmARegBRegCRegV1 } else { - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + if constexpr(PackMNIter) + { + constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - return a_block_dstr_encode; + return a_block_dstr_encode; + } + else + { + constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } } } @@ -126,50 +147,103 @@ struct BlockGemmARegBRegCRegV1 } else { - constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + if constexpr(PackMNIter) + { + constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - return b_block_dstr_encode; + return b_block_dstr_encode; + } + else + { + constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } } } CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() { using c_distr_ys_major = std::conditional_t, sequence<1, 2>>; - if constexpr(UseDefaultScheduler) + if constexpr(PackMNIter) { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple<>, - tuple<>, - c_distr_ys_major, - sequence<0, 0>>{}; - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + if constexpr(UseDefaultScheduler) + { + using c_distr_ys_minor = + std::conditional_t, sequence<0, 1>>; + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + c_distr_ys_major, + c_distr_ys_minor>{}; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - return c_block_dstr_encode; + return c_block_dstr_encode; + } + else + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + c_distr_ys_major, + sequence<1, 1>>{}; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + return c_block_dstr_encode; + } } else { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - c_distr_ys_major, - sequence<0, 0>>{}; - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + if constexpr(UseDefaultScheduler) + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple<>, + tuple<>, + c_distr_ys_major, + sequence<0, 0>>{}; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - return c_block_dstr_encode; + return c_block_dstr_encode; + } + else + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + c_distr_ys_major, + sequence<0, 0>>{}; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + return c_block_dstr_encode; + } } } @@ -411,9 +485,12 @@ struct BlockGemmARegBRegCRegV1 typename BBlockTensor, typename ScaleATensor, typename ScaleBTensor, - index_t MXdlPack_ = 2, - index_t NXdlPack_ = 2, - index_t KXdlPack_ = 2> + index_t MXdlPack_ = 2, + index_t NXdlPack_ = 2, + index_t KXdlPack_ = 2, + typename std::enable_if_t && + !is_null_tensor_v, + bool>* = nullptr> CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, const ABlockTensor& a_block_tensor, const BBlockTensor& b_block_tensor, @@ -423,7 +500,7 @@ struct BlockGemmARegBRegCRegV1 static_assert(std::is_same_v> && std::is_same_v> && std::is_same_v>, - "wrong!"); + "Datatypes do not match BlockTensor datatypes!"); // check ABC-block-distribution static_assert( @@ -545,41 +622,27 @@ struct BlockGemmARegBRegCRegV1 }); } + template < + typename CBlockTensor, + typename ABlockTensor, + typename BBlockTensor, + typename ScaleATensor, + typename ScaleBTensor, + typename std::enable_if_t && is_null_tensor_v, + bool>* = nullptr> + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensor& a_block_tensor, + const BBlockTensor& b_block_tensor, + const ScaleATensor&, + const ScaleBTensor&) const + { + operator()(c_block_tensor, a_block_tensor, b_block_tensor); + } + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - using c_distr_ys_major = std::conditional_t, sequence<1, 2>>; - if constexpr(UseDefaultScheduler) - { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple<>, - tuple<>, - c_distr_ys_major, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - return c_block_tensor; - } - else - { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - c_distr_ys_major, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - return c_block_tensor; - } + return make_static_distributed_tensor( + make_static_tile_distribution(MakeCBlockDistributionEncode())); } // C = A * B diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp index 2a47d3b54d..a52c27e5a1 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp @@ -12,8 +12,8 @@ template // this variable is used for split K into multiple subtiles in - // order to reduce register usage per wave> + index_t KSubTileNum_ = 1, // this variable is used for split K into multiple subtiles in + bool PackMNIter_ = false> // order to reduce register usage per wave> struct BlockGemmARegBRegCRegV1CustomPolicy { using AType = remove_cvref_t; @@ -29,6 +29,7 @@ struct BlockGemmARegBRegCRegV1CustomPolicy using WarpGemm = remove_cvref_t; static constexpr index_t KSubTileNum = KSubTileNum_; + static constexpr bool PackMNIter = PackMNIter_; template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp index 91ace17499..e02bf97b88 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -26,6 +26,8 @@ struct BlockGemmASmemBSmemCRegV1CustomPolicy static constexpr index_t kNWarps = BlockWarps::at(number<1>{}); static constexpr index_t kKWarps = BlockWarps::at(number<2>{}); + static constexpr bool PackMNIter = false; + using WarpGemm = remove_cvref_t; template diff --git a/include/ck_tile/ops/gemm_mx/block/block_mx_asmem_breg_creg.hpp b/include/ck_tile/ops/gemm/block/block_mx_asmem_breg_creg.hpp similarity index 100% rename from include/ck_tile/ops/gemm_mx/block/block_mx_asmem_breg_creg.hpp rename to include/ck_tile/ops/gemm/block/block_mx_asmem_breg_creg.hpp diff --git a/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp index bfc5673e11..d1b8d84b48 100644 --- a/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp @@ -3,6 +3,8 @@ #pragma once +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" + namespace ck_tile { template @@ -72,6 +74,7 @@ struct MxGemmKernel using BaseKernel::PersistentKernel; using typename BaseKernel::AsLayout; using typename BaseKernel::BsLayout; + using typename BaseKernel::CLayout; using typename BaseKernel::DsLayout; using typename BaseKernel::ADataType; @@ -91,6 +94,8 @@ struct MxGemmKernel using BaseKernel::APackedSize; using BaseKernel::BPackedSize; + using BaseKernel::I1; + using AElementWise = remove_cvref_t; using BElementWise = remove_cvref_t; @@ -100,12 +105,48 @@ struct MxGemmKernel static constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{}); static constexpr int BlockScaleSize = MxGemmPipeline::ScaleBlockSize; - static_assert(BlockScaleSize == 16 || BlockScaleSize == 32, "unsupported BlockScaleSize"); - // Scale tensor element type is always int32_t (4 packed e8m0 bytes). - // For scale16, each thread needs 8 bytes = 2 int32_t elements. - // For scale32, each thread needs 4 bytes = 1 int32_t element. - static constexpr int ScalePackSize = 4; - using ScalePtrType = const int32_t*; + using ScalePtrType = const int32_t*; + // Padding flags pulled from pipeline so the kernel can pad the (unscaled) C and scale views + // consistently with the A/B views that the pipeline already pads via + // Underlying::MakeA/BBlockWindows. + static constexpr bool kPadM = MxGemmPipeline::kPadM; + static constexpr bool kPadN = MxGemmPipeline::kPadN; + static constexpr bool kPadK = MxGemmPipeline::kPadK; + + // ------------------------------------------------------------------ + // Compile-time padding-support invariants for the MX comp-async pipeline. + // + // - K padding is NOT supported: async_load_tile issues vector buffer reads whose + // OOB check is per-vector-start, so a vector that straddles the K pad boundary + // pulls in data from the adjacent row / next K tile rather than zero. The packed + // scale tile has the same vector-load property. Until the async path learns how + // to do per-element pad masking, we forbid kPadK at compile time. + // + // - kPadM / kPadN are supported only when the GEMM has at least one full block + // along that dimension; the CShuffleEpilogue's LDS shuffle uses thread positions + // that do not all participate when the entire dimension is smaller than a tile + // (resulting in zeros being written into in-range output rows). The "entire + // dimension < tile" case is rejected at runtime in IsSupportedArgument; we + // cannot statically catch it because M and N are runtime values. + // ------------------------------------------------------------------ + static_assert(!kPadK, + "MX GEMM (comp-async pipeline): K padding (kPadK = true) is not supported. " + "The async vector loads do not mask elements that straddle the K pad " + "boundary, so partial K tiles produce silently wrong results. Choose K so " + "that K is a multiple of KPerBlock * k_batch."); + + // Single source of truth for the split-K atomic-add precondition, shared by the runtime + // check in IsSupportedArgument and the atomic_add dispatch in operator(). Split-K + // accumulates each k_id's partial C tile with atomic_add; the CShuffle epilogue can only + // emit atomic_add for fp16/bf16 outputs when the C vector size is even. For an odd vector + // size that combination is not instantiated, so such a config cannot run split-K. For all + // shipped tile shapes GetVectorSizeC() is even, so this is defensive rather than reachable. + static constexpr bool kSplitKAtomicAddSupported = + EpiloguePipeline::GetVectorSizeC() % 2 == 0 || !is_any_of::value; + + static constexpr index_t MXdlPackEff = MxGemmPipeline::MXdlPackEff; + static constexpr index_t NXdlPackEff = MxGemmPipeline::NXdlPackEff; + static constexpr index_t KXdlPackEff = MxGemmPipeline::KXdlPackEff; using KernelArgs = MxGemmKernelArgs; @@ -131,14 +172,57 @@ struct MxGemmKernel CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) { - if(kargs.k_batch != 1) + const bool log = ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)); + + if(kargs.k_batch < 1) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("SplitK (k_batch > 1) is not supported for MX GEMM!"); - } + if(log) + CK_TILE_ERROR("MX GEMM: k_batch must be >= 1."); return false; } + + // Split-K derives this k_id's logical K start from the row-major SplitKBatchOffset + // (as_k_split_offset[0]) to offset the packed-scale / flat-B windows; for column-major A + // that field is stride-scaled, so split-K with non-row-major A is not yet supported. + // (k_batch == 1 is unaffected -- the offset is 0 and unused.) When col-major A lands for + // non-preshuffle, extend the split-K K-offset here instead of this reject. + using ALayout = remove_cvref_t>; + if constexpr(!std::is_same_v) + { + if(kargs.k_batch > 1) + { + if(log) + CK_TILE_ERROR("MX GEMM: split-K (k_batch > 1) currently requires row-major A."); + return false; + } + } + + // Scales are granular in K: each packed int32_t covers BlockScaleSize * KXdlPackEff + // consecutive K elements. Every split-K boundary must land on that granularity so that + // each split can compute a packed-scale K offset. K1 is the WarpTile K, which is a + // multiple of that granularity for all shipped configs, but be defensive. + constexpr index_t scale_granularity_k = BlockScaleSize * KXdlPackEff; + if(kargs.k_batch > 1) + { + // splitk_batch_offset allocates K in units of K1 (warp-tile K). If K1 itself is + // not a multiple of the scale granularity, split-K is not safe. + constexpr index_t K1 = BlockGemmShape::WarpTile::at(number<2>{}); + static_assert(K1 % scale_granularity_k == 0, + "MX GEMM: WarpTile K must be a multiple of BlockScaleSize * KXdlPack " + "to support split-K."); + // Defensive runtime check: K must split evenly along K1 boundaries so that each + // k_id consumes a whole number of warp-tile K chunks (and therefore a whole + // number of packed-scale K elements). + if(kargs.K % (K1 * kargs.k_batch) != 0) + { + if(log) + CK_TILE_ERROR("MX GEMM: with k_batch > 1, K must be a multiple of WarpTile_K * " + "k_batch so that every split lands on a packed-scale boundary."); + return false; + } + } + + // Delegate the remaining shape/vector-size checks to the universal kernel. return BaseKernel::IsSupportedArgument(kargs); } @@ -146,10 +230,14 @@ struct MxGemmKernel CK_TILE_DEVICE static auto MakeScaleABlockWindow(const std::array& as_scale_ptr, const KernelArgs& kargs, - index_t block_idx_m) + index_t block_idx_m, + const index_t k_elem_offset = 0) { - const auto&& scale_packs_m = integer_divide_ceil(kargs.M, MThreadPerXdl); - const auto&& scale_packs_k = kargs.K / BlockScaleSize / ScalePackSize; + const auto&& scale_packs_m = integer_divide_ceil(kargs.M, MThreadPerXdl * MXdlPackEff); + const auto&& scale_packs_k = kargs.K / BlockScaleSize / KXdlPackEff; + + // For split-K (k_batch > 1) advance the scale origin into this k_id's packed-K slice. + const index_t k_scale_offset = k_elem_offset / BlockScaleSize / KXdlPackEff; // Scale16: descriptor order [packs_m, MThreadPerXdl, packs_k] -- K contiguous per M-row, // no pre-shuffle needed (natural row-major layout matches). @@ -184,14 +272,28 @@ struct MxGemmKernel return make_tensor_view(as_scale_ptr[i], scale_a_desc); }, number{}); + + // Pad the scale view so partial trailing tiles along M are handled safely (OOB scale + // loads return zero; with A also zero on the padded region the contribution is zero + // regardless of scale value). kPadK is statically disabled, so K never actually pads. + const auto& scale_a_pad_view = generate_tuple( + [&](auto i) { + return pad_tensor_view( + scale_a_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + }, + number{}); + const auto& scale_a_block_window = generate_tuple( [&](auto i) { return make_tile_window( - scale_a_tensor_view[i], + scale_a_pad_view[i], make_tuple( - number{}, - number{}), - {block_idx_m, 0}); + number{}, + number{}), + {block_idx_m / MXdlPackEff, k_scale_offset}); }, number{}); @@ -202,10 +304,14 @@ struct MxGemmKernel CK_TILE_DEVICE static auto MakeScaleBBlockWindow(const std::array& bs_scale_ptr, const KernelArgs& kargs, - index_t block_idx_n) + index_t block_idx_n, + const index_t k_elem_offset = 0) { - const auto&& scale_packs_n = integer_divide_ceil(kargs.N, NThreadPerXdl); - const auto&& scale_packs_k = kargs.K / BlockScaleSize / ScalePackSize; + const auto&& scale_packs_n = integer_divide_ceil(kargs.N, NThreadPerXdl * NXdlPackEff); + const auto&& scale_packs_k = kargs.K / BlockScaleSize / KXdlPackEff; + + // For split-K (k_batch > 1) advance the scale origin into this k_id's packed-K slice. + const index_t k_scale_offset = k_elem_offset / BlockScaleSize / KXdlPackEff; const auto scale_b_naive_desc = [&]() { if constexpr(BlockScaleSize == 16) @@ -236,33 +342,120 @@ struct MxGemmKernel return make_tensor_view(bs_scale_ptr[i], scale_b_desc); }, number{}); + + // Pad the scale view so partial trailing tiles along N are handled safely (OOB scale + // loads return zero; with B also zero on the padded region the contribution is zero + // regardless of scale value). kPadK is statically disabled, so K never actually pads. + const auto& scale_b_pad_view = generate_tuple( + [&](auto i) { + return pad_tensor_view( + scale_b_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + }, + number{}); + const auto& scale_b_block_window = generate_tuple( [&](auto i) { return make_tile_window( - scale_b_tensor_view[i], + scale_b_pad_view[i], make_tuple( - number{}, - number{}), - {block_idx_n, 0}); + number{}, + number{}), + {block_idx_n / NXdlPackEff, k_scale_offset}); }, number{}); return scale_b_block_window; } + CK_TILE_DEVICE static auto + MakeBFlatBlockWindows(const std::array& bs_ptr, + const KernelArgs& kargs, + const index_t i_n, + const index_t k_elem_offset = 0) + { + static_assert(NumBTensor == 1, "MX GEMM preshuffle currently supports one B tensor"); + + constexpr index_t kKPerBlock = MxGemmPipeline::kKPerBlock; + constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1); + constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile; + const index_t kFlatKBlocks = kargs.K / kKPerBlock; + const index_t kFlatN = kargs.N / kNWarpTile; + + const index_t k_flat_offset = (k_elem_offset / kKPerBlock) * flatKPerBlock; + + auto b_flat_tensor_view = [&]() { + static_assert(flatKPerBlock % MxGemmPipeline::GetVectorSizeB() == 0, + "wrong! vector size for preshuffled B tensor"); + auto naive_desc = make_naive_tensor_descriptor_packed( + make_tuple(kFlatN, kFlatKBlocks, number{})); + auto desc = transform_tensor_descriptor( + naive_desc, + make_tuple(make_pass_through_transform(kFlatN), + make_merge_transform_v3_division_mod( + make_tuple(kFlatKBlocks, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view(bs_ptr[number<0>{}], desc); + }(); + + return generate_tuple( + [&](auto) { + return make_tile_window(b_flat_tensor_view, + make_tuple(number{}, + number{}), + {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), + static_cast(k_flat_offset)}); + }, + number{}); + } + + template CK_TILE_DEVICE static void RunGemm(const std::array& as_ptr, const std::array& bs_ptr, const std::array& ds_ptr, EDataType* e_ptr, void* smem_ptr, - const KernelArgs& kargs, + KernelArgs kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, - const index_t block_idx_n) + const index_t block_idx_n, + const index_t k_elem_offset = 0) { std::array as_scale_ptr; - static_for<0, NumATensor, 1>{}([&](auto i) { - as_scale_ptr[i] = reinterpret_cast(kargs.as_scale_ptr[i]); - }); + std::array as_ptr_; + index_t block_idx_m_; + // Large tensor support (when M is large, N and K are relatively small) + using ALayout = remove_cvref_t>; + constexpr bool offset_ptrs_by_tile_coords = + std::is_same_v && + std::is_same_v && !BaseKernel::ClusterLaunch; + + if constexpr(offset_ptrs_by_tile_coords) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + as_ptr_[i] = as_ptr[i] + static_cast(block_idx_m) * + kargs.stride_As[i] / APackedSize; + }); + e_ptr += static_cast(block_idx_m) * kargs.stride_E; + static_for<0, NumATensor, 1>{}([&](auto i) { + as_scale_ptr[i] = reinterpret_cast(kargs.as_scale_ptr[i]) + + static_cast(block_idx_m / MXdlPackEff) * + (kargs.K / BlockScaleSize / KXdlPackEff); + }); + + kargs.M = std::min(kargs.M - block_idx_m, TilePartitioner::MPerBlock); + block_idx_m_ = 0; + } + else + { + static_for<0, NumATensor, 1>{}([&](auto i) { + as_scale_ptr[i] = reinterpret_cast(kargs.as_scale_ptr[i]); + }); + static_for<0, NumATensor, 1>{}([&](auto i) { as_ptr_[i] = as_ptr[i]; }); + block_idx_m_ = block_idx_m; + } std::array bs_scale_ptr; static_for<0, NumBTensor, 1>{}([&](auto i) { @@ -272,18 +465,50 @@ struct MxGemmKernel // cluster launch pads grid to cluster boundaries; skip out-of-bound blocks if constexpr(BaseKernel::ClusterLaunch) { - if(block_idx_m >= kargs.M || block_idx_n >= kargs.N) + if(block_idx_m_ >= kargs.M || block_idx_n >= kargs.N) return; } - const auto& as_block_window = BaseKernel::MakeABlockWindows( - as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); - const auto& bs_block_window = BaseKernel::MakeBBlockWindows( - bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + // The preshuffle A async-load (MakeMX_AAsyncLoadBytesDramWindow) rebuilds the A + // view with a packed descriptor, i.e. it assumes the leading (M) stride equals + // the view's K extent. That only holds when the extent equals stride_A, which is + // the case for k_batch == 1 (splitted_k == K) but NOT for split-K (splitted_k < K): + // a packed extent of splitted_k would stride M by splitted_k instead of stride_A + // and read the wrong rows (only row 0 lands correctly). Use the full K extent so + // the packed M stride matches stride_A. The as_ptr K-offset already selects this + // k_id's slice and num_loop bounds the blocks read, so reads stay within + // [as_k_split_offset, as_k_split_offset + splitted_k) <= K (in-allocation). + const auto& as_block_window = [&]() { + if constexpr(MxGemmPipeline::Preshuffle) + { + return BaseKernel::MakeABlockWindows(as_ptr_, kargs, kargs.K, block_idx_m_); + } + else + { + return BaseKernel::MakeABlockWindows( + as_ptr_, kargs, splitk_batch_offset.splitted_k, block_idx_m_); + } + }(); + const auto& bs_block_window = [&]() { + if constexpr(MxGemmPipeline::Preshuffle) + { + return MakeBFlatBlockWindows(bs_ptr, kargs, block_idx_n, k_elem_offset); + } + else + { + return BaseKernel::MakeBBlockWindows( + bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + } + }(); const auto& ds_block_window = - BaseKernel::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); - const auto& scale_a_block_window = MakeScaleABlockWindow(as_scale_ptr, kargs, block_idx_m); - const auto& scale_b_block_window = MakeScaleBBlockWindow(bs_scale_ptr, kargs, block_idx_n); + BaseKernel::MakeDBlockWindows(ds_ptr, kargs, block_idx_m_, block_idx_n); + + // Create scale block windows. For split-K (k_batch > 1), k_elem_offset advances the + // scale origin into the correct packed-K slice for this k_id; otherwise it is zero. + const auto& scale_a_block_window = + MakeScaleABlockWindow(as_scale_ptr, kargs, block_idx_m_, k_elem_offset); + const auto& scale_b_block_window = + MakeScaleBBlockWindow(bs_scale_ptr, kargs, block_idx_n, k_elem_offset); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); @@ -297,10 +522,58 @@ struct MxGemmKernel num_loop, smem_ptr); - auto c_block_window = BaseKernel::template MakeCBlockWindows( - e_ptr, kargs, block_idx_m, block_idx_n); + // Dispatch epilogue: when k_batch > 1 each split accumulates a partial result into + // the same C tile, so we need atomic add (universal_gemm_kernel pattern). The + // fp16/bf16 even-vector-size precondition is captured once in kSplitKAtomicAddSupported + // and also rejected up front in IsSupportedArgument. + // if(k_batch == 1) + auto c_block_window = BaseKernel::template MakeCBlockWindows( + e_ptr, kargs, block_idx_m_, block_idx_n); EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr); } + + CK_TILE_DEVICE static void RunGemm(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* smem_ptr, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + if(kargs.k_batch == 1) + { + RunGemm(as_ptr, + bs_ptr, + ds_ptr, + e_ptr, + smem_ptr, + kargs, + splitk_batch_offset, + block_idx_m, + block_idx_n); + } + else + { + // This k_id's logical K-element start. For row-major A, as_k_split_offset[0] is exactly + // that offset, so reuse it rather than recomputing the split formula; the packed-scale + // and flat-B K offsets are derived from it. Split-K with non-row-major A is rejected in + // IsSupportedArgument; for k_batch == 1 this value is 0 and unused for any layout. + const index_t k_elem_offset = + amd_wave_read_first_lane(splitk_batch_offset.as_k_split_offset[number<0>{}]); + RunGemm(as_ptr, + bs_ptr, + ds_ptr, + e_ptr, + smem_ptr, + kargs, + splitk_batch_offset, + block_idx_m, + block_idx_n, + k_elem_offset); + } + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp index 63eda6b925..969de3c2d9 100644 --- a/include/ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp @@ -116,7 +116,7 @@ struct MxGroupedGemmKernel using P_ = GemmPipeline; return concat('_', "mx_gemm_grouped", gemm_prec_str(), concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), - concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), + concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB()), concat('x', P_::kPadM, P_::kPadN, P_::kPadK), (UsePersistentKernel ? "Persistent" : "NonPersistent"), (NumDTensor_ == 2 ? "MultiD" : "NoMultiD"), diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index f315c21cef..39718c3338 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -1280,8 +1280,18 @@ struct UniversalGemmKernel std::array bs_ptr; static_for<0, NumBTensor, 1>{}([&](auto i) { - bs_ptr[i] = static_cast(kargs.bs_ptr[i]) + - splitk_batch_offset.bs_k_split_offset[i] / BPackedSize; + if constexpr(GemmPipeline::Preshuffle) + { + // The preshuffle (flat-B) path applies the per-split K offset to the flat + // window origin in when creating the window; bs_k_split_offset is derived from + // the logical B stride and would mis-offset the flat buffer. + bs_ptr[i] = static_cast(kargs.bs_ptr[i]); + } + else + { + bs_ptr[i] = static_cast(kargs.bs_ptr[i]) + + splitk_batch_offset.bs_k_split_offset[i] / BPackedSize; + } }); // Calculate output offset from tile partitioner and apply to output pointer diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 01b94516e7..5247c3909e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -15,9 +15,10 @@ namespace ck_tile { template struct BaseGemmPipelineAgBgCrCompAsync { - static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefetchStages = 3; static constexpr index_t PrefillStages = 1; static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t UnrollHotLoop = 2; CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop) { @@ -30,13 +31,13 @@ struct BaseGemmPipelineAgBgCrCompAsync { return TailNumber::One; } - if(num_loop % PrefetchStages == 1) + if(num_loop % UnrollHotLoop == 0) { - return TailNumber::Three; + return TailNumber::Two; } else { - return TailNumber::Two; + return TailNumber::Three; } } @@ -130,10 +131,9 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync>::PackedSize; - using BlockGemm = remove_cvref_t())>; - using I0 = number<0>; - using I1 = number<1>; - using I2 = number<2>; + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; static constexpr bool LargeTensors = Problem::LargeTensors; @@ -176,6 +176,37 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}; static constexpr auto is_b_load_tr_v = bool_constant{}; + using BlockWarps = typename BlockGemmShape::BlockWarps; + using WarpTile = typename BlockGemmShape::WarpTile; + static constexpr index_t MWarp = BlockWarps::at(I0{}); + static constexpr index_t NWarp = BlockWarps::at(I1{}); + + // Compute effective XdlPack sizes (fall back to 1 when iter count < pack) + static constexpr index_t MPerXdl = WarpTile::at(I0{}); + static constexpr index_t NPerXdl = WarpTile::at(I1{}); + static constexpr index_t KPerXdl = WarpTile::at(I2{}); + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); + static constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; + + static constexpr index_t MXdlPackEff = + (MIterPerWarp >= Policy::MXdlPack && MIterPerWarp % Policy::MXdlPack == 0) + ? Policy::MXdlPack + : 1; + static constexpr index_t NXdlPackEff = + (NIterPerWarp >= Policy::NXdlPack && NIterPerWarp % Policy::NXdlPack == 0) + ? Policy::NXdlPack + : 1; + static constexpr index_t KXdlPackEff = + (KIterPerWarp >= Policy::KXdlPack && KIterPerWarp % Policy::KXdlPack == 0) + ? Policy::KXdlPack + : 1; + + static constexpr index_t ScaleBlockSize = 32; + + // Packed scale dimensions + static constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize / KXdlPackEff; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() { // clang-format off @@ -246,6 +277,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync::value && is_detected::value, bool>* = nullptr> @@ -253,9 +286,16 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync && + !is_null_tile_window_v; + using BlockGemm = + remove_cvref_t())>; + // TODO support multi-ABD static_assert(1 == std::tuple_size_v); static_assert(1 == std::tuple_size_v); @@ -370,6 +410,23 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}], b_dram_tile_window_step); } + // Load scales for iteration 0 (ping) + load_scales_from_dram(scale_a_tile_ping, scale_b_tile_ping); + + // Load scales for iteration 1 (pong) if needed + if(num_loop > 1) + { + load_scales_from_dram(scale_a_tile_pong, scale_b_tile_pong); + } + if constexpr(HasHotLoop) { // we have had 3 global prefetches so far, indexed (0, 1, 2). @@ -482,8 +548,14 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}], b_dram_tile_window_step); // C(i-3) = A(i-3) @ B(i-3) - block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + block_gemm(c_block_tile, + a_block_tile0, + b_block_tile0, + scale_a_tile_ping, + scale_b_tile_ping); HotLoopScheduler(); + // Load next scales after using current scales above + load_scales_from_dram(scale_a_tile_ping, scale_b_tile_ping); } // pong { @@ -503,8 +575,14 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}], b_dram_tile_window_step); // C(i-2) = A(i-2) @ B(i-2) - block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + block_gemm(c_block_tile, + a_block_tile1, + b_block_tile1, + scale_a_tile_pong, + scale_b_tile_pong); HotLoopScheduler(); + // Load next scales after using current scales above + load_scales_from_dram(scale_a_tile_pong, scale_b_tile_pong); } i_global_read += 2; } while(i_global_read < num_loop); @@ -518,7 +596,13 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}, number<0>{}))); + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, + index_t num_loop, + void* p_smem) const + { + // Scale tensor views and base origins for creating tile windows per iteration + const auto& scale_a_tensor_view = scale_a_window[number<0>{}].get_bottom_tensor_view(); + const auto& scale_b_tensor_view = scale_b_window[number<0>{}].get_bottom_tensor_view(); + auto scale_a_base_origin = scale_a_window[number<0>{}].get_window_origin(); + auto scale_b_base_origin = scale_b_window[number<0>{}].get_window_origin(); + + // Create scale windows with packed int32_t dimensions + auto scale_a_dram_window = make_tile_window( + scale_a_tensor_view, + make_tuple(number{}, number{}), + scale_a_base_origin, + Policy::template MakeMX_ScaleA_DramTileDistribution()); + + auto scale_b_dram_window = make_tile_window( + scale_b_tensor_view, + make_tuple(number{}, number{}), + scale_b_base_origin, + Policy::template MakeMX_ScaleB_DramTileDistribution()); + + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + scale_a_dram_window, + scale_b_dram_window, + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } + template + static constexpr index_t AsyncVectorBytes = + sizeof(DataType) * KPack / numeric_traits>::PackedSize; + template static constexpr bool IsSupportedAsyncVectorWidth = - sizeof(DataType) * KPack == 4 || sizeof(DataType) * KPack == 12 || - sizeof(DataType) * KPack == 16; + AsyncVectorBytes == 4 || AsyncVectorBytes == 12 || + AsyncVectorBytes == 16; // XOR Swizzle: support FP8 / BF8 template @@ -57,10 +62,10 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy // Compute the number of LDS read accesses for A or B // IsLoadTr=true if ds_read_tr is used - template + template CK_TILE_HOST_DEVICE static constexpr auto CalculateWGAttrNumAccess() { - if constexpr(IsLoadTr) + if constexpr(IsLoadTr && !IsScale) { // Transpose-load path: ds_read_tr reads DS_READ_TR_SIZE bytes per instruction. constexpr index_t vector_size = @@ -91,32 +96,34 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy } } - template + template CK_TILE_HOST_DEVICE static constexpr auto GetAWGAttrNumAccess() { using WarpTile = typename Problem::BlockGemmShape::WarpTile; constexpr index_t thread_elements = WarpTile::at(I0) * WarpTile::at(I2) / get_warp_size(); return CalculateWGAttrNumAccess, typename Problem::ADataType, - thread_elements>(); + thread_elements, + IsScale>(); } - template + template CK_TILE_HOST_DEVICE static constexpr auto GetBWGAttrNumAccess() { using WarpTile = typename Problem::BlockGemmShape::WarpTile; constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size(); return CalculateWGAttrNumAccess, typename Problem::BDataType, - thread_elements>(); + thread_elements, + IsScale>(); } // Get number of accesses - template + template CK_TILE_HOST_DEVICE static constexpr auto GetWGAttrNumAccess() { - constexpr auto num_access_a = GetAWGAttrNumAccess(); - constexpr auto num_access_b = GetBWGAttrNumAccess(); + constexpr auto num_access_a = GetAWGAttrNumAccess(); + constexpr auto num_access_b = GetBWGAttrNumAccess(); if constexpr(num_access_a == WGAttrNumAccessEnum::Invalid || num_access_b == WGAttrNumAccessEnum::Invalid) @@ -127,6 +134,70 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy return num_access_b; } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXorSwizzleABDramTileDistribution() + { + using BlockGemmShape = typename Problem::BlockGemmShape; + using BlockWarps = typename BlockGemmShape::BlockWarps; + using WarpTile = typename BlockGemmShape::WarpTile; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t KPerBlock = BlockGemmShape::kK; + constexpr index_t KWarps = BlockWarps::at(I2); + constexpr index_t K1 = WarpTile::at(I2) / K2; + constexpr index_t K0 = KPerBlock / (KWarps * K1 * K2); + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t warp_num = BlockSize / warp_size; + + static_assert(KWarps == 1, "MX XOR swizzle currently supports KWarps == 1"); + static_assert(KWarps * K0 * K1 * K2 == KPerBlock, "Wrong!"); + + constexpr index_t M2 = warp_size / K1; + constexpr index_t M1 = warp_num / Problem::NumWaveGroups; + constexpr index_t M0 = MNPerBlock / (M1 * M2); + + static_assert(M0 * M1 * M2 == MNPerBlock, "Wrong!"); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + if constexpr(UseXorSwizzle) + { + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPack = Base::template GetSmemPackA(); + return MakeXorSwizzleABDramTileDistribution(); + } + else + { + return Base::template MakeADramTileDistribution(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + if constexpr(UseXorSwizzle) + { + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPack = Base::template GetSmemPackB(); + return MakeXorSwizzleABDramTileDistribution(); + } + else + { + return Base::template MakeBDramTileDistribution(); + } + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() + { + using BlockGemmShape = typename Problem::BlockGemmShape; + using BlockWarps = typename BlockGemmShape::BlockWarps; + using WarpTile = typename BlockGemmShape::WarpTile; + + constexpr index_t ScaleGranularityK = 32; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t MWarp = BlockWarps::at(number<0>{}); + constexpr index_t NWarp = BlockWarps::at(number<1>{}); + constexpr index_t MPerXdl = WarpTile::at(number<0>{}); + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K_Lane = get_warp_size() / MPerXdl; + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); + constexpr index_t KPerXdl = WarpTile::at(number<2>{}); + constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; + constexpr index_t KPerLane = KPerXdl / ScaleGranularityK / K_Lane; + + // Effective pack sizes: fall back to 1 when iteration count < pack size + constexpr index_t MXdlPackEff = + (MIterPerWarp >= MXdlPack && MIterPerWarp % MXdlPack == 0) ? MXdlPack : 1; + constexpr index_t KXdlPackEff = + (KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1; + + constexpr index_t MIterPerWarp_packed = MIterPerWarp / MXdlPackEff; + constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<2, 1, 2>, + sequence<0, 1, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution() + { + using BlockGemmShape = typename Problem::BlockGemmShape; + using BlockWarps = typename BlockGemmShape::BlockWarps; + using WarpTile = typename BlockGemmShape::WarpTile; + + constexpr index_t ScaleGranularityK = 32; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t MWarp = BlockWarps::at(number<0>{}); + constexpr index_t NWarp = BlockWarps::at(number<1>{}); + constexpr index_t NPerXdl = WarpTile::at(number<1>{}); + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t K_Lane = get_warp_size() / NPerXdl; + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); + + constexpr index_t KPerXdl = WarpTile::at(number<2>{}); + constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; + constexpr index_t KPerLane = KPerXdl / ScaleGranularityK / K_Lane; + + // Effective pack sizes: fall back to 1 when iteration count < pack size + constexpr index_t NXdlPackEff = + (NIterPerWarp >= NXdlPack && NIterPerWarp % NXdlPack == 0) ? NXdlPack : 1; + constexpr index_t KXdlPackEff = + (KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1; + + constexpr index_t NIterPerWarp_packed = NIterPerWarp / NXdlPackEff; + constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<2, 1, 2>, + sequence<0, 1, 2>>{}); + } + template CK_TILE_DEVICE static constexpr auto GetEstimatedVgprCount() { @@ -479,13 +639,13 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy return number{}; } - template + template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr auto wg_attr_num_access = GetWGAttrNumAccess(); + constexpr auto wg_attr_num_access = GetWGAttrNumAccess(); constexpr auto pipeline_tune_params = GetPipelineSubTileNum(); constexpr index_t sub_tile_num = EnableSubTile ? pipeline_tune_params.value : 1; @@ -506,7 +666,8 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy typename Problem::CDataType, BlockWarps, WarpGemm, - sub_tile_num>; + sub_tile_num, + PackMNIter>; return BlockGemmARegBRegCRegV1{}; } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp index 7c4e42c700..b35fc317e9 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp @@ -45,9 +45,6 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp static constexpr index_t APackedSize = ck_tile::numeric_traits::PackedSize; static constexpr index_t BPackedSize = ck_tile::numeric_traits::PackedSize; - using BlockGemm = remove_cvref_t())>; - using WarpGemm = typename BlockGemm::WarpGemm; - static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; @@ -66,9 +63,9 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp static constexpr index_t kflatKPerWarp = BlockGemmShape::flatKPerWarp; - static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * WarpGemm::kM); - static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * WarpGemm::kN); - static constexpr index_t KIterPerWarp = KPerBlock / (KWarps * WarpGemm::kK); + static constexpr index_t MXdlPackEff = Policy::template GetMXdlPackEff(); + static constexpr index_t NXdlPackEff = Policy::template GetNXdlPackEff(); + static constexpr index_t KXdlPackEff = Policy::template GetKXdlPackEff(); static constexpr bool Async = true; @@ -97,6 +94,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp static constexpr auto Scheduler = Problem::Scheduler; + static constexpr index_t ScaleBlockSize = 32; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() { // clang-format off @@ -123,8 +122,6 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp return Policy::template GetSmemSize(); } - static constexpr index_t MFMA_INST = MIterPerWarp * NIterPerWarp * KIterPerWarp; - template struct PipelineImpl : public PipelineImplBase { @@ -141,6 +138,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp typename BsDramBlockWindowTmp, typename AElementFunction, typename BElementFunction, + typename ScaleADramBlockWindowTmp, + typename ScaleBDramBlockWindowTmp, typename std::enable_if_t::value && !is_detected::value, bool>* = nullptr> @@ -148,9 +147,23 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp const AElementFunction& a_element_func, const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, index_t num_loop, void* __restrict__ p_smem) const { + constexpr bool IsScaledGemm = !is_null_tile_window_v && + !is_null_tile_window_v; + using BlockGemm = + remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t MIterPerWarp = MPerBlock / (MWarps * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarps * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / (KWarps * WarpGemm::kK); + + constexpr index_t MFMA_INST = MIterPerWarp * NIterPerWarp * KIterPerWarp; + // TODO: A/B elementwise functions currently not supported ignore = a_element_func; ignore = b_element_func; @@ -183,12 +196,10 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp // Hot loop scheduler // ------------------ auto hot_loop_scheduler = [&]() { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, MIterPerWarp, 0); // MFMA s_waitcnt_lgkm<4>(); __builtin_amdgcn_sched_group_barrier(0x004, 1, 0); // lgkmcnt / SALU - static_for<0, MFMA_INST - 3, 1>{}([&](auto) { + static_for<0, MFMA_INST - MIterPerWarp, 1>{}([&](auto) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA }); __builtin_amdgcn_sched_barrier(0); @@ -201,10 +212,57 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp num_loop, a_dram_block_window_tmp, b_dram_block_window_tmp, + scale_a_window, + scale_b_window, hot_loop_scheduler); } }; + template < + typename AsDramBlockWindowTmp, + typename BsDramBlockWindowTmp, + typename AElementFunction, + typename BElementFunction, + typename ScaleADramBlockWindowTmp, + typename ScaleBDramBlockWindowTmp, + typename std::enable_if_t::value && + is_detected::value && + is_detected::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, + index_t num_loop, + void* p_smem) const + { + // TODO: A/B windows are tuple of windows, but the implementation doesn't take that into + // account yet and just the first element is passed + static_assert(AsDramBlockWindowTmp::size() == 1); + static_assert(BsDramBlockWindowTmp::size() == 1); + static_assert(ScaleADramBlockWindowTmp::size() == 1); + static_assert(ScaleBDramBlockWindowTmp::size() == 1); + + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp[I0], + a_element_func, + b_dram_block_window_tmp[I0], + b_element_func, + scale_a_window[I0], + scale_b_window[I0], + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } + template {}, number<0>{}))); // TODO: A/B windows are tuple of windows, but the implementation doesn't take that into // account yet and just the first element is passed static_assert(AsDramBlockWindowTmp::size() == 1); @@ -231,6 +291,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp a_element_func, b_dram_block_window_tmp[I0], b_element_func, + NullTileWindowType{}, + NullTileWindowType{}, num_loop, p_smem); }; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp index 56709be910..4a6e1ab705 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp @@ -389,30 +389,85 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked; static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked; - CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + // Scale part + static constexpr int BlockScaleSize = 32; + + // XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension + // Host packs MXdlPack * KXdlPack e8m0_t into one int32_t for A scales + // Host packs NXdlPack * KXdlPack e8m0_t into one int32_t for B scales + static constexpr int MXdlPack = 2; + static constexpr int NXdlPack = 2; + static constexpr int KXdlPack = 2; + + // Compute effective XdlPack sizes (fall back to 1 when iter count < pack) + static constexpr index_t KPerXdl = WarpTile::at(I2); + static constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; + + static constexpr index_t MXdlPackEff = + (MIterPerWarp >= MXdlPack && MIterPerWarp % MXdlPack == 0) ? MXdlPack : 1; + static constexpr index_t NXdlPackEff = + (NIterPerWarp >= NXdlPack && NIterPerWarp % NXdlPack == 0) ? NXdlPack : 1; + static constexpr index_t KXdlPackEff = + (KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1; + + static constexpr index_t KPerBlockScale = KPerBlock / BlockScaleSize / KXdlPackEff; + + CK_TILE_HOST_DEVICE static constexpr auto GetMXdlPackEff() { return MXdlPackEff; } + CK_TILE_HOST_DEVICE static constexpr auto GetNXdlPackEff() { return NXdlPackEff; } + CK_TILE_HOST_DEVICE static constexpr auto GetKXdlPackEff() { return KXdlPackEff; } + + CK_TILE_HOST_DEVICE static constexpr auto GetKStepAQ() { return KPerBlockScale; } + CK_TILE_HOST_DEVICE static constexpr auto GetKStepBQ() { return KPerBlockScale; } + + CK_TILE_HOST_DEVICE static constexpr auto GetInstCountAQ() { - // TODO: Fix for transpose - constexpr auto wg_attr_num_access = WGAccess; + return (MIterPerWarp / MXdlPackEff) * (KIterPerWarp / KXdlPackEff); + } - using WarpGemm = WarpGemmDispatcher; + CK_TILE_HOST_DEVICE static constexpr auto GetInstCountBQ() + { + return (NIterPerWarp / NXdlPackEff) * (KIterPerWarp / KXdlPackEff); + } - using BlockGemmPolicy = - BlockGemmARegBRegCRegV1CustomPolicy; + CK_TILE_HOST_DEVICE static constexpr auto MakeAQBlockDistribution() + { + constexpr index_t K_Lane = get_warp_size() / WarpTileM; - return BlockGemmARegBRegCRegEightWavesV1{}; + constexpr index_t KPerLane = WarpTileK / BlockScaleSize / K_Lane; + + constexpr index_t MIterPerWarp_packed = MIterPerWarp / MXdlPackEff; + constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // repeat over MWarps + tuple, // M dimension (first) + sequence>, // K dimension (second) + tuple, sequence<2, 1>>, // , + tuple, sequence<1, 2>>, + sequence<2, 1, 2>, // + sequence<0, 1, 2>>{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeBQBlockDistribution() + { + constexpr index_t K_Lane = get_warp_size() / WarpTileN; + + constexpr index_t KPerLane = WarpTileK / BlockScaleSize / K_Lane; + + constexpr index_t NIterPerWarp_packed = NIterPerWarp / NXdlPackEff; + constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // repeat over MWarps + tuple, // N dimension + // (first) + sequence>, // K dimension (second) + tuple, sequence<2, 1>>, // , + tuple, sequence<1, 3>>, + sequence<2, 1, 2>, // + sequence<0, 1, 2>>{}); } }; } // namespace detail @@ -447,8 +502,61 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy FORWARD_METHOD_(GetSmemPackA); FORWARD_METHOD_(GetSmemPackB); FORWARD_METHOD_(IsPreshuffle); + // Scale part + FORWARD_METHOD_(MakeAQBlockDistribution); + FORWARD_METHOD_(MakeBQBlockDistribution); + FORWARD_METHOD_(GetKStepAQ); + FORWARD_METHOD_(GetKStepBQ); + FORWARD_METHOD_(GetInstCountAQ); + FORWARD_METHOD_(GetInstCountBQ); + FORWARD_METHOD_(GetMXdlPackEff); + FORWARD_METHOD_(GetNXdlPackEff); + FORWARD_METHOD_(GetKXdlPackEff); #undef FORWARD_METHOD_ + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockGemmShape = typename Problem::BlockGemmShape; + using BlockWarps = typename BlockGemmShape::BlockWarps; + using WarpTile = typename BlockGemmShape::WarpTile; + + using AComputeDataType = remove_cvref_t; + using BComputeDataType = remove_cvref_t; + static_assert(std::is_same_v); + using ComputeDataType = AComputeDataType; + + constexpr auto WGAccess = + std::is_same_v || std::is_same_v + ? WGAttrNumAccessEnum::Double + : WGAttrNumAccessEnum::Single; + + // TODO: Fix for transpose + constexpr auto wg_attr_num_access = WGAccess; + + using WarpGemm = WarpGemmDispatcher{}), + WarpTile::at(number<1>{}), + WarpTile::at(number<2>{}), + Problem::TransposeC, + false, + false, + wg_attr_num_access>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegEightWavesV1{}; + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_tdm_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_tdm_v1.hpp index a88319927b..d50414eb75 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_tdm_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_tdm_v1.hpp @@ -138,6 +138,10 @@ struct GemmPipelineAgBgCrCompTDMV1 : public BaseGemmPipelineAgBgCrCompTDM(); // for these three functions, we always return 1 since TDM handles vectorization internally diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp index 823c4eef32..49a8f84078 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp @@ -17,9 +17,6 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< using BlockGemmShape = remove_cvref_t; - using BlockGemm = remove_cvref_t())>; - using WarpGemm = typename BlockGemm::WarpGemm; - static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; @@ -42,10 +39,6 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; static constexpr index_t WarpTileN = BlockGemmShape::WarpTile::at(I1); - static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * WarpGemm::kM); - static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * WarpGemm::kN); - static constexpr index_t KIterPerWarp = KPerBlock / (KWarps * WarpGemm::kK); - // Rely on the policy. In this way it works for both GEMM and blockscale static constexpr bool Preshuffle = Policy::template IsPreshuffle(); @@ -72,23 +65,30 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< lds_tile_window.load(dst_block_tile, number<-1>{}, true_type{}, static_move_ys{}); } - template + template CK_TILE_DEVICE void LocalPrefetchB(DataType* smem, DstBlockTile& dst_block_tile, - SrcTileWindow& lds_tile_window) const + SrcTileWindow& lds_tile_window, + number = {}, + number = {}) const { + constexpr index_t NIterPerWarp = NPerBlock / (NWarps * NPerXdl); + constexpr index_t KIterPerWarp = KPerBlock / (KWarps * KPerXdl); // swizzle factor limitation using static_move_ys = std::conditional_t, false_type, true_type>; lds_tile_window.set_bottom_tensor_view_data_ptr(smem); static_for_product, number>{}( [&](auto nIter, auto kIter) { - lds_tile_window.load_with_offset( - number_tuple{}, - dst_block_tile[nIter][kIter], - number<-1>{}, - true_type{}, - static_move_ys{}); + lds_tile_window.load_with_offset(number_tuple{}, + dst_block_tile[nIter][kIter], + number<-1>{}, + true_type{}, + static_move_ys{}); }); } @@ -290,6 +290,12 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< const BQDramBlockWindowTmp& bq_dram_block_window_tmp, SchedulerFunc&& scheduler_func) const { + constexpr bool IsScaledGemm = !is_null_tile_window_v && + !is_null_tile_window_v; + using BlockGemm = + remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + // Loop count constexpr index_t N_LOOP = HasHotLoop ? 4 : TailNum == TailNumber::One ? 1 @@ -378,7 +384,11 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< LocalPrefetchA(smem_a, a_block_tile, a_lds_gemm_window); BDataType* smem_b = reinterpret_cast(smem01[i] + lds_offset_b); - LocalPrefetchB(smem_b, b_block_tiles, b_lds_gemm_window); + LocalPrefetchB(smem_b, + b_block_tiles, + b_lds_gemm_window, + number{}, + number{}); }; auto calc_gemm = [&](index_t i) { @@ -418,7 +428,11 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< GlobalPrefetchAsync(smem_b_tic, b_copy_lds_window, b_copy_dram_window); BDataType* smem_b_toc = reinterpret_cast(smem01[toc] + lds_offset_b); - LocalPrefetchB(smem_b_toc, b_block_tiles, b_lds_gemm_window); + LocalPrefetchB(smem_b_toc, + b_block_tiles, + b_lds_gemm_window, + number{}, + number{}); __builtin_amdgcn_sched_barrier(0); block_sync_lds_direct_load(); diff --git a/include/ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1.hpp similarity index 86% rename from include/ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp rename to include/ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1.hpp index e9d80d73b7..a19a28ea95 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1.hpp @@ -8,16 +8,11 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" -#include "ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" namespace ck_tile { -template -struct MXEpilogueTraits -{ - static constexpr index_t BlockedXDLNPerWarp = GemmConfig::Preshuffle ? 2 : 1; -}; - // This pipeline extends the existing universal GEMM machinery with preshuffled-B support. template struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 @@ -53,9 +48,9 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t WaveSize = get_warp_size(); - static constexpr index_t kMPerBlock = BlockGemmShape::kM; - static constexpr index_t kNPerBlock = BlockGemmShape::kN; - static constexpr index_t kKPerBlock = BlockGemmShape::kK; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; @@ -69,12 +64,12 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 template static constexpr index_t GetVectorSizeA() { - return 32; + return PipelinePolicy::template GetVectorSizeA(); } template static constexpr index_t GetVectorSizeB() { - return 32; + return PipelinePolicy::template GetVectorSizeB(); } static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } @@ -93,21 +88,26 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 static constexpr index_t MWarp = BlockGemm::MWarp; static constexpr index_t NWarp = BlockGemm::NWarp; - static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); - static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); - static constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; static constexpr index_t KFlatBytesPerBlockPerIter = flatKPerWarp * sizeof(BDataType) / BPackedSize; static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; - static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; - static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; + static constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + static constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; - static constexpr index_t ScaleGranularityK = 32; - static constexpr index_t MXdlPack = 2; - static constexpr index_t NXdlPack = 2; - static constexpr index_t KXdlPack = 2; + static constexpr index_t ScaleBlockSize = 32; + static constexpr index_t MXdlPack = 2; + static constexpr index_t NXdlPack = 2; + static constexpr index_t KXdlPack = 2; + + // Preshuffle only supports this case as checked by static asserts + static constexpr index_t MXdlPackEff = MXdlPack; + static constexpr index_t NXdlPackEff = NXdlPack; + static constexpr index_t KXdlPackEff = KXdlPack; static constexpr index_t AK1 = 16 * APackedSize / sizeof(ADataType); static constexpr index_t BK1 = 16 * BPackedSize / sizeof(BDataType); @@ -125,12 +125,12 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 static constexpr index_t Aload_num_perK = dswrite_num_perK; static constexpr index_t Aload_rep = dswrite_rep; - static constexpr index_t Bload_num_perK = kNPerBlock * WarpGemm::kK / NWarp / BK1 / WaveSize; + static constexpr index_t Bload_num_perK = NPerBlock * WarpGemm::kK / NWarp / BK1 / WaveSize; static constexpr index_t Bload_num = Bload_num_perK * KIterPerWarp; static constexpr index_t ScaleBload_num = - kNPerBlock * kKPerBlock / NWarp / ScaleGranularityK / NXdlPack / KXdlPack / WaveSize; + NPerBlock * KPerBlock / NWarp / ScaleBlockSize / NXdlPack / KXdlPack / WaveSize; static constexpr index_t ScaleAload_num = - kMPerBlock * kKPerBlock / MWarp / ScaleGranularityK / MXdlPack / KXdlPack / WaveSize; + MPerBlock * KPerBlock / MWarp / ScaleBlockSize / MXdlPack / KXdlPack / WaveSize; static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2; static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter; @@ -181,9 +181,9 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 std::is_same_v>, "wrong!"); - static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!"); - static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + static_assert(KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); static_assert(MWarp == 1); @@ -194,7 +194,7 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 a_copy_dram_window_tmp); using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; constexpr ADramTileWindowStep a_dram_tile_window_step = - make_array(index_t{0}, index_t{kKPerBlock * sizeof(ADataType) / APackedSize}); + make_array(index_t{0}, index_t{KPerBlock * sizeof(ADataType) / APackedSize}); __builtin_amdgcn_sched_barrier(0); @@ -208,13 +208,13 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 auto a_store_lds_window_ping = make_tile_window(a_lds_block_ping, - make_tuple(number{}, - number{}), + make_tuple(number{}, + number{}), {0, 0}); auto a_store_lds_window_pong = make_tile_window(a_lds_block_pong, - make_tuple(number{}, - number{}), + make_tuple(number{}, + number{}), {0, 0}); auto a_warp_window_ping = make_tile_window( @@ -306,7 +306,7 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); }); }); - move_tile_window(scale_a_dram_window, {0, kKPerBlock / (ScaleGranularityK * KXdlPack)}); + move_tile_window(scale_a_dram_window, {0, KPerBlock / (ScaleBlockSize * KXdlPack)}); static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { @@ -315,7 +315,7 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); }); }); - move_tile_window(scale_b_dram_window, {0, kKPerBlock / (ScaleGranularityK * KXdlPack)}); + move_tile_window(scale_b_dram_window, {0, KPerBlock / (ScaleBlockSize * KXdlPack)}); __builtin_amdgcn_sched_barrier(0); if constexpr(HasHotLoop || TailNum == TailNumber::Even) @@ -375,10 +375,8 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 Base::GlobalPrefetchAsync( a_store_lds_window_ping, a_dram_window, a_dram_tile_window_step); - move_tile_window(scale_a_dram_window, - {0, kKPerBlock / (ScaleGranularityK * KXdlPack)}); - move_tile_window(scale_b_dram_window, - {0, kKPerBlock / (ScaleGranularityK * KXdlPack)}); + move_tile_window(scale_a_dram_window, {0, KPerBlock / (ScaleBlockSize * KXdlPack)}); + move_tile_window(scale_b_dram_window, {0, KPerBlock / (ScaleBlockSize * KXdlPack)}); block_gemm.LocalPrefetch(a_load_windows_pong); HotLoopScheduler(); @@ -420,10 +418,8 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 Base::GlobalPrefetchAsync( a_store_lds_window_pong, a_dram_window, a_dram_tile_window_step); - move_tile_window(scale_a_dram_window, - {0, kKPerBlock / (ScaleGranularityK * KXdlPack)}); - move_tile_window(scale_b_dram_window, - {0, kKPerBlock / (ScaleGranularityK * KXdlPack)}); + move_tile_window(scale_a_dram_window, {0, KPerBlock / (ScaleBlockSize * KXdlPack)}); + move_tile_window(scale_b_dram_window, {0, KPerBlock / (ScaleBlockSize * KXdlPack)}); block_gemm.LocalPrefetch(a_load_windows_ping); HotLoopScheduler(); @@ -707,6 +703,43 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 } } + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_copy_dram_window_tmp, + const AElementFunction&, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const BElementFunction&, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, + index_t num_loop, + void* __restrict__ p_smem) const + { + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + constexpr index_t smem_size = PipelinePolicy::template GetSmemSize(); + const auto smem = reinterpret_cast(p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_num = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_copy_dram_window_tmp[number<0>{}], + b_flat_dram_block_window_tmp[number<0>{}], + scale_a_window[number<0>{}], + scale_b_window[number<0>{}], + num_loop, + smem, + smem + smem_size); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_num); + } + template (); + const auto smem = reinterpret_cast(p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_num = Base::GetBlockLoopTailNum(num_loop); const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { return PipelineImpl{}.template operator()( @@ -729,8 +763,8 @@ struct MXGemmPreshufflePipelineAGmemBGmemCRegV1 scale_a_window, scale_b_window, num_loop, - p_smem_ping, - p_smem_pong); + smem, + smem + smem_size); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_num); diff --git a/include/ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1_policy.hpp similarity index 98% rename from include/ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp rename to include/ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1_policy.hpp index 04fac8f67a..8ac832f158 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_mx_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" -#include "ck_tile/ops/gemm_mx/block/block_mx_asmem_breg_creg.hpp" +#include "ck_tile/ops/gemm/block/block_mx_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" namespace ck_tile { @@ -57,6 +57,9 @@ struct MXGemmPipelineAgBgCrPolicy : UniversalGemmPipelineAgBgCrPolicy static constexpr index_t AK1 = DWORDx4 * APackedSize; static constexpr index_t BK1 = DWORDx4 * BPackedSize; + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() { return AK1; } + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() { return BK1; } + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { using WarpGemm = WarpGemmDispatcher -struct BlockMXGemmARegBRegCRegEightWavesV1 -{ - private: - template - struct GemmTraits_ - { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using AComputeDataType = remove_cvref_t; - using BComputeDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr auto Scheduler = Problem::Scheduler; - - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; - - static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - static constexpr index_t MWarp = config.template at<1>(); - static constexpr index_t NWarp = config.template at<2>(); - static constexpr index_t KWarp = Problem::BlockGemmShape::BlockWarps::at(number<2>{}); - - using I0 = number<0>; - using I1 = number<1>; - - static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}), - "Error! WarpGemm's MWarp is not consistent with BlockGemmShape!"); - static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}), - "Error! WarpGemm's NWarp is not consistent with BlockGemmShape!"); - static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}), - "Error! WarpGemm's M is not consistent with BlockGemmShape!"); - static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}), - "Error! WarpGemm's N is not consistent with BlockGemmShape!"); - - static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); - static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); - static constexpr index_t KIterPerWarp = KPerBlock / (KWarp * WarpGemm::kK); - - // Controls how many MAC clusters (MFMA blocks) we have per wave - // If InterWaveSchedulingMacClusters = 1; - // Then we group all WarpGemms into single MAC cluster. - // But if InterWaveSchedulingMacClusters = 2, then we - // split the warp gemms into two groups. - static constexpr index_t InterWaveSchedulingMacClusters = 1; - - static constexpr index_t KPackA = WarpGemm::kAKPack; - static constexpr index_t KPackB = WarpGemm::kBKPack; - static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; - static constexpr bool TransposeC = Problem::TransposeC; - }; - - public: - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using Traits = GemmTraits_; - - using WarpGemm = typename Traits::WarpGemm; - using BlockGemmShape = typename Traits::BlockGemmShape; - - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using AComputeDataType = remove_cvref_t; - using BComputeDataType = remove_cvref_t; - - static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; - static constexpr index_t MIterPerWarp = Traits::MIterPerWarp; - static constexpr index_t NIterPerWarp = Traits::NIterPerWarp; - - static constexpr index_t MWarp = Traits::MWarp; - static constexpr index_t NWarp = Traits::NWarp; - static constexpr index_t KWarp = Traits::KWarp; - - static constexpr auto Scheduler = Traits::Scheduler; - static constexpr bool TransposeC = Traits::TransposeC; - - using AWarpDstr = typename WarpGemm::AWarpDstr; - using BWarpDstr = typename WarpGemm::BWarpDstr; - using CWarpDstr = typename WarpGemm::CWarpDstr; - - using AWarpTensor = typename WarpGemm::AWarpTensor; - using BWarpTensor = typename WarpGemm::BWarpTensor; - using CWarpTensor = typename WarpGemm::CWarpTensor; - - static_assert(std::is_same_v); - - static constexpr auto a_warp_y_lengths = - to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - static constexpr auto b_warp_y_lengths = - to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - static constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - - static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; - static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; - static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - static constexpr index_t APackedSize = - ck_tile::numeric_traits>::PackedSize; - static constexpr index_t BPackedSize = - ck_tile::numeric_traits>::PackedSize; - - using I0 = number<0>; - using I1 = number<1>; - - // Note: distribution encodings have MIterPerWarp and NIterPerWarp contiguous because of scale - // packing. - - CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() - { - constexpr index_t KPerThread = Traits::KPerThread; - constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - - constexpr index_t KPerInnerLoop = - ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); - - constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; - - using KIterSeq = std::conditional_t, - sequence>; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<1, 1>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - return a_block_dstr_encode; - } - - CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() - { - constexpr index_t KPerThread = Traits::KPerThread; - constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - constexpr index_t KPerInnerLoop = - ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); - constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; - - using KIterSeq = std::conditional_t, - sequence>; - - constexpr auto b_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<>, - sequence<>>{}; - - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - return b_block_dstr_encode; - } - - CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() - { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence<2, NIterPerWarp, NWarp / 2>>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<1, 1>>{}; - constexpr auto c_block_dstr_encoding = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - return c_block_dstr_encoding; - } - - CK_TILE_DEVICE static constexpr auto MakeCBlockTile() - { - return make_static_distributed_tensor( - make_static_tile_distribution(MakeCBlockDistributionEncode())); - } - - using ALdsTile = decltype(make_static_distributed_tensor( - make_static_tile_distribution(MakeABlockDistributionEncode()))); - using BLdsTiles = statically_indexed_array< - statically_indexed_array( - make_static_tile_distribution( - MakeBBlockDistributionEncode()))), - KIterPerWarp>, - NIterPerWarp>; - - // C += A * B - template - CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ALdsTile& a_warp_tile_, - const BLdsTiles& b_warp_tiles_, - const ScaleATensor& scale_a_tensor, - const ScaleBTensor& scale_b_tensor) const - { - // checks - static_assert(std::is_same_v>, - "CDataType must be same as CBlockTensor::DataType!"); - static_assert( - std::is_same_v, - remove_cvref_t>, - "C distribution is wrong!"); - - // Effective XdlPack: fall back to 1 when iteration count is insufficient - constexpr index_t MXdlPack = - (MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1; - constexpr index_t NXdlPack = - (NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1; - constexpr index_t KXdlPack = - (KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1; - - constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack; - constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack; - constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack; - - // hot loop: - static_for_product, - number, - number>{}([&](auto ikpack, auto inpack, auto impack) { - // get A scale for this M-K tile using get_y_sliced_thread_data - auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const int32_t a_scale_packed = bit_cast(scale_a_slice[number<0>{}]); - - // get B scale for this N-K tile using get_y_sliced_thread_data - auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const int32_t b_scale_packed = bit_cast(scale_b_slice[number<0>{}]); - - // Inner loops: issue MFMAs within the pack group using OpSel - static_for_product, number, number>{}( - [&](auto ikxdl, auto inxdl, auto imxdl) { - constexpr auto kIter = ikpack * KXdlPack + ikxdl; - constexpr auto mIter = impack * MXdlPack + imxdl; - constexpr auto nIter = inpack * NXdlPack + inxdl; - - // OpSel for A: selects byte within packed int32_t - constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl; - - // OpSel for B: selects byte within packed int32_t - constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl; - - // read A warp tensor from A Block window - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = - b_warp_tiles_[number{}][number{}].get_thread_buffer(); - - // read C warp tensor from C block tensor - using c_iter_idx = sequence; - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM with MX scaling - WarpGemm{}.template operator(), OpSelB>( - c_warp_tensor, - a_warp_tensor, - b_warp_tensor, - a_scale_packed, - b_scale_packed); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp deleted file mode 100644 index 7e190dc8e1..0000000000 --- a/include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp +++ /dev/null @@ -1,324 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp" - -namespace ck_tile { - -// A is block distributed tensor -// B is block distributed tensor -// C is block distributed tensor -template -struct BlockMXGemmARegBRegCRegV1 -{ - private: - template - struct GemmTraits_ - { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - static constexpr index_t kBlockSize = Problem::kBlockSize; - - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; - - static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - static constexpr index_t MWarp = config.template at<1>(); - static constexpr index_t NWarp = config.template at<2>(); - static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); - static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); - static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; - - static constexpr index_t KPackA = WarpGemm::kAKPack; - static constexpr index_t KPackB = WarpGemm::kBKPack; - }; - - public: - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - static constexpr bool TransposeC = TransposeC_; - - using Traits = GemmTraits_; - - using WarpGemm = typename Traits::WarpGemm; - using BlockGemmShape = typename Traits::BlockGemmShape; - - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - - static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; - static constexpr index_t MIterPerWarp = Traits::MIterPerWarp; - static constexpr index_t NIterPerWarp = Traits::NIterPerWarp; - - static constexpr index_t MWarp = Traits::MWarp; - static constexpr index_t NWarp = Traits::NWarp; - static constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1); - - // Note: distribution encodings have MIterPerWarp and NIterPerWarp contiguous because of scale - // packing. - - CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() - { - if constexpr(UseDefaultScheduler) - { - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple<>, - tuple<>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - return a_block_dstr_encode; - } - else - { - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<1, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - return a_block_dstr_encode; - } - } - - CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() - { - if constexpr(UseDefaultScheduler) - { - constexpr auto b_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple<>, - tuple<>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - return b_block_dstr_encode; - } - else - { - constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<1, 0>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - return b_block_dstr_encode; - } - } - - CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() - { - using c_distr_ys_major = std::conditional_t, sequence<1, 2>>; - if constexpr(UseDefaultScheduler) - { - using c_distr_ys_minor = std::conditional_t, sequence<0, 1>>; - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - c_distr_ys_major, - c_distr_ys_minor>{}; - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - - return c_block_dstr_encode; - } - else - { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - c_distr_ys_major, - sequence<1, 1>>{}; - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - - return c_block_dstr_encode; - } - } - - // C += A * B with MX scaling and packed-in-two (XdlPack) optimization - // Scale tensors contain pre-packed int32_t: each int32_t holds MXdlPack * KXdlPack e8m0_t - // values (for A) or NXdlPack * KXdlPack (for B), packed on the host. - // Uses OpSel (0-3) to select which byte within the packed int32_t for each MFMA call. - // XdlPack template parameters default to 2; fall back to 1 when iteration count is too small. - template - CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ABlockTensor& a_block_tensor, - const BBlockTensor& b_block_tensor, - const ScaleATensor& scale_a_tensor, - const ScaleBTensor& scale_b_tensor) const - { - static_assert(std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "Datatypes do not match BlockTensor datatypes!"); - - // check ABC-block-distribution - static_assert( - std::is_same_v, - remove_cvref_t>, - "A distribution is wrong!"); - static_assert( - std::is_same_v, - remove_cvref_t>, - "B distribution is wrong!"); - static_assert( - std::is_same_v, - remove_cvref_t>, - "C distribution is wrong!"); - - using AWarpDstr = typename WarpGemm::AWarpDstr; - using BWarpDstr = typename WarpGemm::BWarpDstr; - using CWarpDstr = typename WarpGemm::CWarpDstr; - - using AWarpTensor = typename WarpGemm::AWarpTensor; - using BWarpTensor = typename WarpGemm::BWarpTensor; - using CWarpTensor = typename WarpGemm::CWarpTensor; - - constexpr auto a_warp_y_lengths = - to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto b_warp_y_lengths = - to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - - constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; - constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - // Effective XdlPack: fall back to 1 when iteration count is insufficient - constexpr index_t MXdlPack = - (MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1; - constexpr index_t NXdlPack = - (NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1; - constexpr index_t KXdlPack = - (KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1; - - constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack; - constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack; - constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack; - - // hot loop with MX scaling and pre-packed int32_t scales: - // Outer loops iterate over pack groups (scale tile indices) - static_ford>{}([&](auto ii) { - constexpr auto ikpack = number{}]>{}; - constexpr auto impack = number{}]>{}; - // Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t) - auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const int32_t a_scale_packed = bit_cast(scale_a_slice[number<0>{}]); - - static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { - // Get pre-packed int32_t B scale - auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const int32_t b_scale_packed = bit_cast(scale_b_slice[number<0>{}]); - - // Inner loops: issue MFMAs within the pack group using OpSel - static_ford>{}([&](auto jj) { - constexpr auto ikxdl = number{}]>{}; - constexpr auto imxdl = number{}]>{}; - constexpr auto kIter = ikpack * KXdlPack + ikxdl; - constexpr auto mIter = impack * MXdlPack + imxdl; - - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - // OpSel for A: selects byte within packed int32_t - constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl; - - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto nIter = inpack * NXdlPack + inxdl; - - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - // OpSel for B: selects byte within packed int32_t - constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl; - - // read C warp tensor from C block tensor - using c_iter_idx = std::conditional_t, - sequence>; - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM with MX scaling using pre-packed scale and OpSel - WarpGemm{}.template operator(), OpSelB>( - c_warp_tensor, - a_warp_tensor, - b_warp_tensor, - a_scale_packed, - b_scale_packed); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); - }); - }); - } - - CK_TILE_DEVICE static constexpr auto MakeCBlockTile() - { - return make_static_distributed_tensor( - make_static_tile_distribution(MakeCBlockDistributionEncode())); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp deleted file mode 100644 index edeb1a5214..0000000000 --- a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp +++ /dev/null @@ -1,863 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/common.hpp" -#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" -#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp" - -namespace ck_tile { - -template -struct MXGemmPipelineAgBgCrCompAsyncEightWaves; - -namespace detail { -template -struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy; - -template -struct MXGemmKernelScaleTraits -{ - static constexpr index_t ScaleGranularityK = Pipeline::ScaleGranularityK; - static constexpr index_t MXdlPack = Pipeline::MXdlPack; - static constexpr index_t NXdlPack = Pipeline::NXdlPack; - static constexpr index_t KXdlPack = Pipeline::KXdlPack; -}; - -template -struct MXGemmKernelScaleTraits> -{ - using PolicyTraits = MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy; - - static constexpr index_t ScaleGranularityK = PolicyTraits::BlockScaleSize; - static constexpr index_t MXdlPack = PolicyTraits::MXdlPack; - static constexpr index_t NXdlPack = PolicyTraits::NXdlPack; - static constexpr index_t KXdlPack = PolicyTraits::KXdlPack; -}; -} // namespace detail - -template , - typename ScaleN = MXScalePointer, - index_t NumATensor = 1, - index_t NumBTensor = 1, - index_t NumDTensor = 0> -struct MXGemmKernelArgs : UniversalGemmKernelArgs -{ - using Base = UniversalGemmKernelArgs; - - CK_TILE_HOST MXGemmKernelArgs(const std::array& as_ptr_, - const std::array& bs_ptr_, - const std::array& ds_ptr_, - void* e_ptr_, - index_t k_batch_, - index_t M_, - index_t N_, - index_t K_, - const std::array& stride_As_, - const std::array& stride_Bs_, - const std::array& stride_Ds_, - index_t stride_E_, - ScaleM scale_m_ptr_, - ScaleN scale_n_ptr_) - : Base{as_ptr_, - bs_ptr_, - ds_ptr_, - e_ptr_, - M_, - N_, - K_, - stride_As_, - stride_Bs_, - stride_Ds_, - stride_E_, - k_batch_}, - scale_m_ptr(scale_m_ptr_), - scale_n_ptr(scale_n_ptr_) - { - } - - ScaleM scale_m_ptr; - ScaleN scale_n_ptr; -}; - -template -struct MXGemmKernel : UniversalGemmKernel -{ - using Underlying = UniversalGemmKernel; - - using TilePartitioner = remove_cvref_t; - using MXGemmPipeline = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; - using DsLayout = remove_cvref_t; - using DsDataType = remove_cvref_t; - static constexpr index_t KernelBlockSize = MXGemmPipeline::BlockSize; - static constexpr bool UsePersistentKernel = MXGemmPipeline::UsePersistentKernel; - - // Below type is actually accumulation data type - the output of block GEMM. - using EDataType = remove_cvref_t; - - static constexpr auto I0 = number<0>(); - static constexpr auto I1 = number<1>(); - static constexpr auto I2 = number<2>(); - static constexpr auto I3 = number<3>(); - static constexpr auto I4 = number<4>(); - static constexpr auto I5 = number<5>(); - - static constexpr index_t NumATensor = Underlying::AsDataType::size(); - static constexpr index_t NumBTensor = Underlying::BsDataType::size(); - static constexpr index_t NumDTensor = Underlying::DsDataType::size(); - - using ADataType = remove_cvref_t>; - using BDataType = remove_cvref_t>; - - static constexpr auto MThreadPerXdl = BlockGemmShape::WarpTile::at(number<0>{}); - static constexpr auto NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{}); - static constexpr auto KThreadPerXdl = 64 / MThreadPerXdl; - - static constexpr auto APackedSize = numeric_traits::PackedSize; - static constexpr auto BPackedSize = numeric_traits::PackedSize; - - // XdlPack: desired packing of e8m0_t scale values into int32_t - using ScaleTraits = detail::MXGemmKernelScaleTraits; - static constexpr index_t ScaleGranularityK = ScaleTraits::ScaleGranularityK; - static constexpr index_t MXdlPack = ScaleTraits::MXdlPack; - static constexpr index_t NXdlPack = ScaleTraits::NXdlPack; - static constexpr index_t KXdlPack = ScaleTraits::KXdlPack; - - // Effective pack sizes: fall back to 1 when dimension is too small - using BlockWarps_ = typename BlockGemmShape::BlockWarps; - static constexpr index_t MPerBlock_ = BlockGemmShape::kM; - static constexpr index_t NPerBlock_ = BlockGemmShape::kN; - static constexpr index_t KPerBlock_ = BlockGemmShape::kK; - static constexpr index_t MWarp_ = BlockWarps_::at(number<0>{}); - static constexpr index_t NWarp_ = BlockWarps_::at(number<1>{}); - static constexpr index_t KPerXdl_ = BlockGemmShape::WarpTile::at(number<2>{}); - static constexpr index_t MIterPerWarp_ = MPerBlock_ / (MWarp_ * MThreadPerXdl); - static constexpr index_t NIterPerWarp_ = NPerBlock_ / (NWarp_ * NThreadPerXdl); - static constexpr index_t KIterPerWarp_ = KPerBlock_ / KPerXdl_; - - static constexpr index_t MXdlPackEff = - (MIterPerWarp_ >= MXdlPack && MIterPerWarp_ % MXdlPack == 0) ? MXdlPack : 1; - static constexpr index_t NXdlPackEff = - (NIterPerWarp_ >= NXdlPack && NIterPerWarp_ % NXdlPack == 0) ? NXdlPack : 1; - static constexpr index_t KXdlPackEff = - (KIterPerWarp_ >= KXdlPack && KIterPerWarp_ % KXdlPack == 0) ? KXdlPack : 1; - - static constexpr int kBlockPerCu = 1; - - // Scale block size (same constant used by MXGemmPipeline): each e8m0 scale covers 32 K elements - static constexpr index_t ScaleBlockSize = 32; - - // Padding flags pulled from pipeline so the kernel can pad the (unscaled) C and scale views - // consistently with the A/B views that the pipeline already pads via - // Underlying::MakeA/BBlockWindows. - static constexpr bool kPadM = MXGemmPipeline::kPadM; - static constexpr bool kPadN = MXGemmPipeline::kPadN; - static constexpr bool kPadK = MXGemmPipeline::kPadK; - - static_assert(DsLayout::size() == DsDataType::size(), - "The size of DsLayout and DsDataType should be the same"); - - // ------------------------------------------------------------------ - // Compile-time padding-support invariants for the MX comp-async pipeline. - // - // - K padding is NOT supported: async_load_tile issues vector buffer reads whose - // OOB check is per-vector-start, so a vector that straddles the K pad boundary - // pulls in data from the adjacent row / next K tile rather than zero. The packed - // scale tile has the same vector-load property. Until the async path learns how - // to do per-element pad masking, we forbid kPadK at compile time. - // - // - kPadM / kPadN are supported only when the GEMM has at least one full block - // along that dimension; the CShuffleEpilogue's LDS shuffle uses thread positions - // that do not all participate when the entire dimension is smaller than a tile - // (resulting in zeros being written into in-range output rows). The "entire - // dimension < tile" case is rejected at runtime in IsSupportedArgument; we - // cannot statically catch it because M and N are runtime values. - // ------------------------------------------------------------------ - static_assert(!kPadK, - "MX GEMM (comp-async pipeline): K padding (kPadK = true) is not supported. " - "The async vector loads do not mask elements that straddle the K pad " - "boundary, so partial K tiles produce silently wrong results. Choose K so " - "that K is a multiple of KPerBlock * k_batch."); - - // Single source of truth for the split-K atomic-add precondition, shared by the runtime - // check in IsSupportedArgument and the atomic_add dispatch in operator(). Split-K - // accumulates each k_id's partial C tile with atomic_add; the CShuffle epilogue can only - // emit atomic_add for fp16/bf16 outputs when the C vector size is even. For an odd vector - // size that combination is not instantiated, so such a config cannot run split-K. For all - // shipped tile shapes GetVectorSizeC() is even, so this is defensive rather than reachable. - static constexpr bool kSplitKAtomicAddSupported = - EpiloguePipeline::GetVectorSizeC() % 2 == 0 || !is_any_of::value; - - [[nodiscard]] CK_TILE_HOST static const std::string GetName() - { - // clang-format off - return concat('_', "mx_gemm", gemm_prec_str, MXGemmPipeline::GetName()); - // clang-format on - } - - template - using KernelArgs = MXGemmKernelArgs; - - template - CK_TILE_HOST static auto MakeKernelArgs(const std::array& as_ptr, - const std::array& bs_ptr, - const std::array& ds_ptr, - void* e_ptr, - index_t k_batch, - index_t M, - index_t N, - index_t K, - const std::array& stride_As, - const std::array& stride_Bs, - const std::array& stride_Ds, - index_t stride_E, - ScaleM scale_m_ptr, - ScaleN scale_n_ptr) - { - return KernelArgs(as_ptr, - bs_ptr, - ds_ptr, - e_ptr, - k_batch, - M, - N, - K, - stride_As, - stride_Bs, - stride_Ds, - stride_E, - scale_m_ptr, - scale_n_ptr); - } - - template - CK_TILE_HOST static constexpr auto GridSize(const KernelArgs& kargs) - { - const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N); - - if constexpr(UsePersistentKernel) - { - hipDeviceProp_t prop; - int deviceId = 0; // default device - - int dync_smem_size = 0; - int maxActiveBlocksPerCU = 0; - - if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess) - throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") + - hipGetErrorName(hipGetLastError())); - - if(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &maxActiveBlocksPerCU, - reinterpret_cast( - kentry<1, MXGemmKernel, remove_cvref_t>), - KernelBlockSize, - dync_smem_size) != hipSuccess) - throw std::runtime_error( - std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") + - hipGetErrorName(hipGetLastError())); - - const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU; - const int actual_grid_size = min(persistent_block_size, total_work_tile_cnt); - - // blockIdx.z selects the K split. For split-K, each k_id gets its own set of - // persistent blocks looping over the MxN tile space. - return dim3(actual_grid_size, 1, kargs.k_batch); - } - else - { - // Non-persistent: grid is (MxN tiles) x 1 x k_batch. blockIdx.z selects the K split. - return dim3(total_work_tile_cnt, 1, kargs.k_batch); - } - } - - template - CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) - { - // Reject unsupported combinations early; the MX pipeline silently produces wrong - // results otherwise (OOB reads, partial-tile shuffle artifacts, mis-aligned splits). - // See the static_assert block at the top of MXGemmKernel for the rationale behind - // each constraint. - const bool log = ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)); - - if(kargs.k_batch < 1) - { - if(log) - CK_TILE_ERROR("MX GEMM: k_batch must be >= 1."); - return false; - } - - // Split-K needs the atomic_add epilogue; reject configs that cannot emit it (fp16/bf16 - // output with an odd C vector size) instead of silently skipping the accumulation. - if constexpr(!kSplitKAtomicAddSupported) - { - if(kargs.k_batch > 1) - { - if(log) - CK_TILE_ERROR("MX GEMM: split-K (k_batch > 1) requires an even C vector size " - "for fp16/bf16 outputs (atomic_add epilogue constraint)."); - return false; - } - } - - // Split-K derives this k_id's logical K start from the row-major SplitKBatchOffset - // (as_k_split_offset[0]) to offset the packed-scale / flat-B windows; for column-major A - // that field is stride-scaled, so split-K with non-row-major A is not yet supported. - // (k_batch == 1 is unaffected -- the offset is 0 and unused.) When col-major A lands for - // non-preshuffle, extend the split-K K-offset here instead of this reject. - if constexpr(!std::is_same_v) - { - if(kargs.k_batch > 1) - { - if(log) - CK_TILE_ERROR("MX GEMM: split-K (k_batch > 1) currently requires row-major A."); - return false; - } - } - - // Preshuffle split-K relies on each split starting on a K-block boundary that is also - // aligned to the host-preshuffled scale layout's packed-K granularity, so that the flat-B - // and scale windows start at the same logical K. Split boundaries are KPerBlock-aligned - // (enforced by the "K % (KPerBlock * k_batch)" check below); it therefore suffices that the - // preshuffled scale K-block granularity (ScaleGranularityK * KXdlPackEff * KThreadPerXdl) - // divides KPerBlock. - if constexpr(MXGemmPipeline::Preshuffle) - { - constexpr index_t preshuffle_scale_k_granularity = - ScaleGranularityK * KXdlPackEff * KThreadPerXdl; - if(kargs.k_batch > 1 && - (TilePartitioner::KPerBlock % preshuffle_scale_k_granularity != 0)) - { - if(log) - CK_TILE_ERROR("MX GEMM: preshuffle split-K requires KPerBlock to be a multiple " - "of ScaleGranularityK * KXdlPackEff * KThreadPerXdl."); - return false; - } - } - - // M / N must be a multiple of the block tile when padding is disabled. - if(!kPadM && (kargs.M % TilePartitioner::MPerBlock != 0)) - { - if(log) - CK_TILE_ERROR("MX GEMM: M must be a multiple of MPerBlock when kPadM is false. " - "Enable kPadM on the GEMM config to run this shape."); - return false; - } - if(!kPadN && (kargs.N % TilePartitioner::NPerBlock != 0)) - { - if(log) - CK_TILE_ERROR("MX GEMM: N must be a multiple of NPerBlock when kPadN is false. " - "Enable kPadN on the GEMM config to run this shape."); - return false; - } - - // CShuffleEpilogue cannot run with a single partial tile along M or N: the shuffle's - // LDS write/read pattern leaves some in-range output rows/cols at zero. Reject these - // pathological shapes whether or not kPadM/kPadN is enabled. - if(kargs.M < TilePartitioner::MPerBlock) - { - if(log) - CK_TILE_ERROR("MX GEMM: M must be >= MPerBlock. Partial-only M tiles are not " - "supported by the MX CShuffleEpilogue."); - return false; - } - if(kargs.N < TilePartitioner::NPerBlock) - { - if(log) - CK_TILE_ERROR("MX GEMM: N must be >= NPerBlock. Partial-only N tiles are not " - "supported by the MX CShuffleEpilogue."); - return false; - } - - // K padding is unconditionally rejected (kPadK is also a compile-time error -- see the - // static_assert at the top of MXGemmKernel). Every split must consume an exact number - // of K tiles, otherwise the async vector loads read garbage past the K boundary. - const index_t k_tile = TilePartitioner::KPerBlock; - if(kargs.K % (k_tile * kargs.k_batch) != 0) - { - if(log) - CK_TILE_ERROR( - "MX GEMM: K must be a multiple of KPerBlock * k_batch. The MX comp-async " - "pipeline does not currently support K padding (vector loads across the K " - "pad boundary read garbage); pick aligned K dimensions or change k_batch."); - return false; - } - - // Scales are granular in K: each packed int32_t covers ScaleBlockSize * KXdlPackEff - // consecutive K elements. Every split-K boundary must land on that granularity so that - // each split can compute a packed-scale K offset. K1 is the WarpTile K, which is a - // multiple of that granularity for all shipped configs, but be defensive. - constexpr index_t scale_granularity_k = ScaleBlockSize * KXdlPackEff; - if(kargs.k_batch > 1) - { - // splitk_batch_offset allocates K in units of K1 (warp-tile K). If K1 itself is - // not a multiple of the scale granularity, split-K is not safe. - constexpr index_t K1 = BlockGemmShape::WarpTile::at(number<2>{}); - static_assert(K1 % scale_granularity_k == 0, - "MX GEMM: WarpTile K must be a multiple of ScaleBlockSize * KXdlPack " - "to support split-K."); - // Defensive runtime check: K must split evenly along K1 boundaries so that each - // k_id consumes a whole number of warp-tile K chunks (and therefore a whole - // number of packed-scale K elements). - if(kargs.K % (K1 * kargs.k_batch) != 0) - { - if(log) - CK_TILE_ERROR("MX GEMM: with k_batch > 1, K must be a multiple of WarpTile_K * " - "k_batch so that every split lands on a packed-scale boundary."); - return false; - } - } - - // Delegate the remaining shape/vector-size checks to the universal kernel. All MX - // pipelines (comp-async, eight-waves, preshuffle) expose the templated - // GetVectorSize{A,B}() that UniversalGemmKernel::IsSupportedArgument requires. - return Underlying::IsSupportedArgument( - static_cast(kargs)); - } - - using SplitKBatchOffset = typename Underlying::SplitKBatchOffset; - - // Create C block window following UniversalGemmKernel pattern - template - CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr, - const KernelArgs& kargs, - const index_t i_m, - const index_t i_n) - { - // Create tensor view for E/C tensor - constexpr index_t vector_size = EpiloguePipeline::GetVectorSizeC(); - const auto& e_tensor_view = [&]() -> auto { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_E, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_E), - number<1>{}, - number{}); - } - }(); - - // Pad both dims so OOB C writes (including partial trailing tiles where M < MPerBlock - // or N < NPerBlock) are masked by the pad transform. - const auto& e_pad_view = pad_tensor_view( - e_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - - // Create block window - auto c_block_window = make_tile_window( - e_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); - - return c_block_window; - } - - // Create scale A block windows with packed int32_t layout. - // Host packs (MXdlPack x KXdlPack) e8m0_t values into a single int32_t, producing a - // packed tensor of shape [M/MXdlPackEff, K/ScaleBlockSize/KXdlPackEff]. - // - // k_elem_offset: starting K element index for this block (0 unless split-K). - // Must be a multiple of ScaleBlockSize * KXdlPackEff. - template - CK_TILE_DEVICE static auto MakeScaleABlockWindows(const KernelArgs& kargs, - const index_t i_m, - const index_t k_elem_offset = 0) - { - auto scale_a = kargs.scale_m_ptr; - static_assert(ScaleM::GranularityK == ScaleGranularityK); - if constexpr(MXGemmPipeline::Preshuffle) - { - const auto scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPackEff * MThreadPerXdl)); - const auto scale_packs_k = kargs.K / ScaleGranularityK / (KXdlPackEff * KThreadPerXdl); - - const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed( - make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl)); - const auto scale_a_desc = transform_tensor_descriptor( - scale_a_naive_desc, - make_tuple(make_merge_transform(make_tuple(scale_packs_m, MThreadPerXdl)), - make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), - make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto scale_a_tensor_view = make_tensor_view( - reinterpret_cast(scale_a.ptr), scale_a_desc); - - // For split-K (k_batch > 1) advance the scale origin into this k_id's packed-K slice. - // The merged-K axis of the preshuffled scale view, merge(scale_packs_k, KThreadPerXdl), - // has the same total extent K/(ScaleGranularityK*KXdlPackEff) as the non-preshuffle - // layout, and split boundaries are KPerBlock-aligned (see IsSupportedArgument), so the - // K-block offset is the same closed form used by the non-preshuffle branch. - const index_t k_scale_offset = k_elem_offset / ScaleGranularityK / KXdlPackEff; - return make_tile_window( - scale_a_tensor_view, - make_tuple( - number{}, - number{}), - {i_m / MXdlPackEff, k_scale_offset}); - } - else - { - const auto scale_k_packed = kargs.K / ScaleGranularityK / KXdlPackEff; - const auto scale_m_packed = kargs.M / MXdlPackEff; - - // A scale tensor view - layout [M/MXdlPackEff, K/32/KXdlPackEff] with int32_t elements - const auto scale_a_tensor_view = make_naive_tensor_view( - reinterpret_cast(scale_a.ptr), - make_tuple(scale_m_packed, scale_k_packed), - make_tuple(scale_k_packed, 1)); - - // Pad the scale view so partial trailing tiles along M are handled safely (OOB scale - // loads return zero; with A also zero on the padded region the contribution is zero - // regardless of scale value). kPadK is statically disabled, so K never actually pads. - const auto scale_a_pad_view = pad_tensor_view( - scale_a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - - // For split-K (k_batch > 1) advance the scale origin into this k_id's packed-K slice. - const index_t k_scale_offset = k_elem_offset / ScaleGranularityK / KXdlPackEff; - - // Tile window shape: [MPerBlock/MXdlPackEff, KPerBlock/32/KXdlPackEff] - return make_tile_window( - scale_a_pad_view, - make_tuple(number{}, - number{}), - {i_m / MXdlPackEff, k_scale_offset}); - } - } - - template - CK_TILE_DEVICE static auto - MakeBFlatBlockWindows(const std::array& bs_ptr, - const KernelArgs& kargs, - const index_t i_n, - const index_t k_elem_offset = 0) - { - static_assert(NumBTensor == 1, "MX GEMM preshuffle currently supports one B tensor"); - - constexpr index_t kKPerBlock = MXGemmPipeline::kKPerBlock; - constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1); - constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile; - const index_t kFlatKBlocks = kargs.K / kKPerBlock; - const index_t kFlatN = kargs.N / kNWarpTile; - - // For split-K (k_batch > 1) advance the flat-B K origin into this k_id's K slice. The - // flat layout stores K as kFlatKBlocks blocks of flatKPerBlock elements each, and split - // boundaries are KPerBlock-aligned (enforced in IsSupportedArgument), so the offset lands - // on a clean K-block boundary. The universal bs_k_split_offset is not used here: it is - // derived from the logical B stride and does not match the preshuffled flat layout. - const index_t k_flat_offset = (k_elem_offset / kKPerBlock) * flatKPerBlock; - - auto b_flat_tensor_view = [&]() { - static_assert(flatKPerBlock % MXGemmPipeline::GetVectorSizeB() == 0, - "wrong! vector size for preshuffled B tensor"); - auto naive_desc = make_naive_tensor_descriptor_packed( - make_tuple(kFlatN, kFlatKBlocks, number{})); - auto desc = transform_tensor_descriptor( - naive_desc, - make_tuple(make_pass_through_transform(kFlatN), - make_merge_transform_v3_division_mod( - make_tuple(kFlatKBlocks, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return make_tensor_view(bs_ptr[number<0>{}], desc); - }(); - - return generate_tuple( - [&](auto) { - return make_tile_window(b_flat_tensor_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), - static_cast(k_flat_offset)}); - }, - number{}); - } - - template - CK_TILE_DEVICE static auto MakeScaleBBlockWindows(const KernelArgs& kargs, - const index_t i_n, - const index_t k_elem_offset = 0) - { - auto scale_b = kargs.scale_n_ptr; - static_assert(ScaleN::GranularityK == ScaleGranularityK); - - if constexpr(MXGemmPipeline::Preshuffle) - { - const auto scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPackEff * NThreadPerXdl)); - const auto scale_packs_k = kargs.K / ScaleGranularityK / (KXdlPackEff * KThreadPerXdl); - - const auto scale_b_naive_desc = make_naive_tensor_descriptor_packed( - make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl)); - const auto scale_b_desc = transform_tensor_descriptor( - scale_b_naive_desc, - make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)), - make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), - make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto scale_b_tensor_view = make_tensor_view( - reinterpret_cast(scale_b.ptr), scale_b_desc); - - // For split-K (k_batch > 1) advance the scale origin into this k_id's packed-K slice. - // The merged-K axis of the preshuffled scale view, merge(scale_packs_k, KThreadPerXdl), - // has the same total extent K/(ScaleGranularityK*KXdlPackEff) as the non-preshuffle - // layout, and split boundaries are KPerBlock-aligned (see IsSupportedArgument), so the - // K-block offset is the same closed form used by the non-preshuffle branch. - const index_t k_scale_offset = k_elem_offset / ScaleGranularityK / KXdlPackEff; - return make_tile_window( - scale_b_tensor_view, - make_tuple( - number{}, - number{}), - {i_n / NXdlPackEff, k_scale_offset}); - } - else - { - const auto scale_k_packed = kargs.K / ScaleGranularityK / KXdlPackEff; - const auto scale_n_packed = kargs.N / NXdlPackEff; - - // B scale tensor view - [N/NXdlPackEff, K/32/KXdlPackEff] of int32_t - const auto scale_b_tensor_view = make_naive_tensor_view( - reinterpret_cast(scale_b.ptr), - make_tuple(scale_n_packed, scale_k_packed), - make_tuple(scale_k_packed, 1)); - - // Pad the scale view so partial trailing tiles along N are handled safely (OOB scale - // loads return zero; with B also zero on the padded region the contribution is zero - // regardless of scale value). kPadK is statically disabled, so K never actually pads. - const auto scale_b_pad_view = pad_tensor_view( - scale_b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - - // For split-K (k_batch > 1) advance the scale origin into this k_id's packed-K slice. - const index_t k_scale_offset = k_elem_offset / ScaleGranularityK / KXdlPackEff; - - // Tile window shape: [NPerBlock/NXdlPackEff, KPerBlock/32/KXdlPackEff] - return make_tile_window( - scale_b_pad_view, - make_tuple(number{}, - number{}), - {i_n / NXdlPackEff, k_scale_offset}); - } - } - - template - CK_TILE_DEVICE static void RunMxGemm(const std::array& as_ptr, - const std::array& bs_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* smem_ptr, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t i_m, - const index_t i_n, - const index_t k_elem_offset = 0) - { - // Create block windows directly, following the new pattern from UniversalGemmKernel - // i_m and i_n are element offsets (iM * MPerBlock, iN * NPerBlock), not tile indices - const auto& a_block_window = [&]() { - if constexpr(MXGemmPipeline::Preshuffle) - { - // The preshuffle A async-load (MakeMX_AAsyncLoadBytesDramWindow) rebuilds the A - // view with a *packed* descriptor, i.e. it assumes the leading (M) stride equals - // the view's K extent. That only holds when the extent equals stride_A, which is - // the case for k_batch == 1 (splitted_k == K) but NOT for split-K (splitted_k < K): - // a packed extent of splitted_k would stride M by splitted_k instead of stride_A - // and read the wrong rows (only row 0 lands correctly). Use the full K extent so - // the packed M stride matches stride_A. The as_ptr K-offset already selects this - // k_id's slice and num_loop bounds the blocks read, so reads stay within - // [as_k_split_offset, as_k_split_offset + splitted_k) <= K (in-allocation). - return Underlying::MakeABlockWindows(as_ptr, kargs, kargs.K, i_m); - } - else - { - return Underlying::MakeABlockWindows( - as_ptr, kargs, splitk_batch_offset.splitted_k, i_m); - } - }(); - const auto& b_block_window = [&]() { - if constexpr(MXGemmPipeline::Preshuffle) - { - return MakeBFlatBlockWindows(bs_ptr, kargs, i_n, k_elem_offset); - } - else - { - return Underlying::MakeBBlockWindows( - bs_ptr, kargs, splitk_batch_offset.splitted_k, i_n); - } - }(); - const auto& d_block_window = Underlying::MakeDBlockWindows(ds_ptr, kargs, i_m, i_n); - - // Create scale block windows. For split-K (k_batch > 1), k_elem_offset advances the - // scale origin into the correct packed-K slice for this k_id; otherwise it is zero. - const auto& scale_a_block_window = MakeScaleABlockWindows(kargs, i_m, k_elem_offset); - const auto& scale_b_block_window = MakeScaleBBlockWindows(kargs, i_n, k_elem_offset); - - const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); - - static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK - || ScaleM::GranularityMN == -1 // or ScaleA is disable - || ScaleN::GranularityMN == -1, // or ScaleB is disable - "ScaleM and ScaleN should have the same GranularityK"); - - const auto& c_block_tile = [&]() { - if constexpr(MXGemmPipeline::Preshuffle) - { - constexpr index_t smem_ping_pong_size = MXGemmPipeline::GetSmemSize() / 2; - return MXGemmPipeline{}(a_block_window[number<0>{}], - b_block_window[number<0>{}], - scale_a_block_window, - scale_b_block_window, - num_loop, - smem_ptr, - static_cast(smem_ptr) + smem_ping_pong_size); - } - else - { - return MXGemmPipeline{}(a_block_window[number<0>{}], - b_block_window[number<0>{}], - scale_a_block_window, - scale_b_block_window, - num_loop, - smem_ptr); - } - }(); - - // Run Epilogue Pipeline - create C block window with the requested memory op (set for - // k_batch == 1, atomic_add for split-K so partial results accumulate into the same tile). - auto c_block_window = MakeCBlockWindows(e_ptr, kargs, i_m, i_n); - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr); - } - - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - return max(MXGemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); - } - - template - CK_TILE_DEVICE void operator()(KernelArgs kargs, - int partition_idx = get_block_id()) const - { -#if !defined(__gfx950__) - static_assert(sizeof(MXGemmPipeline) == 0, "CKTile MX GEMM kernels require gfx950."); - ignore = kargs; - ignore = partition_idx; -#else - const int total_work_tile_cnt = - amd_wave_read_first_lane(TilePartitioner::GridSize(kargs.M, kargs.N)); - - // Allocate shared memory for ping pong buffers - __shared__ char smem_ptr[GetSmemSize()]; - - // Support both persistent and non-persistent modes - do - { - const auto [iM, iN] = - TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx); - const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); - const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); - - // SplitKBatchOffset defaults its k_id to blockIdx.z, selecting this split's K slice. - const SplitKBatchOffset splitk_batch_offset( - static_cast(kargs)); - - // This k_id's logical K-element start. For row-major A, as_k_split_offset[0] is exactly - // that offset, so reuse it rather than recomputing the split formula; the packed-scale - // and flat-B K offsets are derived from it. Split-K with non-row-major A is rejected in - // IsSupportedArgument; for k_batch == 1 this value is 0 and unused for any layout. - const index_t k_elem_offset = - amd_wave_read_first_lane(splitk_batch_offset.as_k_split_offset[I0]); - - EDataType* e_ptr = static_cast(kargs.e_ptr); - - std::array as_ptr; - static_for<0, NumATensor, 1>{}([&](auto i) { - as_ptr[i] = static_cast(kargs.as_ptr[i]) + - splitk_batch_offset.as_k_split_offset[i] / APackedSize; - }); - - std::array bs_ptr; - static_for<0, NumBTensor, 1>{}([&](auto i) { - if constexpr(MXGemmPipeline::Preshuffle) - { - // The preshuffle (flat-B) path applies the per-split K offset to the flat - // window origin in MakeBFlatBlockWindows; bs_k_split_offset is derived from - // the logical B stride and would mis-offset the flat buffer. - bs_ptr[i] = static_cast(kargs.bs_ptr[i]); - } - else - { - bs_ptr[i] = static_cast(kargs.bs_ptr[i]) + - splitk_batch_offset.bs_k_split_offset[i] / BPackedSize; - } - }); - - // Dispatch epilogue: when k_batch > 1 each split accumulates a partial result into - // the same C tile, so we need atomic add (universal_gemm_kernel pattern). The - // fp16/bf16 even-vector-size precondition is captured once in kSplitKAtomicAddSupported - // and also rejected up front in IsSupportedArgument. - if(kargs.k_batch == 1) - { - RunMxGemm(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr, - kargs, - splitk_batch_offset, - i_m, - i_n, - /*k_elem_offset=*/0); - } - else - { - if constexpr(kSplitKAtomicAddSupported) - { - RunMxGemm(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr, - kargs, - splitk_batch_offset, - i_m, - i_n, - k_elem_offset); - } - } - partition_idx += gridDim.x; - } while(UsePersistentKernel && partition_idx < total_work_tile_cnt); -#endif - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp b/include/ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp deleted file mode 100644 index 0214f26a9c..0000000000 --- a/include/ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core.hpp" - -#if __clang_major__ >= 23 -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" -#endif -namespace ck_tile { - -template -struct MXScalePointer -{ - static constexpr int GranularityMN = SharedGranularityMN; - static constexpr int GranularityK = SharedGranularityK; - - static_assert(GranularityK != 0, - "GranularityK cannot be zero in primary template; " - "use the partial specialization for GranularityK == 0"); - - const ScaleType* ptr; - - CK_TILE_HOST_DEVICE MXScalePointer() = default; - CK_TILE_HOST_DEVICE MXScalePointer(const ScaleType* ptr_) : ptr(ptr_) {} - CK_TILE_HOST_DEVICE MXScalePointer(const ScaleType* ptr_, [[maybe_unused]] index_t length_) - : ptr(ptr_) - { - } - - CK_TILE_HOST_DEVICE MXScalePointer operator+(index_t offset) const - { - MXScalePointer ret; - if constexpr(GranularityMN == 0) - { - ret.ptr = ptr + offset / GranularityK; - } - else - { - ret.ptr = ptr + offset / GranularityMN / GranularityK; - } - return ret; - } - - CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const = delete; -}; - -template -struct MXScalePointer -{ - static constexpr int GranularityMN = SharedGranularityMN; - static constexpr int GranularityK = 0; - - static_assert(GranularityMN != 0); - - const ScaleType* ptr; - index_t length; - - CK_TILE_HOST_DEVICE MXScalePointer() = default; - CK_TILE_HOST_DEVICE MXScalePointer(const ScaleType* ptr_) : ptr(ptr_), length(1) {} - CK_TILE_HOST_DEVICE MXScalePointer(const ScaleType* ptr_, index_t length_) - : ptr(ptr_), length(length_) - { - } - - CK_TILE_HOST_DEVICE MXScalePointer operator+(index_t offset) const - { - MXScalePointer ret; - if constexpr(GranularityMN == 1) - { - ret.ptr = ptr + offset; - ret.length = length - offset; - } - else - { - ret.ptr = ptr + offset / GranularityMN; - ret.length = length - offset / GranularityMN; - } - return ret; - } - - CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const - { - // with additional oob check - if constexpr(GranularityMN == 1) - return i < length ? ptr[i] : 0; - else - return i / GranularityMN < length ? ptr[i / GranularityMN] : 0; - } -}; - -// shared granularityMN = -1 means no scale -template -struct MXScalePointer -{ - static constexpr int GranularityMN = -1; - static constexpr int GranularityK = 0; - - const ScaleType* ptr = nullptr; - - CK_TILE_HOST_DEVICE constexpr MXScalePointer() = default; - CK_TILE_HOST_DEVICE constexpr MXScalePointer(const ScaleType*) {} - CK_TILE_HOST_DEVICE constexpr MXScalePointer(const ScaleType*, index_t) {} - - CK_TILE_HOST_DEVICE constexpr MXScalePointer operator+(index_t) const - { - return MXScalePointer{}; - } - CK_TILE_HOST_DEVICE constexpr ScaleType operator[](index_t) const - { - return 1; // alway return 1, it doesn't change the result - } -}; - -} // namespace ck_tile -#if __clang_major__ >= 23 -#pragma clang diagnostic pop -#endif diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp deleted file mode 100644 index 32a1afc3c8..0000000000 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ /dev/null @@ -1,782 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once -#include "ck_tile/core.hpp" -#include "ck_tile/core/arch/arch.hpp" -#include "ck_tile/core/tensor/load_tile.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" -#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp" - -namespace ck_tile { - -// A Tile Window: global memory -// B Tile Window: global memory -// C Distributed tensor: register -// MX scaling support with OpSel -template -struct BaseMXGemmPipelineAgBgCrCompAsync -{ - static constexpr index_t PrefetchStages = 2; - static constexpr index_t PrefillStages = 1; - static constexpr index_t GlobalBufferNum = 1; - - static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; - - CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop) - { - // The prologue puts PrefetchStages + PrefillStages tiles in flight (2 LDS buffers + 1 - // register prefill) before the main loop, so the loop only runs when there is work - // beyond them; otherwise the tail drains the in-flight tiles. - return num_loop > PrefetchStages + PrefillStages; - } - - CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) - { - if(num_loop == 1) - { - return TailNumber::One; - } - if(num_loop % PrefetchStages == 1) - { - return TailNumber::Three; - } - else - { - return TailNumber::Two; - } - } - - template - CK_TILE_HOST_DEVICE static auto - TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) - { - // Handle all the valid cases. - if(has_hot_loop) - { - if(tail_number == TailNumber::Three) - { - return run_func(bool_constant{}, - integral_constant{}); - } - else if(tail_number == TailNumber::Two) - { - return run_func(bool_constant{}, - integral_constant{}); - } - } - else - { - if(tail_number == TailNumber::Three) - { - return run_func(bool_constant{}, - integral_constant{}); - } - else if(tail_number == TailNumber::Two) - { - return run_func(bool_constant{}, - integral_constant{}); - } - else - { - return (run_func(bool_constant{}, - integral_constant{})); - } - } - // If execution reaches here, it's an invalid tail_number because it wasn't handled above. -#if defined(__HIP_DEVICE_COMPILE__) - __builtin_unreachable(); -#else - throw std::logic_error( - "Invalid TailNumber: Only TailNumber::Three and TailNumber::Two are supported"); -#endif - } -}; - -/** - * @brief MX GEMM compute optimized pipeline version async; which is based on V4. - * - * This pipeline introduces asynchronous load from global memory to LDS, - * skipping the intermediate loading into pipeline registers. - * Supports MX scaling with e8m0 packed values and OpSel. - */ -template -struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync -{ - using Base = BaseMXGemmPipelineAgBgCrCompAsync; - using PipelineImplBase = GemmPipelineAgBgCrImplBase; - - using AsDataType = remove_cvref_t; - using BsDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - using AsLayout = remove_cvref_t; - using BsLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - - using AElementWise = remove_cvref_t; - using BElementWise = remove_cvref_t; - - using ALayout = remove_cvref_t>; - using BLayout = remove_cvref_t>; - - using ADataType = remove_cvref_t>; - using BDataType = remove_cvref_t>; - - static_assert(!std::is_same_v, "Not implemented"); - - static constexpr index_t ScaleGranularityK = Policy::ScaleGranularityK; - static constexpr index_t MXdlPack = Policy::MXdlPack; - static constexpr index_t NXdlPack = Policy::NXdlPack; - static constexpr index_t KXdlPack = Policy::KXdlPack; - - static constexpr index_t APackedSize = - ck_tile::numeric_traits>::PackedSize; - static constexpr index_t BPackedSize = - ck_tile::numeric_traits>::PackedSize; - - using BlockGemm = remove_cvref_t())>; - using I0 = number<0>; - using I1 = number<1>; - using I2 = number<2>; - - static constexpr index_t BlockSize = Problem::kBlockSize; - - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; - - template - static constexpr index_t GetVectorSizeA() - { - return Policy::template GetVectorSizeA(); - } - template - static constexpr index_t GetVectorSizeB() - { - return Policy::template GetVectorSizeB(); - } - static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } - - static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } - static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } - - static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - static constexpr index_t Preshuffle = Problem::Preshuffle; - - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; - static constexpr bool kPadK = Problem::kPadK; - - static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - - static constexpr auto Scheduler = Problem::Scheduler; - - static constexpr auto is_a_load_tr_v = bool_constant{}; - static constexpr auto is_b_load_tr_v = bool_constant{}; - -#if defined(__gfx950__) - static_assert(!(std::is_same_v && - !PipelineImplBase::is_a_load_tr), - "A=ColumnMajor requires transpose load (ds_read_tr), but it is disabled for " - "this K warp tile size. Use a smaller K warp tile (e.g. 32x32x64 MFMA)."); - static_assert(!(std::is_same_v && - !PipelineImplBase::is_b_load_tr), - "B=RowMajor requires transpose load (ds_read_tr), but it is disabled for " - "this K warp tile size. Use a smaller K warp tile (e.g. 32x32x64 MFMA)."); -#endif - - [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() - { - // clang-format off - return "COMPUTE_ASYNC"; - // clang-format on - } - - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - constexpr index_t smem_size = Policy::template GetSmemSize(); - return 2 * smem_size; - } - - CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() - { - return Policy::template IsTransposeC(); - } - - template - struct PipelineImpl : public PipelineImplBase - { - }; - - template <> - struct PipelineImpl : public PipelineImplBase - { - using Base = PipelineImplBase; - - CK_TILE_DEVICE static constexpr auto HotLoopScheduler() - { - constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{}); - constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{}); - constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{}); - - constexpr index_t WaveSize = get_warp_size(); - - constexpr index_t A_Buffer_Load_Inst_Num = - MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); - constexpr index_t B_Buffer_Load_Inst_Num = - NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); - - constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / - (BlockSize / WaveSize) / - (MPerXDL * NPerXDL * KPerXDL); - - constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num; - constexpr auto num_issue = num_buffer_load_inst; - - static_for<0, num_buffer_load_inst, 1>{}([&](auto i) { - // TODO: this will likely need to be redesigned after (1) changes to reading from - // LDS and (2) re-profiling - ignore = i; - __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA : 1 - __builtin_amdgcn_sched_group_barrier( - LLVMSchedGroupMask::DS_READ, 1, 0); // DS read : 1 - __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA: 1 - __builtin_amdgcn_sched_group_barrier( - LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read :1 - __builtin_amdgcn_sched_group_barrier( - LLVMSchedGroupMask::MFMA, C_MFMA_Inst_Num / num_issue - 2, 0); // MFMA : 6 - }); - __builtin_amdgcn_sched_barrier(0); - } - - template ::value && - is_detected::value, - bool>* = nullptr> - CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - index_t num_loop, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const - { - // TODO support multi-ABD - static_assert(1 == std::tuple_size_v); - static_assert(1 == std::tuple_size_v); - using ADramBlockWindowTmp = - remove_cvref_t{}, AsDramBlockWindowTmp>>; - using BDramBlockWindowTmp = - remove_cvref_t{}, BsDramBlockWindowTmp>>; - // TODO currently fused elementwise are not supported - ignore = a_element_func; - ignore = b_element_func; - static_assert(std::is_same_v, - element_wise::PassThrough>); - static_assert(std::is_same_v, - element_wise::PassThrough>); - static_assert( - std::is_same_v> && - std::is_same_v>, - "Data Type conflict on A and B matrix input data type."); - - constexpr bool is_a_col_major = - std::is_same_v; - constexpr bool is_b_row_major = std::is_same_v; - - static_assert(is_a_col_major - ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) - : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), - "A block window has incorrect lengths for defined ALayout!"); - static_assert(is_b_row_major - ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) - : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), - "B block window has incorrect lengths for defined BLayout!"); - - ////////////// global window & register ///////////////// - // A DRAM tile window(s) for load - auto a_tile_windows = generate_tuple( - [&](auto idx) { - return make_tile_window( - a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeADramTileDistribution()); - }, - number{}); - // B DRAM window(s) for load - auto b_tile_windows = generate_tuple( - [&](auto idx) { - return make_tile_window( - b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeBDramTileDistribution()); - }, - number{}); - - // for XOR swizzle: policy makes async global-to-LDS stores match LDS reads - // otherwise: no change to view - auto a_async_tile_windows = generate_tuple( - [&](auto idx) { - return make_tile_window(Policy::template MakeAsyncLoadADramWindow( - a_tile_windows[number{}]), - Policy::template MakeADramTileDistribution()); - }, - number{}); - - auto b_async_tile_windows = generate_tuple( - [&](auto idx) { - return make_tile_window(Policy::template MakeAsyncLoadBDramWindow( - b_tile_windows[number{}]), - Policy::template MakeBDramTileDistribution()); - }, - number{}); - - ////////////// MX Scale windows (pre-packed int32_t) ///////////////// - // Get WarpGemm configuration - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - constexpr index_t MWarp = BlockWarps::at(I0{}); - constexpr index_t NWarp = BlockWarps::at(I1{}); - - // Compute effective XdlPack sizes (fall back to 1 when iter count < pack) - constexpr index_t MPerXdl = WarpTile::at(I0{}); - constexpr index_t NPerXdl = WarpTile::at(I1{}); - constexpr index_t KPerXdl = WarpTile::at(I2{}); - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); - constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); - constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; - - constexpr index_t MXdlPackEff = - (MIterPerWarp >= Policy::MXdlPack && MIterPerWarp % Policy::MXdlPack == 0) - ? Policy::MXdlPack - : 1; - constexpr index_t NXdlPackEff = - (NIterPerWarp >= Policy::NXdlPack && NIterPerWarp % Policy::NXdlPack == 0) - ? Policy::NXdlPack - : 1; - constexpr index_t KXdlPackEff = - (KIterPerWarp >= Policy::KXdlPack && KIterPerWarp % Policy::KXdlPack == 0) - ? Policy::KXdlPack - : 1; - - // Packed scale dimensions - constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleGranularityK / KXdlPackEff; - - // Scale tensor views and base origins for creating tile windows per iteration - const auto& scale_a_tensor_view = scale_a_window.get_bottom_tensor_view(); - const auto& scale_b_tensor_view = scale_b_window.get_bottom_tensor_view(); - auto scale_a_base_origin = scale_a_window.get_window_origin(); - auto scale_b_base_origin = scale_b_window.get_window_origin(); - - // Create scale windows with packed int32_t dimensions - auto scale_a_dram_window = make_tile_window( - scale_a_tensor_view, - make_tuple(number{}, number{}), - scale_a_base_origin, - Policy::template MakeMX_ScaleA_DramTileDistribution()); - - auto scale_b_dram_window = make_tile_window( - scale_b_tensor_view, - make_tuple(number{}, number{}), - scale_b_base_origin, - Policy::template MakeMX_ScaleB_DramTileDistribution()); - - // this pipeline has a pair of LDS buffers per logical tile - auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); - auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); - - constexpr auto a_lds_shape = []() { - if constexpr(is_a_load_tr_v) - return make_tuple(number{}, number{}); - else - return make_tuple(number{}, number{}); - }(); - - constexpr auto b_lds_shape = []() { - if constexpr(is_b_load_tr_v) - return make_tuple(number{}, number{}); - else - return make_tuple(number{}, number{}); - }(); - - // LDS tile windows for storing, one per LDS buffer - auto a_copy_lds_window0 = make_tile_window(a_lds_block0, a_lds_shape, {0, 0}); - - auto a_copy_lds_window1 = make_tile_window(a_lds_block1, a_lds_shape, {0, 0}); - - auto b_copy_lds_window0 = make_tile_window(b_lds_block0, b_lds_shape, {0, 0}); - - auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0}); - - // initialize DRAM window steps, used to advance the DRAM windows - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; - using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; - constexpr ADramTileWindowStep a_dram_tile_window_step = - is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); - constexpr BDramTileWindowStep b_dram_tile_window_step = - is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); - - // read A(0), B(0) from DRAM to LDS window(0) - // and advance the DRAM windows - Base::GlobalPrefetchAsync( - a_copy_lds_window0, a_async_tile_windows[number<0>{}], a_dram_tile_window_step); - Base::GlobalPrefetchAsync( - b_copy_lds_window0, b_async_tile_windows[number<0>{}], b_dram_tile_window_step); - - // Initialize block gemm and C block tile - auto block_gemm = BlockGemm(); - auto c_block_tile = block_gemm.MakeCBlockTile(); - clear_tile(c_block_tile); - - // read A(1), B(1) from DRAM to LDS window(1) - // and advance the DRAM windows - Base::GlobalPrefetchAsync( - a_copy_lds_window1, a_async_tile_windows[number<0>{}], a_dram_tile_window_step); - Base::GlobalPrefetchAsync( - b_copy_lds_window1, b_async_tile_windows[number<0>{}], b_dram_tile_window_step); - - // tile distribution for the register tiles - using ALdsTileDistr = - decltype(make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode())); - using BLdsTileDistr = - decltype(make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode())); - - using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr{})); - using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr{})); - - // register tiles; double buffering -> a register tile corresponds to a LDS tile window - ALdsTile a_block_tile0, a_block_tile1; - BLdsTile b_block_tile0, b_block_tile1; - - // Some sanity checks on the LDS tile sizes - static_assert(sizeof(ALdsTile) == MPerBlock * - (KPerBlock * sizeof(ADataType) / APackedSize) * - NWarp / BlockSize, - "ALdsTile size is wrong!"); - static_assert(sizeof(BLdsTile) == NPerBlock * - (KPerBlock * sizeof(BDataType) / BPackedSize) * - MWarp / BlockSize, - "BLdsTile size is wrong!"); - static_assert(Policy::template GetSmemSizeA() >= - MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize), - "SmemSizeA size is wrong!"); - static_assert(Policy::template GetSmemSizeB() >= - (KPerBlock * sizeof(BDataType) / BPackedSize) * NPerBlock, - "SmemSizeB size is wrong!"); - - ////////////// MX Scale register tiles (ping-pong buffers) ///////////////// - // Scales are pre-packed int32_t: each int32_t holds 2M/N x 2K e8m0_t values - // Block GEMM uses OpSel (0-3) to select the right byte per MFMA call - - using ScaleATileType = decltype(load_tile(scale_a_dram_window)); - using ScaleBTileType = decltype(load_tile(scale_b_dram_window)); - ScaleATileType scale_a_tile_ping, scale_a_tile_pong; - ScaleBTileType scale_b_tile_ping, scale_b_tile_pong; - - // initialize Scale DRAM window steps, used to advance the Scale DRAM windows - using ScaleADramTileWindowStep = typename ScaleADramBlockWindowTmp::BottomTensorIndex; - using ScaleBDramTileWindowStep = typename ScaleBDramBlockWindowTmp::BottomTensorIndex; - constexpr ScaleADramTileWindowStep scale_a_dram_tile_window_step = - make_array(0, ScaleKDimPerBlock); - constexpr ScaleBDramTileWindowStep scale_b_dram_tile_window_step = - make_array(0, ScaleKDimPerBlock); - - // Helper function to load scales - auto load_scales_from_dram = [&](auto& scale_a, auto& scale_b) { - scale_a = load_tile(scale_a_dram_window); - scale_b = load_tile(scale_b_dram_window); - move_tile_window(scale_a_dram_window, scale_a_dram_tile_window_step); - move_tile_window(scale_b_dram_window, scale_b_dram_tile_window_step); - }; - - constexpr auto a_lds_input_tile_distr = []() { - if constexpr(is_a_load_tr_v) - return make_static_tile_distribution( - typename InputTileDistributionTraits< - typename ALdsTileDistr::DstrEncode, - typename Problem::ADataType>::TransposedDstrEncode{}); - else - return ALdsTileDistr{}; - }(); - constexpr auto b_lds_input_tile_distr = []() { - if constexpr(is_b_load_tr_v) - return make_static_tile_distribution( - typename InputTileDistributionTraits< - typename BLdsTileDistr::DstrEncode, - typename Problem::BDataType>::TransposedDstrEncode{}); - else - return BLdsTileDistr{}; - }(); - - // LDS tile windows for reading; - // they share the data pointer with the LDS windows for storing - // but also associate with a distribution to produce a register tile when reading - auto a_lds_ld_window0 = - make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr); - auto a_lds_ld_window1 = - make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr); - auto b_lds_ld_window0 = - make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr); - auto b_lds_ld_window1 = - make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr); - - static_assert(!(is_tile_window_linear_v) && - !(is_tile_window_linear_v) && - !(is_tile_window_linear_v) && - !(is_tile_window_linear_v), - "LDS windows must not be linear"); - - // write to LDS window(0) must complete before the local prefetch - block_sync_lds_direct_load(); - // read A(0), B(0) from LDS window(0) to pipeline registers(0) - Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v); - Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v); - // LDS window(0) contents are overwritten below by global prefetch, need to sync - block_sync_lds(); - // read A(2), B(2) from DRAM to LDS window(0) - // and advance the DRAM windows - Base::GlobalPrefetchAsync( - a_copy_lds_window0, a_async_tile_windows[number<0>{}], a_dram_tile_window_step); - Base::GlobalPrefetchAsync( - b_copy_lds_window0, b_async_tile_windows[number<0>{}], b_dram_tile_window_step); - - // Load scales for iteration 0 (ping) - load_scales_from_dram(scale_a_tile_ping, scale_b_tile_ping); - // Load scales for iteration 1 (pong) if needed - if(num_loop > 1) - { - load_scales_from_dram(scale_a_tile_pong, scale_b_tile_pong); - } - - if(HasHotLoop) - { - // we have had 3 global prefetches so far, indexed (0, 1, 2). - index_t i_global_read = amd_wave_read_first_lane(3); - // alternate ping: (read to register tile(1), use register tile(0) as gemm input) - // pong: (read to register tile(0), use register tile(1) as gemm input) - do - { - // ping - { - // read A(i-1), B(i-1) from LDS window(1) to pipeline registers(1) - Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); - Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); - // LDS window(1) contents are overwritten by global prefetch, need to sync - block_sync_lds(); - // read A(i), B(i) from DRAM to LDS window(1) - // and advance the DRAM windows - Base::GlobalPrefetchAsync(a_copy_lds_window1, - a_async_tile_windows[number<0>{}], - a_dram_tile_window_step); - Base::GlobalPrefetchAsync(b_copy_lds_window1, - b_async_tile_windows[number<0>{}], - b_dram_tile_window_step); - // C(i-3) = A(i-3) @ B(i-3) with MX scaling - block_gemm(c_block_tile, - a_block_tile0, - b_block_tile0, - scale_a_tile_ping, - scale_b_tile_ping); - HotLoopScheduler(); - // Load next scales after using current scales above - load_scales_from_dram(scale_a_tile_ping, scale_b_tile_ping); - } - // pong - { - // write to LDS window(0) must complete before the local prefetch - block_sync_lds_direct_load(); - // read A(i), B(i) from LDS window(0) to pipeline registers(0) - Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v); - Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v); - // LDS window(0) contents are overwritten by global prefetch, need to sync - block_sync_lds(); - // read A(i+1), B(i+1) from DRAM to LDS window(0) - // and advance the DRAM windows - Base::GlobalPrefetchAsync(a_copy_lds_window0, - a_async_tile_windows[number<0>{}], - a_dram_tile_window_step); - Base::GlobalPrefetchAsync(b_copy_lds_window0, - b_async_tile_windows[number<0>{}], - b_dram_tile_window_step); - // C(i-2) = A(i-2) @ B(i-2) with MX scaling - block_gemm(c_block_tile, - a_block_tile1, - b_block_tile1, - scale_a_tile_pong, - scale_b_tile_pong); - HotLoopScheduler(); - // Load next scales after using current scales above - load_scales_from_dram(scale_a_tile_pong, scale_b_tile_pong); - } - i_global_read += 2; - } while(i_global_read < num_loop); - } - - // 3 block gemms remaining - if constexpr(TailNum == TailNumber::Three) - { - { - // read A(num_loop-1), B(num_loop-1) from LDS window(1) to pipeline registers(1) - Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); - Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); - // C(num_loop-2) = A(num_loop-2) @ B(num_loop-2) with MX scaling - block_gemm(c_block_tile, - a_block_tile0, - b_block_tile0, - scale_a_tile_ping, - scale_b_tile_ping); - - // load last scales to ping for the last iteration to ping buffers - load_scales_from_dram(scale_a_tile_ping, scale_b_tile_ping); - } - { - // write to LDS window(0) must complete before the local prefetch - block_sync_lds_direct_load(); - // read A(num_loop), B(num_loop) from LDS window(0) to pipeline registers(0) - Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v); - Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v); - // C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling - block_gemm(c_block_tile, - a_block_tile1, - b_block_tile1, - scale_a_tile_pong, - scale_b_tile_pong); - } - { - // C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling - block_gemm(c_block_tile, - a_block_tile0, - b_block_tile0, - scale_a_tile_ping, - scale_b_tile_ping); - } - } - else if(TailNum == TailNumber::Two) - // 2 block gemms remaining - { - { - // read A(num_loop), B(num_loop) from LDS window(1) to pipeline registers(1) - Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); - Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); - block_gemm(c_block_tile, - a_block_tile0, - b_block_tile0, - scale_a_tile_ping, - scale_b_tile_ping); - } - { - // C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling - block_gemm(c_block_tile, - a_block_tile1, - b_block_tile1, - scale_a_tile_pong, - scale_b_tile_pong); - } - } - else if(TailNum == TailNumber::One) - { - block_sync_lds(); - // C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling - block_gemm(c_block_tile, - a_block_tile0, - b_block_tile0, - scale_a_tile_ping, - scale_b_tile_ping); - __builtin_amdgcn_sched_barrier(0); - } - - return c_block_tile; - } - }; - - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - index_t num_loop, - void* __restrict__ p_smem) const - { - constexpr index_t smem_size = Policy::template GetSmemSize(); - const auto smem = reinterpret_cast(p_smem); - - const bool has_hot_loop = Base::BlockHasHotloop(num_loop); - const auto tail_number = Base::GetBlockLoopTailNum(num_loop); - - const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - scale_a_window, - scale_b_window, - num_loop, - smem, - smem + smem_size); - }; - - return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); - } - - public: - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - const index_t num_loop, - void* __restrict__ p_smem) const - { - constexpr index_t smem_size = Policy::template GetSmemSize(); - const auto smem = reinterpret_cast(p_smem); - - const bool has_hot_loop = Base::BlockHasHotloop(num_loop); - const auto tail_number = Base::GetBlockLoopTailNum(num_loop); - - const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { - return PipelineImpl{}.template operator()( - make_tuple(a_dram_block_window_tmp), - element_wise::PassThrough{}, - make_tuple(b_dram_block_window_tmp), - element_wise::PassThrough{}, - scale_a_window, - scale_b_window, - num_loop, - smem, - smem + smem_size); - }; - - return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); - } -}; -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp deleted file mode 100644 index bce312d1f9..0000000000 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ /dev/null @@ -1,605 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/core/arch/arch.hpp" -#include "ck_tile/core/numeric/float8.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" -#include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" -#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp" -#include - -namespace ck_tile { -// Default policy for MXGemmPipelineAgBgCrCompAsync -// Customized methods: MakeALdsBlockDescriptor, MakeBLdsBlockDescriptor -// GetBlockGemm implementation is copied from GemmPipelineAgBgCrCompV4DefaultPolicy -// Adds MX scale tile distributions -struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy - : public UniversalGemmBasePolicy -{ - static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked; - static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked; - - // Async copy supports 32-bit, 96-bit, or 128-bit transfers (4, 12, 16 bytes) - // Take PackedSize into consideration (for example for FP4 support) - template - static constexpr index_t AsyncVectorBytes = - sizeof(DataType) * KPack / numeric_traits>::PackedSize; - - template - static constexpr bool IsSupportedAsyncVectorWidth = - AsyncVectorBytes == 4 || AsyncVectorBytes == 12 || - AsyncVectorBytes == 16; - - template - static constexpr bool IsF8XorSwizzleDataType = - std::is_same_v, fp8_t> || - std::is_same_v, bf8_t>; - - template - static constexpr bool IsFP4XorSwizzleDataType = - std::is_same_v, pk_fp4_t>; - - // XOR Swizzle: support F8/F8 and FP4/FP4. Mixed F8/FP4 stays on the plain path. - template - static constexpr bool IsSupportedXorSwizzleDataType = - (IsF8XorSwizzleDataType && - IsF8XorSwizzleDataType) || - (IsFP4XorSwizzleDataType && - IsFP4XorSwizzleDataType); - - // FP4 needs the XOR KPack in logical elements - // so the async transaction remains 16 bytes - template - static constexpr index_t GetXorSwizzleKPack() - { - return SmemPack * numeric_traits>::PackedSize; - } - - template - static constexpr index_t GetXorSwizzleKPackA() - { - return GetXorSwizzleKPack()>(); - } - - template - static constexpr index_t GetXorSwizzleKPackB() - { - return GetXorSwizzleKPack()>(); - } - - // Check that async vector store to LDS is supported - template - static constexpr bool IsSupportedXorSwizzleAsyncWidth = - IsSupportedAsyncVectorWidth()> && - IsSupportedAsyncVectorWidth()>; - - // gfx950 scales:16x16x128 warp tile, 16-element smem pack, KWarps==1 - template - static constexpr bool IsSupportedXorSwizzleShape = []() { - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - return Problem::NumWaveGroups == 1 && BlockWarps::at(number<2>{}) == 1 && - WarpTile::at(number<0>{}) == 16 && WarpTile::at(number<1>{}) == 16 && - WarpTile::at(number<2>{}) == 128 && GetSmemPackA() == 16 && - GetSmemPackB() == 16; - }(); - - // Assume normal LDS layout, not transpose-load - template - static constexpr bool UseXorSwizzle = - !is_a_load_tr && !is_b_load_tr && - IsSupportedXorSwizzleDataType && IsSupportedXorSwizzleAsyncWidth && - IsSupportedXorSwizzleShape; - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeXorSwizzleABDramTileDistribution() - { - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t KPerBlock = BlockGemmShape::kK; - constexpr index_t KWarps = BlockWarps::at(I2); - constexpr index_t K1 = WarpTile::at(I2) / K2; - constexpr index_t K0 = KPerBlock / (KWarps * K1 * K2); - - constexpr index_t warp_size = get_warp_size(); - constexpr index_t warp_num = BlockSize / warp_size; - - static_assert(KWarps == 1, "MX XOR swizzle currently supports KWarps == 1"); - static_assert(KWarps * K0 * K1 * K2 == KPerBlock, "Wrong!"); - - constexpr index_t M2 = warp_size / K1; - constexpr index_t M1 = warp_num / Problem::NumWaveGroups; - constexpr index_t M0 = MNPerBlock / (M1 * M2); - - static_assert(M0 * M1 * M2 == MNPerBlock, "Wrong!"); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 1>>, - sequence<1, 2, 2>, - sequence<0, 0, 2>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() - { - if constexpr(UseXorSwizzle) - { - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPack = GetXorSwizzleKPackA(); - return MakeXorSwizzleABDramTileDistribution(); - } - else - { - return UniversalGemmBasePolicy:: - template MakeADramTileDistribution(); - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() - { - if constexpr(UseXorSwizzle) - { - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPack = GetXorSwizzleKPackB(); - return MakeXorSwizzleABDramTileDistribution(); - } - else - { - return UniversalGemmBasePolicy:: - template MakeBDramTileDistribution(); - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeXorSwizzledABLdsBlockDescriptor() - { - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t KPerBlock = BlockGemmShape::kK; - constexpr index_t KWarps = BlockWarps::at(I2); - constexpr index_t K1 = WarpTile::at(I2) / K2; - constexpr index_t K0 = KPerBlock / (KWarps * K1 * K2); - - constexpr index_t warp_size = get_warp_size(); - constexpr index_t warp_num = BlockSize / warp_size; - constexpr index_t wg_attr_num_access_v = static_cast(WGAttrNumAccess); - - static_assert(warp_num * warp_size == BlockSize, "Wrong!"); - static_assert(KWarps * K0 * K1 * K2 == KPerBlock, "Wrong!"); - static_assert(KWarps == 1, "MX XOR swizzle currently supports KWarps == 1"); - static_assert(wg_attr_num_access_v == 1 || wg_attr_num_access_v == 2, - "MX XOR swizzle currently supports FP8, BF8, FP4"); - - constexpr index_t K2Pad = K2 < 16 ? 16 : K2; - constexpr index_t M3 = 4; - constexpr index_t M2 = warp_size / K1 / M3; - constexpr index_t M1 = WarpTileMN / (M2 * M3); - constexpr index_t M0 = MNPerBlock / (M1 * M2 * M3); - - static_assert(M0 * M1 * M2 * M3 == MNPerBlock, "Wrong!"); - - constexpr index_t PadSize = 4 * K2; - - constexpr auto desc_0 = make_naive_tensor_descriptor( - number_tuple{}, - number_tuple{}, - number{}, - number<1>{}); - - constexpr auto desc_1 = transform_tensor_descriptor( - desc_0, - make_tuple(make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_xor_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4, 5>{}, - sequence<6>{}), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4, 5>{}, - sequence<6>{})); - - constexpr auto desc_2 = transform_tensor_descriptor( - desc_1, - make_tuple(make_merge_transform_v3_division_mod(number_tuple{}), - make_merge_transform_v3_division_mod(number_tuple{})), - make_tuple(sequence<0, 2, 3, 4>{}, sequence<1, 5, 6>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return desc_2; - } - - // MX scaling configuration: each e8m0 scale covers 32 elements in K - static constexpr int ScaleGranularityK = 32; - - template > - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - if constexpr(is_a_load_tr) - { - // TODO: better LDS descriptor for performance - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( // - make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{}), - number{}, - number<1>{}); - return a_lds_block_desc_0; - } - else - { - if constexpr(UseXorSwizzle) - { - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr index_t KPack = GetXorSwizzleKPackA(); - constexpr auto desc = - MakeXorSwizzledABLdsBlockDescriptor()>(); - static_assert(desc.get_element_space_size() >= MPerBlock * KPerBlock, - "XOR swizzle LDS allocation must cover the A tile"); - return desc; - } - else - { - constexpr index_t KPack = GetSmemPackA(); - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - return transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_pass_through_transform(number{}), - make_merge_transform( - make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - if constexpr(is_b_load_tr) - { - // TODO: better LDS descriptor for performance - constexpr auto b_lds_block_desc_0 = - make_naive_tensor_descriptor(make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{}), - number{}, - number<1>{}); - return b_lds_block_desc_0; - } - else - { - if constexpr(UseXorSwizzle) - { - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr index_t KPack = GetXorSwizzleKPackB(); - constexpr auto desc = - MakeXorSwizzledABLdsBlockDescriptor()>(); - static_assert(desc.get_element_space_size() >= NPerBlock * KPerBlock, - "XOR swizzle LDS allocation must cover the B tile"); - return desc; - } - else - { - constexpr index_t KPack = GetSmemPackB(); - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - return transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_pass_through_transform(number{}), - make_merge_transform( - make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - } - } - - // MX GEMM: Double access for FP8/BF8, Single for FP4 - template - CK_TILE_HOST_DEVICE static constexpr auto CalculateWGAttrNumAccess() - { - using DataType = remove_cvref_t; - - if constexpr(std::is_same_v || std::is_same_v) - { - return WGAttrNumAccessEnum::Double; - } - else if constexpr(std::is_same_v) - { - return WGAttrNumAccessEnum::Single; - } - else - { - static_assert(sizeof(DataType) == 0, - "CalculateWGAttrNumAccess(): unsupported data type"); - return WGAttrNumAccessEnum::Invalid; - } - } - - // Get number of accesses - template - CK_TILE_HOST_DEVICE static constexpr auto GetWGAttrNumAccess() - { - constexpr auto num_access_a = CalculateWGAttrNumAccess(); - constexpr auto num_access_b = CalculateWGAttrNumAccess(); - - if constexpr(static_cast(num_access_a) >= static_cast(num_access_b)) - { - return num_access_a; - } - else - { - return num_access_b; - } - } - - template - CK_TILE_DEVICE static constexpr auto MakeAsyncLoadABDramWindow(const Window& window) - { - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - constexpr auto ndims = std::decay_t::get_num_of_dimension(); - static_assert(ndims == 2, "only support 2D tensor"); - - constexpr index_t KPerBlock = BlockGemmShape::kK; - constexpr index_t KWarps = BlockWarps::at(I2); - constexpr index_t K1 = WarpTile::at(I2) / K2; - - static_assert(K1 * K2 == WarpTile::at(I2), "Wrong!"); - static_assert(KPerBlock % (KWarps * K1 * K2) == 0, "Wrong!"); - - constexpr index_t wg_attr_num_access_v = static_cast(WGAttrNumAccess); - - constexpr index_t M4 = 4; // same as MakeXorSwizzledABLdsBlockDescriptor::M3 - static_assert(get_warp_size() % (wg_attr_num_access_v * K1 * M4) == 0, - "warp_size must be divisible by (wg_attr_num_access_v * K1 * M4)"); - - auto&& tensor_view = window.get_bottom_tensor_view(); - const auto [rows, cols] = tensor_view.get_tensor_descriptor().get_lengths(); - - const index_t k_tiles = cols / (KWarps * K1 * K2); - const auto col_lens = make_tuple(k_tiles, number{}, number{}, number{}); - - const index_t M0 = integer_divide_ceil(rows, M4); - const auto row_lens = make_tuple(M0, number{}); - - const auto desc_0 = transform_tensor_descriptor( - tensor_view.get_tensor_descriptor(), - make_tuple(make_unmerge_transform(row_lens), make_unmerge_transform(col_lens)), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0, 1>{}, sequence<2, 3, 4, 5>{})); - - const auto desc_1 = transform_tensor_descriptor( - desc_0, - make_tuple(make_pass_through_transform(M0), - make_xor_transform(make_tuple(number{}, number{})), - make_pass_through_transform(k_tiles), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple( - sequence<0>{}, sequence<1, 4>{}, sequence<2>{}, sequence<3>{}, sequence<5>{}), - make_tuple( - sequence<0>{}, sequence<1, 4>{}, sequence<2>{}, sequence<3>{}, sequence<5>{})); - - const auto desc = - transform_tensor_descriptor(desc_1, - make_tuple(make_merge_transform_v3_division_mod(row_lens), - make_merge_transform_v3_division_mod(col_lens)), - make_tuple(sequence<0, 1>{}, sequence<2, 3, 4, 5>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tile_window( - make_tensor_view(&tensor_view.get_buffer_view()(0), desc), - window.get_window_lengths(), - window.get_window_origin()); - } - - template - CK_TILE_DEVICE static constexpr auto MakeAsyncLoadADramWindow(const Window& window) - { - if constexpr(UseXorSwizzle) - { - constexpr index_t KPack = GetXorSwizzleKPackA(); - return MakeAsyncLoadABDramWindow()>(window); - } - else - { - return make_tile_window(window.get_bottom_tensor_view(), - window.get_window_lengths(), - window.get_window_origin()); - } - } - - template - CK_TILE_DEVICE static constexpr auto MakeAsyncLoadBDramWindow(const Window& window) - { - if constexpr(UseXorSwizzle) - { - constexpr index_t KPack = GetXorSwizzleKPackB(); - return MakeAsyncLoadABDramWindow()>(window); - } - else - { - return make_tile_window(window.get_bottom_tensor_view(), - window.get_window_lengths(), - window.get_window_origin()); - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() - { - using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - - using ADataType = typename Problem::ADataType; - using BDataType = typename Problem::BDataType; - using CDataType = typename Problem::CDataType; - - // FP4 and FP8 require different layouts for the scaled mfma instructions - constexpr auto wg_attr_num_access = GetWGAttrNumAccess(); - - using WarpGemm = WarpGemmDispatcher; - - using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; - - return BlockMXGemmARegBRegCRegV1{}; - } - - // XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension - // Host packs MXdlPack * KXdlPack e8m0_t into one int32_t for A scales - // Host packs NXdlPack * KXdlPack e8m0_t into one int32_t for B scales - static constexpr int MXdlPack = 2; - static constexpr int NXdlPack = 2; - static constexpr int KXdlPack = 2; - - // MX Scale tile distributions for loading pre-packed int32_t from global memory - // Packed layout: [M/MXdlPack, K/32/KXdlPack] of int32_t - template - CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() - { - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t MWarp = BlockWarps::at(number<0>{}); - constexpr index_t NWarp = BlockWarps::at(number<1>{}); - constexpr index_t MPerXdl = WarpTile::at(number<0>{}); - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K_Lane = get_warp_size() / MPerXdl; - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); - constexpr index_t KPerXdl = WarpTile::at(number<2>{}); - constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; - constexpr index_t KPerLane = KPerXdl / ScaleGranularityK / K_Lane; - - // Effective pack sizes: fall back to 1 when iteration count < pack size - constexpr index_t MXdlPackEff = - (MIterPerWarp >= MXdlPack && MIterPerWarp % MXdlPack == 0) ? MXdlPack : 1; - constexpr index_t KXdlPackEff = - (KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1; - - constexpr index_t MIterPerWarp_packed = MIterPerWarp / MXdlPackEff; - constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, - sequence<2, 1, 2>, - sequence<0, 1, 2>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution() - { - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t MWarp = BlockWarps::at(number<0>{}); - constexpr index_t NWarp = BlockWarps::at(number<1>{}); - constexpr index_t NPerXdl = WarpTile::at(number<1>{}); - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K_Lane = get_warp_size() / NPerXdl; - constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); - - constexpr index_t KPerXdl = WarpTile::at(number<2>{}); - constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; - constexpr index_t KPerLane = KPerXdl / ScaleGranularityK / K_Lane; - - // Effective pack sizes: fall back to 1 when iteration count < pack size - constexpr index_t NXdlPackEff = - (NIterPerWarp >= NXdlPack && NIterPerWarp % NXdlPack == 0) ? NXdlPack : 1; - constexpr index_t KXdlPackEff = - (KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1; - - constexpr index_t NIterPerWarp_packed = NIterPerWarp / NXdlPackEff; - constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, - sequence<2, 1, 2>, - sequence<0, 1, 2>>{}); - } -}; -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp deleted file mode 100644 index 3b25d6091a..0000000000 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp +++ /dev/null @@ -1,282 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once -#include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" -#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp" - -namespace ck_tile { - -/** - * @brief Compute optimized pipeline version async for 8 waves - * - * This pipeline introduces asynchronous load from global memory to LDS, - * skipping the intermediate loading into pipeline registers. - */ -template -struct MXGemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrCompV3 -{ - using Base = BaseGemmPipelineAgBgCrCompV3; - using PipelineImplBase = GemmPipelineAgBgCrEightWavesImplBase; - - using AsDataType = remove_cvref_t; - using BsDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - using AsLayout = remove_cvref_t; - using BsLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - - using AElementWise = remove_cvref_t; - using BElementWise = remove_cvref_t; - - using ALayout = remove_cvref_t>; - using BLayout = remove_cvref_t>; - - using ADataType = remove_cvref_t>; - using BDataType = remove_cvref_t>; - - static_assert(!std::is_same_v, "Not implemented"); - - static constexpr index_t APackedSize = ck_tile::numeric_traits::PackedSize; - static constexpr index_t BPackedSize = ck_tile::numeric_traits::PackedSize; - - using BlockGemm = remove_cvref_t())>; - using WarpGemm = typename BlockGemm::WarpGemm; - - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - - static constexpr index_t BlockSize = Problem::kBlockSize; - - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; - - static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(I0); - static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(I1); - static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(I2); - - static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock; - - static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * WarpGemm::kM); - static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * WarpGemm::kN); - static constexpr index_t KIterPerWarp = KPerBlock / (KWarps * WarpGemm::kK); - - static constexpr bool Async = true; - - template - static constexpr index_t GetVectorSizeA() - { - return Policy::template GetVectorSizeA(); - } - template - static constexpr index_t GetVectorSizeB() - { - return Policy::template GetVectorSizeB(); - } - - static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } - static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } - - static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - static constexpr index_t Preshuffle = Problem::Preshuffle; - - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; - static constexpr bool kPadK = Problem::kPadK; - - static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - - static constexpr auto Scheduler = Problem::Scheduler; - - [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() - { - // clang-format off - return "COMPUTE_ASYNC_EIGHT_WAVES"; - // clang-format on - } - - [[nodiscard]] CK_TILE_HOST static const std::string GetName() - { - // clang-format off - constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0); - constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1); - return concat('_', "pipeline_AgBgCrCompAsyncEightWaves", - concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, - concat('x', GetVectorSizeA(), GetVectorSizeB()), - concat('x', WaveNumM, WaveNumN), - concat('x', kPadM, kPadN, kPadK), - Problem::GetName()); - // clang-format on - } - - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - return Policy::template GetSmemSize(); - } - - static constexpr index_t MFMA_INST = MIterPerWarp * NIterPerWarp * KIterPerWarp; - - // Scales are packed so odd numbers of iterations greater than 1 are not supported - static_assert((MIterPerWarp == 1) || (MIterPerWarp % 2 == 0)); - static_assert((NIterPerWarp == 1) || (NIterPerWarp % 2 == 0)); - static_assert((KIterPerWarp == 1) || (KIterPerWarp % 2 == 0)); - - template - struct PipelineImpl : public PipelineImplBase - { - }; - - template <> - struct PipelineImpl : public PipelineImplBase - { - using Base = PipelineImplBase; - - template ::value && - !is_detected::value, - bool>* = nullptr> - CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - index_t num_loop, - void* __restrict__ p_smem) const - { - // TODO: A/B elementwise functions currently not supported - ignore = a_element_func; - ignore = b_element_func; - - // ------ - // Checks - // ------ - static_assert( - std::is_same_v> && - std::is_same_v>, - "A/B Dram block window should have the same data type as appropriate " - "([A|B]DataType) defined in Problem definition!"); - - static_assert(std::is_same_v, "Wrong!"); - static_assert(std::is_same_v, "Wrong!"); - - static_assert((MPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I0] && - KPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I1]), - "A block window has incorrect lengths for defined ALayout!"); - static_assert(Preshuffle // - ? (NWarps == BsDramBlockWindowTmp{}.get_window_lengths()[I0] && - kflatKPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1]) - : (NPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I0] && - KPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1]), - "B block window has incorrect lengths for defined BLayout!"); - - // ------------------ - // Hot loop scheduler - // ------------------ - auto hot_loop_scheduler = [&]() { - __builtin_amdgcn_sched_group_barrier(0x008, MIterPerWarp, 0); // MFMA - s_waitcnt_lgkm<4>(); - __builtin_amdgcn_sched_group_barrier(0x004, 1, 0); // lgkmcnt / SALU - static_for<0, MFMA_INST - MIterPerWarp, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - __builtin_amdgcn_sched_barrier(0); - }; - - // ------- - // Compute - // ------- - return Base::template Run_(p_smem, - num_loop, - a_dram_block_window_tmp, - b_dram_block_window_tmp, - scale_a_window, - scale_b_window, - hot_loop_scheduler); - } - }; - - template ::value && - !is_detected::value, - bool>* = nullptr> - CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - index_t num_loop, - void* p_smem) const - { - const bool has_hot_loop = Base::BlockHasHotloop(num_loop); - const auto tail_number = Base::GetBlockLoopTailNum(num_loop); - const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - scale_a_window, - scale_b_window, - num_loop, - p_smem); - }; - - return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); - } - - template ::value && - !is_detected::value, - bool>* = nullptr> - CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - index_t num_loop, - void* p_smem) const - { - const bool has_hot_loop = Base::BlockHasHotloop(num_loop); - const auto tail_number = Base::GetBlockLoopTailNum(num_loop); - const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - identity{}, - b_dram_block_window_tmp, - identity{}, - scale_a_window, - scale_b_window, - num_loop, - p_smem); - }; - - return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp deleted file mode 100644 index 519b7afcd3..0000000000 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp" -#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_eight_waves_v1.hpp" - -namespace ck_tile { -namespace detail { - -template -struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy -{ - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - - // MX scaling configuration: each e8m0 scale covers 32 elements in K - static constexpr int BlockScaleSize = 32; - - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using AComputeDataType = remove_cvref_t; - using BComputeDataType = remove_cvref_t; - using ComputeDataType = AComputeDataType; - static_assert(std::is_same_v, "Wrong!"); - static_assert(std::is_same_v, "Wrong!"); - static_assert(is_any_of::value); - static_assert(is_any_of::value); - static_assert(std::is_same_v); - static_assert(std::is_same_v); - - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - static constexpr index_t BlockSize = Problem::kBlockSize; - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(I0); - static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(I1); - static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(I2); - static constexpr index_t WarpTileM = WarpTile::at(I0); - static constexpr index_t WarpTileN = WarpTile::at(I1); - static constexpr index_t WarpTileK = WarpTile::at(I2); - static constexpr index_t MWarpTiles = MPerBlock / WarpTileM; - static constexpr index_t NWarpTiles = NPerBlock / WarpTileN; - static constexpr index_t KWarpTiles = KPerBlock / WarpTileK; - - // XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension - // Host packs MXdlPack * KXdlPack e8m0_t into one int32_t for A scales - // Host packs NXdlPack * KXdlPack e8m0_t into one int32_t for B scales - static constexpr int MXdlPack = 2; - static constexpr int NXdlPack = 2; - static constexpr int KXdlPack = 2; - - // Compute effective XdlPack sizes (fall back to 1 when iter count < pack) - static constexpr index_t MPerXdl = WarpTile::at(I0); - static constexpr index_t NPerXdl = WarpTile::at(I1); - static constexpr index_t KPerXdl = WarpTile::at(I2); - static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * MPerXdl); - static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * NPerXdl); - static constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; - - static constexpr index_t MXdlPackEff = - (MIterPerWarp >= MXdlPack && MIterPerWarp % MXdlPack == 0) ? MXdlPack : 1; - static constexpr index_t NXdlPackEff = - (NIterPerWarp >= NXdlPack && NIterPerWarp % NXdlPack == 0) ? NXdlPack : 1; - static constexpr index_t KXdlPackEff = - (KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1; - - static constexpr index_t KPerBlockScale = KPerBlock / BlockScaleSize / KXdlPackEff; - - static constexpr index_t KPerWarp = KPerBlock / KWarps; - static constexpr index_t NPerWarp = NPerBlock / NWarps; - static_assert(NWarps == 2, "NWarps == 2 for ping-pong!"); - - static constexpr index_t warp_size = get_warp_size(); - static constexpr index_t warp_num = BlockSize / warp_size; - static_assert(warp_size == 64, "Wrong!"); - static_assert(warp_num * warp_size == BlockSize, "Wrong!"); - - static_assert(sizeof(ADataType) == sizeof(BDataType), "Wrong!"); - static constexpr index_t ElementSize = sizeof(ADataType); - static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize; // 16 - static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8 - static constexpr index_t K0 = KPerWarp / (K1 * K2); - static_assert(K0 * K1 * K2 == KPerWarp, "Wrong!"); - - CK_TILE_HOST_DEVICE static constexpr auto GetKStepAQ() { return KPerBlockScale; } - CK_TILE_HOST_DEVICE static constexpr auto GetKStepBQ() { return KPerBlockScale; } - - CK_TILE_HOST_DEVICE static constexpr auto GetInstCountAQ() - { - return (MIterPerWarp / MXdlPackEff) * (KIterPerWarp / KXdlPackEff); - } - - CK_TILE_HOST_DEVICE static constexpr auto GetInstCountBQ() - { - return (NIterPerWarp / NXdlPackEff) * (KIterPerWarp / KXdlPackEff); - } - - CK_TILE_HOST_DEVICE static constexpr auto MakeAQBlockDistribution() - { - constexpr index_t K_Lane = get_warp_size() / WarpTileM; - - constexpr index_t KPerLane = WarpTileK / BlockScaleSize / K_Lane; - - constexpr index_t MIterPerWarp_packed = MIterPerWarp / MXdlPackEff; - constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; - - return make_static_tile_distribution( - tile_distribution_encoding< - sequence, // repeat over MWarps - tuple, // M dimension (first) - sequence>, // K dimension (second) - tuple, sequence<2, 1>>, // , - tuple, sequence<1, 2>>, - sequence<2, 1, 2>, // - sequence<0, 1, 2>>{}); - } - - CK_TILE_HOST_DEVICE static constexpr auto MakeBQBlockDistribution() - { - constexpr index_t K_Lane = get_warp_size() / WarpTileN; - - constexpr index_t KPerLane = WarpTileK / BlockScaleSize / K_Lane; - - constexpr index_t NIterPerWarp_packed = NIterPerWarp / NXdlPackEff; - constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; - - return make_static_tile_distribution( - tile_distribution_encoding< - sequence, // repeat over MWarps - tuple, // N dimension - // (first) - sequence>, // K dimension (second) - tuple, sequence<2, 1>>, // , - tuple, sequence<1, 3>>, - sequence<2, 1, 2>, // - sequence<0, 1, 2>>{}); - } - - CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() - { - constexpr auto wg_attr_num_access = - (std::is_same_v || std::is_same_v) - ? WGAttrNumAccessEnum::Double - : WGAttrNumAccessEnum::Single; - - using WarpGemm = WarpGemmDispatcher; - - using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; - - return BlockMXGemmARegBRegCRegEightWavesV1{}; - } -}; -} // namespace detail - -struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy - : public GemmPipelineAgBgCrCompAsyncEightWavesPolicy -{ - -#define FORWARD_METHOD_(method) \ - template \ - CK_TILE_HOST_DEVICE static constexpr auto method(Args&&... args) \ - { \ - return detail::MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy::method( \ - std::forward(args)...); \ - } - - FORWARD_METHOD_(MakeAQBlockDistribution); - FORWARD_METHOD_(MakeBQBlockDistribution); - FORWARD_METHOD_(GetBlockGemm); - FORWARD_METHOD_(GetKStepAQ); - FORWARD_METHOD_(GetKStepBQ); - FORWARD_METHOD_(GetInstCountAQ); - FORWARD_METHOD_(GetInstCountBQ); - -#undef FORWARD_METHOD_ -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves_policy.hpp index d52cb9ddc1..3723353233 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves_policy.hpp @@ -115,31 +115,6 @@ struct GemmABQuantPipelineAgBgCrAsyncPolicy sequence<1, 2>, sequence<0, 1>>{}); } - - CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() - { - static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0, - "KPerWarpGemm must be a multiple of QuantGroupSize::kK!"); - static_assert(Problem::TransposeC, "Wrong!"); - - using WarpGemm = WarpGemmDispatcher; - - using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; - return ABQuantBlockUniversalGemmAsBsCrAsync{}; - } }; } // namespace detail @@ -165,6 +140,47 @@ struct GemmABQuantPipelineAgBgCrAsyncPolicy : public GemmPipelineAgBgCrCompAsync FORWARD_METHOD_(GetInstCountBQ); #undef FORWARD_METHOD_ + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0, + "KPerWarpGemm must be a multiple of QuantGroupSize::kK!"); + static_assert(Problem::TransposeC, "Wrong!"); + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using AComputeDataType = remove_cvref_t; + using BComputeDataType = remove_cvref_t; + + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + + constexpr index_t WarpTileM = WarpTile::at(I0); + constexpr index_t WarpTileN = WarpTile::at(I1); + constexpr index_t WarpTileK = WarpTile::at(I2); + + constexpr auto WGAccessDouble = WGAttrNumAccessEnum::Double; + + using WarpGemm = WarpGemmDispatcher; + + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + return ABQuantBlockUniversalGemmAsBsCrAsync{}; + } }; } // namespace ck_tile diff --git a/test/ck_tile/gemm_mx/CMakeLists.txt b/test/ck_tile/gemm_mx/CMakeLists.txt index 4b6e6b795c..18316ec801 100644 --- a/test/ck_tile/gemm_mx/CMakeLists.txt +++ b/test/ck_tile/gemm_mx/CMakeLists.txt @@ -7,8 +7,16 @@ if(CK_USE_OCP_FP8) endif() if(GPU_TARGETS MATCHES "gfx95") - add_gtest_executable(test_ck_tile_mx_gemm_async test_mx_gemm_async.cpp) - target_compile_options(test_ck_tile_mx_gemm_async PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_mx_gemm_async_rcr test_mx_gemm_async_rcr.cpp) + target_compile_options(test_ck_tile_mx_gemm_async_rcr PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_mx_gemm_async_rrr test_mx_gemm_async_rrr.cpp) + target_compile_options(test_ck_tile_mx_gemm_async_rrr PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_mx_gemm_async_crr test_mx_gemm_async_crr.cpp) + target_compile_options(test_ck_tile_mx_gemm_async_crr PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_mx_gemm_async_ccr test_mx_gemm_async_ccr.cpp) + target_compile_options(test_ck_tile_mx_gemm_async_ccr PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_mx_gemm_async_rcr_large_cases test_mx_gemm_async_rcr_large_cases.cpp) + target_compile_options(test_ck_tile_mx_gemm_async_rcr_large_cases PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile MX GEMM tests for current target") endif() diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_async.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_async.cpp deleted file mode 100644 index 1804325013..0000000000 --- a/test/ck_tile/gemm_mx/test_mx_gemm_async.cpp +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_mx_gemm_config.hpp" -#include "test_mx_gemm_util.hpp" - -using Row = ck_tile::tensor_layout::gemm::RowMajor; -using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -using F4 = ck_tile::pk_fp4_t; -using F8 = ck_tile::fp8_t; -using B8 = ck_tile::bf8_t; - -// clang-format off -using MxTypes = ::testing::Types, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple>; -// clang-format on - -template -class TestMxGemm : public TestMxGemmUtil -{ -}; - -TYPED_TEST_SUITE(TestMxGemm, MxTypes); - -TYPED_TEST(TestMxGemm, Default) -{ - this->Run(128, 512, 256); - this->Run(256, 512, 512); - this->Run(1024, 1024, 1024); -} - -// 32x32x64 MFMA warp tile: enables all four A/B layout combinations via ds_read_tr -// transposed LDS loads. (16x16x128 stays Row/Col only above: KWarpTile=128 exceeds the -// ds_read_tr subtile limit, which disables transpose loads.) -// clang-format off -using MxTypesTranspose = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - // bf8/bf8 and mixed fp8/bf8 exercise the float8 paths newly consolidated into the generic - // 32x32x64 f8/f6/f4 dispatcher (previously distinct per-type code paths). - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple>; -// clang-format on - -template -class TestMxGemmTranspose : public TestMxGemmUtil -{ -}; - -TYPED_TEST_SUITE(TestMxGemmTranspose, MxTypesTranspose); - -TYPED_TEST(TestMxGemmTranspose, BasicSizes) -{ - this->Run(128, 128, 256); - this->Run(128, 128, 512); -} - -TYPED_TEST(TestMxGemmTranspose, MultiBlockMN) -{ - this->Run(256, 128, 256); - this->Run(128, 256, 256); - this->Run(256, 256, 256); -} - -// Preshuffle split-K coverage. MxTypes already exercises the preshuffle configs on the -// non-split-K shapes (TestMxGemm.Default); this fixture pins the split-K shapes to the -// fp4/fp8 preshuffle configs. -using MxTypesPreshuffle = - ::testing::Types, - std::tuple>; - -template -class TestMxGemmPreshuffle : public TestMxGemmUtil -{ -}; - -TYPED_TEST_SUITE(TestMxGemmPreshuffle, MxTypesPreshuffle); - -// Split-K for the preshuffle pipeline: each k_id offsets the flat-B window and the -// host-preshuffled A/B scale windows into its own K slice (and accumulates via atomic-add). -// K is a multiple of K_Tile * k_batch (= 256 * k_batch); N is a multiple of 512 so the shapes -// are valid for both the fp4 (N_Tile = 512) and fp8 (N_Tile = 256) preshuffle configs. -TYPED_TEST(TestMxGemmPreshuffle, SplitK) -{ - this->Run(128, 512, 512, /*k_batch=*/2); - this->Run(128, 512, 1024, /*k_batch=*/2); - this->Run(128, 512, 1024, /*k_batch=*/4); - this->Run(256, 512, 2048, /*k_batch=*/4); -} - -// Regression coverage for the MX GEMM correctness fixes (PR #6663): num_loop == 3 hot-loop -// dispatch, split-K, and M/N padding. Shapes are pinned to fp8 x MX_GemmConfig16 (M_Tile = 64, -// N_Tile = 128, K_Tile = 256, default comp-async pipeline) so the regressions hit the intended -// code path -- e.g. K = 768 gives num_loop = K / K_Tile = 3. -using MxFp8Cfg16Types = ::testing::Types>; - -using MxFp8PadMNTypes = - ::testing::Types>; - -template -class TestMxGemmFp8Regression : public TestMxGemmUtil -{ -}; - -TYPED_TEST_SUITE(TestMxGemmFp8Regression, MxFp8Cfg16Types); - -// num_loop == 3 must not enter the hot loop: with K_Tile = 256, K = 768 gives num_loop = 3, -// which previously produced 5 gemm accumulations instead of 3 (deterministically wrong). -TYPED_TEST(TestMxGemmFp8Regression, HotLoopTailNumLoopThree) -{ - this->Run(64, 128, 768); - this->Run(128, 256, 768); - this->Run(256, 256, 768); -} - -// Split-K: exercises both the full_k_read and partial_k_read paths of SplitKBatchOffset together -// with the per-split scale-window K offset and the atomic-add epilogue. K is a multiple of -// K_Tile * k_batch and of WarpTile_K * k_batch (= 128 * k_batch) so every split lands on a -// packed-scale boundary. -TYPED_TEST(TestMxGemmFp8Regression, SplitK) -{ - this->Run(128, 256, 512, /*k_batch=*/2); - this->Run(128, 256, 1024, /*k_batch=*/2); - this->Run(128, 256, 1024, /*k_batch=*/4); - this->Run(256, 256, 2048, /*k_batch=*/4); -} - -// fp4 split-K (non-preshuffle). Same MX_GemmConfig16 tile shape as the fp8 regression above, so -// the K alignment requirements are identical; this verifies the packed (BPackedSize = 2) A/B -// pointer K-offset works under split-K + atomic-add for fp4. -using MxF4Cfg16Types = ::testing::Types>; - -template -class TestMxGemmFp4Regression : public TestMxGemmUtil -{ -}; - -TYPED_TEST_SUITE(TestMxGemmFp4Regression, MxF4Cfg16Types); - -TYPED_TEST(TestMxGemmFp4Regression, SplitK) -{ - this->Run(128, 256, 512, /*k_batch=*/2); - this->Run(128, 256, 1024, /*k_batch=*/2); - this->Run(128, 256, 1024, /*k_batch=*/4); - this->Run(256, 256, 2048, /*k_batch=*/4); -} - -template -class TestMxGemmFp8PadMN : public TestMxGemmUtil -{ -}; - -TYPED_TEST_SUITE(TestMxGemmFp8PadMN, MxFp8PadMNTypes); - -// M/N padding (kPadM = kPadN = true). M_Tile = 64, N_Tile = 128. Each of M and N must be >= its -// block tile (the CShuffleEpilogue cannot safely run a single partial tile along either -// dimension); K stays aligned because the MX async pipeline does not support K padding. -TYPED_TEST(TestMxGemmFp8PadMN, MNPaddingAligned) -{ - // Sanity: padding enabled but already-aligned M, N must not regress the normal path. - this->Run(64, 128, 256); -} - -TYPED_TEST(TestMxGemmFp8PadMN, MPadding) -{ - // M has a full tile + partial trailing tile (N aligned to N_Tile = 128). - this->Run(96, 128, 256); - this->Run(160, 128, 256); -} - -TYPED_TEST(TestMxGemmFp8PadMN, NPadding) -{ - // N has a full tile + partial trailing tile (M aligned to M_Tile = 64). - this->Run(64, 160, 256); - this->Run(64, 224, 256); -} - -TYPED_TEST(TestMxGemmFp8PadMN, MNPadding) -{ - // Both M and N unaligned (full + partial trailing tiles). - this->Run(96, 160, 256); - this->Run(160, 224, 512); -} diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_async_ccr.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_async_ccr.cpp new file mode 100644 index 0000000000..0dab4b592b --- /dev/null +++ b/test/ck_tile/gemm_mx/test_mx_gemm_async_ccr.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_mx_gemm_pipeline_kernel_types.hpp" +#include "test_mx_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileMxGemmPipelineCompAsyncCCR + : public TestCkTileMxGemmPipeline> +{ + public: + static constexpr bool check_data_type() { return true; } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsyncCCR + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsyncCCR, KernelTypesMxGemmCompAsyncCCR); + +#include "test_mx_gemm_pipeline_tr_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_async_crr.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_async_crr.cpp new file mode 100644 index 0000000000..f95b286cd3 --- /dev/null +++ b/test/ck_tile/gemm_mx/test_mx_gemm_async_crr.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_mx_gemm_pipeline_kernel_types.hpp" +#include "test_mx_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileMxGemmPipelineCompAsyncCRR + : public TestCkTileMxGemmPipeline> +{ + public: + static constexpr bool check_data_type() { return true; } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsyncCRR + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsyncCRR, KernelTypesMxGemmCompAsyncCRR); + +#include "test_mx_gemm_pipeline_tr_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_async_rcr.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_async_rcr.cpp new file mode 100644 index 0000000000..0528295138 --- /dev/null +++ b/test/ck_tile/gemm_mx/test_mx_gemm_async_rcr.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_mx_gemm_pipeline_kernel_types.hpp" +#include "test_mx_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileMxGemmPipelineCompAsyncRCR + : public TestCkTileMxGemmPipeline> +{ + public: + static constexpr bool check_data_type() { return true; } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsyncRCR + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsyncRCR, KernelTypesMxGemmCompAsyncRCR); + +#include "test_mx_gemm_pipeline_ut_cases.inc" + +TYPED_TEST(TEST_SUITE_NAME, MNPadding) +{ + if constexpr(TestFixture::PipelineType == MxGemmPipelineType::WeightPreshuffle || + TestFixture::PipelineType == MxGemmPipelineType::CompEightWaves) + { + return; + } + + std::vector Ms{96, 160, 224}; + std::vector Ns{96, 160, 224}; + std::vector Ks; + // K must be multiple of ScaleBlockSize (16 or 32) and K_Tile + for(auto K_count : {2, 3, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int N : Ns) + { + for(int K : Ks) + { + this->template Run(M, N, K); + } + } + } +} + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_async_rcr_large_cases.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_async_rcr_large_cases.cpp new file mode 100644 index 0000000000..faa7b4be8d --- /dev/null +++ b/test/ck_tile/gemm_mx/test_mx_gemm_async_rcr_large_cases.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_mx_gemm_pipeline_kernel_types.hpp" +#include "test_mx_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileMxGemmPipelineCompAsyncRCR + : public TestCkTileMxGemmPipeline> +{ + public: + static constexpr bool check_data_type() { return true; } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsyncRCR + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsyncRCR, KernelTypesMxGemmCompAsyncRCRLargeCases); + +TYPED_TEST(TEST_SUITE_NAME, Large) +{ + int M = 6422528; + int N = 6144; + int K = 1024; + + this->RunAllGpu(M, N, K); +} + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_async_rrr.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_async_rrr.cpp new file mode 100644 index 0000000000..490326a78f --- /dev/null +++ b/test/ck_tile/gemm_mx/test_mx_gemm_async_rrr.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_mx_gemm_pipeline_kernel_types.hpp" +#include "test_mx_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileMxGemmPipelineCompAsyncRRR + : public TestCkTileMxGemmPipeline> +{ + public: + static constexpr bool check_data_type() { return true; } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsyncRRR + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsyncRRR, KernelTypesMxGemmCompAsyncRRR); + +#include "test_mx_gemm_pipeline_tr_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp deleted file mode 100644 index 1e49fdbe71..0000000000 --- a/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp" - -template -struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0> -{ - using Base = ck_tile::UniversalGemmHostArgs<1, 1, 0>; - - MXGemmHostArgs(const void* a_ptr, - const void* b_ptr, - void* c_ptr_, - ck_tile::index_t k_batch_, - ck_tile::index_t M_, - ck_tile::index_t N_, - ck_tile::index_t K_, - ck_tile::index_t stride_A_, - ck_tile::index_t stride_B_, - ck_tile::index_t stride_C_, - ScaleM scale_m_, - ScaleN scale_n_) - : Base({a_ptr}, - {b_ptr}, - {}, - c_ptr_, - k_batch_, - M_, - N_, - K_, - {stride_A_}, - {stride_B_}, - {}, - stride_C_), - scale_m(scale_m_), - scale_n(scale_n_) - { - } - - ScaleM scale_m; - ScaleN scale_n; -}; - -struct MxGemmConfig -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 512; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 128; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = false; - - static constexpr int kBlockPerCu = 1; - static constexpr int TileParitionerGroupNum = 8; - static constexpr int TileParitionerM01 = 4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool DoubleSmemBuffer = false; - static constexpr bool Preshuffle = false; - static constexpr ck_tile::index_t BContiguousItemsPerAccess = 16; - - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = false; -}; - -struct MX_GemmConfig16 : MxGemmConfig -{ - static constexpr ck_tile::index_t M_Tile = 64; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 256; -}; - -struct MX_GemmConfigEightWaves : MxGemmConfig -{ - static constexpr ck_tile::index_t M_Warp = 4; - static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong! - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128 * N_Warp; - static constexpr ck_tile::index_t K_Tile = 128 * K_Warp; - - static constexpr int kBlockPerCu = 2; -}; - -struct MXfp4_GemmConfig16_Preshuffle : MxGemmConfig -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 512; - static constexpr ck_tile::index_t K_Tile = 256; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr bool Preshuffle = true; - static constexpr ck_tile::index_t BContiguousItemsPerAccess = 32; -}; - -struct MXfp4_GemmConfig16_PermuteN : MXfp4_GemmConfig16_Preshuffle -{ - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; -}; - -struct MXfp8_GemmConfig16_Preshuffle : MxGemmConfig -{ - // For FP8 Preshuffle: - // The theoretical functional minimum is N_Tile = N_Warp * N_Warp_Tile * NXdlPack = 4*16*2 = - // 128 . For better performance, we would choose N_Repeat = 2 which would yield N_Tile = 128 * 2 - // = 256 . Note: If we use fewer waves, the minimum theoretical N_Tile can be even smaller, - // reduced to N_Tile = 32 for 1 single wave. - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 256; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr bool Preshuffle = true; -}; - -struct MxGemmConfig32 : MxGemmConfig -{ - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 64; -}; - -struct MXfp4_GemmConfig32 : MxGemmConfig32 -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 256; -}; - -struct MXfp8_GemmConfig32 : MxGemmConfig32 -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 256; -}; - -// Variant with M/N padding enabled. Used to cover shapes where M/N are not multiples of -// the respective block tiles (MX_GemmConfig16 has M_Tile = 64, N_Tile = 128). K is still -// required to be a multiple of K_Tile -- the MX comp-async pipeline does not support K padding -// (see MXGemmKernel::IsSupportedArgument). -struct MXfp8_GemmConfig16_PadMN : MX_GemmConfig16 -{ - static constexpr bool kPadM = true; - static constexpr bool kPadN = true; -}; - -struct MXfp8_GemmConfig16_PermuteN : MXfp8_GemmConfig16_Preshuffle -{ - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; -}; diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_instance.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_instance.hpp deleted file mode 100644 index 84cea49af3..0000000000 --- a/test/ck_tile/gemm_mx/test_mx_gemm_instance.hpp +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm_mx.hpp" -#include "test_mx_gemm_config.hpp" - -template -float mx_gemm_calc(const MXGemmHostArgs& args, const ck_tile::stream_config& s) -{ - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile:: - sequence>; - - using MXGemmTraits = ck_tile::TileGemmUniversalTraits; - - using ComputeDataType = ADataType; - static_assert(sizeof(ComputeDataType) >= sizeof(BDataType), - "mixed_prec_gemm requires ADataType is a wider type than BDataType"); - - using MXPipelineProblem = ck_tile::UniversalGemmPipelineProblem; - - constexpr bool IsEightWave = - (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp) == 8; - using MXGemmPipeline = std::conditional_t< - GemmConfig::Preshuffle, - ck_tile::MXGemmPreshufflePipelineAGmemBGmemCRegV1, - std::conditional_t, - ck_tile::MXGemmPipelineAgBgCrCompAsync>>; - - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; - - using GemmEpilogue = - std::conditional_t, // DsDataType - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - MXPipelineProblem::TransposeC, - false, // FixedVectorSize_ (Default) - 1>>, // VectorSizeC_ (Default) - ck_tile::CShuffleEpilogue, // DsDataType - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - MXPipelineProblem::TransposeC, - GemmConfig::NumWaveGroups, - false, - 1, - ck_tile::MXEpilogueTraits::BlockedXDLNPerWarp, - false, // DoubleSmemBuffer_ (Default) - ADataType, // AComputeDataType - BDataType, // BComputeDataType - !GemmConfig::Preshuffle>>>; // TilesPacked_ (because of - // packed scales) - - using Kernel = ck_tile::MXGemmKernel; - - auto kargs = Kernel::MakeKernelArgs(std::array{args.as_ptr}, - std::array{args.bs_ptr}, - std::array{}, - args.e_ptr, - args.k_batch, - args.M, - args.N, - args.K, - std::array{args.stride_As}, - std::array{args.stride_Bs}, - std::array{}, - args.stride_E, - args.scale_m, - args.scale_n); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error( - "MX GEMM: unsupported shape/configuration (set CK_TILE_LOGGING=1 for details)."); - } - - const auto kernel = ck_tile::make_kernel( - Kernel{}, Kernel::GridSize(kargs), Kernel::BlockSize(), 0, kargs); - - return ck_tile::launch_kernel(s, kernel); -} diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp index 71d05e7656..ff9e42e2e4 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp @@ -10,19 +10,23 @@ #include "test_mx_gemm_pipeline_util.hpp" #include "test_mx_gemm_pipeline_prec_types.hpp" -using Row = ck_tile::tensor_layout::gemm::RowMajor; -using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -using Intrawave = ck_tile::integral_constant; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using CompTDMV1 = ck_tile::integral_constant; using CompTDMV2 = ck_tile::integral_constant; +using CompAsync = ck_tile::integral_constant; +using CompEightWaves = + ck_tile::integral_constant; +using WeightPreshuffle = + ck_tile::integral_constant; using I16 = ck_tile::number<16>; using I32 = ck_tile::number<32>; using I64 = ck_tile::number<64>; using I128 = ck_tile::number<128>; using I256 = ck_tile::number<256>; +using I512 = ck_tile::number<512>; using ClusterEnable = std::true_type; using ClusterDisable = std::false_type; @@ -33,48 +37,89 @@ using ClusterDisable = std::false_type; // ALayout, BLayout, CLayout, ADataType, BDataType, AScaleDataType, BScaleDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, N_TileSize, Scheduler, PipelineType, ScaleBlockSize using KernelTypesMxGemmCompTDMWmma = ::testing::Types< // --- Scale32 (WarpTile=32, CompTDMV1) --- - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, - std::tuple< Row, Col, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, - std::tuple< Row, Col, Row, F4, F4, E4M3, E4M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, - std::tuple< Row, Col, Row, F8, F4, E8M0, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, - std::tuple< Row, Col, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, - std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, - std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, + std::tuple< Row, Col, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, + std::tuple< Row, Col, Row, F4, F4, E4M3, E4M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, + std::tuple< Row, Col, Row, F8, F4, E8M0, E5M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, + std::tuple< Row, Col, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, // --- Scale32 (WarpTile=32, CompTDMV2) --- - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>, - std::tuple< Row, Col, Row, F4, F4, E4M3, E4M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>, - std::tuple< Row, Col, Row, F8, F4, E8M0, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>, - std::tuple< Row, Col, Row, F4, F8, E4M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>, - std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>, - std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>, - std::tuple< Row, Row, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, - std::tuple< Col, Row, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>, + std::tuple< Row, Col, Row, F4, F4, E4M3, E4M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>, + std::tuple< Row, Col, Row, F8, F4, E8M0, E5M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>, + std::tuple< Row, Col, Row, F4, F8, E4M3, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>, + std::tuple< Row, Row, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, + std::tuple< Col, Row, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>, // --- Scale16 (WarpTile=16, CompTDMV1) --- - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV1, I16>, // RRR (non-RCR) layout + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV1, I16>, + std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV1, I16>, + std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV1, I16>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV1, I16>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV1, I16>, // RRR (non-RCR) layout // --- Scale16 (WarpTile=32, CompTDMV1) --- - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I16>, - std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I16>, // RRR (non-RCR) layout + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I16>, + std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I16>, + std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I16>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I16>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I16>, // RRR (non-RCR) layout // --- Scale16 (WarpTile=16, CompTDMV2) --- - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV2, I16>, // RRR (non-RCR) layout + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV2, I16>, + std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV2, I16>, + std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV2, I16>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV2, I16>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV2, I16>, // RRR (non-RCR) layout // --- Scale16 (WarpTile=32, CompTDMV2) --- - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I16>, - std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I16>, // RRR (non-RCR) layout + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I16>, + std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I16>, + std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I16>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I16>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I16>, // RRR (non-RCR) layout // --- Scale32 cluster launch (from develop; ScaleBlockSize=I32 at idx 16, ClusterEnable at idx 17) --- - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32, ClusterEnable>, - std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32, ClusterEnable> + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32, std::false_type, ClusterEnable>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32, std::false_type, ClusterEnable> >; + +using KernelTypesMxGemmCompAsyncRCR = ::testing::Types< + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I256, I16, I16, CompAsync, I32>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I64, I64, I256, I16, I16, CompAsync, I32>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I256, I128, I16, I16, CompEightWaves, I32>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I256, I128, I16, I16, CompEightWaves, I32>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I256, I256, I16, I16, WeightPreshuffle, I32>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I512, I256, I16, I16, WeightPreshuffle, I32>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I256, I256, I16, I16, WeightPreshuffle, I32, std::true_type>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I256, I256, I16, I16, WeightPreshuffle, I32, std::true_type>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32> +>; + +using KernelTypesMxGemmCompAsyncRCRLargeCases = ::testing::Types< + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I64, I64, I256, I16, I16, CompAsync, I32> +>; + +using KernelTypesMxGemmCompAsyncRRR = ::testing::Types< + std::tuple< Row, Row, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Row, Row, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Row, Row, Row, BF8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Row, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32> +>; + +using KernelTypesMxGemmCompAsyncCRR = ::testing::Types< + std::tuple< Col, Row, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Col, Row, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Col, Row, Row, BF8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Col, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32> +>; + +using KernelTypesMxGemmCompAsyncCCR = ::testing::Types< + std::tuple< Col, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Col, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Col, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>, + std::tuple< Col, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32> +>; + // clang-format on diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_tr_cases.inc b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_tr_cases.inc new file mode 100644 index 0000000000..1d5e359935 --- /dev/null +++ b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_tr_cases.inc @@ -0,0 +1,27 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +TYPED_TEST(TEST_SUITE_NAME, Regular) +{ + std::vector Ms{128, 256}; + std::vector Ns{128, 256}; + std::vector Ks; + // K must be multiple of ScaleBlockSize (16 or 32) and K_Tile + for(auto K_count : {1, 2, 3, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int N : Ns) + { + for(int K : Ks) + { + this->Run(M, N, K); + } + } + } +} diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_ut_cases.inc index f579766cf6..a4e62ad3f1 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_ut_cases.inc @@ -14,7 +14,7 @@ TYPED_TEST(TEST_SUITE_NAME, SingleTile) TYPED_TEST(TEST_SUITE_NAME, SmallM) { std::vector Ms{1, 2, 4, 8, 16}; - constexpr int N = 64; + constexpr int N = TestFixture::N_Tile; std::vector Ks; // K must be multiple of ScaleBlockSize (16 or 32) and K_Tile for(auto K_count : {2, 3, 4}) @@ -34,7 +34,7 @@ TYPED_TEST(TEST_SUITE_NAME, SmallM) TYPED_TEST(TEST_SUITE_NAME, MidLargeM) { std::vector Ms{32, 64, 128, 256}; - std::vector Ns{96, 128}; // 96 tests non-tile-aligned N + std::vector Ns{TestFixture::N_Tile}; std::vector Ks; for(auto K_count : {2, 3, 4}) { diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp index d6a6217b55..361921ce45 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp @@ -11,6 +11,7 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp" #include "ck_tile/core/numeric/math.hpp" +#include "ck/library/utility/gpu_verification.hpp" template static constexpr inline auto is_row_major(Layout layout_) @@ -44,9 +45,9 @@ constexpr ck_tile::index_t get_k_warp_tile() #endif #else if constexpr(M_Warp_Tile == 32) - return 16; + return 64; else - return 32; + return 128; #endif } @@ -68,10 +69,64 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } +// Deterministic per-element hash RNG for GPU data init. Returns a float in [-3, 3). +// The generic `fill_tensor_uniform_rand_fp_values` filler is NOT valid for ck_tile::pk_fp4_t +// (it converts a single float and duplicates it into both nibbles, and special-cases only the +// classic ck::f4x2_pk_t). We need two independent fp4 values per byte, so we fill directly. +// The narrow [-3,3) range keeps the fp16 GEMM output from overflowing at K up to 4096 (with the +// [0.25,1.0] scales used in RunAllGpu, worst case K*9 = 36864 < 65504). +__device__ inline float mx_fp4_fill_rand(unsigned int seed, unsigned long long idx) +{ + // splitmix64-style avalanche; deterministic given (seed, idx). + unsigned long long z = (idx + 1ULL) * 0x9E3779B97F4A7C15ULL + + static_cast(seed) * 0xD1B54A32D192ED03ULL; + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ULL; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBULL; + z ^= z >> 31; + const float u = + static_cast((z >> 40) & 0xFFFFFFULL) / static_cast(0x1000000); // [0,1) + return u * 6.0f - 3.0f; // [-3,3) +} + +// Fill a packed-fp4 buffer with two independent, deterministic random fp4 values per byte. +// `num_packed` is the number of pk_fp4_t elements (= total fp4 values / 2). +__global__ void +fill_pk_fp4_uniform_kernel(ck_tile::pk_fp4_t* __restrict__ ptr, long num_packed, unsigned int seed) +{ + const long idx0 = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const long nthr = static_cast(gridDim.x) * blockDim.x; + for(long i = idx0; i < num_packed; i += nthr) + { + const float lo_f = rintf(mx_fp4_fill_rand(seed, static_cast(i) * 2ULL)); + const float hi_f = + rintf(mx_fp4_fill_rand(seed, static_cast(i) * 2ULL + 1ULL)); + const auto lo = ck_tile::float_to_mxfp4(lo_f, 1.0f); + const auto hi = ck_tile::float_to_mxfp4(hi_f, 1.0f); + ptr[i] = ck_tile::pk_fp4_t::_pack(lo, hi); + } +} + +inline void fill_pk_fp4_uniform(ck_tile::pk_fp4_t* ptr, + long num_packed, + unsigned int seed, + hipStream_t stream = nullptr) +{ + constexpr int threads = 256; + constexpr long max_blocks = 65536; // grid-stride cap + const long needed = (num_packed + threads - 1) / threads; + const long blocks = needed < max_blocks ? needed : max_blocks; + fill_pk_fp4_uniform_kernel<<(blocks)), dim3(threads), 0, stream>>>( + ptr, num_packed, seed); + ck_tile::hip_check_error(hipGetLastError()); +} + enum struct MxGemmPipelineType { CompTDMV1, - CompTDMV2 + CompTDMV2, + CompAsync, + CompEightWaves, + WeightPreshuffle }; template @@ -95,88 +150,93 @@ struct MxGemmPipelineTypeSelector static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV2"; } }; -template +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync; + using pipeline = ck_tile::GemmPipelineAgBgCrCompAsync; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsync"; } +}; + +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + using pipeline = ck_tile::GemmPipelineAgBgCrCompAsyncEightWaves; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompEightWaves"; } +}; + +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using pipeline = ck_tile::MXGemmPreshufflePipelineAGmemBGmemCRegV1; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffle"; } +}; + +template struct MxGemmEpilogueTypeSelector +{ +}; + +template +struct MxGemmEpilogueTypeSelector { using epilogue = ck_tile::TdmEpilogue; }; +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = ck_tile::TdmEpilogue; +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = ck_tile::CShuffleEpilogue; +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = ck_tile::CShuffleEpilogue; +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = std::conditional_t, + ck_tile::CShuffleEpilogue>; +}; + template struct MxGemmPipelineDefaultParams { static constexpr bool PadM = false; static constexpr bool PadN = false; static constexpr bool PadK = false; - static constexpr bool Preshuffle = false; + static constexpr bool Preshuffle = PT == MxGemmPipelineType::WeightPreshuffle; }; -/// @brief Pre-shuffle scale buffer for gfx1250 wmma mx scale instruction. -/// -/// Reorganizes the scale data from row-major (MN x K) layout to the hardware-specific -/// layout expected by the gfx1250 wmma instruction. -/// -/// @tparam ScaleType Scale data type (e.g., e8m0_t) -/// @tparam ScaleBlockSize The block size for microscaling (e.g., 32) -/// @tparam KStride Whether K is the fast-moving dimension -template -void preShuffleScaleBuffer_gfx1250(const ScaleType* src, - ScaleType* dst, - ck_tile::index_t MN, - ck_tile::index_t K) +template +struct Config { - static_assert((ScaleBlockSize == 32 || ScaleBlockSize == 16) && sizeof(ScaleType) == 1, - "wrong! only support 8-bit scale with ScaleBlockSize=32 or 16"); - - // ScaleBlockSize == 16: the natural row-major scale layout already matches the gfx1250 - // wmma scale distribution (one e8m0 per 16 K-elements lands warp-aligned), so the - // device-side shuffle is the identity transform for all K. - if constexpr(ScaleBlockSize == 16) - { - for(ck_tile::index_t mn = 0; mn < MN; ++mn) - for(ck_tile::index_t k = 0; k < K; ++k) - { - if constexpr(KStride) - dst[mn * K + k] = src[mn * K + k]; - else - dst[mn * K + k] = src[k * MN + mn]; - } - return; - } - - constexpr ck_tile::index_t MPerXdlops = 16; - constexpr ck_tile::index_t KPerXdlops = 128; - - int MNPack = 2; - int KPack = 1; - - int MNStep = MPerXdlops; - int KStep = KPerXdlops / ScaleBlockSize; - - int K0 = K / KPack / KStep; - - for(int mn = 0; mn < MN; ++mn) - { - int iMNRepeat = mn / (MNStep * MNPack); - int tempmn = mn % (MNStep * MNPack); - - for(int k = 0; k < K; ++k) - { - int iKRepeat = k / (KStep * KPack); - int tempk = k % (KStep * KPack); - - int outputIndex = (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) + - (iKRepeat * KStep * KPack) * (MNStep * MNPack) + - tempmn * (KStep * KPack) + tempk; - - if constexpr(KStride) - { - dst[outputIndex] = src[mn * K + k]; - } - else - dst[outputIndex] = src[k * MN + mn]; - } - } -} + static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_; + static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_; + static constexpr ck_tile::index_t N_Tile = N_Tile_; + static constexpr ck_tile::index_t N_Warp = N_Warp_; + static constexpr ck_tile::index_t BContiguousItemsPerAccess = + std::is_same_v ? 32 : 16; +}; template class TestCkTileMxGemmPipeline : public ::testing::Test @@ -191,8 +251,10 @@ class TestCkTileMxGemmPipeline : public ::testing::Test using BScaleDataType = std::tuple_element_t<6, Tuple>; using AccDataType = std::tuple_element_t<7, Tuple>; using CDataType = std::tuple_element_t<8, Tuple>; - static constexpr auto Scheduler = std::tuple_element_t<14, Tuple>::value; - static constexpr auto PipelineType = std::tuple_element_t<15, Tuple>::value; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr auto PipelineType = std::tuple_element_t<14, Tuple>::value; + static constexpr bool PermuteN = + ck_tile::tuple_element_or_default_t::value; static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<9, Tuple>{}; static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<10, Tuple>{}; @@ -213,17 +275,21 @@ class TestCkTileMxGemmPipeline : public ::testing::Test static constexpr bool ClusterLaunch = ck_tile::tuple_element_or_default_t::value; - static constexpr ck_tile::index_t ScaleBlockSize = std::tuple_element_t<16, Tuple>{}; + static constexpr ck_tile::index_t ScaleBlockSize = std::tuple_element_t<15, Tuple>{}; + + static constexpr ck_tile::index_t M_Warp = + PipelineType == MxGemmPipelineType::WeightPreshuffle + ? 1 + : (PipelineType == MxGemmPipelineType::CompEightWaves ? 4 : 2); + static constexpr ck_tile::index_t N_Warp = + PipelineType == MxGemmPipelineType::WeightPreshuffle ? 4 : 2; + static constexpr ck_tile::index_t K_Warp = 1; protected: template void invoke_mx_gemm(const ck_tile::MxGemmHostArgs<1, 1, 0>& args, const ck_tile::stream_config& s) { - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - // if cluster launch is enabled, set cluster dim to 2x2x1 constexpr ck_tile::index_t kClusterSizeM = std::conditional_t, ck_tile::number<1>>{}; @@ -240,11 +306,13 @@ class TestCkTileMxGemmPipeline : public ::testing::Test constexpr bool DoubleSmemBuffer = true; // TDM pipeline requires double smem buffer #if defined(CK_USE_GFX1250) + constexpr ck_tile::index_t BlockedXDLNPerWarp = 1; constexpr bool TransposeC = std::is_same_v && M_Warp_Tile == N_Warp_Tile; -#else - constexpr bool TransposeC = false; +#elif defined(CK_USE_GFX950) + constexpr ck_tile::index_t BlockedXDLNPerWarp = Preshuffle ? 2 : 1; + constexpr bool TransposeC = false; #endif static constexpr bool StructuredSparsity = false; static constexpr bool NumWaveGroup = 1; @@ -302,8 +370,26 @@ class TestCkTileMxGemmPipeline : public ::testing::Test using GemmPipeline = typename MxGemmPipelineTypeSelector::pipeline; - using GemmEpilogue = typename MxGemmEpilogueTypeSelector< - PipelineType, + using GemmEpilogueProblem = std::conditional_t< + PipelineType == MxGemmPipelineType::WeightPreshuffle && PermuteN, + ck_tile::PermuteNEpilogueProblem, /*VectorSizeC_*/ ck_tile::CShuffleEpilogueProblem>::epilogue; + 1, /*kNumWaveGroups_*/ + false, /*FixedVectorSize_*/ + 1, /*VectorSizeC_*/ + BlockedXDLNPerWarp, /*BlockedXDLN_PerWarp_*/ + DoubleSmemBuffer, /*DoubleSmemBuffer*/ + AComputeDataType, /*AComputeDataType_*/ + BComputeDataType, /*BComputeDataType_*/ + !preshuffle>>; + + using GemmEpilogue = typename MxGemmEpilogueTypeSelector::epilogue; using Kernel = ck_tile::MxGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -360,12 +451,28 @@ class TestCkTileMxGemmPipeline : public ::testing::Test } public: + std::vector k_batches_; + void SetUp() override { if constexpr(!Derived::check_data_type()) { GTEST_SKIP() << "Unsupported data type combination for mx_gemm pipeline test."; } + // for TDM it's used tdm_epilogue which don't support split-k + if constexpr(PipelineType == MxGemmPipelineType::CompTDMV1 || + PipelineType == MxGemmPipelineType::CompTDMV2 || + std::is_same_v || + std::is_same_v) + { + // Only do k_batch = 1 + k_batches_ = {1}; + } + else + { + // Otherwise, use k_batch = 1 and 2 + k_batches_ = {1, 2}; + } } template ::PadM, @@ -381,7 +488,15 @@ class TestCkTileMxGemmPipeline : public ::testing::Test { if constexpr(Derived::check_data_type()) { - RunSingle(M, N, K, StrideA, StrideB, StrideC, 1); + for(auto kb : k_batches_) + { + // skip test when split k' number is not evenly distributed + if((K / K_Tile) % kb != 0) + { + continue; + } + RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); + } } } @@ -422,16 +537,18 @@ class TestCkTileMxGemmPipeline : public ::testing::Test // so M must be padded to at least MNPack * MPerXdlops = 32. constexpr index_t ScaleShuffleAlign = 32; const index_t scale_padded_M = integer_least_multiple( - static_cast(M), - static_cast(ck_tile::max(M_Warp_Tile, ScaleShuffleAlign))); + static_cast(M), static_cast(ck_tile::max(M_Tile, ScaleShuffleAlign))); HostTensor scale_a( {static_cast(scale_padded_M), static_cast(num_scale_k)}, {static_cast(num_scale_k), static_cast(1)}); - // scale_b uses N as first dimension (col-major like B) + const index_t scale_padded_N = integer_least_multiple( + static_cast(N), static_cast(ck_tile::max(N_Tile, ScaleShuffleAlign))); + // Pre-shuffle interleaves 2 K-lanes (MNPack=2) with MPerXdlops=16 stride, + // so N must be padded to at least MNPack * NPerXdlops = 32. HostTensor scale_b( - {static_cast(N), static_cast(num_scale_k)}, + {static_cast(scale_padded_N), static_cast(num_scale_k)}, {static_cast(num_scale_k), static_cast(1)}); // Fill data @@ -485,38 +602,112 @@ class TestCkTileMxGemmPipeline : public ::testing::Test } // Pre-shuffle scale buffers for the hardware +#if defined(CK_USE_GFX1250) + static constexpr index_t NXdlPackEff = 1; + HostTensor scale_a_shuffled( {static_cast(scale_padded_M), static_cast(num_scale_k)}, {static_cast(num_scale_k), static_cast(1)}); HostTensor scale_b_shuffled( - {static_cast(N), static_cast(num_scale_k)}, + {static_cast(scale_padded_N), static_cast(num_scale_k)}, {static_cast(num_scale_k), static_cast(1)}); // Pre-shuffle for gfx1250 (WaveSize=32, WMMA) // Scales start in natural tensor layout and are pre-shuffled into the device layout // for both scale block sizes (the shuffle is the identity for ScaleBlockSize==16, // whose natural layout already matches the warp scale distribution). - preShuffleScaleBuffer_gfx1250( + ck_tile::preShuffleScaleBuffer_gfx1250( scale_a.mData.data(), scale_a_shuffled.mData.data(), scale_padded_M, num_scale_k); - preShuffleScaleBuffer_gfx1250( - scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k); + ck_tile::preShuffleScaleBuffer_gfx1250( + scale_b.mData.data(), scale_b_shuffled.mData.data(), scale_padded_N, num_scale_k); +#elif defined(CK_USE_GFX950) + constexpr ck_tile::index_t MPerXdl = M_Warp_Tile; + constexpr ck_tile::index_t NPerXdl = N_Warp_Tile; + constexpr ck_tile::index_t KPerXdl = K_Warp_Tile; + constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * MPerXdl); + constexpr ck_tile::index_t NIterPerWarp = N_Tile / (N_Warp * NPerXdl); + constexpr ck_tile::index_t KIterPerWarp = K_Tile / KPerXdl; + + constexpr ck_tile::index_t MXdlPackEff = + (MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t NXdlPackEff = + (NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t KXdlPackEff = + (KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1; + + constexpr ck_tile::index_t XdlMNThread = M_Warp_Tile; + constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread; + + HostTensor scale_a_shuffled( + {static_cast(scale_padded_M / MXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), static_cast(1)}); + + HostTensor scale_b_shuffled( + {static_cast(scale_padded_N / NXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), static_cast(1)}); + + ck_tile::preShuffleScaleBuffer_gfx950( + scale_a.mData.data(), scale_a_shuffled.mData.data(), scale_padded_M, num_scale_k, true); + + if constexpr(PipelineType == MxGemmPipelineType::WeightPreshuffle && PermuteN) + { + ck_tile::preShuffleScaleBufferPermuteN_gfx950( + scale_b.mData.data(), + scale_b_shuffled.mData.data(), + scale_padded_N, + num_scale_k, + true); + } + else + { + ck_tile:: + preShuffleScaleBuffer_gfx950( + scale_b.mData.data(), + scale_b_shuffled.mData.data(), + scale_padded_N, + num_scale_k, + true); + } +#endif // Allocate device memory DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); - DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes()); DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes()); // Upload data to device a_m_k_dev_buf.ToDevice(a_m_k.data()); - b_k_n_dev_buf.ToDevice(b_k_n.data()); c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); scale_a_dev_buf.ToDevice(scale_a_shuffled.data()); scale_b_dev_buf.ToDevice(scale_b_shuffled.data()); + using GemmConfig = Config; + + const auto b_host_for_dev = [&]() { + if constexpr(Preshuffle) + { + if constexpr(PermuteN) + { + return ck_tile::shuffle_b_permuteN(b_k_n); + } + else + { + return ck_tile::shuffle_b(b_k_n); + } + } + else + { + return b_k_n; + } + }(); + DeviceMem b_k_n_dev_buf(b_host_for_dev.get_element_space_size_in_bytes()); + b_k_n_dev_buf.ToDevice(b_host_for_dev.data()); + // Create MxGemmHostArgs ck_tile::MxGemmHostArgs<1, 1, 0> args( {static_cast(a_m_k_dev_buf.GetDeviceBuffer())}, @@ -546,7 +737,14 @@ class TestCkTileMxGemmPipeline : public ::testing::Test {static_cast(1), static_cast(num_scale_k)}); // Copy scale_b data (our scale_b is (N, num_scale_k) row-major, // reference expects (num_scale_k, N) col-major, which is the same memory layout) - std::copy(scale_b.mData.begin(), scale_b.mData.end(), scale_b_ref.mData.begin()); + // Truncate scale_a to actual N (not padded) + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < num_scale_k; ++k) + { + scale_b_ref(k, n) = scale_b(n, k); + } + } // Truncate scale_a to actual M (not padded) HostTensor scale_a_ref( @@ -582,4 +780,275 @@ class TestCkTileMxGemmPipeline : public ::testing::Test rtol_atol.at(number<1>{})); EXPECT_TRUE(pass); } + + // All-GPU validation path for the fp4 (pk_fp4_t) MX GEMM. + // + // Unlike Run(), this never materializes the A/B/C tensors on the host: + // - A/B are generated directly on device with a deterministic fp4 fill. + // - the reference is computed on device by reference_mx_gemm_gpu. + // - the comparison is done on device by ck::profiler::gpu_verify. + // Only the tiny e8m0 scales touch the host (for pre-shuffle + an unshuffled copy that the + // device reference consumes). + void RunAllGpu(const int M, const int N, const int K, const int kbatch = 1) + { + if constexpr(!Derived::check_data_type()) + return; + + static_assert(std::is_same_v && + std::is_same_v, + "RunAllGpu currently supports pk_fp4_t A/B only."); + // The GPU reference (reference_mx_gemm_gpu) hardcodes these layouts; guard so it cannot be + // silently misused with a layout it does not handle. + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, + "RunAllGpu / reference_mx_gemm_gpu assume RowMajor-A, ColumnMajor-B, " + "RowMajor-C."); + + static_assert(PipelineType != MxGemmPipelineType::WeightPreshuffle); + +#if !defined(CK_USE_GFX950) + (void)M; + (void)N; + (void)K; + (void)kbatch; + GTEST_SKIP() << "RunAllGpu requires CK_USE_GFX950."; +#else + using namespace ck_tile::literals; + constexpr long kIntMax = 2147483647L; // INT_MAX + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + return col; + else + return row; + } + else + return stride; + }; + + constexpr ck_tile::index_t psize = ck_tile::numeric_traits::PackedSize; // 2 + static_assert(psize == 2, + "RunAllGpu byte-sizing and reference_mx_gemm_kernel's a_ptr[a_lin/2] " + "addressing assume pk_fp4_t PackedSize == 2."); + + bool pass = true; + long total_MN = 0; + + // Strides are K/N here (small); keep them as index_t to match the kernel args, and + // make the size_t->index_t narrowing explicit. + const ck_tile::index_t stride_A = + static_cast(f_get_default_stride(M, K, 0, ALayout{})); // K + const ck_tile::index_t stride_B = + static_cast(f_get_default_stride(K, N, 0, BLayout{})); // K + const ck_tile::index_t stride_C = + static_cast(f_get_default_stride(M, N, 0, CLayout{})); // N + + ASSERT_EQ(K % ScaleBlockSize, 0) << "K must be a multiple of ScaleBlockSize for MX GEMM"; + const ck_tile::index_t num_scale_k = K / ScaleBlockSize; + ASSERT_EQ(num_scale_k % (K_Warp_Tile / ScaleBlockSize), 0) + << "K must be a multiple of K_Warp_Tile (" << K_Warp_Tile + << ") for MX GEMM. Pad the scale data."; + const ck_tile::index_t scale_padded_M = ck_tile::integer_least_multiple( + static_cast(M), static_cast(M_Tile)); + + // int32-safety: the property under test for the M-decomposition. The predicate is + // "largest 0-based element offset fits in a signed 32-bit int", i.e. offset <= INT_MAX. + const long MN = static_cast(M) * N; + const long A_elems = static_cast(M) * K; + const long B_elems = static_cast(K) * N; + const long C_off = static_cast(M - 1) * stride_C + (N - 1); + const long A_off = static_cast(M - 1) * stride_A + (K - 1); + const long B_off = static_cast(N - 1) * stride_B + (K - 1); + const long c_bytes = MN * static_cast(sizeof(CDataType)); + std::cout << "[int32-safety] M=" << M << " N=" << N << " K=" << K << " M*N=" << MN + << " A_elems=" << A_elems << " B_elems=" << B_elems << " C_off=" << C_off + << " A_off=" << A_off << " B_off=" << B_off << " C_bytes=" << c_bytes + << " (INT_MAX=" << kIntMax << ")" << std::endl; + // Note (not an assert): the C *byte* span can exceed INT_MAX even when the element + // count is int32-safe. We deliberately let the run proceed -- if any internal byte + // offset overflows, gpu_verify will flag it, which is exactly what we want to discover. + if(c_bytes > kIntMax) + std::cout << "[int32-safety][note] C byte span (" << c_bytes + << ") exceeds INT_MAX; if verification fails, byte-offset overflow is the " + "prime suspect." + << std::endl; + ASSERT_LE(B_off, kIntMax) << "B offset exceeds INT_MAX"; + total_MN += MN; + + const long a_bytes = (A_elems + psize - 1) / psize; + const long b_bytes = (B_elems + psize - 1) / psize; + + // Bound peak device memory (A + B + 2*C + scales). Skip cleanly rather + // than aborting via hip_check_error if the device cannot hold test shapes. + { + std::size_t free_b = 0, total_b = 0; + ck_tile::hip_check_error(hipMemGetInfo(&free_b, &total_b)); + const std::size_t need = static_cast(a_bytes) + + static_cast(b_bytes) + + 2u * static_cast(c_bytes) + (64u << 20); + if(free_b < need) + GTEST_SKIP() << "insufficient device memory (need " << need << " B, free " << free_b + << " B)"; + } + + auto a_dev = std::make_unique(static_cast(a_bytes)); + auto b_dev = std::make_unique(static_cast(b_bytes)); + auto c_dev = std::make_unique(static_cast(c_bytes)); + auto c_ref_dev = std::make_unique(static_cast(c_bytes)); + c_dev->SetZero(); + c_ref_dev->SetZero(); + + // GPU fill A/B (deterministic, fp4-correct). Same device buffers feed both the kernel + // and the reference, so the fill need not bit-match any host RNG. + fill_pk_fp4_uniform( + reinterpret_cast(a_dev->GetDeviceBuffer()), a_bytes, 11939u); + fill_pk_fp4_uniform( + reinterpret_cast(b_dev->GetDeviceBuffer()), b_bytes, 11940u); + ck_tile::hip_check_error(hipDeviceSynchronize()); // surface fill faults at the fill site + + // e8m0 scales (tiny, host-built). The range is + // deliberately narrow ([0.25,1.0] scales, [-3,3) fp4 fill) so that K up to 4096 cannot + // overflow the fp16 output (worst case K*9 = 36864 < 65504); gpu_verify counts matched + // infinities as errors, so an overflow would otherwise be a false failure. + ck_tile::HostTensor scale_a( + {static_cast(scale_padded_M), static_cast(num_scale_k)}, + {static_cast(num_scale_k), static_cast(1)}); + ck_tile::HostTensor scale_b( + {static_cast(N), static_cast(num_scale_k)}, + {static_cast(num_scale_k), static_cast(1)}); + { + std::mt19937 gen(11941u); + std::uniform_real_distribution dist(0.25f, 1.0f); + for(auto& s : scale_a.mData) + s = AScaleDataType{dist(gen)}; + for(auto& s : scale_b.mData) + s = BScaleDataType{dist(gen)}; + } + + // gfx950 scale pre-shuffle. NOTE: this must stay in sync with the identical block in + // Run() -- the kernel-input layout and the reference-input layout must agree. + constexpr ck_tile::index_t MPerXdl = M_Warp_Tile; + constexpr ck_tile::index_t NPerXdl = N_Warp_Tile; + constexpr ck_tile::index_t KPerXdl = K_Warp_Tile; + constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * MPerXdl); + constexpr ck_tile::index_t NIterPerWarp = N_Tile / (N_Warp * NPerXdl); + constexpr ck_tile::index_t KIterPerWarp = K_Tile / KPerXdl; + + constexpr ck_tile::index_t MXdlPackEff = + (MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t NXdlPackEff = + (NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t KXdlPackEff = + (KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1; + + constexpr ck_tile::index_t XdlMNThread = M_Warp_Tile; + constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread; + + ck_tile::HostTensor scale_a_shuffled( + {static_cast(scale_padded_M / MXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), static_cast(1)}); + ck_tile::HostTensor scale_b_shuffled( + {static_cast(N / NXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), static_cast(1)}); + + ck_tile::preShuffleScaleBuffer_gfx950( + scale_a.mData.data(), scale_a_shuffled.mData.data(), scale_padded_M, num_scale_k, true); + ck_tile::preShuffleScaleBuffer_gfx950( + scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k, true); + + // Device scale buffers: shuffled feed the kernel, unshuffled feed the reference. + auto scale_a_shuf_dev = std::make_unique( + scale_a_shuffled.get_element_space_size_in_bytes()); + auto scale_b_shuf_dev = std::make_unique( + scale_b_shuffled.get_element_space_size_in_bytes()); + scale_a_shuf_dev->ToDevice(scale_a_shuffled.data()); + scale_b_shuf_dev->ToDevice(scale_b_shuffled.data()); + + auto scale_a_ref_dev = + std::make_unique(scale_a.get_element_space_size_in_bytes()); + auto scale_b_ref_dev = + std::make_unique(scale_b.get_element_space_size_in_bytes()); + scale_a_ref_dev->ToDevice(scale_a.data()); + scale_b_ref_dev->ToDevice(scale_b.data()); + + // Launch kernel + ck_tile::MxGemmHostArgs<1, 1, 0> args( + {static_cast(a_dev->GetDeviceBuffer())}, + {static_cast(scale_a_shuf_dev->GetDeviceBuffer())}, + {static_cast(b_dev->GetDeviceBuffer())}, + {static_cast(scale_b_shuf_dev->GetDeviceBuffer())}, + {}, + c_dev->GetDeviceBuffer(), + kbatch, + M, + N, + K, + {stride_A}, + {stride_B}, + {}, + stride_C); + + invoke_mx_gemm(args, ck_tile::stream_config{nullptr, false}); + + ck_tile::hip_check_error(hipDeviceSynchronize()); + + // GPU reference on the same device A/B buffers. + ck_tile::reference_mx_gemm_gpu( + reinterpret_cast(a_dev->GetDeviceBuffer()), + reinterpret_cast(b_dev->GetDeviceBuffer()), + reinterpret_cast(scale_a_ref_dev->GetDeviceBuffer()), + reinterpret_cast(scale_b_ref_dev->GetDeviceBuffer()), + reinterpret_cast(c_ref_dev->GetDeviceBuffer()), + M, + N, + K, + num_scale_k, + ScaleBlockSize); + ck_tile::hip_check_error(hipDeviceSynchronize()); + + // GPU verify with explicit MX tolerance (auto tolerance defaults too tight for MX). + const float max_acc = ck::profiler::gpu_reduce_max(c_ref_dev->GetDeviceBuffer(), + static_cast(MN)); + // The reference must be non-degenerate, else error_count==0 is a vacuous pass. + ASSERT_GT(max_acc, 0.0f) << "GPU reference output is all-zero"; + const auto rtol_atol = + calculate_rtol_atol(K, kbatch, max_acc); + const auto res = ck::profiler::gpu_verify(c_dev->GetDeviceBuffer(), + c_ref_dev->GetDeviceBuffer(), + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{}), + static_cast(MN)); + + // Positive liveness check on the *device* output. res.all_zero ANDs device- and + // reference-zeroness, and the reference is never zero here, so it cannot detect a no-op + // kernel on its own -- reduce the device buffer directly. + const float c_dev_absmax = ck::profiler::gpu_reduce_max( + c_dev->GetDeviceBuffer(), static_cast(MN)); + + std::cout << "[verify] errors=" << res.error_count << " max_error=" << res.max_error + << " c_dev_absmax=" << c_dev_absmax << " max_acc=" << max_acc + << " rtol=" << rtol_atol.at(ck_tile::number<0>{}) + << " atol=" << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + + EXPECT_EQ(res.error_count, 0ull) << "produced mismatched results"; + EXPECT_GT(c_dev_absmax, 0.0f) << "produced an all-zero device output"; + pass &= (res.error_count == 0 && c_dev_absmax > 0.0f); + + std::cout << "[int32-safety] aggregate total_M*N=" << total_MN << " (INT_MAX=" << kIntMax + << ") -> decomposition is the variable under test" << std::endl; + EXPECT_TRUE(pass); +#endif // CK_USE_GFX950 + } }; diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp deleted file mode 100644 index f566ecb38b..0000000000 --- a/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" -#include "ck_tile/host/check_err.hpp" -#include "ck_tile/host/reference/reference_gemm.hpp" -#include "ck_tile/host/tensor_shuffle_utils.hpp" -#include "ck_tile/host/mx_processing.hpp" -#include "test_mx_gemm_config.hpp" -#include "test_mx_gemm_instance.hpp" - -template -static constexpr auto is_row_major(Layout) -{ - return ck_tile::bool_constant< - std::is_same_v, ck_tile::tensor_layout::gemm::RowMajor>>{}; -} - -template -auto calculate_rtol_atol_mx(ck_tile::index_t K, float max_accumulated_value) -{ - using ComputeType = - std::conditional_t; - const auto rtol = ck_tile::get_relative_threshold(K); - const auto atol = ck_tile::get_absolute_threshold( - max_accumulated_value, K); - return ck_tile::make_tuple(rtol, atol); -} - -template -class TestMxGemmUtil : public ::testing::Test -{ - protected: - using ADataType = std::tuple_element_t<0, Tuple>; - using BDataType = std::tuple_element_t<1, Tuple>; - using GemmConfig = std::tuple_element_t<2, Tuple>; - using ALayout = std::tuple_element_t<3, Tuple>; - using BLayout = std::tuple_element_t<4, Tuple>; - using CLayout = std::tuple_element_t<5, Tuple>; - - using AccDataType = float; - using CDataType = ck_tile::fp16_t; - using ScaleType = ck_tile::e8m0_t; - using ScaleM = ck_tile::MXScalePointer; - using ScaleN = ck_tile::MXScalePointer; - - void - Run(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t k_batch = 1) - { - const ck_tile::index_t scale_k_size = K / 32; - const ck_tile::index_t stride_A = - ck_tile::get_default_stride(M, K, 0, is_row_major(ALayout{})); - const ck_tile::index_t stride_B = - ck_tile::get_default_stride(K, N, 0, is_row_major(BLayout{})); - const ck_tile::index_t stride_C = - ck_tile::get_default_stride(M, N, 0, is_row_major(CLayout{})); - // Scales use fixed layouts independent of A/B layout: - // scale A is row-major [M, K/32], and scale B is column-major [K/32, N]. - const ck_tile::index_t stride_scale_a = - ck_tile::get_default_stride(M, scale_k_size, 0, ck_tile::bool_constant{}); - const ck_tile::index_t stride_scale_b = - ck_tile::get_default_stride(scale_k_size, N, 0, ck_tile::bool_constant{}); - - ck_tile::HostTensor a_host( - ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_host( - ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(BLayout{}))); - ck_tile::HostTensor c_host( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - ck_tile::HostTensor scale_a_host(ck_tile::host_tensor_descriptor( - M, scale_k_size, stride_scale_a, ck_tile::bool_constant{})); - ck_tile::HostTensor scale_b_host(ck_tile::host_tensor_descriptor( - scale_k_size, N, stride_scale_b, ck_tile::bool_constant{})); - - std::mt19937 gen(42); - std::uniform_int_distribution fill_seed(0, 500); - - auto gen_scales = [&](auto& scales, float range_min, float range_max) { - // e8m0_t is basically an exponent of float32 - ck_tile::HostTensor pow2(scales.get_lengths()); - ck_tile::FillUniformDistributionIntegerValue{ - range_min, range_max, fill_seed(gen)}(pow2); - scales.ForEach([&](auto& self, const auto& i) { - self(i) = static_cast(std::exp2(pow2(i))); - }); - }; - - ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(a_host); - ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(b_host); - gen_scales(scale_a_host, -2, 2); - gen_scales(scale_b_host, -2, 2); - - // Compute effective XdlPack sizes based on GemmConfig tile dimensions - constexpr ck_tile::index_t MPerXdl = GemmConfig::M_Warp_Tile; - constexpr ck_tile::index_t NPerXdl = GemmConfig::N_Warp_Tile; - constexpr ck_tile::index_t KPerXdl = GemmConfig::K_Warp_Tile; - constexpr ck_tile::index_t MIterPerWarp = - GemmConfig::M_Tile / (GemmConfig::M_Warp * MPerXdl); - constexpr ck_tile::index_t NIterPerWarp = - GemmConfig::N_Tile / (GemmConfig::N_Warp * NPerXdl); - constexpr ck_tile::index_t KIterPerWarp = GemmConfig::K_Tile / KPerXdl; - - constexpr ck_tile::index_t MXdlPackEff = - (MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1; - constexpr ck_tile::index_t NXdlPackEff = - (NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1; - constexpr ck_tile::index_t KXdlPackEff = - (KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1; - - constexpr ck_tile::index_t XdlMNThread = GemmConfig::M_Warp_Tile; - constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread; - - // Pack scales into int32_t for GPU consumption - auto scale_a_packed = - ck_tile::packScalesMNxK(scale_a_host, - true); - auto scale_b_packed = - ck_tile::packScalesMNxK(scale_b_host, - false); - - const auto b_host_for_device = [&]() { - if constexpr(GemmConfig::Preshuffle) - if constexpr(GemmConfig::TiledMMAPermuteN) - return ck_tile::shuffle_b_permuteN(b_host); - else - return ck_tile::shuffle_b(b_host); - else - return b_host; - }(); - - const auto scale_a_host_for_device = [&]() { - if constexpr(GemmConfig::Preshuffle) - return ck_tile::preShuffleScale(scale_a_host, true); - else - return scale_a_packed; - }(); - - constexpr ck_tile::index_t XdlNThread = GemmConfig::N_Warp_Tile; - constexpr ck_tile::index_t NPerBlock = GemmConfig::N_Tile; - constexpr ck_tile::index_t NWarp = GemmConfig::N_Warp; - - const auto scale_b_host_for_device = [&]() { - if constexpr(GemmConfig::Preshuffle) - if constexpr(GemmConfig::TiledMMAPermuteN) - return ck_tile::preShuffleScalePermuteN( - scale_b_host, false); - else - return ck_tile::preShuffleScale(scale_b_host, false); - else - return scale_b_packed; - }(); - - ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_dev_buf(b_host_for_device.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem scale_a_dev_buf( - scale_a_host_for_device.get_element_space_size_in_bytes()); - ck_tile::DeviceMem scale_b_dev_buf( - scale_b_host_for_device.get_element_space_size_in_bytes()); - - a_dev_buf.ToDevice(a_host.data()); - b_dev_buf.ToDevice(b_host_for_device.data()); - c_dev_buf.SetZero(); - scale_a_dev_buf.ToDevice(scale_a_host_for_device.data()); - scale_b_dev_buf.ToDevice(scale_b_host_for_device.data()); - - ScaleM scale_m(reinterpret_cast(scale_a_dev_buf.GetDeviceBuffer())); - ScaleN scale_n(reinterpret_cast(scale_b_dev_buf.GetDeviceBuffer())); - - MXGemmHostArgs args(a_dev_buf.GetDeviceBuffer(), - b_dev_buf.GetDeviceBuffer(), - c_dev_buf.GetDeviceBuffer(), - k_batch, - M, - N, - K, - stride_A, - stride_B, - stride_C, - scale_m, - scale_n); - - mx_gemm_calc(args, ck_tile::stream_config{}); - - c_dev_buf.FromDevice(c_host.data()); - - ck_tile::HostTensor c_ref( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - c_ref.SetZero(); - ck_tile:: - reference_mx_gemm( - a_host, b_host, c_ref, scale_a_host, scale_b_host); - - const float max_accumulated_value = ck_tile::type_convert(c_ref.max()); - const auto rtol_atol = calculate_rtol_atol_mx( - K, max_accumulated_value); - const double rtol = rtol_atol.at(ck_tile::number<0>{}); - const double atol = rtol_atol.at(ck_tile::number<1>{}); - - bool pass = ck_tile::check_err(c_host, c_ref, "MX GEMM: Incorrect results!", rtol, atol); - - EXPECT_TRUE(pass); - } -}; diff --git a/test/ck_tile/grouped_gemm_mx/CMakeLists.txt b/test/ck_tile/grouped_gemm_mx/CMakeLists.txt index 42bcb2a0ae..3cd34fe00d 100644 --- a/test/ck_tile/grouped_gemm_mx/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_mx/CMakeLists.txt @@ -11,13 +11,33 @@ endif() # Currently TDM is only supported on gfx1250 if(GPU_TARGETS MATCHES "gfx1250") - add_gtest_executable(test_ck_tile_grouped_gemm_mx_tdm test_mx_grouped_gemm.cpp) + add_gtest_executable(test_ck_tile_grouped_gemm_mx_tdm test_mx_grouped_gemm_wmma_tdm.cpp) # target_compile_options(test_ck_tile_grouped_gemm_mx_tdm PRIVATE --save-temps) add_gtest_executable(test_ck_tile_grouped_gemm_mx_flatmm_tdm test_grouped_gemm_mx_flatmm_tdm.cpp) target_compile_options(test_ck_tile_grouped_gemm_mx_flatmm_tdm PRIVATE ${GROUPED_MX_FLATMM_COMPILE_OPTIONS}) endif() +if(GPU_TARGETS MATCHES "gfx950") + add_gtest_executable(test_ck_tile_grouped_gemm_mx_comp_async test_mx_grouped_gemm_comp_async.cpp) + # target_compile_options(test_ck_tile_grouped_gemm_mx_tdm PRIVATE --save-temps) + + # Large-tensor / decomposition cases. Built so it does not bitrot, but NOT + # registered with ctest (plain add_executable, no add_test) -> excluded from the default CI + # test pass; run explicitly (mirrors the CK *_large_cases convention). Allocates multi-GB + # device buffers to exercise the int32 element-count / decomposition boundary. + set(_mx_large_cases_src test_mx_grouped_gemm_comp_async_large_cases.cpp) + set_source_files_properties(${_mx_large_cases_src} PROPERTIES LANGUAGE HIP) + add_executable(test_ck_tile_grouped_gemm_mx_comp_async_large_cases ${_mx_large_cases_src}) + set_property(TARGET test_ck_tile_grouped_gemm_mx_comp_async_large_cases + PROPERTY HIP_ARCHITECTURES ${SUPPORTED_GPU_TARGETS}) + target_compile_options(test_ck_tile_grouped_gemm_mx_comp_async_large_cases + PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(test_ck_tile_grouped_gemm_mx_comp_async_large_cases + PRIVATE gtest_main getopt::getopt) + add_dependencies(tests test_ck_tile_grouped_gemm_mx_comp_async_large_cases) +endif() + if(GPU_TARGETS MATCHES "gfx950|gfx1250") add_gtest_executable(test_ck_tile_grouped_gemm_mx_flatmm_non_tdm test_grouped_gemm_mx_flatmm_non_tdm.cpp) diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm.cpp b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm.cpp deleted file mode 100644 index 3b5c9d6018..0000000000 --- a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm.cpp +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "gtest/gtest.h" - -#include "ck_tile/host.hpp" -#include "test_mx_grouped_gemm_util.hpp" - -using F8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using F16 = ck_tile::half_t; -using F32 = float; -using BF16 = ck_tile::bf16_t; -using Row = ck_tile::tensor_layout::gemm::RowMajor; -using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -using True = ck_tile::bool_constant; -using False = ck_tile::bool_constant; -using E8M0 = ck_tile::e8m0_t; -using Intrawave = ck_tile::integral_constant; -using CompTDMV1 = ck_tile::integral_constant; -using CompTDMV2 = ck_tile::integral_constant; -template -using ScaleBS = ck_tile::integral_constant; - -// clang-format off -using KernelTypes = ::testing::Types< - // ALayout, BLayout, CLayout, ADataType, BDataType, AScaleDataType, BScaleDataType, AccDataType, CDataType, Persistent, Scheduler, PipelineType, ScaleBlockSize -std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV1, ScaleBS<32>>, -std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV1, ScaleBS<32>>, -std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV1, ScaleBS<32>>, -std::tuple< Col, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV1, ScaleBS<32>>, -std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV2, ScaleBS<32>>, -std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV2, ScaleBS<32>>, -std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV2, ScaleBS<32>>, -std::tuple< Col, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV2, ScaleBS<32>> ->; -// clang-format on - -TYPED_TEST_SUITE(TestCkTileMxGroupedGemm, KernelTypes); - -#include "test_mx_grouped_gemm_ut_cases.inc" diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_comp_async.cpp b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_comp_async.cpp new file mode 100644 index 0000000000..5c5de3a1b3 --- /dev/null +++ b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_comp_async.cpp @@ -0,0 +1,26 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_mx_grouped_gemm_util.hpp" +#include "test_mx_grouped_gemm_pipeline_kernel_types.hpp" + +template +class TestCkTileMxGemmPipelineCompAsync + : public TestCkTileMxGroupedGemm> +{ + public: + static constexpr bool check_data_type() { return true; } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsync + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsync, KernelTypesMxGemmCompAsync); + +#include "test_mx_grouped_gemm_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_comp_async_large_cases.cpp b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_comp_async_large_cases.cpp new file mode 100644 index 0000000000..5241812ed9 --- /dev/null +++ b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_comp_async_large_cases.cpp @@ -0,0 +1,34 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Large-tensor / decomposition cases for the fp4 (a4w4) grouped MX GEMM (ROCM-22075). +// +// This is a SEPARATE executable from test_ck_tile_grouped_gemm_mx_comp_async and is intentionally +// NOT registered with ctest (its CMake target uses add_executable, not add_gtest_executable), so +// it is excluded from the default CI test pass. It is run explicitly (mirroring the CK +// *_large_cases / RUN_*_LARGE_CASES_TESTS convention) because it allocates multi-GB device buffers +// (per-group C ~2.5 GB) to exercise the int32 element-count / decomposition boundary. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_mx_grouped_gemm_util.hpp" +#include "test_mx_grouped_gemm_pipeline_kernel_types.hpp" + +template +class TestCkTileMxGemmPipelineCompAsync + : public TestCkTileMxGroupedGemm> +{ + public: + static constexpr bool check_data_type() { return true; } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsync + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsync, KernelTypesMxGemmCompAsync); + +#include "test_mx_grouped_gemm_largeM_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_largeM_cases.inc b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_largeM_cases.inc new file mode 100644 index 0000000000..0fc2dd1872 --- /dev/null +++ b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_largeM_cases.inc @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +// Large-tensor decomposition / INT_MAX element-count validation for fp4 (a4w4) grouped MX GEMM +// These cases exercise the all-GPU validation path (RunAllGpu): GPU data +// init, GPU reference, GPU compare. They apply only to the fp4 non-persistent CompAsync kernel +// row; for every other type in the suite they skip (RunAllGpu is fp4-only and the discarded +// if-constexpr branch is never instantiated for non-fp4 types). +// +// These tests are deliberately NOT registered with ctest (see the dedicated *_large_cases target +// in CMakeLists.txt) so they do not run in the default CI test pass; they are run explicitly +// (mirroring the CK *_large_cases / RUN_*_LARGE_CASES_TESTS convention). + +// Stage A: cross-validate the new GPU reference against the trusted host reference on identical +// small shapes. Both Run() (host reference + check_err) and RunAllGpu() (GPU reference + +// gpu_verify) must pass before the GPU reference can be trusted at large scale. +TYPED_TEST(TEST_SUITE_NAME, GpuRefCrossValidate_Small) +{ + using ADataType = typename TestFixture::ADataType; + if constexpr(std::is_same_v && + TestFixture::PipelineType == MxGemmPipelineType::CompAsync && + !TestFixture::Persistent) + { + const int group_count = 2; + const int kbatch = 1; + const std::vector Ms{256, 512}; + const std::vector Ns{512, 1024}; + const std::vector Ks{512, 512}; + + // Trusted host reference path first, then the new all-GPU path on the same shapes. + this->Run(Ms, Ns, Ks, kbatch, group_count); + this->RunAllGpu(Ms, Ns, Ks, kbatch, group_count); + } + else + { + GTEST_SKIP() << "GPU-reference cases apply only to fp4 non-persistent CompAsync."; + } +} + +// Stage B: the decomposition / int32-overflow case. Minimal shape that crosses the boundary: +// 2 groups, each per-group M*N = 100352*12288 = 1,233,125,376 < INT_MAX (and C byte span +// 2,466,250,752 > INT_MAX), aggregate total M*N = 2,466,250,752 > INT_MAX. This proves the +// host M-decomposition is what keeps every per-group buffer int32-safe (a single fused tensor +// would overflow the int32 element count), while the kernel still addresses C correctly even +// though the per-group C *byte* span exceeds INT_MAX. K is kept small (compute is irrelevant to +// the addressing property under test). +TYPED_TEST(TEST_SUITE_NAME, LargeM_decomposition_int32) +{ + using ADataType = typename TestFixture::ADataType; + if constexpr(std::is_same_v && + TestFixture::PipelineType == MxGemmPipelineType::CompAsync && + !TestFixture::Persistent) + { + const int group_count = 2; + const int kbatch = 1; + const std::vector Ms(group_count, 100352); + const std::vector Ns(group_count, 12288); + const std::vector Ks(group_count, 512); + + this->RunAllGpu(Ms, Ns, Ks, kbatch, group_count); + } + else + { + GTEST_SKIP() << "GPU-reference cases apply only to fp4 non-persistent CompAsync."; + } +} diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_pipeline_kernel_types.hpp b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_pipeline_kernel_types.hpp new file mode 100644 index 0000000000..95b1cf4387 --- /dev/null +++ b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_pipeline_kernel_types.hpp @@ -0,0 +1,78 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_mx_grouped_gemm_util.hpp" + +using F4 = ck_tile::pk_fp4_t; +using F8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using F16 = ck_tile::half_t; +using F32 = float; +using BF16 = ck_tile::bf16_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; +using E8M0 = ck_tile::e8m0_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +using CompTDMV1 = ck_tile::integral_constant; +using CompTDMV2 = ck_tile::integral_constant; +using CompAsync = ck_tile::integral_constant; +using CompEightWaves = + ck_tile::integral_constant; +using WeightPreshuffle = + ck_tile::integral_constant; + +using I16 = ck_tile::number<16>; +using I32 = ck_tile::number<32>; +using I64 = ck_tile::number<64>; +using I128 = ck_tile::number<128>; +using I256 = ck_tile::number<256>; +using I512 = ck_tile::number<512>; + +template +using ScaleBS = ck_tile::integral_constant; + +// clang-format off +// MX GEMM kernel types using TDM pipeline with scale support +// Tuple format: +// ALayout, BLayout, CLayout, ADataType, BDataType, AScaleDataType, BScaleDataType, AccDataType, CDataType, Persistent, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, N_TileSize, PipelineType +using KernelTypesMxGemmCompTDMWmma = ::testing::Types< + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV1, ScaleBS<32>>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV1, ScaleBS<32>>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV1, ScaleBS<32>>, + std::tuple< Col, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV1, ScaleBS<32>>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV2, ScaleBS<32>>, + std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV2, ScaleBS<32>>, + std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV2, ScaleBS<32>>, + std::tuple< Col, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV2, ScaleBS<32>> +>; + +using KernelTypesMxGemmCompAsync = ::testing::Types< + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I256, I16, I16, CompAsync, ScaleBS<32>>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, False, I64, I64, I256, I16, I16, CompAsync, ScaleBS<32>>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, True, I64, I64, I256, I16, I16, CompAsync, ScaleBS<32>>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, True, I64, I64, I256, I16, I16, CompAsync, ScaleBS<32>>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I128, I256, I128, I16, I16, CompEightWaves, ScaleBS<32>>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, False, I128, I256, I128, I16, I16, CompEightWaves, ScaleBS<32>>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, True, I128, I256, I128, I16, I16, CompEightWaves, ScaleBS<32>>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, True, I128, I256, I128, I16, I16, CompEightWaves, ScaleBS<32>>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I128, I256, I256, I16, I16, WeightPreshuffle, ScaleBS<32>>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, False, I128, I512, I256, I16, I16, WeightPreshuffle, ScaleBS<32>>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, True, I128, I256, I256, I16, I16, WeightPreshuffle, ScaleBS<32>>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, True, I128, I512, I256, I16, I16, WeightPreshuffle, ScaleBS<32>>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I128, I256, I256, I16, I16, WeightPreshuffle, ScaleBS<32>, True>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, False, I128, I512, I256, I16, I16, WeightPreshuffle, ScaleBS<32>, True>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, True, I128, I256, I256, I16, I16, WeightPreshuffle, ScaleBS<32>, True>, + std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, True, I128, I512, I256, I16, I16, WeightPreshuffle, ScaleBS<32>, True> +>; +// clang-format on diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_ut_cases.inc b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_ut_cases.inc index c41350db33..ec4a7183e3 100644 --- a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_ut_cases.inc +++ b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_ut_cases.inc @@ -3,7 +3,7 @@ #pragma once -TYPED_TEST(TestCkTileMxGroupedGemm, Basic) +TYPED_TEST(TEST_SUITE_NAME, Basic) { const int group_count = 4; const int kbatch = 1; @@ -14,8 +14,8 @@ TYPED_TEST(TestCkTileMxGroupedGemm, Basic) for(int i = 0; i < group_count; i++) { Ms.push_back(256 + 256 * i); - Ns.push_back(256 + 512 * i); - Ks.push_back(512 + 128 * i); + Ns.push_back(512 + 512 * i); + Ks.push_back(512 + TestFixture::K_Tile * i); } this->Run(Ms, Ns, Ks, kbatch, group_count); diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_util.hpp index 3d22b6409a..797196911c 100644 --- a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_util.hpp @@ -14,11 +14,44 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck/library/utility/gpu_verification.hpp" + +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if CK_TILE_USE_WMMA +#if defined(CK_USE_GFX1250) + constexpr bool is_8bit = std::is_same_v || + std::is_same_v || + std::is_same_v; + constexpr bool is_mxtype = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32 && is_mxtype) + { + return 128; + } + else + { + return is_8bit ? 64 : 32; + } +#else + return 16; +#endif +#else + if constexpr(M_Warp_Tile == 32) + return 64; + else + return 128; +#endif +} enum struct MxGemmPipelineType { CompTDMV1, - CompTDMV2 + CompTDMV2, + CompAsync, + CompEightWaves, + WeightPreshuffle }; template @@ -42,61 +75,137 @@ struct MxGemmPipelineTypeSelector static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV2"; } }; -/** - * @brief Pre-shuffle scale buffer for gfx1250 wmma mx scale instruction. - * - * Reorganizes the scale data from row-major (MN x K) layout to the hardware-specific - * layout expected by the gfx1250 wmma instruction. - * - * @tparam ScaleType Scale data type (e.g., e8m0_t) - * @tparam ScaleBlockSize The block size for microscaling (e.g., 32) - * @tparam KStride Whether K is the fast-moving dimension - */ -template -void preShuffleScaleBuffer_gfx1250(const ScaleType* src, - ScaleType* dst, - ck_tile::index_t MN, - ck_tile::index_t K) +template +struct MxGemmPipelineTypeSelector { - static_assert(ScaleBlockSize == 32 && sizeof(ScaleType) == 1, - "wrong! only support 8-bit scale with ScaleBlockSize=32"); + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync; + using pipeline = ck_tile::GemmPipelineAgBgCrCompAsync; - constexpr ck_tile::index_t MPerXdlops = 16; - constexpr ck_tile::index_t KPerXdlops = 128; + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsync"; } +}; - int MNPack = 2; - int KPack = 1; +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + using pipeline = ck_tile::GemmPipelineAgBgCrCompAsyncEightWaves; - int MNStep = MPerXdlops; - int KStep = KPerXdlops / ScaleBlockSize; + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompEightWaves"; } +}; - int K0 = K / KPack / KStep; +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using pipeline = ck_tile::MXGemmPreshufflePipelineAGmemBGmemCRegV1; - for(int mn = 0; mn < MN; ++mn) + static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffle"; } +}; + +template +struct MxGemmEpilogueTypeSelector +{ +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = ck_tile::TdmEpilogue; +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = ck_tile::TdmEpilogue; +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = ck_tile::CShuffleEpilogue; +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = ck_tile::CShuffleEpilogue; +}; + +template +struct MxGemmEpilogueTypeSelector +{ + using epilogue = std::conditional_t, + ck_tile::CShuffleEpilogue>; +}; + +template +struct Config +{ + static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_; + static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_; + static constexpr ck_tile::index_t N_Tile = N_Tile_; + static constexpr ck_tile::index_t N_Warp = N_Warp_; + static constexpr ck_tile::index_t BContiguousItemsPerAccess = + std::is_same_v ? 32 : 16; +}; + +// Deterministic per-element hash RNG for GPU data init. Returns a float in [-3, 3). +// The generic `fill_tensor_uniform_rand_fp_values` filler is NOT valid for ck_tile::pk_fp4_t +// (it converts a single float and duplicates it into both nibbles, and special-cases only the +// classic ck::f4x2_pk_t). We need two independent fp4 values per byte, so we fill directly. +// The narrow [-3,3) range keeps the fp16 GEMM output from overflowing at K up to 4096 (with the +// [0.25,1.0] scales used in RunAllGpu, worst case K*9 = 36864 < 65504). +__device__ inline float mx_fp4_fill_rand(unsigned int seed, unsigned long long idx) +{ + // splitmix64-style avalanche; deterministic given (seed, idx). + unsigned long long z = (idx + 1ULL) * 0x9E3779B97F4A7C15ULL + + static_cast(seed) * 0xD1B54A32D192ED03ULL; + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ULL; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBULL; + z ^= z >> 31; + const float u = + static_cast((z >> 40) & 0xFFFFFFULL) / static_cast(0x1000000); // [0,1) + return u * 6.0f - 3.0f; // [-3,3) +} + +// Fill a packed-fp4 buffer with two independent, deterministic random fp4 values per byte. +// `num_packed` is the number of pk_fp4_t elements (= total fp4 values / 2). +__global__ void +fill_pk_fp4_uniform_kernel(ck_tile::pk_fp4_t* __restrict__ ptr, long num_packed, unsigned int seed) +{ + const long idx0 = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const long nthr = static_cast(gridDim.x) * blockDim.x; + for(long i = idx0; i < num_packed; i += nthr) { - int iMNRepeat = mn / (MNStep * MNPack); - int tempmn = mn % (MNStep * MNPack); - - for(int k = 0; k < K; ++k) - { - int iKRepeat = k / (KStep * KPack); - int tempk = k % (KStep * KPack); - - int outputIndex = (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) + - (iKRepeat * KStep * KPack) * (MNStep * MNPack) + - tempmn * (KStep * KPack) + tempk; - - if constexpr(KStride) - { - dst[outputIndex] = src[mn * K + k]; - } - else - dst[outputIndex] = src[k * MN + mn]; - } + const float lo_f = rintf(mx_fp4_fill_rand(seed, static_cast(i) * 2ULL)); + const float hi_f = + rintf(mx_fp4_fill_rand(seed, static_cast(i) * 2ULL + 1ULL)); + const auto lo = ck_tile::float_to_mxfp4(lo_f, 1.0f); + const auto hi = ck_tile::float_to_mxfp4(hi_f, 1.0f); + ptr[i] = ck_tile::pk_fp4_t::_pack(lo, hi); } } -template +inline void fill_pk_fp4_uniform(ck_tile::pk_fp4_t* ptr, + long num_packed, + unsigned int seed, + hipStream_t stream = nullptr) +{ + constexpr int threads = 256; + constexpr long max_blocks = 65536; // grid-stride cap + const long needed = (num_packed + threads - 1) / threads; + const long blocks = needed < max_blocks ? needed : max_blocks; + fill_pk_fp4_uniform_kernel<<(blocks)), dim3(threads), 0, stream>>>( + ptr, num_packed, seed); + ck_tile::hip_check_error(hipGetLastError()); +} + +template class TestCkTileMxGroupedGemm : public ::testing::Test { protected: @@ -111,9 +220,36 @@ class TestCkTileMxGroupedGemm : public ::testing::Test using CDataType = std::tuple_element_t<8, Tuple>; using PersistentType = std::tuple_element_t<9, Tuple>; static constexpr bool Persistent = PersistentType::value; - static constexpr auto Scheduler = std::tuple_element_t<10, Tuple>::value; - static constexpr auto PipelineType = std::tuple_element_t<11, Tuple>::value; - static constexpr ck_tile::index_t ScaleBlockSize = std::tuple_element_t<12, Tuple>::value; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr auto PipelineType = std::tuple_element_t<15, Tuple>::value; + static constexpr ck_tile::index_t ScaleBlockSize = std::tuple_element_t<16, Tuple>::value; + static constexpr bool PermuteN = + ck_tile::tuple_element_or_default_t::value; + + static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<10, Tuple>{}; + static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<11, Tuple>{}; + static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<12, Tuple>{}; + + static constexpr ck_tile::index_t M_Warp_Tile = std::tuple_element_t<13, Tuple>{}; + static constexpr ck_tile::index_t N_Warp_Tile = std::tuple_element_t<14, Tuple>{}; + static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::max( + get_k_warp_tile(), get_k_warp_tile()); + + static constexpr ck_tile::index_t M_Warp = + PipelineType == MxGemmPipelineType::WeightPreshuffle + ? 1 + : (PipelineType == MxGemmPipelineType::CompEightWaves ? 4 : 2); + static constexpr ck_tile::index_t N_Warp = + PipelineType == MxGemmPipelineType::WeightPreshuffle ? 4 : 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr int kBlockPerCu = 1; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool Preshuffle = PipelineType == MxGemmPipelineType::WeightPreshuffle; // No D tensors for this test using DsLayout = ck_tile::tuple<>; @@ -123,42 +259,27 @@ class TestCkTileMxGroupedGemm : public ::testing::Test using AComputeDataType = ADataType; using BComputeDataType = BDataType; - struct GroupedGemKernelParam_Wmma - { - static const bool kPadM = false; - static const bool kPadN = false; - static const bool kPadK = false; - - static const int kBlockPerCu = 1; - static const ck_tile::index_t M_Tile = 64; - static const ck_tile::index_t N_Tile = 64; - static const ck_tile::index_t K_Tile = 128; - - static const ck_tile::index_t M_Warp = 2; - static const ck_tile::index_t N_Warp = 2; - static const ck_tile::index_t K_Warp = 1; - - static const ck_tile::index_t M_Warp_Tile = 32; - static const ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 128; - }; - using mx_grouped_gemm_kargs = ck_tile::MxGroupedGemmHostArgs<>; std::size_t get_workspace_size(const std::vector& gemm_descs) { return gemm_descs.size() * sizeof(ck_tile::MxGemmTransKernelArg<>); } - template + template bool invoke_mx_grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr) { - constexpr bool preshuffle = false; constexpr bool DoubleSmemBuffer = true; // TDM pipeline requires double smem buffer +#if defined(CK_USE_GFX1250) + constexpr ck_tile::index_t BlockedXDLNPerWarp = 1; constexpr bool TransposeC = std::is_same_v && - GroupedGemKernelParam::M_Warp_Tile == GroupedGemKernelParam::N_Warp_Tile; + M_Warp_Tile == N_Warp_Tile; +#elif defined(CK_USE_GFX950) + constexpr ck_tile::index_t BlockedXDLNPerWarp = Preshuffle ? 2 : 1; + constexpr bool TransposeC = false; +#endif static constexpr bool StructuredSparsity = false; static constexpr bool NumWaveGroup = 1; @@ -166,21 +287,15 @@ class TestCkTileMxGroupedGemm : public ::testing::Test constexpr ck_tile::index_t TileParitionerM01 = 4; using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + Preshuffle>; using UniversalGemmProblem = ck_tile::MxGemmPipelineProblem::pipeline; - using GemmEpilogue = ck_tile::TdmEpilogue< + using GemmEpilogueProblem = std::conditional_t< + Preshuffle && PermuteN, + ck_tile::PermuteNEpilogueProblem, /*VectorSizeC_*/ ck_tile::CShuffleEpilogueProblem>; + 1, /*kNumWaveGroups_*/ + false, /*FixedVectorSize_*/ + 1, /*VectorSizeC_*/ + BlockedXDLNPerWarp, /*BlockedXDLN_PerWarp_*/ + DoubleSmemBuffer, /*DoubleSmemBuffer*/ + AComputeDataType, /*AComputeDataType_*/ + BComputeDataType, /*BComputeDataType_*/ + !Preshuffle>>; + + using GemmEpilogue = typename MxGemmEpilogueTypeSelector::epilogue; using Kernel = ck_tile::MxGroupedGemmKernel; @@ -266,7 +405,7 @@ class TestCkTileMxGroupedGemm : public ::testing::Test ck_tile::ignore = ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{}, grids, blocks, @@ -303,36 +442,9 @@ class TestCkTileMxGroupedGemm : public ::testing::Test return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } - static constexpr bool check_data_type() - { - - // Validate scale type / data type combination - constexpr bool a_is_f4 = std::is_same_v; - constexpr bool b_is_f4 = std::is_same_v; - constexpr bool a_scale_e8m0 = std::is_same_v; - constexpr bool b_scale_e8m0 = std::is_same_v; - if constexpr(!a_is_f4 && !a_scale_e8m0) - return false; - if constexpr(!b_is_f4 && !b_scale_e8m0) - return false; - - // Check hardware WMMA support for the fixed warp tile (32x32x128) -#if defined(CK_USE_GFX1250) - return ck_tile::has_wmma_traits_v; -#else - return false; -#endif - } - void SetUp() override { - if constexpr(!check_data_type()) + if constexpr(!Derived::check_data_type()) { GTEST_SKIP() << "Unsupported data type / layout combination for mx_grouped_gemm."; } @@ -345,7 +457,7 @@ class TestCkTileMxGroupedGemm : public ::testing::Test const int kbatch = 1, const int group_count = 16) { - if constexpr(!check_data_type()) + if constexpr(!Derived::check_data_type()) return; using namespace ck_tile::literals; @@ -445,8 +557,42 @@ class TestCkTileMxGroupedGemm : public ::testing::Test << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc << " KBatch: " << kbatch << std::endl; - ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); - ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + // For pk_fp4_t each byte packs two 4-bit elements; the generic filler + // converts a single float and duplicates it into both nibbles. + // Generate two independent random values per byte instead. + if constexpr(std::is_same_v) + { + std::mt19937 gen(11939); + std::uniform_real_distribution dis(-5.f, 5.f); + for(auto& elem : a_m_k_tensors[i].mData) + { + auto lo = ck_tile::float_to_mxfp4(std::round(dis(gen)), 1.f); + auto hi = ck_tile::float_to_mxfp4(std::round(dis(gen)), 1.f); + elem = ck_tile::pk_fp4_t::_pack(lo, hi); + } + } + else + { + ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11939}( + a_m_k_tensors[i]); + } + + if constexpr(std::is_same_v) + { + std::mt19937 gen(11940); + std::uniform_real_distribution dis(-5.f, 5.f); + for(auto& elem : b_k_n_tensors[i].mData) + { + auto lo = ck_tile::float_to_mxfp4(std::round(dis(gen)), 1.f); + auto hi = ck_tile::float_to_mxfp4(std::round(dis(gen)), 1.f); + elem = ck_tile::pk_fp4_t::_pack(lo, hi); + } + } + else + { + ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11940}( + b_k_n_tensors[i]); + } // K must be a multiple of ScaleBlockSize if(K % ScaleBlockSize != 0) @@ -454,15 +600,13 @@ class TestCkTileMxGroupedGemm : public ::testing::Test GTEST_SKIP() << "K must be multiple of ScaleBlockSize for MX GEMM"; } const ck_tile::index_t num_scale_k = K / ScaleBlockSize; - if(num_scale_k % (GroupedGemKernelParam_Wmma::K_Warp_Tile / ScaleBlockSize) != 0) + if(num_scale_k % (K_Warp_Tile / ScaleBlockSize) != 0) { - GTEST_SKIP() << "K must be a multiple of K_Warp_Tile (" - << GroupedGemKernelParam_Wmma::K_Warp_Tile + GTEST_SKIP() << "K must be a multiple of K_Warp_Tile (" << K_Warp_Tile << ") for MX GEMM. Pad the scale data."; } const ck_tile::index_t scale_padded_M = ck_tile::integer_least_multiple( - static_cast(M), - static_cast(GroupedGemKernelParam_Wmma::M_Warp_Tile)); + static_cast(M), static_cast(M_Tile)); ck_tile::HostTensor scale_a( {static_cast(scale_padded_M), static_cast(num_scale_k)}, @@ -515,6 +659,9 @@ class TestCkTileMxGroupedGemm : public ::testing::Test } // Pre-shuffle scale buffers for the hardware +#if defined(CK_USE_GFX1250) + constexpr ck_tile::index_t NXdlPackEff = 1; + ck_tile::HostTensor scale_a_shuffled( {static_cast(scale_padded_M), static_cast(num_scale_k)}, {static_cast(num_scale_k), static_cast(1)}); @@ -523,11 +670,6 @@ class TestCkTileMxGroupedGemm : public ::testing::Test {static_cast(N), static_cast(num_scale_k)}, {static_cast(num_scale_k), static_cast(1)}); - std::cout << " scale_a: [scale_padded_M = " << scale_padded_M - << ", num_scale_k = " << num_scale_k << "]." << std::endl; - std::cout << " scale_b: [N = " << N << ", num_scale_k = " << num_scale_k << "]." - << std::endl; - // Pre-shuffle for gfx1250 (WaveSize=32, WMMA) preShuffleScaleBuffer_gfx1250( scale_a.mData.data(), scale_a_shuffled.mData.data(), scale_padded_M, num_scale_k); @@ -536,19 +678,95 @@ class TestCkTileMxGroupedGemm : public ::testing::Test // where N is the fast-changing dimension for col-major B preShuffleScaleBuffer_gfx1250( scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k); +#elif defined(CK_USE_GFX950) + constexpr ck_tile::index_t MPerXdl = M_Warp_Tile; + constexpr ck_tile::index_t NPerXdl = N_Warp_Tile; + constexpr ck_tile::index_t KPerXdl = K_Warp_Tile; + constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * MPerXdl); + constexpr ck_tile::index_t NIterPerWarp = N_Tile / (N_Warp * NPerXdl); + constexpr ck_tile::index_t KIterPerWarp = K_Tile / KPerXdl; + + constexpr ck_tile::index_t MXdlPackEff = + (MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t NXdlPackEff = + (NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t KXdlPackEff = + (KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1; + + constexpr ck_tile::index_t XdlMNThread = M_Warp_Tile; + constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread; + + ck_tile::HostTensor scale_a_shuffled( + {static_cast(scale_padded_M / MXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), + static_cast(1)}); + + ck_tile::HostTensor scale_b_shuffled( + {static_cast(N / NXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), + static_cast(1)}); + + ck_tile:: + preShuffleScaleBuffer_gfx950( + scale_a.mData.data(), + scale_a_shuffled.mData.data(), + scale_padded_M, + num_scale_k, + true); + + if constexpr(Preshuffle && PermuteN) + { + ck_tile::preShuffleScaleBufferPermuteN_gfx950( + scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k, true); + } + else + { + ck_tile:: + preShuffleScaleBuffer_gfx950( + scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k, true); + } +#endif + + std::cout << " scale_a: [scale_padded_M = " << scale_padded_M + << ", num_scale_k = " << num_scale_k << "]." << std::endl; + std::cout << " scale_b: [N = " << N << ", num_scale_k = " << num_scale_k << "]." + << std::endl; scale_a_tensors.push_back(scale_a_shuffled); scale_b_tensors.push_back(scale_b_shuffled); + using GemmConfig = Config; + + const auto b_host_for_dev = [&]() { + if constexpr(Preshuffle) + { + if constexpr(PermuteN) + { + return ck_tile::shuffle_b_permuteN( + b_k_n_tensors[i]); + } + else + { + return ck_tile::shuffle_b(b_k_n_tensors[i]); + } + } + else + { + return b_k_n_tensors[i]; + } + }(); + a_m_k_dev_buf.push_back(std::make_unique( a_m_k_tensors[i].get_element_space_size_in_bytes())); b_k_n_dev_buf.push_back(std::make_unique( - b_k_n_tensors[i].get_element_space_size_in_bytes())); + b_host_for_dev.get_element_space_size_in_bytes())); c_m_n_dev_buf.push_back(std::make_unique( c_m_n_tensors[i].get_element_space_size_in_bytes())); a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); - b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_host_for_dev.data()); c_m_n_dev_buf[i]->SetZero(); c_m_n_tensors[i].SetZero(); @@ -584,7 +802,7 @@ class TestCkTileMxGroupedGemm : public ::testing::Test ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(gemm_descs)); - if(!invoke_mx_grouped_gemm( + if(!invoke_mx_grouped_gemm( gemm_descs, ck_tile::stream_config{nullptr, false, 1}, gemm_workspace.GetDeviceBuffer())) @@ -628,4 +846,321 @@ class TestCkTileMxGroupedGemm : public ::testing::Test } EXPECT_TRUE(pass); } + + // All-GPU validation path for the fp4 (pk_fp4_t) MX grouped GEMM. + // + // Unlike Run(), this never materializes the (potentially 39 GB) A/B/C tensors on the host: + // - A/B are generated directly on device with a deterministic fp4 fill. + // - the reference is computed on device by reference_mx_gemm_gpu. + // - the comparison is done on device by ck::profiler::gpu_verify. + // Only the tiny e8m0 scales touch the host (for pre-shuffle + an unshuffled copy that the + // device reference consumes). Groups are processed one at a time to bound peak device memory + // and to make any fault attributable to a specific group. + // + // Per group it logs and asserts that every int32-addressed quantity (M*N, A/B/C worst-case + // element offsets) stays < INT_MAX -- this is the property under test for the host-side + // M-decomposition that keeps each per-group buffer kernel-safe. + void RunAllGpu(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const int kbatch = 1, + const int group_count = 16) + { + if constexpr(!Derived::check_data_type()) + return; + + static_assert(std::is_same_v && + std::is_same_v, + "RunAllGpu currently supports pk_fp4_t A/B only."); + // The GPU reference (reference_mx_gemm_gpu) hardcodes these layouts; guard so it cannot be + // silently misused with a layout it does not handle. + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, + "RunAllGpu / reference_mx_gemm_gpu assume RowMajor-A, ColumnMajor-B, " + "RowMajor-C."); + +#if !defined(CK_USE_GFX950) + (void)Ms; + (void)Ns; + (void)Ks; + (void)kbatch; + (void)group_count; + GTEST_SKIP() << "RunAllGpu requires CK_USE_GFX950."; +#else + using namespace ck_tile::literals; + constexpr long kIntMax = 2147483647L; // INT_MAX + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + return col; + else + return row; + } + else + return stride; + }; + + constexpr ck_tile::index_t psize = ck_tile::numeric_traits::PackedSize; // 2 + static_assert(psize == 2, + "RunAllGpu byte-sizing and reference_mx_gemm_kernel's a_ptr[a_lin/2] " + "addressing assume pk_fp4_t PackedSize == 2."); + + bool pass = true; + long total_MN = 0; + + for(int i = 0; i < group_count; ++i) + { + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + // Strides are K/N here (small); keep them as index_t to match the kernel args, and + // make the size_t->index_t narrowing explicit. + const ck_tile::index_t stride_A = + static_cast(f_get_default_stride(M, K, 0, ALayout{})); // K + const ck_tile::index_t stride_B = + static_cast(f_get_default_stride(K, N, 0, BLayout{})); // K + const ck_tile::index_t stride_C = + static_cast(f_get_default_stride(M, N, 0, CLayout{})); // N + + // Per-group shape guards. Fatal (ASSERT), not GTEST_SKIP: a skip mid-loop would + // silently report success for only a prefix of the groups already validated. + ASSERT_EQ(K % ScaleBlockSize, 0) + << "group " << i << ": K must be a multiple of ScaleBlockSize for MX GEMM"; + const ck_tile::index_t num_scale_k = K / ScaleBlockSize; + ASSERT_EQ(num_scale_k % (K_Warp_Tile / ScaleBlockSize), 0) + << "group " << i << ": K must be a multiple of K_Warp_Tile (" << K_Warp_Tile + << ") for MX GEMM. Pad the scale data."; + const ck_tile::index_t scale_padded_M = ck_tile::integer_least_multiple( + static_cast(M), static_cast(M_Tile)); + + // int32-safety: the property under test for the M-decomposition. The predicate is + // "largest 0-based element offset fits in a signed 32-bit int", i.e. offset <= INT_MAX. + const long MN = static_cast(M) * N; + const long A_elems = static_cast(M) * K; + const long B_elems = static_cast(K) * N; + const long C_off = static_cast(M - 1) * stride_C + (N - 1); + const long A_off = static_cast(M - 1) * stride_A + (K - 1); + const long B_off = static_cast(N - 1) * stride_B + (K - 1); + const long c_bytes = MN * static_cast(sizeof(CDataType)); + std::cout << "[int32-safety] group " << i << " M=" << M << " N=" << N << " K=" << K + << " M*N=" << MN << " A_elems=" << A_elems << " B_elems=" << B_elems + << " C_off=" << C_off << " A_off=" << A_off << " B_off=" << B_off + << " C_bytes=" << c_bytes << " (INT_MAX=" << kIntMax << ")" << std::endl; + // Note (not an assert): the C *byte* span can exceed INT_MAX even when the element + // count is int32-safe. We deliberately let the run proceed -- if any internal byte + // offset overflows, gpu_verify will flag it, which is exactly what we want to discover. + if(c_bytes > kIntMax) + std::cout + << "[int32-safety][note] group " << i << " C byte span (" << c_bytes + << ") exceeds INT_MAX; if verification fails, byte-offset overflow is the " + "prime suspect." + << std::endl; + ASSERT_LE(MN - 1, kIntMax) << "group " << i << " max C element index exceeds INT_MAX"; + ASSERT_LE(C_off, kIntMax) << "group " << i << " C offset exceeds INT_MAX"; + ASSERT_LE(A_off, kIntMax) << "group " << i << " A offset exceeds INT_MAX"; + ASSERT_LE(B_off, kIntMax) << "group " << i << " B offset exceeds INT_MAX"; + total_MN += MN; + + // Device buffers (no big host tensors). Round byte counts up: a stray odd fp4 element + // still occupies a full packed byte. + const long a_bytes = (A_elems + psize - 1) / psize; + const long b_bytes = (B_elems + psize - 1) / psize; + + // Bound peak device memory (A + B + 2*C + scales/workspace slack). Skip cleanly rather + // than aborting via hip_check_error if the device cannot hold one group. + { + std::size_t free_b = 0, total_b = 0; + ck_tile::hip_check_error(hipMemGetInfo(&free_b, &total_b)); + const std::size_t need = static_cast(a_bytes) + + static_cast(b_bytes) + + 2u * static_cast(c_bytes) + (64u << 20); + if(free_b < need) + GTEST_SKIP() << "group " << i << ": insufficient device memory (need " << need + << " B, free " << free_b << " B)"; + } + + auto a_dev = std::make_unique(static_cast(a_bytes)); + auto b_dev = std::make_unique(static_cast(b_bytes)); + auto c_dev = std::make_unique(static_cast(c_bytes)); + auto c_ref_dev = + std::make_unique(static_cast(c_bytes)); + c_dev->SetZero(); + c_ref_dev->SetZero(); + + // GPU fill A/B (deterministic, fp4-correct). Same device buffers feed both the kernel + // and the reference, so the fill need not bit-match any host RNG. Fold the group index + // into the seed so each group gets a distinct data pattern. + fill_pk_fp4_uniform(reinterpret_cast(a_dev->GetDeviceBuffer()), + a_bytes, + 11939u + static_cast(i)); + fill_pk_fp4_uniform(reinterpret_cast(b_dev->GetDeviceBuffer()), + b_bytes, + 11940u + static_cast(i)); + ck_tile::hip_check_error( + hipDeviceSynchronize()); // surface fill faults at the fill site + + // e8m0 scales (tiny, host-built, fixed per-group seed for determinism). The range is + // deliberately narrow ([0.25,1.0] scales, [-3,3) fp4 fill) so that K up to 4096 cannot + // overflow the fp16 output (worst case K*9 = 36864 < 65504); gpu_verify counts matched + // infinities as errors, so an overflow would otherwise be a false failure. + ck_tile::HostTensor scale_a( + {static_cast(scale_padded_M), static_cast(num_scale_k)}, + {static_cast(num_scale_k), static_cast(1)}); + ck_tile::HostTensor scale_b( + {static_cast(N), static_cast(num_scale_k)}, + {static_cast(num_scale_k), static_cast(1)}); + { + std::mt19937 gen(11941u + static_cast(i)); + std::uniform_real_distribution dist(0.25f, 1.0f); + for(auto& s : scale_a.mData) + s = AScaleDataType{dist(gen)}; + for(auto& s : scale_b.mData) + s = BScaleDataType{dist(gen)}; + } + + // gfx950 scale pre-shuffle. NOTE: this must stay in sync with the identical block in + // Run() -- the kernel-input layout and the reference-input layout must agree. + constexpr ck_tile::index_t MPerXdl = M_Warp_Tile; + constexpr ck_tile::index_t NPerXdl = N_Warp_Tile; + constexpr ck_tile::index_t KPerXdl = K_Warp_Tile; + constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * MPerXdl); + constexpr ck_tile::index_t NIterPerWarp = N_Tile / (N_Warp * NPerXdl); + constexpr ck_tile::index_t KIterPerWarp = K_Tile / KPerXdl; + + constexpr ck_tile::index_t MXdlPackEff = + (MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t NXdlPackEff = + (NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t KXdlPackEff = + (KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1; + + constexpr ck_tile::index_t XdlMNThread = M_Warp_Tile; + constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread; + + ck_tile::HostTensor scale_a_shuffled( + {static_cast(scale_padded_M / MXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), + static_cast(1)}); + ck_tile::HostTensor scale_b_shuffled( + {static_cast(N / NXdlPackEff * 2), + static_cast(num_scale_k / KXdlPackEff * 2)}, + {static_cast(num_scale_k / KXdlPackEff * 2), + static_cast(1)}); + + ck_tile:: + preShuffleScaleBuffer_gfx950( + scale_a.mData.data(), + scale_a_shuffled.mData.data(), + scale_padded_M, + num_scale_k, + true); + ck_tile:: + preShuffleScaleBuffer_gfx950( + scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k, true); + + // Device scale buffers: shuffled feed the kernel, unshuffled feed the reference. + auto scale_a_shuf_dev = std::make_unique( + scale_a_shuffled.get_element_space_size_in_bytes()); + auto scale_b_shuf_dev = std::make_unique( + scale_b_shuffled.get_element_space_size_in_bytes()); + scale_a_shuf_dev->ToDevice(scale_a_shuffled.data()); + scale_b_shuf_dev->ToDevice(scale_b_shuffled.data()); + + auto scale_a_ref_dev = + std::make_unique(scale_a.get_element_space_size_in_bytes()); + auto scale_b_ref_dev = + std::make_unique(scale_b.get_element_space_size_in_bytes()); + scale_a_ref_dev->ToDevice(scale_a.data()); + scale_b_ref_dev->ToDevice(scale_b.data()); + + // Launch the grouped kernel for this single group. + std::vector gemm_descs; + gemm_descs.push_back(mx_grouped_gemm_kargs(a_dev->GetDeviceBuffer(), + scale_a_shuf_dev->GetDeviceBuffer(), + b_dev->GetDeviceBuffer(), + scale_b_shuf_dev->GetDeviceBuffer(), + {/*ds_ptr*/}, + c_dev->GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + {/*stride_Ds*/}, + stride_C)); + + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + if(!invoke_mx_grouped_gemm( + gemm_descs, + ck_tile::stream_config{nullptr, false, 1}, + gemm_workspace.GetDeviceBuffer())) + { + ADD_FAILURE() << "invoke_mx_grouped_gemm failed for group " << i; + pass = false; + continue; // DeviceMem frees cleanly at loop end; keep validating other groups + } + ck_tile::hip_check_error(hipDeviceSynchronize()); + + // GPU reference on the same device A/B buffers. + ck_tile::reference_mx_gemm_gpu( + reinterpret_cast(a_dev->GetDeviceBuffer()), + reinterpret_cast(b_dev->GetDeviceBuffer()), + reinterpret_cast(scale_a_ref_dev->GetDeviceBuffer()), + reinterpret_cast(scale_b_ref_dev->GetDeviceBuffer()), + reinterpret_cast(c_ref_dev->GetDeviceBuffer()), + M, + N, + K, + num_scale_k, + ScaleBlockSize); + ck_tile::hip_check_error(hipDeviceSynchronize()); + + // GPU verify with explicit MX tolerance (auto tolerance defaults too tight for MX). + const float max_acc = ck::profiler::gpu_reduce_max( + c_ref_dev->GetDeviceBuffer(), static_cast(MN)); + // The reference must be non-degenerate, else error_count==0 is a vacuous pass. + ASSERT_GT(max_acc, 0.0f) << "group " << i << ": GPU reference output is all-zero"; + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_acc); + const auto res = ck::profiler::gpu_verify(c_dev->GetDeviceBuffer(), + c_ref_dev->GetDeviceBuffer(), + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{}), + static_cast(MN)); + + // Positive liveness check on the *device* output. res.all_zero ANDs device- and + // reference-zeroness, and the reference is never zero here, so it cannot detect a no-op + // kernel on its own -- reduce the device buffer directly. + const float c_dev_absmax = ck::profiler::gpu_reduce_max( + c_dev->GetDeviceBuffer(), static_cast(MN)); + + std::cout << "[verify] group " << i << " errors=" << res.error_count + << " max_error=" << res.max_error << " c_dev_absmax=" << c_dev_absmax + << " max_acc=" << max_acc << " rtol=" << rtol_atol.at(ck_tile::number<0>{}) + << " atol=" << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + + EXPECT_EQ(res.error_count, 0ull) << "group " << i << " produced mismatched results"; + EXPECT_GT(c_dev_absmax, 0.0f) << "group " << i << " produced an all-zero device output"; + pass &= (res.error_count == 0 && c_dev_absmax > 0.0f); + // a_dev/b_dev/c_dev/... freed here (unique_ptr) before the next group. + } + + std::cout << "[int32-safety] aggregate total_M*N=" << total_MN << " (INT_MAX=" << kIntMax + << ") -> decomposition is the variable under test" << std::endl; + EXPECT_TRUE(pass); +#endif // CK_USE_GFX950 + } }; diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_wmma_tdm.cpp b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_wmma_tdm.cpp new file mode 100644 index 0000000000..012598fbc1 --- /dev/null +++ b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_wmma_tdm.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_mx_grouped_gemm_util.hpp" +#include "test_mx_grouped_gemm_pipeline_kernel_types.hpp" + +template +class TestCkTileMxGemmPipelineCompTDMWmma + : public TestCkTileMxGroupedGemm> +{ + public: + static constexpr bool check_data_type() + { + using Base = TestCkTileMxGroupedGemm>; + + if constexpr(!is_valid_mx_scale_combination()) + { + return false; + } + +#if defined(CK_USE_GFX1250) + using DeviceIp = ck_tile::gfx125_t; +#else +#error "Unsupported architecture for WMMA MX GEMM" +#endif + + return ck_tile::has_wmma_traits_v::value, + ck_tile::constant::value, + ck_tile::constant::value>; + } + + private: + template + static constexpr bool is_valid_mx_scale_combination() + { + constexpr bool a_is_f4 = std::is_same_v; + constexpr bool b_is_f4 = std::is_same_v; + constexpr bool a_scale_e8m0 = std::is_same_v; + constexpr bool b_scale_e8m0 = std::is_same_v; + + // Non-F4 must use E8M0 scale + if constexpr(!a_is_f4 && !a_scale_e8m0) + return false; + if constexpr(!b_is_f4 && !b_scale_e8m0) + return false; + + // Both E8M0 -> always valid + if constexpr(a_scale_e8m0 && b_scale_e8m0) + return true; + + // Both non-E8M0 -> must match (both are F4 by rule 1) + if constexpr(!a_scale_e8m0 && !b_scale_e8m0) + return std::is_same_v; + + // One side non-E8M0: the E8M0 side must not be F4 + if constexpr(!a_scale_e8m0) + return !b_is_f4; + if constexpr(!b_scale_e8m0) + return !a_is_f4; + + return true; + } +}; + +#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompTDMWmma + +TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompTDMWmma, KernelTypesMxGemmCompTDMWmma); + +#include "test_mx_grouped_gemm_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 9e71d6ea76..31148661f2 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -520,7 +520,7 @@ class GemmKernelBuilder: } elif self.kernel_name_prefix == "mx_gemm": pipeline_impl_map = { - "comp_async": "ck_tile::MXGemmPipelineAgBgCrCompAsync", + "comp_async": "ck_tile::GemmPipelineAgBgCrCompAsync", } base_pipeline_map = {} @@ -581,9 +581,6 @@ class GemmKernelBuilder: instance_code += """#include #include #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" -""" - elif self.kernel_name_prefix == "mx_gemm": - instance_code += """#include "ck_tile/ops/gemm_mx.hpp" """ return instance_code @@ -617,9 +614,7 @@ using CDataType = {get_dtype_string(c_type)};""" if self.kernel_name_prefix == "mx_gemm": instance_code += """ using ScaleType = ck_tile::e8m0_t; -using ScaleM = ck_tile::MXScalePointer; -using ScaleN = ck_tile::MXScalePointer; -using MxGemmHostArgs = ck_tile::MXGemmKernelArgs;""" +using MxGemmHostArgs = ck_tile::MxGemmHostArgs<1, 1, 0>;""" if self.kernel_name_prefix == "gemm_multi_d": instance_code += f""" @@ -684,7 +679,7 @@ struct SelectedKernel {{ static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"}; static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"}; static constexpr bool TransposeC = false; - static constexpr bool DoubleSmemBuffer = {"true" if pipeline in ["compv4", "preshufflev2"] else "false"};""" + static constexpr bool DoubleSmemBuffer = {"true" if pipeline in ["compv4", "preshufflev2", "comp_async"] else "false"};""" if self.kernel_name_prefix in [ "gemm_universal", @@ -1069,17 +1064,17 @@ struct SelectedKernel {{ instance_code += f""" // Kernel type - using Kernel = ck_tile::MXGemmKernel; + using Kernel = ck_tile::MxGemmKernel; // Kernel arguments - auto kargs = args; + auto kargs = Kernel::MakeKernelArgs(args); - if(!Kernel::Underlying::IsSupportedArgument(kargs)) {{ + if(!Kernel::IsSupportedArgument(kargs)) {{ throw std::runtime_error("Wrong! Arguments not supported! Skipping mx gemm!"); }} // Get grid and block sizes - const dim3 grids = Kernel::GridSize(kargs); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); const dim3 blocks = Kernel::BlockSize(); if(stream.log_level_ > 0) {{ @@ -1245,7 +1240,15 @@ struct SelectedKernel {{ WarpTileM, WarpTileN, WarpTileK, - TransposeC>; + TransposeC, + 1, // NumWaveGroups + false, // FixedVectorSize_ + 1, // VectorSizeC_ + 1, // BlockedXDLNPerWarp + false, // DoubleSmemBuffer_ + ADataType, // AComputeDataType + BDataType, // BComputeDataType + true>; // TilesPacked_ using GemmEpilogue = ck_tile::CShuffleEpilogue;""" return instance_code diff --git a/tile_engine/ops/gemm/mx_gemm/mx_gemm_profiler.hpp b/tile_engine/ops/gemm/mx_gemm/mx_gemm_profiler.hpp index 55d21b5733..1e6b2fee05 100644 --- a/tile_engine/ops/gemm/mx_gemm/mx_gemm_profiler.hpp +++ b/tile_engine/ops/gemm/mx_gemm/mx_gemm_profiler.hpp @@ -9,7 +9,6 @@ #include #include "ck_tile/host/device_prop.hpp" -#include "ck_tile/ops/gemm_mx.hpp" #include "gemm/gemm_profiler.hpp" #include "mx_gemm_benchmark.hpp" @@ -49,15 +48,13 @@ class MXGemmProfiler : public GemmProfiler scale_a_host(ck_tile::host_tensor_descriptor( - gemm_problem.m_, scale_k_size, stride_scale_a, is_row_major(layout_a))); - ck_tile::HostTensor scale_b_host(ck_tile::host_tensor_descriptor( - scale_k_size, gemm_problem.n_, stride_scale_b, is_row_major(layout_b))); + ck_tile::HostTensor scale_a_host( + {static_cast(gemm_problem.m_), static_cast(scale_k_size)}, + {static_cast(scale_k_size), static_cast(1)}); + ck_tile::HostTensor scale_b_host( + {static_cast(gemm_problem.n_), static_cast(scale_k_size)}, + {static_cast(scale_k_size), static_cast(1)}); if(setting_.init_method == 0) { @@ -109,31 +106,47 @@ class MXGemmProfiler : public GemmProfiler(scale_a_host, - true); - auto scale_b_packed = - pack_mx_scales_mn_x_k(scale_b_host, - false); + ck_tile::HostTensor scale_a_shuffled( + {static_cast(gemm_problem.m_ / m_xdl_pack * 2), + static_cast(scale_k_size / k_xdl_pack * 2)}, + {static_cast(scale_k_size / k_xdl_pack * 2), static_cast(1)}); + + ck_tile::HostTensor scale_b_shuffled( + {static_cast(gemm_problem.n_ / n_xdl_pack * 2), + static_cast(scale_k_size / k_xdl_pack * 2)}, + {static_cast(scale_k_size / k_xdl_pack * 2), static_cast(1)}); + + ck_tile::preShuffleScaleBuffer_gfx950( + scale_a_host.mData.data(), + scale_a_shuffled.mData.data(), + gemm_problem.m_, + scale_k_size, + true); + + ck_tile::preShuffleScaleBuffer_gfx950( + scale_b_host.mData.data(), + scale_b_shuffled.mData.data(), + gemm_problem.n_, + scale_k_size, + true); ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - ck_tile::DeviceMem scale_a_dev_buf(scale_a_packed.size() * sizeof(int32_t)); - ck_tile::DeviceMem scale_b_dev_buf(scale_b_packed.size() * sizeof(int32_t)); + ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes()); + ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes()); a_m_k_dev_buf.ToDevice(a_m_k.data()); b_k_n_dev_buf.ToDevice(b_k_n.data()); c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - scale_a_dev_buf.ToDevice(scale_a_packed.data()); - scale_b_dev_buf.ToDevice(scale_b_packed.data()); - - ScaleM scale_m(reinterpret_cast(scale_a_dev_buf.GetDeviceBuffer())); - ScaleN scale_n(reinterpret_cast(scale_b_dev_buf.GetDeviceBuffer())); + scale_a_dev_buf.ToDevice(scale_a_shuffled.data()); + scale_b_dev_buf.ToDevice(scale_b_shuffled.data()); MxGemmHostArgs gemm_args({a_m_k_dev_buf.GetDeviceBuffer()}, + {scale_a_dev_buf.GetDeviceBuffer()}, {b_k_n_dev_buf.GetDeviceBuffer()}, + {scale_b_dev_buf.GetDeviceBuffer()}, {}, c_m_n_dev_buf.GetDeviceBuffer(), gemm_problem.split_k_, @@ -143,17 +156,26 @@ class MXGemmProfiler : public GemmProfiler c_m_n_host_result(ck_tile::host_tensor_descriptor( gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); if(setting_.verify) { + // Host reference computation using reference_mx_gemm + // reference_mx_gemm expects scale_a(M, K/ScaleBlockSize) and scale_b(K/ScaleBlockSize, + // N) We need to create scale_b in (K/ScaleBlockSize, N) format for the reference + ck_tile::HostTensor scale_b_ref( + {static_cast(scale_k_size), static_cast(gemm_problem.n_)}, + {static_cast(1), static_cast(scale_k_size)}); + // Copy scale_b data (our scale_b is (N, scale_k_size) row-major, + // reference expects (scale_k_size, N) col-major, which is the same memory layout) + std::copy( + scale_b_host.mData.begin(), scale_b_host.mData.end(), scale_b_ref.mData.begin()); + mx_gemm_host_reference( - setting_.verify, a_m_k, b_k_n, c_m_n_host_result, scale_a_host, scale_b_host); + setting_.verify, a_m_k, b_k_n, c_m_n_host_result, scale_a_host, scale_b_ref); } for(auto& callable : callables)