mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
Merge some updates for ck_tile headers (#3342)
* fix some issues from internal branch * update cshuffle_epilogue * update cshuffle_epilogue * update cshuffle * update warp_gemm
This commit is contained in:
@@ -306,6 +306,16 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIter
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
template <typename A, typename B, WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_f8f6f4 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<A, B>, AttrNumAccess>>;
|
||||
|
||||
@@ -68,6 +68,19 @@ struct WarpGemmAttributeWmma
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeWmmaImpl_>;
|
||||
|
||||
// When kTransC is true and A/B types differ, we need an impl with swapped types
|
||||
using TransposedImpl =
|
||||
std::conditional_t<kTransC &&
|
||||
!std::is_same_v<typename Impl::ADataType, typename Impl::BDataType>,
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<typename Impl::TraitsType::ArchType,
|
||||
typename Impl::BDataType,
|
||||
typename Impl::ADataType,
|
||||
typename Impl::CDataType,
|
||||
Impl::kM,
|
||||
Impl::kN,
|
||||
Impl::kK>>,
|
||||
Impl>;
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
@@ -104,7 +117,7 @@ struct WarpGemmAttributeWmma
|
||||
{
|
||||
if constexpr(kTransC)
|
||||
{
|
||||
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
|
||||
TransposedImpl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -117,7 +130,7 @@ struct WarpGemmAttributeWmma
|
||||
{
|
||||
if constexpr(kTransC)
|
||||
{
|
||||
return Impl{}(b_vec, a_vec);
|
||||
return TransposedImpl{}(b_vec, a_vec);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -22,9 +22,10 @@ struct WmmaTraits;
|
||||
template <typename Traits>
|
||||
struct WarpGemmAttributeWmmaImpl
|
||||
{
|
||||
using ADataType = typename Traits::ADataType;
|
||||
using BDataType = typename Traits::BDataType;
|
||||
using CDataType = typename Traits::CDataType;
|
||||
using TraitsType = Traits;
|
||||
using ADataType = typename Traits::ADataType;
|
||||
using BDataType = typename Traits::BDataType;
|
||||
using CDataType = typename Traits::CDataType;
|
||||
|
||||
using AVecType = typename Traits::AVecType;
|
||||
using BVecType = typename Traits::BVecType;
|
||||
|
||||
@@ -10,6 +10,8 @@ template <>
|
||||
struct WmmaTraits<gfx11_t, fp16_t, fp16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx11_t, fp16_t, fp16_t, float>
|
||||
{
|
||||
using ArchType = gfx11_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
@@ -30,6 +32,8 @@ template <>
|
||||
struct WmmaTraits<gfx11_t, bf16_t, bf16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx11_t, bf16_t, bf16_t, float>
|
||||
{
|
||||
using ArchType = gfx11_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
@@ -50,6 +54,8 @@ template <>
|
||||
struct WmmaTraits<gfx12_t, fp16_t, fp16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, fp16_t, fp16_t, float>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
@@ -70,6 +76,8 @@ template <>
|
||||
struct WmmaTraits<gfx12_t, bf16_t, bf16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, bf16_t, bf16_t, float>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
|
||||
@@ -10,6 +10,8 @@ template <>
|
||||
struct WmmaTraits<gfx11_t, int8_t, int8_t, int32_t, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx11_t, int8_t, int8_t, int32_t>
|
||||
{
|
||||
using ArchType = gfx11_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
@@ -35,6 +37,8 @@ template <>
|
||||
struct WmmaTraits<gfx12_t, int8_t, int8_t, int32_t, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, int8_t, int8_t, int32_t>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
@@ -60,6 +64,8 @@ template <>
|
||||
struct WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, fp8_t, fp8_t, float>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
@@ -80,6 +86,8 @@ template <>
|
||||
struct WmmaTraits<gfx12_t, bf8_t, bf8_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, bf8_t, bf8_t, float>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
@@ -100,6 +108,8 @@ template <>
|
||||
struct WmmaTraits<gfx12_t, fp8_t, bf8_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, fp8_t, bf8_t, float>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
|
||||
@@ -10,6 +10,8 @@ struct WmmaTraitsBase;
|
||||
template <typename ADType, typename BDType, typename CDType>
|
||||
struct WmmaTraitsBase<gfx11_t, ADType, BDType, CDType>
|
||||
{
|
||||
using ArchType = gfx11_t;
|
||||
|
||||
using ADataType = ADType;
|
||||
using BDataType = BDType;
|
||||
using CDataType = CDType;
|
||||
@@ -57,6 +59,8 @@ struct WmmaTraitsBase<gfx11_t, ADType, BDType, CDType>
|
||||
template <typename ADType, typename BDType, typename CDType>
|
||||
struct WmmaTraitsBase<gfx12_t, ADType, BDType, CDType>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
using ADataType = ADType;
|
||||
using BDataType = BDType;
|
||||
using CDataType = CDType;
|
||||
|
||||
@@ -100,6 +100,7 @@ template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 16, false> { using Ty
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, true> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
|
||||
@@ -113,6 +114,7 @@ template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false> { using Ty
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, true> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
|
||||
|
||||
// scale mfma based f8f6f4
|
||||
|
||||
Reference in New Issue
Block a user