From b7ba13afcd32cbbaf253cbea33dcf0f451f00110 Mon Sep 17 00:00:00 2001 From: ZhaoAn Date: Sun, 4 May 2025 15:36:42 +0000 Subject: [PATCH] fix: support int8 weight-only kernel with per-channel scale --- example/01_gemm/CMakeLists.txt | 2 +- .../01_gemm/gemm_xdl_fp16_i8_v3_b_scale.cpp | 521 ++++++++++++++++++ .../gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp | 2 +- .../threadwise_tensor_slice_transfer.hpp | 22 + .../gpu/gemm_b_scale.hpp | 65 ++- .../gpu/gemm_b_scale/CMakeLists.txt | 9 +- ...e_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp | 125 +++++ ...6_i8_f16_mk_nk_mn_cmp_default_instance.cpp | 32 ++ ...6_i8_f16_mk_nk_mn_lat_default_instance.cpp | 32 ++ ...8_f16_mk_nk_mn_mem_v3_default_instance.cpp | 32 ++ ...8_f16_mk_nk_mn_mem_v4_default_instance.cpp | 32 ++ .../profiler/profile_gemm_b_scale_impl.hpp | 48 +- profiler/src/profile_gemm_b_scale.cpp | 17 +- 13 files changed, 923 insertions(+), 16 deletions(-) create mode 100644 example/01_gemm/gemm_xdl_fp16_i8_v3_b_scale.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_default_instance.cpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 77f15a213c..51b144a870 100755 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -67,7 +67,7 @@ foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32) - + add_example_executable(example_gemm_xdl_fp16_i8_v3_b_scale gemm_xdl_fp16_i8_v3_b_scale.cpp) add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16) set(target 1) diff --git a/example/01_gemm/gemm_xdl_fp16_i8_v3_b_scale.cpp b/example/01_gemm/gemm_xdl_fp16_i8_v3_b_scale.cpp new file mode 100644 index 0000000000..be16edbd13 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_i8_v3_b_scale.cpp @@ -0,0 +1,521 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp" + +#include +#include +#include + +std::vector LoadTxtTensor(const std::string& path) +{ + std::ifstream file(path); + std::vector data; + float val; + while (file >> val) + data.push_back(val); + return data; +} + + +std::vector LoadFp16Binary(const std::string& path, size_t count) { + std::vector data(count); + std::ifstream in(path, std::ios::binary); + if (!in) { + throw std::runtime_error("Failed to open file: " + path); + } + in.read(reinterpret_cast(data.data()), count * sizeof(ck::half_t)); + return data; +} + + + +std::vector LoadInt8Binary(const std::string& path, size_t count) { + std::vector data(count); + std::ifstream in(path, std::ios::binary); + in.read(reinterpret_cast(data.data()), count * sizeof(int8_t)); + return data; +} + + +// std::vector LoadTensorFromPt(const std::string& path) { +// torch::Tensor tensor = torch::load(path).to(torch::kCPU); +// tensor = tensor.contiguous(); +// auto* ptr = tensor.data_ptr(); +// return std::vector(ptr, ptr + tensor.numel()); +// } + + +void PrintGemmParams(const void* A, const void* B, const void* C, + int M, int N, int K, + int StrideA, int StrideB, int StrideC, + int ScaleStrideBN, const void* BScale, int KBatch) +{ +std::cout << "[W8Only_Gemm_Debug] A_input = " << A << std::endl; +std::cout << "[W8Only_Gemm_Debug] B_input = " << B << std::endl; +std::cout << "[W8Only_Gemm_Debug] C_output = " << C << std::endl; +std::cout << "[W8Only_Gemm_Debug] M = " << M << std::endl; +std::cout << "[W8Only_Gemm_Debug] N = " << N << std::endl; +std::cout << "[W8Only_Gemm_Debug] K = " << K << std::endl; +std::cout << "[W8Only_Gemm_Debug] StrideA = " << StrideA << std::endl; +std::cout << "[W8Only_Gemm_Debug] StrideB = " << StrideB << std::endl; +std::cout << "[W8Only_Gemm_Debug] StrideC = " << StrideC << std::endl; +std::cout << "[W8Only_Gemm_Debug] Scale_Stride_BN = " << ScaleStrideBN << std::endl; +std::cout << "[W8Only_Gemm_Debug] B_scale = " << BScale << std::endl; +std::cout << "[W8Only_Gemm_Debug] KBatch = " << KBatch << std::endl; +} + +using ADataType = ck::half_t; +using BDataType = int8_t; +using BScaleDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = false; + +static constexpr ck::index_t Scale_Block_N = 1; +static constexpr ck::index_t Scale_Block_K = 1024; + +static constexpr ck::index_t KPerBlock = 64; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, Scale_Block_N, Scale_Block_K, + 128, 128, + KPerBlock, 8, 32, + 32, 32, + 4, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + // ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_BN = 1; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + BLayout{})); + // Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + // (N + Scale_Block_N - 1) / Scale_Block_N, + // Scale_Stride_BN, + // ck::tensor_layout::gemm::RowMajor{})); + // Tensor b1_k_n(f_host_tensor_descriptor(1, 4096, 1, ck::tensor_layout::gemm::RowMajor{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 4: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + + // std::string pt_dir = "/mnt/raid0/zhaoan12/letao_gemm_pt/"; + + // std::vector A_input_actual = LoadTxtTensor(pt_dir + "A.txt"); + // std::vector B_input_actual = LoadTxtTensor(pt_dir + "B.txt"); + // std::vector D0_input_actual = LoadTxtTensor(pt_dir + "scale.txt"); + + + // // 拷贝 A_input_actual (float → ck::half_t) + // a_m_k.mData.resize(A_input_actual.size()); + // for (size_t i = 0; i < A_input_actual.size(); ++i) { + // a_m_k.mData[i] = ck::type_convert(A_input_actual[i]); + // } + + // // 拷贝 B_input_actual (int → int8_t) + // b_k_n.mData.resize(B_input_actual.size()); + // for (size_t i = 0; i < B_input_actual.size(); ++i) { + // b_k_n.mData[i] = static_cast(B_input_actual[i]); + // } + + // // 拷贝 D0_input_actual (float → ck::half_t) + // b1_k_n.mData.resize(D0_input_actual.size()); + // for (size_t i = 0; i < D0_input_actual.size(); ++i) { + // b1_k_n.mData[i] = ck::type_convert(D0_input_actual[i]); + // } + + + // std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + // std::cout << "D0_input_actual.size:" << D0_input_actual.size() << std::endl; + + + std::string pt_dir = "/mnt/raid0/zhaoan12/gemm_save_rocm/"; + + std::vector A_input_actual = LoadFp16Binary(pt_dir + "A_fp16.bin", M * K); + std::vector B_input_actual = LoadInt8Binary(pt_dir + "B_int8.bin", K * N); + std::vector D0_input_actual = LoadFp16Binary(pt_dir + "scale_fp16.bin", N); + + std::cout << "== Loaded A_fp16[0:10] ==" << std::endl; + for (int i = 0; i < 10; ++i) + std::cout << "A_input_actual[" << i << "] = " << ck::type_convert(A_input_actual[i]) << std::endl; + + std::cout << "== Loaded B_int8[0:10] ==" << std::endl; + for (int i = 0; i < 10; ++i) + std::cout << "B_input_actual[" << i << "] = " << static_cast(B_input_actual[i]) << std::endl; + + std::cout << "== Loaded scale_fp16[0:10] ==" << std::endl; + for (int i = 0; i < 10; ++i) + std::cout << "D0_input_actual[" << i << "] = " << ck::type_convert(D0_input_actual[i]) << std::endl; + + + // A: float → half + a_m_k.mData.resize(A_input_actual.size()); + for (size_t i = 0; i < A_input_actual.size(); ++i) + a_m_k.mData[i] = ck::type_convert(A_input_actual[i]); + + // B: float → int8 + b_k_n.mData.resize(B_input_actual.size()); + for (size_t i = 0; i < B_input_actual.size(); ++i) + b_k_n.mData[i] = static_cast(std::round(B_input_actual[i])); + + // Scale: float → half + b1_k_n.mData.resize(D0_input_actual.size()); + for (size_t i = 0; i < D0_input_actual.size(); ++i) + b1_k_n.mData[i] = ck::type_convert(D0_input_actual[i]); + + + + std::cout << "== Converted A_fp16 (a_m_k.mData)[0:10] ==" << std::endl; + for (size_t i = 0; i < 10; ++i) + std::cout << "a_m_k.mData[" << i << "] = " << ck::type_convert(a_m_k.mData[i]) << std::endl; + + std::cout << "== Converted B_int8 (b_k_n.mData)[0:10] ==" << std::endl; + for (size_t i = 0; i < 10; ++i) + std::cout << "b_k_n.mData[" << i << "] = " << static_cast(b_k_n.mData[i]) << std::endl; + + std::cout << "== Converted scale_fp16 (b1_k_n.mData)[0:10] ==" << std::endl; + for (size_t i = 0; i < 10; ++i) + std::cout << "b1_k_n.mData[" << i << "] = " << ck::type_convert(b1_k_n.mData[i]) << std::endl; + + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // // weight permute + // if constexpr(PermuteB) + // { + // int K1 = KPerBlock; + // int K0 = K / KPerBlock; + + // // int K0, N, K1 + // for(int j = 0; j < K0; j++) + // { + // for(int i = 0; i < N; i++) + // { + // for(int jj = 0; jj < K1; jj++) + // { + // b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + // } + // } + // } + // } + // else + // { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + // } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + b1_scale_device_buf.ToDevice(b1_k_n.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + +#if !defined(__HIP_DEVICE_COMPILE__) + std::cout << "[W8Only_Gemm_Debug] A_input = " << static_cast(a_m_k_device_buf.GetDeviceBuffer()) << std::endl; + std::cout << "[W8Only_Gemm_Debug] B_input = " << static_cast(b_k_n_device_buf.GetDeviceBuffer()) << std::endl; + std::cout << "[W8Only_Gemm_Debug] C_output = " << static_cast(c_m_n_device_buf.GetDeviceBuffer()) << std::endl; + std::cout << "[W8Only_Gemm_Debug] M = " << M << std::endl; + std::cout << "[W8Only_Gemm_Debug] N = " << N << std::endl; + std::cout << "[W8Only_Gemm_Debug] K = " << K << std::endl; + std::cout << "[W8Only_Gemm_Debug] StrideA = " << StrideA << std::endl; + std::cout << "[W8Only_Gemm_Debug] StrideB = " << StrideB << std::endl; + std::cout << "[W8Only_Gemm_Debug] StrideC = " << StrideC << std::endl; + std::cout << "[W8Only_Gemm_Debug] Scale_Stride_BN = " << Scale_Stride_BN << std::endl; + std::cout << "[W8Only_Gemm_Debug] B_scale = " << static_cast(b1_scale_device_buf.GetDeviceBuffer()) << std::endl; + std::cout << "[W8Only_Gemm_Debug] KBatch = " << KBatch << std::endl; +#endif + + + +PrintGemmParams( + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, N, K, + StrideA, StrideB, StrideC, + Scale_Stride_BN, + static_cast(b1_scale_device_buf.GetDeviceBuffer()), + KBatch); + + auto argument = + gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + Scale_Stride_BN, + static_cast(b1_scale_device_buf.GetDeviceBuffer()), + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + Tensor b_k_n_dequant({K, N}); + + float v_b = 0; + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + v_b = ck::type_convert(b_k_n(k, n)); + + b_k_n_dequant(k, n) = + ck::type_convert(v_b) * + ck::type_convert(b1_k_n(k / Scale_Block_K, n / Scale_Block_N)); + } + } + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n_dequant, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + + std::cout << "\n== First 10 values of b_k_n_dequant ==" << std::endl; + for(int i = 0; i < 15; ++i) + { + std::cout << "b_k_n_dequant[" << i << "] = " << ck::type_convert(b_k_n_dequant.mData[i]) << std::endl; + } + + std::cout << "\n== First 10 values of loaded scale ==\n"; + for (int i = 0; i < 10; ++i) { + std::cout << "b1_k_n(0, " << i << ") = " << ck::type_convert(b1_k_n(0, i)) << std::endl; + } + + + for(int i = 0; i < 10; i++) + { + std::cout << "data[" << i << "]: " << ck::type_convert(c_m_n_device_result.mData[i]) << std::endl; + } + + + std::cout << "\n== CK: Checking b_k_n(k,n), scale, and dequant ==" << std::endl; + std::vector> check_indices = { + {0, 0}, {1, 0}, {0, 1}, {1, 1}, {1023, 4095}, {100, 2000} + }; + + for (const auto& [k, n] : check_indices) { + v_b = ck::type_convert(b_k_n(k, n)); + float v_scale = ck::type_convert(b1_k_n(k / 1024, n)); + float v_dequant = ck::type_convert(b_k_n_dequant(k, n)); + std::cout << "b_k_n(" << k << "," << n << ") = " << v_b + << ", scale = " << v_scale + << ", dequant = " << v_dequant << std::endl; + } + + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + // problem_size.M = 8; + // problem_size.N = 3072; + // problem_size.K = 1024; + + problem_size.M = 8; + problem_size.N = 1024; + problem_size.K = 4096; + + config.do_verification = true; + config.init_method = 1; + config.time_kernel = true; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp index bdb24c25a5..52c086280f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -1455,7 +1455,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 1, false>( b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, + make_multi_index((block_n_id * NPerBlock + b_thread_offset_n) / ScaleBlockN, b_thread_offset_k / ScaleBlockK)); constexpr auto b_scale_thread_slice_copy_step = diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index bb1871ae62..10a04044ef 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1399,6 +1399,28 @@ struct ThreadwiseTensorSliceTransfer_v4 dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; }); } + else if constexpr(is_same, int8_t>::value && + is_same, half_t>::value) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + // TODO: if SrcData and DstData are vetor type, then static_cast may not compile + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::Scale{type_convert(scale)}( + dst_tmp_vector.template AsType()(i), + type_convert(src_tmp_vector.template AsType()[i])); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } else { // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp index 93eed31bc5..d420dcf479 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp @@ -30,6 +30,58 @@ void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_default_instances( +std::vector>>& instances); +void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_default_instances( +std::vector>>& instances); +void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_default_instances( +std::vector>>& instances); +void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_default_instances( +std::vector>>& instances); #endif template && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_default_instances(op_ptrs); + add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_default_instances(op_ptrs); + add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_default_instances(op_ptrs); + add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_default_instances(op_ptrs); + } + } return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt index 424320fa8f..a7b7cb9ce1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt @@ -3,8 +3,15 @@ set(GEMM_B_SCALE_INSTANCES) list(APPEND GEMM_B_SCALE_INSTANCES device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_default_instance.cpp + device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_default_instance.cpp + device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_default_instance.cpp + device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_default_instance.cpp ) set_source_files_properties(device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - +set_source_files_properties(device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") add_instance_library(device_gemm_b_scale_instance ${GEMM_B_SCALE_INSTANCES}) \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..e3052fb8e9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = int8_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +#if 0 +template +using device_gemm_xdl_b_scale_f16_i8_f16_mk_nk_mn_comp_instances = std::tuple< + +#endif + +template +using device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | |Wave| Wave| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + //Compute friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1 + + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //3 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //4 + + //new Compute friendly kernel + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //35 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false> //36 + // clang-format on + >; +template +using device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | |Wave| Wave| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + //Latency friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //5 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //7 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false> //8 + // clang-format on + >; +template +using device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | |Wave| Wave| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // // Memory friendly v3 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 32, 128, 8, 32, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //9 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 128, 8, 16, 16, 16, 4, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //10 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //11 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //12 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //13 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //16 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //17 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //18 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //19 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //20 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //21 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //22 + //new Memory friendly kernel + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 64, 256, 8, 32, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false> + // clang-format on + >; +template +using device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | |Wave| Wave| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Memory friendly v4 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //23 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //24 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //25 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //26 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //28 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //29 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //30 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //31 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false> //32 + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_default_instance.cpp new file mode 100644 index 0000000000..b36d422ef5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_default_instance.cpp new file mode 100644 index 0000000000..5d652c0b28 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_default_instance.cpp new file mode 100644 index 0000000000..e8264f6aa3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_default_instance.cpp new file mode 100644 index 0000000000..0e40023348 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck \ No newline at end of file diff --git a/profiler/include/profiler/profile_gemm_b_scale_impl.hpp b/profiler/include/profiler/profile_gemm_b_scale_impl.hpp index fe977e766e..db7ed1c628 100644 --- a/profiler/include/profiler/profile_gemm_b_scale_impl.hpp +++ b/profiler/include/profiler/profile_gemm_b_scale_impl.hpp @@ -65,6 +65,27 @@ bool profile_gemm_b_scale_impl(int do_verification, } }; + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); ck::index_t Scale_Stride_BN = ck::is_same_v ? ((K + ScaleBlockK - 1) / ScaleBlockK) : N; @@ -159,21 +180,28 @@ bool profile_gemm_b_scale_impl(int do_verification, { for(int k = 0; k < K; k++) { - ck::pk_i4_t i4x2 = b_k_n(k, n).data; - int8_t i4 = 0; - if(k % 2 == 1) - i4 = (i4x2.data >> 0) & 0xf; - else - i4 = (i4x2.data >> 4) & 0xf; - i4 = i4 - 8; - v_b = ck::type_convert(i4); + if constexpr(is_same_v) + { + ck::pk_i4_t i4x2 = b_k_n(k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + v_b = ck::type_convert(i4); + } + else if constexpr(is_same_v) + { + v_b = ck::type_convert(b_k_n(k, n)); + } b_k_n_dequant(k, n) = ck::type_convert(v_b) * ck::type_convert(b1_k_n(k / ScaleBlockK, n)); } } using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm && is_same_v) + if constexpr(is_same_v && is_same_v) { // vector pk_i4x4 permute for(int i = 0; i < N; i++) diff --git a/profiler/src/profile_gemm_b_scale.cpp b/profiler/src/profile_gemm_b_scale.cpp index 443ebff834..707072c0c2 100644 --- a/profiler/src/profile_gemm_b_scale.cpp +++ b/profiler/src/profile_gemm_b_scale.cpp @@ -28,6 +28,7 @@ enum struct GemmDataType F16_F16_F16_F8, // 6 F8_F8_BF16, // 7 F16_I4_F16, // 8 + F16_I8_F16, // 8 }; enum struct BScaleBlockTile @@ -46,7 +47,7 @@ int profile_gemm_b_scale(int argc, char* argv[]) printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: " "f16->f8; 7: f8->bf16, " - "comp f8; 8: f16@i4)\n"); + "comp f8; 8: f16@i4 9: f16@i8)\n"); printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); @@ -106,6 +107,7 @@ int profile_gemm_b_scale(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using I4 = ck::pk_i4_t; + using I8 = int8_t; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -170,10 +172,21 @@ int profile_gemm_b_scale(int argc, char* argv[]) return profile( F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<128>{}, Row{}, Col{}, Row{}); } + if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN && + B_scale_block == BScaleBlockTile::K_128) + { + printf("F16_I8_F16 MK_NK_MN K_128\n"); + return profile( + F16{}, I8{}, F16{}, F16{}, F32{}, F16{}, ck::Number<128>{}, Row{}, Col{}, Row{}); + } else { std::cout << "this data_type & layout is not implemented" << std::endl; - + std::cout << "this data_type: " << static_cast(data_type) + << " & layout :" << static_cast(layout) + << " & B_scale_block:" << static_cast(B_scale_block) << " is not implemented" + << std::endl; + return 1; } }