mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Merge some updates for ck_tile headers (#3342)
* fix some issues from internal branch * update cshuffle_epilogue * update cshuffle_epilogue * update cshuffle * update warp_gemm
This commit is contained in:
@@ -423,7 +423,7 @@ struct UniversalGemmKernel
|
||||
|
||||
const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
|
||||
: GemmPipeline::template GetVectorSizeA<false>();
|
||||
bool AsTesnorIsValid = {true};
|
||||
bool AsTensorIsValid = {true};
|
||||
static_for<0, NumATensor, 1>{}([&](auto index) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -437,15 +437,27 @@ struct UniversalGemmKernel
|
||||
"Can't support K that is not a multiple of k_batch * KPerBlock "
|
||||
"without padding!");
|
||||
}
|
||||
AsTesnorIsValid = false;
|
||||
AsTensorIsValid = false;
|
||||
}
|
||||
if(kargs.K % vectorSizeA != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
const auto remainder = kargs.K % vectorSizeA;
|
||||
constexpr ck_tile::index_t APackedSize =
|
||||
ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize;
|
||||
// oob can support to dword level
|
||||
if(remainder_in_bytes % 4 == 0)
|
||||
{
|
||||
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
|
||||
AsTensorIsValid = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
|
||||
}
|
||||
AsTensorIsValid = false;
|
||||
}
|
||||
AsTesnorIsValid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -457,20 +469,33 @@ struct UniversalGemmKernel
|
||||
CK_TILE_ERROR(
|
||||
"Can't support M that is not a multiple of MPerBlock without padding!");
|
||||
}
|
||||
AsTesnorIsValid = false;
|
||||
AsTensorIsValid = false;
|
||||
}
|
||||
if(kargs.M % vectorSizeA != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
const auto remainder = kargs.M % vectorSizeA;
|
||||
constexpr ck_tile::index_t APackedSize =
|
||||
ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize;
|
||||
// oob can support to dword level
|
||||
if(remainder_in_bytes % 4 == 0)
|
||||
{
|
||||
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
|
||||
|
||||
AsTensorIsValid = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
|
||||
}
|
||||
AsTensorIsValid = false;
|
||||
}
|
||||
AsTesnorIsValid = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
bool BsTesnorIsValid = {true};
|
||||
bool BsTensorIsValid = {true};
|
||||
const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB<true>()
|
||||
: GemmPipeline::template GetVectorSizeB<false>();
|
||||
static_for<0, NumBTensor, 1>{}([&](auto index) {
|
||||
@@ -484,47 +509,72 @@ struct UniversalGemmKernel
|
||||
CK_TILE_ERROR(
|
||||
"Can't support N that is not a multiple of NPerBlock without padding!");
|
||||
}
|
||||
BsTesnorIsValid = false;
|
||||
BsTensorIsValid = false;
|
||||
}
|
||||
if(kargs.N % vectorSizeB != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
const auto remainder = kargs.N % vectorSizeB;
|
||||
constexpr ck_tile::index_t BPackedSize =
|
||||
ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize;
|
||||
// oob can support to dword level
|
||||
if(remainder_in_bytes % 4 == 0)
|
||||
{
|
||||
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
|
||||
BsTensorIsValid = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
|
||||
}
|
||||
BsTensorIsValid = false;
|
||||
}
|
||||
BsTesnorIsValid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
|
||||
GemmPipeline::kPadK == false)
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
|
||||
GemmPipeline::kPadK == false)
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support K that is not a multiple of k_batch * KPerBlock "
|
||||
"without padding!");
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support K that is not a multiple of k_batch * KPerBlock "
|
||||
"without padding!");
|
||||
}
|
||||
BsTensorIsValid = false;
|
||||
}
|
||||
BsTesnorIsValid = false;
|
||||
}
|
||||
if(kargs.K % vectorSizeB != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
if(kargs.K % vectorSizeB != 0)
|
||||
{
|
||||
CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
|
||||
const auto remainder = kargs.K % vectorSizeB;
|
||||
constexpr ck_tile::index_t BPackedSize =
|
||||
ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize;
|
||||
// oob can support to dword level
|
||||
if(remainder_in_bytes % 4 == 0)
|
||||
{
|
||||
BsTensorIsValid = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"K is not a multiple of vector load size for B tensor!");
|
||||
}
|
||||
BsTensorIsValid = false;
|
||||
}
|
||||
}
|
||||
BsTesnorIsValid = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
bool DTesnorIsValid = {true};
|
||||
bool DTensorIsValid = {true};
|
||||
static_for<0, NumDTensor, 1>{}([&](auto index) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if(std::is_same_v<DiLayout, CLayout> == false)
|
||||
{
|
||||
DTesnorIsValid = false;
|
||||
DTensorIsValid = false;
|
||||
}
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -535,7 +585,7 @@ struct UniversalGemmKernel
|
||||
CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
|
||||
"NPerBlock without padding!");
|
||||
}
|
||||
DTesnorIsValid = false;
|
||||
DTensorIsValid = false;
|
||||
}
|
||||
if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
|
||||
{
|
||||
@@ -543,7 +593,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
|
||||
}
|
||||
DTesnorIsValid = false;
|
||||
DTensorIsValid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -555,7 +605,7 @@ struct UniversalGemmKernel
|
||||
CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
|
||||
"MPerBlock without padding!");
|
||||
}
|
||||
DTesnorIsValid = false;
|
||||
DTensorIsValid = false;
|
||||
}
|
||||
if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
|
||||
{
|
||||
@@ -563,7 +613,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
|
||||
}
|
||||
DTesnorIsValid = false;
|
||||
DTensorIsValid = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -608,7 +658,7 @@ struct UniversalGemmKernel
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid;
|
||||
return AsTensorIsValid && BsTensorIsValid && DTensorIsValid;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
|
||||
@@ -845,10 +845,10 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr index_t smem_size_a =
|
||||
integer_least_multiple(sizeof(typename Problem::ADataType) *
|
||||
Problem::BlockGemmShape::kM * Problem::BlockGemmShape::kK,
|
||||
16);
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_a = integer_least_multiple(
|
||||
a_lds_block_desc.get_element_space_size() * sizeof(ADataType), 16);
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
@@ -859,8 +859,9 @@ struct UniversalGemmBasePolicy
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
constexpr index_t smem_size_b = integer_least_multiple(
|
||||
sizeof(BDataType) * Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, 16);
|
||||
constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_b = integer_least_multiple(
|
||||
b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16);
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
|
||||
@@ -53,11 +53,11 @@ struct TileGemmUniversalTraits
|
||||
static constexpr int _VectorSize = VectorSize_;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
|
||||
using AsLayout = AsLayout_;
|
||||
using BsLayout = BsLayout_;
|
||||
using CLayout = CLayout_;
|
||||
using AsLayout = AsLayout_;
|
||||
using BsLayout = BsLayout_;
|
||||
using CLayout = CLayout_;
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
static constexpr bool UseStructuredSparsity = UseStructuredSparsity_;
|
||||
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
|
||||
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
||||
|
||||
@@ -306,6 +306,16 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIter
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
template <typename A, typename B, WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_f8f6f4 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<A, B>, AttrNumAccess>>;
|
||||
|
||||
@@ -68,6 +68,19 @@ struct WarpGemmAttributeWmma
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeWmmaImpl_>;
|
||||
|
||||
// When kTransC is true and A/B types differ, we need an impl with swapped types
|
||||
using TransposedImpl =
|
||||
std::conditional_t<kTransC &&
|
||||
!std::is_same_v<typename Impl::ADataType, typename Impl::BDataType>,
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<typename Impl::TraitsType::ArchType,
|
||||
typename Impl::BDataType,
|
||||
typename Impl::ADataType,
|
||||
typename Impl::CDataType,
|
||||
Impl::kM,
|
||||
Impl::kN,
|
||||
Impl::kK>>,
|
||||
Impl>;
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
@@ -104,7 +117,7 @@ struct WarpGemmAttributeWmma
|
||||
{
|
||||
if constexpr(kTransC)
|
||||
{
|
||||
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
|
||||
TransposedImpl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -117,7 +130,7 @@ struct WarpGemmAttributeWmma
|
||||
{
|
||||
if constexpr(kTransC)
|
||||
{
|
||||
return Impl{}(b_vec, a_vec);
|
||||
return TransposedImpl{}(b_vec, a_vec);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -22,9 +22,10 @@ struct WmmaTraits;
|
||||
template <typename Traits>
|
||||
struct WarpGemmAttributeWmmaImpl
|
||||
{
|
||||
using ADataType = typename Traits::ADataType;
|
||||
using BDataType = typename Traits::BDataType;
|
||||
using CDataType = typename Traits::CDataType;
|
||||
using TraitsType = Traits;
|
||||
using ADataType = typename Traits::ADataType;
|
||||
using BDataType = typename Traits::BDataType;
|
||||
using CDataType = typename Traits::CDataType;
|
||||
|
||||
using AVecType = typename Traits::AVecType;
|
||||
using BVecType = typename Traits::BVecType;
|
||||
|
||||
@@ -10,6 +10,8 @@ template <>
|
||||
struct WmmaTraits<gfx11_t, fp16_t, fp16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx11_t, fp16_t, fp16_t, float>
|
||||
{
|
||||
using ArchType = gfx11_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
@@ -30,6 +32,8 @@ template <>
|
||||
struct WmmaTraits<gfx11_t, bf16_t, bf16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx11_t, bf16_t, bf16_t, float>
|
||||
{
|
||||
using ArchType = gfx11_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
@@ -50,6 +54,8 @@ template <>
|
||||
struct WmmaTraits<gfx12_t, fp16_t, fp16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, fp16_t, fp16_t, float>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
@@ -70,6 +76,8 @@ template <>
|
||||
struct WmmaTraits<gfx12_t, bf16_t, bf16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, bf16_t, bf16_t, float>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
|
||||
@@ -10,6 +10,8 @@ template <>
|
||||
struct WmmaTraits<gfx11_t, int8_t, int8_t, int32_t, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx11_t, int8_t, int8_t, int32_t>
|
||||
{
|
||||
using ArchType = gfx11_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
@@ -35,6 +37,8 @@ template <>
|
||||
struct WmmaTraits<gfx12_t, int8_t, int8_t, int32_t, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, int8_t, int8_t, int32_t>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
@@ -60,6 +64,8 @@ template <>
|
||||
struct WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, fp8_t, fp8_t, float>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
@@ -80,6 +86,8 @@ template <>
|
||||
struct WmmaTraits<gfx12_t, bf8_t, bf8_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, bf8_t, bf8_t, float>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
@@ -100,6 +108,8 @@ template <>
|
||||
struct WmmaTraits<gfx12_t, fp8_t, bf8_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, fp8_t, bf8_t, float>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
|
||||
@@ -10,6 +10,8 @@ struct WmmaTraitsBase;
|
||||
template <typename ADType, typename BDType, typename CDType>
|
||||
struct WmmaTraitsBase<gfx11_t, ADType, BDType, CDType>
|
||||
{
|
||||
using ArchType = gfx11_t;
|
||||
|
||||
using ADataType = ADType;
|
||||
using BDataType = BDType;
|
||||
using CDataType = CDType;
|
||||
@@ -57,6 +59,8 @@ struct WmmaTraitsBase<gfx11_t, ADType, BDType, CDType>
|
||||
template <typename ADType, typename BDType, typename CDType>
|
||||
struct WmmaTraitsBase<gfx12_t, ADType, BDType, CDType>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
using ADataType = ADType;
|
||||
using BDataType = BDType;
|
||||
using CDataType = CDType;
|
||||
|
||||
@@ -100,6 +100,7 @@ template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 16, false> { using Ty
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, true> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
|
||||
@@ -113,6 +114,7 @@ template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false> { using Ty
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, true> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
|
||||
|
||||
// scale mfma based f8f6f4
|
||||
|
||||
Reference in New Issue
Block a user