From 3c1b791968def8121e80ab2651a514b4a289c3e3 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Tue, 4 Jul 2023 21:38:49 -0500 Subject: [PATCH] Add fp8 GEMM and an example for it (#767) * Add fp8 xdl gemm * Add example * Use int8 intrinsics for buffer load/store * Format * Update cmakelists [ROCm/composable_kernel commit: 1cf5003179970b3ce9f8252395962810adb50f76] --- example/01_gemm/CMakeLists.txt | 4 ++ example/01_gemm/gemm_xdl_f8.cpp | 38 +++++++++++ .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 62 +++++++++++++++++- include/ck/utility/amd_buffer_addressing.hpp | 57 ++++++++++++++--- include/ck/utility/amd_xdlops.hpp | 63 +++++++++++++++++++ 5 files changed, 212 insertions(+), 12 deletions(-) create mode 100644 example/01_gemm/gemm_xdl_f8.cpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index c5a8295188..66afe73c2a 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -44,3 +44,7 @@ if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) endif() +if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") + add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_f8) +endif() diff --git a/example/01_gemm/gemm_xdl_f8.cpp b/example/01_gemm/gemm_xdl_f8.cpp new file mode 100644 index 0000000000..1015926777 --- /dev/null +++ b/example/01_gemm/gemm_xdl_f8.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using CDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::f8_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index faaa2c5a95..814969ef42 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -29,7 +29,9 @@ enum struct MfmaInstr mfma_i32_16x16x16i8, mfma_i32_32x32x16i8, mfma_i32_16x16x32i8, - mfma_f64_16x16x4f64 + mfma_f64_16x16x4f64, + mfma_f32_32x32x16f8f8, + mfma_f32_16x16x32f8f8 }; template @@ -454,6 +456,50 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16f8f8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32f8f8::Run(a, b, reg_c); + } +}; + template struct MfmaSelector { @@ -594,6 +640,18 @@ struct MfmaSelector } #endif + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16f8f8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32f8f8; + } + static constexpr auto selected_mfma = mfma_type()>{}; __host__ __device__ constexpr MfmaSelector() @@ -794,7 +852,7 @@ struct XdlopsGemm { static_assert(is_same::value || is_same::value || is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value, "base base_type must be double, float, half, bfloat16, and int8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 38ee76d883..ea231154ea 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1114,13 +1114,30 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; - return amd_buffer_load_impl( - src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); + if constexpr(is_same::value) + { + auto tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); + return bit_cast(tmp); + } + else + { + return amd_buffer_load_impl( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); + } #else - vector_t tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); - - return src_thread_element_valid ? tmp : vector_t(0); + if constexpr(is_same::value) + { + auto tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + return src_thread_element_valid ? bit_cast(tmp) : vector_t(0); + } + else + { + vector_t tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + return src_thread_element_valid ? tmp : vector_t(0); + } #endif } @@ -1179,13 +1196,33 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); + if constexpr(is_same::value) + { + auto tmp = + bit_cast::type::type>(src_thread_data); + amd_buffer_store_impl( + tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); + } + else + { + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); + } #else if(dst_thread_element_valid) { - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + if constexpr(is_same::value) + { + auto tmp = bit_cast::type::type>( + src_thread_data); + amd_buffer_store_impl( + tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } + else + { + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } } #endif } diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index d00b7cd078..ca38077cf5 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -354,5 +354,68 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> #endif } }; + +template +struct intrin_mfma_f32_32x32x16f8f8; + +template <> +struct intrin_mfma_f32_32x32x16f8f8<32, 32> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32f8f8; + +template <> +struct intrin_mfma_f32_16x16x32f8f8<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; } // namespace ck #endif