mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
Add FP64 XDL GEMM built-in function (#199)
* add intrin_mfma_f64_16x16x4f64
* add example
* gemm reference add double data type
* chang init data
* fix M N PerXdlops
* fix ifdef
* add comparsion config
* add conv fwd example
* format log out
* change rc matrix egister layout
* reorganize example
* reorganize example 2
* format,because merge develop
* fix call impl adding acc data type
* lost ;
* add compiler warning
* change example tunning parameters
* add test for fp64
* add instance
* add test/gemm/gemm_fp64.cpp
* fix get name issue
* remove some tunning parameter
* fix conflict
* format
* use integer value for GEMM test
* add acc data type
* remove typeid because fp16
* fix streamconfig etc bug from merging develop
* format
* remove test_gemm_xdl_fp64
* add AccDataType
* AccDataType problem
Co-authored-by: qinletao <letaoqin@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 3e6c2610ae]
This commit is contained in:
@@ -25,6 +25,7 @@ enum struct MfmaInstr
|
||||
mfma_f32_16x16x8bf16,
|
||||
mfma_i32_32x32x8i8,
|
||||
mfma_i32_16x16x16i8,
|
||||
mfma_f64_16x16x4f64
|
||||
};
|
||||
|
||||
template <MfmaInstr instr>
|
||||
@@ -383,12 +384,40 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
|
||||
{
|
||||
static constexpr index_t group_size = 1;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
static constexpr index_t num_regs_per_blk = 4; // group_size * num_groups_per_blk;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 4; // wave_size / num_threads_per_blk;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 1;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f64_16x16x4f64<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
|
||||
struct MfmaSelector
|
||||
{
|
||||
template <typename base_type_, index_t MPerXdlops_, index_t NPerXdlops_>
|
||||
static constexpr auto GetMfma();
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<double, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f64_16x16x4f64;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<float, 64, 64>()
|
||||
{
|
||||
@@ -661,9 +690,10 @@ struct XdlopsGemm
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
|
||||
{
|
||||
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
|
||||
is_same<base_type, bhalf_t>::value || is_same<base_type, int8_t>::value,
|
||||
"base base_type must be float, half, bfloat16, and int8_t!");
|
||||
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
|
||||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
|
||||
is_same<base_type, int8_t>::value,
|
||||
"base base_type must be double, float, half, bfloat16, and int8_t!");
|
||||
|
||||
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
|
||||
|
||||
@@ -294,5 +294,24 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f64_16x16x4f64;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f64_16x16x4f64<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#ifdef __gfx90a__
|
||||
reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
|
||||
reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user