diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 97ac21eba5..00a7278ec9 100755 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -30,6 +30,7 @@ add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3) add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp) add_example_executable(example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp) +add_example_executable(example_gemm_xdl_fp8_pk_i4_v3 gemm_xdl_fp8_pk_i4_v3.cpp) add_example_executable(example_gemm_xdl_fp16_pk_i4_v3_b_scale gemm_xdl_fp16_pk_i4_v3_b_scale.cpp) add_example_executable(example_gemm_xdl_bf16_pk_i4_v3 gemm_xdl_bf16_pk_i4_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 9664c50b6e..58b15bda1d 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -369,3 +369,26 @@ inline __host__ __device__ constexpr double get_atol() return 1e-3; } } + +float i4_to_f32_gfx9(uint8_t i4) +{ + static std::unordered_map u = { + {0b1000, -0.5000f}, + {0b1001, -0.4375f}, + {0b1010, -0.3750f}, + {0b1011, -0.3125f}, + {0b1100, -0.2500f}, + {0b1101, -0.1875f}, + {0b1110, -0.1250f}, + {0b1111, -0.0625f}, + {0b0 , +0.0000f}, + {0b1 , +0.0625f}, + {0b10 , +0.1250f}, + {0b11 , +0.1875f}, + {0b100 , +0.2500f}, + {0b101 , +0.3125f}, + {0b110 , +0.3750f}, + {0b111 , +0.4375f}}; + + return u[i4]; +} diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp new file mode 100644 index 0000000000..4dd92b1973 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp @@ -0,0 +1,348 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using F8 = ck::f8_t; +using I4 = ck::pk_i4_t; +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F8; +using BDataType = I4; +using AccDataType = float; +using CShuffleDataType = F16; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; +static constexpr ck::index_t KPerBlock = 128; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 128, + 16, 128, + KPerBlock, 16, 32, + 16, 16, + 1, 4, + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 16, 1, 8>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>; + +// clang-format on + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + Tensor b_k_n_f32({K, N}); + + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_k_n(k, n).data; + uint8_t i4 = 0; + + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + + float v_b = i4_to_f32_gfx9(i4); + b_k_n_f32(k, n) = v_b; + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n_f32, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + +#if 0 + printf("B Matrix:\n"); + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_k_n(k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + + printf("%f (%d),", i4_to_f32_gfx9(i4), static_cast(i4x2.data)); + } + printf("\n"); + } +#endif + + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index be4e68bffa..754967dc4d 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -79,6 +79,15 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale) return res.template AsType()[Number<0>{}]; } +__device__ inline f8x8_t i4_to_fp8x8(int q) +{ + vector_type res; + + res.template AsType()(Number<0>{}) = amd_assembly_i4_to_fp8x2(q); + + return res.template AsType()[Number<0>{}]; +} + __device__ inline bhalf4_t i4_to_bhalf4(int q) { uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12); @@ -142,6 +151,19 @@ struct PassThroughPack8 #endif } + __host__ __device__ constexpr void operator()(ck::f8x8_t& y, const ck::pk_i4x4_t& x) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + vector_type result; + + result.template AsType()(Number<0>{}) = i4_to_fp8x8(bit_cast(x)); + + y = result.template AsType()[Number<0>{}]; +#else + // Added pk_i4_t to f8x2_fnuz_t conversion +#endif + } + __host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const { #if CK_USE_PK4_LAYOUT_SHUFFLE diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 113f3af4ae..ba27d63dbd 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -32,6 +32,44 @@ inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b) return c; } +inline __device__ f8x8_t amd_assembly_i4_to_fp8x2(int a) +{ + uint32_t i4x8 = static_cast(a); + uint32_t fp8x4_0; + uint32_t fp8x4_1; + float tmp_0, tmp_1, tmp_2; + + asm volatile("v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_2\n" + "v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1]\n" + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_3\n" + "v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1]\n" + "v_lshrrev_b32 %[v_tmp_2], 4, %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2]\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_2\n" + "v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n" + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2], src0_sel:BYTE_1\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_3\n" + "v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n" + : [v_tmp_0] "+v"(tmp_0), + [v_tmp_1] "+v"(tmp_1), + [v_tmp_2] "+v"(tmp_2), + [v_dst_0] "+v"(fp8x4_0), + [v_dst_1] "+v"(fp8x4_1), + [v_src] "+v"(i4x8) + :); + + union + { + uint64_t as_uint64; + f8x8_t as_f8x8; + } convert; + + convert.as_uint64 = (static_cast(fp8x4_1) << 32) | fp8x4_0; + return convert.as_f8x8; +} + // c0 += inner_product(a, b0) // c1 += inner_product(a, b1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)