Merge some updates for ck_tile headers (#3342)

* fix some issues from internal branch

* update cshuffle_epilogue

* update cshuffle_epilogue

* update cshuffle

* update warp_gemm
This commit is contained in:
joyeamd
2026-01-06 15:39:00 +08:00
committed by GitHub
parent 2b563ad048
commit b78563b3d3
14 changed files with 205 additions and 119 deletions

View File

@@ -306,6 +306,16 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIter
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
template <typename A, typename B, WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_f8f6f4 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<A, B>, AttrNumAccess>>;

View File

@@ -68,6 +68,19 @@ struct WarpGemmAttributeWmma
{
using Impl = remove_cvref_t<WarpGemmAttributeWmmaImpl_>;
// When kTransC is true and A/B types differ, we need an impl with swapped types
using TransposedImpl =
std::conditional_t<kTransC &&
!std::is_same_v<typename Impl::ADataType, typename Impl::BDataType>,
WarpGemmAttributeWmmaImpl<WmmaTraits<typename Impl::TraitsType::ArchType,
typename Impl::BDataType,
typename Impl::ADataType,
typename Impl::CDataType,
Impl::kM,
Impl::kN,
Impl::kK>>,
Impl>;
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
using CDataType = typename Impl::CDataType;
@@ -104,7 +117,7 @@ struct WarpGemmAttributeWmma
{
if constexpr(kTransC)
{
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
TransposedImpl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
}
else
{
@@ -117,7 +130,7 @@ struct WarpGemmAttributeWmma
{
if constexpr(kTransC)
{
return Impl{}(b_vec, a_vec);
return TransposedImpl{}(b_vec, a_vec);
}
else
{

View File

@@ -22,9 +22,10 @@ struct WmmaTraits;
template <typename Traits>
struct WarpGemmAttributeWmmaImpl
{
using ADataType = typename Traits::ADataType;
using BDataType = typename Traits::BDataType;
using CDataType = typename Traits::CDataType;
using TraitsType = Traits;
using ADataType = typename Traits::ADataType;
using BDataType = typename Traits::BDataType;
using CDataType = typename Traits::CDataType;
using AVecType = typename Traits::AVecType;
using BVecType = typename Traits::BVecType;

View File

@@ -10,6 +10,8 @@ template <>
struct WmmaTraits<gfx11_t, fp16_t, fp16_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx11_t, fp16_t, fp16_t, float>
{
using ArchType = gfx11_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
@@ -30,6 +32,8 @@ template <>
struct WmmaTraits<gfx11_t, bf16_t, bf16_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx11_t, bf16_t, bf16_t, float>
{
using ArchType = gfx11_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
@@ -50,6 +54,8 @@ template <>
struct WmmaTraits<gfx12_t, fp16_t, fp16_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, fp16_t, fp16_t, float>
{
using ArchType = gfx12_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
@@ -70,6 +76,8 @@ template <>
struct WmmaTraits<gfx12_t, bf16_t, bf16_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, bf16_t, bf16_t, float>
{
using ArchType = gfx12_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)

View File

@@ -10,6 +10,8 @@ template <>
struct WmmaTraits<gfx11_t, int8_t, int8_t, int32_t, 16, 16, 16>
: WmmaTraitsBase<gfx11_t, int8_t, int8_t, int32_t>
{
using ArchType = gfx11_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
@@ -35,6 +37,8 @@ template <>
struct WmmaTraits<gfx12_t, int8_t, int8_t, int32_t, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, int8_t, int8_t, int32_t>
{
using ArchType = gfx12_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
@@ -60,6 +64,8 @@ template <>
struct WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, fp8_t, fp8_t, float>
{
using ArchType = gfx12_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
@@ -80,6 +86,8 @@ template <>
struct WmmaTraits<gfx12_t, bf8_t, bf8_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, bf8_t, bf8_t, float>
{
using ArchType = gfx12_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
@@ -100,6 +108,8 @@ template <>
struct WmmaTraits<gfx12_t, fp8_t, bf8_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, fp8_t, bf8_t, float>
{
using ArchType = gfx12_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)

View File

@@ -10,6 +10,8 @@ struct WmmaTraitsBase;
template <typename ADType, typename BDType, typename CDType>
struct WmmaTraitsBase<gfx11_t, ADType, BDType, CDType>
{
using ArchType = gfx11_t;
using ADataType = ADType;
using BDataType = BDType;
using CDataType = CDType;
@@ -57,6 +59,8 @@ struct WmmaTraitsBase<gfx11_t, ADType, BDType, CDType>
template <typename ADType, typename BDType, typename CDType>
struct WmmaTraitsBase<gfx12_t, ADType, BDType, CDType>
{
using ArchType = gfx12_t;
using ADataType = ADType;
using BDataType = BDType;
using CDataType = CDType;

View File

@@ -100,6 +100,7 @@ template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 16, false> { using Ty
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, true> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; };
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
@@ -113,6 +114,7 @@ template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false> { using Ty
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, true> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
// scale mfma based f8f6f4