Refactor WarpGemm dispatcher and compose attributes

* Introduce WarpGemmCoreDispatcher to select MFMA/SMFMAC Implmentations by A/B/Acc types, (M,N,K) per wave, and structured sparsity flag

* Add composed attribute layer (detail/warp_gemm_attribute_mfma_compose.hpp) to fold policies (IterateK, TransposeC, Swizzle, NumAccess) and derive A/B/C encodings
Update warp_gemm_dispatcher.hpp to route through the new core dispatcher

* Align SMFMAC attribute and wrapper: expose A/B/C distribution encodings and unify get_num_of_access; adjust WarpGemmSmfmacImpl integration
Expand/adjust MFMA attribute impls to work with the new composition/dispatcher path

* Add unit test for attribute composition (test_warp_gemm_attr_compose.cpp) and wire test CMake; minor example CMake tweak (example/ck_tile/03_gemm)
This commit is contained in:
Jeff Huang
2025-11-07 12:06:15 +08:00
parent 51027474af
commit ab768af196
17 changed files with 1967 additions and 49 deletions

View File

@@ -13,6 +13,13 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a")
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
option(CK_TILE_EXAMPLE_ROUTE_DISPATCHER_TO_MAKE "Route ck_tile::WarpGemmDispatcher to MakeWarpGemm in examples" OFF)
# When routing dispatcher to MakeWarpGemm is enabled, add the macro to the shared options
if(CK_TILE_EXAMPLE_ROUTE_DISPATCHER_TO_MAKE)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_ROUTE_WARP_GEMM_DISPATCHER_TO_MAKE)
endif()
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps)

View File

@@ -0,0 +1,709 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
// smfmac (structured sparsity)
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
namespace ck_tile {
namespace detail {
namespace wg_attr_compose {
// Basic config/state that policies transform at compile-time
struct NoneTag;
enum class SwizzleKind
{
None,
A,
B
};
template <typename Impl_,
bool TransposeC_,
SwizzleKind Swizzle_,
index_t SFactor_,
index_t KIter_,
WGAttrNumAccessEnum NumAccess_>
struct State
{
using Impl = remove_cvref_t<Impl_>;
static constexpr bool TransposeC = TransposeC_;
static constexpr SwizzleKind Swizzle = Swizzle_;
static constexpr index_t SFactor = SFactor_;
static constexpr index_t KIter = KIter_;
static constexpr WGAttrNumAccessEnum NumAccess = NumAccess_;
static constexpr index_t NumAccValue = static_cast<index_t>(NumAccess);
};
// Default base state
template <typename Impl, WGAttrNumAccessEnum NumAccess = WGAttrNumAccessEnum::Single>
using BaseState = State<Impl, false, SwizzleKind::None, 1, 1, NumAccess>;
// Helpers to compute derived constants and types from State
// A/B data/vec types under transpose flag
template <class S>
struct ABTypes
{
using ADataType =
std::conditional_t<S::TransposeC, typename S::Impl::BDataType, typename S::Impl::ADataType>;
using BDataType =
std::conditional_t<S::TransposeC, typename S::Impl::ADataType, typename S::Impl::BDataType>;
using AVecBase =
std::conditional_t<S::TransposeC, typename S::Impl::BVecType, typename S::Impl::AVecType>;
using BVecBase =
std::conditional_t<S::TransposeC, typename S::Impl::AVecType, typename S::Impl::BVecType>;
using AVecType = ext_vector_t<ADataType, vector_traits<AVecBase>::vector_size * S::KIter>;
using BVecType = ext_vector_t<BDataType, vector_traits<BVecBase>::vector_size * S::KIter>;
};
// C types/shape (C does not change type with policies)
template <class S>
struct CTypes
{
using CDataType = typename S::Impl::CDataType;
using CVecType = typename S::Impl::CVecType;
};
// kM/kN/kK/kKPerThread under transpose and iterateK
template <class S>
struct KShape
{
static constexpr index_t kM = S::TransposeC ? S::Impl::kN : S::Impl::kM;
static constexpr index_t kN = S::TransposeC ? S::Impl::kM : S::Impl::kN;
static constexpr index_t kK = S::Impl::kK * S::KIter;
static constexpr index_t kKPerThread = S::Impl::kABKPerLane * S::KIter;
static constexpr index_t kCMLane = S::Impl::kCMLane; // unchanged
};
// Lane helpers
template <class S>
struct Lanes
{
static constexpr index_t AMLane = S::TransposeC ? S::Impl::kBNLane : S::Impl::kAMLane;
static constexpr index_t BNLane = S::TransposeC ? S::Impl::kAMLane : S::Impl::kBNLane;
};
// Encoding builders centralizing existing logic, parameterized by S
// A encodings with NumAccess and IterateK folded, with multi-block and swizzle
template <class S>
CK_TILE_DEVICE static constexpr auto MakeAWarpDstrEncoding()
{
constexpr index_t NumAccValue = S::NumAccValue;
// Helper lambdas for the base A-encoding and the swizzled variant
auto base_enc = []() {
if constexpr(S::Impl::kAMBlock == 1 && S::Impl::kBNBlock == 1)
{
if constexpr(NumAccValue == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Lanes<S>::AMLane>,
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
static_assert(KShape<S>::kKPerThread % NumAccValue == 0,
"kKPerThread must be divisible by NumAccess");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Lanes<S>::AMLane>,
sequence<NumAccValue,
S::Impl::kABKLane,
S::Impl::kABKPerLane * S::KIter / NumAccValue>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
else if constexpr(S::Impl::kAMBlock == 1 && 1 < S::Impl::kBNBlock)
{
static_assert(NumAccValue == 1,
"Multiple access is not supported when using multi-block");
return tile_distribution_encoding<
sequence<S::Impl::kBNBlock>,
tuple<sequence<Lanes<S>::AMLane>,
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
else // 1 < AMBlock && BNBlock == 1
{
static_assert(NumAccValue == 1,
"Multiple access is not supported when using multi-block");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Impl::kAMBlock, Lanes<S>::AMLane>,
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
};
auto swizzled_enc = []() {
return tile_distribution_encoding<
sequence<>,
tuple<
sequence<S::Impl::kAMLane / (S::Impl::kCMLane * S::SFactor * S::Impl::kCM1PerLane),
S::Impl::kCMLane,
S::SFactor,
S::Impl::kCM1PerLane>,
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
tuple<sequence<2, 1, 1, 1, 1>>,
tuple<sequence<0, 0, 2, 1, 3>>,
sequence<2>,
sequence<1>>{};
};
if constexpr(S::Swizzle == SwizzleKind::A)
return swizzled_enc();
else
return base_enc();
}
// B encodings with NumAccess and IterateK folded, with multi-block and swizzle
template <class S>
CK_TILE_DEVICE static constexpr auto MakeBWarpDstrEncoding()
{
constexpr index_t NumAccValue = S::NumAccValue;
auto base_enc = []() {
if constexpr(S::Impl::kAMBlock == 1 && S::Impl::kBNBlock == 1)
{
if constexpr(NumAccValue == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Lanes<S>::BNLane>,
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
static_assert(KShape<S>::kKPerThread % NumAccValue == 0,
"kKPerThread must be divisible by NumAccess");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Lanes<S>::BNLane>,
sequence<NumAccValue,
S::Impl::kABKLane,
S::Impl::kABKPerLane * S::KIter / NumAccValue>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
else if constexpr(S::Impl::kAMBlock == 1 && 1 < S::Impl::kBNBlock)
{
static_assert(NumAccValue == 1,
"Multiple access is not supported when using multi-block");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Impl::kBNBlock, Lanes<S>::BNLane>,
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
else // 1 < AMBlock && BNBlock == 1
{
static_assert(NumAccValue == 1,
"Multiple access is not supported when using multi-block");
return tile_distribution_encoding<
sequence<S::Impl::kAMBlock>,
tuple<sequence<Lanes<S>::BNLane>,
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
};
auto swizzled_enc = []() {
return tile_distribution_encoding<
sequence<>,
tuple<
sequence<S::Impl::kAMLane / (S::Impl::kCMLane * S::SFactor * S::Impl::kCM1PerLane),
S::Impl::kCMLane,
S::SFactor,
S::Impl::kCM1PerLane>,
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
tuple<sequence<2, 1, 1, 1, 1>>,
tuple<sequence<0, 0, 2, 1, 3>>,
sequence<2>,
sequence<1>>{};
};
if constexpr(S::Swizzle == SwizzleKind::B)
return swizzled_enc();
else
return base_enc();
}
// C distribution encoding with transpose and multi-block
template <class S>
CK_TILE_DEVICE static constexpr auto MakeCWarpDstrEncoding()
{
constexpr bool HasSwizzle = (S::Swizzle == SwizzleKind::A) || (S::Swizzle == SwizzleKind::B);
auto make_m_splits = []() {
if constexpr(HasSwizzle)
{
// Swizzled M splits
return sequence<S::Impl::kCM0PerLane / S::SFactor,
S::Impl::kCMLane,
S::Impl::kCM1PerLane * S::SFactor>{};
}
else
{
return sequence<S::Impl::kCM0PerLane, S::Impl::kCMLane, S::Impl::kCM1PerLane>{};
}
};
if constexpr(S::Impl::kAMBlock == 1 && S::Impl::kBNBlock == 1)
{
if constexpr(!S::TransposeC)
{
return tile_distribution_encoding<
sequence<>,
tuple<decltype(make_m_splits()), sequence<S::Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
else // TransposeC
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Impl::kCNLane>, decltype(make_m_splits())>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
else if constexpr(S::Impl::kAMBlock == 1 && 1 < S::Impl::kBNBlock)
{
if constexpr(!S::TransposeC)
{
return tile_distribution_encoding<
sequence<>,
tuple<decltype(make_m_splits()), sequence<S::Impl::kBNBlock * S::Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
else
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Impl::kBNBlock * S::Impl::kCNLane>, decltype(make_m_splits())>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
else if constexpr(1 < S::Impl::kAMBlock && S::Impl::kBNBlock == 1)
{
if constexpr(!S::TransposeC)
{
return tile_distribution_encoding<sequence<>,
tuple<sequence<S::Impl::kCM0PerLane,
S::Impl::kAMBlock * S::Impl::kCMLane,
S::Impl::kCM1PerLane>,
sequence<S::Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
else
{
return tile_distribution_encoding<sequence<>,
tuple<sequence<S::Impl::kCNLane>,
sequence<S::Impl::kCM0PerLane,
S::Impl::kAMBlock * S::Impl::kCMLane,
S::Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
}
// Detect smfmac by Impl type: provide a small trait that is true for smfmac attribute impls
template <class T>
struct is_smfmac_impl : std::false_type
{
};
template <typename AType_,
typename BType_,
typename AccType_,
index_t MPerWave_,
index_t NPerWave_,
index_t KPerWave_,
WGAttrCtlEnum C>
struct is_smfmac_impl<
WarpGemmAttributeSmfmacImpl<AType_, BType_, AccType_, MPerWave_, NPerWave_, KPerWave_, C>>
: std::true_type
{
};
// Final composed attribute
// Primary (MFMA) and SMFMA-specialized ComposedAttribute
template <class S, bool IsSmfmac = is_smfmac_impl<typename S::Impl>::value>
struct ComposedAttribute
{
using Impl = typename S::Impl;
using ATypes = ABTypes<S>;
using CTypesT = CTypes<S>;
static constexpr index_t kM = KShape<S>::kM;
static constexpr index_t kN = KShape<S>::kN;
static constexpr index_t kK = KShape<S>::kK;
static constexpr index_t kKPerThread = KShape<S>::kKPerThread;
static constexpr index_t kCMLane = KShape<S>::kCMLane;
using ADataType = typename ATypes::ADataType;
using BDataType = typename ATypes::BDataType;
using CDataType = typename CTypesT::CDataType;
using AVecType = typename ATypes::AVecType;
using BVecType = typename ATypes::BVecType;
using CVecType = typename CTypesT::CVecType;
using AWarpDstrEncoding = decltype(MakeAWarpDstrEncoding<S>());
using BWarpDstrEncoding = decltype(MakeBWarpDstrEncoding<S>());
using CWarpDstrEncoding = decltype(MakeCWarpDstrEncoding<S>());
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return S::KIter; }
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(S::KIter == 1)
{
if constexpr(S::TransposeC)
{
// When TransposeC, composed A/B are swapped relative to Impl.
// Pass b_vec as Impl::AVecType and a_vec as Impl::BVecType.
Impl{}(c_vec,
reinterpret_cast<const typename Impl::AVecType&>(b_vec),
reinterpret_cast<const typename Impl::BVecType&>(a_vec),
bool_constant<post_nop_>{});
}
else
{
Impl{}(c_vec,
reinterpret_cast<const typename Impl::AVecType&>(a_vec),
reinterpret_cast<const typename Impl::BVecType&>(b_vec),
bool_constant<post_nop_>{});
}
}
else
{
using buf_a = thread_buffer<typename Impl::AVecType, S::KIter>;
using buf_b = thread_buffer<typename Impl::BVecType, S::KIter>;
static_for<0, S::KIter, 1>{}([&](auto iKIter) {
if constexpr(S::TransposeC)
{
// Swap mapping: b_vec -> Impl::AVecType, a_vec -> Impl::BVecType
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
}
else
{
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
}
});
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
if constexpr(S::KIter == 1)
{
if constexpr(S::TransposeC)
{
// Swap mapping: b_vec -> Impl::AVecType, a_vec -> Impl::BVecType
return Impl{}(reinterpret_cast<const typename Impl::AVecType&>(b_vec),
reinterpret_cast<const typename Impl::BVecType&>(a_vec));
}
else
{
return Impl{}(reinterpret_cast<const typename Impl::AVecType&>(a_vec),
reinterpret_cast<const typename Impl::BVecType&>(b_vec));
}
}
else
{
using buf_a = thread_buffer<typename Impl::AVecType, S::KIter>;
using buf_b = thread_buffer<typename Impl::BVecType, S::KIter>;
constexpr auto I0 = number<0>{};
CVecType c_vec;
if constexpr(S::TransposeC)
{
// Swap mapping: b_vec -> Impl::AVecType, a_vec -> Impl::BVecType
c_vec = Impl{}(reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::AVecType>()[I0],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::BVecType>()[I0]);
static_for<1, S::KIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::BVecType>()[iKIter]);
});
}
else
{
c_vec = Impl{}(reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[I0],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[I0]);
static_for<1, S::KIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]);
});
}
return c_vec;
}
}
};
// SMFMA specialization: forbid Transpose and Swizzle for now; KIter must be 1
// TODO: enable swizzle, transpose for smfmac
template <class S>
struct ComposedAttribute<S, true>
{
using Impl = typename S::Impl;
static_assert(!S::TransposeC, "smfmac TransposeC is not supported in composed attributes");
static_assert(S::Swizzle == SwizzleKind::None,
"smfmac Swizzle is not supported in composed attributes");
static_assert(S::KIter == 1, "smfmac IterateK is not supported (KIter must be 1)");
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename Impl::AVecType;
using BVecType = typename Impl::BVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
static constexpr index_t kCMLane = Impl::kCMLane;
static constexpr index_t kCompressionRatio = Impl::CompressionRatio;
// Reuse the encodings defined by the smfmac attribute wrapper for consistency
using AWarpDstrEncoding = typename WarpGemmAttributeSmfmac<Impl>::AWarpDstrEncoding;
using BWarpDstrEncoding = typename WarpGemmAttributeSmfmac<Impl>::BWarpDstrEncoding;
using CWarpDstrEncoding = typename WarpGemmAttributeSmfmac<Impl>::CWarpDstrEncoding;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
// c_vec += a_vec * b_vec[idx]
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
const int32_t& idx,
bool_constant<post_nop_> = {}) const
{
Impl{}(c_vec, a_vec, b_vec, idx, bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec[idx]
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
const BVecType& b_vec,
const int32_t& idx) const
{
CVecType c_vec{0};
Impl{}(c_vec, a_vec, b_vec, idx);
return c_vec;
}
};
// Policy wrappers produce a new State from an old one
template <index_t KIter>
struct PolicyIterateK
{
template <class S>
using apply =
State<typename S::Impl, S::TransposeC, S::Swizzle, S::SFactor, KIter, S::NumAccess>;
};
struct PolicyTransposeC
{
template <class S>
using apply = State<typename S::Impl, true, S::Swizzle, S::SFactor, S::KIter, S::NumAccess>;
};
template <index_t SFactor>
struct PolicySwizzleA
{
template <class S>
using apply =
State<typename S::Impl, S::TransposeC, SwizzleKind::A, SFactor, S::KIter, S::NumAccess>;
};
template <index_t SFactor>
struct PolicySwizzleB
{
template <class S>
using apply =
State<typename S::Impl, S::TransposeC, SwizzleKind::B, SFactor, S::KIter, S::NumAccess>;
};
} // namespace wg_attr_compose
} // namespace detail
} // namespace ck_tile
// High-level alias to construct a composed WarpGemm attribute via CoreDispatcher and policies
#include "ck_tile/ops/gemm/warp/warp_gemm_core_dispatcher.hpp"
namespace ck_tile {
// Helper that uses core dispatcher to select MFMA vs smfmac and composes policies only
template <bool TransposeC,
bool SwizzleA,
bool UseStructuredSparsity,
typename AType,
typename BType,
typename AccType,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
WGAttrNumAccessEnum NumAccess>
struct ComposePolicies
{
using CD = WarpGemmCoreDispatcher<AType,
BType,
AccType,
MPerWave,
NPerWave,
KPerWave,
UseStructuredSparsity>;
using Impl = typename CD::Impl;
static_assert(Impl::kK > 0, "Invalid K dimension");
static_assert(Impl::kK <= KPerWave, "KPerWave must smaller and equal than Impl::kK");
static constexpr index_t KIter = KPerWave / Impl::kK;
static constexpr bool IsSmfmac = detail::wg_attr_compose::is_smfmac_impl<Impl>::value;
// First, setup a default state as the base state.
using S0 = detail::wg_attr_compose::BaseState<Impl, NumAccess>;
// The order of policies matters so we compose them with the order: Kiter->TransposeC->SwizzleA.
using S1 = std::conditional_t<
(KIter == 1),
S0,
typename detail::wg_attr_compose::PolicyIterateK<KIter>::template apply<S0>>;
using S2 =
std::conditional_t<TransposeC,
typename detail::wg_attr_compose::PolicyTransposeC::template apply<S1>,
S1>;
// Match dispatcher behavior: when TransposeC is enabled, the swizzle applies to B
// (i.e., SwizzleBTransposedCDistribution). Otherwise apply swizzle to A.
using S3 = std::conditional_t<
SwizzleA && !TransposeC,
typename detail::wg_attr_compose::PolicySwizzleA<2>::template apply<S2>,
std::conditional_t<SwizzleA && TransposeC,
typename detail::wg_attr_compose::PolicySwizzleB<2>::template apply<S2>,
S2>>;
// For SMFMA, use the dedicated WarpGemmSmfmacImpl wrapper with the SMFMA attribute.
// For MFMA (default), use the regular WarpGemmImpl over the composed attribute.
using type = std::conditional_t<
IsSmfmac,
ck_tile::WarpGemmSmfmacImpl<detail::wg_attr_compose::ComposedAttribute<S3>>,
ck_tile::WarpGemmImpl<detail::wg_attr_compose::ComposedAttribute<S3>>>;
};
// Wrapper struct to match usage as MakeWarpGemm<...>::Type
template <bool TransposeC,
bool SwizzleA,
typename AType,
typename BType,
typename AccType,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum NumAccess = WGAttrNumAccessEnum::Single>
struct MakeWarpGemm
{
using Type = typename ComposePolicies<TransposeC,
SwizzleA,
UseStructuredSparsity,
AType,
BType,
AccType,
MPerWave,
NPerWave,
KPerWave,
NumAccess>::type;
};
} // namespace ck_tile

View File

@@ -22,6 +22,16 @@ enum class WGAttrCtlEnum
// raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
};
// Primary template: generic MFMA warp-gemm attribute (to be specialized per supported shape)
template <typename ADataType,
typename BDataType,
typename CDataType,
index_t kM,
index_t kN,
index_t kK,
WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl; // no definition, only specializations are provided
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
if constexpr(post_nop_) \
{ \
@@ -191,8 +201,8 @@ struct WarpGemmAttributeMfmaImplF32F32F32M32N32K2
};
// V_MFMA_F32_16x16x32_BF16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 16, 16, 32, Ctrl_>
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
@@ -254,8 +264,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32
}
};
// FP16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 32, 32, 8, Ctrl_>
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
@@ -317,8 +327,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 16, 16, 16, Ctrl_>
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
@@ -380,8 +390,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M16N16K32
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 16, 16, 32, Ctrl_>
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
@@ -443,8 +453,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K32
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M4N64K4
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 4, 64, 4, Ctrl_>
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
@@ -507,8 +517,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M4N64K4
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M64N4K4
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 64, 4, 4, Ctrl_>
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
@@ -572,8 +582,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M64N4K4
};
// Bf16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 32, 32, 8, Ctrl_>
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
@@ -661,8 +671,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 16, 16, 16, Ctrl_>
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
@@ -749,8 +759,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 4, 64, 4, Ctrl_>
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
@@ -839,8 +849,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 64, 4, 4, Ctrl_>
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
@@ -930,8 +940,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
};
// gfx950
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K16
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 32, 32, 16, Ctrl_>
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
@@ -1044,8 +1054,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K16
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 32, 32, 16, Ctrl_>
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
@@ -1158,6 +1168,60 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16
}
};
// ---------------------------------------------------------------------------------
// Backward-compatibility aliases (preserve original names)
// ---------------------------------------------------------------------------------
// BF16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32 =
WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 16, 16, 32, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 =
WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 32, 32, 8, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 =
WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 16, 16, 16, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4 =
WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 4, 64, 4, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 =
WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 64, 4, 4, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16 =
WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 32, 32, 16, Ctrl_>;
// F16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImplF16F16F32M32N32K8 =
WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 32, 32, 8, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImplF16F16F32M16N16K16 =
WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 16, 16, 16, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImplF16F16F32M16N16K32 =
WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 16, 16, 32, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImplF16F16F32M4N64K4 =
WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 4, 64, 4, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImplF16F16F32M64N4K4 =
WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 64, 4, 4, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImplF16F16F32M32N32K16 =
WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 32, 32, 16, Ctrl_>;
// FP8
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
@@ -1318,6 +1382,31 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
}
};
// Map FP8/BF8 shapes to the primary template via specializations (reuse base impl)
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 16, 16, 32, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<fp8_t, fp8_t, Ctrl_>
{
};
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<fp8_t, bf8_t, float, 16, 16, 32, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<fp8_t, bf8_t, Ctrl_>
{
};
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<bf8_t, fp8_t, float, 16, 16, 32, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<bf8_t, fp8_t, Ctrl_>
{
};
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<bf8_t, bf8_t, float, 16, 16, 32, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<bf8_t, bf8_t, Ctrl_>
{
};
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
{
@@ -1501,12 +1590,38 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
}
};
// Map FP8/BF8 32x32x16 shapes to the primary template via specializations (reuse base impl)
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 32, 32, 16, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, Ctrl_>
{
};
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<fp8_t, bf8_t, float, 32, 32, 16, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, bf8_t, Ctrl_>
{
};
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<bf8_t, fp8_t, float, 32, 32, 16, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, fp8_t, Ctrl_>
{
};
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<bf8_t, bf8_t, float, 32, 32, 16, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, bf8_t, Ctrl_>
{
};
// Back-compat aliases now point to primary-template specializations
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, Ctrl_>;
WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 32, 32, 16, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8 =
WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<fp8_t, fp8_t, Ctrl_>;
WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 16, 16, 32, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, bf8_t, Ctrl_>;
@@ -1516,7 +1631,7 @@ using WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_bf8 =
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8 =
WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<bf8_t, bf8_t, Ctrl_>;
WarpGemmAttributeMfmaImpl<bf8_t, bf8_t, float, 16, 16, 32, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 =
@@ -1555,6 +1670,130 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
//__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
// opsel, scale_b)
#if defined(__gfx950__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx950__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
return CVecType{0.f};
#endif
}
};
// Map FP8/BF8 16x16x128 scale shapes to the primary template via specializations
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 16, 16, 128, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<fp8_t, fp8_t, Ctrl_>
{
};
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<fp8_t, bf8_t, float, 16, 16, 128, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<fp8_t, bf8_t, Ctrl_>
{
};
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<bf8_t, fp8_t, float, 16, 16, 128, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<bf8_t, fp8_t, Ctrl_>
{
};
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<bf8_t, bf8_t, float, 16, 16, 128, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<bf8_t, bf8_t, Ctrl_>
{
};
// Back-compat aliases
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8 =
WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 16, 16, 128, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8 =
WarpGemmAttributeMfmaImpl<fp8_t, bf8_t, float, 16, 16, 128, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8 =
WarpGemmAttributeMfmaImpl<bf8_t, fp8_t, float, 16, 16, 128, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 =
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<bf8_t, bf8_t, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = pk_fp4_t;
using BDataType = pk_fp4_t;
using CDataType = float;
using AVecType = ext_vector_t<ADataType, 16>;
using BVecType = ext_vector_t<BDataType, 16>;
using CVecType = ext_vector_t<CDataType, 4>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
static constexpr index_t kK = 128;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 32;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <index_t opselA, index_t opselB, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
@@ -1716,23 +1955,49 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base
}
};
// Map FP8/BF8 32x32x64 scale shapes to the primary template via specializations
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 32, 32, 64, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<fp8_t, fp8_t, Ctrl_>
{
};
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<fp8_t, bf8_t, float, 32, 32, 64, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<fp8_t, bf8_t, Ctrl_>
{
};
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<bf8_t, fp8_t, float, 32, 32, 64, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<bf8_t, fp8_t, Ctrl_>
{
};
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<bf8_t, bf8_t, float, 32, 32, 64, Ctrl_>
: WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<bf8_t, bf8_t, Ctrl_>
{
};
// Back-compat aliases
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8 =
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<fp8_t, fp8_t, Ctrl_>;
WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 32, 32, 64, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8 =
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<fp8_t, bf8_t, Ctrl_>;
WarpGemmAttributeMfmaImpl<fp8_t, bf8_t, float, 32, 32, 64, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8 =
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<bf8_t, fp8_t, Ctrl_>;
WarpGemmAttributeMfmaImpl<bf8_t, fp8_t, float, 32, 32, 64, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8 =
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<bf8_t, bf8_t, Ctrl_>;
WarpGemmAttributeMfmaImpl<bf8_t, bf8_t, float, 32, 32, 64, Ctrl_>;
// int8
// int8: map shapes to the primary template via specializations (reuse the concrete impl bodies)
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
{
@@ -1980,6 +2245,31 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x32_i8
}
};
// Primary-template specializations delegating to the bodies above
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<int8_t, int8_t, int32_t, 32, 32, 16, Ctrl_>
: WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<Ctrl_>
{
};
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<int8_t, int8_t, int32_t, 16, 16, 32, Ctrl_>
: WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<Ctrl_>
{
};
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<int8_t, int8_t, int32_t, 16, 16, 64, Ctrl_>
: WarpGemmAttributeMfmaImpl_i32_16x16x64_i8<Ctrl_>
{
};
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeMfmaImpl<int8_t, int8_t, int32_t, 32, 32, 32, Ctrl_>
: WarpGemmAttributeMfmaImpl_i32_32x32x32_i8<Ctrl_>
{
};
#undef DISPATCH_MFMA_
} // namespace ck_tile

View File

@@ -81,5 +81,12 @@ struct WarpGemmAttributeSmfmac
{
Impl{}(c_vec, a_vec, b_vec, idx, bool_constant<post_nop_>{});
}
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
const BVecType& b_vec,
const int32_t& idx) const
{
return Impl{}(a_vec, b_vec, idx);
}
};
} // namespace ck_tile

View File

@@ -7,10 +7,19 @@
namespace ck_tile {
// fp16 2:4 structured sparsity
// Primary template: generic Smfmac warp-gemm attribute (to be specialized per supported shape)
template <typename ADataType,
typename BDataType,
typename CDataType,
index_t kM,
index_t kN,
index_t kK,
WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeSmfmacImpl; // no definition, only specializations are provided
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16
// fp16 2:4 structured sparsity
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeSmfmacImpl<fp16_t, fp16_t, float, 32, 32, 16, Ctrl_>
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
@@ -56,12 +65,28 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = idx;
#endif
}
// c_vec = a_vec * b_vec[idx]
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
const BVecType& b_vec,
const int32_t& idx) const
{
#if defined(__gfx94_) or defined(__gfx95_)
return bit_cast<CVecType>(
__builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, fp32x4_t{0.f}, idx, 0, 0));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = idx;
return CVecType{0.f};
#endif
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32
template <WGAttrCtlEnum Ctrl_>
struct WarpGemmAttributeSmfmacImpl<fp16_t, fp16_t, float, 16, 16, 32, Ctrl_>
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
@@ -107,8 +132,33 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = idx;
#endif
}
// c_vec = a_vec * b_vec[idx]
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
const BVecType& b_vec,
const int32_t& idx) const
{
#if defined(__gfx94_) or defined(__gfx95_)
return bit_cast<CVecType>(
__builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, fp32x4_t{0.f}, idx, 0, 0));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = idx;
return CVecType{0.f};
#endif
}
};
// Back-compat aliases
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeSmfmacImplF16F16F32M16N16K32 =
WarpGemmAttributeSmfmacImpl<fp16_t, fp16_t, float, 16, 16, 32, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeSmfmacImplF16F16F32M32N32K16 =
WarpGemmAttributeSmfmacImpl<fp16_t, fp16_t, float, 32, 32, 16, Ctrl_>;
} // namespace ck_tile

View File

@@ -0,0 +1,271 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
// Add smfmac attribute impl for structured sparsity support
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp"
namespace ck_tile {
// WarpGemmCoreDispatcher: choose the underlying MFMA/SMFMAC attribute Impl
// based on A/B/Acc types and (M,N,K) per wave.
template <typename AType,
typename BType,
typename AccType,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
bool UseStructuredSparsity = false>
struct WarpGemmCoreDispatcher;
// Generic specialization for MFMA
template <typename AType,
typename BType,
typename AccType,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave>
struct WarpGemmCoreDispatcher<AType, BType, AccType, MPerWave, NPerWave, KPerWave, false>
{
using Impl = WarpGemmAttributeMfmaImpl<AType,
BType,
AccType,
MPerWave,
NPerWave,
KPerWave,
WGAttrCtlEnum::Default_>;
};
// Generic specialization for SMFMAC
// TODO: we also need to support smfmac for FP8/BF8 and I8 format
template <typename AType,
typename BType,
typename AccType,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave>
struct WarpGemmCoreDispatcher<AType, BType, AccType, MPerWave, NPerWave, KPerWave, true>
{
using Impl = WarpGemmAttributeSmfmacImpl<AType,
BType,
AccType,
MPerWave,
NPerWave,
KPerWave,
WGAttrCtlEnum::Default_>;
};
// Specialization for special cases
template <>
struct WarpGemmCoreDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false>
{
#if defined(__gfx950__)
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::half_t,
ck_tile::half_t,
float,
32,
32,
16,
WGAttrCtlEnum::Default_>;
#else
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::half_t,
ck_tile::half_t,
float,
32,
32,
8,
WGAttrCtlEnum::Default_>;
#endif
};
template <>
struct WarpGemmCoreDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false>
{
#if defined(__gfx950__)
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::half_t,
ck_tile::half_t,
float,
16,
16,
32,
WGAttrCtlEnum::Default_>;
#else
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::half_t,
ck_tile::half_t,
float,
16,
16,
16,
WGAttrCtlEnum::Default_>;
#endif
};
template <>
struct WarpGemmCoreDispatcher<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false>
{
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::half_t,
ck_tile::half_t,
float,
4,
64,
4,
WGAttrCtlEnum::Default_>;
};
template <>
struct WarpGemmCoreDispatcher<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false>
{
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::half_t,
ck_tile::half_t,
float,
64,
4,
4,
WGAttrCtlEnum::Default_>;
};
template <>
struct WarpGemmCoreDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 32, false>
{
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::fp8_t,
ck_tile::fp8_t,
float,
32,
32,
16,
WGAttrCtlEnum::Default_>;
};
template <>
struct WarpGemmCoreDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 32, false>
{
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf8_t,
ck_tile::bf8_t,
float,
32,
32,
16,
WGAttrCtlEnum::Default_>;
};
template <>
struct WarpGemmCoreDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 64, false>
{
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::fp8_t,
ck_tile::fp8_t,
float,
16,
16,
32,
WGAttrCtlEnum::Default_>;
};
template <>
struct WarpGemmCoreDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 64, false>
{
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf8_t,
ck_tile::bf8_t,
float,
16,
16,
32,
WGAttrCtlEnum::Default_>;
};
template <>
struct WarpGemmCoreDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false>
{
#if defined(__gfx950__)
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
32,
32,
16,
WGAttrCtlEnum::Default_>;
#else
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
32,
32,
8,
WGAttrCtlEnum::Default_>;
#endif
};
template <>
struct WarpGemmCoreDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false>
{
#if defined(__gfx950__)
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
16,
16,
32,
WGAttrCtlEnum::Default_>;
#else
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
16,
16,
16,
WGAttrCtlEnum::Default_>;
#endif
};
template <>
struct WarpGemmCoreDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false>
{
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
4,
64,
4,
WGAttrCtlEnum::Default_>;
};
template <>
struct WarpGemmCoreDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false>
{
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
64,
4,
4,
WGAttrCtlEnum::Default_>;
};
// Iterate-K variants built from the direct impls
template <>
struct WarpGemmCoreDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 32, false>
{
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::int8_t,
ck_tile::int8_t,
ck_tile::int32_t,
32,
32,
16,
WGAttrCtlEnum::Default_>;
};
template <>
struct WarpGemmCoreDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 64, false>
{
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::int8_t,
ck_tile::int8_t,
ck_tile::int32_t,
16,
16,
32,
WGAttrCtlEnum::Default_>;
};
} // namespace ck_tile

View File

@@ -6,7 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
#include "ck_tile/ops/gemm/warp/detail/warp_gemm_attribute_mfma_compose.hpp"
namespace ck_tile {
namespace impl {
@@ -163,16 +163,30 @@ template <typename AType,
bool SwizzleA = false,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< //
AType,
BType,
AccType,
MPerWave,
NPerWave,
KPerWave,
TransposeC,
SwizzleA,
UseStructuredSparsity,
AttrNumAccess>::Type;
using WarpGemmDispatcher =
#if defined(CK_TILE_ROUTE_WARP_GEMM_DISPATCHER_TO_MAKE)
typename MakeWarpGemm<TransposeC,
SwizzleA,
AType,
BType,
AccType,
MPerWave,
NPerWave,
KPerWave,
UseStructuredSparsity,
AttrNumAccess>::Type;
#else
using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< //
AType,
BType,
AccType,
MPerWave,
NPerWave,
KPerWave,
TransposeC,
SwizzleA,
UseStructuredSparsity,
AttrNumAccess>::Type;
#endif
} // namespace ck_tile

View File

@@ -105,6 +105,38 @@ struct WarpGemmSmfmacImpl
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
}
template <typename ATensor, typename BTensor>
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
{
using CTensor = CWarpTensor;
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using AVecCompressed =
ext_vector_t<ADataType, ATensor::get_thread_buffer_size() / CompressionRatio>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
const int32_t idx = compress_a(a_vec);
// @TODO can we simply set a_vec_pruned to a_vec[0:3]?
const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};
// c_vec = a_vec * b_vec[idx]
auto c_vec = WarpGemmAttribute{}(a_vec_pruned, b_vec, idx);
CTensor c;
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
return c;
}
};
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@ set(EXAMPLE_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
option(CK_TILE_TEST_ROUTE_DISPATCHER_TO_MAKE "Route ck_tile::WarpGemmDispatcher to MakeWarpGemm in tests" OFF)
set(EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_USE_OCP_FP8)
@@ -16,6 +17,14 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
)
set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
# Build routed variants of the common options when the switch is ON
set(EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED ${EXAMPLE_GEMM_COMPILE_OPTIONS})
if(CK_TILE_TEST_ROUTE_DISPATCHER_TO_MAKE)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_ROUTE_WARP_GEMM_DISPATCHER_TO_MAKE)
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_ROUTE_WARP_GEMM_DISPATCHER_TO_MAKE)
endif()
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
if(GPU_TARGETS MATCHES "gfx94|gfx95")
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
@@ -59,3 +68,38 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
else()
message(DEBUG "Skipping ck_tile_gemm tests for current target test_ck_tile_gemm_pipeline")
endif()
# MFMA (Attribute Compose) - STRICTLY CDNA (gfx90a, gfx94x, gfx950)
if(GPU_TARGETS MATCHES "gfx90a|gfx94|gfx95")
add_gtest_executable(test_ck_tile_gemm_attr_compose_int8_mfma test_warp_gemm_attr_compose_int8_mfma.cpp)
target_compile_options(test_ck_tile_gemm_attr_compose_int8_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED})
endif()
# MFMA (Attribute Compose) - STRICTLY CDNA3+ (gfx94x, gfx950)
if(GPU_TARGETS MATCHES "gfx94|gfx95")
add_gtest_executable(test_ck_tile_gemm_attr_compose_fp8_mfma test_warp_gemm_attr_compose_fp8_mfma.cpp)
target_compile_options(test_ck_tile_gemm_attr_compose_fp8_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED})
add_gtest_executable(test_ck_tile_gemm_attr_compose_bf8_mfma test_warp_gemm_attr_compose_bf8_mfma.cpp)
target_compile_options(test_ck_tile_gemm_attr_compose_bf8_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED})
add_gtest_executable(test_ck_tile_gemm_attr_compose_mixed_fp8_bf8_mfma test_warp_gemm_attr_compose_mixed_fp8_bf8_mfma.cpp)
target_compile_options(test_ck_tile_gemm_attr_compose_mixed_fp8_bf8_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED})
endif()
# MFMA (Attribute Compose) - STRICTLY CDNA (gfx90a, gfx94x, gfx950)
if(GPU_TARGETS MATCHES "gfx90a|gfx94|gfx95")
add_gtest_executable(test_ck_tile_gemm_attr_compose_fp16_mfma test_warp_gemm_attr_compose_fp16_mfma.cpp)
target_compile_options(test_ck_tile_gemm_attr_compose_fp16_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED})
add_gtest_executable(test_ck_tile_gemm_attr_compose_bf16_mfma test_warp_gemm_attr_compose_bf16_mfma.cpp)
target_compile_options(test_ck_tile_gemm_attr_compose_bf16_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED})
endif()
# --- SMFMAC Tests ---
# SMFMAC (Attribute Compose) - STRICTLY CDNA3+ (gfx94x, gfx950)
if(GPU_TARGETS MATCHES "gfx94|gfx95")
add_gtest_executable(test_ck_tile_gemm_attr_compose_fp16_smfmac test_warp_gemm_attr_compose_fp16_smfmac.cpp)
target_compile_options(test_ck_tile_gemm_attr_compose_fp16_smfmac PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED})
endif()

View File

@@ -0,0 +1,301 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TILE_TEST_WARP_GEMM_ATTR_COMPOSE_HPP
#define CK_TILE_TEST_WARP_GEMM_ATTR_COMPOSE_HPP
#include <gtest/gtest.h>
#include <hip/hip_runtime.h>
#include <algorithm>
#include "ck_tile/ops/gemm/warp/detail/warp_gemm_attribute_mfma_compose.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
// For real kernel run and verification
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
// ---------------------- Typed tests for Dispatcher coverage ----------------------
template <typename A,
typename B,
typename Acc,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
bool TransposeC,
bool SwizzleA = false,
bool UseStructuredSparsity = false,
ck_tile::WGAttrNumAccessEnum NA = ck_tile::WGAttrNumAccessEnum::Single>
struct WGDispCase
{
using AType = A;
using BType = B;
using AccType = Acc;
static constexpr ck_tile::index_t MPerWave = M;
static constexpr ck_tile::index_t NPerWave = N;
static constexpr ck_tile::index_t KPerWave = K;
static constexpr bool kTransposeC = TransposeC;
static constexpr bool kSwizzleA = SwizzleA;
static constexpr bool kUSS = UseStructuredSparsity;
static constexpr ck_tile::WGAttrNumAccessEnum kNA = NA;
};
template <typename Case>
class WGCompileTimeTest : public ::testing::Test
{
public:
void RunTest()
{
// Dispatcher-selected WarpGemm
using Disp = typename ck_tile::WarpGemmDispatcher<typename Case::AType,
typename Case::BType,
typename Case::AccType,
Case::MPerWave,
Case::NPerWave,
Case::KPerWave,
Case::kTransposeC,
Case::kSwizzleA,
Case::kUSS,
Case::kNA>;
// Factory-selected WarpGemm (MakeWarpGemm)
using Make = typename ck_tile::MakeWarpGemm<Case::kTransposeC,
Case::kSwizzleA,
typename Case::AType,
typename Case::BType,
typename Case::AccType,
Case::MPerWave,
Case::NPerWave,
Case::KPerWave,
Case::kUSS,
Case::kNA>::Type;
// 1) Scalar compile-time constants must match
static_assert(Disp::kM == Make::kM, "kM differs between Dispatcher and MakeWarpGemm");
static_assert(Disp::kN == Make::kN, "kN differs between Dispatcher and MakeWarpGemm");
static_assert(Disp::kK == Make::kK, "kK differs between Dispatcher and MakeWarpGemm");
static_assert(Disp::kKPerThread == Make::kKPerThread,
"kKPerThread differs between Dispatcher and MakeWarpGemm");
static_assert(Disp::get_num_of_access() == Make::get_num_of_access(),
"get_num_of_access() differs between Dispatcher and MakeWarpGemm");
// 2) Data types must match
static_assert(std::is_same_v<typename Disp::ADataType, typename Make::ADataType>,
"ADataType differs between Dispatcher and MakeWarpGemm");
static_assert(std::is_same_v<typename Disp::BDataType, typename Make::BDataType>,
"BDataType differs between Dispatcher and MakeWarpGemm");
static_assert(std::is_same_v<typename Disp::CDataType, typename Make::CDataType>,
"CDataType differs between Dispatcher and MakeWarpGemm");
// 3) Distribution encodings must match (ensures identical warp tiling/layout)
static_assert(
std::is_same_v<typename Disp::AWarpDstrEncoding, typename Make::AWarpDstrEncoding>,
"AWarpDstrEncoding differs between Dispatcher and MakeWarpGemm");
static_assert(
std::is_same_v<typename Disp::BWarpDstrEncoding, typename Make::BWarpDstrEncoding>,
"BWarpDstrEncoding differs between Dispatcher and MakeWarpGemm");
static_assert(
std::is_same_v<typename Disp::CWarpDstrEncoding, typename Make::CWarpDstrEncoding>,
"CWarpDstrEncoding differs between Dispatcher and MakeWarpGemm");
// 4) Final tensor types must match (encodes DataType + Distribution)
static_assert(std::is_same_v<typename Disp::AWarpTensor, typename Make::AWarpTensor>,
"AWarpTensor differs between Dispatcher and MakeWarpGemm");
static_assert(std::is_same_v<typename Disp::BWarpTensor, typename Make::BWarpTensor>,
"BWarpTensor differs between Dispatcher and MakeWarpGemm");
static_assert(std::is_same_v<typename Disp::CWarpTensor, typename Make::CWarpTensor>,
"CWarpTensor differs between Dispatcher and MakeWarpGemm");
SUCCEED();
}
};
// ---------------------- Runtime tests: Compare Dispatcher (MFMA and SMFMA only) vs MakeWarpGemm vs
// CPU ----------------------
// ---------------------- Runtime operator() behavior tests on GPU ----------------------
template <bool UseMakeWarpGemm,
typename AType,
typename BType,
typename CType,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
bool TransposeC,
bool SwizzleA,
bool UseStructuredSparsity,
ck_tile::WGAttrNumAccessEnum NumAccess>
struct WarpGemmKernel
{
static constexpr int kBlockSize = 64;
__device__ void operator()(const AType* A, const BType* B, CType* C) const
{
using WarpGemm = std::conditional_t<UseMakeWarpGemm,
typename ck_tile::MakeWarpGemm<TransposeC,
SwizzleA,
AType,
BType,
CType,
M,
N,
K,
UseStructuredSparsity,
NumAccess>::Type,
ck_tile::WarpGemmDispatcher<AType,
BType,
CType,
M,
N,
K,
TransposeC,
SwizzleA,
UseStructuredSparsity,
NumAccess>>;
// A: [M,K] row-major (packed)
const auto a_view =
ck_tile::make_naive_tensor_view_packed<ck_tile::address_space_enum::global>(
const_cast<AType*>(A), ck_tile::make_tuple(M, K));
// B: expose as logical [N,K] with strides (1, N) over the original row-major [K,N] buffer
const auto b_view = ck_tile::make_naive_tensor_view<ck_tile::address_space_enum::global>(
const_cast<BType*>(B), ck_tile::make_tuple(N, K), ck_tile::make_tuple(1, N));
// C: [M,N] row-major (packed)
const auto c_view =
ck_tile::make_naive_tensor_view_packed<ck_tile::address_space_enum::global>(
const_cast<CType*>(C), ck_tile::make_tuple(M, N));
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto a_len = AWarpTensor::get_tile_distribution().get_lengths();
constexpr auto b_len = BWarpTensor::get_tile_distribution().get_lengths();
constexpr auto c_len = CWarpTensor::get_tile_distribution().get_lengths();
auto a_win = ck_tile::make_tile_window(
a_view, a_len, ck_tile::make_multi_index(0, 0), AWarpTensor::get_tile_distribution());
auto b_win = ck_tile::make_tile_window(
b_view, b_len, ck_tile::make_multi_index(0, 0), BWarpTensor::get_tile_distribution());
auto c_win = ck_tile::make_tile_window(
c_view, c_len, ck_tile::make_multi_index(0, 0), CWarpTensor::get_tile_distribution());
AWarpTensor a_tile;
BWarpTensor b_tile;
ck_tile::load_tile(a_tile, a_win);
ck_tile::load_tile(b_tile, b_win);
CWarpTensor c_tile;
c_tile = WarpGemm{}(a_tile, b_tile);
ck_tile::store_tile(c_win, c_tile);
}
};
// ---------------------- New runtime helper: run a WG on device with given A/B into C
// ----------------------
template <typename Case, bool UseMakeWarpGemm>
static void RunWarpGemmCase(const ck_tile::HostTensor<typename Case::AType>& A,
const ck_tile::HostTensor<typename Case::BType>& B,
ck_tile::HostTensor<typename Case::AccType>& C)
{
using AType = typename Case::AType;
using BType = typename Case::BType;
using CType = typename Case::AccType; // CDataType equals Acc for these tests
ck_tile::DeviceMem Ad(A.get_element_space_size_in_bytes());
ck_tile::DeviceMem Bd(B.get_element_space_size_in_bytes());
ck_tile::DeviceMem Cd(C.get_element_space_size_in_bytes());
Ad.ToDevice(A.data());
Bd.ToDevice(B.data());
Cd.SetZero();
dim3 grid(1);
dim3 block{64};
using Kernel = WarpGemmKernel<UseMakeWarpGemm,
AType,
BType,
CType,
Case::MPerWave,
Case::NPerWave,
Case::KPerWave,
Case::kTransposeC,
Case::kSwizzleA,
Case::kUSS,
Case::kNA>;
(void)ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, true},
ck_tile::make_kernel(Kernel{},
grid,
block,
0,
static_cast<const AType*>(Ad.GetDeviceBuffer()),
static_cast<const BType*>(Bd.GetDeviceBuffer()),
static_cast<CType*>(Cd.GetDeviceBuffer())));
Cd.FromDevice(C.mData.data());
}
// enforce 2:4 sparsity on A for SMFMA runtime cases (only meaningful for half_t here)
template <typename AType>
static inline void make_2to4_sparse_A(ck_tile::HostTensor<AType>& A)
{
// zero half the values in each consecutive group of 4 along K for each row m
const ck_tile::index_t M = A.mDesc.get_lengths()[0];
const ck_tile::index_t K = A.mDesc.get_lengths()[1];
for(ck_tile::index_t m = 0; m < M; ++m)
{
for(ck_tile::index_t k = 0; k + 3 < K; k += 4)
{
// keep entries at k and k+2, zero k+1 and k+3 (simple 2:4 pattern)
A(m, k + 1) = ck_tile::type_convert<AType>(0);
A(m, k + 3) = ck_tile::type_convert<AType>(0);
}
}
}
template <typename Case>
class WGRuntimeTest : public ::testing::Test
{
public:
void RunTest()
{
// Equivalent MakeWarpGemm
using AType = typename Case::AType;
using BType = typename Case::BType;
using CType = typename Case::AccType;
constexpr ck_tile::index_t M = Case::MPerWave;
constexpr ck_tile::index_t N = Case::NPerWave;
constexpr ck_tile::index_t K = Case::KPerWave;
ck_tile::HostTensor<AType> A({M, K});
ck_tile::HostTensor<BType> B({K, N});
ck_tile::HostTensor<CType> C_disp({M, N});
ck_tile::HostTensor<CType> C_make({M, N});
for(ck_tile::index_t m = 0; m < M; ++m)
for(ck_tile::index_t k = 0; k < K; ++k)
A(m, k) = ck_tile::type_convert<AType>((m + 1) * 0.1f + (k + 1) * 0.01f);
if constexpr(Case::kUSS)
{
// ensure A satisfies 2:4 sparsity for SMFMA
make_2to4_sparse_A(A);
}
for(ck_tile::index_t k0 = 0; k0 < K; ++k0)
for(ck_tile::index_t n = 0; n < N; ++n)
B(k0, n) = ck_tile::type_convert<BType>((k0 + 1) * 0.2f - (n + 1) * 0.03f);
C_disp.SetZero();
C_make.SetZero();
RunWarpGemmCase<Case, /*UseMakeWarpGemm=*/false>(A, B, C_disp);
RunWarpGemmCase<Case, /*UseMakeWarpGemm=*/true>(A, B, C_make);
EXPECT_TRUE(
ck_tile::check_err(C_disp, C_make, "Dispatcher vs MakeWarpGemm mismatch", 0, 0));
}
};
#endif // CK_TILE_TEST_WARP_GEMM_ATTR_COMPOSE_HPP

View File

@@ -0,0 +1,35 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_warp_gemm_attr_compose.hpp"
using WGDispatcherTypesList = ::testing::Types<
// clang-format off
// bf16
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true>,
// WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true, true>,
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, true>>;
// clang-format on
TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList);
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); }
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); }

View File

@@ -0,0 +1,27 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_warp_gemm_attr_compose.hpp"
using WGDispatcherTypesList = ::testing::Types<
// clang-format off
// bf8
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, false>,
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 32, false>,
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 32, false>,
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 32, true>,
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 64, false>,
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true>,
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false>,
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false>,
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>,
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>>;
// clang-format on
TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList);
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); }
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); }

View File

@@ -0,0 +1,35 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_warp_gemm_attr_compose.hpp"
using WGDispatcherTypesList = ::testing::Types<
// clang-format off
// fp16
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float,16, 16, 32, false, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true>,
// WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true, true>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, true>>;
// clang-format on
TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList);
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); }
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); }

View File

@@ -0,0 +1,19 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_warp_gemm_attr_compose.hpp"
using WGDispatcherTypesList = ::testing::Types<
// clang-format off
// fp16 2:4 structural sparsity
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, true>,
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, true>>;
// clang-format on
TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList);
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); }
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); }

View File

@@ -0,0 +1,27 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_warp_gemm_attr_compose.hpp"
using WGDispatcherTypesList = ::testing::Types<
// clang-format off
// fp8
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false>,
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 32, false>,
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 32, false>,
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 64, false>,
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true>,
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 32, true>,
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false>,
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false>,
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>,
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>>;
// clang-format on
TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList);
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); }
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); }

View File

@@ -0,0 +1,21 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_warp_gemm_attr_compose.hpp"
using WGDispatcherTypesList = ::testing::Types<
// clang-format off
// int8
WGDispCase<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, false>,
WGDispCase<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true>,
WGDispCase<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, false>,
WGDispCase<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true>>;
// clang-format on
TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList);
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); }
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); }

View File

@@ -0,0 +1,29 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_warp_gemm_attr_compose.hpp"
using WGDispatcherTypesList = ::testing::Types<
// clang-format off
// mixed fp8/bf8
WGDispCase<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, false>,
WGDispCase<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, true>,
WGDispCase<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, false>,
WGDispCase<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, true>,
WGDispCase<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false>,
WGDispCase<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false>,
WGDispCase<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false>,
WGDispCase<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false>,
WGDispCase<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>,
WGDispCase<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>,
WGDispCase<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>,
WGDispCase<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>>;
// clang-format on
TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList);
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); }
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); }