mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Refactor WarpGemm dispatcher and compose attributes
* Introduce WarpGemmCoreDispatcher to select MFMA/SMFMAC Implmentations by A/B/Acc types, (M,N,K) per wave, and structured sparsity flag * Add composed attribute layer (detail/warp_gemm_attribute_mfma_compose.hpp) to fold policies (IterateK, TransposeC, Swizzle, NumAccess) and derive A/B/C encodings Update warp_gemm_dispatcher.hpp to route through the new core dispatcher * Align SMFMAC attribute and wrapper: expose A/B/C distribution encodings and unify get_num_of_access; adjust WarpGemmSmfmacImpl integration Expand/adjust MFMA attribute impls to work with the new composition/dispatcher path * Add unit test for attribute composition (test_warp_gemm_attr_compose.cpp) and wire test CMake; minor example CMake tweak (example/ck_tile/03_gemm)
This commit is contained in:
@@ -13,6 +13,13 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a")
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
option(CK_TILE_EXAMPLE_ROUTE_DISPATCHER_TO_MAKE "Route ck_tile::WarpGemmDispatcher to MakeWarpGemm in examples" OFF)
|
||||
|
||||
# When routing dispatcher to MakeWarpGemm is enabled, add the macro to the shared options
|
||||
if(CK_TILE_EXAMPLE_ROUTE_DISPATCHER_TO_MAKE)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_ROUTE_WARP_GEMM_DISPATCHER_TO_MAKE)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef)
|
||||
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker)
|
||||
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps)
|
||||
|
||||
@@ -0,0 +1,709 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
||||
// smfmac (structured sparsity)
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
||||
namespace ck_tile {
|
||||
namespace detail {
|
||||
namespace wg_attr_compose {
|
||||
|
||||
// Basic config/state that policies transform at compile-time
|
||||
struct NoneTag;
|
||||
|
||||
enum class SwizzleKind
|
||||
{
|
||||
None,
|
||||
A,
|
||||
B
|
||||
};
|
||||
|
||||
template <typename Impl_,
|
||||
bool TransposeC_,
|
||||
SwizzleKind Swizzle_,
|
||||
index_t SFactor_,
|
||||
index_t KIter_,
|
||||
WGAttrNumAccessEnum NumAccess_>
|
||||
struct State
|
||||
{
|
||||
using Impl = remove_cvref_t<Impl_>;
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
static constexpr SwizzleKind Swizzle = Swizzle_;
|
||||
static constexpr index_t SFactor = SFactor_;
|
||||
static constexpr index_t KIter = KIter_;
|
||||
static constexpr WGAttrNumAccessEnum NumAccess = NumAccess_;
|
||||
static constexpr index_t NumAccValue = static_cast<index_t>(NumAccess);
|
||||
};
|
||||
|
||||
// Default base state
|
||||
template <typename Impl, WGAttrNumAccessEnum NumAccess = WGAttrNumAccessEnum::Single>
|
||||
using BaseState = State<Impl, false, SwizzleKind::None, 1, 1, NumAccess>;
|
||||
|
||||
// Helpers to compute derived constants and types from State
|
||||
|
||||
// A/B data/vec types under transpose flag
|
||||
template <class S>
|
||||
struct ABTypes
|
||||
{
|
||||
using ADataType =
|
||||
std::conditional_t<S::TransposeC, typename S::Impl::BDataType, typename S::Impl::ADataType>;
|
||||
using BDataType =
|
||||
std::conditional_t<S::TransposeC, typename S::Impl::ADataType, typename S::Impl::BDataType>;
|
||||
|
||||
using AVecBase =
|
||||
std::conditional_t<S::TransposeC, typename S::Impl::BVecType, typename S::Impl::AVecType>;
|
||||
using BVecBase =
|
||||
std::conditional_t<S::TransposeC, typename S::Impl::AVecType, typename S::Impl::BVecType>;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, vector_traits<AVecBase>::vector_size * S::KIter>;
|
||||
using BVecType = ext_vector_t<BDataType, vector_traits<BVecBase>::vector_size * S::KIter>;
|
||||
};
|
||||
|
||||
// C types/shape (C does not change type with policies)
|
||||
template <class S>
|
||||
struct CTypes
|
||||
{
|
||||
using CDataType = typename S::Impl::CDataType;
|
||||
using CVecType = typename S::Impl::CVecType;
|
||||
};
|
||||
|
||||
// kM/kN/kK/kKPerThread under transpose and iterateK
|
||||
|
||||
template <class S>
|
||||
struct KShape
|
||||
{
|
||||
static constexpr index_t kM = S::TransposeC ? S::Impl::kN : S::Impl::kM;
|
||||
static constexpr index_t kN = S::TransposeC ? S::Impl::kM : S::Impl::kN;
|
||||
static constexpr index_t kK = S::Impl::kK * S::KIter;
|
||||
static constexpr index_t kKPerThread = S::Impl::kABKPerLane * S::KIter;
|
||||
static constexpr index_t kCMLane = S::Impl::kCMLane; // unchanged
|
||||
};
|
||||
|
||||
// Lane helpers
|
||||
|
||||
template <class S>
|
||||
struct Lanes
|
||||
{
|
||||
static constexpr index_t AMLane = S::TransposeC ? S::Impl::kBNLane : S::Impl::kAMLane;
|
||||
static constexpr index_t BNLane = S::TransposeC ? S::Impl::kAMLane : S::Impl::kBNLane;
|
||||
};
|
||||
|
||||
// Encoding builders centralizing existing logic, parameterized by S
|
||||
|
||||
// A encodings with NumAccess and IterateK folded, with multi-block and swizzle
|
||||
template <class S>
|
||||
CK_TILE_DEVICE static constexpr auto MakeAWarpDstrEncoding()
|
||||
{
|
||||
constexpr index_t NumAccValue = S::NumAccValue;
|
||||
|
||||
// Helper lambdas for the base A-encoding and the swizzled variant
|
||||
auto base_enc = []() {
|
||||
if constexpr(S::Impl::kAMBlock == 1 && S::Impl::kBNBlock == 1)
|
||||
{
|
||||
if constexpr(NumAccValue == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Lanes<S>::AMLane>,
|
||||
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(KShape<S>::kKPerThread % NumAccValue == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Lanes<S>::AMLane>,
|
||||
sequence<NumAccValue,
|
||||
S::Impl::kABKLane,
|
||||
S::Impl::kABKPerLane * S::KIter / NumAccValue>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(S::Impl::kAMBlock == 1 && 1 < S::Impl::kBNBlock)
|
||||
{
|
||||
static_assert(NumAccValue == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
return tile_distribution_encoding<
|
||||
sequence<S::Impl::kBNBlock>,
|
||||
tuple<sequence<Lanes<S>::AMLane>,
|
||||
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else // 1 < AMBlock && BNBlock == 1
|
||||
{
|
||||
static_assert(NumAccValue == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Impl::kAMBlock, Lanes<S>::AMLane>,
|
||||
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
};
|
||||
|
||||
auto swizzled_enc = []() {
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<
|
||||
sequence<S::Impl::kAMLane / (S::Impl::kCMLane * S::SFactor * S::Impl::kCM1PerLane),
|
||||
S::Impl::kCMLane,
|
||||
S::SFactor,
|
||||
S::Impl::kCM1PerLane>,
|
||||
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
|
||||
tuple<sequence<2, 1, 1, 1, 1>>,
|
||||
tuple<sequence<0, 0, 2, 1, 3>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
};
|
||||
|
||||
if constexpr(S::Swizzle == SwizzleKind::A)
|
||||
return swizzled_enc();
|
||||
else
|
||||
return base_enc();
|
||||
}
|
||||
|
||||
// B encodings with NumAccess and IterateK folded, with multi-block and swizzle
|
||||
template <class S>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBWarpDstrEncoding()
|
||||
{
|
||||
constexpr index_t NumAccValue = S::NumAccValue;
|
||||
|
||||
auto base_enc = []() {
|
||||
if constexpr(S::Impl::kAMBlock == 1 && S::Impl::kBNBlock == 1)
|
||||
{
|
||||
if constexpr(NumAccValue == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Lanes<S>::BNLane>,
|
||||
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(KShape<S>::kKPerThread % NumAccValue == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Lanes<S>::BNLane>,
|
||||
sequence<NumAccValue,
|
||||
S::Impl::kABKLane,
|
||||
S::Impl::kABKPerLane * S::KIter / NumAccValue>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(S::Impl::kAMBlock == 1 && 1 < S::Impl::kBNBlock)
|
||||
{
|
||||
static_assert(NumAccValue == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Impl::kBNBlock, Lanes<S>::BNLane>,
|
||||
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else // 1 < AMBlock && BNBlock == 1
|
||||
{
|
||||
static_assert(NumAccValue == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
return tile_distribution_encoding<
|
||||
sequence<S::Impl::kAMBlock>,
|
||||
tuple<sequence<Lanes<S>::BNLane>,
|
||||
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
};
|
||||
|
||||
auto swizzled_enc = []() {
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<
|
||||
sequence<S::Impl::kAMLane / (S::Impl::kCMLane * S::SFactor * S::Impl::kCM1PerLane),
|
||||
S::Impl::kCMLane,
|
||||
S::SFactor,
|
||||
S::Impl::kCM1PerLane>,
|
||||
sequence<S::Impl::kABKLane, S::Impl::kABKPerLane * S::KIter>>,
|
||||
tuple<sequence<2, 1, 1, 1, 1>>,
|
||||
tuple<sequence<0, 0, 2, 1, 3>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
};
|
||||
|
||||
if constexpr(S::Swizzle == SwizzleKind::B)
|
||||
return swizzled_enc();
|
||||
else
|
||||
return base_enc();
|
||||
}
|
||||
|
||||
// C distribution encoding with transpose and multi-block
|
||||
|
||||
template <class S>
|
||||
CK_TILE_DEVICE static constexpr auto MakeCWarpDstrEncoding()
|
||||
{
|
||||
constexpr bool HasSwizzle = (S::Swizzle == SwizzleKind::A) || (S::Swizzle == SwizzleKind::B);
|
||||
|
||||
auto make_m_splits = []() {
|
||||
if constexpr(HasSwizzle)
|
||||
{
|
||||
// Swizzled M splits
|
||||
return sequence<S::Impl::kCM0PerLane / S::SFactor,
|
||||
S::Impl::kCMLane,
|
||||
S::Impl::kCM1PerLane * S::SFactor>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return sequence<S::Impl::kCM0PerLane, S::Impl::kCMLane, S::Impl::kCM1PerLane>{};
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr(S::Impl::kAMBlock == 1 && S::Impl::kBNBlock == 1)
|
||||
{
|
||||
if constexpr(!S::TransposeC)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<decltype(make_m_splits()), sequence<S::Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
else // TransposeC
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Impl::kCNLane>, decltype(make_m_splits())>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(S::Impl::kAMBlock == 1 && 1 < S::Impl::kBNBlock)
|
||||
{
|
||||
if constexpr(!S::TransposeC)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<decltype(make_m_splits()), sequence<S::Impl::kBNBlock * S::Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Impl::kBNBlock * S::Impl::kCNLane>, decltype(make_m_splits())>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(1 < S::Impl::kAMBlock && S::Impl::kBNBlock == 1)
|
||||
{
|
||||
if constexpr(!S::TransposeC)
|
||||
{
|
||||
return tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<S::Impl::kCM0PerLane,
|
||||
S::Impl::kAMBlock * S::Impl::kCMLane,
|
||||
S::Impl::kCM1PerLane>,
|
||||
sequence<S::Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<S::Impl::kCNLane>,
|
||||
sequence<S::Impl::kCM0PerLane,
|
||||
S::Impl::kAMBlock * S::Impl::kCMLane,
|
||||
S::Impl::kCM1PerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Detect smfmac by Impl type: provide a small trait that is true for smfmac attribute impls
|
||||
template <class T>
|
||||
struct is_smfmac_impl : std::false_type
|
||||
{
|
||||
};
|
||||
template <typename AType_,
|
||||
typename BType_,
|
||||
typename AccType_,
|
||||
index_t MPerWave_,
|
||||
index_t NPerWave_,
|
||||
index_t KPerWave_,
|
||||
WGAttrCtlEnum C>
|
||||
struct is_smfmac_impl<
|
||||
WarpGemmAttributeSmfmacImpl<AType_, BType_, AccType_, MPerWave_, NPerWave_, KPerWave_, C>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
// Final composed attribute
|
||||
// Primary (MFMA) and SMFMA-specialized ComposedAttribute
|
||||
template <class S, bool IsSmfmac = is_smfmac_impl<typename S::Impl>::value>
|
||||
struct ComposedAttribute
|
||||
{
|
||||
using Impl = typename S::Impl;
|
||||
using ATypes = ABTypes<S>;
|
||||
using CTypesT = CTypes<S>;
|
||||
|
||||
static constexpr index_t kM = KShape<S>::kM;
|
||||
static constexpr index_t kN = KShape<S>::kN;
|
||||
static constexpr index_t kK = KShape<S>::kK;
|
||||
static constexpr index_t kKPerThread = KShape<S>::kKPerThread;
|
||||
static constexpr index_t kCMLane = KShape<S>::kCMLane;
|
||||
|
||||
using ADataType = typename ATypes::ADataType;
|
||||
using BDataType = typename ATypes::BDataType;
|
||||
using CDataType = typename CTypesT::CDataType;
|
||||
|
||||
using AVecType = typename ATypes::AVecType;
|
||||
using BVecType = typename ATypes::BVecType;
|
||||
using CVecType = typename CTypesT::CVecType;
|
||||
|
||||
using AWarpDstrEncoding = decltype(MakeAWarpDstrEncoding<S>());
|
||||
using BWarpDstrEncoding = decltype(MakeBWarpDstrEncoding<S>());
|
||||
using CWarpDstrEncoding = decltype(MakeCWarpDstrEncoding<S>());
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return S::KIter; }
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
if constexpr(S::KIter == 1)
|
||||
{
|
||||
if constexpr(S::TransposeC)
|
||||
{
|
||||
// When TransposeC, composed A/B are swapped relative to Impl.
|
||||
// Pass b_vec as Impl::AVecType and a_vec as Impl::BVecType.
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const typename Impl::AVecType&>(b_vec),
|
||||
reinterpret_cast<const typename Impl::BVecType&>(a_vec),
|
||||
bool_constant<post_nop_>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const typename Impl::AVecType&>(a_vec),
|
||||
reinterpret_cast<const typename Impl::BVecType&>(b_vec),
|
||||
bool_constant<post_nop_>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, S::KIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, S::KIter>;
|
||||
|
||||
static_for<0, S::KIter, 1>{}([&](auto iKIter) {
|
||||
if constexpr(S::TransposeC)
|
||||
{
|
||||
// Swap mapping: b_vec -> Impl::AVecType, a_vec -> Impl::BVecType
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
if constexpr(S::KIter == 1)
|
||||
{
|
||||
if constexpr(S::TransposeC)
|
||||
{
|
||||
// Swap mapping: b_vec -> Impl::AVecType, a_vec -> Impl::BVecType
|
||||
return Impl{}(reinterpret_cast<const typename Impl::AVecType&>(b_vec),
|
||||
reinterpret_cast<const typename Impl::BVecType&>(a_vec));
|
||||
}
|
||||
else
|
||||
{
|
||||
return Impl{}(reinterpret_cast<const typename Impl::AVecType&>(a_vec),
|
||||
reinterpret_cast<const typename Impl::BVecType&>(b_vec));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, S::KIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, S::KIter>;
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
CVecType c_vec;
|
||||
if constexpr(S::TransposeC)
|
||||
{
|
||||
// Swap mapping: b_vec -> Impl::AVecType, a_vec -> Impl::BVecType
|
||||
c_vec = Impl{}(reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
static_for<1, S::KIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
c_vec = Impl{}(reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
static_for<1, S::KIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
return c_vec;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// SMFMA specialization: forbid Transpose and Swizzle for now; KIter must be 1
|
||||
// TODO: enable swizzle, transpose for smfmac
|
||||
template <class S>
|
||||
struct ComposedAttribute<S, true>
|
||||
{
|
||||
using Impl = typename S::Impl;
|
||||
|
||||
static_assert(!S::TransposeC, "smfmac TransposeC is not supported in composed attributes");
|
||||
static_assert(S::Swizzle == SwizzleKind::None,
|
||||
"smfmac Swizzle is not supported in composed attributes");
|
||||
static_assert(S::KIter == 1, "smfmac IterateK is not supported (KIter must be 1)");
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename Impl::AVecType;
|
||||
using BVecType = typename Impl::BVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane;
|
||||
static constexpr index_t kCMLane = Impl::kCMLane;
|
||||
static constexpr index_t kCompressionRatio = Impl::CompressionRatio;
|
||||
|
||||
// Reuse the encodings defined by the smfmac attribute wrapper for consistency
|
||||
using AWarpDstrEncoding = typename WarpGemmAttributeSmfmac<Impl>::AWarpDstrEncoding;
|
||||
using BWarpDstrEncoding = typename WarpGemmAttributeSmfmac<Impl>::BWarpDstrEncoding;
|
||||
using CWarpDstrEncoding = typename WarpGemmAttributeSmfmac<Impl>::CWarpDstrEncoding;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
// c_vec += a_vec * b_vec[idx]
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& idx,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
Impl{}(c_vec, a_vec, b_vec, idx, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec[idx]
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& idx) const
|
||||
{
|
||||
CVecType c_vec{0};
|
||||
Impl{}(c_vec, a_vec, b_vec, idx);
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
// Policy wrappers produce a new State from an old one
|
||||
|
||||
template <index_t KIter>
|
||||
struct PolicyIterateK
|
||||
{
|
||||
template <class S>
|
||||
using apply =
|
||||
State<typename S::Impl, S::TransposeC, S::Swizzle, S::SFactor, KIter, S::NumAccess>;
|
||||
};
|
||||
|
||||
struct PolicyTransposeC
|
||||
{
|
||||
template <class S>
|
||||
using apply = State<typename S::Impl, true, S::Swizzle, S::SFactor, S::KIter, S::NumAccess>;
|
||||
};
|
||||
|
||||
template <index_t SFactor>
|
||||
struct PolicySwizzleA
|
||||
{
|
||||
template <class S>
|
||||
using apply =
|
||||
State<typename S::Impl, S::TransposeC, SwizzleKind::A, SFactor, S::KIter, S::NumAccess>;
|
||||
};
|
||||
|
||||
template <index_t SFactor>
|
||||
struct PolicySwizzleB
|
||||
{
|
||||
template <class S>
|
||||
using apply =
|
||||
State<typename S::Impl, S::TransposeC, SwizzleKind::B, SFactor, S::KIter, S::NumAccess>;
|
||||
};
|
||||
|
||||
} // namespace wg_attr_compose
|
||||
} // namespace detail
|
||||
} // namespace ck_tile
|
||||
|
||||
// High-level alias to construct a composed WarpGemm attribute via CoreDispatcher and policies
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_core_dispatcher.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Helper that uses core dispatcher to select MFMA vs smfmac and composes policies only
|
||||
template <bool TransposeC,
|
||||
bool SwizzleA,
|
||||
bool UseStructuredSparsity,
|
||||
typename AType,
|
||||
typename BType,
|
||||
typename AccType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
WGAttrNumAccessEnum NumAccess>
|
||||
struct ComposePolicies
|
||||
{
|
||||
using CD = WarpGemmCoreDispatcher<AType,
|
||||
BType,
|
||||
AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
UseStructuredSparsity>;
|
||||
using Impl = typename CD::Impl;
|
||||
|
||||
static_assert(Impl::kK > 0, "Invalid K dimension");
|
||||
static_assert(Impl::kK <= KPerWave, "KPerWave must smaller and equal than Impl::kK");
|
||||
|
||||
static constexpr index_t KIter = KPerWave / Impl::kK;
|
||||
|
||||
static constexpr bool IsSmfmac = detail::wg_attr_compose::is_smfmac_impl<Impl>::value;
|
||||
|
||||
// First, setup a default state as the base state.
|
||||
using S0 = detail::wg_attr_compose::BaseState<Impl, NumAccess>;
|
||||
|
||||
// The order of policies matters so we compose them with the order: Kiter->TransposeC->SwizzleA.
|
||||
using S1 = std::conditional_t<
|
||||
(KIter == 1),
|
||||
S0,
|
||||
typename detail::wg_attr_compose::PolicyIterateK<KIter>::template apply<S0>>;
|
||||
using S2 =
|
||||
std::conditional_t<TransposeC,
|
||||
typename detail::wg_attr_compose::PolicyTransposeC::template apply<S1>,
|
||||
S1>;
|
||||
// Match dispatcher behavior: when TransposeC is enabled, the swizzle applies to B
|
||||
// (i.e., SwizzleBTransposedCDistribution). Otherwise apply swizzle to A.
|
||||
using S3 = std::conditional_t<
|
||||
SwizzleA && !TransposeC,
|
||||
typename detail::wg_attr_compose::PolicySwizzleA<2>::template apply<S2>,
|
||||
std::conditional_t<SwizzleA && TransposeC,
|
||||
typename detail::wg_attr_compose::PolicySwizzleB<2>::template apply<S2>,
|
||||
S2>>;
|
||||
|
||||
// For SMFMA, use the dedicated WarpGemmSmfmacImpl wrapper with the SMFMA attribute.
|
||||
// For MFMA (default), use the regular WarpGemmImpl over the composed attribute.
|
||||
using type = std::conditional_t<
|
||||
IsSmfmac,
|
||||
ck_tile::WarpGemmSmfmacImpl<detail::wg_attr_compose::ComposedAttribute<S3>>,
|
||||
ck_tile::WarpGemmImpl<detail::wg_attr_compose::ComposedAttribute<S3>>>;
|
||||
};
|
||||
|
||||
// Wrapper struct to match usage as MakeWarpGemm<...>::Type
|
||||
template <bool TransposeC,
|
||||
bool SwizzleA,
|
||||
typename AType,
|
||||
typename BType,
|
||||
typename AccType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum NumAccess = WGAttrNumAccessEnum::Single>
|
||||
struct MakeWarpGemm
|
||||
{
|
||||
using Type = typename ComposePolicies<TransposeC,
|
||||
SwizzleA,
|
||||
UseStructuredSparsity,
|
||||
AType,
|
||||
BType,
|
||||
AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
NumAccess>::type;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -22,6 +22,16 @@ enum class WGAttrCtlEnum
|
||||
// raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
|
||||
};
|
||||
|
||||
// Primary template: generic MFMA warp-gemm attribute (to be specialized per supported shape)
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
index_t kM,
|
||||
index_t kN,
|
||||
index_t kK,
|
||||
WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl; // no definition, only specializations are provided
|
||||
|
||||
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
|
||||
if constexpr(post_nop_) \
|
||||
{ \
|
||||
@@ -191,8 +201,8 @@ struct WarpGemmAttributeMfmaImplF32F32F32M32N32K2
|
||||
};
|
||||
|
||||
// V_MFMA_F32_16x16x32_BF16
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 16, 16, 32, Ctrl_>
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = bf16_t;
|
||||
@@ -254,8 +264,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32
|
||||
}
|
||||
};
|
||||
// FP16
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 32, 32, 8, Ctrl_>
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = fp16_t;
|
||||
@@ -317,8 +327,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 16, 16, 16, Ctrl_>
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = fp16_t;
|
||||
@@ -380,8 +390,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplF16F16F32M16N16K32
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 16, 16, 32, Ctrl_>
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = fp16_t;
|
||||
@@ -443,8 +453,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K32
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplF16F16F32M4N64K4
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 4, 64, 4, Ctrl_>
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = fp16_t;
|
||||
@@ -507,8 +517,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M4N64K4
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplF16F16F32M64N4K4
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 64, 4, 4, Ctrl_>
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = fp16_t;
|
||||
@@ -572,8 +582,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M64N4K4
|
||||
};
|
||||
|
||||
// Bf16
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 32, 32, 8, Ctrl_>
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = bf16_t;
|
||||
@@ -661,8 +671,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 16, 16, 16, Ctrl_>
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = bf16_t;
|
||||
@@ -749,8 +759,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 4, 64, 4, Ctrl_>
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = bf16_t;
|
||||
@@ -839,8 +849,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 64, 4, 4, Ctrl_>
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = bf16_t;
|
||||
@@ -930,8 +940,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
|
||||
};
|
||||
|
||||
// gfx950
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K16
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 32, 32, 16, Ctrl_>
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = fp16_t;
|
||||
@@ -1044,8 +1054,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K16
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 32, 32, 16, Ctrl_>
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = bf16_t;
|
||||
@@ -1158,6 +1168,60 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16
|
||||
}
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------------
|
||||
// Backward-compatibility aliases (preserve original names)
|
||||
// ---------------------------------------------------------------------------------
|
||||
|
||||
// BF16
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32 =
|
||||
WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 16, 16, 32, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 =
|
||||
WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 32, 32, 8, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 =
|
||||
WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 16, 16, 16, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4 =
|
||||
WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 4, 64, 4, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 =
|
||||
WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 64, 4, 4, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16 =
|
||||
WarpGemmAttributeMfmaImpl<bf16_t, bf16_t, float, 32, 32, 16, Ctrl_>;
|
||||
|
||||
// F16
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImplF16F16F32M32N32K8 =
|
||||
WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 32, 32, 8, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImplF16F16F32M16N16K16 =
|
||||
WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 16, 16, 16, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImplF16F16F32M16N16K32 =
|
||||
WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 16, 16, 32, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImplF16F16F32M4N64K4 =
|
||||
WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 4, 64, 4, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImplF16F16F32M64N4K4 =
|
||||
WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 64, 4, 4, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImplF16F16F32M32N32K16 =
|
||||
WarpGemmAttributeMfmaImpl<fp16_t, fp16_t, float, 32, 32, 16, Ctrl_>;
|
||||
|
||||
// FP8
|
||||
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
|
||||
@@ -1318,6 +1382,31 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
|
||||
}
|
||||
};
|
||||
|
||||
// Map FP8/BF8 shapes to the primary template via specializations (reuse base impl)
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 16, 16, 32, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<fp8_t, fp8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<fp8_t, bf8_t, float, 16, 16, 32, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<fp8_t, bf8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<bf8_t, fp8_t, float, 16, 16, 32, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<bf8_t, fp8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<bf8_t, bf8_t, float, 16, 16, 32, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<bf8_t, bf8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
{
|
||||
@@ -1501,12 +1590,38 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
}
|
||||
};
|
||||
|
||||
// Map FP8/BF8 32x32x16 shapes to the primary template via specializations (reuse base impl)
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 32, 32, 16, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<fp8_t, bf8_t, float, 32, 32, 16, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, bf8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<bf8_t, fp8_t, float, 32, 32, 16, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, fp8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<bf8_t, bf8_t, float, 32, 32, 16, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, bf8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
// Back-compat aliases now point to primary-template specializations
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, Ctrl_>;
|
||||
WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 32, 32, 16, Ctrl_>;
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<fp8_t, fp8_t, Ctrl_>;
|
||||
WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 16, 16, 32, Ctrl_>;
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, bf8_t, Ctrl_>;
|
||||
@@ -1516,7 +1631,7 @@ using WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_bf8 =
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<bf8_t, bf8_t, Ctrl_>;
|
||||
WarpGemmAttributeMfmaImpl<bf8_t, bf8_t, float, 16, 16, 32, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 =
|
||||
@@ -1555,6 +1670,130 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
//__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
|
||||
// opsel, scale_b)
|
||||
#if defined(__gfx950__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// Map FP8/BF8 16x16x128 scale shapes to the primary template via specializations
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 16, 16, 128, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<fp8_t, fp8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<fp8_t, bf8_t, float, 16, 16, 128, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<fp8_t, bf8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<bf8_t, fp8_t, float, 16, 16, 128, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<bf8_t, fp8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<bf8_t, bf8_t, float, 16, 16, 128, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<bf8_t, bf8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
// Back-compat aliases
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 16, 16, 128, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl<fp8_t, bf8_t, float, 16, 16, 128, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl<bf8_t, fp8_t, float, 16, 16, 128, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<bf8_t, bf8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = pk_fp4_t;
|
||||
using BDataType = pk_fp4_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 16>;
|
||||
using BVecType = ext_vector_t<BDataType, 16>;
|
||||
using CVecType = ext_vector_t<CDataType, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 128;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 32;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <index_t opselA, index_t opselB, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
@@ -1716,23 +1955,49 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base
|
||||
}
|
||||
};
|
||||
|
||||
// Map FP8/BF8 32x32x64 scale shapes to the primary template via specializations
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 32, 32, 64, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<fp8_t, fp8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<fp8_t, bf8_t, float, 32, 32, 64, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<fp8_t, bf8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<bf8_t, fp8_t, float, 32, 32, 64, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<bf8_t, fp8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<bf8_t, bf8_t, float, 32, 32, 64, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<bf8_t, bf8_t, Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
// Back-compat aliases
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<fp8_t, fp8_t, Ctrl_>;
|
||||
WarpGemmAttributeMfmaImpl<fp8_t, fp8_t, float, 32, 32, 64, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<fp8_t, bf8_t, Ctrl_>;
|
||||
WarpGemmAttributeMfmaImpl<fp8_t, bf8_t, float, 32, 32, 64, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<bf8_t, fp8_t, Ctrl_>;
|
||||
WarpGemmAttributeMfmaImpl<bf8_t, fp8_t, float, 32, 32, 64, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<bf8_t, bf8_t, Ctrl_>;
|
||||
WarpGemmAttributeMfmaImpl<bf8_t, bf8_t, float, 32, 32, 64, Ctrl_>;
|
||||
|
||||
// int8
|
||||
// int8: map shapes to the primary template via specializations (reuse the concrete impl bodies)
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
|
||||
{
|
||||
@@ -1980,6 +2245,31 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x32_i8
|
||||
}
|
||||
};
|
||||
|
||||
// Primary-template specializations delegating to the bodies above
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<int8_t, int8_t, int32_t, 32, 32, 16, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<int8_t, int8_t, int32_t, 16, 16, 32, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<int8_t, int8_t, int32_t, 16, 16, 64, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_i32_16x16x64_i8<Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeMfmaImpl<int8_t, int8_t, int32_t, 32, 32, 32, Ctrl_>
|
||||
: WarpGemmAttributeMfmaImpl_i32_32x32x32_i8<Ctrl_>
|
||||
{
|
||||
};
|
||||
|
||||
#undef DISPATCH_MFMA_
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -81,5 +81,12 @@ struct WarpGemmAttributeSmfmac
|
||||
{
|
||||
Impl{}(c_vec, a_vec, b_vec, idx, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& idx) const
|
||||
{
|
||||
return Impl{}(a_vec, b_vec, idx);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -7,10 +7,19 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// fp16 2:4 structured sparsity
|
||||
// Primary template: generic Smfmac warp-gemm attribute (to be specialized per supported shape)
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
index_t kM,
|
||||
index_t kN,
|
||||
index_t kK,
|
||||
WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeSmfmacImpl; // no definition, only specializations are provided
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16
|
||||
// fp16 2:4 structured sparsity
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeSmfmacImpl<fp16_t, fp16_t, float, 32, 32, 16, Ctrl_>
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = fp16_t;
|
||||
@@ -56,12 +65,28 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = idx;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec[idx]
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& idx) const
|
||||
{
|
||||
#if defined(__gfx94_) or defined(__gfx95_)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, fp32x4_t{0.f}, idx, 0, 0));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = idx;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32
|
||||
template <WGAttrCtlEnum Ctrl_>
|
||||
struct WarpGemmAttributeSmfmacImpl<fp16_t, fp16_t, float, 16, 16, 32, Ctrl_>
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = fp16_t;
|
||||
@@ -107,8 +132,33 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = idx;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec[idx]
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& idx) const
|
||||
{
|
||||
#if defined(__gfx94_) or defined(__gfx95_)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, fp32x4_t{0.f}, idx, 0, 0));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = idx;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// Back-compat aliases
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeSmfmacImplF16F16F32M16N16K32 =
|
||||
WarpGemmAttributeSmfmacImpl<fp16_t, fp16_t, float, 16, 16, 32, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeSmfmacImplF16F16F32M32N32K16 =
|
||||
WarpGemmAttributeSmfmacImpl<fp16_t, fp16_t, float, 32, 32, 16, Ctrl_>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
271
include/ck_tile/ops/gemm/warp/warp_gemm_core_dispatcher.hpp
Normal file
271
include/ck_tile/ops/gemm/warp/warp_gemm_core_dispatcher.hpp
Normal file
@@ -0,0 +1,271 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
||||
// Add smfmac attribute impl for structured sparsity support
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// WarpGemmCoreDispatcher: choose the underlying MFMA/SMFMAC attribute Impl
|
||||
// based on A/B/Acc types and (M,N,K) per wave.
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename AccType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool UseStructuredSparsity = false>
|
||||
struct WarpGemmCoreDispatcher;
|
||||
|
||||
// Generic specialization for MFMA
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename AccType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave>
|
||||
struct WarpGemmCoreDispatcher<AType, BType, AccType, MPerWave, NPerWave, KPerWave, false>
|
||||
{
|
||||
using Impl = WarpGemmAttributeMfmaImpl<AType,
|
||||
BType,
|
||||
AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
};
|
||||
|
||||
// Generic specialization for SMFMAC
|
||||
// TODO: we also need to support smfmac for FP8/BF8 and I8 format
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename AccType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave>
|
||||
struct WarpGemmCoreDispatcher<AType, BType, AccType, MPerWave, NPerWave, KPerWave, true>
|
||||
{
|
||||
using Impl = WarpGemmAttributeSmfmacImpl<AType,
|
||||
BType,
|
||||
AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
};
|
||||
|
||||
// Specialization for special cases
|
||||
template <>
|
||||
struct WarpGemmCoreDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false>
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
#else
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WarpGemmCoreDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false>
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
#else
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WarpGemmCoreDispatcher<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false>
|
||||
{
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
4,
|
||||
64,
|
||||
4,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WarpGemmCoreDispatcher<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false>
|
||||
{
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
64,
|
||||
4,
|
||||
4,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WarpGemmCoreDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 32, false>
|
||||
{
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WarpGemmCoreDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 32, false>
|
||||
{
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WarpGemmCoreDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 64, false>
|
||||
{
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
float,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WarpGemmCoreDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 64, false>
|
||||
{
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
float,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WarpGemmCoreDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false>
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
#else
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WarpGemmCoreDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false>
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
#else
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WarpGemmCoreDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false>
|
||||
{
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
4,
|
||||
64,
|
||||
4,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WarpGemmCoreDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false>
|
||||
{
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
64,
|
||||
4,
|
||||
4,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
};
|
||||
|
||||
// Iterate-K variants built from the direct impls
|
||||
template <>
|
||||
struct WarpGemmCoreDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 32, false>
|
||||
{
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::int8_t,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int32_t,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WarpGemmCoreDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 64, false>
|
||||
{
|
||||
using Impl = WarpGemmAttributeMfmaImpl<ck_tile::int8_t,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int32_t,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
WGAttrCtlEnum::Default_>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
|
||||
|
||||
#include "ck_tile/ops/gemm/warp/detail/warp_gemm_attribute_mfma_compose.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
namespace impl {
|
||||
@@ -163,16 +163,30 @@ template <typename AType,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< //
|
||||
AType,
|
||||
BType,
|
||||
AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
TransposeC,
|
||||
SwizzleA,
|
||||
UseStructuredSparsity,
|
||||
AttrNumAccess>::Type;
|
||||
using WarpGemmDispatcher =
|
||||
#if defined(CK_TILE_ROUTE_WARP_GEMM_DISPATCHER_TO_MAKE)
|
||||
typename MakeWarpGemm<TransposeC,
|
||||
SwizzleA,
|
||||
AType,
|
||||
BType,
|
||||
AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
UseStructuredSparsity,
|
||||
AttrNumAccess>::Type;
|
||||
#else
|
||||
using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< //
|
||||
AType,
|
||||
BType,
|
||||
AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
TransposeC,
|
||||
SwizzleA,
|
||||
UseStructuredSparsity,
|
||||
AttrNumAccess>::Type;
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -105,6 +105,38 @@ struct WarpGemmSmfmacImpl
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
}
|
||||
|
||||
template <typename ATensor, typename BTensor>
|
||||
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
|
||||
{
|
||||
using CTensor = CWarpTensor;
|
||||
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
|
||||
constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio;
|
||||
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using AVecCompressed =
|
||||
ext_vector_t<ADataType, ATensor::get_thread_buffer_size() / CompressionRatio>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
|
||||
const int32_t idx = compress_a(a_vec);
|
||||
|
||||
// @TODO can we simply set a_vec_pruned to a_vec[0:3]?
|
||||
const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};
|
||||
|
||||
// c_vec = a_vec * b_vec[idx]
|
||||
auto c_vec = WarpGemmAttribute{}(a_vec_pruned, b_vec, idx);
|
||||
|
||||
CTensor c;
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
return c;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,7 @@ set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
option(CK_TILE_TEST_ROUTE_DISPATCHER_TO_MAKE "Route ck_tile::WarpGemmDispatcher to MakeWarpGemm in tests" OFF)
|
||||
set(EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
@@ -16,6 +17,14 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
|
||||
)
|
||||
set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
|
||||
# Build routed variants of the common options when the switch is ON
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
if(CK_TILE_TEST_ROUTE_DISPATCHER_TO_MAKE)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_ROUTE_WARP_GEMM_DISPATCHER_TO_MAKE)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_ROUTE_WARP_GEMM_DISPATCHER_TO_MAKE)
|
||||
endif()
|
||||
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
|
||||
@@ -59,3 +68,38 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile_gemm tests for current target test_ck_tile_gemm_pipeline")
|
||||
endif()
|
||||
|
||||
|
||||
# MFMA (Attribute Compose) - STRICTLY CDNA (gfx90a, gfx94x, gfx950)
|
||||
if(GPU_TARGETS MATCHES "gfx90a|gfx94|gfx95")
|
||||
add_gtest_executable(test_ck_tile_gemm_attr_compose_int8_mfma test_warp_gemm_attr_compose_int8_mfma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_attr_compose_int8_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED})
|
||||
endif()
|
||||
|
||||
# MFMA (Attribute Compose) - STRICTLY CDNA3+ (gfx94x, gfx950)
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
add_gtest_executable(test_ck_tile_gemm_attr_compose_fp8_mfma test_warp_gemm_attr_compose_fp8_mfma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_attr_compose_fp8_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED})
|
||||
|
||||
add_gtest_executable(test_ck_tile_gemm_attr_compose_bf8_mfma test_warp_gemm_attr_compose_bf8_mfma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_attr_compose_bf8_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED})
|
||||
|
||||
add_gtest_executable(test_ck_tile_gemm_attr_compose_mixed_fp8_bf8_mfma test_warp_gemm_attr_compose_mixed_fp8_bf8_mfma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_attr_compose_mixed_fp8_bf8_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED})
|
||||
endif()
|
||||
|
||||
# MFMA (Attribute Compose) - STRICTLY CDNA (gfx90a, gfx94x, gfx950)
|
||||
if(GPU_TARGETS MATCHES "gfx90a|gfx94|gfx95")
|
||||
add_gtest_executable(test_ck_tile_gemm_attr_compose_fp16_mfma test_warp_gemm_attr_compose_fp16_mfma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_attr_compose_fp16_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED})
|
||||
|
||||
add_gtest_executable(test_ck_tile_gemm_attr_compose_bf16_mfma test_warp_gemm_attr_compose_bf16_mfma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_attr_compose_bf16_mfma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED})
|
||||
endif()
|
||||
|
||||
# --- SMFMAC Tests ---
|
||||
# SMFMAC (Attribute Compose) - STRICTLY CDNA3+ (gfx94x, gfx950)
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
add_gtest_executable(test_ck_tile_gemm_attr_compose_fp16_smfmac test_warp_gemm_attr_compose_fp16_smfmac.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_attr_compose_fp16_smfmac PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS_WITHOUT_ROUTED})
|
||||
endif()
|
||||
301
test/ck_tile/gemm/test_warp_gemm_attr_compose.hpp
Normal file
301
test/ck_tile/gemm/test_warp_gemm_attr_compose.hpp
Normal file
@@ -0,0 +1,301 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_TILE_TEST_WARP_GEMM_ATTR_COMPOSE_HPP
|
||||
#define CK_TILE_TEST_WARP_GEMM_ATTR_COMPOSE_HPP
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <algorithm>
|
||||
#include "ck_tile/ops/gemm/warp/detail/warp_gemm_attribute_mfma_compose.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
||||
// For real kernel run and verification
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
// ---------------------- Typed tests for Dispatcher coverage ----------------------
|
||||
template <typename A,
|
||||
typename B,
|
||||
typename Acc,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
ck_tile::WGAttrNumAccessEnum NA = ck_tile::WGAttrNumAccessEnum::Single>
|
||||
struct WGDispCase
|
||||
{
|
||||
using AType = A;
|
||||
using BType = B;
|
||||
using AccType = Acc;
|
||||
static constexpr ck_tile::index_t MPerWave = M;
|
||||
static constexpr ck_tile::index_t NPerWave = N;
|
||||
static constexpr ck_tile::index_t KPerWave = K;
|
||||
static constexpr bool kTransposeC = TransposeC;
|
||||
static constexpr bool kSwizzleA = SwizzleA;
|
||||
static constexpr bool kUSS = UseStructuredSparsity;
|
||||
static constexpr ck_tile::WGAttrNumAccessEnum kNA = NA;
|
||||
};
|
||||
|
||||
template <typename Case>
|
||||
class WGCompileTimeTest : public ::testing::Test
|
||||
{
|
||||
public:
|
||||
void RunTest()
|
||||
{
|
||||
// Dispatcher-selected WarpGemm
|
||||
using Disp = typename ck_tile::WarpGemmDispatcher<typename Case::AType,
|
||||
typename Case::BType,
|
||||
typename Case::AccType,
|
||||
Case::MPerWave,
|
||||
Case::NPerWave,
|
||||
Case::KPerWave,
|
||||
Case::kTransposeC,
|
||||
Case::kSwizzleA,
|
||||
Case::kUSS,
|
||||
Case::kNA>;
|
||||
|
||||
// Factory-selected WarpGemm (MakeWarpGemm)
|
||||
using Make = typename ck_tile::MakeWarpGemm<Case::kTransposeC,
|
||||
Case::kSwizzleA,
|
||||
typename Case::AType,
|
||||
typename Case::BType,
|
||||
typename Case::AccType,
|
||||
Case::MPerWave,
|
||||
Case::NPerWave,
|
||||
Case::KPerWave,
|
||||
Case::kUSS,
|
||||
Case::kNA>::Type;
|
||||
|
||||
// 1) Scalar compile-time constants must match
|
||||
static_assert(Disp::kM == Make::kM, "kM differs between Dispatcher and MakeWarpGemm");
|
||||
static_assert(Disp::kN == Make::kN, "kN differs between Dispatcher and MakeWarpGemm");
|
||||
static_assert(Disp::kK == Make::kK, "kK differs between Dispatcher and MakeWarpGemm");
|
||||
static_assert(Disp::kKPerThread == Make::kKPerThread,
|
||||
"kKPerThread differs between Dispatcher and MakeWarpGemm");
|
||||
static_assert(Disp::get_num_of_access() == Make::get_num_of_access(),
|
||||
"get_num_of_access() differs between Dispatcher and MakeWarpGemm");
|
||||
|
||||
// 2) Data types must match
|
||||
static_assert(std::is_same_v<typename Disp::ADataType, typename Make::ADataType>,
|
||||
"ADataType differs between Dispatcher and MakeWarpGemm");
|
||||
static_assert(std::is_same_v<typename Disp::BDataType, typename Make::BDataType>,
|
||||
"BDataType differs between Dispatcher and MakeWarpGemm");
|
||||
static_assert(std::is_same_v<typename Disp::CDataType, typename Make::CDataType>,
|
||||
"CDataType differs between Dispatcher and MakeWarpGemm");
|
||||
|
||||
// 3) Distribution encodings must match (ensures identical warp tiling/layout)
|
||||
static_assert(
|
||||
std::is_same_v<typename Disp::AWarpDstrEncoding, typename Make::AWarpDstrEncoding>,
|
||||
"AWarpDstrEncoding differs between Dispatcher and MakeWarpGemm");
|
||||
static_assert(
|
||||
std::is_same_v<typename Disp::BWarpDstrEncoding, typename Make::BWarpDstrEncoding>,
|
||||
"BWarpDstrEncoding differs between Dispatcher and MakeWarpGemm");
|
||||
static_assert(
|
||||
std::is_same_v<typename Disp::CWarpDstrEncoding, typename Make::CWarpDstrEncoding>,
|
||||
"CWarpDstrEncoding differs between Dispatcher and MakeWarpGemm");
|
||||
|
||||
// 4) Final tensor types must match (encodes DataType + Distribution)
|
||||
static_assert(std::is_same_v<typename Disp::AWarpTensor, typename Make::AWarpTensor>,
|
||||
"AWarpTensor differs between Dispatcher and MakeWarpGemm");
|
||||
static_assert(std::is_same_v<typename Disp::BWarpTensor, typename Make::BWarpTensor>,
|
||||
"BWarpTensor differs between Dispatcher and MakeWarpGemm");
|
||||
static_assert(std::is_same_v<typename Disp::CWarpTensor, typename Make::CWarpTensor>,
|
||||
"CWarpTensor differs between Dispatcher and MakeWarpGemm");
|
||||
|
||||
SUCCEED();
|
||||
}
|
||||
};
|
||||
|
||||
// ---------------------- Runtime tests: Compare Dispatcher (MFMA and SMFMA only) vs MakeWarpGemm vs
|
||||
// CPU ----------------------
|
||||
// ---------------------- Runtime operator() behavior tests on GPU ----------------------
|
||||
template <bool UseMakeWarpGemm,
|
||||
typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
bool TransposeC,
|
||||
bool SwizzleA,
|
||||
bool UseStructuredSparsity,
|
||||
ck_tile::WGAttrNumAccessEnum NumAccess>
|
||||
struct WarpGemmKernel
|
||||
{
|
||||
static constexpr int kBlockSize = 64;
|
||||
__device__ void operator()(const AType* A, const BType* B, CType* C) const
|
||||
{
|
||||
using WarpGemm = std::conditional_t<UseMakeWarpGemm,
|
||||
typename ck_tile::MakeWarpGemm<TransposeC,
|
||||
SwizzleA,
|
||||
AType,
|
||||
BType,
|
||||
CType,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
UseStructuredSparsity,
|
||||
NumAccess>::Type,
|
||||
ck_tile::WarpGemmDispatcher<AType,
|
||||
BType,
|
||||
CType,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
TransposeC,
|
||||
SwizzleA,
|
||||
UseStructuredSparsity,
|
||||
NumAccess>>;
|
||||
|
||||
// A: [M,K] row-major (packed)
|
||||
const auto a_view =
|
||||
ck_tile::make_naive_tensor_view_packed<ck_tile::address_space_enum::global>(
|
||||
const_cast<AType*>(A), ck_tile::make_tuple(M, K));
|
||||
// B: expose as logical [N,K] with strides (1, N) over the original row-major [K,N] buffer
|
||||
const auto b_view = ck_tile::make_naive_tensor_view<ck_tile::address_space_enum::global>(
|
||||
const_cast<BType*>(B), ck_tile::make_tuple(N, K), ck_tile::make_tuple(1, N));
|
||||
// C: [M,N] row-major (packed)
|
||||
const auto c_view =
|
||||
ck_tile::make_naive_tensor_view_packed<ck_tile::address_space_enum::global>(
|
||||
const_cast<CType*>(C), ck_tile::make_tuple(M, N));
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto a_len = AWarpTensor::get_tile_distribution().get_lengths();
|
||||
constexpr auto b_len = BWarpTensor::get_tile_distribution().get_lengths();
|
||||
constexpr auto c_len = CWarpTensor::get_tile_distribution().get_lengths();
|
||||
|
||||
auto a_win = ck_tile::make_tile_window(
|
||||
a_view, a_len, ck_tile::make_multi_index(0, 0), AWarpTensor::get_tile_distribution());
|
||||
auto b_win = ck_tile::make_tile_window(
|
||||
b_view, b_len, ck_tile::make_multi_index(0, 0), BWarpTensor::get_tile_distribution());
|
||||
auto c_win = ck_tile::make_tile_window(
|
||||
c_view, c_len, ck_tile::make_multi_index(0, 0), CWarpTensor::get_tile_distribution());
|
||||
|
||||
AWarpTensor a_tile;
|
||||
BWarpTensor b_tile;
|
||||
ck_tile::load_tile(a_tile, a_win);
|
||||
ck_tile::load_tile(b_tile, b_win);
|
||||
|
||||
CWarpTensor c_tile;
|
||||
c_tile = WarpGemm{}(a_tile, b_tile);
|
||||
ck_tile::store_tile(c_win, c_tile);
|
||||
}
|
||||
};
|
||||
|
||||
// ---------------------- New runtime helper: run a WG on device with given A/B into C
|
||||
// ----------------------
|
||||
template <typename Case, bool UseMakeWarpGemm>
|
||||
static void RunWarpGemmCase(const ck_tile::HostTensor<typename Case::AType>& A,
|
||||
const ck_tile::HostTensor<typename Case::BType>& B,
|
||||
ck_tile::HostTensor<typename Case::AccType>& C)
|
||||
{
|
||||
using AType = typename Case::AType;
|
||||
using BType = typename Case::BType;
|
||||
using CType = typename Case::AccType; // CDataType equals Acc for these tests
|
||||
|
||||
ck_tile::DeviceMem Ad(A.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem Bd(B.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem Cd(C.get_element_space_size_in_bytes());
|
||||
|
||||
Ad.ToDevice(A.data());
|
||||
Bd.ToDevice(B.data());
|
||||
Cd.SetZero();
|
||||
|
||||
dim3 grid(1);
|
||||
dim3 block{64};
|
||||
|
||||
using Kernel = WarpGemmKernel<UseMakeWarpGemm,
|
||||
AType,
|
||||
BType,
|
||||
CType,
|
||||
Case::MPerWave,
|
||||
Case::NPerWave,
|
||||
Case::KPerWave,
|
||||
Case::kTransposeC,
|
||||
Case::kSwizzleA,
|
||||
Case::kUSS,
|
||||
Case::kNA>;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true},
|
||||
ck_tile::make_kernel(Kernel{},
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
static_cast<const AType*>(Ad.GetDeviceBuffer()),
|
||||
static_cast<const BType*>(Bd.GetDeviceBuffer()),
|
||||
static_cast<CType*>(Cd.GetDeviceBuffer())));
|
||||
|
||||
Cd.FromDevice(C.mData.data());
|
||||
}
|
||||
|
||||
// enforce 2:4 sparsity on A for SMFMA runtime cases (only meaningful for half_t here)
|
||||
template <typename AType>
|
||||
static inline void make_2to4_sparse_A(ck_tile::HostTensor<AType>& A)
|
||||
{
|
||||
// zero half the values in each consecutive group of 4 along K for each row m
|
||||
const ck_tile::index_t M = A.mDesc.get_lengths()[0];
|
||||
const ck_tile::index_t K = A.mDesc.get_lengths()[1];
|
||||
for(ck_tile::index_t m = 0; m < M; ++m)
|
||||
{
|
||||
for(ck_tile::index_t k = 0; k + 3 < K; k += 4)
|
||||
{
|
||||
// keep entries at k and k+2, zero k+1 and k+3 (simple 2:4 pattern)
|
||||
A(m, k + 1) = ck_tile::type_convert<AType>(0);
|
||||
A(m, k + 3) = ck_tile::type_convert<AType>(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Case>
|
||||
class WGRuntimeTest : public ::testing::Test
|
||||
{
|
||||
public:
|
||||
void RunTest()
|
||||
{
|
||||
// Equivalent MakeWarpGemm
|
||||
using AType = typename Case::AType;
|
||||
using BType = typename Case::BType;
|
||||
using CType = typename Case::AccType;
|
||||
|
||||
constexpr ck_tile::index_t M = Case::MPerWave;
|
||||
constexpr ck_tile::index_t N = Case::NPerWave;
|
||||
constexpr ck_tile::index_t K = Case::KPerWave;
|
||||
|
||||
ck_tile::HostTensor<AType> A({M, K});
|
||||
ck_tile::HostTensor<BType> B({K, N});
|
||||
ck_tile::HostTensor<CType> C_disp({M, N});
|
||||
ck_tile::HostTensor<CType> C_make({M, N});
|
||||
|
||||
for(ck_tile::index_t m = 0; m < M; ++m)
|
||||
for(ck_tile::index_t k = 0; k < K; ++k)
|
||||
A(m, k) = ck_tile::type_convert<AType>((m + 1) * 0.1f + (k + 1) * 0.01f);
|
||||
|
||||
if constexpr(Case::kUSS)
|
||||
{
|
||||
// ensure A satisfies 2:4 sparsity for SMFMA
|
||||
make_2to4_sparse_A(A);
|
||||
}
|
||||
|
||||
for(ck_tile::index_t k0 = 0; k0 < K; ++k0)
|
||||
for(ck_tile::index_t n = 0; n < N; ++n)
|
||||
B(k0, n) = ck_tile::type_convert<BType>((k0 + 1) * 0.2f - (n + 1) * 0.03f);
|
||||
|
||||
C_disp.SetZero();
|
||||
C_make.SetZero();
|
||||
RunWarpGemmCase<Case, /*UseMakeWarpGemm=*/false>(A, B, C_disp);
|
||||
RunWarpGemmCase<Case, /*UseMakeWarpGemm=*/true>(A, B, C_make);
|
||||
|
||||
EXPECT_TRUE(
|
||||
ck_tile::check_err(C_disp, C_make, "Dispatcher vs MakeWarpGemm mismatch", 0, 0));
|
||||
}
|
||||
};
|
||||
|
||||
#endif // CK_TILE_TEST_WARP_GEMM_ATTR_COMPOSE_HPP
|
||||
35
test/ck_tile/gemm/test_warp_gemm_attr_compose_bf16_mfma.cpp
Normal file
35
test/ck_tile/gemm/test_warp_gemm_attr_compose_bf16_mfma.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_warp_gemm_attr_compose.hpp"
|
||||
|
||||
using WGDispatcherTypesList = ::testing::Types<
|
||||
// clang-format off
|
||||
|
||||
// bf16
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true>,
|
||||
// WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true, true>,
|
||||
WGDispCase<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, true>>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList);
|
||||
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
|
||||
|
||||
TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); }
|
||||
|
||||
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); }
|
||||
27
test/ck_tile/gemm/test_warp_gemm_attr_compose_bf8_mfma.cpp
Normal file
27
test/ck_tile/gemm/test_warp_gemm_attr_compose_bf8_mfma.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_warp_gemm_attr_compose.hpp"
|
||||
|
||||
using WGDispatcherTypesList = ::testing::Types<
|
||||
// clang-format off
|
||||
|
||||
// bf8
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, false>,
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 32, false>,
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 32, false>,
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 32, true>,
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 64, false>,
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true>,
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false>,
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false>,
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>,
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList);
|
||||
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
|
||||
|
||||
TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); }
|
||||
|
||||
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); }
|
||||
35
test/ck_tile/gemm/test_warp_gemm_attr_compose_fp16_mfma.cpp
Normal file
35
test/ck_tile/gemm/test_warp_gemm_attr_compose_fp16_mfma.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_warp_gemm_attr_compose.hpp"
|
||||
|
||||
using WGDispatcherTypesList = ::testing::Types<
|
||||
// clang-format off
|
||||
|
||||
// fp16
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float,16, 16, 32, false, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true, false, false, ck_tile::WGAttrNumAccessEnum::Double>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true>,
|
||||
// WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true, true>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, true>>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList);
|
||||
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
|
||||
|
||||
TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); }
|
||||
|
||||
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); }
|
||||
@@ -0,0 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_warp_gemm_attr_compose.hpp"
|
||||
|
||||
using WGDispatcherTypesList = ::testing::Types<
|
||||
// clang-format off
|
||||
|
||||
// fp16 2:4 structural sparsity
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, true>,
|
||||
WGDispCase<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, true>>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList);
|
||||
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
|
||||
|
||||
TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); }
|
||||
|
||||
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); }
|
||||
27
test/ck_tile/gemm/test_warp_gemm_attr_compose_fp8_mfma.cpp
Normal file
27
test/ck_tile/gemm/test_warp_gemm_attr_compose_fp8_mfma.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_warp_gemm_attr_compose.hpp"
|
||||
|
||||
using WGDispatcherTypesList = ::testing::Types<
|
||||
// clang-format off
|
||||
|
||||
// fp8
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false>,
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 32, false>,
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 32, false>,
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 64, false>,
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true>,
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 32, true>,
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false>,
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false>,
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>,
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList);
|
||||
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
|
||||
|
||||
TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); }
|
||||
|
||||
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); }
|
||||
21
test/ck_tile/gemm/test_warp_gemm_attr_compose_int8_mfma.cpp
Normal file
21
test/ck_tile/gemm/test_warp_gemm_attr_compose_int8_mfma.cpp
Normal file
@@ -0,0 +1,21 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_warp_gemm_attr_compose.hpp"
|
||||
|
||||
using WGDispatcherTypesList = ::testing::Types<
|
||||
// clang-format off
|
||||
|
||||
// int8
|
||||
WGDispCase<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, false>,
|
||||
WGDispCase<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true>,
|
||||
WGDispCase<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, false>,
|
||||
WGDispCase<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true>>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList);
|
||||
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
|
||||
|
||||
TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); }
|
||||
|
||||
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); }
|
||||
@@ -0,0 +1,29 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_warp_gemm_attr_compose.hpp"
|
||||
|
||||
using WGDispatcherTypesList = ::testing::Types<
|
||||
// clang-format off
|
||||
|
||||
// mixed fp8/bf8
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, false>,
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, true>,
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, false>,
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, true>,
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false>,
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false>,
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false>,
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false>,
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>,
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>,
|
||||
WGDispCase<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>,
|
||||
WGDispCase<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, ck_tile::WGAttrNumAccessEnum::Quad>>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(WGCompileTimeTest, WGDispatcherTypesList);
|
||||
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
|
||||
|
||||
TYPED_TEST(WGCompileTimeTest, Instantiate) { this->RunTest(); }
|
||||
|
||||
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { this->RunTest(); }
|
||||
Reference in New Issue
Block a user