From abf9cc6c5cdb9b3b6f2e8ed25feebbacab59d198 Mon Sep 17 00:00:00 2001 From: Haocong WANG Date: Sat, 3 Dec 2022 01:41:13 +0800 Subject: [PATCH] [Navi3x-LWPCK-449] wmma_op + unit test (#484) * wmma_op + unit test * add arch limitation to wmma test * change arch limitation * Refactor + Add all type unit test(int4 compile failed) * Add f32_16x16x16_bf16 unit test * Remote int4 related * delete deprecated test Co-authored-by: Po Yen Chen Co-authored-by: Chao Liu --- include/ck/ck.hpp | 11 +- include/ck/utility/amd_wmma.hpp | 102 +++++++++ test/CMakeLists.txt | 3 + test/wmma_op/CMakeLists.txt | 2 + test/wmma_op/wmma_op.cpp | 67 ++++++ test/wmma_op/wmma_op_util.hpp | 369 ++++++++++++++++++++++++++++++++ 6 files changed, 553 insertions(+), 1 deletion(-) create mode 100644 include/ck/utility/amd_wmma.hpp create mode 100644 test/wmma_op/CMakeLists.txt create mode 100644 test/wmma_op/wmma_op.cpp create mode 100644 test/wmma_op/wmma_op_util.hpp diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index ddaef1db3b..4be2e85d50 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -25,7 +25,7 @@ // check GPU target #ifdef __HIP_DEVICE_COMPILE__ #if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx1030__)) + defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__)) #error Not supported target #endif #endif @@ -38,6 +38,8 @@ #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx1030__) // for GPU code #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +#elif defined(__gfx1100__) // for GPU code +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x10020000 #endif // FMA instruction @@ -62,6 +64,13 @@ #define CK_USE_AMD_MFMA_BF16_1K_OP #endif +// WMMA instruction +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_USE_AMD_WMMA +#elif defined(__gfx1100__) // for GPU code +#define CK_USE_AMD_WMMA +#endif + // buffer load #define CK_USE_AMD_BUFFER_LOAD 1 diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp new file mode 100644 index 0000000000..752876a769 --- /dev/null +++ b/include/ck/utility/amd_wmma.hpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_AMD_WMMA_HPP +#define CK_AMD_WMMA_HPP + +#include "data_type.hpp" +// TODO: Add arch limitation +namespace ck { + +// wave32 only +// src: fp16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f16_w32; + +template <> +struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> +{ + template + __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); + } +}; + +// src: bf16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf16_w32; + +template <> +struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> +{ + template + __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); + } +}; + +// src: fp16, dst: fp16 +template +struct intrin_wmma_f16_16x16x16_f16_w32; + +template +struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel> +{ + template + __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) + { + // opsel usage + // false: D0.[0:15] = result + // true : D0.[16:31]= result + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); + } +}; + +// src: bf16, dst: bf16 +template +struct intrin_wmma_bf16_16x16x16_bf16_w32; + +template +struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> +{ + template + __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) + { + // opsel usage + // false: D0.[0:15] = result + // true : D0.[16:31]= result + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); + } +}; + +// src: iu8, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu8_w32; + +template +struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); + } +}; + +} // namespace ck +#endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a8347d9e38..b2e25e4ca7 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -55,3 +55,6 @@ add_subdirectory(normalization) add_subdirectory(data_type) add_subdirectory(elementwise_normalization) add_subdirectory(batchnorm) +if(GPU_TARGETS MATCHES "gfx1100") + add_subdirectory(wmma_op) +endif() diff --git a/test/wmma_op/CMakeLists.txt b/test/wmma_op/CMakeLists.txt new file mode 100644 index 0000000000..e553253c62 --- /dev/null +++ b/test/wmma_op/CMakeLists.txt @@ -0,0 +1,2 @@ +add_test_executable(test_wmma_op wmma_op.cpp) +target_link_libraries(test_wmma_op PRIVATE utility) diff --git a/test/wmma_op/wmma_op.cpp b/test/wmma_op/wmma_op.cpp new file mode 100644 index 0000000000..761c15f1dd --- /dev/null +++ b/test/wmma_op/wmma_op.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "test/wmma_op/wmma_op_util.hpp" + +template +bool run_test() +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + bool pass = true; + + const auto matmul_default = ck::wmma_op_util::matmul; + const auto matmul_swizzle_a = + ck::wmma_op_util::matmul_swizzle_a; + + const auto wmma_kernel_container = std::make_tuple(matmul_default, matmul_swizzle_a); + + ck::static_for<0, 2, 1>{}([&](auto i) { + pass &= + ck::wmma_op_util::TestWmma{}>(wmma_kernel_container)), + SrcType, + SrcType, + DstType, + GPUAccType, + CPUAccType, + decltype(Row{}), + decltype(Col{}), + decltype(Row{}), + PassThrough, + PassThrough, + PassThrough, + AccNum>{}(std::get{}>(wmma_kernel_container)); + }); + + return pass ? 1 : 0; +} +int main(int, char*[]) +{ + bool pass = true; + // clang-format off + // |SrcType |DstType |GPUAccType |CPUAccType |AccNum + pass &= run_test(); + pass &= run_test(); + pass &= run_test(); + pass &= run_test(); + pass &= run_test(); + // clang-format on + + std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp new file mode 100644 index 0000000000..ef3f831abd --- /dev/null +++ b/test/wmma_op/wmma_op_util.hpp @@ -0,0 +1,369 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/utility/amd_wmma.hpp" + +namespace ck { +namespace wmma_op_util { + +template +__device__ void builtin_wmma_naive_selector(const src_vec&, const src_vec&, acc_vec&) +{ +} + +template <> +__device__ void +builtin_wmma_naive_selector>( + const half16_t& reg_a, + const half16_t& reg_b, + StaticBufferTupleOfVector& reg_c) +{ + intrin_wmma_f32_16x16x16_f16_w32<16, 16>::Run( + reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void +builtin_wmma_naive_selector>( + const bhalf16_t& reg_a, + const bhalf16_t& reg_b, + StaticBufferTupleOfVector& reg_c) +{ + intrin_wmma_f32_16x16x16_bf16_w32<16, 16>::Run( + reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void +builtin_wmma_naive_selector>( + const half16_t& reg_a, + const half16_t& reg_b, + StaticBufferTupleOfVector& reg_c) +{ + intrin_wmma_f16_16x16x16_f16_w32<16, 16, 0>::Run( + reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void builtin_wmma_naive_selector< + bhalf16_t, + StaticBufferTupleOfVector>( + const bhalf16_t& reg_a, + const bhalf16_t& reg_b, + StaticBufferTupleOfVector& reg_c) +{ + intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, 0>::Run( + reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void +builtin_wmma_naive_selector>( + const int8x16_t& reg_a, + const int8x16_t& reg_b, + StaticBufferTupleOfVector& reg_c) +{ + intrin_wmma_i32_16x16x16_iu8_w32<16, 16, true, true, false>::Run( + reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{})); +} + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +__device__ void +builtin_wmma_naive_selector>( + const int4x16_t& reg_a, + const int4x16_t& reg_b, + StaticBufferTupleOfVector& reg_c) +{ + intrin_wmma_i32_16x16x16_iu4_w32<16, 16, true, true, false>::Run( + reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{})); +} +#endif + +template +__global__ void matmul(const src_t* a, const src_t* b, dst_t* c) +{ + const int lIdx = threadIdx.x; + // a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and + // b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the + // 16x16 matrix tile + using src_vec = typename vector_type::type; + src_vec a_frag = {}; + src_vec b_frag = {}; + // initialize c fragment to 0 + using acc_vec = StaticBufferTupleOfVector; + acc_vec c_thread_buf_; + + // lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11 + // see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482 + // TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101 + const int lane = lIdx % 16; + + for(int ele = 0; ele < 16; ++ele) + { + b_frag[ele] = b[16 * lane + ele]; + } + // follow origin design + for(int ele = 0; ele < 16; ++ele) + { + a_frag[ele] = a[16 * lane + ele]; + } + + // sync threads, similar to mma_sync + __syncthreads(); + builtin_wmma_naive_selector(a_frag, b_frag, c_thread_buf_); + __syncthreads(); + // wait for results, similar to mma_sync + static_for<0, 8, 1>{}([&](auto ele) { + const int r = ele * 2 + (lIdx / 16); + // store results from unpacked c_thread_buf_ output + c[16 * r + lane] = ck::type_convert(c_thread_buf_[Number{}]); + }); +} + +template +__global__ void matmul_swizzle_a(const src_t* a, const src_t* b, dst_t* c) +{ + const int lIdx = threadIdx.x; + + using src_vec = typename vector_type::type; + src_vec a_frag = {}; + src_vec b_frag = {}; + using acc_vec = StaticBufferTupleOfVector; + acc_vec c_thread_buf_; + + const int lane = lIdx % 16; + + for(int ele = 0; ele < 16; ++ele) + { + b_frag[ele] = b[16 * lane + ele]; + } + + const int offset_m = (((lane & 1) << 3) | (lane >> 1)); + for(int ele = 0; ele < 16; ++ele) + { + a_frag[ele] = a[16 * offset_m + ele]; + } + + __syncthreads(); + builtin_wmma_naive_selector(a_frag, b_frag, c_thread_buf_); + __syncthreads(); + + static_for<0, 8, 1>{}([&](auto ele) { + const int blk = lIdx / 16; + const int r = ele; + c[16 * 8 * blk + 16 * r + lane] = + ck::type_convert(c_thread_buf_[Number{}]); + }); +} + +struct GemmParams +{ + GemmParams() : M(16), N(16), K(16), StrideA(16), StrideB(16), StrideC(16), alpha(1), beta(0) {} + + ck::index_t M; + ck::index_t N; + ck::index_t K; + + ck::index_t StrideA; + ck::index_t StrideB; + ck::index_t StrideC; + + float alpha; + float beta; +}; + +template +void RunHostGEMM(const Tensor& A, + const Tensor& B, + Tensor& C, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) +{ + auto ref_gemm = GemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); +} + +template +bool RunDeviceGEMM(KernelType kernel, + const Tensor& A, + const Tensor& B, + Tensor& C) +{ + DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize()); + DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(A.mData.data()); + b_n_k_device_buf.ToDevice(B.mData.data()); + kernel<<<1, 32>>>(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer())); + c_m_n_device_buf.FromDevice(C.mData.data()); + + return true; +} + +template +struct TestWmma +{ + auto PrepareGemmTensor(const ck::wmma_op_util::GemmParams& params) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_n_k( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + auto f_generate_tensor_value = [](auto& tensor, auto type) { + using dataType = decltype(type); + + tensor.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + }; + + f_generate_tensor_value(a_m_k, ADataType{}); + f_generate_tensor_value(b_n_k, BDataType{}); + + return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result); + } + + auto operator()(const DeviceWmma& wmma_kernel) + { + std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name + << ", CLayout = " << CLayout{}.name << std::endl; + + // Arrange + ck::wmma_op_util::GemmParams params; + params.M = 16; + params.N = 16; + params.K = 16; + params.StrideA = 16; + params.StrideB = 16; + params.StrideC = 16; + + auto host_tensors = PrepareGemmTensor(params); + + const Tensor& a = std::get<0>(host_tensors); + const Tensor& b = std::get<1>(host_tensors); + Tensor& c_host = std::get<2>(host_tensors); + Tensor& c_device = std::get<3>(host_tensors); + + auto a_element_op = AElementwiseOperation{}; + auto b_element_op = BElementwiseOperation{}; + auto c_element_op = CElementwiseOperation{}; + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + ck::wmma_op_util::RunHostGEMM( + a, b, c_host, a_element_op, b_element_op, c_element_op); + + // Act + bool is_supported = ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device); + + if(is_supported) + { + // Assert + bool res = false; + if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + // 0.5 Pixel Error Tolerance is introduced by Accumulator difference. + // BF16 WMMA Accumulator is in BF16 Type while On Host-side Accumulator is Float. + res = ck::utils::check_err( + c_device.mData, c_host.mData, "Error: Incorrect results!", 0, 1.0); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else + { + std::cout << "UNSUPPORTED CDataType" << std::endl; + } + + return res; + } + else + { + return true; + } + } +}; + +} // namespace wmma_op_util +} // namespace ck