From 3973caa48c132d5cbe7bbee25143b30666a1bddd Mon Sep 17 00:00:00 2001 From: illsilin Date: Tue, 7 Feb 2023 09:46:58 -0800 Subject: [PATCH] switch between intrinsic mfma routines on mi100/200 and mi300 --- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 55 +++++++++++++++++++ include/ck/utility/amd_xdlops.hpp | 21 +++++++ 2 files changed, 76 insertions(+) diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 4d53f0d816..dcca5ce37f 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -9,6 +9,7 @@ namespace ck { +#if (defined(__gfx908__) || defined(__gfx90a__)) enum struct MfmaInstr { mfma_f32_32x32x1xf32 = 0, @@ -29,6 +30,28 @@ enum struct MfmaInstr mfma_i32_16x16x16i8, mfma_f64_16x16x4f64 }; +#elif (defined(__gfx940__)) +enum struct MfmaInstr +{ + mfma_f32_32x32x1xf32 = 0, + mfma_f32_16x16x1xf32, + mfma_f32_4x4x1xf32, + mfma_f32_32x32x2xf32, + mfma_f32_16x16x4xf32, + mfma_f32_32x32x4f16, + mfma_f32_16x16x4f16, + mfma_f32_4x4x4f16, + mfma_f32_32x32x8f16, + mfma_f32_16x16x16f16, + mfma_f32_32x32x8bf16_1k, + mfma_f32_16x16x16bf16_1k, + mfma_f32_32x32x4bf16, + mfma_f32_16x16x8bf16, + mfma_i32_32x32x16i8, + mfma_i32_16x16x16i8, + mfma_f64_16x16x4f64 +}; +#endif template struct mfma_type; @@ -342,6 +365,7 @@ struct mfma_type } }; +#if (defined(__gfx908__) || defined(__gfx90a__)) template <> struct mfma_type { @@ -363,6 +387,29 @@ struct mfma_type intrin_mfma_i32_32x32x8i8::Run(a, b, reg_c); } }; +#elif (defined(__gfx940__)) +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_32x32x16i8::Run(a, b, reg_c); + } +}; +#endif template <> struct mfma_type @@ -524,11 +571,19 @@ struct MfmaSelector #endif } +#if (defined(__gfx908__) || defined(__gfx90a__)) template <> static constexpr auto GetMfma() { return MfmaInstr::mfma_i32_32x32x8i8; } +#elif (defined(__gfx940__)) + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_32x32x16i8; + } +#endif template <> static constexpr auto GetMfma() diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index d17d866668..bc9676f1f7 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -259,6 +259,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16> } }; +#if (defined(__gfx908__) || defined(__gfx90a__)) template struct intrin_mfma_i32_32x32x8i8; @@ -277,6 +278,26 @@ struct intrin_mfma_i32_32x32x8i8<32, 32> 0); } }; +#elif (defined(__gfx940__)) +template +struct intrin_mfma_i32_32x32x16i8; + +template <> +struct intrin_mfma_i32_32x32x16i8<32, 32> +{ + template + __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; +#endif template struct intrin_mfma_i32_16x16x16i8;