Jing's contribution: prototype of mixed precision gemm FP16/BF16xint4 GEMM (#1762)

* add a prototype of int4

* clean

* debug

* clean

* clean

* move packed into dynamic_buffer

* fixed coord reset

* add fast pki4 to half conversion

* fix

* fixed reference and host_tensor

* fixed tensor init

* format

* debug i4_to_f16_convert

* format

* fixed splitk

* weight permute

* add b tile permute

* clean

* weight permute with splitki

* format

* improve weight layout

* add and_or_b32

* fixed splitk crush

* add permute switch as a template

* recover v3r1

* clean

* failure with intrawave v2

* fixed

* fixed

* add ckProfiler

* add bfp16 support

* add bf16 example

* fixed int4 to bhalf_t conversion

* format

* fixed int4 to bf16 conversion

* clean

* add instances for mem

* clean

* fixed host tensor size

* fixed

* debug

* fixed

* add pk_i4_t as a struct

* fix

* Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* revert

* Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* fixed comments

* revert

* clean

* revert

* revert

* fixed

* Update CMakeLists.txt

* Update script/cmake-ck-dev.sh

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update CMakeLists.txt

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* fixed

* fixed

* fixed

* revert

* revert

* add comments

* format

* fixed assert

* fixed

* Fix I4 define in ckProfiler

* Fixed example_gemm_xdl_bf16_pk_i4_v3 test failed issue

---------

Co-authored-by: Jing Zhang <jizhan@fb.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: mtgu0705 <mtgu@amd.com>

[ROCm/composable_kernel commit: 1d8e4ec2ce]
This commit is contained in:
Adam Osewski
2025-01-02 04:48:06 +01:00
committed by GitHub
parent a860c20099
commit ac74520ff6
37 changed files with 1582 additions and 349 deletions

View File

@@ -585,7 +585,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
)
add_subdirectory(example)
if(BUILD_TESTING)
add_subdirectory(test)
add_subdirectory(test)
endif()
endif()

View File

@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt

View File

@@ -29,6 +29,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3)
add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3)
add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp)
add_example_executable(example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp)
add_example_executable(example_gemm_xdl_bf16_pk_i4_v3 gemm_xdl_bf16_pk_i4_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3)
add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3)

View File

@@ -287,3 +287,85 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
return true;
}
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}

View File

@@ -0,0 +1,253 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
using ADataType = ck::bhalf_t;
using BDataType = ck::pk_i4_t;
using AccDataType = float;
using CShuffleDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = true;
static constexpr ck::index_t KPerBlock = 128;
// clang-format off
using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
128,
16, 64,
KPerBlock, 8, 32,
16, 16,
1, 2,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto M = problem_size.M;
auto N = problem_size.N;
auto K = problem_size.K;
auto StrideA = problem_size.StrideA;
auto StrideB = problem_size.StrideB;
auto StrideC = problem_size.StrideC;
auto KBatch = problem_size.KBatch;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 3:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute
if constexpr(PermuteB)
{
int K1 = KPerBlock;
int K0 = K / KPerBlock;
// int K0, N, K1
for(int j = 0; j < K0; j++)
{
for(int i = 0; i < N; i++)
{
for(int jj = 0; jj < K1; jj++)
{
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj));
}
}
}
}
else
{
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
}
}
}
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
DeviceMem workspace;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmV2Instance{};
auto invoker = gemm.MakeInvoker();
float ave_time = 0;
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
bool pass = true;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0});
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
}
if(config.time_kernel)
{
ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K +
sizeof(BDataType) * K * N /
(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
return pass;
}
bool run_gemm_splitk_example(int argc, char* argv[])
{
ProblemSizeSplitK problem_size;
ExecutionConfig config;
return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config);
}
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }

View File

@@ -1,12 +1,12 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
using ADataType = ck::f8_t;
using BDataType = ck::half_t;
using ADataType = ck::half_t;
using BDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t;
@@ -29,15 +29,15 @@ using DeviceGemmV2Instance =
AElementOp, BElementOp, CElementOp, GemmDefault,
64,
16, 16,
64, 16, 8,
256, 8, 16,
16, 16,
1, 1,
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
1, 1, S<1, 16, 1, 4>, 4,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v1>;
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,

View File

@@ -0,0 +1,303 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
using ADataType = ck::half_t;
using BDataType = ck::pk_i4_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = true;
static constexpr ck::index_t KPerBlock = 128;
// clang-format off
using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
128,
16, 128,
KPerBlock, 8, 32,
16, 16,
1, 4,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto M = problem_size.M;
auto N = problem_size.N;
auto K = problem_size.K;
auto StrideA = problem_size.StrideA;
auto StrideB = problem_size.StrideB;
auto StrideC = problem_size.StrideC;
auto KBatch = problem_size.KBatch;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 3:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute
if constexpr(PermuteB)
{
int K1 = KPerBlock;
int K0 = K / KPerBlock;
// int K0, N, K1
for(int j = 0; j < K0; j++)
{
for(int i = 0; i < N; i++)
{
for(int jj = 0; jj < K1; jj++)
{
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj));
}
}
}
}
else
{
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
}
}
}
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int input[8];
for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n_permute(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 0, i) = i4x2;
}
{
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 2, i) = i4x2;
}
{
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 4, i) = i4x2;
}
{
int hi = input[7];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 6, i) = i4x2;
}
}
}
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
DeviceMem workspace;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmV2Instance{};
auto invoker = gemm.MakeInvoker();
float ave_time = 0;
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
bool pass = true;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0});
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
}
if(config.time_kernel)
{
ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K +
sizeof(BDataType) * K * N /
(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
return pass;
}
bool run_gemm_splitk_example(int argc, char* argv[])
{
ProblemSizeSplitK problem_size;
ExecutionConfig config;
return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config);
}
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }

View File

@@ -12,7 +12,7 @@ using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t;
using ALayout = Row;
using BLayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
@@ -27,17 +27,17 @@ using DeviceGemmV2Instance =
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
256,
224, 256,
64, 8, 2,
64,
16, 16,
256, 8, 8,
16, 16,
7, 8,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
1, 1,
S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 8, 2, 0,
1, 2, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
1, 1, S<1, 16, 1, 4>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::

View File

@@ -5,88 +5,6 @@
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 2e-1;
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 2e-1;
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 2e-1;
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 2e-1;
}
else
{
return 1e-3;
}
}
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{

View File

@@ -3,88 +3,6 @@
#pragma once
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{

View File

@@ -3,88 +3,6 @@
#pragma once
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{

View File

@@ -266,18 +266,18 @@ struct Tensor
using Data = std::vector<T>;
template <typename X>
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(GetElementSpaceSize())
{
}
template <typename X, typename Y>
Tensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
: mDesc(lens, strides), mData(mDesc.GetElementSpaceSize())
: mDesc(lens, strides), mData(GetElementSpaceSize())
{
}
template <typename Lengths>
Tensor(const Lengths& lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
Tensor(const Lengths& lens) : mDesc(lens), mData(GetElementSpaceSize())
{
}
@@ -287,7 +287,7 @@ struct Tensor
{
}
Tensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {}
Tensor(const Descriptor& desc) : mDesc(desc), mData(GetElementSpaceSize()) {}
template <typename OutT>
Tensor<OutT> CopyAsType() const
@@ -322,7 +322,17 @@ struct Tensor
std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); }
std::size_t GetElementSpaceSize() const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return (mDesc.GetElementSpaceSize() + 1) / 2;
}
else
{
return mDesc.GetElementSpaceSize();
}
}
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
@@ -469,29 +479,64 @@ struct Tensor
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
return mDesc.GetOffsetFromMultiIndex(is...);
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
}
else
{
return mDesc.GetOffsetFromMultiIndex(is...);
}
}
template <typename... Is>
T& operator()(Is... is)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
}
template <typename... Is>
const T& operator()(Is... is) const
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
}
T& operator()(std::vector<std::size_t> idx)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
}
const T& operator()(std::vector<std::size_t> idx) const
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
}
typename Data::iterator begin() { return mData.begin(); }

View File

@@ -81,6 +81,20 @@ struct GeneratorTensor_1<int8_t>
}
};
template <>
struct GeneratorTensor_1<ck::pk_i4_t>
{
int8_t value = 1;
template <typename... Is>
ck::pk_i4_t operator()(Is...)
{
int t = value + 8;
ck::pk_i4_t r = ((t << 4) + t) & 0xff;
return r;
}
};
template <typename T>
struct GeneratorTensor_2
{
@@ -121,6 +135,22 @@ struct GeneratorTensor_2<int8_t>
}
};
template <>
struct GeneratorTensor_2<ck::pk_i4_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::pk_i4_t operator()(Is...)
{
int hi = std::rand() % (max_value - min_value) + min_value + 8;
int lo = std::rand() % (max_value - min_value) + min_value + 8;
ck::pk_i4_t r = ((hi << 4) + lo) & 0xff;
return r;
}
};
#if defined CK_ENABLE_FP8
template <>
struct GeneratorTensor_2<ck::f8_t>

View File

@@ -167,7 +167,7 @@ struct StaticTensorTupleOfVectorBuffer
// Idx is for S, not X. Idx should be aligned with X
template <typename X,
typename Idx,
typename enable_if<has_same_scalar_type<S, X>::value &&
typename enable_if<(has_same_scalar_type<S, X>::value || !is_native_type<S>()) &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr X GetAsType(Idx) const
@@ -201,7 +201,7 @@ struct StaticTensorTupleOfVectorBuffer
// Idx is for S, not X. Idx should be aligned with X
template <typename X,
typename Idx,
typename enable_if<has_same_scalar_type<S, X>::value &&
typename enable_if<(has_same_scalar_type<S, X>::value || !is_native_type<S>()) &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr void SetAsType(Idx, X x)

View File

@@ -36,6 +36,10 @@ struct DeviceGemmV2 : public BaseOperator
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual bool GetPermuteA() = 0;
virtual bool GetPermuteB() = 0;
virtual ck::index_t GetKPerBlock() = 0;
};
template <typename ALayout,

View File

@@ -64,7 +64,9 @@ template <typename ALayout,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BLayout,
CLayout,
@@ -122,7 +124,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
ComputeTypeB,
PermuteA,
PermuteB>;
using Argument = typename GridwiseGemm::Argument;
@@ -633,6 +637,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
index_t GetKPerBlock() override { return KPerBlock; }
bool GetPermuteA() override { return PermuteA; }
bool GetPermuteB() override { return PermuteB; }
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,

View File

@@ -7,12 +7,177 @@
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/amd_inline_asm.hpp"
#include <cassert>
namespace ck {
// Fast int4x4 to half8_t data type conversion based on paper
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__host__ __device__ inline half4_t pki4_to_half4(int q)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Extract the two int4 at low bit and create two fp16 number.
int lo = amd_assembly_and_or_b32(q, LO, EX);
// Extract the two int4 at hight bit and create two fp16 number.
int hi = amd_assembly_and_or_b32(q, HI, EX);
const int SUB = 0xE408E408; // half2 {-1032, -1032}
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
const int ADD = 0xd480d480; // half2 {-72, -72}
vector_type<half_t, 4> res;
// for two fp16 from lowbit, subtract 1032 to get correct fp16 value
res.template AsType<half2_t>()(Number<0>{}) =
amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
// for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD));
return res.template AsType<half4_t>()[Number<0>{}];
}
__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
{
#if 1
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8
int lo = i4s | EX;
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
#else
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
vector_type<half_t, 2> res;
half_t x_h = (x_u8 & 0x0f) - 8;
half_t x_l = ((x_u8 & 0xf0) >> 4) - 8;
res.template AsType<half_t>()(Number<0>{}) = x_l;
res.template AsType<half_t>()(Number<1>{}) = x_h;
return res.template AsType<half2_t>()[Number<0>{}];
#endif
}
__host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
{
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
static constexpr uint32_t fp32_base = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388616.f;
fp32_intermediates[1] -= 8388616.f;
fp32_intermediates[2] -= 8388616.f;
fp32_intermediates[3] -= 8388616.f;
vector_type<bhalf_t, 4> res;
res.template AsType<bhalf2_t>()(Number<0>{}) = bit_cast<bhalf2_t>(
__byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632));
res.template AsType<bhalf2_t>()(Number<1>{}) = bit_cast<bhalf2_t>(
__byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
return res.template AsType<bhalf4_t>()[Number<0>{}];
}
__host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
float x_h = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_l = ((x_u8 & 0xf0) >> 4) - 8.f;
vector_type<bhalf_t, 2> res;
res.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(x_l);
res.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(x_h);
return res.template AsType<bhalf2_t>()[Number<0>{}];
}
namespace tensor_operation {
namespace element_wise {
struct PassThroughPack8
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
__host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const
{
#if 1
vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8);
y = result.template AsType<half8_t>()[Number<0>{}];
#else
vector_type<half_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};
dst.template AsType<half2_t>()(Number<0>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<1>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<half8_t>()[Number<0>{}];
#endif
}
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const
{
#if 1
vector_type<bhalf_t, 8> result;
result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x));
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);
y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else
vector_type<bhalf_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};
dst.template AsType<bhalf2_t>()(Number<0>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<bhalf2_t>()(Number<1>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<bhalf2_t>()(Number<2>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<bhalf2_t>()(Number<3>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
#endif
}
constexpr const static bool is_pack8_invocable = true;
};
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct UnaryOpBase
@@ -49,6 +214,24 @@ struct PassThroughPack2
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
}
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
{
#if 1
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
auto l_f16 = ck::type_convert<ck::half_t>(x_l);
auto h_f16 = ck::type_convert<ck::half_t>(x_h);
y = {l_f16, h_f16};
#else
uint32_t t = ck::bit_cast<uint8_t>(x);
y = ck::bit_cast<half2_t>(t);
#endif
}
constexpr const static bool is_pack2_invocable = true;
};
@@ -76,6 +259,12 @@ struct PassThrough final : public UnaryOpBase
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<pk_i4_t, pk_i4_t>(pk_i4_t& y, const pk_i4_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const
{

View File

@@ -127,7 +127,9 @@ template <typename ALayout,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct GridwiseGemm_xdl_cshuffle_v3
{
static constexpr auto I0 = Number<0>{};
@@ -151,6 +153,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
return 2;
else
return 1;
}();
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
@@ -319,6 +335,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
GemmSpec != GemmSpecialization::Default),
"pk_i4_t does not support padding");
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
@@ -373,15 +393,39 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
if constexpr(!PermuteB)
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
return b_grid_desc_bk0_n_bk1;
}
else
{
// Pre-shuffled Weight
// BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
constexpr index_t BK01 = KPerBlock / BK1Value;
const index_t BK0_ = StrideB / BK1Value;
const index_t BK00 = BK0_ / BK01;
const auto b_grid_desc_bk00_n_bk01_bk1_permute =
make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
b_grid_desc_bk00_n_bk01_bk1_permute,
make_tuple(make_merge_transform(make_tuple(BK00, BK01)),
make_pass_through_transform(make_tuple(N)),
make_pass_through_transform(BK1Value)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_grid_desc_bk0_n_bk1_permute;
}
}
}
@@ -572,7 +616,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead;
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
@@ -585,7 +629,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead;
if constexpr(!PermuteB)
{
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
}
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
@@ -625,9 +677,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
// in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1
? 1
: 32 * 4 / KPerBlock / sizeof(ADataType);
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
@@ -761,10 +812,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
// NLdsLayer * K0 as logical Bank
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1
? 1
: 32 * 4 / KPerBlock / sizeof(BDataType);
;
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
@@ -946,8 +995,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(BDataType)),
return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize +
b_block_space_size_aligned * sizeof(BDataType) / BPackedSize),
c_block_size * sizeof(CShuffleDataType));
}
@@ -1312,8 +1361,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BDataType*>(p_shared) +
a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType),
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
sizeof(ADataType) /
APackedSize),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
@@ -1706,16 +1756,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BDataType*>(p_shared_0) +
a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType),
bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
a_block_space_size_aligned * sizeof(ADataType)),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BDataType*>(p_shared_1) +
a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType),
bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
a_block_space_size_aligned * sizeof(ADataType)),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);

View File

@@ -1007,6 +1007,13 @@ struct ThreadwiseTensorSliceTransfer_v4
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
__device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx)
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
{
@@ -1015,6 +1022,11 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
}
}
template <typename SrcRefToOriginDisplacement,
@@ -1109,7 +1121,7 @@ struct ThreadwiseTensorSliceTransfer_v4
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
vector_type_maker_t<SrcData, SrcScalarPerVector / PackedSize> src_tmp_vector;
using src_vector_t = typename decltype(src_tmp_vector)::type;
@@ -1120,7 +1132,8 @@ struct ThreadwiseTensorSliceTransfer_v4
if constexpr(SrcBuffer::IsDynamicBuffer())
{
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize,
is_src_valid);
}
else if constexpr(SrcBuffer::IsStaticBuffer())
{
@@ -1133,9 +1146,36 @@ struct ThreadwiseTensorSliceTransfer_v4
});
}
if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
is_same<remove_cvref_t<DstData>, half_t>::value &&
SrcScalarPerVector % 2 == 0)
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
constexpr index_t pack_size = 8;
static_assert(SrcScalarPerVector % pack_size == 0, "");
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack8{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
is_same<remove_cvref_t<DstData>, half_t>::value &&
SrcScalarPerVector % 2 == 0)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)

View File

@@ -31,8 +31,8 @@ template <typename SliceLengths,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarPerVector_,
index_t DstScalarPerVector_,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
@@ -55,6 +55,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto I0 = Number<0>{};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
const SrcDesc& src_desc,
const Index& src_slice_origin,
@@ -67,6 +77,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_(src_element_op),
dst_element_op_(dst_element_op)
{
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
static_assert(
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
@@ -95,11 +116,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
@@ -180,9 +201,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
auto src_vector_container =
src_vector_type{src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), true)};
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
dst_vector_type op_r_v;
@@ -193,17 +211,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value)
else if constexpr(is_detected<is_pack4_invocable_t,
decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value)
else if constexpr(is_detected<is_pack2_invocable_t,
decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector);
}
return 1;
else
{
return 1;
}
};
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
@@ -211,6 +234,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)};
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
@@ -276,10 +302,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
});
#else
// OOB Check
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
@@ -350,6 +375,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
static_assert(!is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
"in-register transpose is not supported for pk_i4_t");
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
@@ -410,7 +437,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
}
else
{
static_ford<SliceLengths>{}([&](auto idx) {
constexpr auto packed_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
static_ford<decltype(packed_access_lengths)>{}([&](auto idx) {
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
});
}
@@ -438,7 +470,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// src scalar per access on each dim
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
@@ -526,13 +558,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// apply DstElementwiseOperation
dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]);
dst_vector_container.template AsType<DstData>()(i) = dst_v;
});
// copy data from dst_vector_container to dst_buf
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
dst_coord_.GetOffset() / PackedSize,
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
@@ -586,7 +616,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
@@ -644,7 +674,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
@@ -730,7 +760,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
@@ -779,7 +809,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
@@ -790,7 +820,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{
// 1st stage of transforms
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;

View File

@@ -429,7 +429,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
(is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, pk_i4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
using r_t = typename vector_type<T, N>::type;

View File

@@ -11,6 +11,27 @@
namespace ck {
inline __device__ int amd_assembly_and_or_b32(int a, int b, int d)
{
int c;
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(c) : "v"(a), "v"(b), "v"(d));
return c;
}
inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c)
{
half2_t d;
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
return d;
}
inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
{
half2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
return c;
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)

View File

@@ -12,6 +12,15 @@ using bhalf_t = ushort;
using half_t = _Float16;
using int4_t = _BitInt(4);
// custom data type - pack int4 data
struct pk_i4_t
{
using type = int8_t;
type data;
__host__ __device__ constexpr pk_i4_t() : data{type{}} {}
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
};
inline constexpr auto next_pow2(uint32_t x)
{
// Precondition: x > 1.
@@ -165,6 +174,13 @@ struct scalar_type<int4_t>
};
#endif
template <>
struct scalar_type<pk_i4_t>
{
using type = pk_i4_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<f8_fnuz_t>
{
@@ -1044,6 +1060,12 @@ struct nnvb_data_t_selector<bf8_ocp_t>
using type = bf8_ocp_t::data_type;
};
template <>
struct nnvb_data_t_selector<pk_i4_t>
{
using type = pk_i4_t::type;
};
template <typename T, index_t N>
struct non_native_vector_base<
T,
@@ -1163,6 +1185,14 @@ struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
static constexpr index_t vector_size = N;
};
template <index_t N>
struct scalar_type<non_native_vector_base<pk_i4_t, N>>
{
using type = typename non_native_vector_base<pk_i4_t, N>::data_t;
static constexpr index_t vector_size = N;
};
// non-native vector_type implementation
template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
@@ -1871,6 +1901,11 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using uint8x64_t = typename vector_type<uint8_t, 64>::type;
// pack int4
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
using pk_i4x8_t = typename vector_type<pk_i4_t, 8>::type;
template <typename T>
struct NumericLimits
{

View File

@@ -54,7 +54,8 @@ struct DynamicBuffer
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{
@@ -195,7 +196,8 @@ struct DynamicBuffer
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{

View File

@@ -116,7 +116,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X
template <typename X,
index_t I,
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
typename enable_if<has_same_scalar_type<S, X>::value || !is_native_type<S>(),
bool>::type = false>
__host__ __device__ constexpr auto GetAsType(Number<I> i) const
{
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
@@ -134,7 +135,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X
template <typename X,
index_t I,
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
typename enable_if<has_same_scalar_type<S, X>::value || !is_native_type<S>(),
bool>::type = false>
__host__ __device__ constexpr void SetAsType(Number<I> i, X x)
{
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};

View File

@@ -74,6 +74,17 @@ struct ReferenceGemm : public device::BaseOperator
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else if constexpr(is_same_v<ADataType, pk_i4_t>)
{
uint8_t i4x2 = arg.a_m_k_(m, k).data;
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
v_a = type_convert<ComputeTypeA>(i4);
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
@@ -84,6 +95,17 @@ struct ReferenceGemm : public device::BaseOperator
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
}
else if constexpr(is_same_v<BDataType, pk_i4_t>)
{
uint8_t i4x2 = arg.b_k_n_(k, n).data;
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
v_b = type_convert<ComputeTypeB>(i4);
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));

View File

@@ -22,6 +22,7 @@ using I8 = int8_t;
using I32 = int32_t;
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
using I4 = ck::pk_i4_t;
using Empty_Tuple = ck::Tuple<>;

View File

@@ -166,11 +166,22 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, I4, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, I4, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
@@ -810,6 +821,28 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, pk_i4_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
}
}
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, pk_i4_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances(
op_ptrs);
}
}
return op_ptrs;
}
};

View File

@@ -97,6 +97,9 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp

View File

@@ -0,0 +1,87 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I4 = pk_i4_t;
using BF16 = bhalf_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
#if 0
template <GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_comp_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
#endif
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_instances =
std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| ACompType| BCompType| APermute| BPermute|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| | | | |
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| | | | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 128, 8, 32, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 128, 8, 16, 16, 16, 4, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,24 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, I4, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,86 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I4 = pk_i4_t;
using F16 = half_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
#if 0
template <GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_comp_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
#endif
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| ACompType| BCompType| APermute| BPermute|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| | | | |
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| | | | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 128, 8, 32, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 128, 8, 16, 16, 16, 4, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,24 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, I4, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -65,11 +65,13 @@ bool profile_gemm_universal_impl(int do_verification,
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
int total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes();
int rotating_count = std::max(
std::size_t total_gemm_needed =
a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes();
int rotating_count = std::max(
1,
std::min(n_iter,
static_cast<int>(std::ceil(static_cast<double>(rotating) / total_gemm_needed))));
@@ -86,9 +88,13 @@ bool profile_gemm_universal_impl(int do_verification,
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
break;
default:
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
@@ -100,11 +106,10 @@ bool profile_gemm_universal_impl(int do_verification,
const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
using DeviceOp = ck::tensor_operation::device::DeviceGemmV2<ALayout,
BLayout,
@@ -153,6 +158,84 @@ bool profile_gemm_universal_impl(int do_verification,
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
{
const int KPerBlock = op_ptr->GetKPerBlock();
if(op_ptr->GetPermuteB())
{
int K1 = KPerBlock;
int K0 = K / KPerBlock;
// int K0, N, K1
for(int j = 0; j < K0; j++)
{
for(int i = 0; i < N; i++)
{
for(int jj = 0; jj < K1; jj++)
{
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj));
}
}
}
if(is_same_v<BDataType, pk_i4_t> && is_same_v<ADataType, half_t>)
{
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int input[8];
for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n_permute(j + k * 2, i);
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 0, i) = i4x2;
}
{
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 2, i) = i4x2;
}
{
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 4, i) = i4x2;
}
{
int hi = input[7];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 6, i) = i4x2;
}
}
}
}
}
else
{
b_k_n_permute = b_k_n;
}
b_device_buf.ToDevice(b_k_n_permute.mData.data());
std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38};
if(KBatch > 0)
@@ -240,7 +323,15 @@ bool profile_gemm_universal_impl(int do_verification,
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
return 2;
else
return 1;
}();
std::size_t num_btype = sizeof(ADataType) * M * K +
sizeof(BDataType) * K * N / BPackedSize +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

View File

@@ -177,5 +177,4 @@ if(DL_KERNELS)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
endif()
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)

View File

@@ -1,10 +1,10 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_gemm_universal_impl.hpp"
#include "profiler_operation_registry.hpp"
@@ -27,6 +27,8 @@ enum struct GemmDataType
F16_F8_F16, // 5
F16_F16_F16_F8, // 6
F8_F8_BF16, // 7
F16_I4_F16, // 8
BF16_I4_BF16, // 9
};
#define OP_NAME "gemm_universal"
@@ -39,7 +41,7 @@ int profile_gemm_universal(int argc, char* argv[])
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8)\n");
"comp f8; 8: f16@i4; 9: bf16@i4\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
@@ -103,6 +105,7 @@ int profile_gemm_universal(int argc, char* argv[])
using BF16 = ck::bhalf_t;
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using F8 = ck::f8_t;
using I4 = ck::pk_i4_t;
#endif
using Row = ck::tensor_layout::gemm::RowMajor;
@@ -207,6 +210,14 @@ int profile_gemm_universal(int argc, char* argv[])
{
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, I4{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::BF16_I4_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(BF16{}, I4{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{});
}
#endif
else
{

View File

@@ -17,7 +17,7 @@ fi
cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \