mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Added Int4 mixed batch gemm support (#1839)
* remove redundant kernels. * added batched_gemm_xdl_fp16int4_b_scale_v3 * Enabled the split K. * added the batched_gemm_b_scale ckProfiler, meet function issue * fix some typo * fix ckProfiler build issue * fix some bugs * updated some debug info * comment some code * Fix * fixed some bugs and refactor the code * fixed a function bug. * formatted files. * formatted * uncommented the ckProfiler CMakeLists * fixed. * fix ckProfiler for batched_gemm_b_scale --------- Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
This commit is contained in:
@@ -22,3 +22,6 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp)
|
||||
add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4)
|
||||
endif()
|
||||
|
||||
add_example_executable(example_batched_gemm_xdl_fp16int4_b_scale_v3 batched_gemm_xdl_fp16int4_b_scale_v3.cpp)
|
||||
add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16int4_b_scale_v3)
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
#include <cstdlib>
|
||||
#include <initializer_list>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = ck::pk_i4_t;
|
||||
using BScaleDataType = ck::half_t;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using CDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr ck::index_t Scale_Block_N = 1;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
static constexpr ck::index_t KPerBlock = 256;
|
||||
|
||||
// clang-format off
|
||||
using DeviceBatchedGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemm_Xdl_CShuffleV3_BScale<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
256, Scale_Block_N, Scale_Block_K,
|
||||
16, 64,
|
||||
KPerBlock, 8, 32,
|
||||
16, 16,
|
||||
1, 1,
|
||||
S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 32, 32, 0,
|
||||
1, 1, S<1, 16, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
#include "run_batched_gemm_example_fp16int4_b_scale.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_fp16_int4_b_scale_example(argc, argv); }
|
||||
@@ -0,0 +1,578 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <random>
|
||||
|
||||
#pragma once
|
||||
struct ProblemSize final
|
||||
{
|
||||
ck::index_t M = 128;
|
||||
ck::index_t N = 128;
|
||||
ck::index_t K = 384;
|
||||
|
||||
ck::index_t stride_A = K;
|
||||
ck::index_t stride_B = K;
|
||||
ck::index_t stride_C = N;
|
||||
|
||||
ck::index_t batch_stride_A = M * K;
|
||||
ck::index_t batch_stride_B = K * N;
|
||||
ck::index_t batch_stride_C = M * N;
|
||||
|
||||
// Batched Gemm count
|
||||
ck::index_t batch_count = 2;
|
||||
|
||||
// Split K count
|
||||
ck::index_t KBatch = 1;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
inline __host__ __device__ constexpr double get_rtol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 1e-1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 1.5e-1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
inline __host__ __device__ constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 16.1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 8192.1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
auto& [M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_count,
|
||||
KBatch] = problem_size;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
ck::index_t batch_BScale_Stride =
|
||||
((K + Scale_Block_K - 1) / Scale_Block_K) * ((N + Scale_Block_N - 1) / Scale_Block_N);
|
||||
|
||||
Tensor<ADataType> a_g_m_k(
|
||||
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{}));
|
||||
Tensor<BDataType> b_g_k_n(
|
||||
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{}));
|
||||
Tensor<BDataType> b_g_k_n_permute(
|
||||
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{}));
|
||||
Tensor<BScaleDataType> b1_g_k_n(
|
||||
f_host_tensor_descriptor(batch_count,
|
||||
(K + Scale_Block_K - 1) / Scale_Block_K,
|
||||
(N + Scale_Block_N - 1) / Scale_Block_N,
|
||||
Scale_Stride_BN,
|
||||
batch_BScale_Stride,
|
||||
BLayout{}));
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
b1_g_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
case 1:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_g_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
case 3:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
b1_g_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
case 4:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
break;
|
||||
case 5:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_g_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.5, 0.5});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
}
|
||||
|
||||
Tensor<CDataType> c_g_m_n_host_result(
|
||||
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{}));
|
||||
Tensor<CDataType> c_g_m_n_device_result(
|
||||
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{}));
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_g_k_n: " << b1_g_k_n.mDesc << std::endl;
|
||||
std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_g_scale_device_buf(sizeof(BScaleDataType) * b1_g_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_g_m_n_device_buf(sizeof(CDataType) *
|
||||
c_g_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
printf("a_g_m_k size: %zu, b_g_k_n size: %zu, b1_g_k_n size: %zu, c_g_m_n size: %zu\n",
|
||||
a_g_m_k.mDesc.GetElementSpaceSize(),
|
||||
b_g_k_n_permute.mDesc.GetElementSpaceSize(),
|
||||
b1_g_k_n.mDesc.GetElementSpaceSize(),
|
||||
c_g_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
// weight permute
|
||||
if constexpr(PermuteB)
|
||||
{
|
||||
printf("Permute B\n");
|
||||
int K1 = KPerBlock;
|
||||
int K0 = K / KPerBlock;
|
||||
|
||||
// int K0, N, K1
|
||||
for(int bs = 0; bs < batch_count; bs++)
|
||||
{
|
||||
for(int j = 0; j < K0; j++)
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int jj = 0; jj < K1; jj++)
|
||||
{
|
||||
b_g_k_n_permute(bs * batch_stride_B + j * N * K1 + i * K1 + jj) =
|
||||
b_g_k_n(bs * batch_stride_B + i * K + (j * K1 + jj));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
b_g_k_n_permute = b_g_k_n;
|
||||
}
|
||||
|
||||
// vector pk_i4x4 permute
|
||||
for(int bs = 0; bs < batch_count; bs++)
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j += 8)
|
||||
{
|
||||
int input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int i4x2 = b_g_k_n_permute(bs, j + k * 2, i).data;
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int hi = input[2];
|
||||
int lo = input[0];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_g_k_n_permute(bs, j + 0, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[6];
|
||||
int lo = input[4];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_g_k_n_permute(bs, j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[3];
|
||||
int lo = input[1];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_g_k_n_permute(bs, j + 4, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[7];
|
||||
int lo = input[5];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_g_k_n_permute(bs, j + 6, i) = i4x2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b_g_k_n_device_buf.ToDevice(b_g_k_n_permute.mData.data());
|
||||
b1_g_scale_device_buf.ToDevice(b1_g_k_n.mData.data());
|
||||
DeviceMem workspace;
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceBatchedGemmV2Instance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
float ave_time = 0;
|
||||
|
||||
auto argument =
|
||||
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_g_m_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
Scale_Stride_BN,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_BScale_Stride,
|
||||
static_cast<BScaleDataType*>(b1_g_scale_device_buf.GetDeviceBuffer()),
|
||||
batch_count, // batch count
|
||||
KBatch, // split K count
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
Tensor<float> b_g_k_n_dequant({batch_count, K, N});
|
||||
if(config.do_verification)
|
||||
{
|
||||
float v_b = 0;
|
||||
for(int bs = 0; bs < batch_count; bs++)
|
||||
{
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
ck::pk_i4_t i4x2 = b_g_k_n(bs, k, n).data;
|
||||
int8_t i4 = 0;
|
||||
if(k % 2 == 1)
|
||||
i4 = (i4x2.data >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2.data >> 4) & 0xf;
|
||||
i4 = i4 - 8;
|
||||
v_b = ck::type_convert<float>(i4);
|
||||
|
||||
b_g_k_n_dequant(bs, k, n) =
|
||||
ck::type_convert<float>(v_b) *
|
||||
ck::type_convert<float>(b1_g_k_n(bs, k / Scale_Block_K, n / Scale_Block_N));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto ref_gemm = ReferenceBatchedGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_g_m_k,
|
||||
b_g_k_n_dequant,
|
||||
c_g_m_n_host_result,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0});
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
c_g_m_n_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(c_g_m_n_device_result,
|
||||
c_g_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t flop = 2_uz * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K +
|
||||
sizeof(BDataType) * K * N /
|
||||
(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) +
|
||||
sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
#if 0
|
||||
// print A matrix
|
||||
printf("A matrix:\n");
|
||||
for(int bs = 0; bs < batch_count; bs++)
|
||||
{
|
||||
printf("batch %d -> Address: %p\n", bs, static_cast<void*>(&a_g_m_k(bs, 0, 0)));
|
||||
for(int i = 0; i < M; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j++)
|
||||
{
|
||||
printf("%.2f,", static_cast<float>(a_g_m_k(bs, i, j)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// print B matrix original
|
||||
printf("B matrix original:\n");
|
||||
for(int bs = 0; bs < batch_count; bs++)
|
||||
{
|
||||
printf("batch %d -> Address: %p\n", bs, static_cast<void*>(&b_g_k_n(bs, 0, 0)));
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
ck::pk_i4_t i4x2 = b_g_k_n(bs, k, n).data;
|
||||
int8_t i4 = 0;
|
||||
if(k % 2 == 1)
|
||||
i4 = (i4x2.data >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2.data >> 4) & 0xf;
|
||||
i4 = i4 - 8;
|
||||
printf("%d,", static_cast<int>(i4));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// print B matrix
|
||||
printf("B matrix:\n");
|
||||
for(int bs = 0; bs < batch_count; bs++)
|
||||
{
|
||||
printf("batch %d -> Address: %p\n", bs, static_cast<void*>(&b_g_k_n_dequant(bs, 0, 0)));
|
||||
for(int i = 0; i < K; i++)
|
||||
{
|
||||
for(int j = 0; j < N; j++)
|
||||
{
|
||||
printf("%.2f, ", static_cast<float>(b_g_k_n_dequant(bs, i, j)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// print B scale matrix
|
||||
printf("B Scale matrix:\n");
|
||||
for(int bs = 0; bs < batch_count; bs++)
|
||||
{
|
||||
printf("batch %d -> Address: %p\n", bs, static_cast<void*>(&b1_g_k_n(bs, 0, 0)));
|
||||
for(int i = 0; i < (K + Scale_Block_K - 1) / Scale_Block_K; i++)
|
||||
{
|
||||
for(int j = 0; j < (N + Scale_Block_N - 1) / Scale_Block_N; j++)
|
||||
{
|
||||
printf("%.2f, ", static_cast<float>(b1_g_k_n(bs, i, j)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// print C matrix
|
||||
printf("C matrix:\n");
|
||||
for(int bs = 0; bs < batch_count; bs++)
|
||||
{
|
||||
printf(
|
||||
"batch %d -> Address: %p\n", bs, static_cast<void*>(&c_g_m_n_device_result(bs, 0, 0)));
|
||||
for(int i = 0; i < M; i++)
|
||||
{
|
||||
for(int j = 0; j < N; j++)
|
||||
{
|
||||
printf("%.2f, ", static_cast<float>(c_g_m_n_device_result(bs, i, j)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
printf("C reference matrix:\n");
|
||||
for(int bs = 0; bs < batch_count; bs++)
|
||||
{
|
||||
printf("batch %d -> Address: %p\n", bs, static_cast<void*>(&c_g_m_n_host_result(bs, 0, 0)));
|
||||
for(int i = 0; i < M; i++)
|
||||
{
|
||||
for(int j = 0; j < N; j++)
|
||||
{
|
||||
printf("%.2f, ", static_cast<float>(c_g_m_n_host_result(bs, i, j)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
std::mt19937 gen(11939);
|
||||
std::uniform_int_distribution<int> dis(0, 15);
|
||||
|
||||
problem_size.M = 128 * (dis(gen) + 1);
|
||||
problem_size.N = 128 * (dis(gen) + 1);
|
||||
problem_size.K = 256 * (dis(gen) + 2);
|
||||
|
||||
problem_size.batch_count = 2;
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc >= 7)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
|
||||
if(argc >= 8)
|
||||
{
|
||||
problem_size.batch_count = std::stoi(argv[7]);
|
||||
}
|
||||
|
||||
if(argc >= 9)
|
||||
{
|
||||
problem_size.KBatch = std::stoi(argv[8]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
problem_size.stride_A = problem_size.K;
|
||||
problem_size.stride_B = problem_size.K;
|
||||
problem_size.stride_C = problem_size.N;
|
||||
|
||||
problem_size.batch_stride_A = problem_size.M * problem_size.K;
|
||||
problem_size.batch_stride_B = problem_size.K * problem_size.N;
|
||||
problem_size.batch_stride_C = problem_size.M * problem_size.N;
|
||||
|
||||
return run_batched_gemm(problem_size, config);
|
||||
}
|
||||
@@ -44,6 +44,48 @@ struct DeviceBatchedGemm : public BaseOperator
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename BScaleType,
|
||||
typename CDataType,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceBatchedGemmV2BScale : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t StrideScaleB,
|
||||
ck::index_t BatchStrideA,
|
||||
ck::index_t BatchStrideB,
|
||||
ck::index_t BatchStrideC,
|
||||
ck::index_t BatchStrideScaleB,
|
||||
const void* p_b_scale,
|
||||
ck::index_t Batch,
|
||||
ck::index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual bool GetPermuteB() = 0;
|
||||
virtual ck::index_t GetKPerBlock() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -37,7 +37,7 @@ __global__ void
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
@@ -70,7 +70,7 @@ __global__ void
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
@@ -638,45 +638,45 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
|
||||
__device__ SplitKBatchOffset(Argument& karg)
|
||||
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
|
||||
{
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
|
||||
a_k_split_offset = k_id * karg.KRead / APackedSize;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
|
||||
b_k_split_offset = k_id * karg.KRead / BPackedSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
|
||||
b_k_split_offset = k_id * k0_offset / BPackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate B scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB;
|
||||
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
|
||||
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
|
||||
}
|
||||
|
||||
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
|
||||
if(k_id < (karg.KBatch - 1))
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
}
|
||||
@@ -687,7 +687,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
if(karg.IsReduceAdd())
|
||||
{
|
||||
c_reduce_offset = blockIdx.z * karg.M * karg.N;
|
||||
c_reduce_offset = k_id * karg.M * karg.N;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
|
||||
void add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmV2BScale<Row,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
I4,
|
||||
F16,
|
||||
F16,
|
||||
1,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
index_t ScaleBlockK>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatchedGemmV2BScale<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
CDataType,
|
||||
1,
|
||||
ScaleBlockK,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmV2BScale<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
CDataType,
|
||||
1,
|
||||
ScaleBlockK,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, pk_i4_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,10 @@
|
||||
# ONLY XDL_KERNELS
|
||||
set(BATCHED_GEMM_B_SCALE_INSTANCES)
|
||||
|
||||
list(APPEND BATCHED_GEMM_B_SCALE_INSTANCES
|
||||
device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
)
|
||||
|
||||
set_source_files_properties(device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
add_instance_library(device_batched_gemm_b_scale_instance ${BATCHED_GEMM_B_SCALE_INSTANCES})
|
||||
@@ -0,0 +1,95 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I4 = pk_i4_t;
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | |Wave| Wave| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
//Compute friendly
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1
|
||||
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //3
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //4
|
||||
|
||||
//Latency friendly
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //5
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //7
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //8
|
||||
|
||||
// Memory friendly v3
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 32, 128, 8, 32, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //9
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 128, 8, 16, 16, 16, 4, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //10
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //11
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //12
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //13
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //16
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //17
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //18
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //19
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //20
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //21
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //22
|
||||
|
||||
// Memory friendly v4
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //23
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //24
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //25
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //26
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //28
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //29
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //30
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //31
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //32
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //33
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //34
|
||||
|
||||
//new Compute friendly kernel
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //35
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //36
|
||||
|
||||
//new Memory friendly kernel
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 64, 256, 8, 32, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false> //37
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmV2BScale<Row,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
I4,
|
||||
F16,
|
||||
F16,
|
||||
1,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Intrawave,
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -48,55 +48,51 @@ using device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple<
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
//Compute friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1
|
||||
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //3
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //4
|
||||
|
||||
//Latency friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //5
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //7
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //8
|
||||
|
||||
// Memory friendly v3
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 32, 128, 8, 32, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 128, 8, 16, 16, 16, 4, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 32, 128, 8, 32, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //9
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 128, 8, 16, 16, 16, 4, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //10
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //11
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //12
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //13
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //16
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //17
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //18
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //19
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //20
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //21
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //22
|
||||
|
||||
// Memory friendly v4
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //23
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //24
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //25
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //26
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //28
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //29
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //30
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //31
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //32
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //33
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //34
|
||||
|
||||
//new Compute friendly kernel
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //35
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //36
|
||||
|
||||
//new Memory friendly kernel
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 64, 256, 8, 32, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 64, 256, 8, 32, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false> //37
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
|
||||
488
profiler/include/profiler/profile_batched_gemm_b_scale_impl.hpp
Normal file
488
profiler/include/profiler/profile_batched_gemm_b_scale_impl.hpp
Normal file
@@ -0,0 +1,488 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <typeinfo>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_b_scale.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
index_t ScaleBlockK,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
bool profile_batched_gemm_b_scale_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideC,
|
||||
int BatchStrideA,
|
||||
int BatchStrideB,
|
||||
int BatchStrideC,
|
||||
int BatchStrideScaleB,
|
||||
int BatchSize,
|
||||
int KBatch,
|
||||
int n_warmup,
|
||||
int n_iter,
|
||||
uint64_t rotating = 0)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
ck::index_t Scale_Stride_BN = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
|
||||
? ((K + ScaleBlockK - 1) / ScaleBlockK)
|
||||
: N;
|
||||
|
||||
Tensor<ADataType> a_g_m_k(
|
||||
f_host_tensor_descriptor(BatchSize, M, K, StrideA, BatchStrideA, ALayout{}));
|
||||
Tensor<BDataType> b_g_k_n(
|
||||
f_host_tensor_descriptor(BatchSize, K, N, StrideB, BatchStrideB, BLayout{}));
|
||||
Tensor<BDataType> b_g_k_n_permute(
|
||||
f_host_tensor_descriptor(BatchSize, K, N, StrideB, BatchStrideB, BLayout{}));
|
||||
Tensor<BScaleDataType> b1_g_k_n(f_host_tensor_descriptor(
|
||||
BatchSize,
|
||||
(K + ScaleBlockK - 1) / ScaleBlockK, // K direction group size is ScaleBlockK
|
||||
N, // N direction group size is 1
|
||||
Scale_Stride_BN,
|
||||
BatchStrideScaleB,
|
||||
BLayout{}));
|
||||
Tensor<CDataType> c_g_m_n_host_result(
|
||||
f_host_tensor_descriptor(BatchSize, M, N, StrideC, BatchStrideC, CLayout{}));
|
||||
Tensor<CDataType> c_g_m_n_device_result(
|
||||
f_host_tensor_descriptor(BatchSize, M, N, StrideC, BatchStrideC, CLayout{}));
|
||||
|
||||
int total_gemm_needed = a_g_m_k.GetElementSpaceSizeInBytes() +
|
||||
b_g_k_n.GetElementSpaceSizeInBytes() +
|
||||
b1_g_k_n.GetElementSpaceSizeInBytes();
|
||||
|
||||
int rotating_count = std::max(
|
||||
1,
|
||||
std::min(n_iter,
|
||||
static_cast<int>(std::ceil(static_cast<double>(rotating) / total_gemm_needed))));
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_g_k_n: " << b1_g_k_n.mDesc << std::endl;
|
||||
std::cout << "c_g_m_n: " << c_g_m_n_device_result.mDesc << std::endl;
|
||||
std::cout << "rotating count: " << rotating_count << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
|
||||
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto c_element_op = CElementOp{};
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(BScaleDataType) * b1_g_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b1_device_buf.ToDevice(b1_g_k_n.mData.data());
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemmV2BScale<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
CDataType,
|
||||
1,
|
||||
ScaleBlockK,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
// Run reference GEMM
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<float> b_g_k_n_dequant({K, N});
|
||||
|
||||
float v_b = 0;
|
||||
for(int bs = 0; bs < BatchSize; bs++)
|
||||
{
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
ck::pk_i4_t i4x2 = b_g_k_n(bs, k, n).data;
|
||||
int8_t i4 = 0;
|
||||
if(k % 2 == 1)
|
||||
i4 = (i4x2.data >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2.data >> 4) & 0xf;
|
||||
i4 = i4 - 8;
|
||||
v_b = ck::type_convert<float>(i4);
|
||||
|
||||
b_g_k_n_dequant(bs, k, n) =
|
||||
ck::type_convert<float>(v_b) *
|
||||
ck::type_convert<float>(b1_g_k_n(bs, k / ScaleBlockK, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
ComputeDataType>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_g_m_k,
|
||||
b_g_k_n_dequant,
|
||||
c_g_m_n_host_result,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
float best_kbatch = 0;
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
const int KPerBlock = op_ptr->GetKPerBlock();
|
||||
|
||||
if(op_ptr->GetPermuteB())
|
||||
{
|
||||
int K1 = KPerBlock;
|
||||
int K0 = K / KPerBlock;
|
||||
|
||||
// int K0, N, K1
|
||||
for(int bs = 0; bs < BatchSize; bs++)
|
||||
{
|
||||
for(int j = 0; j < K0; j++)
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int jj = 0; jj < K1; jj++)
|
||||
{
|
||||
b_g_k_n_permute(bs * BatchStrideB + j * N * K1 + i * K1 + jj) =
|
||||
b_g_k_n(bs * BatchStrideB + i * K + (j * K1 + jj));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(is_same_v<BDataType, pk_i4_t> && is_same_v<ADataType, half_t>)
|
||||
{
|
||||
// vector pk_i4x4 permute
|
||||
for(int bs = 0; bs < BatchSize; bs++)
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j += 8)
|
||||
{
|
||||
int input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int i4x2 = b_g_k_n_permute(bs, j + k * 2, i).data;
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int hi = input[2];
|
||||
int lo = input[0];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_g_k_n_permute(bs, j + 0, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[6];
|
||||
int lo = input[4];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_g_k_n_permute(bs, j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[3];
|
||||
int lo = input[1];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_g_k_n_permute(bs, j + 4, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[7];
|
||||
int lo = input[5];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_g_k_n_permute(bs, j + 6, i) = i4x2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
b_g_k_n_permute = b_g_k_n;
|
||||
}
|
||||
|
||||
b_device_buf.ToDevice(b_g_k_n_permute.mData.data());
|
||||
|
||||
std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38};
|
||||
|
||||
if(KBatch > 0)
|
||||
{
|
||||
kbatch_list = {KBatch};
|
||||
}
|
||||
|
||||
for(std::size_t i = 0; i < kbatch_list.size(); i++)
|
||||
{
|
||||
auto kbatch_curr = kbatch_list[i];
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
Scale_Stride_BN,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
BatchStrideScaleB,
|
||||
static_cast<BScaleDataType*>(b1_device_buf.GetDeviceBuffer()),
|
||||
BatchSize, // Batch count
|
||||
kbatch_curr, // Split K count
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
|
||||
// re-init C to zero before profiling next kernel
|
||||
c_device_buf.SetZero();
|
||||
|
||||
// invoker_ptr->Run(argument_ptr.get(),
|
||||
// StreamConfig{nullptr, false, 0, n_warmup, n_iter});
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0});
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
// set softer tolerances for fp8
|
||||
if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
|
||||
is_same_v<CDataType, f8_t>)
|
||||
{
|
||||
std::string msg = "Error: Incorrect results!";
|
||||
double rtol = 1e-1;
|
||||
double atol = 1e-1;
|
||||
pass =
|
||||
pass & ck::utils::check_err(
|
||||
c_g_m_n_device_result, c_g_m_n_host_result, msg, rtol, atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
pass =
|
||||
pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
|
||||
#if defined CK_ENABLE_FP8
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_host : ", c_g_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(),
|
||||
StreamConfig{nullptr,
|
||||
time_kernel,
|
||||
0,
|
||||
n_warmup,
|
||||
n_iter,
|
||||
rotating_count > 1,
|
||||
rotating_count});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K * BatchSize;
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K +
|
||||
sizeof(BDataType) * K * N / BPackedSize +
|
||||
sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
|
||||
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
|
||||
<< kbatch_curr << std::endl;
|
||||
|
||||
if(tflops > best_tflops && ave_time > 1e-10)
|
||||
{
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_kbatch = kbatch_curr;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << " does not support this problem"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<CDataType, float>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = f32";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, half_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = f16";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, bhalf_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = bf16";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, int8_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = int8";
|
||||
}
|
||||
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
std::cout << " ALayout = RowMajor";
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
|
||||
{
|
||||
std::cout << " ALayout = ColumnMajor";
|
||||
}
|
||||
|
||||
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
std::cout << " BLayout = RowMajor";
|
||||
}
|
||||
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
|
||||
{
|
||||
std::cout << " BLayout = ColumnMajor";
|
||||
}
|
||||
|
||||
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
|
||||
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch
|
||||
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
|
||||
<< " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -59,6 +59,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_batched_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_universal_batched.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_universal_reduce.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp)
|
||||
@@ -143,6 +144,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_b_scale_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_b_scale_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_batched_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance)
|
||||
|
||||
200
profiler/src/profile_batched_gemm_b_scale.cpp
Normal file
200
profiler/src/profile_batched_gemm_b_scale.cpp
Normal file
@@ -0,0 +1,200 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <initializer_list>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
#include "profiler/profile_batched_gemm_b_scale_impl.hpp"
|
||||
#include "profiler_operation_registry.hpp"
|
||||
|
||||
enum struct GemmMatrixLayout
|
||||
{
|
||||
MK_KN_MN, // 0
|
||||
MK_NK_MN, // 1
|
||||
KM_KN_MN, // 2
|
||||
KM_NK_MN, // 3
|
||||
};
|
||||
|
||||
enum struct GemmDataType
|
||||
{
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
INT8_INT8_INT8, // 3
|
||||
F8_F16_F16, // 4
|
||||
F16_F8_F16, // 5
|
||||
F16_F16_F16_F8, // 6
|
||||
F8_F8_BF16, // 7
|
||||
F16_I4_F16, // 8
|
||||
};
|
||||
|
||||
enum struct BScaleBlockTile
|
||||
{
|
||||
K_64, // 0
|
||||
K_128, // 1
|
||||
};
|
||||
|
||||
#define OP_NAME "batched_gemm_b_scale"
|
||||
#define OP_DESC "Int4-dequant batched GEMM"
|
||||
|
||||
int profile_batched_gemm_b_scale(int argc, char* argv[])
|
||||
{
|
||||
if(argc != 17 && argc != 20)
|
||||
{
|
||||
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
|
||||
"f16->f8; 7: f8->bf16, "
|
||||
"comp f8; 8: f16@i4)\n");
|
||||
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
|
||||
printf("arg4: B scale block tile (0: 64, 1: 128):\n");
|
||||
printf("arg5: verification (0: no; 1: yes)\n");
|
||||
printf("arg6: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg7: print tensor value (0: no; 1: yes)\n");
|
||||
printf("arg8: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg9 to 15: M, N, K, StrideA, StrideB, StrideC, BatachCount\n");
|
||||
printf("arg16: split k into mulitiple batch\n");
|
||||
printf("optional:\n");
|
||||
printf("arg17: number of warm-up cycles (default 1)\n");
|
||||
printf("arg18: number of iterations (default 10)\n");
|
||||
printf("arg19: memory for rotating buffer (default 0, size in MB)\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
printf("Start profiling\n");
|
||||
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
|
||||
const auto B_scale_block = static_cast<BScaleBlockTile>(std::stoi(argv[4]));
|
||||
const bool do_verification = std::stoi(argv[5]);
|
||||
const int init_method = std::stoi(argv[6]);
|
||||
const bool do_log = std::stoi(argv[7]);
|
||||
const bool time_kernel = std::stoi(argv[8]);
|
||||
|
||||
const int M = std::stoi(argv[9]);
|
||||
const int N = std::stoi(argv[10]);
|
||||
const int K = std::stoi(argv[11]);
|
||||
|
||||
const int StrideA = std::stoi(argv[12]);
|
||||
const int StrideB = std::stoi(argv[13]);
|
||||
const int StrideC = std::stoi(argv[14]);
|
||||
|
||||
const int BatchStrideA = M * N;
|
||||
const int BatchStrideB = N * K;
|
||||
const int BatchStrideC = M * N;
|
||||
const int BatchStrideScaleB =
|
||||
(K + static_cast<int>(B_scale_block) - 1) / static_cast<int>(B_scale_block) * N;
|
||||
const int BatchSize = std::stoi(argv[15]);
|
||||
const int KBatch = std::stoi(argv[16]);
|
||||
|
||||
printf("M:%d, N:%d, K:%d, StrideA:%d, StrideB:%d, StrideC:%d, BatchStrideA:%d, "
|
||||
"BatchStrideB:%d, BatchStrideC:%d, BatchStrideScaleB:%d, BatchSize:%d, KBatch:%d,\n",
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
BatchStrideScaleB,
|
||||
BatchSize,
|
||||
KBatch);
|
||||
|
||||
int n_warmup = 1;
|
||||
int n_iter = 10;
|
||||
uint64_t rotating = 0;
|
||||
if(argc == 20)
|
||||
{
|
||||
n_warmup = std::stoi(argv[17]);
|
||||
n_iter = std::stoi(argv[18]);
|
||||
rotating = std::stoull(argv[19]) * 1024 * 1024;
|
||||
|
||||
printf("n_warmup:%d, n_iter:%d, rotating:%lu\n", n_warmup, n_iter, rotating);
|
||||
}
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using I4 = ck::pk_i4_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
auto profile = [&](auto a_type,
|
||||
auto b_type,
|
||||
auto b_scale_type,
|
||||
auto comp_type,
|
||||
auto acc_type,
|
||||
auto c_type,
|
||||
auto scale_block_k,
|
||||
auto a_layout,
|
||||
auto b_layout,
|
||||
auto c_layout) {
|
||||
using ADataType = decltype(a_type);
|
||||
using BDataType = decltype(b_type);
|
||||
using BScaleDataType = decltype(b_scale_type);
|
||||
using ComputeDataType = decltype(comp_type);
|
||||
using AccDataType = decltype(acc_type);
|
||||
using CDataType = decltype(c_type);
|
||||
|
||||
using ALayout = decltype(a_layout);
|
||||
using BLayout = decltype(b_layout);
|
||||
using CLayout = decltype(c_layout);
|
||||
|
||||
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
|
||||
|
||||
bool pass = ck::profiler::profile_batched_gemm_b_scale_impl<ADataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
scale_block_k,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? DefaultStrideA : StrideA,
|
||||
(StrideB < 0) ? DefaultStrideB : StrideB,
|
||||
(StrideC < 0) ? DefaultStrideC : StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
BatchStrideScaleB,
|
||||
BatchSize,
|
||||
KBatch,
|
||||
n_warmup,
|
||||
n_iter,
|
||||
rotating);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
};
|
||||
|
||||
if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN &&
|
||||
B_scale_block == BScaleBlockTile::K_128)
|
||||
{
|
||||
printf("F16_I4_F16 MK_NK_MN K_128\n");
|
||||
return profile(
|
||||
F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<128>{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_batched_gemm_b_scale);
|
||||
Reference in New Issue
Block a user