mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 1 (#2606)
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/amd_xdlops.hpp"
|
||||
#include "ck/utility/amd_wmma.hpp"
|
||||
|
||||
namespace ck {
|
||||
/**
|
||||
@@ -76,7 +77,21 @@ enum struct MfmaInstr
|
||||
mfma_f32_32x32x64f8f6f4,
|
||||
mfma_f32_16x16x128f8f6f4,
|
||||
mfma_scale_f32_32x32x64f8f6f4,
|
||||
mfma_scale_f32_16x16x128f8f6f4
|
||||
mfma_scale_f32_16x16x128f8f6f4,
|
||||
// gfx11
|
||||
wmma_f32_16x16x16_f16,
|
||||
wmma_f32_16x16x16_bf16,
|
||||
wmma_i32_16x16x16_iu8,
|
||||
wmma_unsupport_16x16_gfx11,
|
||||
// gfx12
|
||||
wmma_f32_16x16x16_f16_gfx12,
|
||||
wmma_f32_16x16x16_bf16_gfx12,
|
||||
wmma_i32_16x16x16_iu8_gfx12,
|
||||
wmma_f32_16x16x16_f8f8_gfx12,
|
||||
wmma_f32_16x16x16_f8bf8_gfx12,
|
||||
wmma_f32_16x16x16_bf8f8_gfx12,
|
||||
wmma_f32_16x16x16_bf8bf8_gfx12,
|
||||
wmma_unsupport_16x16_gfx12,
|
||||
};
|
||||
|
||||
template <MfmaInstr instr>
|
||||
@@ -932,6 +947,175 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
|
||||
}
|
||||
};
|
||||
|
||||
// gfx11
|
||||
struct mfma_type_gfx11_base
|
||||
{
|
||||
static constexpr index_t group_size = 8;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 8;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 32;
|
||||
static constexpr index_t num_input_blks = 1;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 16;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_f16> : public mfma_type_gfx11_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_f16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_bf16> : public mfma_type_gfx11_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_bf16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_i32_16x16x16_iu8> : public mfma_type_gfx11_base
|
||||
{
|
||||
template <index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC,
|
||||
bool neg_a = true,
|
||||
bool neg_b = true,
|
||||
bool clamp = false>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_i32_16x16x16_iu8_w32<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_unsupport_16x16_gfx11> : public mfma_type_gfx11_base
|
||||
{
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA&, const FloatB&, FloatC&) const
|
||||
{
|
||||
// empty for all unsupported types.
|
||||
}
|
||||
};
|
||||
|
||||
// gfx12
|
||||
struct mfma_type_gfx12_base
|
||||
{
|
||||
static constexpr index_t group_size = 8;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 8;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 32;
|
||||
static constexpr index_t num_input_blks = 2;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_f16_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_f16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_bf16_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_bf16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_i32_16x16x16_iu8_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
template <index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC,
|
||||
bool neg_a = true,
|
||||
bool neg_b = true,
|
||||
bool clamp = false>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_i32_16x16x16_iu8_w32_gfx12<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
|
||||
a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_f8f8_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_f8f8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_f8bf8_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_bf8f8_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::wmma_unsupport_16x16_gfx12> : public mfma_type_gfx12_base
|
||||
{
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA&, const FloatB&, FloatC&) const
|
||||
{
|
||||
// empty for all unsupported types.
|
||||
}
|
||||
};
|
||||
|
||||
template <typename base_type,
|
||||
index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
@@ -951,7 +1135,13 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<double, 16, 16>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_f64_16x16x4f64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -993,7 +1183,13 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<float, 16, 16>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x4xf32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -1026,7 +1222,11 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<half_t, 16, 16, half_t, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f16;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x32f16;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x16f16;
|
||||
@@ -1036,7 +1236,13 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<half_t, 16, 16, half_t, true>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f16;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x16f16;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -1082,7 +1288,11 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf16;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x32bf16;
|
||||
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
|
||||
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
|
||||
@@ -1094,7 +1304,11 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
|
||||
{
|
||||
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf16;
|
||||
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
|
||||
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x8bf16;
|
||||
@@ -1126,7 +1340,11 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<int8_t, 16, 16, int8_t, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_i32_16x16x16_iu8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_i32_16x16x16_iu8;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_i32_16x16x64i8;
|
||||
#elif defined(__gfx942__)
|
||||
return MfmaInstr::mfma_i32_16x16x32i8;
|
||||
@@ -1138,7 +1356,11 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<int8_t, 16, 16, int8_t, true>()
|
||||
{
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_i32_16x16x16_iu8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_i32_16x16x16_iu8;
|
||||
#elif defined(__gfx942__) || defined(__gfx950__)
|
||||
return MfmaInstr::mfma_i32_16x16x32i8;
|
||||
#else
|
||||
return MfmaInstr::mfma_i32_16x16x16i8;
|
||||
@@ -1186,13 +1408,23 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f8f8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32f8f8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f8f8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32f8f8;
|
||||
@@ -1263,13 +1495,23 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, true, false>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
|
||||
@@ -1295,13 +1537,23 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, bf8_t, true, false>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f8bf8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32f8bf8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, bf8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_f8bf8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32f8bf8;
|
||||
@@ -1327,13 +1579,23 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, f8_t, true, false>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf8f8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8f8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, f8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_f32_16x16x16_bf8f8_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8f8;
|
||||
@@ -1355,10 +1617,18 @@ struct MfmaSelector
|
||||
|
||||
static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk,
|
||||
"n_per_blk != num_threads_per_blk");
|
||||
|
||||
#if defined(__gfx11__)
|
||||
if constexpr(MPerXdlops == 16 && NPerXdlops == 16)
|
||||
{
|
||||
static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks * 2 ==
|
||||
selected_mfma.m_per_blk,
|
||||
"m_per_blk != num_input_blks * num_regs_per_blk");
|
||||
}
|
||||
#else
|
||||
static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks ==
|
||||
selected_mfma.m_per_blk,
|
||||
"m_per_blk != num_input_blks * num_regs_per_blk");
|
||||
#endif
|
||||
|
||||
static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks ||
|
||||
selected_mfma.num_output_blks == 1,
|
||||
@@ -1424,8 +1694,9 @@ struct XdlopsGemm
|
||||
static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
|
||||
MPerXdlops == 64,
|
||||
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
|
||||
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
|
||||
#endif
|
||||
}
|
||||
|
||||
// XDL output supporting C = A * B
|
||||
@@ -1434,10 +1705,11 @@ struct XdlopsGemm
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
|
||||
{
|
||||
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_m0_n0_m1_n1_m2_n2,
|
||||
@@ -1446,7 +1718,7 @@ struct XdlopsGemm
|
||||
make_pass_through_transform(M1),
|
||||
make_pass_through_transform(N1),
|
||||
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
|
||||
Number<mfma_instr.num_input_blks>{},
|
||||
Number<num_blks>{},
|
||||
Number<mfma_instr.group_size>{})),
|
||||
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
|
||||
make_tuple(Sequence<0>{},
|
||||
@@ -1469,12 +1741,13 @@ struct XdlopsGemm
|
||||
__host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
|
||||
const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
|
||||
{
|
||||
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
|
||||
const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
|
||||
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
|
||||
const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
|
||||
constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_m0_n0_m1_n1_m2_n2,
|
||||
@@ -1485,7 +1758,7 @@ struct XdlopsGemm
|
||||
make_pass_through_transform(M2),
|
||||
make_pass_through_transform(N2),
|
||||
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
|
||||
Number<mfma_instr.num_input_blks>{},
|
||||
Number<num_blks>{},
|
||||
Number<mfma_instr.group_size>{})),
|
||||
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
|
||||
make_tuple(Sequence<0>{},
|
||||
@@ -1512,10 +1785,11 @@ struct XdlopsGemm
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
|
||||
{
|
||||
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_m0_n0_m1_n1_m2_n2,
|
||||
@@ -1525,7 +1799,7 @@ struct XdlopsGemm
|
||||
make_pass_through_transform(N1),
|
||||
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{}),
|
||||
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
|
||||
Number<mfma_instr.num_input_blks>{},
|
||||
Number<num_blks>{},
|
||||
Number<mfma_instr.group_size>{}))),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -1545,11 +1819,12 @@ struct XdlopsGemm
|
||||
__host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
|
||||
{
|
||||
const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
|
||||
const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
|
||||
constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_g_m0_n0_m1_n1_m2_n2,
|
||||
@@ -1558,9 +1833,8 @@ struct XdlopsGemm
|
||||
make_pass_through_transform(N0),
|
||||
make_pass_through_transform(M1),
|
||||
make_pass_through_transform(N1),
|
||||
make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk,
|
||||
mfma_instr.num_input_blks,
|
||||
mfma_instr.group_size)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
mfma_instr.num_groups_per_blk, num_blks, mfma_instr.group_size)),
|
||||
make_pass_through_transform(mfma_instr.num_threads_per_blk)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -1642,8 +1916,32 @@ struct XdlopsGemm
|
||||
|
||||
__device__ static auto GetBlkIdx()
|
||||
{
|
||||
const auto laneId = GetLaneId();
|
||||
const auto laneId = GetLaneId();
|
||||
constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
|
||||
|
||||
constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(1, num_blks, mfma_instr.num_threads_per_blk))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto blk_idx =
|
||||
threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
|
||||
|
||||
const auto blk_id = blk_idx[I1];
|
||||
const auto blk_td = blk_idx[I2];
|
||||
|
||||
return make_tuple(blk_id, blk_td);
|
||||
}
|
||||
|
||||
template <bool SwizzleA>
|
||||
__device__ static auto GetGfx11InputBlkIdx()
|
||||
{
|
||||
const auto laneId = GetLaneId() % mfma_instr.num_threads_per_blk;
|
||||
if constexpr(SwizzleA)
|
||||
{
|
||||
laneId = ((laneId & 1) << 3) | (laneId >> 1);
|
||||
}
|
||||
constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))),
|
||||
@@ -1661,8 +1959,12 @@ struct XdlopsGemm
|
||||
|
||||
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const auto laneId = GetLaneId();
|
||||
const auto laneId = GetLaneId();
|
||||
#if defined(__gfx11__)
|
||||
const auto blk_idx = GetGfx11InputBlkIdx<true>();
|
||||
#else
|
||||
const auto blk_idx = GetBlkIdx();
|
||||
#endif
|
||||
|
||||
const auto blk_id = blk_idx[I0];
|
||||
const auto blk_td = blk_idx[I1];
|
||||
@@ -1679,8 +1981,12 @@ struct XdlopsGemm
|
||||
|
||||
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const auto laneId = GetLaneId();
|
||||
const auto laneId = GetLaneId();
|
||||
#if defined(__gfx11__)
|
||||
const auto blk_idx = GetGfx11InputBlkIdx<false>();
|
||||
#else
|
||||
const auto blk_idx = GetBlkIdx();
|
||||
#endif
|
||||
|
||||
const auto blk_id = blk_idx[I0];
|
||||
const auto blk_td = blk_idx[I1];
|
||||
|
||||
Reference in New Issue
Block a user