mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
fix: support int8 weight-only kernel with per-channel scale
This commit is contained in:
@@ -67,7 +67,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32)
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp16_i8_v3_b_scale gemm_xdl_fp16_i8_v3_b_scale.cpp)
|
||||
add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16)
|
||||
set(target 1)
|
||||
|
||||
521
example/01_gemm/gemm_xdl_fp16_i8_v3_b_scale.cpp
Normal file
521
example/01_gemm/gemm_xdl_fp16_i8_v3_b_scale.cpp
Normal file
@@ -0,0 +1,521 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
|
||||
std::vector<float> LoadTxtTensor(const std::string& path)
|
||||
{
|
||||
std::ifstream file(path);
|
||||
std::vector<float> data;
|
||||
float val;
|
||||
while (file >> val)
|
||||
data.push_back(val);
|
||||
return data;
|
||||
}
|
||||
|
||||
|
||||
std::vector<ck::half_t> LoadFp16Binary(const std::string& path, size_t count) {
|
||||
std::vector<ck::half_t> data(count);
|
||||
std::ifstream in(path, std::ios::binary);
|
||||
if (!in) {
|
||||
throw std::runtime_error("Failed to open file: " + path);
|
||||
}
|
||||
in.read(reinterpret_cast<char*>(data.data()), count * sizeof(ck::half_t));
|
||||
return data;
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::vector<int8_t> LoadInt8Binary(const std::string& path, size_t count) {
|
||||
std::vector<int8_t> data(count);
|
||||
std::ifstream in(path, std::ios::binary);
|
||||
in.read(reinterpret_cast<char*>(data.data()), count * sizeof(int8_t));
|
||||
return data;
|
||||
}
|
||||
|
||||
|
||||
// std::vector<float> LoadTensorFromPt(const std::string& path) {
|
||||
// torch::Tensor tensor = torch::load(path).to(torch::kCPU);
|
||||
// tensor = tensor.contiguous();
|
||||
// auto* ptr = tensor.data_ptr<float>();
|
||||
// return std::vector<float>(ptr, ptr + tensor.numel());
|
||||
// }
|
||||
|
||||
|
||||
void PrintGemmParams(const void* A, const void* B, const void* C,
|
||||
int M, int N, int K,
|
||||
int StrideA, int StrideB, int StrideC,
|
||||
int ScaleStrideBN, const void* BScale, int KBatch)
|
||||
{
|
||||
std::cout << "[W8Only_Gemm_Debug] A_input = " << A << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] B_input = " << B << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] C_output = " << C << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] M = " << M << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] N = " << N << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] K = " << K << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] StrideA = " << StrideA << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] StrideB = " << StrideB << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] StrideC = " << StrideC << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] Scale_Stride_BN = " << ScaleStrideBN << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] B_scale = " << BScale << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] KBatch = " << KBatch << std::endl;
|
||||
}
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = int8_t;
|
||||
using BScaleDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr ck::index_t Scale_Block_N = 1;
|
||||
static constexpr ck::index_t Scale_Block_K = 1024;
|
||||
|
||||
static constexpr ck::index_t KPerBlock = 64;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
256, Scale_Block_N, Scale_Block_K,
|
||||
128, 128,
|
||||
KPerBlock, 8, 32,
|
||||
32, 32,
|
||||
4, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
1, 1, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
auto K = problem_size.K;
|
||||
auto StrideA = problem_size.StrideA;
|
||||
auto StrideB = problem_size.StrideB;
|
||||
auto StrideC = problem_size.StrideC;
|
||||
auto KBatch = problem_size.KBatch;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
// ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
ck::index_t Scale_Stride_BN = 1;
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<BScaleDataType> b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K,
|
||||
(N + Scale_Block_N - 1) / Scale_Block_N,
|
||||
Scale_Stride_BN,
|
||||
BLayout{}));
|
||||
// Tensor<BScaleDataType> b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K,
|
||||
// (N + Scale_Block_N - 1) / Scale_Block_N,
|
||||
// Scale_Stride_BN,
|
||||
// ck::tensor_layout::gemm::RowMajor{}));
|
||||
// Tensor<BScaleDataType> b1_k_n(f_host_tensor_descriptor(1, 4096, 1, ck::tensor_layout::gemm::RowMajor{}));
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
case 3:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
case 4:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
break;
|
||||
case 5:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.5, 0.5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
}
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
|
||||
|
||||
// std::string pt_dir = "/mnt/raid0/zhaoan12/letao_gemm_pt/";
|
||||
|
||||
// std::vector<float> A_input_actual = LoadTxtTensor(pt_dir + "A.txt");
|
||||
// std::vector<float> B_input_actual = LoadTxtTensor(pt_dir + "B.txt");
|
||||
// std::vector<float> D0_input_actual = LoadTxtTensor(pt_dir + "scale.txt");
|
||||
|
||||
|
||||
// // 拷贝 A_input_actual (float → ck::half_t)
|
||||
// a_m_k.mData.resize(A_input_actual.size());
|
||||
// for (size_t i = 0; i < A_input_actual.size(); ++i) {
|
||||
// a_m_k.mData[i] = ck::type_convert<ck::half_t>(A_input_actual[i]);
|
||||
// }
|
||||
|
||||
// // 拷贝 B_input_actual (int → int8_t)
|
||||
// b_k_n.mData.resize(B_input_actual.size());
|
||||
// for (size_t i = 0; i < B_input_actual.size(); ++i) {
|
||||
// b_k_n.mData[i] = static_cast<int8_t>(B_input_actual[i]);
|
||||
// }
|
||||
|
||||
// // 拷贝 D0_input_actual (float → ck::half_t)
|
||||
// b1_k_n.mData.resize(D0_input_actual.size());
|
||||
// for (size_t i = 0; i < D0_input_actual.size(); ++i) {
|
||||
// b1_k_n.mData[i] = ck::type_convert<ck::half_t>(D0_input_actual[i]);
|
||||
// }
|
||||
|
||||
|
||||
// std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
// std::cout << "D0_input_actual.size:" << D0_input_actual.size() << std::endl;
|
||||
|
||||
|
||||
std::string pt_dir = "/mnt/raid0/zhaoan12/gemm_save_rocm/";
|
||||
|
||||
std::vector<ck::half_t> A_input_actual = LoadFp16Binary(pt_dir + "A_fp16.bin", M * K);
|
||||
std::vector<int8_t> B_input_actual = LoadInt8Binary(pt_dir + "B_int8.bin", K * N);
|
||||
std::vector<ck::half_t> D0_input_actual = LoadFp16Binary(pt_dir + "scale_fp16.bin", N);
|
||||
|
||||
std::cout << "== Loaded A_fp16[0:10] ==" << std::endl;
|
||||
for (int i = 0; i < 10; ++i)
|
||||
std::cout << "A_input_actual[" << i << "] = " << ck::type_convert<float>(A_input_actual[i]) << std::endl;
|
||||
|
||||
std::cout << "== Loaded B_int8[0:10] ==" << std::endl;
|
||||
for (int i = 0; i < 10; ++i)
|
||||
std::cout << "B_input_actual[" << i << "] = " << static_cast<int>(B_input_actual[i]) << std::endl;
|
||||
|
||||
std::cout << "== Loaded scale_fp16[0:10] ==" << std::endl;
|
||||
for (int i = 0; i < 10; ++i)
|
||||
std::cout << "D0_input_actual[" << i << "] = " << ck::type_convert<float>(D0_input_actual[i]) << std::endl;
|
||||
|
||||
|
||||
// A: float → half
|
||||
a_m_k.mData.resize(A_input_actual.size());
|
||||
for (size_t i = 0; i < A_input_actual.size(); ++i)
|
||||
a_m_k.mData[i] = ck::type_convert<ck::half_t>(A_input_actual[i]);
|
||||
|
||||
// B: float → int8
|
||||
b_k_n.mData.resize(B_input_actual.size());
|
||||
for (size_t i = 0; i < B_input_actual.size(); ++i)
|
||||
b_k_n.mData[i] = static_cast<int8_t>(std::round(B_input_actual[i]));
|
||||
|
||||
// Scale: float → half
|
||||
b1_k_n.mData.resize(D0_input_actual.size());
|
||||
for (size_t i = 0; i < D0_input_actual.size(); ++i)
|
||||
b1_k_n.mData[i] = ck::type_convert<ck::half_t>(D0_input_actual[i]);
|
||||
|
||||
|
||||
|
||||
std::cout << "== Converted A_fp16 (a_m_k.mData)[0:10] ==" << std::endl;
|
||||
for (size_t i = 0; i < 10; ++i)
|
||||
std::cout << "a_m_k.mData[" << i << "] = " << ck::type_convert<float>(a_m_k.mData[i]) << std::endl;
|
||||
|
||||
std::cout << "== Converted B_int8 (b_k_n.mData)[0:10] ==" << std::endl;
|
||||
for (size_t i = 0; i < 10; ++i)
|
||||
std::cout << "b_k_n.mData[" << i << "] = " << static_cast<int>(b_k_n.mData[i]) << std::endl;
|
||||
|
||||
std::cout << "== Converted scale_fp16 (b1_k_n.mData)[0:10] ==" << std::endl;
|
||||
for (size_t i = 0; i < 10; ++i)
|
||||
std::cout << "b1_k_n.mData[" << i << "] = " << ck::type_convert<float>(b1_k_n.mData[i]) << std::endl;
|
||||
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
// // weight permute
|
||||
// if constexpr(PermuteB)
|
||||
// {
|
||||
// int K1 = KPerBlock;
|
||||
// int K0 = K / KPerBlock;
|
||||
|
||||
// // int K0, N, K1
|
||||
// for(int j = 0; j < K0; j++)
|
||||
// {
|
||||
// for(int i = 0; i < N; i++)
|
||||
// {
|
||||
// for(int jj = 0; jj < K1; jj++)
|
||||
// {
|
||||
// b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj));
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j++)
|
||||
{
|
||||
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
|
||||
}
|
||||
}
|
||||
// }
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
|
||||
b1_scale_device_buf.ToDevice(b1_k_n.mData.data());
|
||||
DeviceMem workspace;
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmV2Instance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
float ave_time = 0;
|
||||
|
||||
#if !defined(__HIP_DEVICE_COMPILE__)
|
||||
std::cout << "[W8Only_Gemm_Debug] A_input = " << static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()) << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] B_input = " << static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()) << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] C_output = " << static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()) << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] M = " << M << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] N = " << N << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] K = " << K << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] StrideA = " << StrideA << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] StrideB = " << StrideB << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] StrideC = " << StrideC << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] Scale_Stride_BN = " << Scale_Stride_BN << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] B_scale = " << static_cast<BScaleDataType*>(b1_scale_device_buf.GetDeviceBuffer()) << std::endl;
|
||||
std::cout << "[W8Only_Gemm_Debug] KBatch = " << KBatch << std::endl;
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
PrintGemmParams(
|
||||
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
M, N, K,
|
||||
StrideA, StrideB, StrideC,
|
||||
Scale_Stride_BN,
|
||||
static_cast<BScaleDataType*>(b1_scale_device_buf.GetDeviceBuffer()),
|
||||
KBatch);
|
||||
|
||||
auto argument =
|
||||
gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
Scale_Stride_BN,
|
||||
static_cast<BScaleDataType*>(b1_scale_device_buf.GetDeviceBuffer()),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
Tensor<float> b_k_n_dequant({K, N});
|
||||
|
||||
float v_b = 0;
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
v_b = ck::type_convert<float>(b_k_n(k, n));
|
||||
|
||||
b_k_n_dequant(k, n) =
|
||||
ck::type_convert<float>(v_b) *
|
||||
ck::type_convert<float>(b1_k_n(k / Scale_Block_K, n / Scale_Block_N));
|
||||
}
|
||||
}
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n_dequant, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0});
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
|
||||
std::cout << "\n== First 10 values of b_k_n_dequant ==" << std::endl;
|
||||
for(int i = 0; i < 15; ++i)
|
||||
{
|
||||
std::cout << "b_k_n_dequant[" << i << "] = " << ck::type_convert<float>(b_k_n_dequant.mData[i]) << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "\n== First 10 values of loaded scale ==\n";
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
std::cout << "b1_k_n(0, " << i << ") = " << ck::type_convert<float>(b1_k_n(0, i)) << std::endl;
|
||||
}
|
||||
|
||||
|
||||
for(int i = 0; i < 10; i++)
|
||||
{
|
||||
std::cout << "data[" << i << "]: " << ck::type_convert<float>(c_m_n_device_result.mData[i]) << std::endl;
|
||||
}
|
||||
|
||||
|
||||
std::cout << "\n== CK: Checking b_k_n(k,n), scale, and dequant ==" << std::endl;
|
||||
std::vector<std::pair<int, int>> check_indices = {
|
||||
{0, 0}, {1, 0}, {0, 1}, {1, 1}, {1023, 4095}, {100, 2000}
|
||||
};
|
||||
|
||||
for (const auto& [k, n] : check_indices) {
|
||||
v_b = ck::type_convert<float>(b_k_n(k, n));
|
||||
float v_scale = ck::type_convert<float>(b1_k_n(k / 1024, n));
|
||||
float v_dequant = ck::type_convert<float>(b_k_n_dequant(k, n));
|
||||
std::cout << "b_k_n(" << k << "," << n << ") = " << v_b
|
||||
<< ", scale = " << v_scale
|
||||
<< ", dequant = " << v_dequant << std::endl;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
ave_time =
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
|
||||
|
||||
std::size_t flop = 2_uz * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K +
|
||||
sizeof(BDataType) * K * N /
|
||||
(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) +
|
||||
sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_gemm_splitk_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSizeSplitK problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
// problem_size.M = 8;
|
||||
// problem_size.N = 3072;
|
||||
// problem_size.K = 1024;
|
||||
|
||||
problem_size.M = 8;
|
||||
problem_size.N = 1024;
|
||||
problem_size.K = 4096;
|
||||
|
||||
config.do_verification = true;
|
||||
config.init_method = 1;
|
||||
config.time_kernel = true;
|
||||
|
||||
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
@@ -1455,7 +1455,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
1,
|
||||
false>(
|
||||
b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n,
|
||||
make_multi_index((block_n_id * NPerBlock + b_thread_offset_n) / ScaleBlockN,
|
||||
b_thread_offset_k / ScaleBlockK));
|
||||
|
||||
constexpr auto b_scale_thread_slice_copy_step =
|
||||
|
||||
@@ -1399,6 +1399,28 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
}
|
||||
else if constexpr(is_same<remove_cvref_t<SrcData>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<DstData>, half_t>::value)
|
||||
{
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
|
||||
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
ck::tensor_operation::element_wise::Scale{type_convert<float>(scale)}(
|
||||
dst_tmp_vector.template AsType<DstData>()(i),
|
||||
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]));
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
|
||||
@@ -30,6 +30,58 @@ void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
I8,
|
||||
F16,
|
||||
F16,
|
||||
1,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
I8,
|
||||
F16,
|
||||
F16,
|
||||
1,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
I8,
|
||||
F16,
|
||||
F16,
|
||||
1,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
I8,
|
||||
F16,
|
||||
F16,
|
||||
1,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
template <typename ADataType,
|
||||
@@ -80,7 +132,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmV2
|
||||
add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_default_instances(op_ptrs);
|
||||
add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_default_instances(op_ptrs);
|
||||
add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_default_instances(op_ptrs);
|
||||
add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_default_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -3,8 +3,15 @@ set(GEMM_B_SCALE_INSTANCES)
|
||||
|
||||
list(APPEND GEMM_B_SCALE_INSTANCES
|
||||
device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_default_instance.cpp
|
||||
device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_default_instance.cpp
|
||||
device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_default_instance.cpp
|
||||
device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_default_instance.cpp
|
||||
)
|
||||
|
||||
set_source_files_properties(device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
set_source_files_properties(device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
add_instance_library(device_gemm_b_scale_instance ${GEMM_B_SCALE_INSTANCES})
|
||||
@@ -0,0 +1,125 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I8 = int8_t;
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
#if 0
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_b_scale_f16_i8_f16_mk_nk_mn_comp_instances = std::tuple<
|
||||
|
||||
#endif
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | |Wave| Wave| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
//Compute friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1
|
||||
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //3
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //4
|
||||
|
||||
//new Compute friendly kernel
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //35
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false> //36
|
||||
// clang-format on
|
||||
>;
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | |Wave| Wave| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
//Latency friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //5
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //7
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false> //8
|
||||
// clang-format on
|
||||
>;
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | |Wave| Wave| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// // Memory friendly v3
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 32, 128, 8, 32, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //9
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 128, 8, 16, 16, 16, 4, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //10
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //11
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //12
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //13
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //16
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //17
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //18
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //19
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //20
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //21
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //22
|
||||
//new Memory friendly kernel
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 64, 256, 8, 32, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>
|
||||
// clang-format on
|
||||
>;
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | |Wave| Wave| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// Memory friendly v4
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //23
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //24
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //25
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //26
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //28
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //29
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //30
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //31
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false> //32
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
I8,
|
||||
F16,
|
||||
F16,
|
||||
1,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_cmp_instances<Intrawave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
I8,
|
||||
F16,
|
||||
F16,
|
||||
1,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_lat_instances<Intrawave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
I8,
|
||||
F16,
|
||||
F16,
|
||||
1,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v3_instances<Intrawave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
I8,
|
||||
F16,
|
||||
F16,
|
||||
1,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_mem_v4_instances<Intrawave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -65,6 +65,27 @@ bool profile_gemm_b_scale_impl(int do_verification,
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
ck::index_t Scale_Stride_BN = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
|
||||
? ((K + ScaleBlockK - 1) / ScaleBlockK)
|
||||
: N;
|
||||
@@ -159,21 +180,28 @@ bool profile_gemm_b_scale_impl(int do_verification,
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
ck::pk_i4_t i4x2 = b_k_n(k, n).data;
|
||||
int8_t i4 = 0;
|
||||
if(k % 2 == 1)
|
||||
i4 = (i4x2.data >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2.data >> 4) & 0xf;
|
||||
i4 = i4 - 8;
|
||||
v_b = ck::type_convert<float>(i4);
|
||||
if constexpr(is_same_v<BDataType, pk_i4_t>)
|
||||
{
|
||||
ck::pk_i4_t i4x2 = b_k_n(k, n).data;
|
||||
int8_t i4 = 0;
|
||||
if(k % 2 == 1)
|
||||
i4 = (i4x2.data >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2.data >> 4) & 0xf;
|
||||
i4 = i4 - 8;
|
||||
v_b = ck::type_convert<float>(i4);
|
||||
}
|
||||
else if constexpr(is_same_v<BDataType, int8_t>)
|
||||
{
|
||||
v_b = ck::type_convert<float>(b_k_n(k, n));
|
||||
}
|
||||
|
||||
b_k_n_dequant(k, n) = ck::type_convert<float>(v_b) *
|
||||
ck::type_convert<float>(b1_k_n(k / ScaleBlockK, n));
|
||||
}
|
||||
}
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
float,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
@@ -218,7 +246,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
|
||||
}
|
||||
}
|
||||
|
||||
if(is_same_v<BDataType, pk_i4_t> && is_same_v<ADataType, half_t>)
|
||||
if constexpr(is_same_v<BDataType, pk_i4_t> && is_same_v<ADataType, half_t>)
|
||||
{
|
||||
// vector pk_i4x4 permute
|
||||
for(int i = 0; i < N; i++)
|
||||
|
||||
@@ -28,6 +28,7 @@ enum struct GemmDataType
|
||||
F16_F16_F16_F8, // 6
|
||||
F8_F8_BF16, // 7
|
||||
F16_I4_F16, // 8
|
||||
F16_I8_F16, // 8
|
||||
};
|
||||
|
||||
enum struct BScaleBlockTile
|
||||
@@ -46,7 +47,7 @@ int profile_gemm_b_scale(int argc, char* argv[])
|
||||
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
|
||||
"f16->f8; 7: f8->bf16, "
|
||||
"comp f8; 8: f16@i4)\n");
|
||||
"comp f8; 8: f16@i4 9: f16@i8)\n");
|
||||
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
@@ -106,6 +107,7 @@ int profile_gemm_b_scale(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using I4 = ck::pk_i4_t;
|
||||
using I8 = int8_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -170,10 +172,21 @@ int profile_gemm_b_scale(int argc, char* argv[])
|
||||
return profile(
|
||||
F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<128>{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN &&
|
||||
B_scale_block == BScaleBlockTile::K_128)
|
||||
{
|
||||
printf("F16_I8_F16 MK_NK_MN K_128\n");
|
||||
return profile(
|
||||
F16{}, I8{}, F16{}, F16{}, F32{}, F16{}, ck::Number<128>{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
std::cout << "this data_type: " << static_cast<int>(data_type)
|
||||
<< " & layout :" << static_cast<int>(layout)
|
||||
<< " & B_scale_block:" << static_cast<int>(B_scale_block) << " is not implemented"
|
||||
<< std::endl;
|
||||
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user