mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
This reverts commit c51102144f.
This commit is contained in:
@@ -41,11 +41,11 @@ static constexpr bool scale_mfma_hw_support()
|
||||
|
||||
enum struct MfmaInstr
|
||||
{
|
||||
mfma_f32_32x32x1f32 = 0,
|
||||
mfma_f32_16x16x1f32,
|
||||
mfma_f32_4x4x1f32,
|
||||
mfma_f32_32x32x2f32,
|
||||
mfma_f32_16x16x4f32,
|
||||
mfma_f32_32x32x1xf32 = 0,
|
||||
mfma_f32_16x16x1xf32,
|
||||
mfma_f32_4x4x1xf32,
|
||||
mfma_f32_32x32x2xf32,
|
||||
mfma_f32_16x16x4xf32,
|
||||
mfma_f32_32x32x4f16,
|
||||
mfma_f32_16x16x4f16,
|
||||
mfma_f32_4x4x4f16,
|
||||
@@ -78,8 +78,6 @@ enum struct MfmaInstr
|
||||
mfma_f32_16x16x128f8f6f4,
|
||||
mfma_scale_f32_32x32x64f8f6f4,
|
||||
mfma_scale_f32_16x16x128f8f6f4,
|
||||
mfma_f32_16x16x8xf32, // tf32
|
||||
mfma_f32_32x32x4xf32,
|
||||
// gfx11
|
||||
wmma_f32_16x16x16_f16,
|
||||
wmma_f32_16x16x16_bf16,
|
||||
@@ -100,7 +98,7 @@ template <MfmaInstr instr>
|
||||
struct mfma_type;
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x1f32>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x1xf32>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
@@ -122,7 +120,7 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x1f32>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x2f32>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x2xf32>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
@@ -144,7 +142,7 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2f32>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x4f32>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x4xf32>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
@@ -166,7 +164,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x4f32>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x1f32>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x1xf32>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
@@ -189,7 +187,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x1f32>
|
||||
|
||||
// treat 4x4x1 as a single-blk 4x64 mfma
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_4x4x1f32>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_4x4x1xf32>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
@@ -949,70 +947,6 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* num_threads_per_blk == n_per_blk
|
||||
* num_regs_per_blk * num_input_blks == m_per_blk
|
||||
* num_regs_per_blk * wave_size == m_per_blk * n_per_blk
|
||||
*
|
||||
* group_size * num_groups_per_blk == num_regs_per_blk
|
||||
*
|
||||
* num_regs_per_blk is output(CD) register size which is determined by the instruction.
|
||||
* k_per_blk(K1PerXdlops) is input(AB) register size which is determined by the instruction.
|
||||
* group_size is corresponding to CD rows mapping. see: GetBeginOfThreadBlk()
|
||||
*
|
||||
* is_k_reduction = (k_per_blk == KPerXdlops) ? false: true.
|
||||
*
|
||||
* if (is_k_reduction){
|
||||
* num_output_blks == 1;
|
||||
* } else {
|
||||
* num_input_blks == num_output_blks;
|
||||
* }
|
||||
*/
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x8xf32>
|
||||
{
|
||||
static constexpr index_t wave_size = 64; // fixed
|
||||
static constexpr index_t m_per_blk = 16; // from the instruction
|
||||
static constexpr index_t n_per_blk = 16; // from the instruction
|
||||
static constexpr index_t num_threads_per_blk = n_per_blk; // 16
|
||||
static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 4
|
||||
static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 4
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t k_per_blk = 2; // k_per_blk(K1PerXdlops) should be 2.
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
// AB register size : 2, register size: 4
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x8xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x4xf32>
|
||||
{
|
||||
static constexpr index_t wave_size = 64; // fixed
|
||||
static constexpr index_t m_per_blk = 32; // from the instruction
|
||||
static constexpr index_t n_per_blk = 32; // from the instruction
|
||||
static constexpr index_t num_threads_per_blk = n_per_blk; // 32
|
||||
static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 16
|
||||
static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 2
|
||||
static constexpr index_t group_size = 4; // corresponding to CD rows mapping
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
// AB register size: 2, CD register size: 16
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x4xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
// gfx11
|
||||
struct mfma_type_gfx11_base
|
||||
{
|
||||
@@ -1182,20 +1116,6 @@ struct mfma_type<MfmaInstr::wmma_unsupport_16x16_gfx12> : public mfma_type_gfx12
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @class MfmaSelector
|
||||
* @brief Selects the appropriate MFMA instruction type and configuration for given data types
|
||||
* and tile sizes on AMD GPUs.
|
||||
*
|
||||
* @tparam base_type The base data type for the matrix operation (e.g., float, half_t).
|
||||
* @tparam MPerXdlops The number of rows per XDLops tile.
|
||||
* @tparam NPerXdlops The number of columns per XDLops tile.
|
||||
* @tparam additional_type (Optional) Additional data type for mixed-precision or special cases.
|
||||
* Defaults to base_type.
|
||||
* @tparam is_single_rate_mfma (Optional) Whether to use single-rate MFMA instructions.
|
||||
* Defaults to false.
|
||||
* @tparam is_scale_mfma (Optional) Whether to use scale MFMA instructions. Defaults to false.
|
||||
*/
|
||||
template <typename base_type,
|
||||
index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
@@ -1227,37 +1147,37 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<float, 64, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x1f32;
|
||||
return MfmaInstr::mfma_f32_32x32x1xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<float, 32, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x1f32;
|
||||
return MfmaInstr::mfma_f32_32x32x1xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<float, 16, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x1f32;
|
||||
return MfmaInstr::mfma_f32_16x16x1xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<float, 8, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_4x4x1f32;
|
||||
return MfmaInstr::mfma_f32_4x4x1xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<float, 4, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_4x4x1f32;
|
||||
return MfmaInstr::mfma_f32_4x4x1xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<float, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x2f32;
|
||||
return MfmaInstr::mfma_f32_32x32x2xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -1268,22 +1188,10 @@ struct MfmaSelector
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x4f32;
|
||||
return MfmaInstr::mfma_f32_16x16x4xf32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<tf32_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x4xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<tf32_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x8xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<half_t, 64, 64>()
|
||||
{
|
||||
@@ -1988,7 +1896,7 @@ struct XdlopsGemm
|
||||
|
||||
__device__ __host__ static constexpr index_t GetRegSizePerXdlops()
|
||||
{
|
||||
return mfma_instr.num_regs_per_blk;
|
||||
return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; }
|
||||
@@ -1998,12 +1906,12 @@ struct XdlopsGemm
|
||||
{
|
||||
static_assert(
|
||||
is_same<base_type, double>::value || is_same<base_type, float>::value ||
|
||||
is_same<base_type, tf32_t>::value || is_same<base_type, half_t>::value ||
|
||||
is_same<base_type, bhalf_t>::value || is_same<base_type, int8_t>::value ||
|
||||
is_same<base_type, f8_t>::value || is_same<base_type, bf8_t>::value ||
|
||||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
|
||||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value ||
|
||||
is_same<base_type, bf8_t>::value ||
|
||||
(is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value) ||
|
||||
(is_same<base_type, bf8_t>::value && is_same<additional_type, f8_t>::value),
|
||||
"base_type must be double, float, tf32_t, half, bfloat16, int8_t, f8_t or bf8_t!");
|
||||
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
|
||||
|
||||
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
if constexpr(!TransposeC)
|
||||
|
||||
Reference in New Issue
Block a user