diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 40547d0719..b6b35f1b53 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -13,6 +13,13 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a") list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) + + option(CK_TILE_EXAMPLE_ROUTE_DISPATCHER_TO_MAKE "Route ck_tile::WarpGemmDispatcher to MakeWarpGemm in examples" OFF) + + # When routing dispatcher to MakeWarpGemm is enabled, add the macro to the shared options + if(CK_TILE_EXAMPLE_ROUTE_DISPATCHER_TO_MAKE) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_ROUTE_WARP_GEMM_DISPATCHER_TO_MAKE) + endif() list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef) list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker) list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps) diff --git a/include/ck_tile/ops/gemm/warp/detail/warp_gemm_attribute_mfma_compose.hpp b/include/ck_tile/ops/gemm/warp/detail/warp_gemm_attribute_mfma_compose.hpp new file mode 100644 index 0000000000..c2d9e68c22 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/detail/warp_gemm_attribute_mfma_compose.hpp @@ -0,0 +1,709 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" +// smfmac (structured sparsity) +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" +namespace ck_tile { +namespace detail { +namespace wg_attr_compose { + +// Basic config/state that policies transform at compile-time +struct NoneTag; + +enum class SwizzleKind +{ + None, + A, + B +}; + +template +struct State +{ + using Impl = remove_cvref_t; + static constexpr bool TransposeC = TransposeC_; + static constexpr SwizzleKind Swizzle = Swizzle_; + static constexpr index_t SFactor = SFactor_; + static constexpr index_t KIter = KIter_; + static constexpr WGAttrNumAccessEnum NumAccess = NumAccess_; + static constexpr index_t NumAccValue = static_cast(NumAccess); +}; + +// Default base state +template +using BaseState = State; + +// Helpers to compute derived constants and types from State + +// A/B data/vec types under transpose flag +template +struct ABTypes +{ + using ADataType = + std::conditional_t; + using BDataType = + std::conditional_t; + + using AVecBase = + std::conditional_t; + using BVecBase = + std::conditional_t; + + using AVecType = ext_vector_t::vector_size * S::KIter>; + using BVecType = ext_vector_t::vector_size * S::KIter>; +}; + +// C types/shape (C does not change type with policies) +template +struct CTypes +{ + using CDataType = typename S::Impl::CDataType; + using CVecType = typename S::Impl::CVecType; +}; + +// kM/kN/kK/kKPerThread under transpose and iterateK + +template +struct KShape +{ + static constexpr index_t kM = S::TransposeC ? S::Impl::kN : S::Impl::kM; + static constexpr index_t kN = S::TransposeC ? S::Impl::kM : S::Impl::kN; + static constexpr index_t kK = S::Impl::kK * S::KIter; + static constexpr index_t kKPerThread = S::Impl::kABKPerLane * S::KIter; + static constexpr index_t kCMLane = S::Impl::kCMLane; // unchanged +}; + +// Lane helpers + +template +struct Lanes +{ + static constexpr index_t AMLane = S::TransposeC ? S::Impl::kBNLane : S::Impl::kAMLane; + static constexpr index_t BNLane = S::TransposeC ? S::Impl::kAMLane : S::Impl::kBNLane; +}; + +// Encoding builders centralizing existing logic, parameterized by S + +// A encodings with NumAccess and IterateK folded, with multi-block and swizzle +template +CK_TILE_DEVICE static constexpr auto MakeAWarpDstrEncoding() +{ + constexpr index_t NumAccValue = S::NumAccValue; + + // Helper lambdas for the base A-encoding and the swizzled variant + auto base_enc = []() { + if constexpr(S::Impl::kAMBlock == 1 && S::Impl::kBNBlock == 1) + { + if constexpr(NumAccValue == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple::AMLane>, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else + { + static_assert(KShape::kKPerThread % NumAccValue == 0, + "kKPerThread must be divisible by NumAccess"); + return tile_distribution_encoding< + sequence<>, + tuple::AMLane>, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } + else if constexpr(S::Impl::kAMBlock == 1 && 1 < S::Impl::kBNBlock) + { + static_assert(NumAccValue == 1, + "Multiple access is not supported when using multi-block"); + return tile_distribution_encoding< + sequence, + tuple::AMLane>, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else // 1 < AMBlock && BNBlock == 1 + { + static_assert(NumAccValue == 1, + "Multiple access is not supported when using multi-block"); + return tile_distribution_encoding< + sequence<>, + tuple::AMLane>, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + }; + + auto swizzled_enc = []() { + return tile_distribution_encoding< + sequence<>, + tuple< + sequence, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + }; + + if constexpr(S::Swizzle == SwizzleKind::A) + return swizzled_enc(); + else + return base_enc(); +} + +// B encodings with NumAccess and IterateK folded, with multi-block and swizzle +template +CK_TILE_DEVICE static constexpr auto MakeBWarpDstrEncoding() +{ + constexpr index_t NumAccValue = S::NumAccValue; + + auto base_enc = []() { + if constexpr(S::Impl::kAMBlock == 1 && S::Impl::kBNBlock == 1) + { + if constexpr(NumAccValue == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple::BNLane>, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else + { + static_assert(KShape::kKPerThread % NumAccValue == 0, + "kKPerThread must be divisible by NumAccess"); + return tile_distribution_encoding< + sequence<>, + tuple::BNLane>, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } + else if constexpr(S::Impl::kAMBlock == 1 && 1 < S::Impl::kBNBlock) + { + static_assert(NumAccValue == 1, + "Multiple access is not supported when using multi-block"); + return tile_distribution_encoding< + sequence<>, + tuple::BNLane>, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else // 1 < AMBlock && BNBlock == 1 + { + static_assert(NumAccValue == 1, + "Multiple access is not supported when using multi-block"); + return tile_distribution_encoding< + sequence, + tuple::BNLane>, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + }; + + auto swizzled_enc = []() { + return tile_distribution_encoding< + sequence<>, + tuple< + sequence, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + }; + + if constexpr(S::Swizzle == SwizzleKind::B) + return swizzled_enc(); + else + return base_enc(); +} + +// C distribution encoding with transpose and multi-block + +template +CK_TILE_DEVICE static constexpr auto MakeCWarpDstrEncoding() +{ + constexpr bool HasSwizzle = (S::Swizzle == SwizzleKind::A) || (S::Swizzle == SwizzleKind::B); + + auto make_m_splits = []() { + if constexpr(HasSwizzle) + { + // Swizzled M splits + return sequence{}; + } + else + { + return sequence{}; + } + }; + + if constexpr(S::Impl::kAMBlock == 1 && S::Impl::kBNBlock == 1) + { + if constexpr(!S::TransposeC) + { + return tile_distribution_encoding< + sequence<>, + tuple>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>{}; + } + else // TransposeC + { + return tile_distribution_encoding< + sequence<>, + tuple, decltype(make_m_splits())>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } + else if constexpr(S::Impl::kAMBlock == 1 && 1 < S::Impl::kBNBlock) + { + if constexpr(!S::TransposeC) + { + return tile_distribution_encoding< + sequence<>, + tuple>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>{}; + } + else + { + return tile_distribution_encoding< + sequence<>, + tuple, decltype(make_m_splits())>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } + else if constexpr(1 < S::Impl::kAMBlock && S::Impl::kBNBlock == 1) + { + if constexpr(!S::TransposeC) + { + return tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>{}; + } + else + { + return tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } +} + +// Detect smfmac by Impl type: provide a small trait that is true for smfmac attribute impls +template +struct is_smfmac_impl : std::false_type +{ +}; +template +struct is_smfmac_impl< + WarpGemmAttributeSmfmacImpl> + : std::true_type +{ +}; + +// Final composed attribute +// Primary (MFMA) and SMFMA-specialized ComposedAttribute +template ::value> +struct ComposedAttribute +{ + using Impl = typename S::Impl; + using ATypes = ABTypes; + using CTypesT = CTypes; + + static constexpr index_t kM = KShape::kM; + static constexpr index_t kN = KShape::kN; + static constexpr index_t kK = KShape::kK; + static constexpr index_t kKPerThread = KShape::kKPerThread; + static constexpr index_t kCMLane = KShape::kCMLane; + + using ADataType = typename ATypes::ADataType; + using BDataType = typename ATypes::BDataType; + using CDataType = typename CTypesT::CDataType; + + using AVecType = typename ATypes::AVecType; + using BVecType = typename ATypes::BVecType; + using CVecType = typename CTypesT::CVecType; + + using AWarpDstrEncoding = decltype(MakeAWarpDstrEncoding()); + using BWarpDstrEncoding = decltype(MakeBWarpDstrEncoding()); + using CWarpDstrEncoding = decltype(MakeCWarpDstrEncoding()); + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return S::KIter; } + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + if constexpr(S::KIter == 1) + { + if constexpr(S::TransposeC) + { + // When TransposeC, composed A/B are swapped relative to Impl. + // Pass b_vec as Impl::AVecType and a_vec as Impl::BVecType. + Impl{}(c_vec, + reinterpret_cast(b_vec), + reinterpret_cast(a_vec), + bool_constant{}); + } + else + { + Impl{}(c_vec, + reinterpret_cast(a_vec), + reinterpret_cast(b_vec), + bool_constant{}); + } + } + else + { + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + static_for<0, S::KIter, 1>{}([&](auto iKIter) { + if constexpr(S::TransposeC) + { + // Swap mapping: b_vec -> Impl::AVecType, a_vec -> Impl::BVecType + Impl{}(c_vec, + reinterpret_cast(b_vec) + .template get_as()[iKIter], + reinterpret_cast(a_vec) + .template get_as()[iKIter], + bool_constant{}); + } + else + { + Impl{}(c_vec, + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter], + bool_constant{}); + } + }); + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + if constexpr(S::KIter == 1) + { + if constexpr(S::TransposeC) + { + // Swap mapping: b_vec -> Impl::AVecType, a_vec -> Impl::BVecType + return Impl{}(reinterpret_cast(b_vec), + reinterpret_cast(a_vec)); + } + else + { + return Impl{}(reinterpret_cast(a_vec), + reinterpret_cast(b_vec)); + } + } + else + { + using buf_a = thread_buffer; + using buf_b = thread_buffer; + constexpr auto I0 = number<0>{}; + + CVecType c_vec; + if constexpr(S::TransposeC) + { + // Swap mapping: b_vec -> Impl::AVecType, a_vec -> Impl::BVecType + c_vec = Impl{}(reinterpret_cast(b_vec) + .template get_as()[I0], + reinterpret_cast(a_vec) + .template get_as()[I0]); + + static_for<1, S::KIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(b_vec) + .template get_as()[iKIter], + reinterpret_cast(a_vec) + .template get_as()[iKIter]); + }); + } + else + { + c_vec = Impl{}(reinterpret_cast(a_vec) + .template get_as()[I0], + reinterpret_cast(b_vec) + .template get_as()[I0]); + + static_for<1, S::KIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter]); + }); + } + return c_vec; + } + } +}; + +// SMFMA specialization: forbid Transpose and Swizzle for now; KIter must be 1 +// TODO: enable swizzle, transpose for smfmac +template +struct ComposedAttribute +{ + using Impl = typename S::Impl; + + static_assert(!S::TransposeC, "smfmac TransposeC is not supported in composed attributes"); + static_assert(S::Swizzle == SwizzleKind::None, + "smfmac Swizzle is not supported in composed attributes"); + static_assert(S::KIter == 1, "smfmac IterateK is not supported (KIter must be 1)"); + + using ADataType = typename Impl::ADataType; + using BDataType = typename Impl::BDataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename Impl::AVecType; + using BVecType = typename Impl::BVecType; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kM; + static constexpr index_t kN = Impl::kN; + static constexpr index_t kK = Impl::kK; + static constexpr index_t kKPerThread = Impl::kABKPerLane; + static constexpr index_t kCMLane = Impl::kCMLane; + static constexpr index_t kCompressionRatio = Impl::CompressionRatio; + + // Reuse the encodings defined by the smfmac attribute wrapper for consistency + using AWarpDstrEncoding = typename WarpGemmAttributeSmfmac::AWarpDstrEncoding; + using BWarpDstrEncoding = typename WarpGemmAttributeSmfmac::BWarpDstrEncoding; + using CWarpDstrEncoding = typename WarpGemmAttributeSmfmac::CWarpDstrEncoding; + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } + + // c_vec += a_vec * b_vec[idx] + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + const int32_t& idx, + bool_constant = {}) const + { + Impl{}(c_vec, a_vec, b_vec, idx, bool_constant{}); + } + + // c_vec = a_vec * b_vec[idx] + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, + const BVecType& b_vec, + const int32_t& idx) const + { + CVecType c_vec{0}; + Impl{}(c_vec, a_vec, b_vec, idx); + return c_vec; + } +}; + +// Policy wrappers produce a new State from an old one + +template +struct PolicyIterateK +{ + template + using apply = + State; +}; + +struct PolicyTransposeC +{ + template + using apply = State; +}; + +template +struct PolicySwizzleA +{ + template + using apply = + State; +}; + +template +struct PolicySwizzleB +{ + template + using apply = + State; +}; + +} // namespace wg_attr_compose +} // namespace detail +} // namespace ck_tile + +// High-level alias to construct a composed WarpGemm attribute via CoreDispatcher and policies +#include "ck_tile/ops/gemm/warp/warp_gemm_core_dispatcher.hpp" + +namespace ck_tile { + +// Helper that uses core dispatcher to select MFMA vs smfmac and composes policies only +template +struct ComposePolicies +{ + using CD = WarpGemmCoreDispatcher; + using Impl = typename CD::Impl; + + static_assert(Impl::kK > 0, "Invalid K dimension"); + static_assert(Impl::kK <= KPerWave, "KPerWave must smaller and equal than Impl::kK"); + + static constexpr index_t KIter = KPerWave / Impl::kK; + + static constexpr bool IsSmfmac = detail::wg_attr_compose::is_smfmac_impl::value; + + // First, setup a default state as the base state. + using S0 = detail::wg_attr_compose::BaseState; + + // The order of policies matters so we compose them with the order: Kiter->TransposeC->SwizzleA. + using S1 = std::conditional_t< + (KIter == 1), + S0, + typename detail::wg_attr_compose::PolicyIterateK::template apply>; + using S2 = + std::conditional_t, + S1>; + // Match dispatcher behavior: when TransposeC is enabled, the swizzle applies to B + // (i.e., SwizzleBTransposedCDistribution). Otherwise apply swizzle to A. + using S3 = std::conditional_t< + SwizzleA && !TransposeC, + typename detail::wg_attr_compose::PolicySwizzleA<2>::template apply, + std::conditional_t::template apply, + S2>>; + + // For SMFMA, use the dedicated WarpGemmSmfmacImpl wrapper with the SMFMA attribute. + // For MFMA (default), use the regular WarpGemmImpl over the composed attribute. + using type = std::conditional_t< + IsSmfmac, + ck_tile::WarpGemmSmfmacImpl>, + ck_tile::WarpGemmImpl>>; +}; + +// Wrapper struct to match usage as MakeWarpGemm<...>::Type +template +struct MakeWarpGemm +{ + using Type = typename ComposePolicies::type; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index bd65f53383..d9bdaea300 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -22,6 +22,16 @@ enum class WGAttrCtlEnum // raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr }; +// Primary template: generic MFMA warp-gemm attribute (to be specialized per supported shape) +template +struct WarpGemmAttributeMfmaImpl; // no definition, only specializations are provided + #define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \ if constexpr(post_nop_) \ { \ @@ -191,8 +201,8 @@ struct WarpGemmAttributeMfmaImplF32F32F32M32N32K2 }; // V_MFMA_F32_16x16x32_BF16 -template -struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32 +template +struct WarpGemmAttributeMfmaImpl { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = bf16_t; @@ -254,8 +264,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32 } }; // FP16 -template -struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 +template +struct WarpGemmAttributeMfmaImpl { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = fp16_t; @@ -317,8 +327,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 } }; -template -struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 +template +struct WarpGemmAttributeMfmaImpl { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = fp16_t; @@ -380,8 +390,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 } }; -template -struct WarpGemmAttributeMfmaImplF16F16F32M16N16K32 +template +struct WarpGemmAttributeMfmaImpl { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = fp16_t; @@ -443,8 +453,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K32 } }; -template -struct WarpGemmAttributeMfmaImplF16F16F32M4N64K4 +template +struct WarpGemmAttributeMfmaImpl { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = fp16_t; @@ -507,8 +517,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M4N64K4 } }; -template -struct WarpGemmAttributeMfmaImplF16F16F32M64N4K4 +template +struct WarpGemmAttributeMfmaImpl { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = fp16_t; @@ -572,8 +582,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M64N4K4 }; // Bf16 -template -struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 +template +struct WarpGemmAttributeMfmaImpl { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = bf16_t; @@ -661,8 +671,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 } }; -template -struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 +template +struct WarpGemmAttributeMfmaImpl { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = bf16_t; @@ -749,8 +759,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 } }; -template -struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4 +template +struct WarpGemmAttributeMfmaImpl { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = bf16_t; @@ -839,8 +849,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4 } }; -template -struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 +template +struct WarpGemmAttributeMfmaImpl { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = bf16_t; @@ -930,8 +940,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 }; // gfx950 -template -struct WarpGemmAttributeMfmaImplF16F16F32M32N32K16 +template +struct WarpGemmAttributeMfmaImpl { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = fp16_t; @@ -1044,8 +1054,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K16 } }; -template -struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16 +template +struct WarpGemmAttributeMfmaImpl { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = bf16_t; @@ -1158,6 +1168,60 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16 } }; +// --------------------------------------------------------------------------------- +// Backward-compatibility aliases (preserve original names) +// --------------------------------------------------------------------------------- + +// BF16 +template +using WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32 = + WarpGemmAttributeMfmaImpl; + +template +using WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 = + WarpGemmAttributeMfmaImpl; + +template +using WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 = + WarpGemmAttributeMfmaImpl; + +template +using WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4 = + WarpGemmAttributeMfmaImpl; + +template +using WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 = + WarpGemmAttributeMfmaImpl; + +template +using WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16 = + WarpGemmAttributeMfmaImpl; + +// F16 +template +using WarpGemmAttributeMfmaImplF16F16F32M32N32K8 = + WarpGemmAttributeMfmaImpl; + +template +using WarpGemmAttributeMfmaImplF16F16F32M16N16K16 = + WarpGemmAttributeMfmaImpl; + +template +using WarpGemmAttributeMfmaImplF16F16F32M16N16K32 = + WarpGemmAttributeMfmaImpl; + +template +using WarpGemmAttributeMfmaImplF16F16F32M4N64K4 = + WarpGemmAttributeMfmaImpl; + +template +using WarpGemmAttributeMfmaImplF16F16F32M64N4K4 = + WarpGemmAttributeMfmaImpl; + +template +using WarpGemmAttributeMfmaImplF16F16F32M32N32K16 = + WarpGemmAttributeMfmaImpl; + // FP8 template struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base @@ -1318,6 +1382,31 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base } }; +// Map FP8/BF8 shapes to the primary template via specializations (reuse base impl) +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base +{ +}; + +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base +{ +}; + +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base +{ +}; + +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base +{ +}; + template struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base { @@ -1501,12 +1590,38 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base } }; +// Map FP8/BF8 32x32x16 shapes to the primary template via specializations (reuse base impl) +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base +{ +}; + +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base +{ +}; + +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base +{ +}; + +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base +{ +}; + +// Back-compat aliases now point to primary-template specializations template using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 = - WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; + WarpGemmAttributeMfmaImpl; template using WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8 = - WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base; + WarpGemmAttributeMfmaImpl; template using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 = WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; @@ -1516,7 +1631,7 @@ using WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_bf8 = template using WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8 = - WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base; + WarpGemmAttributeMfmaImpl; template using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 = @@ -1555,6 +1670,130 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4 static constexpr index_t kCM0PerLane = 1; static constexpr index_t kCM1PerLane = 4; + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a, + // opsel, scale_b) +#if defined(__gfx950__) + if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + +// Map FP8/BF8 16x16x128 scale shapes to the primary template via specializations +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base +{ +}; + +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base +{ +}; + +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base +{ +}; + +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base +{ +}; + +// Back-compat aliases +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8 = + WarpGemmAttributeMfmaImpl; + +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8 = + WarpGemmAttributeMfmaImpl; + +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8 = + WarpGemmAttributeMfmaImpl; + +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 = + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; + +template +struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = pk_fp4_t; + using BDataType = pk_fp4_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 128; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 32; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + // c_vec += a_vec * b_vec template CK_TILE_DEVICE void operator()(CVecType& c_vec, @@ -1716,23 +1955,49 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base } }; +// Map FP8/BF8 32x32x64 scale shapes to the primary template via specializations +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base +{ +}; + +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base +{ +}; + +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base +{ +}; + +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base +{ +}; + +// Back-compat aliases template using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8 = - WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + WarpGemmAttributeMfmaImpl; template using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8 = - WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + WarpGemmAttributeMfmaImpl; template using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8 = - WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + WarpGemmAttributeMfmaImpl; template using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8 = - WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + WarpGemmAttributeMfmaImpl; -// int8 +// int8: map shapes to the primary template via specializations (reuse the concrete impl bodies) template struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 { @@ -1980,6 +2245,31 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x32_i8 } }; +// Primary-template specializations delegating to the bodies above +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 +{ +}; + +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_i32_16x16x32_i8 +{ +}; + +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_i32_16x16x64_i8 +{ +}; + +template +struct WarpGemmAttributeMfmaImpl + : WarpGemmAttributeMfmaImpl_i32_32x32x32_i8 +{ +}; + #undef DISPATCH_MFMA_ } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp index 72cbf37206..3c7edc6627 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp @@ -81,5 +81,12 @@ struct WarpGemmAttributeSmfmac { Impl{}(c_vec, a_vec, b_vec, idx, bool_constant{}); } + + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, + const BVecType& b_vec, + const int32_t& idx) const + { + return Impl{}(a_vec, b_vec, idx); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp index d45abae887..7a9909720d 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp @@ -7,10 +7,19 @@ namespace ck_tile { -// fp16 2:4 structured sparsity +// Primary template: generic Smfmac warp-gemm attribute (to be specialized per supported shape) +template +struct WarpGemmAttributeSmfmacImpl; // no definition, only specializations are provided -template -struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16 +// fp16 2:4 structured sparsity +template +struct WarpGemmAttributeSmfmacImpl { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = fp16_t; @@ -56,12 +65,28 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16 ck_tile::ignore = a_vec; ck_tile::ignore = b_vec; ck_tile::ignore = idx; +#endif + } + + // c_vec = a_vec * b_vec[idx] + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, + const BVecType& b_vec, + const int32_t& idx) const + { +#if defined(__gfx94_) or defined(__gfx95_) + return bit_cast( + __builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, fp32x4_t{0.f}, idx, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = idx; + return CVecType{0.f}; #endif } }; -template -struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32 +template +struct WarpGemmAttributeSmfmacImpl { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = fp16_t; @@ -107,8 +132,33 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32 ck_tile::ignore = a_vec; ck_tile::ignore = b_vec; ck_tile::ignore = idx; +#endif + } + + // c_vec = a_vec * b_vec[idx] + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, + const BVecType& b_vec, + const int32_t& idx) const + { +#if defined(__gfx94_) or defined(__gfx95_) + return bit_cast( + __builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, fp32x4_t{0.f}, idx, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = idx; + return CVecType{0.f}; #endif } }; +// Back-compat aliases +template +using WarpGemmAttributeSmfmacImplF16F16F32M16N16K32 = + WarpGemmAttributeSmfmacImpl; + +template +using WarpGemmAttributeSmfmacImplF16F16F32M32N32K16 = + WarpGemmAttributeSmfmacImpl; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_core_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_core_dispatcher.hpp new file mode 100644 index 0000000000..01af6b32ed --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_core_dispatcher.hpp @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" +// Add smfmac attribute impl for structured sparsity support +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp" + +namespace ck_tile { + +// WarpGemmCoreDispatcher: choose the underlying MFMA/SMFMAC attribute Impl +// based on A/B/Acc types and (M,N,K) per wave. +template +struct WarpGemmCoreDispatcher; + +// Generic specialization for MFMA +template +struct WarpGemmCoreDispatcher +{ + using Impl = WarpGemmAttributeMfmaImpl; +}; + +// Generic specialization for SMFMAC +// TODO: we also need to support smfmac for FP8/BF8 and I8 format +template +struct WarpGemmCoreDispatcher +{ + using Impl = WarpGemmAttributeSmfmacImpl; +}; + +// Specialization for special cases +template <> +struct WarpGemmCoreDispatcher +{ +#if defined(__gfx950__) + using Impl = WarpGemmAttributeMfmaImpl; +#else + using Impl = WarpGemmAttributeMfmaImpl; +#endif +}; + +template <> +struct WarpGemmCoreDispatcher +{ +#if defined(__gfx950__) + using Impl = WarpGemmAttributeMfmaImpl; +#else + using Impl = WarpGemmAttributeMfmaImpl; +#endif +}; + +template <> +struct WarpGemmCoreDispatcher +{ + using Impl = WarpGemmAttributeMfmaImpl; +}; + +template <> +struct WarpGemmCoreDispatcher +{ + using Impl = WarpGemmAttributeMfmaImpl; +}; + +template <> +struct WarpGemmCoreDispatcher +{ + using Impl = WarpGemmAttributeMfmaImpl; +}; + +template <> +struct WarpGemmCoreDispatcher +{ + using Impl = WarpGemmAttributeMfmaImpl; +}; + +template <> +struct WarpGemmCoreDispatcher +{ + using Impl = WarpGemmAttributeMfmaImpl; +}; + +template <> +struct WarpGemmCoreDispatcher +{ + using Impl = WarpGemmAttributeMfmaImpl; +}; + +template <> +struct WarpGemmCoreDispatcher +{ +#if defined(__gfx950__) + using Impl = WarpGemmAttributeMfmaImpl; +#else + using Impl = WarpGemmAttributeMfmaImpl; +#endif +}; + +template <> +struct WarpGemmCoreDispatcher +{ +#if defined(__gfx950__) + using Impl = WarpGemmAttributeMfmaImpl; +#else + using Impl = WarpGemmAttributeMfmaImpl; +#endif +}; + +template <> +struct WarpGemmCoreDispatcher +{ + using Impl = WarpGemmAttributeMfmaImpl; +}; + +template <> +struct WarpGemmCoreDispatcher +{ + using Impl = WarpGemmAttributeMfmaImpl; +}; + +// Iterate-K variants built from the direct impls +template <> +struct WarpGemmCoreDispatcher +{ + using Impl = WarpGemmAttributeMfmaImpl; +}; + +template <> +struct WarpGemmCoreDispatcher +{ + using Impl = WarpGemmAttributeMfmaImpl; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index d6c21e88b5..720b11b3f6 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -6,7 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp" - +#include "ck_tile/ops/gemm/warp/detail/warp_gemm_attribute_mfma_compose.hpp" namespace ck_tile { namespace impl { @@ -163,16 +163,30 @@ template -using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< // - AType, - BType, - AccType, - MPerWave, - NPerWave, - KPerWave, - TransposeC, - SwizzleA, - UseStructuredSparsity, - AttrNumAccess>::Type; +using WarpGemmDispatcher = +#if defined(CK_TILE_ROUTE_WARP_GEMM_DISPATCHER_TO_MAKE) + typename MakeWarpGemm::Type; +#else + using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< // + AType, + BType, + AccType, + MPerWave, + NPerWave, + KPerWave, + TransposeC, + SwizzleA, + UseStructuredSparsity, + AttrNumAccess>::Type; +#endif } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp index 3d64e148c4..c9828a0b43 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp @@ -105,6 +105,38 @@ struct WarpGemmSmfmacImpl c.get_thread_buffer().template set_as(I0, c_vec); } + + template + CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const + { + using CTensor = CWarpTensor; + static_assert(detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v); + constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio; + + using AVec = ext_vector_t; + using AVecCompressed = + ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; + + constexpr auto I0 = number<0>{}; + + auto a_vec = a.get_thread_buffer().template get_as()[I0]; + const auto b_vec = b.get_thread_buffer().template get_as()[I0]; + + const int32_t idx = compress_a(a_vec); + + // @TODO can we simply set a_vec_pruned to a_vec[0:3]? + const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]}; + + // c_vec = a_vec * b_vec[idx] + auto c_vec = WarpGemmAttribute{}(a_vec_pruned, b_vec, idx); + + CTensor c; + c.get_thread_buffer().template set_as(I0, c_vec); + return c; + } }; } // namespace ck_tile diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index ee23ad2f63..b8b77e0b88 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -6,6 +6,7 @@ set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() +option(CK_TILE_TEST_ROUTE_DISPATCHER_TO_MAKE "Route ck_tile::WarpGemmDispatcher to MakeWarpGemm in tests" OFF) set(EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_USE_OCP_FP8) @@ -16,6 +17,14 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS ) set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) +# Build routed variants of the common options when the switch is ON +set(EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +if(CK_TILE_TEST_ROUTE_DISPATCHER_TO_MAKE) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_ROUTE_WARP_GEMM_DISPATCHER_TO_MAKE) + list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_ROUTE_WARP_GEMM_DISPATCHER_TO_MAKE) +endif() + + if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12") if(GPU_TARGETS MATCHES "gfx94|gfx95") add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp) @@ -59,3 +68,38 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12") else() message(DEBUG "Skipping ck_tile_gemm tests for current target test_ck_tile_gemm_pipeline") endif() + + +# MFMA (Attribute Compose) - STRICTLY CDNA (gfx90a, gfx94x, gfx950) +if(GPU_TARGETS MATCHES "gfx90a|gfx94|gfx95") + add_gtest_executable(test_ck_tile_gemm_attr_compose_int8_mfma test_warp_gemm_attr_compose_int8_mfma.cpp) + target_compile_options(test_ck_tile_gemm_attr_compose_int8_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED}) +endif() + +# MFMA (Attribute Compose) - STRICTLY CDNA3+ (gfx94x, gfx950) +if(GPU_TARGETS MATCHES "gfx94|gfx95") + add_gtest_executable(test_ck_tile_gemm_attr_compose_fp8_mfma test_warp_gemm_attr_compose_fp8_mfma.cpp) + target_compile_options(test_ck_tile_gemm_attr_compose_fp8_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED}) + + add_gtest_executable(test_ck_tile_gemm_attr_compose_bf8_mfma test_warp_gemm_attr_compose_bf8_mfma.cpp) + target_compile_options(test_ck_tile_gemm_attr_compose_bf8_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED}) + + add_gtest_executable(test_ck_tile_gemm_attr_compose_mixed_fp8_bf8_mfma test_warp_gemm_attr_compose_mixed_fp8_bf8_mfma.cpp) + target_compile_options(test_ck_tile_gemm_attr_compose_mixed_fp8_bf8_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED}) +endif() + +# MFMA (Attribute Compose) - STRICTLY CDNA (gfx90a, gfx94x, gfx950) +if(GPU_TARGETS MATCHES "gfx90a|gfx94|gfx95") + add_gtest_executable(test_ck_tile_gemm_attr_compose_fp16_mfma test_warp_gemm_attr_compose_fp16_mfma.cpp) + target_compile_options(test_ck_tile_gemm_attr_compose_fp16_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED}) + + add_gtest_executable(test_ck_tile_gemm_attr_compose_bf16_mfma test_warp_gemm_attr_compose_bf16_mfma.cpp) + target_compile_options(test_ck_tile_gemm_attr_compose_bf16_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED}) +endif() + +# --- SMFMAC Tests --- +# SMFMAC (Attribute Compose) - STRICTLY CDNA3+ (gfx94x, gfx950) +if(GPU_TARGETS MATCHES "gfx94|gfx95") + add_gtest_executable(test_ck_tile_gemm_attr_compose_fp16_smfmac test_warp_gemm_attr_compose_fp16_smfmac.cpp) + target_compile_options(test_ck_tile_gemm_attr_compose_fp16_smfmac PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED}) +endif() \ No newline at end of file diff --git a/test/ck_tile/gemm/test_warp_gemm_attr_compose.hpp b/test/ck_tile/gemm/test_warp_gemm_attr_compose.hpp new file mode 100644 index 0000000000..c6f52f6f50 --- /dev/null +++ b/test/ck_tile/gemm/test_warp_gemm_attr_compose.hpp @@ -0,0 +1,301 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TILE_TEST_WARP_GEMM_ATTR_COMPOSE_HPP +#define CK_TILE_TEST_WARP_GEMM_ATTR_COMPOSE_HPP + +#include +#include +#include +#include "ck_tile/ops/gemm/warp/detail/warp_gemm_attribute_mfma_compose.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" +// For real kernel run and verification +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +// ---------------------- Typed tests for Dispatcher coverage ---------------------- +template +struct WGDispCase +{ + using AType = A; + using BType = B; + using AccType = Acc; + static constexpr ck_tile::index_t MPerWave = M; + static constexpr ck_tile::index_t NPerWave = N; + static constexpr ck_tile::index_t KPerWave = K; + static constexpr bool kTransposeC = TransposeC; + static constexpr bool kSwizzleA = SwizzleA; + static constexpr bool kUSS = UseStructuredSparsity; + static constexpr ck_tile::WGAttrNumAccessEnum kNA = NA; +}; + +template +class WGCompileTimeTest : public ::testing::Test +{ + public: + void RunTest() + { + // Dispatcher-selected WarpGemm + using Disp = typename ck_tile::WarpGemmDispatcher; + + // Factory-selected WarpGemm (MakeWarpGemm) + using Make = typename ck_tile::MakeWarpGemm::Type; + + // 1) Scalar compile-time constants must match + static_assert(Disp::kM == Make::kM, "kM differs between Dispatcher and MakeWarpGemm"); + static_assert(Disp::kN == Make::kN, "kN differs between Dispatcher and MakeWarpGemm"); + static_assert(Disp::kK == Make::kK, "kK differs between Dispatcher and MakeWarpGemm"); + static_assert(Disp::kKPerThread == Make::kKPerThread, + "kKPerThread differs between Dispatcher and MakeWarpGemm"); + static_assert(Disp::get_num_of_access() == Make::get_num_of_access(), + "get_num_of_access() differs between Dispatcher and MakeWarpGemm"); + + // 2) Data types must match + static_assert(std::is_same_v, + "ADataType differs between Dispatcher and MakeWarpGemm"); + static_assert(std::is_same_v, + "BDataType differs between Dispatcher and MakeWarpGemm"); + static_assert(std::is_same_v, + "CDataType differs between Dispatcher and MakeWarpGemm"); + + // 3) Distribution encodings must match (ensures identical warp tiling/layout) + static_assert( + std::is_same_v, + "AWarpDstrEncoding differs between Dispatcher and MakeWarpGemm"); + static_assert( + std::is_same_v, + "BWarpDstrEncoding differs between Dispatcher and MakeWarpGemm"); + static_assert( + std::is_same_v, + "CWarpDstrEncoding differs between Dispatcher and MakeWarpGemm"); + + // 4) Final tensor types must match (encodes DataType + Distribution) + static_assert(std::is_same_v, + "AWarpTensor differs between Dispatcher and MakeWarpGemm"); + static_assert(std::is_same_v, + "BWarpTensor differs between Dispatcher and MakeWarpGemm"); + static_assert(std::is_same_v, + "CWarpTensor differs between Dispatcher and MakeWarpGemm"); + + SUCCEED(); + } +}; + +// ---------------------- Runtime tests: Compare Dispatcher (MFMA and SMFMA only) vs MakeWarpGemm vs +// CPU ---------------------- +// ---------------------- Runtime operator() behavior tests on GPU ---------------------- +template +struct WarpGemmKernel +{ + static constexpr int kBlockSize = 64; + __device__ void operator()(const AType* A, const BType* B, CType* C) const + { + using WarpGemm = std::conditional_t::Type, + ck_tile::WarpGemmDispatcher>; + + // A: [M,K] row-major (packed) + const auto a_view = + ck_tile::make_naive_tensor_view_packed( + const_cast(A), ck_tile::make_tuple(M, K)); + // B: expose as logical [N,K] with strides (1, N) over the original row-major [K,N] buffer + const auto b_view = ck_tile::make_naive_tensor_view( + const_cast(B), ck_tile::make_tuple(N, K), ck_tile::make_tuple(1, N)); + // C: [M,N] row-major (packed) + const auto c_view = + ck_tile::make_naive_tensor_view_packed( + const_cast(C), ck_tile::make_tuple(M, N)); + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + constexpr auto a_len = AWarpTensor::get_tile_distribution().get_lengths(); + constexpr auto b_len = BWarpTensor::get_tile_distribution().get_lengths(); + constexpr auto c_len = CWarpTensor::get_tile_distribution().get_lengths(); + + auto a_win = ck_tile::make_tile_window( + a_view, a_len, ck_tile::make_multi_index(0, 0), AWarpTensor::get_tile_distribution()); + auto b_win = ck_tile::make_tile_window( + b_view, b_len, ck_tile::make_multi_index(0, 0), BWarpTensor::get_tile_distribution()); + auto c_win = ck_tile::make_tile_window( + c_view, c_len, ck_tile::make_multi_index(0, 0), CWarpTensor::get_tile_distribution()); + + AWarpTensor a_tile; + BWarpTensor b_tile; + ck_tile::load_tile(a_tile, a_win); + ck_tile::load_tile(b_tile, b_win); + + CWarpTensor c_tile; + c_tile = WarpGemm{}(a_tile, b_tile); + ck_tile::store_tile(c_win, c_tile); + } +}; + +// ---------------------- New runtime helper: run a WG on device with given A/B into C +// ---------------------- +template +static void RunWarpGemmCase(const ck_tile::HostTensor& A, + const ck_tile::HostTensor& B, + ck_tile::HostTensor& C) +{ + using AType = typename Case::AType; + using BType = typename Case::BType; + using CType = typename Case::AccType; // CDataType equals Acc for these tests + + ck_tile::DeviceMem Ad(A.get_element_space_size_in_bytes()); + ck_tile::DeviceMem Bd(B.get_element_space_size_in_bytes()); + ck_tile::DeviceMem Cd(C.get_element_space_size_in_bytes()); + + Ad.ToDevice(A.data()); + Bd.ToDevice(B.data()); + Cd.SetZero(); + + dim3 grid(1); + dim3 block{64}; + + using Kernel = WarpGemmKernel; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true}, + ck_tile::make_kernel(Kernel{}, + grid, + block, + 0, + static_cast(Ad.GetDeviceBuffer()), + static_cast(Bd.GetDeviceBuffer()), + static_cast(Cd.GetDeviceBuffer()))); + + Cd.FromDevice(C.mData.data()); +} + +// enforce 2:4 sparsity on A for SMFMA runtime cases (only meaningful for half_t here) +template +static inline void make_2to4_sparse_A(ck_tile::HostTensor& A) +{ + // zero half the values in each consecutive group of 4 along K for each row m + const ck_tile::index_t M = A.mDesc.get_lengths()[0]; + const ck_tile::index_t K = A.mDesc.get_lengths()[1]; + for(ck_tile::index_t m = 0; m < M; ++m) + { + for(ck_tile::index_t k = 0; k + 3 < K; k += 4) + { + // keep entries at k and k+2, zero k+1 and k+3 (simple 2:4 pattern) + A(m, k + 1) = ck_tile::type_convert(0); + A(m, k + 3) = ck_tile::type_convert(0); + } + } +} + +template +class WGRuntimeTest : public ::testing::Test +{ + public: + void RunTest() + { + // Equivalent MakeWarpGemm + using AType = typename Case::AType; + using BType = typename Case::BType; + using CType = typename Case::AccType; + + constexpr ck_tile::index_t M = Case::MPerWave; + constexpr ck_tile::index_t N = Case::NPerWave; + constexpr ck_tile::index_t K = Case::KPerWave; + + ck_tile::HostTensor A({M, K}); + ck_tile::HostTensor B({K, N}); + ck_tile::HostTensor C_disp({M, N}); + ck_tile::HostTensor C_make({M, N}); + + for(ck_tile::index_t m = 0; m < M; ++m) + for(ck_tile::index_t k = 0; k < K; ++k) + A(m, k) = ck_tile::type_convert((m + 1) * 0.1f + (k + 1) * 0.01f); + + if constexpr(Case::kUSS) + { + // ensure A satisfies 2:4 sparsity for SMFMA + make_2to4_sparse_A(A); + } + + for(ck_tile::index_t k0 = 0; k0 < K; ++k0) + for(ck_tile::index_t n = 0; n < N; ++n) + B(k0, n) = ck_tile::type_convert((k0 + 1) * 0.2f - (n + 1) * 0.03f); + + C_disp.SetZero(); + C_make.SetZero(); + RunWarpGemmCase(A, B, C_disp); + RunWarpGemmCase(A, B, C_make); + + EXPECT_TRUE( + ck_tile::check_err(C_disp, C_make, "Dispatcher vs MakeWarpGemm mismatch", 0, 0)); + } +}; + +#endif // CK_TILE_TEST_WARP_GEMM_ATTR_COMPOSE_HPP diff --git a/test/ck_tile/gemm/test_warp_gemm_attr_compose_bf16_mfma.cpp b/test/ck_tile/gemm/test_warp_gemm_attr_compose_bf16_mfma.cpp new file mode 100644 index 0000000000..a9dcd0713b --- /dev/null +++ b/test/ck_tile/gemm/test_warp_gemm_attr_compose_bf16_mfma.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_warp_gemm_attr_compose.hpp" + +using WGDispatcherTypesList = ::testing::Types< + // clang-format off + + // bf16 + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + // WGDispCase, + WGDispCase, + WGDispCase>; +// clang-format on + +TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList); +TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList); + +TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); } + +TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); } diff --git a/test/ck_tile/gemm/test_warp_gemm_attr_compose_bf8_mfma.cpp b/test/ck_tile/gemm/test_warp_gemm_attr_compose_bf8_mfma.cpp new file mode 100644 index 0000000000..b132fb15c2 --- /dev/null +++ b/test/ck_tile/gemm/test_warp_gemm_attr_compose_bf8_mfma.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_warp_gemm_attr_compose.hpp" + +using WGDispatcherTypesList = ::testing::Types< + // clang-format off + + // bf8 + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase>; +// clang-format on + +TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList); +TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList); + +TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); } + +TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); } diff --git a/test/ck_tile/gemm/test_warp_gemm_attr_compose_fp16_mfma.cpp b/test/ck_tile/gemm/test_warp_gemm_attr_compose_fp16_mfma.cpp new file mode 100644 index 0000000000..faa59a7bbe --- /dev/null +++ b/test/ck_tile/gemm/test_warp_gemm_attr_compose_fp16_mfma.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_warp_gemm_attr_compose.hpp" + +using WGDispatcherTypesList = ::testing::Types< + // clang-format off + + // fp16 + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + // WGDispCase, + WGDispCase, + WGDispCase>; +// clang-format on + +TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList); +TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList); + +TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); } + +TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); } diff --git a/test/ck_tile/gemm/test_warp_gemm_attr_compose_fp16_smfmac.cpp b/test/ck_tile/gemm/test_warp_gemm_attr_compose_fp16_smfmac.cpp new file mode 100644 index 0000000000..55b24ef532 --- /dev/null +++ b/test/ck_tile/gemm/test_warp_gemm_attr_compose_fp16_smfmac.cpp @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_warp_gemm_attr_compose.hpp" + +using WGDispatcherTypesList = ::testing::Types< + // clang-format off + + // fp16 2:4 structural sparsity + WGDispCase, + WGDispCase>; +// clang-format on + +TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList); +TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList); + +TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); } + +TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); } diff --git a/test/ck_tile/gemm/test_warp_gemm_attr_compose_fp8_mfma.cpp b/test/ck_tile/gemm/test_warp_gemm_attr_compose_fp8_mfma.cpp new file mode 100644 index 0000000000..3deaea4c74 --- /dev/null +++ b/test/ck_tile/gemm/test_warp_gemm_attr_compose_fp8_mfma.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_warp_gemm_attr_compose.hpp" + +using WGDispatcherTypesList = ::testing::Types< + // clang-format off + + // fp8 + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase>; +// clang-format on + +TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList); +TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList); + +TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); } + +TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); } diff --git a/test/ck_tile/gemm/test_warp_gemm_attr_compose_int8_mfma.cpp b/test/ck_tile/gemm/test_warp_gemm_attr_compose_int8_mfma.cpp new file mode 100644 index 0000000000..e7a80e414f --- /dev/null +++ b/test/ck_tile/gemm/test_warp_gemm_attr_compose_int8_mfma.cpp @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_warp_gemm_attr_compose.hpp" + +using WGDispatcherTypesList = ::testing::Types< + // clang-format off + + // int8 + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase>; +// clang-format on + +TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList); +TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList); + +TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); } + +TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); } diff --git a/test/ck_tile/gemm/test_warp_gemm_attr_compose_mixed_fp8_bf8_mfma.cpp b/test/ck_tile/gemm/test_warp_gemm_attr_compose_mixed_fp8_bf8_mfma.cpp new file mode 100644 index 0000000000..0f5b4cbbb3 --- /dev/null +++ b/test/ck_tile/gemm/test_warp_gemm_attr_compose_mixed_fp8_bf8_mfma.cpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_warp_gemm_attr_compose.hpp" + +using WGDispatcherTypesList = ::testing::Types< + // clang-format off + + // mixed fp8/bf8 + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase, + WGDispCase>; +// clang-format on + +TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList); +TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList); + +TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); } + +TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); }